diff options
41 files changed, 593 insertions, 374 deletions
diff --git a/changelog.d/11530.bugfix b/changelog.d/11530.bugfix new file mode 100644 index 0000000000..7ea9ba4e49 --- /dev/null +++ b/changelog.d/11530.bugfix @@ -0,0 +1,2 @@ +Fix a long-standing issue which could cause Synapse to incorrectly accept data in the unsigned field of events +received over federation. \ No newline at end of file diff --git a/changelog.d/11587.bugfix b/changelog.d/11587.bugfix new file mode 100644 index 0000000000..ad2b83edf7 --- /dev/null +++ b/changelog.d/11587.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where Synapse wouldn't cache a response indicating that a remote user has no devices. \ No newline at end of file diff --git a/changelog.d/11593.bugfix b/changelog.d/11593.bugfix new file mode 100644 index 0000000000..963fd0e58e --- /dev/null +++ b/changelog.d/11593.bugfix @@ -0,0 +1 @@ +Fix an error in to get federation status of a destination server even if no error has occurred. This admin API was new introduced in Synapse 1.49.0. diff --git a/changelog.d/11612.misc b/changelog.d/11612.misc new file mode 100644 index 0000000000..2d886169c5 --- /dev/null +++ b/changelog.d/11612.misc @@ -0,0 +1 @@ +Avoid database access in the JSON serialization process. diff --git a/changelog.d/11667.bugfix b/changelog.d/11667.bugfix new file mode 100644 index 0000000000..bf65fd4c8b --- /dev/null +++ b/changelog.d/11667.bugfix @@ -0,0 +1 @@ +Fix `/_matrix/client/v1/room/{roomId}/hierarchy` endpoint returning incorrect fields which have been present since Synapse 1.49.0. diff --git a/changelog.d/11672.feature b/changelog.d/11672.feature new file mode 100644 index 0000000000..ce8b3e9547 --- /dev/null +++ b/changelog.d/11672.feature @@ -0,0 +1 @@ +Return an `M_FORBIDDEN` error code instead of `M_UNKNOWN` when a spam checker module prevents a user from creating a room. diff --git a/changelog.d/11682.removal b/changelog.d/11682.removal new file mode 100644 index 0000000000..50bdf35b20 --- /dev/null +++ b/changelog.d/11682.removal @@ -0,0 +1 @@ +Remove the unstable `/send_relation` endpoint. diff --git a/changelog.d/11685.misc b/changelog.d/11685.misc new file mode 100644 index 0000000000..c4566b2012 --- /dev/null +++ b/changelog.d/11685.misc @@ -0,0 +1 @@ +Run `pyupgrade --py37-plus --keep-percent-format` on Synapse. diff --git a/changelog.d/11693.misc b/changelog.d/11693.misc new file mode 100644 index 0000000000..521a1796b8 --- /dev/null +++ b/changelog.d/11693.misc @@ -0,0 +1 @@ +Remove debug logging for #4422, which has been closed since Synapse 0.99. \ No newline at end of file diff --git a/changelog.d/11699.misc b/changelog.d/11699.misc new file mode 100644 index 0000000000..ffae5f2960 --- /dev/null +++ b/changelog.d/11699.misc @@ -0,0 +1 @@ +Remove fallback code for Python 2. diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 2038e72924..de0e0c1731 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -14,17 +14,7 @@ # limitations under the License. import collections.abc import re -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Iterable, - List, - Mapping, - Optional, - Union, -) +from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union from frozendict import frozendict @@ -32,14 +22,10 @@ from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersion from synapse.types import JsonDict -from synapse.util.async_helpers import yieldable_gather_results from synapse.util.frozenutils import unfreeze from . import EventBase -if TYPE_CHECKING: - from synapse.server import HomeServer - # Split strings on "." but not "\." This uses a negative lookbehind assertion for '\' # (?<!stuff) matches if the current position in the string is not preceded # by a match for 'stuff'. @@ -385,17 +371,12 @@ class EventClientSerializer: clients. """ - def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() - self._msc1849_enabled = hs.config.experimental.msc1849_enabled - self._msc3440_enabled = hs.config.experimental.msc3440_enabled - - async def serialize_event( + def serialize_event( self, event: Union[JsonDict, EventBase], time_now: int, *, - bundle_aggregations: bool = False, + bundle_aggregations: Optional[Dict[str, JsonDict]] = None, **kwargs: Any, ) -> JsonDict: """Serializes a single event. @@ -418,66 +399,41 @@ class EventClientSerializer: serialized_event = serialize_event(event, time_now, **kwargs) # Check if there are any bundled aggregations to include with the event. - # - # Do not bundle aggregations if any of the following at true: - # - # * Support is disabled via the configuration or the caller. - # * The event is a state event. - # * The event has been redacted. - if ( - self._msc1849_enabled - and bundle_aggregations - and not event.is_state() - and not event.internal_metadata.is_redacted() - ): - await self._injected_bundled_aggregations(event, time_now, serialized_event) + if bundle_aggregations: + event_aggregations = bundle_aggregations.get(event.event_id) + if event_aggregations: + self._injected_bundled_aggregations( + event, + time_now, + bundle_aggregations[event.event_id], + serialized_event, + ) return serialized_event - async def _injected_bundled_aggregations( - self, event: EventBase, time_now: int, serialized_event: JsonDict + def _injected_bundled_aggregations( + self, + event: EventBase, + time_now: int, + aggregations: JsonDict, + serialized_event: JsonDict, ) -> None: """Potentially injects bundled aggregations into the unsigned portion of the serialized event. Args: event: The event being serialized. time_now: The current time in milliseconds + aggregations: The bundled aggregation to serialize. serialized_event: The serialized event which may be modified. """ - # Do not bundle aggregations for an event which represents an edit or an - # annotation. It does not make sense for them to have related events. - relates_to = event.content.get("m.relates_to") - if isinstance(relates_to, (dict, frozendict)): - relation_type = relates_to.get("rel_type") - if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE): - return - - event_id = event.event_id - room_id = event.room_id - - # The bundled aggregations to include. - aggregations = {} - - annotations = await self.store.get_aggregation_groups_for_event( - event_id, room_id - ) - if annotations.chunk: - aggregations[RelationTypes.ANNOTATION] = annotations.to_dict() + # Make a copy in-case the object is cached. + aggregations = aggregations.copy() - references = await self.store.get_relations_for_event( - event_id, room_id, RelationTypes.REFERENCE, direction="f" - ) - if references.chunk: - aggregations[RelationTypes.REFERENCE] = references.to_dict() - - edit = None - if event.type == EventTypes.Message: - edit = await self.store.get_applicable_edit(event_id, room_id) - - if edit: + if RelationTypes.REPLACE in aggregations: # If there is an edit replace the content, preserving existing # relations. + edit = aggregations[RelationTypes.REPLACE] # Ensure we take copies of the edit content, otherwise we risk modifying # the original event. @@ -502,27 +458,19 @@ class EventClientSerializer: } # If this event is the start of a thread, include a summary of the replies. - if self._msc3440_enabled: - ( - thread_count, - latest_thread_event, - ) = await self.store.get_thread_summary(event_id, room_id) - if latest_thread_event: - aggregations[RelationTypes.THREAD] = { - # Don't bundle aggregations as this could recurse forever. - "latest_event": await self.serialize_event( - latest_thread_event, time_now, bundle_aggregations=False - ), - "count": thread_count, - } - - # If any bundled aggregations were found, include them. - if aggregations: - serialized_event["unsigned"].setdefault("m.relations", {}).update( - aggregations + if RelationTypes.THREAD in aggregations: + # Serialize the latest thread event. + latest_thread_event = aggregations[RelationTypes.THREAD]["latest_event"] + + # Don't bundle aggregations as this could recurse forever. + aggregations[RelationTypes.THREAD]["latest_event"] = self.serialize_event( + latest_thread_event, time_now, bundle_aggregations=None ) - async def serialize_events( + # Include the bundled aggregations in the event. + serialized_event["unsigned"].setdefault("m.relations", {}).update(aggregations) + + def serialize_events( self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any ) -> List[JsonDict]: """Serializes multiple events. @@ -535,9 +483,9 @@ class EventClientSerializer: Returns: The list of serialized events """ - return await yieldable_gather_results( - self.serialize_event, events, time_now=time_now, **kwargs - ) + return [ + self.serialize_event(event, time_now=time_now, **kwargs) for event in events + ] def copy_power_levels_contents( diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index addc0bf000..896168c05c 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -230,6 +230,10 @@ def event_from_pdu_json(pdu_json: JsonDict, room_version: RoomVersion) -> EventB # origin, etc etc) assert_params_in_dict(pdu_json, ("type", "depth")) + # Strip any unauthorized values from "unsigned" if they exist + if "unsigned" in pdu_json: + _strip_unsigned_values(pdu_json) + depth = pdu_json["depth"] if not isinstance(depth, int): raise SynapseError(400, "Depth %r not an intger" % (depth,), Codes.BAD_JSON) @@ -245,3 +249,24 @@ def event_from_pdu_json(pdu_json: JsonDict, room_version: RoomVersion) -> EventB event = make_event_from_dict(pdu_json, room_version) return event + + +def _strip_unsigned_values(pdu_dict: JsonDict) -> None: + """ + Strip any unsigned values unless specifically allowed, as defined by the whitelist. + + pdu: the json dict to strip values from. Note that the dict is mutated by this + function + """ + unsigned = pdu_dict["unsigned"] + + if not isinstance(unsigned, dict): + pdu_dict["unsigned"] = {} + + if pdu_dict["type"] == "m.room.member": + whitelist = ["knock_room_state", "invite_room_state", "age"] + else: + whitelist = ["age"] + + filtered_unsigned = {k: v for k, v in unsigned.items() if k in whitelist} + pdu_dict["unsigned"] = filtered_unsigned diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 7665425232..b184a48cb1 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -948,8 +948,16 @@ class DeviceListUpdater: devices = [] ignore_devices = True else: + prev_stream_id = await self.store.get_device_list_last_stream_id_for_remote( + user_id + ) cached_devices = await self.store.get_cached_devices_for_user(user_id) - if cached_devices == {d["device_id"]: d for d in devices}: + + # To ensure that a user with no devices is cached, we skip the resync only + # if we have a stream_id from previously writing a cache entry. + if prev_stream_id is not None and cached_devices == { + d["device_id"]: d for d in devices + }: logging.info( "Skipping device list resync for %s, as our cache matches already", user_id, diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 1b996c420d..a3add8a586 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -119,7 +119,7 @@ class EventStreamHandler: events.extend(to_add) - chunks = await self._event_serializer.serialize_events( + chunks = self._event_serializer.serialize_events( events, time_now, as_client_event=as_client_event, diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 601bab67f9..346a06ff49 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -170,7 +170,7 @@ class InitialSyncHandler: d["inviter"] = event.sender invite_event = await self.store.get_event(event.event_id) - d["invite"] = await self._event_serializer.serialize_event( + d["invite"] = self._event_serializer.serialize_event( invite_event, time_now, as_client_event=as_client_event, @@ -222,7 +222,7 @@ class InitialSyncHandler: d["messages"] = { "chunk": ( - await self._event_serializer.serialize_events( + self._event_serializer.serialize_events( messages, time_now=time_now, as_client_event=as_client_event, @@ -232,7 +232,7 @@ class InitialSyncHandler: "end": await end_token.to_string(self.store), } - d["state"] = await self._event_serializer.serialize_events( + d["state"] = self._event_serializer.serialize_events( current_state.values(), time_now=time_now, as_client_event=as_client_event, @@ -376,16 +376,14 @@ class InitialSyncHandler: "messages": { "chunk": ( # Don't bundle aggregations as this is a deprecated API. - await self._event_serializer.serialize_events(messages, time_now) + self._event_serializer.serialize_events(messages, time_now) ), "start": await start_token.to_string(self.store), "end": await end_token.to_string(self.store), }, "state": ( # Don't bundle aggregations as this is a deprecated API. - await self._event_serializer.serialize_events( - room_state.values(), time_now - ) + self._event_serializer.serialize_events(room_state.values(), time_now) ), "presence": [], "receipts": [], @@ -404,7 +402,7 @@ class InitialSyncHandler: # TODO: These concurrently time_now = self.clock.time_msec() # Don't bundle aggregations as this is a deprecated API. - state = await self._event_serializer.serialize_events( + state = self._event_serializer.serialize_events( current_state.values(), time_now ) @@ -480,7 +478,7 @@ class InitialSyncHandler: "messages": { "chunk": ( # Don't bundle aggregations as this is a deprecated API. - await self._event_serializer.serialize_events(messages, time_now) + self._event_serializer.serialize_events(messages, time_now) ), "start": await start_token.to_string(self.store), "end": await end_token.to_string(self.store), diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 5e3d3886eb..b37250aa38 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -246,7 +246,7 @@ class MessageHandler: room_state = room_state_events[membership_event_id] now = self.clock.time_msec() - events = await self._event_serializer.serialize_events(room_state.values(), now) + events = self._event_serializer.serialize_events(room_state.values(), now) return events async def get_joined_members(self, requester: Requester, room_id: str) -> dict: diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 7469cc55a2..472688f045 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -537,14 +537,16 @@ class PaginationHandler: state_dict = await self.store.get_events(list(state_ids.values())) state = state_dict.values() + aggregations = await self.store.get_bundled_aggregations(events) + time_now = self.clock.time_msec() chunk = { "chunk": ( - await self._event_serializer.serialize_events( + self._event_serializer.serialize_events( events, time_now, - bundle_aggregations=True, + bundle_aggregations=aggregations, as_client_event=as_client_event, ) ), @@ -553,7 +555,7 @@ class PaginationHandler: } if state: - chunk["state"] = await self._event_serializer.serialize_events( + chunk["state"] = self._event_serializer.serialize_events( state, time_now, as_client_event=as_client_event ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index b9c1cbffa5..3d47163f25 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -393,7 +393,9 @@ class RoomCreationHandler: user_id = requester.user.to_string() if not await self.spam_checker.user_may_create_room(user_id): - raise SynapseError(403, "You are not permitted to create rooms") + raise SynapseError( + 403, "You are not permitted to create rooms", Codes.FORBIDDEN + ) creation_content: JsonDict = { "room_version": new_room_version.identifier, @@ -685,7 +687,9 @@ class RoomCreationHandler: invite_3pid_list, ) ): - raise SynapseError(403, "You are not permitted to create rooms") + raise SynapseError( + 403, "You are not permitted to create rooms", Codes.FORBIDDEN + ) if ratelimit: await self.request_ratelimiter.ratelimit(requester) @@ -1177,6 +1181,16 @@ class RoomContextHandler: # `filtered` rather than the event we retrieved from the datastore. results["event"] = filtered[0] + # Fetch the aggregations. + aggregations = await self.store.get_bundled_aggregations([results["event"]]) + aggregations.update( + await self.store.get_bundled_aggregations(results["events_before"]) + ) + aggregations.update( + await self.store.get_bundled_aggregations(results["events_after"]) + ) + results["aggregations"] = aggregations + if results["events_after"]: last_event_id = results["events_after"][-1].event_id else: diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index b2cfe537df..9ef88feb8a 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -209,7 +209,7 @@ class RoomSummaryHandler: # Before returning to the client, remove the allowed_room_ids # and allowed_spaces keys. room.pop("allowed_room_ids", None) - room.pop("allowed_spaces", None) + room.pop("allowed_spaces", None) # historical rooms_result.append(room) events.extend(room_entry.children_state_events) @@ -988,12 +988,14 @@ class RoomSummaryHandler: "canonical_alias": stats["canonical_alias"], "num_joined_members": stats["joined_members"], "avatar_url": stats["avatar"], + # plural join_rules is a documentation error but kept for historical + # purposes. Should match /publicRooms. "join_rules": stats["join_rules"], + "join_rule": stats["join_rules"], "world_readable": ( stats["history_visibility"] == HistoryVisibility.WORLD_READABLE ), "guest_can_join": stats["guest_access"] == "can_join", - "creation_ts": create_event.origin_server_ts, "room_type": create_event.content.get(EventContentFields.ROOM_TYPE), } diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index ab7eaab2fb..0b153a6822 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -420,10 +420,10 @@ class SearchHandler: time_now = self.clock.time_msec() for context in contexts.values(): - context["events_before"] = await self._event_serializer.serialize_events( + context["events_before"] = self._event_serializer.serialize_events( context["events_before"], time_now ) - context["events_after"] = await self._event_serializer.serialize_events( + context["events_after"] = self._event_serializer.serialize_events( context["events_after"], time_now ) @@ -441,9 +441,7 @@ class SearchHandler: results.append( { "rank": rank_map[e.event_id], - "result": ( - await self._event_serializer.serialize_event(e, time_now) - ), + "result": self._event_serializer.serialize_event(e, time_now), "context": contexts.get(e.event_id, {}), } ) @@ -457,7 +455,7 @@ class SearchHandler: if state_results: s = {} for room_id, state_events in state_results.items(): - s[room_id] = await self._event_serializer.serialize_events( + s[room_id] = self._event_serializer.serialize_events( state_events, time_now ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 7baf3f199c..4b3f1ea059 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -60,10 +60,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -# Debug logger for https://github.com/matrix-org/synapse/issues/4422 -issue4422_logger = logging.getLogger("synapse.handler.sync.4422_debug") - - # Counts the number of times we returned a non-empty sync. `type` is one of # "initial_sync", "full_state_sync" or "incremental_sync", `lazy_loaded` is # "true" or "false" depending on if the request asked for lazy loaded members or @@ -1161,13 +1157,8 @@ class SyncHandler: num_events = 0 - # debug for https://github.com/matrix-org/synapse/issues/4422 + # debug for https://github.com/matrix-org/synapse/issues/9424 for joined_room in sync_result_builder.joined: - room_id = joined_room.room_id - if room_id in newly_joined_rooms: - issue4422_logger.debug( - "Sync result for newly joined room %s: %r", room_id, joined_room - ) num_events += len(joined_room.timeline.events) log_kv( @@ -1740,18 +1731,6 @@ class SyncHandler: old_mem_ev_id, allow_none=True ) - # debug for #4422 - if has_join: - prev_membership = None - if old_mem_ev: - prev_membership = old_mem_ev.membership - issue4422_logger.debug( - "Previous membership for room %s with join: %s (event %s)", - room_id, - prev_membership, - old_mem_ev_id, - ) - if not old_mem_ev or old_mem_ev.membership != Membership.JOIN: newly_joined_rooms.append(room_id) @@ -1893,13 +1872,6 @@ class SyncHandler: upto_token=since_token, ) - if newly_joined: - # debugging for https://github.com/matrix-org/synapse/issues/4422 - issue4422_logger.debug( - "RoomSyncResultBuilder events for newly joined room %s: %r", - room_id, - entry.events, - ) room_entries.append(entry) return _RoomChanges( @@ -2077,14 +2049,6 @@ class SyncHandler: # `_load_filtered_recents` can't find any events the user should see # (e.g. due to having ignored the sender of the last 50 events). - if newly_joined: - # debug for https://github.com/matrix-org/synapse/issues/4422 - issue4422_logger.debug( - "Timeline events after filtering in newly-joined room %s: %r", - room_id, - batch, - ) - # When we join the room (or the client requests full_state), we should # send down any existing tags. Usually the user won't have tags in a # newly joined room, unless either a) they've joined before or b) the diff --git a/synapse/rest/admin/federation.py b/synapse/rest/admin/federation.py index 50d88c9109..8cd3fa189e 100644 --- a/synapse/rest/admin/federation.py +++ b/synapse/rest/admin/federation.py @@ -111,25 +111,37 @@ class DestinationsRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self._auth, request) + if not await self._store.is_destination_known(destination): + raise NotFoundError("Unknown destination") + destination_retry_timings = await self._store.get_destination_retry_timings( destination ) - if not destination_retry_timings: - raise NotFoundError("Unknown destination") - last_successful_stream_ordering = ( await self._store.get_destination_last_successful_stream_ordering( destination ) ) - response = { + response: JsonDict = { "destination": destination, - "failure_ts": destination_retry_timings.failure_ts, - "retry_last_ts": destination_retry_timings.retry_last_ts, - "retry_interval": destination_retry_timings.retry_interval, "last_successful_stream_ordering": last_successful_stream_ordering, } + if destination_retry_timings: + response = { + **response, + "failure_ts": destination_retry_timings.failure_ts, + "retry_last_ts": destination_retry_timings.retry_last_ts, + "retry_interval": destination_retry_timings.retry_interval, + } + else: + response = { + **response, + "failure_ts": None, + "retry_last_ts": 0, + "retry_interval": 0, + } + return HTTPStatus.OK, response diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index 7236e4027f..299f5c9eb0 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -466,7 +466,7 @@ class UserMediaRestServlet(RestServlet): ) deleted_media, total = await self.media_repository.delete_local_media_ids( - ([row["media_id"] for row in media]) + [row["media_id"] for row in media] ) return HTTPStatus.OK, {"deleted_media": deleted_media, "total": total} diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 6030373ebc..2e714ac87b 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -424,7 +424,7 @@ class RoomStateRestServlet(RestServlet): event_ids = await self.store.get_current_state_ids(room_id) events = await self.store.get_events(event_ids.values()) now = self.clock.time_msec() - room_state = await self._event_serializer.serialize_events(events.values(), now) + room_state = self._event_serializer.serialize_events(events.values(), now) ret = {"state": room_state} return HTTPStatus.OK, ret @@ -744,22 +744,22 @@ class RoomEventContextServlet(RestServlet): ) time_now = self.clock.time_msec() - results["events_before"] = await self._event_serializer.serialize_events( + results["events_before"] = self._event_serializer.serialize_events( results["events_before"], time_now, - bundle_aggregations=True, + bundle_aggregations=results["aggregations"], ) - results["event"] = await self._event_serializer.serialize_event( + results["event"] = self._event_serializer.serialize_event( results["event"], time_now, - bundle_aggregations=True, + bundle_aggregations=results["aggregations"], ) - results["events_after"] = await self._event_serializer.serialize_events( + results["events_after"] = self._event_serializer.serialize_events( results["events_after"], time_now, - bundle_aggregations=True, + bundle_aggregations=results["aggregations"], ) - results["state"] = await self._event_serializer.serialize_events( + results["state"] = self._event_serializer.serialize_events( results["state"], time_now ) diff --git a/synapse/rest/client/events.py b/synapse/rest/client/events.py index 13b72a045a..672c821061 100644 --- a/synapse/rest/client/events.py +++ b/synapse/rest/client/events.py @@ -91,7 +91,7 @@ class EventRestServlet(RestServlet): time_now = self.clock.time_msec() if event: - result = await self._event_serializer.serialize_event(event, time_now) + result = self._event_serializer.serialize_event(event, time_now) return 200, result else: return 404, "Event not found." diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py index acd0c9e135..8e427a96a3 100644 --- a/synapse/rest/client/notifications.py +++ b/synapse/rest/client/notifications.py @@ -72,7 +72,7 @@ class NotificationsServlet(RestServlet): "actions": pa.actions, "ts": pa.received_ts, "event": ( - await self._event_serializer.serialize_event( + self._event_serializer.serialize_event( notif_events[pa.event_id], self.clock.time_msec(), event_format=format_event_for_client_v2_without_room_id, diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 5815650ee6..37d949a71e 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -19,28 +19,20 @@ any time to reflect changes in the MSC. """ import logging -from typing import TYPE_CHECKING, Awaitable, Optional, Tuple +from typing import TYPE_CHECKING, Optional, Tuple -from synapse.api.constants import EventTypes, RelationTypes -from synapse.api.errors import ShadowBanError, SynapseError +from synapse.api.constants import RelationTypes +from synapse.api.errors import SynapseError from synapse.http.server import HttpServer -from synapse.http.servlet import ( - RestServlet, - parse_integer, - parse_json_object_from_request, - parse_string, -) +from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.site import SynapseRequest -from synapse.rest.client.transactions import HttpTransactionCache +from synapse.rest.client._base import client_patterns from synapse.storage.relations import ( AggregationPaginationToken, PaginationChunk, RelationPaginationToken, ) from synapse.types import JsonDict -from synapse.util.stringutils import random_string - -from ._base import client_patterns if TYPE_CHECKING: from synapse.server import HomeServer @@ -48,112 +40,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class RelationSendServlet(RestServlet): - """Helper API for sending events that have relation data. - - Example API shape to send a 👍 reaction to a room: - - POST /rooms/!foo/send_relation/$bar/m.annotation/m.reaction?key=%F0%9F%91%8D - {} - - { - "event_id": "$foobar" - } - """ - - PATTERN = ( - "/rooms/(?P<room_id>[^/]*)/send_relation" - "/(?P<parent_id>[^/]*)/(?P<relation_type>[^/]*)/(?P<event_type>[^/]*)" - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.event_creation_handler = hs.get_event_creation_handler() - self.txns = HttpTransactionCache(hs) - - def register(self, http_server: HttpServer) -> None: - http_server.register_paths( - "POST", - client_patterns(self.PATTERN + "$", releases=()), - self.on_PUT_or_POST, - self.__class__.__name__, - ) - http_server.register_paths( - "PUT", - client_patterns(self.PATTERN + "/(?P<txn_id>[^/]*)$", releases=()), - self.on_PUT, - self.__class__.__name__, - ) - - def on_PUT( - self, - request: SynapseRequest, - room_id: str, - parent_id: str, - relation_type: str, - event_type: str, - txn_id: Optional[str] = None, - ) -> Awaitable[Tuple[int, JsonDict]]: - return self.txns.fetch_or_execute_request( - request, - self.on_PUT_or_POST, - request, - room_id, - parent_id, - relation_type, - event_type, - txn_id, - ) - - async def on_PUT_or_POST( - self, - request: SynapseRequest, - room_id: str, - parent_id: str, - relation_type: str, - event_type: str, - txn_id: Optional[str] = None, - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - - if event_type == EventTypes.Member: - # Add relations to a membership is meaningless, so we just deny it - # at the CS API rather than trying to handle it correctly. - raise SynapseError(400, "Cannot send member events with relations") - - content = parse_json_object_from_request(request) - - aggregation_key = parse_string(request, "key", encoding="utf-8") - - content["m.relates_to"] = { - "event_id": parent_id, - "rel_type": relation_type, - } - if aggregation_key is not None: - content["m.relates_to"]["key"] = aggregation_key - - event_dict = { - "type": event_type, - "content": content, - "room_id": room_id, - "sender": requester.user.to_string(), - } - - try: - ( - event, - _, - ) = await self.event_creation_handler.create_and_send_nonmember_event( - requester, event_dict=event_dict, txn_id=txn_id - ) - event_id = event.event_id - except ShadowBanError: - event_id = "$" + random_string(43) - - return 200, {"event_id": event_id} - - class RelationPaginationServlet(RestServlet): """API to paginate relations on an event by topological ordering, optionally filtered by relation type and event type. @@ -227,13 +113,14 @@ class RelationPaginationServlet(RestServlet): now = self.clock.time_msec() # Do not bundle aggregations when retrieving the original event because # we want the content before relations are applied to it. - original_event = await self._event_serializer.serialize_event( - event, now, bundle_aggregations=False + original_event = self._event_serializer.serialize_event( + event, now, bundle_aggregations=None ) # The relations returned for the requested event do include their # bundled aggregations. - serialized_events = await self._event_serializer.serialize_events( - events, now, bundle_aggregations=True + aggregations = await self.store.get_bundled_aggregations(events) + serialized_events = self._event_serializer.serialize_events( + events, now, bundle_aggregations=aggregations ) return_value = pagination_chunk.to_dict() @@ -422,7 +309,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): ) now = self.clock.time_msec() - serialized_events = await self._event_serializer.serialize_events(events, now) + serialized_events = self._event_serializer.serialize_events(events, now) return_value = result.to_dict() return_value["chunk"] = serialized_events @@ -431,7 +318,6 @@ class RelationAggregationGroupPaginationServlet(RestServlet): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: - RelationSendServlet(hs).register(http_server) RelationPaginationServlet(hs).register(http_server) RelationAggregationPaginationServlet(hs).register(http_server) RelationAggregationGroupPaginationServlet(hs).register(http_server) diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 40330749e5..da6014900a 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -642,6 +642,7 @@ class RoomEventServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.clock = hs.get_clock() + self._store = hs.get_datastore() self.event_handler = hs.get_event_handler() self._event_serializer = hs.get_event_client_serializer() self.auth = hs.get_auth() @@ -660,10 +661,13 @@ class RoomEventServlet(RestServlet): # https://matrix.org/docs/spec/client_server/r0.5.0#get-matrix-client-r0-rooms-roomid-event-eventid raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) - time_now = self.clock.time_msec() if event: - event_dict = await self._event_serializer.serialize_event( - event, time_now, bundle_aggregations=True + # Ensure there are bundled aggregations available. + aggregations = await self._store.get_bundled_aggregations([event]) + + time_now = self.clock.time_msec() + event_dict = self._event_serializer.serialize_event( + event, time_now, bundle_aggregations=aggregations ) return 200, event_dict @@ -708,16 +712,20 @@ class RoomEventContextServlet(RestServlet): raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) time_now = self.clock.time_msec() - results["events_before"] = await self._event_serializer.serialize_events( - results["events_before"], time_now, bundle_aggregations=True + results["events_before"] = self._event_serializer.serialize_events( + results["events_before"], + time_now, + bundle_aggregations=results["aggregations"], ) - results["event"] = await self._event_serializer.serialize_event( - results["event"], time_now, bundle_aggregations=True + results["event"] = self._event_serializer.serialize_event( + results["event"], time_now, bundle_aggregations=results["aggregations"] ) - results["events_after"] = await self._event_serializer.serialize_events( - results["events_after"], time_now, bundle_aggregations=True + results["events_after"] = self._event_serializer.serialize_events( + results["events_after"], + time_now, + bundle_aggregations=results["aggregations"], ) - results["state"] = await self._event_serializer.serialize_events( + results["state"] = self._event_serializer.serialize_events( results["state"], time_now ) diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index e99a943d0d..a3e57e4b20 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -17,7 +17,6 @@ from collections import defaultdict from typing import ( TYPE_CHECKING, Any, - Awaitable, Callable, Dict, Iterable, @@ -395,7 +394,7 @@ class SyncRestServlet(RestServlet): """ invited = {} for room in rooms: - invite = await self._event_serializer.serialize_event( + invite = self._event_serializer.serialize_event( room.invite, time_now, token_id=token_id, @@ -432,7 +431,7 @@ class SyncRestServlet(RestServlet): """ knocked = {} for room in rooms: - knock = await self._event_serializer.serialize_event( + knock = self._event_serializer.serialize_event( room.knock, time_now, token_id=token_id, @@ -525,21 +524,14 @@ class SyncRestServlet(RestServlet): The room, encoded in our response format """ - def serialize(events: Iterable[EventBase]) -> Awaitable[List[JsonDict]]: + def serialize( + events: Iterable[EventBase], + aggregations: Optional[Dict[str, Dict[str, Any]]] = None, + ) -> List[JsonDict]: return self._event_serializer.serialize_events( events, time_now=time_now, - # Don't bother to bundle aggregations if the timeline is unlimited, - # as clients will have all the necessary information. - # bundle_aggregations=room.timeline.limited, - # - # richvdh 2021-12-15: disable this temporarily as it has too high an - # overhead for initialsyncs. We need to figure out a way that the - # bundling can be done *before* the events are stored in the - # SyncResponseCache so that this part can be synchronous. - # - # Ensure to re-enable the test at tests/rest/client/test_relations.py::RelationsTestCase.test_bundled_aggregations. - bundle_aggregations=False, + bundle_aggregations=aggregations, token_id=token_id, event_format=event_formatter, only_event_fields=only_fields, @@ -561,8 +553,21 @@ class SyncRestServlet(RestServlet): event.room_id, ) - serialized_state = await serialize(state_events) - serialized_timeline = await serialize(timeline_events) + serialized_state = serialize(state_events) + # Don't bother to bundle aggregations if the timeline is unlimited, + # as clients will have all the necessary information. + # bundle_aggregations=room.timeline.limited, + # + # richvdh 2021-12-15: disable this temporarily as it has too high an + # overhead for initialsyncs. We need to figure out a way that the + # bundling can be done *before* the events are stored in the + # SyncResponseCache so that this part can be synchronous. + # + # Ensure to re-enable the test at tests/rest/client/test_relations.py::RelationsTestCase.test_bundled_aggregations. + # if room.timeline.limited: + # aggregations = await self.store.get_bundled_aggregations(timeline_events) + aggregations = None + serialized_timeline = serialize(timeline_events, aggregations) account_data = room.account_data diff --git a/synapse/server.py b/synapse/server.py index 185e40e4da..3032f0b738 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -759,7 +759,7 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_event_client_serializer(self) -> EventClientSerializer: - return EventClientSerializer(self) + return EventClientSerializer() @cache_in_self def get_password_policy_handler(self) -> PasswordPolicyHandler: diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 273adb61fd..52fbf50db6 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -713,7 +713,7 @@ class DeviceWorkerStore(SQLBaseStore): @cached(max_entries=10000) async def get_device_list_last_stream_id_for_remote( self, user_id: str - ) -> Optional[Any]: + ) -> Optional[str]: """Get the last stream_id we got for a user. May be None if we haven't got any information for them. """ @@ -729,7 +729,9 @@ class DeviceWorkerStore(SQLBaseStore): cached_method_name="get_device_list_last_stream_id_for_remote", list_name="user_ids", ) - async def get_device_list_last_stream_id_for_remotes(self, user_ids: Iterable[str]): + async def get_device_list_last_stream_id_for_remotes( + self, user_ids: Iterable[str] + ) -> Dict[str, Optional[str]]: rows = await self.db_pool.simple_select_many_batch( table="device_lists_remote_extremeties", column="user_id", @@ -1316,6 +1318,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): content: JsonDict, stream_id: str, ) -> None: + """Delete, update or insert a cache entry for this (user, device) pair.""" if content.get("deleted"): self.db_pool.simple_delete_txn( txn, @@ -1375,6 +1378,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): def _update_remote_device_list_cache_txn( self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int ) -> None: + """Replace the list of cached devices for this user with the given list.""" self.db_pool.simple_delete_txn( txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id} ) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 4ff6aed253..c6c4bd18da 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -13,14 +13,30 @@ # limitations under the License. import logging -from typing import List, Optional, Tuple, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterable, + List, + Optional, + Tuple, + Union, + cast, +) import attr +from frozendict import frozendict -from synapse.api.constants import RelationTypes +from synapse.api.constants import EventTypes, RelationTypes from synapse.events import EventBase from synapse.storage._base import SQLBaseStore -from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, + make_in_list_sql_clause, +) from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.relations import ( AggregationPaginationToken, @@ -29,10 +45,24 @@ from synapse.storage.relations import ( ) from synapse.util.caches.descriptors import cached +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class RelationsWorkerStore(SQLBaseStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + self._msc1849_enabled = hs.config.experimental.msc1849_enabled + self._msc3440_enabled = hs.config.experimental.msc3440_enabled + @cached(tree=True) async def get_relations_for_event( self, @@ -515,6 +545,98 @@ class RelationsWorkerStore(SQLBaseStore): "get_if_user_has_annotated_event", _get_if_user_has_annotated_event ) + async def _get_bundled_aggregation_for_event( + self, event: EventBase + ) -> Optional[Dict[str, Any]]: + """Generate bundled aggregations for an event. + + Note that this does not use a cache, but depends on cached methods. + + Args: + event: The event to calculate bundled aggregations for. + + Returns: + The bundled aggregations for an event, if bundled aggregations are + enabled and the event can have bundled aggregations. + """ + # State events and redacted events do not get bundled aggregations. + if event.is_state() or event.internal_metadata.is_redacted(): + return None + + # Do not bundle aggregations for an event which represents an edit or an + # annotation. It does not make sense for them to have related events. + relates_to = event.content.get("m.relates_to") + if isinstance(relates_to, (dict, frozendict)): + relation_type = relates_to.get("rel_type") + if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE): + return None + + event_id = event.event_id + room_id = event.room_id + + # The bundled aggregations to include, a mapping of relation type to a + # type-specific value. Some types include the direct return type here + # while others need more processing during serialization. + aggregations: Dict[str, Any] = {} + + annotations = await self.get_aggregation_groups_for_event(event_id, room_id) + if annotations.chunk: + aggregations[RelationTypes.ANNOTATION] = annotations.to_dict() + + references = await self.get_relations_for_event( + event_id, room_id, RelationTypes.REFERENCE, direction="f" + ) + if references.chunk: + aggregations[RelationTypes.REFERENCE] = references.to_dict() + + edit = None + if event.type == EventTypes.Message: + edit = await self.get_applicable_edit(event_id, room_id) + + if edit: + aggregations[RelationTypes.REPLACE] = edit + + # If this event is the start of a thread, include a summary of the replies. + if self._msc3440_enabled: + ( + thread_count, + latest_thread_event, + ) = await self.get_thread_summary(event_id, room_id) + if latest_thread_event: + aggregations[RelationTypes.THREAD] = { + # Don't bundle aggregations as this could recurse forever. + "latest_event": latest_thread_event, + "count": thread_count, + } + + # Store the bundled aggregations in the event metadata for later use. + return aggregations + + async def get_bundled_aggregations( + self, events: Iterable[EventBase] + ) -> Dict[str, Dict[str, Any]]: + """Generate bundled aggregations for events. + + Args: + events: The iterable of events to calculate bundled aggregations for. + + Returns: + A map of event ID to the bundled aggregation for the event. Not all + events may have bundled aggregations in the results. + """ + # If bundled aggregations are disabled, nothing to do. + if not self._msc1849_enabled: + return {} + + # TODO Parallelize. + results = {} + for event in events: + event_result = await self._get_bundled_aggregation_for_event(event) + if event_result is not None: + results[event.event_id] = event_result + + return results + class RelationsStore(RelationsWorkerStore): pass diff --git a/synapse/storage/databases/main/session.py b/synapse/storage/databases/main/session.py index 5a97120437..e8c776b97a 100644 --- a/synapse/storage/databases/main/session.py +++ b/synapse/storage/databases/main/session.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 6c299cafa5..4b78b4d098 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -560,3 +560,14 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): return await self.db_pool.runInteraction( "get_destinations_paginate_txn", get_destinations_paginate_txn ) + + async def is_destination_known(self, destination: str) -> bool: + """Check if a destination is known to the server.""" + result = await self.db_pool.simple_select_one_onecol( + table="destinations", + keyvalues={"destination": destination}, + retcol="1", + allow_none=True, + desc="is_destination_known", + ) + return bool(result) diff --git a/synapse/types.py b/synapse/types.py index 42aeaf6270..74a2c51857 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -21,6 +21,7 @@ from typing import ( ClassVar, Dict, Mapping, + Match, MutableMapping, Optional, Tuple, @@ -380,7 +381,7 @@ def map_username_to_mxid_localpart( onto different mxids Returns: - unicode: string suitable for a mxid localpart + string suitable for a mxid localpart """ if not isinstance(username, bytes): username = username.encode("utf-8") @@ -388,29 +389,23 @@ def map_username_to_mxid_localpart( # first we sort out upper-case characters if case_sensitive: - def f1(m): + def f1(m: Match[bytes]) -> bytes: return b"_" + m.group().lower() username = UPPER_CASE_PATTERN.sub(f1, username) else: username = username.lower() - # then we sort out non-ascii characters - def f2(m): - g = m.group()[0] - if isinstance(g, str): - # on python 2, we need to do a ord(). On python 3, the - # byte itself will do. - g = ord(g) - return b"=%02x" % (g,) + # then we sort out non-ascii characters by converting to the hex equivalent. + def f2(m: Match[bytes]) -> bytes: + return b"=%02x" % (m.group()[0],) username = NON_MXID_CHARACTER_PATTERN.sub(f2, username) # we also do the =-escaping to mxids starting with an underscore. username = re.sub(b"^_", b"=5f", username) - # we should now only have ascii bytes left, so can decode back to a - # unicode. + # we should now only have ascii bytes left, so can decode back to a string. return username.decode("ascii") diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index ddcf3ee348..734ed84d78 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -13,8 +13,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Iterable from unittest import mock +from parameterized import parameterized from signedjson import key as key, sign as sign from twisted.internet import defer @@ -23,6 +25,7 @@ from synapse.api.constants import RoomEncryptionAlgorithms from synapse.api.errors import Codes, SynapseError from tests import unittest +from tests.test_utils import make_awaitable class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): @@ -765,6 +768,8 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): remote_user_id = "@test:other" local_user_id = "@test:test" + # Pretend we're sharing a room with the user we're querying. If not, + # `_query_devices_for_destination` will return early. self.store.get_rooms_for_user = mock.Mock( return_value=defer.succeed({"some_room_id"}) ) @@ -831,3 +836,94 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): } }, ) + + @parameterized.expand( + [ + # The remote homeserver's response indicates that this user has 0/1/2 devices. + ([],), + (["device_1"],), + (["device_1", "device_2"],), + ] + ) + def test_query_all_devices_caches_result(self, device_ids: Iterable[str]): + """Test that requests for all of a remote user's devices are cached. + + We do this by asserting that only one call over federation was made, and that + the two queries to the local homeserver produce the same response. + """ + local_user_id = "@test:test" + remote_user_id = "@test:other" + request_body = {"device_keys": {remote_user_id: []}} + + response_devices = [ + { + "device_id": device_id, + "keys": { + "algorithms": ["dummy"], + "device_id": device_id, + "keys": {f"dummy:{device_id}": "dummy"}, + "signatures": {device_id: {f"dummy:{device_id}": "dummy"}}, + "unsigned": {}, + "user_id": "@test:other", + }, + } + for device_id in device_ids + ] + + response_body = { + "devices": response_devices, + "user_id": remote_user_id, + "stream_id": 12345, # an integer, according to the spec + } + + e2e_handler = self.hs.get_e2e_keys_handler() + + # Pretend we're sharing a room with the user we're querying. If not, + # `_query_devices_for_destination` will return early. + mock_get_rooms = mock.patch.object( + self.store, + "get_rooms_for_user", + new_callable=mock.MagicMock, + return_value=make_awaitable(["some_room_id"]), + ) + mock_request = mock.patch.object( + self.hs.get_federation_client(), + "query_user_devices", + new_callable=mock.MagicMock, + return_value=make_awaitable(response_body), + ) + + with mock_get_rooms, mock_request as mocked_federation_request: + # Make the first query and sanity check it succeeds. + response_1 = self.get_success( + e2e_handler.query_devices( + request_body, + timeout=10, + from_user_id=local_user_id, + from_device_id="some_device_id", + ) + ) + self.assertEqual(response_1["failures"], {}) + + # We should have made a federation request to do so. + mocked_federation_request.assert_called_once() + + # Reset the mock so we can prove we don't make a second federation request. + mocked_federation_request.reset_mock() + + # Repeat the query. + response_2 = self.get_success( + e2e_handler.query_devices( + request_body, + timeout=10, + from_user_id=local_user_id, + from_device_id="some_device_id", + ) + ) + self.assertEqual(response_2["failures"], {}) + + # We should not have made a second federation request. + mocked_federation_request.assert_not_called() + + # The two requests to the local homeserver should be identical. + self.assertEqual(response_1, response_2) diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py index 742f194257..b70350b6f1 100644 --- a/tests/rest/admin/test_federation.py +++ b/tests/rest/admin/test_federation.py @@ -314,15 +314,12 @@ class FederationTestCase(unittest.HomeserverTestCase): retry_interval, last_successful_stream_ordering, ) in dest: - self.get_success( - self.store.set_destination_retry_timings( - destination, failure_ts, retry_last_ts, retry_interval - ) - ) - self.get_success( - self.store.set_destination_last_successful_stream_ordering( - destination, last_successful_stream_ordering - ) + self._create_destination( + destination, + failure_ts, + retry_last_ts, + retry_interval, + last_successful_stream_ordering, ) # order by default (destination) @@ -413,11 +410,9 @@ class FederationTestCase(unittest.HomeserverTestCase): _search_test(None, "foo") _search_test(None, "bar") - def test_get_single_destination(self) -> None: - """ - Get one specific destinations. - """ - self._create_destinations(5) + def test_get_single_destination_with_retry_timings(self) -> None: + """Get one specific destination which has retry timings.""" + self._create_destinations(1) channel = self.make_request( "GET", @@ -432,6 +427,53 @@ class FederationTestCase(unittest.HomeserverTestCase): # convert channel.json_body into a List self._check_fields([channel.json_body]) + def test_get_single_destination_no_retry_timings(self) -> None: + """Get one specific destination which has no retry timings.""" + self._create_destination("sub0.example.com") + + channel = self.make_request( + "GET", + self.url + "/sub0.example.com", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual("sub0.example.com", channel.json_body["destination"]) + self.assertEqual(0, channel.json_body["retry_last_ts"]) + self.assertEqual(0, channel.json_body["retry_interval"]) + self.assertIsNone(channel.json_body["failure_ts"]) + self.assertIsNone(channel.json_body["last_successful_stream_ordering"]) + + def _create_destination( + self, + destination: str, + failure_ts: Optional[int] = None, + retry_last_ts: int = 0, + retry_interval: int = 0, + last_successful_stream_ordering: Optional[int] = None, + ) -> None: + """Create one specific destination + + Args: + destination: the destination we have successfully sent to + failure_ts: when the server started failing (ms since epoch) + retry_last_ts: time of last retry attempt in unix epoch ms + retry_interval: how long until next retry in ms + last_successful_stream_ordering: the stream_ordering of the most + recent successfully-sent PDU + """ + self.get_success( + self.store.set_destination_retry_timings( + destination, failure_ts, retry_last_ts, retry_interval + ) + ) + if last_successful_stream_ordering is not None: + self.get_success( + self.store.set_destination_last_successful_stream_ordering( + destination, last_successful_stream_ordering + ) + ) + def _create_destinations(self, number_destinations: int) -> None: """Create a number of destinations @@ -440,10 +482,7 @@ class FederationTestCase(unittest.HomeserverTestCase): """ for i in range(0, number_destinations): dest = f"sub{i}.example.com" - self.get_success(self.store.set_destination_retry_timings(dest, 50, 50, 50)) - self.get_success( - self.store.set_destination_last_successful_stream_ordering(dest, 100) - ) + self._create_destination(dest, 50, 50, 50, 100) def _check_fields(self, content: List[JsonDict]) -> None: """Checks that the expected destination attributes are present in content diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index c026d526ef..ff4e81d069 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -93,11 +93,6 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel.json_body, ) - def test_deny_membership(self): - """Test that we deny relations on membership events""" - channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member) - self.assertEquals(400, channel.code, channel.json_body) - def test_deny_invalid_event(self): """Test that we deny relations on non-existant events""" channel = self._send_relation( @@ -1119,7 +1114,8 @@ class RelationsTestCase(unittest.HomeserverTestCase): relation_type: One of `RelationTypes` event_type: The type of the event to create key: The aggregation key used for m.annotation relation type. - content: The content of the created event. + content: The content of the created event. Will be modified to configure + the m.relates_to key based on the other provided parameters. access_token: The access token used to send the relation, defaults to `self.user_token` parent_id: The event_id this relation relates to. If None, then self.parent_id @@ -1130,17 +1126,21 @@ class RelationsTestCase(unittest.HomeserverTestCase): if not access_token: access_token = self.user_token - query = "" - if key: - query = "?key=" + urllib.parse.quote_plus(key.encode("utf-8")) - original_id = parent_id if parent_id else self.parent_id + if content is None: + content = {} + content["m.relates_to"] = { + "event_id": original_id, + "rel_type": relation_type, + } + if key is not None: + content["m.relates_to"]["key"] = key + channel = self.make_request( "POST", - "/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s" - % (self.room, original_id, relation_type, event_type, query), - content or {}, + f"/_matrix/client/v3/rooms/{self.room}/send/{event_type}", + content, access_token=access_token, ) return channel diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index b58452195a..fe5b536d97 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -228,7 +228,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): self.assertIsNotNone(event) time_now = self.clock.time_msec() - serialized = self.get_success(self.serializer.serialize_event(event, time_now)) + serialized = self.serializer.serialize_event(event, time_now) return serialized diff --git a/tests/test_federation.py b/tests/test_federation.py index 3eef1c4c05..2b9804aba0 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -17,7 +17,9 @@ from unittest.mock import Mock from twisted.internet.defer import succeed from synapse.api.errors import FederationError +from synapse.api.room_versions import RoomVersions from synapse.events import make_event_from_dict +from synapse.federation.federation_base import event_from_pdu_json from synapse.logging.context import LoggingContext from synapse.types import UserID, create_requester from synapse.util import Clock @@ -276,3 +278,73 @@ class MessageAcceptTests(unittest.HomeserverTestCase): "ed25519:" + remote_self_signing_key in self_signing_key["keys"].keys(), ) self.assertTrue(remote_self_signing_key in self_signing_key["keys"].values()) + + +class StripUnsignedFromEventsTestCase(unittest.TestCase): + def test_strip_unauthorized_unsigned_values(self): + event1 = { + "sender": "@baduser:test.serv", + "state_key": "@baduser:test.serv", + "event_id": "$event1:test.serv", + "depth": 1000, + "origin_server_ts": 1, + "type": "m.room.member", + "origin": "test.servx", + "content": {"membership": "join"}, + "auth_events": [], + "unsigned": {"malicious garbage": "hackz", "more warez": "more hackz"}, + } + filtered_event = event_from_pdu_json(event1, RoomVersions.V1) + # Make sure unauthorized fields are stripped from unsigned + self.assertNotIn("more warez", filtered_event.unsigned) + + def test_strip_event_maintains_allowed_fields(self): + event2 = { + "sender": "@baduser:test.serv", + "state_key": "@baduser:test.serv", + "event_id": "$event2:test.serv", + "depth": 1000, + "origin_server_ts": 1, + "type": "m.room.member", + "origin": "test.servx", + "auth_events": [], + "content": {"membership": "join"}, + "unsigned": { + "malicious garbage": "hackz", + "more warez": "more hackz", + "age": 14, + "invite_room_state": [], + }, + } + + filtered_event2 = event_from_pdu_json(event2, RoomVersions.V1) + self.assertIn("age", filtered_event2.unsigned) + self.assertEqual(14, filtered_event2.unsigned["age"]) + self.assertNotIn("more warez", filtered_event2.unsigned) + # Invite_room_state is allowed in events of type m.room.member + self.assertIn("invite_room_state", filtered_event2.unsigned) + self.assertEqual([], filtered_event2.unsigned["invite_room_state"]) + + def test_strip_event_removes_fields_based_on_event_type(self): + event3 = { + "sender": "@baduser:test.serv", + "state_key": "@baduser:test.serv", + "event_id": "$event3:test.serv", + "depth": 1000, + "origin_server_ts": 1, + "type": "m.room.power_levels", + "origin": "test.servx", + "content": {}, + "auth_events": [], + "unsigned": { + "malicious garbage": "hackz", + "more warez": "more hackz", + "age": 14, + "invite_room_state": [], + }, + } + filtered_event3 = event_from_pdu_json(event3, RoomVersions.V1) + self.assertIn("age", filtered_event3.unsigned) + # Invite_room_state field is only permitted in event type m.room.member + self.assertNotIn("invite_room_state", filtered_event3.unsigned) + self.assertNotIn("more warez", filtered_event3.unsigned) diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index 15ac2bfeba..f05a373aa0 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -19,7 +19,7 @@ import sys import warnings from asyncio import Future from binascii import unhexlify -from typing import Any, Awaitable, Callable, TypeVar +from typing import Awaitable, Callable, TypeVar from unittest.mock import Mock import attr @@ -46,7 +46,7 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV: raise Exception("awaitable has not yet completed") -def make_awaitable(result: Any) -> Awaitable[Any]: +def make_awaitable(result: TV) -> Awaitable[TV]: """ Makes an awaitable, suitable for mocking an `async` function. This uses Futures as they can be awaited multiple times so can be returned |