From fc9bd620ce94b64af46737e25a524336967782a1 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 16 Mar 2022 10:39:15 -0400 Subject: Add a relations handler to avoid duplication. (#12227) Adds a handler layer between the REST and datastore layers for relations. --- synapse/rest/client/relations.py | 75 +++++----------------------------------- 1 file changed, 8 insertions(+), 67 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index d9a6be43f7..c16078b187 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -51,9 +51,7 @@ class RelationPaginationServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastores().main - self.clock = hs.get_clock() - self._event_serializer = hs.get_event_client_serializer() - self.event_handler = hs.get_event_handler() + self._relations_handler = hs.get_relations_handler() async def on_GET( self, @@ -65,16 +63,6 @@ class RelationPaginationServlet(RestServlet): ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) - await self.auth.check_user_in_room_or_world_readable( - room_id, requester.user.to_string(), allow_departed_users=True - ) - - # This gets the original event and checks that a) the event exists and - # b) the user is allowed to view it. - event = await self.event_handler.get_event(requester.user, room_id, parent_id) - if event is None: - raise SynapseError(404, "Unknown parent event.") - limit = parse_integer(request, "limit", default=5) direction = parse_string( request, "org.matrix.msc3715.dir", default="b", allowed_values=["f", "b"] @@ -90,9 +78,9 @@ class RelationPaginationServlet(RestServlet): if to_token_str: to_token = await StreamToken.from_string(self.store, to_token_str) - pagination_chunk = await self.store.get_relations_for_event( + result = await self._relations_handler.get_relations( + requester=requester, event_id=parent_id, - event=event, room_id=room_id, relation_type=relation_type, event_type=event_type, @@ -102,30 +90,7 @@ class RelationPaginationServlet(RestServlet): to_token=to_token, ) - events = await self.store.get_events_as_list( - [c["event_id"] for c in pagination_chunk.chunk] - ) - - now = self.clock.time_msec() - # Do not bundle aggregations when retrieving the original event because - # we want the content before relations are applied to it. - original_event = self._event_serializer.serialize_event( - event, now, bundle_aggregations=None - ) - # The relations returned for the requested event do include their - # bundled aggregations. - aggregations = await self.store.get_bundled_aggregations( - events, requester.user.to_string() - ) - serialized_events = self._event_serializer.serialize_events( - events, now, bundle_aggregations=aggregations - ) - - return_value = await pagination_chunk.to_dict(self.store) - return_value["chunk"] = serialized_events - return_value["original_event"] = original_event - - return 200, return_value + return 200, result class RelationAggregationPaginationServlet(RestServlet): @@ -245,9 +210,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastores().main - self.clock = hs.get_clock() - self._event_serializer = hs.get_event_client_serializer() - self.event_handler = hs.get_event_handler() + self._relations_handler = hs.get_relations_handler() async def on_GET( self, @@ -260,18 +223,6 @@ class RelationAggregationGroupPaginationServlet(RestServlet): ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) - await self.auth.check_user_in_room_or_world_readable( - room_id, - requester.user.to_string(), - allow_departed_users=True, - ) - - # This checks that a) the event exists and b) the user is allowed to - # view it. - event = await self.event_handler.get_event(requester.user, room_id, parent_id) - if event is None: - raise SynapseError(404, "Unknown parent event.") - if relation_type != RelationTypes.ANNOTATION: raise SynapseError(400, "Relation type must be 'annotation'") @@ -286,9 +237,9 @@ class RelationAggregationGroupPaginationServlet(RestServlet): if to_token_str: to_token = await StreamToken.from_string(self.store, to_token_str) - result = await self.store.get_relations_for_event( + result = await self._relations_handler.get_relations( + requester=requester, event_id=parent_id, - event=event, room_id=room_id, relation_type=relation_type, event_type=event_type, @@ -298,17 +249,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): to_token=to_token, ) - events = await self.store.get_events_as_list( - [c["event_id"] for c in result.chunk] - ) - - now = self.clock.time_msec() - serialized_events = self._event_serializer.serialize_events(events, now) - - return_value = await result.to_dict(self.store) - return_value["chunk"] = serialized_events - - return 200, return_value + return 200, result def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: -- cgit 1.5.1 From 872dbb0181714e201be082c4e8bd9b727c73f177 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 18 Mar 2022 13:51:41 +0000 Subject: Correct `check_username_for_spam` annotations and docs (#12246) * Formally type the UserProfile in user searches * export UserProfile in synapse.module_api * Update docs Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> --- changelog.d/12246.doc | 1 + docs/modules/spam_checker_callbacks.md | 10 ++++++---- synapse/events/spamcheck.py | 7 +++---- synapse/handlers/user_directory.py | 4 ++-- synapse/module_api/__init__.py | 2 ++ synapse/rest/client/user_directory.py | 4 ++-- synapse/storage/databases/main/user_directory.py | 23 +++++++++++++++++++---- synapse/types.py | 11 +++++++++++ 8 files changed, 46 insertions(+), 16 deletions(-) create mode 100644 changelog.d/12246.doc (limited to 'synapse/rest/client') diff --git a/changelog.d/12246.doc b/changelog.d/12246.doc new file mode 100644 index 0000000000..e7fcc1b99c --- /dev/null +++ b/changelog.d/12246.doc @@ -0,0 +1 @@ +Correct `check_username_for_spam` annotations and docs. \ No newline at end of file diff --git a/docs/modules/spam_checker_callbacks.md b/docs/modules/spam_checker_callbacks.md index 2b672b78f9..472d957180 100644 --- a/docs/modules/spam_checker_callbacks.md +++ b/docs/modules/spam_checker_callbacks.md @@ -172,7 +172,7 @@ any of the subsequent implementations of this callback. _First introduced in Synapse v1.37.0_ ```python -async def check_username_for_spam(user_profile: Dict[str, str]) -> bool +async def check_username_for_spam(user_profile: synapse.module_api.UserProfile) -> bool ``` Called when computing search results in the user directory. The module must return a @@ -182,9 +182,11 @@ search results; otherwise return `False`. The profile is represented as a dictionary with the following keys: -* `user_id`: The Matrix ID for this user. -* `display_name`: The user's display name. -* `avatar_url`: The `mxc://` URL to the user's avatar. +* `user_id: str`. The Matrix ID for this user. +* `display_name: Optional[str]`. The user's display name, or `None` if this user + has not set a display name. +* `avatar_url: Optional[str]`. The `mxc://` URL to the user's avatar, or `None` + if this user has not set an avatar. The module is given a copy of the original dictionary, so modifying it from within the module cannot modify a user's profile when included in user directory search results. diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 60904a55f5..cd80fcf9d1 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -21,7 +21,6 @@ from typing import ( Awaitable, Callable, Collection, - Dict, List, Optional, Tuple, @@ -31,7 +30,7 @@ from typing import ( from synapse.rest.media.v1._base import FileInfo from synapse.rest.media.v1.media_storage import ReadableFileWrapper from synapse.spam_checker_api import RegistrationBehaviour -from synapse.types import RoomAlias +from synapse.types import RoomAlias, UserProfile from synapse.util.async_helpers import maybe_awaitable if TYPE_CHECKING: @@ -50,7 +49,7 @@ USER_MAY_SEND_3PID_INVITE_CALLBACK = Callable[[str, str, str, str], Awaitable[bo USER_MAY_CREATE_ROOM_CALLBACK = Callable[[str], Awaitable[bool]] USER_MAY_CREATE_ROOM_ALIAS_CALLBACK = Callable[[str, RoomAlias], Awaitable[bool]] USER_MAY_PUBLISH_ROOM_CALLBACK = Callable[[str, str], Awaitable[bool]] -CHECK_USERNAME_FOR_SPAM_CALLBACK = Callable[[Dict[str, str]], Awaitable[bool]] +CHECK_USERNAME_FOR_SPAM_CALLBACK = Callable[[UserProfile], Awaitable[bool]] LEGACY_CHECK_REGISTRATION_FOR_SPAM_CALLBACK = Callable[ [ Optional[dict], @@ -383,7 +382,7 @@ class SpamChecker: return True - async def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool: + async def check_username_for_spam(self, user_profile: UserProfile) -> bool: """Checks if a user ID or display name are considered "spammy" by this server. If the server considers a username spammy, then it will not be included in diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index d27ed2be6a..048fd4bb82 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -19,8 +19,8 @@ import synapse.metrics from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules, Membership from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.databases.main.user_directory import SearchResult from synapse.storage.roommember import ProfileInfo -from synapse.types import JsonDict from synapse.util.metrics import Measure if TYPE_CHECKING: @@ -78,7 +78,7 @@ class UserDirectoryHandler(StateDeltasHandler): async def search_users( self, user_id: str, search_term: str, limit: int - ) -> JsonDict: + ) -> SearchResult: """Searches for users in directory Returns: diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index d735c1d461..aa8256b36f 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -111,6 +111,7 @@ from synapse.types import ( StateMap, UserID, UserInfo, + UserProfile, create_requester, ) from synapse.util import Clock @@ -150,6 +151,7 @@ __all__ = [ "EventBase", "StateMap", "ProfileInfo", + "UserProfile", ] logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/user_directory.py b/synapse/rest/client/user_directory.py index a47d9bd01d..116c982ce6 100644 --- a/synapse/rest/client/user_directory.py +++ b/synapse/rest/client/user_directory.py @@ -19,7 +19,7 @@ from synapse.api.errors import SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.site import SynapseRequest -from synapse.types import JsonDict +from synapse.types import JsonMapping from ._base import client_patterns @@ -38,7 +38,7 @@ class UserDirectorySearchRestServlet(RestServlet): self.auth = hs.get_auth() self.user_directory_handler = hs.get_user_directory_handler() - async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonMapping]: """Searches for users in directory Returns: diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index e7fddd2426..55cc9178f0 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -26,6 +26,8 @@ from typing import ( cast, ) +from typing_extensions import TypedDict + from synapse.api.errors import StoreError if TYPE_CHECKING: @@ -40,7 +42,12 @@ from synapse.storage.database import ( from synapse.storage.databases.main.state import StateFilter from synapse.storage.databases.main.state_deltas import StateDeltasStore from synapse.storage.engines import PostgresEngine, Sqlite3Engine -from synapse.types import JsonDict, get_domain_from_id, get_localpart_from_id +from synapse.types import ( + JsonDict, + UserProfile, + get_domain_from_id, + get_localpart_from_id, +) from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -591,6 +598,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): ) +class SearchResult(TypedDict): + limited: bool + results: List[UserProfile] + + class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): # How many records do we calculate before sending it to # add_users_who_share_private_rooms? @@ -777,7 +789,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): async def search_user_dir( self, user_id: str, search_term: str, limit: int - ) -> JsonDict: + ) -> SearchResult: """Searches for users in directory Returns: @@ -910,8 +922,11 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): # This should be unreachable. raise Exception("Unrecognized database engine") - results = await self.db_pool.execute( - "search_user_dir", self.db_pool.cursor_to_dict, sql, *args + results = cast( + List[UserProfile], + await self.db_pool.execute( + "search_user_dir", self.db_pool.cursor_to_dict, sql, *args + ), ) limited = len(results) > limit diff --git a/synapse/types.py b/synapse/types.py index 53be3583a0..5ce2a5b0a5 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -34,6 +34,7 @@ from typing import ( import attr from frozendict import frozendict from signedjson.key import decode_verify_key_bytes +from typing_extensions import TypedDict from unpaddedbase64 import decode_base64 from zope.interface import Interface @@ -63,6 +64,10 @@ MutableStateMap = MutableMapping[StateKey, T] # JSON types. These could be made stronger, but will do for now. # A JSON-serialisable dict. JsonDict = Dict[str, Any] +# A JSON-serialisable mapping; roughly speaking an immutable JSONDict. +# Useful when you have a TypedDict which isn't going to be mutated and you don't want +# to cast to JsonDict everywhere. +JsonMapping = Mapping[str, Any] # A JSON-serialisable object. JsonSerializable = object @@ -791,3 +796,9 @@ class UserInfo: is_deactivated: bool is_guest: bool is_shadow_banned: bool + + +class UserProfile(TypedDict): + user_id: str + display_name: Optional[str] + avatar_url: Optional[str] -- cgit 1.5.1 From 8fe930c215f69913fbcd96d609ec6950644e4ec4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 18 Mar 2022 13:49:32 -0400 Subject: Move get_bundled_aggregations to relations handler. (#12237) The get_bundled_aggregations code is fairly high-level and uses a lot of store methods, we move it into the handler as that seems like a better fit. --- changelog.d/12237.misc | 1 + synapse/events/utils.py | 2 +- synapse/handlers/pagination.py | 5 +- synapse/handlers/relations.py | 151 +++++++++++++++++++++++++++- synapse/handlers/room.py | 5 +- synapse/handlers/search.py | 3 +- synapse/handlers/sync.py | 9 +- synapse/rest/client/room.py | 3 +- synapse/storage/databases/main/relations.py | 151 +--------------------------- 9 files changed, 173 insertions(+), 157 deletions(-) create mode 100644 changelog.d/12237.misc (limited to 'synapse/rest/client') diff --git a/changelog.d/12237.misc b/changelog.d/12237.misc new file mode 100644 index 0000000000..41c9dcbd37 --- /dev/null +++ b/changelog.d/12237.misc @@ -0,0 +1 @@ +Refactor the relations endpoints to add a `RelationsHandler`. diff --git a/synapse/events/utils.py b/synapse/events/utils.py index a0520068e0..7120062127 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -38,8 +38,8 @@ from synapse.util.frozenutils import unfreeze from . import EventBase if TYPE_CHECKING: + from synapse.handlers.relations import BundledAggregations from synapse.server import HomeServer - from synapse.storage.databases.main.relations import BundledAggregations # Split strings on "." but not "\." This uses a negative lookbehind assertion for '\' diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 41679f7f86..876b879483 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -134,6 +134,7 @@ class PaginationHandler: self.clock = hs.get_clock() self._server_name = hs.hostname self._room_shutdown_handler = hs.get_room_shutdown_handler() + self._relations_handler = hs.get_relations_handler() self.pagination_lock = ReadWriteLock() # IDs of rooms in which there currently an active purge *or delete* operation. @@ -539,7 +540,9 @@ class PaginationHandler: state_dict = await self.store.get_events(list(state_ids.values())) state = state_dict.values() - aggregations = await self.store.get_bundled_aggregations(events, user_id) + aggregations = await self._relations_handler.get_bundled_aggregations( + events, user_id + ) time_now = self.clock.time_msec() diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 8e475475ad..57135d4519 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -12,18 +12,53 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Dict, Iterable, Optional, cast +import attr +from frozendict import frozendict + +from synapse.api.constants import RelationTypes from synapse.api.errors import SynapseError +from synapse.events import EventBase from synapse.types import JsonDict, Requester, StreamToken if TYPE_CHECKING: from synapse.server import HomeServer + from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _ThreadAggregation: + # The latest event in the thread. + latest_event: EventBase + # The latest edit to the latest event in the thread. + latest_edit: Optional[EventBase] + # The total number of events in the thread. + count: int + # True if the current user has sent an event to the thread. + current_user_participated: bool + + +@attr.s(slots=True, auto_attribs=True) +class BundledAggregations: + """ + The bundled aggregations for an event. + + Some values require additional processing during serialization. + """ + + annotations: Optional[JsonDict] = None + references: Optional[JsonDict] = None + replace: Optional[EventBase] = None + thread: Optional[_ThreadAggregation] = None + + def __bool__(self) -> bool: + return bool(self.annotations or self.references or self.replace or self.thread) + + class RelationsHandler: def __init__(self, hs: "HomeServer"): self._main_store = hs.get_datastores().main @@ -103,7 +138,7 @@ class RelationsHandler: ) # The relations returned for the requested event do include their # bundled aggregations. - aggregations = await self._main_store.get_bundled_aggregations( + aggregations = await self.get_bundled_aggregations( events, requester.user.to_string() ) serialized_events = self._event_serializer.serialize_events( @@ -115,3 +150,115 @@ class RelationsHandler: return_value["original_event"] = original_event return return_value + + async def _get_bundled_aggregation_for_event( + self, event: EventBase, user_id: str + ) -> Optional[BundledAggregations]: + """Generate bundled aggregations for an event. + + Note that this does not use a cache, but depends on cached methods. + + Args: + event: The event to calculate bundled aggregations for. + user_id: The user requesting the bundled aggregations. + + Returns: + The bundled aggregations for an event, if bundled aggregations are + enabled and the event can have bundled aggregations. + """ + + # Do not bundle aggregations for an event which represents an edit or an + # annotation. It does not make sense for them to have related events. + relates_to = event.content.get("m.relates_to") + if isinstance(relates_to, (dict, frozendict)): + relation_type = relates_to.get("rel_type") + if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE): + return None + + event_id = event.event_id + room_id = event.room_id + + # The bundled aggregations to include, a mapping of relation type to a + # type-specific value. Some types include the direct return type here + # while others need more processing during serialization. + aggregations = BundledAggregations() + + annotations = await self._main_store.get_aggregation_groups_for_event( + event_id, room_id + ) + if annotations.chunk: + aggregations.annotations = await annotations.to_dict( + cast("DataStore", self) + ) + + references = await self._main_store.get_relations_for_event( + event_id, event, room_id, RelationTypes.REFERENCE, direction="f" + ) + if references.chunk: + aggregations.references = await references.to_dict(cast("DataStore", self)) + + # Store the bundled aggregations in the event metadata for later use. + return aggregations + + async def get_bundled_aggregations( + self, events: Iterable[EventBase], user_id: str + ) -> Dict[str, BundledAggregations]: + """Generate bundled aggregations for events. + + Args: + events: The iterable of events to calculate bundled aggregations for. + user_id: The user requesting the bundled aggregations. + + Returns: + A map of event ID to the bundled aggregation for the event. Not all + events may have bundled aggregations in the results. + """ + # De-duplicate events by ID to handle the same event requested multiple times. + # + # State events do not get bundled aggregations. + events_by_id = { + event.event_id: event for event in events if not event.is_state() + } + + # event ID -> bundled aggregation in non-serialized form. + results: Dict[str, BundledAggregations] = {} + + # Fetch other relations per event. + for event in events_by_id.values(): + event_result = await self._get_bundled_aggregation_for_event(event, user_id) + if event_result: + results[event.event_id] = event_result + + # Fetch any edits (but not for redacted events). + edits = await self._main_store.get_applicable_edits( + [ + event_id + for event_id, event in events_by_id.items() + if not event.internal_metadata.is_redacted() + ] + ) + for event_id, edit in edits.items(): + results.setdefault(event_id, BundledAggregations()).replace = edit + + # Fetch thread summaries. + summaries = await self._main_store.get_thread_summaries(events_by_id.keys()) + # Only fetch participated for a limited selection based on what had + # summaries. + participated = await self._main_store.get_threads_participated( + [event_id for event_id, summary in summaries.items() if summary], user_id + ) + for event_id, summary in summaries.items(): + if summary: + thread_count, latest_thread_event, edit = summary + results.setdefault( + event_id, BundledAggregations() + ).thread = _ThreadAggregation( + latest_event=latest_thread_event, + latest_edit=edit, + count=thread_count, + # If there's a thread summary it must also exist in the + # participated dictionary. + current_user_participated=participated[event_id], + ) + + return results diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index b9735631fc..092e185c99 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -60,8 +60,8 @@ from synapse.events import EventBase from synapse.events.utils import copy_power_levels_contents from synapse.federation.federation_client import InvalidResponseError from synapse.handlers.federation import get_domains_from_state +from synapse.handlers.relations import BundledAggregations from synapse.rest.admin._base import assert_user_is_admin -from synapse.storage.databases.main.relations import BundledAggregations from synapse.storage.state import StateFilter from synapse.streams import EventSource from synapse.types import ( @@ -1118,6 +1118,7 @@ class RoomContextHandler: self.store = hs.get_datastores().main self.storage = hs.get_storage() self.state_store = self.storage.state + self._relations_handler = hs.get_relations_handler() async def get_event_context( self, @@ -1190,7 +1191,7 @@ class RoomContextHandler: event = filtered[0] # Fetch the aggregations. - aggregations = await self.store.get_bundled_aggregations( + aggregations = await self._relations_handler.get_bundled_aggregations( itertools.chain(events_before, (event,), events_after), user.to_string(), ) diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index aa16e417eb..30eddda65f 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -54,6 +54,7 @@ class SearchHandler: self.clock = hs.get_clock() self.hs = hs self._event_serializer = hs.get_event_client_serializer() + self._relations_handler = hs.get_relations_handler() self.storage = hs.get_storage() self.state_store = self.storage.state self.auth = hs.get_auth() @@ -354,7 +355,7 @@ class SearchHandler: aggregations = None if self._msc3666_enabled: - aggregations = await self.store.get_bundled_aggregations( + aggregations = await self._relations_handler.get_bundled_aggregations( # Generate an iterable of EventBase for all the events that will be # returned, including contextual events. itertools.chain( diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index c9d6a18bd7..6c569cfb1c 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -33,11 +33,11 @@ from synapse.api.filtering import FilterCollection from synapse.api.presence import UserPresenceState from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase +from synapse.handlers.relations import BundledAggregations from synapse.logging.context import current_context from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span from synapse.push.clientformat import format_push_rules_for_user from synapse.storage.databases.main.event_push_actions import NotifCounts -from synapse.storage.databases.main.relations import BundledAggregations from synapse.storage.roommember import MemberSummary from synapse.storage.state import StateFilter from synapse.types import ( @@ -269,6 +269,7 @@ class SyncHandler: self.store = hs.get_datastores().main self.notifier = hs.get_notifier() self.presence_handler = hs.get_presence_handler() + self._relations_handler = hs.get_relations_handler() self.event_sources = hs.get_event_sources() self.clock = hs.get_clock() self.state = hs.get_state_handler() @@ -638,8 +639,10 @@ class SyncHandler: # as clients will have all the necessary information. bundled_aggregations = None if limited or newly_joined_room: - bundled_aggregations = await self.store.get_bundled_aggregations( - recents, sync_config.user.to_string() + bundled_aggregations = ( + await self._relations_handler.get_bundled_aggregations( + recents, sync_config.user.to_string() + ) ) return TimelineBatch( diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 8a06ab8c5f..47e152c8cc 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -645,6 +645,7 @@ class RoomEventServlet(RestServlet): self._store = hs.get_datastores().main self.event_handler = hs.get_event_handler() self._event_serializer = hs.get_event_client_serializer() + self._relations_handler = hs.get_relations_handler() self.auth = hs.get_auth() async def on_GET( @@ -663,7 +664,7 @@ class RoomEventServlet(RestServlet): if event: # Ensure there are bundled aggregations available. - aggregations = await self._store.get_bundled_aggregations( + aggregations = await self._relations_handler.get_bundled_aggregations( [event], requester.user.to_string() ) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index af2334a65e..b2295fd51f 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -27,7 +27,6 @@ from typing import ( ) import attr -from frozendict import frozendict from synapse.api.constants import RelationTypes from synapse.events import EventBase @@ -41,45 +40,15 @@ from synapse.storage.database import ( from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.engines import PostgresEngine from synapse.storage.relations import AggregationPaginationToken, PaginationChunk -from synapse.types import JsonDict, RoomStreamToken, StreamToken +from synapse.types import RoomStreamToken, StreamToken from synapse.util.caches.descriptors import cached, cachedList if TYPE_CHECKING: from synapse.server import HomeServer - from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) -@attr.s(slots=True, frozen=True, auto_attribs=True) -class _ThreadAggregation: - # The latest event in the thread. - latest_event: EventBase - # The latest edit to the latest event in the thread. - latest_edit: Optional[EventBase] - # The total number of events in the thread. - count: int - # True if the current user has sent an event to the thread. - current_user_participated: bool - - -@attr.s(slots=True, auto_attribs=True) -class BundledAggregations: - """ - The bundled aggregations for an event. - - Some values require additional processing during serialization. - """ - - annotations: Optional[JsonDict] = None - references: Optional[JsonDict] = None - replace: Optional[EventBase] = None - thread: Optional[_ThreadAggregation] = None - - def __bool__(self) -> bool: - return bool(self.annotations or self.references or self.replace or self.thread) - - class RelationsWorkerStore(SQLBaseStore): def __init__( self, @@ -384,7 +353,7 @@ class RelationsWorkerStore(SQLBaseStore): raise NotImplementedError() @cachedList(cached_method_name="get_applicable_edit", list_name="event_ids") - async def _get_applicable_edits( + async def get_applicable_edits( self, event_ids: Collection[str] ) -> Dict[str, Optional[EventBase]]: """Get the most recent edit (if any) that has happened for the given @@ -473,7 +442,7 @@ class RelationsWorkerStore(SQLBaseStore): raise NotImplementedError() @cachedList(cached_method_name="get_thread_summary", list_name="event_ids") - async def _get_thread_summaries( + async def get_thread_summaries( self, event_ids: Collection[str] ) -> Dict[str, Optional[Tuple[int, EventBase, Optional[EventBase]]]]: """Get the number of threaded replies, the latest reply (if any), and the latest edit for that reply for the given event. @@ -587,7 +556,7 @@ class RelationsWorkerStore(SQLBaseStore): latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined] # Check to see if any of those events are edited. - latest_edits = await self._get_applicable_edits(latest_event_ids.values()) + latest_edits = await self.get_applicable_edits(latest_event_ids.values()) # Map to the event IDs to the thread summary. # @@ -610,7 +579,7 @@ class RelationsWorkerStore(SQLBaseStore): raise NotImplementedError() @cachedList(cached_method_name="get_thread_participated", list_name="event_ids") - async def _get_threads_participated( + async def get_threads_participated( self, event_ids: Collection[str], user_id: str ) -> Dict[str, bool]: """Get whether the requesting user participated in the given threads. @@ -766,116 +735,6 @@ class RelationsWorkerStore(SQLBaseStore): "get_if_user_has_annotated_event", _get_if_user_has_annotated_event ) - async def _get_bundled_aggregation_for_event( - self, event: EventBase, user_id: str - ) -> Optional[BundledAggregations]: - """Generate bundled aggregations for an event. - - Note that this does not use a cache, but depends on cached methods. - - Args: - event: The event to calculate bundled aggregations for. - user_id: The user requesting the bundled aggregations. - - Returns: - The bundled aggregations for an event, if bundled aggregations are - enabled and the event can have bundled aggregations. - """ - - # Do not bundle aggregations for an event which represents an edit or an - # annotation. It does not make sense for them to have related events. - relates_to = event.content.get("m.relates_to") - if isinstance(relates_to, (dict, frozendict)): - relation_type = relates_to.get("rel_type") - if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE): - return None - - event_id = event.event_id - room_id = event.room_id - - # The bundled aggregations to include, a mapping of relation type to a - # type-specific value. Some types include the direct return type here - # while others need more processing during serialization. - aggregations = BundledAggregations() - - annotations = await self.get_aggregation_groups_for_event(event_id, room_id) - if annotations.chunk: - aggregations.annotations = await annotations.to_dict( - cast("DataStore", self) - ) - - references = await self.get_relations_for_event( - event_id, event, room_id, RelationTypes.REFERENCE, direction="f" - ) - if references.chunk: - aggregations.references = await references.to_dict(cast("DataStore", self)) - - # Store the bundled aggregations in the event metadata for later use. - return aggregations - - async def get_bundled_aggregations( - self, events: Iterable[EventBase], user_id: str - ) -> Dict[str, BundledAggregations]: - """Generate bundled aggregations for events. - - Args: - events: The iterable of events to calculate bundled aggregations for. - user_id: The user requesting the bundled aggregations. - - Returns: - A map of event ID to the bundled aggregation for the event. Not all - events may have bundled aggregations in the results. - """ - # De-duplicate events by ID to handle the same event requested multiple times. - # - # State events do not get bundled aggregations. - events_by_id = { - event.event_id: event for event in events if not event.is_state() - } - - # event ID -> bundled aggregation in non-serialized form. - results: Dict[str, BundledAggregations] = {} - - # Fetch other relations per event. - for event in events_by_id.values(): - event_result = await self._get_bundled_aggregation_for_event(event, user_id) - if event_result: - results[event.event_id] = event_result - - # Fetch any edits (but not for redacted events). - edits = await self._get_applicable_edits( - [ - event_id - for event_id, event in events_by_id.items() - if not event.internal_metadata.is_redacted() - ] - ) - for event_id, edit in edits.items(): - results.setdefault(event_id, BundledAggregations()).replace = edit - - # Fetch thread summaries. - summaries = await self._get_thread_summaries(events_by_id.keys()) - # Only fetch participated for a limited selection based on what had - # summaries. - participated = await self._get_threads_participated( - [event_id for event_id, summary in summaries.items() if summary], user_id - ) - for event_id, summary in summaries.items(): - if summary: - thread_count, latest_thread_event, edit = summary - results.setdefault( - event_id, BundledAggregations() - ).thread = _ThreadAggregation( - latest_event=latest_thread_event, - latest_edit=edit, - count=thread_count, - # If there's a thread summary it must also exist in the - # participated dictionary. - current_user_participated=participated[event_id], - ) - - return results - class RelationsStore(RelationsWorkerStore): pass -- cgit 1.5.1 From 516d092ff95d02c0bb2133c9316a1fb4ff2f5072 Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Wed, 23 Mar 2022 12:19:20 +0100 Subject: Rename shared_rooms to mutual_rooms (#12036) Co-authored-by: reivilibre --- changelog.d/12036.misc | 1 + synapse/rest/__init__.py | 4 +- synapse/rest/client/mutual_rooms.py | 76 ++++++++++++ synapse/rest/client/shared_rooms.py | 75 ------------ synapse/storage/databases/main/user_directory.py | 6 +- tests/rest/client/test_mutual_rooms.py | 146 +++++++++++++++++++++++ tests/rest/client/test_shared_rooms.py | 146 ----------------------- 7 files changed, 228 insertions(+), 226 deletions(-) create mode 100644 changelog.d/12036.misc create mode 100644 synapse/rest/client/mutual_rooms.py delete mode 100644 synapse/rest/client/shared_rooms.py create mode 100644 tests/rest/client/test_mutual_rooms.py delete mode 100644 tests/rest/client/test_shared_rooms.py (limited to 'synapse/rest/client') diff --git a/changelog.d/12036.misc b/changelog.d/12036.misc new file mode 100644 index 0000000000..d2996730cc --- /dev/null +++ b/changelog.d/12036.misc @@ -0,0 +1 @@ +Rename `shared_rooms` to `mutual_rooms` (MSC2666), as per proposal changes. \ No newline at end of file diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 762808a571..57c4773edc 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -32,6 +32,7 @@ from synapse.rest.client import ( knock, login as v1_login, logout, + mutual_rooms, notifications, openid, password_policy, @@ -49,7 +50,6 @@ from synapse.rest.client import ( room_keys, room_upgrade_rest_servlet, sendtodevice, - shared_rooms, sync, tags, thirdparty, @@ -132,4 +132,4 @@ class ClientRestResource(JsonResource): admin.register_servlets_for_client_rest_resource(hs, client_resource) # unstable - shared_rooms.register_servlets(hs, client_resource) + mutual_rooms.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/mutual_rooms.py b/synapse/rest/client/mutual_rooms.py new file mode 100644 index 0000000000..d3872a76c8 --- /dev/null +++ b/synapse/rest/client/mutual_rooms.py @@ -0,0 +1,76 @@ +# Copyright 2020 Half-Shot +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import TYPE_CHECKING, Tuple + +from synapse.api.errors import Codes, SynapseError +from synapse.http.server import HttpServer +from synapse.http.servlet import RestServlet +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict, UserID + +from ._base import client_patterns + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class UserMutualRoomsServlet(RestServlet): + """ + GET /uk.half-shot.msc2666/user/mutual_rooms/{user_id} HTTP/1.1 + """ + + PATTERNS = client_patterns( + "/uk.half-shot.msc2666/user/mutual_rooms/(?P[^/]*)", + releases=(), # This is an unstable feature + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastores().main + self.user_directory_active = hs.config.server.update_user_directory + + async def on_GET( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: + + if not self.user_directory_active: + raise SynapseError( + code=400, + msg="The user directory is disabled on this server. Cannot determine shared rooms.", + errcode=Codes.FORBIDDEN, + ) + + UserID.from_string(user_id) + + requester = await self.auth.get_user_by_req(request) + if user_id == requester.user.to_string(): + raise SynapseError( + code=400, + msg="You cannot request a list of shared rooms with yourself", + errcode=Codes.FORBIDDEN, + ) + + rooms = await self.store.get_mutual_rooms_for_users( + requester.user.to_string(), user_id + ) + + return 200, {"joined": list(rooms)} + + +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: + UserMutualRoomsServlet(hs).register(http_server) diff --git a/synapse/rest/client/shared_rooms.py b/synapse/rest/client/shared_rooms.py deleted file mode 100644 index e669fa7890..0000000000 --- a/synapse/rest/client/shared_rooms.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright 2020 Half-Shot -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -from typing import TYPE_CHECKING, Tuple - -from synapse.api.errors import Codes, SynapseError -from synapse.http.server import HttpServer -from synapse.http.servlet import RestServlet -from synapse.http.site import SynapseRequest -from synapse.types import JsonDict, UserID - -from ._base import client_patterns - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -class UserSharedRoomsServlet(RestServlet): - """ - GET /uk.half-shot.msc2666/user/shared_rooms/{user_id} HTTP/1.1 - """ - - PATTERNS = client_patterns( - "/uk.half-shot.msc2666/user/shared_rooms/(?P[^/]*)", - releases=(), # This is an unstable feature - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.store = hs.get_datastores().main - self.user_directory_active = hs.config.server.update_user_directory - - async def on_GET( - self, request: SynapseRequest, user_id: str - ) -> Tuple[int, JsonDict]: - - if not self.user_directory_active: - raise SynapseError( - code=400, - msg="The user directory is disabled on this server. Cannot determine shared rooms.", - errcode=Codes.FORBIDDEN, - ) - - UserID.from_string(user_id) - - requester = await self.auth.get_user_by_req(request) - if user_id == requester.user.to_string(): - raise SynapseError( - code=400, - msg="You cannot request a list of shared rooms with yourself", - errcode=Codes.FORBIDDEN, - ) - rooms = await self.store.get_shared_rooms_for_users( - requester.user.to_string(), user_id - ) - - return 200, {"joined": list(rooms)} - - -def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: - UserSharedRoomsServlet(hs).register(http_server) diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 55cc9178f0..0595df01d3 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -730,7 +730,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): users.update(rows) return list(users) - async def get_shared_rooms_for_users( + async def get_mutual_rooms_for_users( self, user_id: str, other_user_id: str ) -> Set[str]: """ @@ -744,7 +744,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): A set of room ID's that the users share. """ - def _get_shared_rooms_for_users_txn( + def _get_mutual_rooms_for_users_txn( txn: LoggingTransaction, ) -> List[Dict[str, str]]: txn.execute( @@ -768,7 +768,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): return rows rows = await self.db_pool.runInteraction( - "get_shared_rooms_for_users", _get_shared_rooms_for_users_txn + "get_mutual_rooms_for_users", _get_mutual_rooms_for_users_txn ) return {row["room_id"] for row in rows} diff --git a/tests/rest/client/test_mutual_rooms.py b/tests/rest/client/test_mutual_rooms.py new file mode 100644 index 0000000000..7b7d283bb6 --- /dev/null +++ b/tests/rest/client/test_mutual_rooms.py @@ -0,0 +1,146 @@ +# Copyright 2020 Half-Shot +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.test.proto_helpers import MemoryReactor + +import synapse.rest.admin +from synapse.rest.client import login, mutual_rooms, room +from synapse.server import HomeServer +from synapse.util import Clock + +from tests import unittest +from tests.server import FakeChannel + + +class UserMutualRoomsTest(unittest.HomeserverTestCase): + """ + Tests the UserMutualRoomsServlet. + """ + + servlets = [ + login.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + mutual_rooms.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + config = self.default_config() + config["update_user_directory"] = True + return self.setup_test_homeserver(config=config) + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.handler = hs.get_user_directory_handler() + + def _get_mutual_rooms(self, token: str, other_user: str) -> FakeChannel: + return self.make_request( + "GET", + "/_matrix/client/unstable/uk.half-shot.msc2666/user/mutual_rooms/%s" + % other_user, + access_token=token, + ) + + def test_shared_room_list_public(self) -> None: + """ + A room should show up in the shared list of rooms between two users + if it is public. + """ + self._check_mutual_rooms_with(room_one_is_public=True, room_two_is_public=True) + + def test_shared_room_list_private(self) -> None: + """ + A room should show up in the shared list of rooms between two users + if it is private. + """ + self._check_mutual_rooms_with( + room_one_is_public=False, room_two_is_public=False + ) + + def test_shared_room_list_mixed(self) -> None: + """ + The shared room list between two users should contain both public and private + rooms. + """ + self._check_mutual_rooms_with(room_one_is_public=True, room_two_is_public=False) + + def _check_mutual_rooms_with( + self, room_one_is_public: bool, room_two_is_public: bool + ) -> None: + """Checks that shared public or private rooms between two users appear in + their shared room lists + """ + u1 = self.register_user("user1", "pass") + u1_token = self.login(u1, "pass") + u2 = self.register_user("user2", "pass") + u2_token = self.login(u2, "pass") + + # Create a room. user1 invites user2, who joins + room_id_one = self.helper.create_room_as( + u1, is_public=room_one_is_public, tok=u1_token + ) + self.helper.invite(room_id_one, src=u1, targ=u2, tok=u1_token) + self.helper.join(room_id_one, user=u2, tok=u2_token) + + # Check shared rooms from user1's perspective. + # We should see the one room in common + channel = self._get_mutual_rooms(u1_token, u2) + self.assertEqual(200, channel.code, channel.result) + self.assertEqual(len(channel.json_body["joined"]), 1) + self.assertEqual(channel.json_body["joined"][0], room_id_one) + + # Create another room and invite user2 to it + room_id_two = self.helper.create_room_as( + u1, is_public=room_two_is_public, tok=u1_token + ) + self.helper.invite(room_id_two, src=u1, targ=u2, tok=u1_token) + self.helper.join(room_id_two, user=u2, tok=u2_token) + + # Check shared rooms again. We should now see both rooms. + channel = self._get_mutual_rooms(u1_token, u2) + self.assertEqual(200, channel.code, channel.result) + self.assertEqual(len(channel.json_body["joined"]), 2) + for room_id_id in channel.json_body["joined"]: + self.assertIn(room_id_id, [room_id_one, room_id_two]) + + def test_shared_room_list_after_leave(self) -> None: + """ + A room should no longer be considered shared if the other + user has left it. + """ + u1 = self.register_user("user1", "pass") + u1_token = self.login(u1, "pass") + u2 = self.register_user("user2", "pass") + u2_token = self.login(u2, "pass") + + room = self.helper.create_room_as(u1, is_public=True, tok=u1_token) + self.helper.invite(room, src=u1, targ=u2, tok=u1_token) + self.helper.join(room, user=u2, tok=u2_token) + + # Assert user directory is not empty + channel = self._get_mutual_rooms(u1_token, u2) + self.assertEqual(200, channel.code, channel.result) + self.assertEqual(len(channel.json_body["joined"]), 1) + self.assertEqual(channel.json_body["joined"][0], room) + + self.helper.leave(room, user=u1, tok=u1_token) + + # Check user1's view of shared rooms with user2 + channel = self._get_mutual_rooms(u1_token, u2) + self.assertEqual(200, channel.code, channel.result) + self.assertEqual(len(channel.json_body["joined"]), 0) + + # Check user2's view of shared rooms with user1 + channel = self._get_mutual_rooms(u2_token, u1) + self.assertEqual(200, channel.code, channel.result) + self.assertEqual(len(channel.json_body["joined"]), 0) diff --git a/tests/rest/client/test_shared_rooms.py b/tests/rest/client/test_shared_rooms.py deleted file mode 100644 index 3818b7b14b..0000000000 --- a/tests/rest/client/test_shared_rooms.py +++ /dev/null @@ -1,146 +0,0 @@ -# Copyright 2020 Half-Shot -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from twisted.test.proto_helpers import MemoryReactor - -import synapse.rest.admin -from synapse.rest.client import login, room, shared_rooms -from synapse.server import HomeServer -from synapse.util import Clock - -from tests import unittest -from tests.server import FakeChannel - - -class UserSharedRoomsTest(unittest.HomeserverTestCase): - """ - Tests the UserSharedRoomsServlet. - """ - - servlets = [ - login.register_servlets, - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - shared_rooms.register_servlets, - ] - - def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() - config["update_user_directory"] = True - return self.setup_test_homeserver(config=config) - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastores().main - self.handler = hs.get_user_directory_handler() - - def _get_shared_rooms(self, token: str, other_user: str) -> FakeChannel: - return self.make_request( - "GET", - "/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/%s" - % other_user, - access_token=token, - ) - - def test_shared_room_list_public(self) -> None: - """ - A room should show up in the shared list of rooms between two users - if it is public. - """ - self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=True) - - def test_shared_room_list_private(self) -> None: - """ - A room should show up in the shared list of rooms between two users - if it is private. - """ - self._check_shared_rooms_with( - room_one_is_public=False, room_two_is_public=False - ) - - def test_shared_room_list_mixed(self) -> None: - """ - The shared room list between two users should contain both public and private - rooms. - """ - self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=False) - - def _check_shared_rooms_with( - self, room_one_is_public: bool, room_two_is_public: bool - ) -> None: - """Checks that shared public or private rooms between two users appear in - their shared room lists - """ - u1 = self.register_user("user1", "pass") - u1_token = self.login(u1, "pass") - u2 = self.register_user("user2", "pass") - u2_token = self.login(u2, "pass") - - # Create a room. user1 invites user2, who joins - room_id_one = self.helper.create_room_as( - u1, is_public=room_one_is_public, tok=u1_token - ) - self.helper.invite(room_id_one, src=u1, targ=u2, tok=u1_token) - self.helper.join(room_id_one, user=u2, tok=u2_token) - - # Check shared rooms from user1's perspective. - # We should see the one room in common - channel = self._get_shared_rooms(u1_token, u2) - self.assertEqual(200, channel.code, channel.result) - self.assertEqual(len(channel.json_body["joined"]), 1) - self.assertEqual(channel.json_body["joined"][0], room_id_one) - - # Create another room and invite user2 to it - room_id_two = self.helper.create_room_as( - u1, is_public=room_two_is_public, tok=u1_token - ) - self.helper.invite(room_id_two, src=u1, targ=u2, tok=u1_token) - self.helper.join(room_id_two, user=u2, tok=u2_token) - - # Check shared rooms again. We should now see both rooms. - channel = self._get_shared_rooms(u1_token, u2) - self.assertEqual(200, channel.code, channel.result) - self.assertEqual(len(channel.json_body["joined"]), 2) - for room_id_id in channel.json_body["joined"]: - self.assertIn(room_id_id, [room_id_one, room_id_two]) - - def test_shared_room_list_after_leave(self) -> None: - """ - A room should no longer be considered shared if the other - user has left it. - """ - u1 = self.register_user("user1", "pass") - u1_token = self.login(u1, "pass") - u2 = self.register_user("user2", "pass") - u2_token = self.login(u2, "pass") - - room = self.helper.create_room_as(u1, is_public=True, tok=u1_token) - self.helper.invite(room, src=u1, targ=u2, tok=u1_token) - self.helper.join(room, user=u2, tok=u2_token) - - # Assert user directory is not empty - channel = self._get_shared_rooms(u1_token, u2) - self.assertEqual(200, channel.code, channel.result) - self.assertEqual(len(channel.json_body["joined"]), 1) - self.assertEqual(channel.json_body["joined"][0], room) - - self.helper.leave(room, user=u1, tok=u1_token) - - # Check user1's view of shared rooms with user2 - channel = self._get_shared_rooms(u1_token, u2) - self.assertEqual(200, channel.code, channel.result) - self.assertEqual(len(channel.json_body["joined"]), 0) - - # Check user2's view of shared rooms with user1 - channel = self._get_shared_rooms(u2_token, u1) - self.assertEqual(200, channel.code, channel.result) - self.assertEqual(len(channel.json_body["joined"]), 0) -- cgit 1.5.1 From c5776780f0621eb9440277cfd9ffd41b1443f34e Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Wed, 23 Mar 2022 13:47:07 +0100 Subject: Remove mutual_rooms `update_user_directory` check, and add extra documentation (#12038) Resolves #10339 --- changelog.d/12038.misc | 1 + docs/workers.md | 11 ++++++++++- synapse/rest/client/mutual_rooms.py | 10 ++++++---- 3 files changed, 17 insertions(+), 5 deletions(-) create mode 100644 changelog.d/12038.misc (limited to 'synapse/rest/client') diff --git a/changelog.d/12038.misc b/changelog.d/12038.misc new file mode 100644 index 0000000000..e2a65726b6 --- /dev/null +++ b/changelog.d/12038.misc @@ -0,0 +1 @@ +Remove check on `update_user_directory` for shared rooms handler (MSC2666), and update/expand documentation. \ No newline at end of file diff --git a/docs/workers.md b/docs/workers.md index 9eb4194e4d..8ac95e39bb 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -528,10 +528,19 @@ the following regular expressions: ^/_matrix/client/(r0|v3|unstable)/user_directory/search$ -When using this worker you must also set `update_user_directory: False` in the +When using this worker you must also set `update_user_directory: false` in the shared configuration file to stop the main synapse running background jobs related to updating the user directory. +Above endpoint is not *required* to be routed to this worker. By default, +`update_user_directory` is set to `true`, which means the main process +will handle updates. All workers configured with `client` can handle the above +endpoint as long as either this worker or the main process are configured to +handle it, and are online. + +If `update_user_directory` is set to `false`, and this worker is not running, +the above endpoint may give outdated results. + ### `synapse.app.frontend_proxy` Proxies some frequently-requested client endpoints to add caching and remove diff --git a/synapse/rest/client/mutual_rooms.py b/synapse/rest/client/mutual_rooms.py index d3872a76c8..27bfaf0b29 100644 --- a/synapse/rest/client/mutual_rooms.py +++ b/synapse/rest/client/mutual_rooms.py @@ -42,17 +42,19 @@ class UserMutualRoomsServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastores().main - self.user_directory_active = hs.config.server.update_user_directory + self.user_directory_search_enabled = ( + hs.config.userdirectory.user_directory_search_enabled + ) async def on_GET( self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: - if not self.user_directory_active: + if not self.user_directory_search_enabled: raise SynapseError( code=400, - msg="The user directory is disabled on this server. Cannot determine shared rooms.", - errcode=Codes.FORBIDDEN, + msg="User directory searching is disabled. Cannot determine shared rooms.", + errcode=Codes.UNKNOWN, ) UserID.from_string(user_id) -- cgit 1.5.1 From 14662d3c18217ba9e865b56203829e88d2ed4532 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Fri, 25 Mar 2022 09:21:06 -0500 Subject: Refactor `create_new_client_event` to use a new parameter, `state_event_ids`, which accurately describes the usage with MSC2716 instead of abusing `auth_event_ids` (#12083) Spawned from https://github.com/matrix-org/synapse/pull/10975#discussion_r813183430 Part of [MSC2716](https://github.com/matrix-org/matrix-spec-proposals/pull/2716) --- changelog.d/12083.misc | 1 + synapse/handlers/message.py | 63 +++++++++++++++------ synapse/handlers/room_batch.py | 112 ++++++++++++++++++++++++-------------- synapse/handlers/room_member.py | 31 +++++++++++ synapse/rest/client/room_batch.py | 23 +++++--- 5 files changed, 165 insertions(+), 65 deletions(-) create mode 100644 changelog.d/12083.misc (limited to 'synapse/rest/client') diff --git a/changelog.d/12083.misc b/changelog.d/12083.misc new file mode 100644 index 0000000000..88fd6b92ee --- /dev/null +++ b/changelog.d/12083.misc @@ -0,0 +1 @@ +Refactor `create_new_client_event` to use a new parameter, `state_event_ids`, which accurately describes the usage with [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) instead of abusing `auth_event_ids`. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index f9544fe7fb..1c4fb4360a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -493,6 +493,7 @@ class EventCreationHandler: allow_no_prev_events: bool = False, prev_event_ids: Optional[List[str]] = None, auth_event_ids: Optional[List[str]] = None, + state_event_ids: Optional[List[str]] = None, require_consent: bool = True, outlier: bool = False, historical: bool = False, @@ -527,6 +528,15 @@ class EventCreationHandler: If non-None, prev_event_ids must also be provided. + state_event_ids: + The full state at a given event. This is used particularly by the MSC2716 + /batch_send endpoint. One use case is with insertion events which float at + the beginning of a historical batch and don't have any `prev_events` to + derive from; we add all of these state events as the explicit state so the + rest of the historical batch can inherit the same state and state_group. + This should normally be left as None, which will cause the auth_event_ids + to be calculated based on the room state at the prev_events. + require_consent: Whether to check if the requester has consented to the privacy policy. @@ -612,6 +622,7 @@ class EventCreationHandler: allow_no_prev_events=allow_no_prev_events, prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids, + state_event_ids=state_event_ids, depth=depth, ) @@ -772,6 +783,7 @@ class EventCreationHandler: allow_no_prev_events: bool = False, prev_event_ids: Optional[List[str]] = None, auth_event_ids: Optional[List[str]] = None, + state_event_ids: Optional[List[str]] = None, ratelimit: bool = True, txn_id: Optional[str] = None, ignore_shadow_ban: bool = False, @@ -801,6 +813,14 @@ class EventCreationHandler: based on the room state at the prev_events. If non-None, prev_event_ids must also be provided. + state_event_ids: + The full state at a given event. This is used particularly by the MSC2716 + /batch_send endpoint. One use case is with insertion events which float at + the beginning of a historical batch and don't have any `prev_events` to + derive from; we add all of these state events as the explicit state so the + rest of the historical batch can inherit the same state and state_group. + This should normally be left as None, which will cause the auth_event_ids + to be calculated based on the room state at the prev_events. ratelimit: Whether to rate limit this send. txn_id: The transaction ID. ignore_shadow_ban: True if shadow-banned users should be allowed to @@ -856,8 +876,10 @@ class EventCreationHandler: requester, event_dict, txn_id=txn_id, + allow_no_prev_events=allow_no_prev_events, prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids, + state_event_ids=state_event_ids, outlier=outlier, historical=historical, depth=depth, @@ -893,6 +915,7 @@ class EventCreationHandler: allow_no_prev_events: bool = False, prev_event_ids: Optional[List[str]] = None, auth_event_ids: Optional[List[str]] = None, + state_event_ids: Optional[List[str]] = None, depth: Optional[int] = None, ) -> Tuple[EventBase, EventContext]: """Create a new event for a local client @@ -915,6 +938,15 @@ class EventCreationHandler: Should normally be left as None, which will cause them to be calculated based on the room state at the prev_events. + state_event_ids: + The full state at a given event. This is used particularly by the MSC2716 + /batch_send endpoint. One use case is with insertion events which float at + the beginning of a historical batch and don't have any `prev_events` to + derive from; we add all of these state events as the explicit state so the + rest of the historical batch can inherit the same state and state_group. + This should normally be left as None, which will cause the auth_event_ids + to be calculated based on the room state at the prev_events. + depth: Override the depth used to order the event in the DAG. Should normally be set to None, which will cause the depth to be calculated based on the prev_events. @@ -922,31 +954,26 @@ class EventCreationHandler: Returns: Tuple of created event, context """ - # Strip down the auth_event_ids to only what we need to auth the event. + # Strip down the state_event_ids to only what we need to auth the event. # For example, we don't need extra m.room.member that don't match event.sender - full_state_ids_at_event = None - if auth_event_ids is not None: - # If auth events are provided, prev events must be also. + if state_event_ids is not None: + # Do a quick check to make sure that prev_event_ids is present to + # make the type-checking around `builder.build` happy. # prev_event_ids could be an empty array though. assert prev_event_ids is not None - # Copy the full auth state before it stripped down - full_state_ids_at_event = auth_event_ids.copy() - temp_event = await builder.build( prev_event_ids=prev_event_ids, - auth_event_ids=auth_event_ids, + auth_event_ids=state_event_ids, depth=depth, ) - auth_events = await self.store.get_events_as_list(auth_event_ids) + state_events = await self.store.get_events_as_list(state_event_ids) # Create a StateMap[str] - auth_event_state_map = { - (e.type, e.state_key): e.event_id for e in auth_events - } - # Actually strip down and use the necessary auth events + state_map = {(e.type, e.state_key): e.event_id for e in state_events} + # Actually strip down and only use the necessary auth events auth_event_ids = self._event_auth_handler.compute_auth_events( event=temp_event, - current_state_ids=auth_event_state_map, + current_state_ids=state_map, for_verification=False, ) @@ -989,12 +1016,16 @@ class EventCreationHandler: context = EventContext.for_outlier() elif ( event.type == EventTypes.MSC2716_INSERTION - and full_state_ids_at_event + and state_event_ids and builder.internal_metadata.is_historical() ): + # Add explicit state to the insertion event so it has state to derive + # from even though it's floating with no `prev_events`. The rest of + # the batch can derive from this state and state_group. + # # TODO(faster_joins): figure out how this works, and make sure that the # old state is complete. - old_state = await self.store.get_events_as_list(full_state_ids_at_event) + old_state = await self.store.get_events_as_list(state_event_ids) context = await self.state.compute_event_context(event, old_state=old_state) else: context = await self.state.compute_event_context(event) diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py index abbf7b7b27..a0255bd143 100644 --- a/synapse/handlers/room_batch.py +++ b/synapse/handlers/room_batch.py @@ -121,12 +121,11 @@ class RoomBatchHandler: return create_requester(user_id, app_service=app_service) - async def get_most_recent_auth_event_ids_from_event_id_list( + async def get_most_recent_full_state_ids_from_event_id_list( self, event_ids: List[str] ) -> List[str]: - """Find the most recent auth event ids (derived from state events) that - allowed that message to be sent. We will use this as a base - to auth our historical messages against. + """Find the most recent event_id and grab the full state at that event. + We will use this as a base to auth our historical messages against. Args: event_ids: List of event ID's to look at @@ -136,38 +135,37 @@ class RoomBatchHandler: """ ( - most_recent_prev_event_id, + most_recent_event_id, _, ) = await self.store.get_max_depth_of(event_ids) # mapping from (type, state_key) -> state_event_id prev_state_map = await self.state_store.get_state_ids_for_event( - most_recent_prev_event_id + most_recent_event_id ) # List of state event ID's - prev_state_ids = list(prev_state_map.values()) - auth_event_ids = prev_state_ids + full_state_ids = list(prev_state_map.values()) - return auth_event_ids + return full_state_ids async def persist_state_events_at_start( self, state_events_at_start: List[JsonDict], room_id: str, - initial_auth_event_ids: List[str], + initial_state_event_ids: List[str], app_service_requester: Requester, ) -> List[str]: """Takes all `state_events_at_start` event dictionaries and creates/persists - them as floating state events which don't resolve into the current room state. - They are floating because they reference a fake prev_event which doesn't connect - to the normal DAG at all. + them in a floating state event chain which don't resolve into the current room + state. They are floating because they reference no prev_events and are marked + as outliers which disconnects them from the normal DAG. Args: state_events_at_start: room_id: Room where you want the events persisted in. - initial_auth_event_ids: These will be the auth_events for the first - state event created. Each event created afterwards will be - added to the list of auth events for the next state event - created. + initial_state_event_ids: + The base set of state for the historical batch which the floating + state chain will derive from. This should probably be the state + from the `prev_event` defined by `/batch_send?prev_event_id=$abc`. app_service_requester: The requester of an application service. Returns: @@ -176,7 +174,7 @@ class RoomBatchHandler: assert app_service_requester.app_service state_event_ids_at_start = [] - auth_event_ids = initial_auth_event_ids.copy() + state_event_ids = initial_state_event_ids.copy() # Make the state events float off on their own by specifying no # prev_events for the first one in the chain so we don't have a bunch of @@ -189,9 +187,7 @@ class RoomBatchHandler: ) logger.debug( - "RoomBatchSendEventRestServlet inserting state_event=%s, auth_event_ids=%s", - state_event, - auth_event_ids, + "RoomBatchSendEventRestServlet inserting state_event=%s", state_event ) event_dict = { @@ -217,16 +213,26 @@ class RoomBatchHandler: room_id=room_id, action=membership, content=event_dict["content"], + # Mark as an outlier to disconnect it from the normal DAG + # and not show up between batches of history. outlier=True, historical=True, - # Only the first event in the chain should be floating. + # Only the first event in the state chain should be floating. # The rest should hang off each other in a chain. allow_no_prev_events=index == 0, prev_event_ids=prev_event_ids_for_state_chain, + # Since each state event is marked as an outlier, the + # `EventContext.for_outlier()` won't have any `state_ids` + # set and therefore can't derive any state even though the + # prev_events are set. Also since the first event in the + # state chain is floating with no `prev_events`, it can't + # derive state from anywhere automatically. So we need to + # set some state explicitly. + # # Make sure to use a copy of this list because we modify it # later in the loop here. Otherwise it will be the same # reference and also update in the event when we append later. - auth_event_ids=auth_event_ids.copy(), + state_event_ids=state_event_ids.copy(), ) else: # TODO: Add some complement tests that adds state that is not member joins @@ -240,21 +246,31 @@ class RoomBatchHandler: state_event["sender"], app_service_requester.app_service ), event_dict, + # Mark as an outlier to disconnect it from the normal DAG + # and not show up between batches of history. outlier=True, historical=True, - # Only the first event in the chain should be floating. + # Only the first event in the state chain should be floating. # The rest should hang off each other in a chain. allow_no_prev_events=index == 0, prev_event_ids=prev_event_ids_for_state_chain, + # Since each state event is marked as an outlier, the + # `EventContext.for_outlier()` won't have any `state_ids` + # set and therefore can't derive any state even though the + # prev_events are set. Also since the first event in the + # state chain is floating with no `prev_events`, it can't + # derive state from anywhere automatically. So we need to + # set some state explicitly. + # # Make sure to use a copy of this list because we modify it # later in the loop here. Otherwise it will be the same # reference and also update in the event when we append later. - auth_event_ids=auth_event_ids.copy(), + state_event_ids=state_event_ids.copy(), ) event_id = event.event_id state_event_ids_at_start.append(event_id) - auth_event_ids.append(event_id) + state_event_ids.append(event_id) # Connect all the state in a floating chain prev_event_ids_for_state_chain = [event_id] @@ -265,7 +281,7 @@ class RoomBatchHandler: events_to_create: List[JsonDict], room_id: str, inherited_depth: int, - auth_event_ids: List[str], + initial_state_event_ids: List[str], app_service_requester: Requester, ) -> List[str]: """Create and persists all events provided sequentially. Handles the @@ -281,8 +297,10 @@ class RoomBatchHandler: room_id: Room where you want the events persisted in. inherited_depth: The depth to create the events at (you will probably by calling inherit_depth_from_prev_ids(...)). - auth_event_ids: Define which events allow you to create the given - event in the room. + initial_state_event_ids: + This is used to set explicit state for the insertion event at + the start of the historical batch since it's floating with no + prev_events to derive state from automatically. app_service_requester: The requester of an application service. Returns: @@ -290,6 +308,11 @@ class RoomBatchHandler: """ assert app_service_requester.app_service + # We expect the first event in a historical batch to be an insertion event + assert events_to_create[0]["type"] == EventTypes.MSC2716_INSERTION + # We expect the last event in a historical batch to be an batch event + assert events_to_create[-1]["type"] == EventTypes.MSC2716_BATCH + # Make the historical event chain float off on its own by specifying no # prev_events for the first event in the chain which causes the HS to # ask for the state at the start of the batch later. @@ -321,11 +344,16 @@ class RoomBatchHandler: ev["sender"], app_service_requester.app_service ), event_dict, - # Only the first event in the chain should be floating. - # The rest should hang off each other in a chain. + # Only the first event (which is the insertion event) in the + # chain should be floating. The rest should hang off each other + # in a chain. allow_no_prev_events=index == 0, prev_event_ids=event_dict.get("prev_events"), - auth_event_ids=auth_event_ids, + # Since the first event (which is the insertion event) in the + # chain is floating with no `prev_events`, it can't derive state + # from anywhere automatically. So we need to set some state + # explicitly. + state_event_ids=initial_state_event_ids if index == 0 else None, historical=True, depth=inherited_depth, ) @@ -343,10 +371,9 @@ class RoomBatchHandler: ) logger.debug( - "RoomBatchSendEventRestServlet inserting event=%s, prev_event_ids=%s, auth_event_ids=%s", + "RoomBatchSendEventRestServlet inserting event=%s, prev_event_ids=%s", event, prev_event_ids, - auth_event_ids, ) events_to_persist.append((event, context)) @@ -376,12 +403,12 @@ class RoomBatchHandler: room_id: str, batch_id_to_connect_to: str, inherited_depth: int, - auth_event_ids: List[str], + initial_state_event_ids: List[str], app_service_requester: Requester, ) -> Tuple[List[str], str]: """ - Handles creating and persisting all of the historical events as well - as insertion and batch meta events to make the batch navigable in the DAG. + Handles creating and persisting all of the historical events as well as + insertion and batch meta events to make the batch navigable in the DAG. Args: events_to_create: List of historical events to create in JSON @@ -391,8 +418,13 @@ class RoomBatchHandler: want this batch to connect to. inherited_depth: The depth to create the events at (you will probably by calling inherit_depth_from_prev_ids(...)). - auth_event_ids: Define which events allow you to create the given - event in the room. + initial_state_event_ids: + This is used to set explicit state for the insertion event at + the start of the historical batch since it's floating with no + prev_events to derive state from automatically. This should + probably be the state from the `prev_event` defined by + `/batch_send?prev_event_id=$abc` plus the outcome of + `persist_state_events_at_start` app_service_requester: The requester of an application service. Returns: @@ -438,7 +470,7 @@ class RoomBatchHandler: events_to_create=events_to_create, room_id=room_id, inherited_depth=inherited_depth, - auth_event_ids=auth_event_ids, + initial_state_event_ids=initial_state_event_ids, app_service_requester=app_service_requester, ) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 7cbc484b06..a33fa34aa8 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -272,6 +272,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): allow_no_prev_events: bool = False, prev_event_ids: Optional[List[str]] = None, auth_event_ids: Optional[List[str]] = None, + state_event_ids: Optional[List[str]] = None, txn_id: Optional[str] = None, ratelimit: bool = True, content: Optional[dict] = None, @@ -298,6 +299,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): The event ids to use as the auth_events for the new event. Should normally be left as None, which will cause them to be calculated based on the room state at the prev_events. + state_event_ids: + The full state at a given event. This is used particularly by the MSC2716 + /batch_send endpoint. One use case is the historical `state_events_at_start`; + since each is marked as an `outlier`, the `EventContext.for_outlier()` won't + have any `state_ids` set and therefore can't derive any state even though the + prev_events are set so we need to set them ourself via this argument. + This should normally be left as None, which will cause the auth_event_ids + to be calculated based on the room state at the prev_events. txn_id: ratelimit: @@ -353,6 +362,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): allow_no_prev_events=allow_no_prev_events, prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids, + state_event_ids=state_event_ids, require_consent=require_consent, outlier=outlier, historical=historical, @@ -456,6 +466,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): allow_no_prev_events: bool = False, prev_event_ids: Optional[List[str]] = None, auth_event_ids: Optional[List[str]] = None, + state_event_ids: Optional[List[str]] = None, ) -> Tuple[str, int]: """Update a user's membership in a room. @@ -487,6 +498,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): The event ids to use as the auth_events for the new event. Should normally be left as None, which will cause them to be calculated based on the room state at the prev_events. + state_event_ids: + The full state at a given event. This is used particularly by the MSC2716 + /batch_send endpoint. One use case is the historical `state_events_at_start`; + since each is marked as an `outlier`, the `EventContext.for_outlier()` won't + have any `state_ids` set and therefore can't derive any state even though the + prev_events are set so we need to set them ourself via this argument. + This should normally be left as None, which will cause the auth_event_ids + to be calculated based on the room state at the prev_events. Returns: A tuple of the new event ID and stream ID. @@ -526,6 +545,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): allow_no_prev_events=allow_no_prev_events, prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids, + state_event_ids=state_event_ids, ) return result @@ -548,6 +568,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): allow_no_prev_events: bool = False, prev_event_ids: Optional[List[str]] = None, auth_event_ids: Optional[List[str]] = None, + state_event_ids: Optional[List[str]] = None, ) -> Tuple[str, int]: """Helper for update_membership. @@ -581,6 +602,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): The event ids to use as the auth_events for the new event. Should normally be left as None, which will cause them to be calculated based on the room state at the prev_events. + state_event_ids: + The full state at a given event. This is used particularly by the MSC2716 + /batch_send endpoint. One use case is the historical `state_events_at_start`; + since each is marked as an `outlier`, the `EventContext.for_outlier()` won't + have any `state_ids` set and therefore can't derive any state even though the + prev_events are set so we need to set them ourself via this argument. + This should normally be left as None, which will cause the auth_event_ids + to be calculated based on the room state at the prev_events. Returns: A tuple of the new event ID and stream ID. @@ -708,6 +737,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): allow_no_prev_events=allow_no_prev_events, prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids, + state_event_ids=state_event_ids, content=content, require_consent=require_consent, outlier=outlier, @@ -932,6 +962,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): ratelimit=ratelimit, prev_event_ids=latest_event_ids, auth_event_ids=auth_event_ids, + state_event_ids=state_event_ids, content=content, require_consent=require_consent, outlier=outlier, diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py index 0048973e59..0780485322 100644 --- a/synapse/rest/client/room_batch.py +++ b/synapse/rest/client/room_batch.py @@ -124,14 +124,14 @@ class RoomBatchSendEventRestServlet(RestServlet): ) # For the event we are inserting next to (`prev_event_ids_from_query`), - # find the most recent auth events (derived from state events) that - # allowed that message to be sent. We will use that as a base - # to auth our historical messages against. - auth_event_ids = await self.room_batch_handler.get_most_recent_auth_event_ids_from_event_id_list( + # find the most recent state events that allowed that message to be + # sent. We will use that as a base to auth our historical messages + # against. + state_event_ids = await self.room_batch_handler.get_most_recent_full_state_ids_from_event_id_list( prev_event_ids_from_query ) - if not auth_event_ids: + if not state_event_ids: raise SynapseError( HTTPStatus.BAD_REQUEST, "No auth events found for given prev_event query parameter. The prev_event=%s probably does not exist." @@ -148,13 +148,13 @@ class RoomBatchSendEventRestServlet(RestServlet): await self.room_batch_handler.persist_state_events_at_start( state_events_at_start=body["state_events_at_start"], room_id=room_id, - initial_auth_event_ids=auth_event_ids, + initial_state_event_ids=state_event_ids, app_service_requester=requester, ) ) # Update our ongoing auth event ID list with all of the new state we # just created - auth_event_ids.extend(state_event_ids_at_start) + state_event_ids.extend(state_event_ids_at_start) inherited_depth = await self.room_batch_handler.inherit_depth_from_prev_ids( prev_event_ids_from_query @@ -196,7 +196,12 @@ class RoomBatchSendEventRestServlet(RestServlet): ), base_insertion_event_dict, prev_event_ids=base_insertion_event_dict.get("prev_events"), - auth_event_ids=auth_event_ids, + # Also set the explicit state here because we want to resolve + # any `state_events_at_start` here too. It's not strictly + # necessary to accomplish anything but if someone asks for the + # state at this point, we probably want to show them the + # historical state that was part of this batch. + state_event_ids=state_event_ids, historical=True, depth=inherited_depth, ) @@ -212,7 +217,7 @@ class RoomBatchSendEventRestServlet(RestServlet): room_id=room_id, batch_id_to_connect_to=batch_id_to_connect_to, inherited_depth=inherited_depth, - auth_event_ids=auth_event_ids, + initial_state_event_ids=state_event_ids, app_service_requester=requester, ) -- cgit 1.5.1