From 4b3e30c276ee7fab1896b3b0386847d3000fea32 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 9 Nov 2021 12:11:50 +0000 Subject: Fix typo in `RelationAggregationPaginationServlet` error response (#11278) --- synapse/rest/client/relations.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'synapse/rest') diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 58f6699073..184cfbe196 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -298,7 +298,9 @@ class RelationAggregationPaginationServlet(RestServlet): raise SynapseError(404, "Unknown parent event.") if relation_type not in (RelationTypes.ANNOTATION, None): - raise SynapseError(400, "Relation type must be 'annotation'") + raise SynapseError( + 400, f"Relation type must be '{RelationTypes.ANNOTATION}'" + ) limit = parse_integer(request, "limit", default=5) from_token_str = parse_string(request, "from") -- cgit 1.5.1 From a19d01c3d95f5dbd3a4bb181cb70dacd44135a8b Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 9 Nov 2021 08:10:58 -0500 Subject: Support filtering by relations per MSC3440 (#11236) Adds experimental support for `relation_types` and `relation_senders` fields for filters. --- changelog.d/11236.feature | 1 + synapse/api/filtering.py | 115 +++++++++++----- synapse/handlers/pagination.py | 2 +- synapse/handlers/room.py | 8 +- synapse/handlers/search.py | 8 +- synapse/handlers/sync.py | 18 +-- synapse/rest/admin/rooms.py | 5 +- synapse/rest/client/room.py | 10 +- synapse/rest/client/sync.py | 6 +- synapse/storage/databases/main/relations.py | 58 +++++++- synapse/storage/databases/main/stream.py | 86 +++++++++--- tests/api/test_filtering.py | 107 +++++++++----- tests/handlers/test_sync.py | 5 +- tests/rest/client/test_rooms.py | 154 ++++++++++++++++++++- tests/storage/test_stream.py | 207 ++++++++++++++++++++++++++++ 15 files changed, 680 insertions(+), 110 deletions(-) create mode 100644 changelog.d/11236.feature create mode 100644 tests/storage/test_stream.py (limited to 'synapse/rest') diff --git a/changelog.d/11236.feature b/changelog.d/11236.feature new file mode 100644 index 0000000000..e7aeee2aa6 --- /dev/null +++ b/changelog.d/11236.feature @@ -0,0 +1 @@ +Support filtering by relation senders & types per [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440). diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 4b0a9b2974..13dd6ce248 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -1,7 +1,7 @@ # Copyright 2015, 2016 OpenMarket Ltd # Copyright 2017 Vector Creations Ltd # Copyright 2018-2019 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -86,6 +86,9 @@ ROOM_EVENT_FILTER_SCHEMA = { # cf https://github.com/matrix-org/matrix-doc/pull/2326 "org.matrix.labels": {"type": "array", "items": {"type": "string"}}, "org.matrix.not_labels": {"type": "array", "items": {"type": "string"}}, + # MSC3440, filtering by event relations. + "io.element.relation_senders": {"type": "array", "items": {"type": "string"}}, + "io.element.relation_types": {"type": "array", "items": {"type": "string"}}, }, } @@ -146,14 +149,16 @@ def matrix_user_id_validator(user_id_str: str) -> UserID: class Filtering: def __init__(self, hs: "HomeServer"): - super().__init__() + self._hs = hs self.store = hs.get_datastore() + self.DEFAULT_FILTER_COLLECTION = FilterCollection(hs, {}) + async def get_user_filter( self, user_localpart: str, filter_id: Union[int, str] ) -> "FilterCollection": result = await self.store.get_user_filter(user_localpart, filter_id) - return FilterCollection(result) + return FilterCollection(self._hs, result) def add_user_filter( self, user_localpart: str, user_filter: JsonDict @@ -191,21 +196,22 @@ FilterEvent = TypeVar("FilterEvent", EventBase, UserPresenceState, JsonDict) class FilterCollection: - def __init__(self, filter_json: JsonDict): + def __init__(self, hs: "HomeServer", filter_json: JsonDict): self._filter_json = filter_json room_filter_json = self._filter_json.get("room", {}) self._room_filter = Filter( - {k: v for k, v in room_filter_json.items() if k in ("rooms", "not_rooms")} + hs, + {k: v for k, v in room_filter_json.items() if k in ("rooms", "not_rooms")}, ) - self._room_timeline_filter = Filter(room_filter_json.get("timeline", {})) - self._room_state_filter = Filter(room_filter_json.get("state", {})) - self._room_ephemeral_filter = Filter(room_filter_json.get("ephemeral", {})) - self._room_account_data = Filter(room_filter_json.get("account_data", {})) - self._presence_filter = Filter(filter_json.get("presence", {})) - self._account_data = Filter(filter_json.get("account_data", {})) + self._room_timeline_filter = Filter(hs, room_filter_json.get("timeline", {})) + self._room_state_filter = Filter(hs, room_filter_json.get("state", {})) + self._room_ephemeral_filter = Filter(hs, room_filter_json.get("ephemeral", {})) + self._room_account_data = Filter(hs, room_filter_json.get("account_data", {})) + self._presence_filter = Filter(hs, filter_json.get("presence", {})) + self._account_data = Filter(hs, filter_json.get("account_data", {})) self.include_leave = filter_json.get("room", {}).get("include_leave", False) self.event_fields = filter_json.get("event_fields", []) @@ -232,25 +238,37 @@ class FilterCollection: def include_redundant_members(self) -> bool: return self._room_state_filter.include_redundant_members - def filter_presence( + async def filter_presence( self, events: Iterable[UserPresenceState] ) -> List[UserPresenceState]: - return self._presence_filter.filter(events) + return await self._presence_filter.filter(events) - def filter_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]: - return self._account_data.filter(events) + async def filter_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]: + return await self._account_data.filter(events) - def filter_room_state(self, events: Iterable[EventBase]) -> List[EventBase]: - return self._room_state_filter.filter(self._room_filter.filter(events)) + async def filter_room_state(self, events: Iterable[EventBase]) -> List[EventBase]: + return await self._room_state_filter.filter( + await self._room_filter.filter(events) + ) - def filter_room_timeline(self, events: Iterable[EventBase]) -> List[EventBase]: - return self._room_timeline_filter.filter(self._room_filter.filter(events)) + async def filter_room_timeline( + self, events: Iterable[EventBase] + ) -> List[EventBase]: + return await self._room_timeline_filter.filter( + await self._room_filter.filter(events) + ) - def filter_room_ephemeral(self, events: Iterable[JsonDict]) -> List[JsonDict]: - return self._room_ephemeral_filter.filter(self._room_filter.filter(events)) + async def filter_room_ephemeral(self, events: Iterable[JsonDict]) -> List[JsonDict]: + return await self._room_ephemeral_filter.filter( + await self._room_filter.filter(events) + ) - def filter_room_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]: - return self._room_account_data.filter(self._room_filter.filter(events)) + async def filter_room_account_data( + self, events: Iterable[JsonDict] + ) -> List[JsonDict]: + return await self._room_account_data.filter( + await self._room_filter.filter(events) + ) def blocks_all_presence(self) -> bool: return ( @@ -274,7 +292,9 @@ class FilterCollection: class Filter: - def __init__(self, filter_json: JsonDict): + def __init__(self, hs: "HomeServer", filter_json: JsonDict): + self._hs = hs + self._store = hs.get_datastore() self.filter_json = filter_json self.limit = filter_json.get("limit", 10) @@ -297,6 +317,20 @@ class Filter: self.labels = filter_json.get("org.matrix.labels", None) self.not_labels = filter_json.get("org.matrix.not_labels", []) + # Ideally these would be rejected at the endpoint if they were provided + # and not supported, but that would involve modifying the JSON schema + # based on the homeserver configuration. + if hs.config.experimental.msc3440_enabled: + self.relation_senders = self.filter_json.get( + "io.element.relation_senders", None + ) + self.relation_types = self.filter_json.get( + "io.element.relation_types", None + ) + else: + self.relation_senders = None + self.relation_types = None + def filters_all_types(self) -> bool: return "*" in self.not_types @@ -306,7 +340,7 @@ class Filter: def filters_all_rooms(self) -> bool: return "*" in self.not_rooms - def check(self, event: FilterEvent) -> bool: + def _check(self, event: FilterEvent) -> bool: """Checks whether the filter matches the given event. Args: @@ -420,8 +454,30 @@ class Filter: return room_ids - def filter(self, events: Iterable[FilterEvent]) -> List[FilterEvent]: - return list(filter(self.check, events)) + async def _check_event_relations( + self, events: Iterable[FilterEvent] + ) -> List[FilterEvent]: + # The event IDs to check, mypy doesn't understand the ifinstance check. + event_ids = [event.event_id for event in events if isinstance(event, EventBase)] # type: ignore[attr-defined] + event_ids_to_keep = set( + await self._store.events_have_relations( + event_ids, self.relation_senders, self.relation_types + ) + ) + + return [ + event + for event in events + if not isinstance(event, EventBase) or event.event_id in event_ids_to_keep + ] + + async def filter(self, events: Iterable[FilterEvent]) -> List[FilterEvent]: + result = [event for event in events if self._check(event)] + + if self.relation_senders or self.relation_types: + return await self._check_event_relations(result) + + return result def with_room_ids(self, room_ids: Iterable[str]) -> "Filter": """Returns a new filter with the given room IDs appended. @@ -433,7 +489,7 @@ class Filter: filter: A new filter including the given rooms and the old filter's rooms. """ - newFilter = Filter(self.filter_json) + newFilter = Filter(self._hs, self.filter_json) newFilter.rooms += room_ids return newFilter @@ -444,6 +500,3 @@ def _matches_wildcard(actual_value: Optional[str], filter_value: str) -> bool: return actual_value.startswith(type_prefix) else: return actual_value == filter_value - - -DEFAULT_FILTER_COLLECTION = FilterCollection({}) diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index abfe7be0e3..aa26911aed 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -424,7 +424,7 @@ class PaginationHandler: if events: if event_filter: - events = event_filter.filter(events) + events = await event_filter.filter(events) events = await filter_events_for_client( self.storage, user_id, events, is_peeking=(member_event_id is None) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 969eb3b9b0..7a95e31cbe 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1158,8 +1158,10 @@ class RoomContextHandler: ) if event_filter: - results["events_before"] = event_filter.filter(results["events_before"]) - results["events_after"] = event_filter.filter(results["events_after"]) + results["events_before"] = await event_filter.filter( + results["events_before"] + ) + results["events_after"] = await event_filter.filter(results["events_after"]) results["events_before"] = await filter_evts(results["events_before"]) results["events_after"] = await filter_evts(results["events_after"]) @@ -1195,7 +1197,7 @@ class RoomContextHandler: state_events = list(state[last_event_id].values()) if event_filter: - state_events = event_filter.filter(state_events) + state_events = await event_filter.filter(state_events) results["state"] = await filter_evts(state_events) diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 6e4dff8056..ab7eaab2fb 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -180,7 +180,7 @@ class SearchHandler: % (set(group_keys) - {"room_id", "sender"},), ) - search_filter = Filter(filter_dict) + search_filter = Filter(self.hs, filter_dict) # TODO: Search through left rooms too rooms = await self.store.get_rooms_for_local_user_where_membership_is( @@ -242,7 +242,7 @@ class SearchHandler: rank_map.update({r["event"].event_id: r["rank"] for r in results}) - filtered_events = search_filter.filter([r["event"] for r in results]) + filtered_events = await search_filter.filter([r["event"] for r in results]) events = await filter_events_for_client( self.storage, user.to_string(), filtered_events @@ -292,7 +292,9 @@ class SearchHandler: rank_map.update({r["event"].event_id: r["rank"] for r in results}) - filtered_events = search_filter.filter([r["event"] for r in results]) + filtered_events = await search_filter.filter( + [r["event"] for r in results] + ) events = await filter_events_for_client( self.storage, user.to_string(), filtered_events diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 2c7c6d63a9..891435c14d 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -510,7 +510,7 @@ class SyncHandler: log_kv({"limited": limited}) if potential_recents: - recents = sync_config.filter_collection.filter_room_timeline( + recents = await sync_config.filter_collection.filter_room_timeline( potential_recents ) log_kv({"recents_after_sync_filtering": len(recents)}) @@ -575,8 +575,8 @@ class SyncHandler: log_kv({"loaded_recents": len(events)}) - loaded_recents = sync_config.filter_collection.filter_room_timeline( - events + loaded_recents = ( + await sync_config.filter_collection.filter_room_timeline(events) ) log_kv({"loaded_recents_after_sync_filtering": len(loaded_recents)}) @@ -1015,7 +1015,7 @@ class SyncHandler: return { (e.type, e.state_key): e - for e in sync_config.filter_collection.filter_room_state( + for e in await sync_config.filter_collection.filter_room_state( list(state.values()) ) if e.type != EventTypes.Aliases # until MSC2261 or alternative solution @@ -1383,7 +1383,7 @@ class SyncHandler: sync_config.user ) - account_data_for_user = sync_config.filter_collection.filter_account_data( + account_data_for_user = await sync_config.filter_collection.filter_account_data( [ {"type": account_data_type, "content": content} for account_data_type, content in account_data.items() @@ -1448,7 +1448,7 @@ class SyncHandler: # Deduplicate the presence entries so that there's at most one per user presence = list({p.user_id: p for p in presence}.values()) - presence = sync_config.filter_collection.filter_presence(presence) + presence = await sync_config.filter_collection.filter_presence(presence) sync_result_builder.presence = presence @@ -2021,12 +2021,14 @@ class SyncHandler: ) account_data_events = ( - sync_config.filter_collection.filter_room_account_data( + await sync_config.filter_collection.filter_room_account_data( account_data_events ) ) - ephemeral = sync_config.filter_collection.filter_room_ephemeral(ephemeral) + ephemeral = await sync_config.filter_collection.filter_room_ephemeral( + ephemeral + ) if not ( always_include diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 05c5b4bf0c..05c823d9ce 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -583,6 +583,7 @@ class RoomEventContextServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() + self._hs = hs self.clock = hs.get_clock() self.room_context_handler = hs.get_room_context_handler() self._event_serializer = hs.get_event_client_serializer() @@ -600,7 +601,9 @@ class RoomEventContextServlet(RestServlet): filter_str = parse_string(request, "filter", encoding="utf-8") if filter_str: filter_json = urlparse.unquote(filter_str) - event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json)) + event_filter: Optional[Filter] = Filter( + self._hs, json_decoder.decode(filter_json) + ) else: event_filter = None diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 6a876cfa2f..03a353d53c 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -550,6 +550,7 @@ class RoomMessageListRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() + self._hs = hs self.pagination_handler = hs.get_pagination_handler() self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -567,7 +568,9 @@ class RoomMessageListRestServlet(RestServlet): filter_str = parse_string(request, "filter", encoding="utf-8") if filter_str: filter_json = urlparse.unquote(filter_str) - event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json)) + event_filter: Optional[Filter] = Filter( + self._hs, json_decoder.decode(filter_json) + ) if ( event_filter and event_filter.filter_json.get("event_format", "client") @@ -672,6 +675,7 @@ class RoomEventContextServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() + self._hs = hs self.clock = hs.get_clock() self.room_context_handler = hs.get_room_context_handler() self._event_serializer = hs.get_event_client_serializer() @@ -688,7 +692,9 @@ class RoomEventContextServlet(RestServlet): filter_str = parse_string(request, "filter", encoding="utf-8") if filter_str: filter_json = urlparse.unquote(filter_str) - event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json)) + event_filter: Optional[Filter] = Filter( + self._hs, json_decoder.decode(filter_json) + ) else: event_filter = None diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 913216a7c4..8c0fdb1940 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -29,7 +29,7 @@ from typing import ( from synapse.api.constants import Membership, PresenceState from synapse.api.errors import Codes, StoreError, SynapseError -from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection +from synapse.api.filtering import FilterCollection from synapse.api.presence import UserPresenceState from synapse.events import EventBase from synapse.events.utils import ( @@ -150,7 +150,7 @@ class SyncRestServlet(RestServlet): request_key = (user, timeout, since, filter_id, full_state, device_id) if filter_id is None: - filter_collection = DEFAULT_FILTER_COLLECTION + filter_collection = self.filtering.DEFAULT_FILTER_COLLECTION elif filter_id.startswith("{"): try: filter_object = json_decoder.decode(filter_id) @@ -160,7 +160,7 @@ class SyncRestServlet(RestServlet): except Exception: raise SynapseError(400, "Invalid filter JSON") self.filtering.check_valid_filter(filter_object) - filter_collection = FilterCollection(filter_object) + filter_collection = FilterCollection(self.hs, filter_object) else: try: filter_collection = await self.filtering.get_user_filter( diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 53576ad52f..907af10995 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -20,7 +20,7 @@ import attr from synapse.api.constants import RelationTypes from synapse.events import EventBase from synapse.storage._base import SQLBaseStore -from synapse.storage.database import LoggingTransaction +from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.relations import ( AggregationPaginationToken, @@ -334,6 +334,62 @@ class RelationsWorkerStore(SQLBaseStore): return count, latest_event + async def events_have_relations( + self, + parent_ids: List[str], + relation_senders: Optional[List[str]], + relation_types: Optional[List[str]], + ) -> List[str]: + """Check which events have a relationship from the given senders of the + given types. + + Args: + parent_ids: The events being annotated + relation_senders: The relation senders to check. + relation_types: The relation types to check. + + Returns: + True if the event has at least one relationship from one of the given senders of the given type. + """ + # If no restrictions are given then the event has the required relations. + if not relation_senders and not relation_types: + return parent_ids + + sql = """ + SELECT relates_to_id FROM event_relations + INNER JOIN events USING (event_id) + WHERE + %s; + """ + + def _get_if_event_has_relations(txn) -> List[str]: + clauses: List[str] = [] + clause, args = make_in_list_sql_clause( + txn.database_engine, "relates_to_id", parent_ids + ) + clauses.append(clause) + + if relation_senders: + clause, temp_args = make_in_list_sql_clause( + txn.database_engine, "sender", relation_senders + ) + clauses.append(clause) + args.extend(temp_args) + if relation_types: + clause, temp_args = make_in_list_sql_clause( + txn.database_engine, "relation_type", relation_types + ) + clauses.append(clause) + args.extend(temp_args) + + txn.execute(sql % " AND ".join(clauses), args) + + return [row[0] for row in txn] + + return await self.db_pool.runInteraction( + "get_if_event_has_relations", _get_if_event_has_relations + ) + async def has_user_annotated_event( self, parent_id: str, event_type: str, aggregation_key: str, sender: str ) -> bool: diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index dc7884b1c0..42dc807d17 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -272,31 +272,37 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]: args = [] if event_filter.types: - clauses.append("(%s)" % " OR ".join("type = ?" for _ in event_filter.types)) + clauses.append( + "(%s)" % " OR ".join("event.type = ?" for _ in event_filter.types) + ) args.extend(event_filter.types) for typ in event_filter.not_types: - clauses.append("type != ?") + clauses.append("event.type != ?") args.append(typ) if event_filter.senders: - clauses.append("(%s)" % " OR ".join("sender = ?" for _ in event_filter.senders)) + clauses.append( + "(%s)" % " OR ".join("event.sender = ?" for _ in event_filter.senders) + ) args.extend(event_filter.senders) for sender in event_filter.not_senders: - clauses.append("sender != ?") + clauses.append("event.sender != ?") args.append(sender) if event_filter.rooms: - clauses.append("(%s)" % " OR ".join("room_id = ?" for _ in event_filter.rooms)) + clauses.append( + "(%s)" % " OR ".join("event.room_id = ?" for _ in event_filter.rooms) + ) args.extend(event_filter.rooms) for room_id in event_filter.not_rooms: - clauses.append("room_id != ?") + clauses.append("event.room_id != ?") args.append(room_id) if event_filter.contains_url: - clauses.append("contains_url = ?") + clauses.append("event.contains_url = ?") args.append(event_filter.contains_url) # We're only applying the "labels" filter on the database query, because applying the @@ -307,6 +313,23 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]: clauses.append("(%s)" % " OR ".join("label = ?" for _ in event_filter.labels)) args.extend(event_filter.labels) + # Filter on relation_senders / relation types from the joined tables. + if event_filter.relation_senders: + clauses.append( + "(%s)" + % " OR ".join( + "related_event.sender = ?" for _ in event_filter.relation_senders + ) + ) + args.extend(event_filter.relation_senders) + + if event_filter.relation_types: + clauses.append( + "(%s)" + % " OR ".join("relation_type = ?" for _ in event_filter.relation_types) + ) + args.extend(event_filter.relation_types) + return " AND ".join(clauses), args @@ -1116,7 +1139,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): bounds = generate_pagination_where_clause( direction=direction, - column_names=("topological_ordering", "stream_ordering"), + column_names=("event.topological_ordering", "event.stream_ordering"), from_token=from_bound, to_token=to_bound, engine=self.database_engine, @@ -1133,32 +1156,51 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): select_keywords = "SELECT" join_clause = "" + # Using DISTINCT in this SELECT query is quite expensive, because it + # requires the engine to sort on the entire (not limited) result set, + # i.e. the entire events table. Only use it in scenarios that could result + # in the same event ID occurring multiple times in the results. + needs_distinct = False if event_filter and event_filter.labels: # If we're not filtering on a label, then joining on event_labels will # return as many row for a single event as the number of labels it has. To # avoid this, only join if we're filtering on at least one label. - join_clause = """ + join_clause += """ LEFT JOIN event_labels USING (event_id, room_id, topological_ordering) """ if len(event_filter.labels) > 1: - # Using DISTINCT in this SELECT query is quite expensive, because it - # requires the engine to sort on the entire (not limited) result set, - # i.e. the entire events table. We only need to use it when we're - # filtering on more than two labels, because that's the only scenario - # in which we can possibly to get multiple times the same event ID in - # the results. - select_keywords += "DISTINCT" + # Multiple labels could cause the same event to appear multiple times. + needs_distinct = True + + # If there is a filter on relation_senders and relation_types join to the + # relations table. + if event_filter and ( + event_filter.relation_senders or event_filter.relation_types + ): + # Filtering by relations could cause the same event to appear multiple + # times (since there's no limit on the number of relations to an event). + needs_distinct = True + join_clause += """ + LEFT JOIN event_relations AS relation ON (event.event_id = relation.relates_to_id) + """ + if event_filter.relation_senders: + join_clause += """ + LEFT JOIN events AS related_event ON (relation.event_id = related_event.event_id) + """ + + if needs_distinct: + select_keywords += " DISTINCT" sql = """ %(select_keywords)s - event_id, instance_name, - topological_ordering, stream_ordering - FROM events + event.event_id, event.instance_name, + event.topological_ordering, event.stream_ordering + FROM events AS event %(join_clause)s - WHERE outlier = ? AND room_id = ? AND %(bounds)s - ORDER BY topological_ordering %(order)s, - stream_ordering %(order)s LIMIT ? + WHERE event.outlier = ? AND event.room_id = ? AND %(bounds)s + ORDER BY event.topological_ordering %(order)s, + event.stream_ordering %(order)s LIMIT ? """ % { "select_keywords": select_keywords, "join_clause": join_clause, diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index f44c91a373..b7fc33dc94 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -15,6 +15,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import patch + import jsonschema from synapse.api.constants import EventContentFields @@ -51,9 +53,8 @@ class FilteringTestCase(unittest.HomeserverTestCase): {"presence": {"senders": ["@bar;pik.test.com"]}}, ] for filter in invalid_filters: - with self.assertRaises(SynapseError) as check_filter_error: + with self.assertRaises(SynapseError): self.filtering.check_valid_filter(filter) - self.assertIsInstance(check_filter_error.exception, SynapseError) def test_valid_filters(self): valid_filters = [ @@ -119,12 +120,12 @@ class FilteringTestCase(unittest.HomeserverTestCase): definition = {"types": ["m.room.message", "org.matrix.foo.bar"]} event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar") - self.assertTrue(Filter(definition).check(event)) + self.assertTrue(Filter(self.hs, definition)._check(event)) def test_definition_types_works_with_wildcards(self): definition = {"types": ["m.*", "org.matrix.foo.bar"]} event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar") - self.assertTrue(Filter(definition).check(event)) + self.assertTrue(Filter(self.hs, definition)._check(event)) def test_definition_types_works_with_unknowns(self): definition = {"types": ["m.room.message", "org.matrix.foo.bar"]} @@ -133,24 +134,24 @@ class FilteringTestCase(unittest.HomeserverTestCase): type="now.for.something.completely.different", room_id="!foo:bar", ) - self.assertFalse(Filter(definition).check(event)) + self.assertFalse(Filter(self.hs, definition)._check(event)) def test_definition_not_types_works_with_literals(self): definition = {"not_types": ["m.room.message", "org.matrix.foo.bar"]} event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar") - self.assertFalse(Filter(definition).check(event)) + self.assertFalse(Filter(self.hs, definition)._check(event)) def test_definition_not_types_works_with_wildcards(self): definition = {"not_types": ["m.room.message", "org.matrix.*"]} event = MockEvent( sender="@foo:bar", type="org.matrix.custom.event", room_id="!foo:bar" ) - self.assertFalse(Filter(definition).check(event)) + self.assertFalse(Filter(self.hs, definition)._check(event)) def test_definition_not_types_works_with_unknowns(self): definition = {"not_types": ["m.*", "org.*"]} event = MockEvent(sender="@foo:bar", type="com.nom.nom.nom", room_id="!foo:bar") - self.assertTrue(Filter(definition).check(event)) + self.assertTrue(Filter(self.hs, definition)._check(event)) def test_definition_not_types_takes_priority_over_types(self): definition = { @@ -158,35 +159,35 @@ class FilteringTestCase(unittest.HomeserverTestCase): "types": ["m.room.message", "m.room.topic"], } event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar") - self.assertFalse(Filter(definition).check(event)) + self.assertFalse(Filter(self.hs, definition)._check(event)) def test_definition_senders_works_with_literals(self): definition = {"senders": ["@flibble:wibble"]} event = MockEvent( sender="@flibble:wibble", type="com.nom.nom.nom", room_id="!foo:bar" ) - self.assertTrue(Filter(definition).check(event)) + self.assertTrue(Filter(self.hs, definition)._check(event)) def test_definition_senders_works_with_unknowns(self): definition = {"senders": ["@flibble:wibble"]} event = MockEvent( sender="@challenger:appears", type="com.nom.nom.nom", room_id="!foo:bar" ) - self.assertFalse(Filter(definition).check(event)) + self.assertFalse(Filter(self.hs, definition)._check(event)) def test_definition_not_senders_works_with_literals(self): definition = {"not_senders": ["@flibble:wibble"]} event = MockEvent( sender="@flibble:wibble", type="com.nom.nom.nom", room_id="!foo:bar" ) - self.assertFalse(Filter(definition).check(event)) + self.assertFalse(Filter(self.hs, definition)._check(event)) def test_definition_not_senders_works_with_unknowns(self): definition = {"not_senders": ["@flibble:wibble"]} event = MockEvent( sender="@challenger:appears", type="com.nom.nom.nom", room_id="!foo:bar" ) - self.assertTrue(Filter(definition).check(event)) + self.assertTrue(Filter(self.hs, definition)._check(event)) def test_definition_not_senders_takes_priority_over_senders(self): definition = { @@ -196,14 +197,14 @@ class FilteringTestCase(unittest.HomeserverTestCase): event = MockEvent( sender="@misspiggy:muppets", type="m.room.topic", room_id="!foo:bar" ) - self.assertFalse(Filter(definition).check(event)) + self.assertFalse(Filter(self.hs, definition)._check(event)) def test_definition_rooms_works_with_literals(self): definition = {"rooms": ["!secretbase:unknown"]} event = MockEvent( sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown" ) - self.assertTrue(Filter(definition).check(event)) + self.assertTrue(Filter(self.hs, definition)._check(event)) def test_definition_rooms_works_with_unknowns(self): definition = {"rooms": ["!secretbase:unknown"]} @@ -212,7 +213,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): type="m.room.message", room_id="!anothersecretbase:unknown", ) - self.assertFalse(Filter(definition).check(event)) + self.assertFalse(Filter(self.hs, definition)._check(event)) def test_definition_not_rooms_works_with_literals(self): definition = {"not_rooms": ["!anothersecretbase:unknown"]} @@ -221,7 +222,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): type="m.room.message", room_id="!anothersecretbase:unknown", ) - self.assertFalse(Filter(definition).check(event)) + self.assertFalse(Filter(self.hs, definition)._check(event)) def test_definition_not_rooms_works_with_unknowns(self): definition = {"not_rooms": ["!secretbase:unknown"]} @@ -230,7 +231,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): type="m.room.message", room_id="!anothersecretbase:unknown", ) - self.assertTrue(Filter(definition).check(event)) + self.assertTrue(Filter(self.hs, definition)._check(event)) def test_definition_not_rooms_takes_priority_over_rooms(self): definition = { @@ -240,7 +241,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): event = MockEvent( sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown" ) - self.assertFalse(Filter(definition).check(event)) + self.assertFalse(Filter(self.hs, definition)._check(event)) def test_definition_combined_event(self): definition = { @@ -256,7 +257,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): type="m.room.message", # yup room_id="!stage:unknown", # yup ) - self.assertTrue(Filter(definition).check(event)) + self.assertTrue(Filter(self.hs, definition)._check(event)) def test_definition_combined_event_bad_sender(self): definition = { @@ -272,7 +273,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): type="m.room.message", # yup room_id="!stage:unknown", # yup ) - self.assertFalse(Filter(definition).check(event)) + self.assertFalse(Filter(self.hs, definition)._check(event)) def test_definition_combined_event_bad_room(self): definition = { @@ -288,7 +289,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): type="m.room.message", # yup room_id="!piggyshouse:muppets", # nope ) - self.assertFalse(Filter(definition).check(event)) + self.assertFalse(Filter(self.hs, definition)._check(event)) def test_definition_combined_event_bad_type(self): definition = { @@ -304,7 +305,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): type="muppets.misspiggy.kisses", # nope room_id="!stage:unknown", # yup ) - self.assertFalse(Filter(definition).check(event)) + self.assertFalse(Filter(self.hs, definition)._check(event)) def test_filter_labels(self): definition = {"org.matrix.labels": ["#fun"]} @@ -315,7 +316,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): content={EventContentFields.LABELS: ["#fun"]}, ) - self.assertTrue(Filter(definition).check(event)) + self.assertTrue(Filter(self.hs, definition)._check(event)) event = MockEvent( sender="@foo:bar", @@ -324,7 +325,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): content={EventContentFields.LABELS: ["#notfun"]}, ) - self.assertFalse(Filter(definition).check(event)) + self.assertFalse(Filter(self.hs, definition)._check(event)) def test_filter_not_labels(self): definition = {"org.matrix.not_labels": ["#fun"]} @@ -335,7 +336,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): content={EventContentFields.LABELS: ["#fun"]}, ) - self.assertFalse(Filter(definition).check(event)) + self.assertFalse(Filter(self.hs, definition)._check(event)) event = MockEvent( sender="@foo:bar", @@ -344,7 +345,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): content={EventContentFields.LABELS: ["#notfun"]}, ) - self.assertTrue(Filter(definition).check(event)) + self.assertTrue(Filter(self.hs, definition)._check(event)) def test_filter_presence_match(self): user_filter_json = {"presence": {"types": ["m.*"]}} @@ -362,7 +363,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): ) ) - results = user_filter.filter_presence(events=events) + results = self.get_success(user_filter.filter_presence(events=events)) self.assertEquals(events, results) def test_filter_presence_no_match(self): @@ -386,7 +387,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): ) ) - results = user_filter.filter_presence(events=events) + results = self.get_success(user_filter.filter_presence(events=events)) self.assertEquals([], results) def test_filter_room_state_match(self): @@ -405,7 +406,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): ) ) - results = user_filter.filter_room_state(events=events) + results = self.get_success(user_filter.filter_room_state(events=events)) self.assertEquals(events, results) def test_filter_room_state_no_match(self): @@ -426,7 +427,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): ) ) - results = user_filter.filter_room_state(events) + results = self.get_success(user_filter.filter_room_state(events)) self.assertEquals([], results) def test_filter_rooms(self): @@ -441,10 +442,52 @@ class FilteringTestCase(unittest.HomeserverTestCase): "!not_included:example.com", # Disallowed because not in rooms. ] - filtered_room_ids = list(Filter(definition).filter_rooms(room_ids)) + filtered_room_ids = list(Filter(self.hs, definition).filter_rooms(room_ids)) self.assertEquals(filtered_room_ids, ["!allowed:example.com"]) + @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) + def test_filter_relations(self): + events = [ + # An event without a relation. + MockEvent( + event_id="$no_relation", + sender="@foo:bar", + type="org.matrix.custom.event", + room_id="!foo:bar", + ), + # An event with a relation. + MockEvent( + event_id="$with_relation", + sender="@foo:bar", + type="org.matrix.custom.event", + room_id="!foo:bar", + ), + # Non-EventBase objects get passed through. + {}, + ] + + # For the following tests we patch the datastore method (intead of injecting + # events). This is a bit cheeky, but tests the logic of _check_event_relations. + + # Filter for a particular sender. + definition = { + "io.element.relation_senders": ["@foo:bar"], + } + + async def events_have_relations(*args, **kwargs): + return ["$with_relation"] + + with patch.object( + self.datastore, "events_have_relations", new=events_have_relations + ): + filtered_events = list( + self.get_success( + Filter(self.hs, definition)._check_event_relations(events) + ) + ) + self.assertEquals(filtered_events, events[1:]) + def test_add_filter(self): user_filter_json = {"room": {"state": {"types": ["m.*"]}}} diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 339c039914..638186f173 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -13,10 +13,11 @@ # limitations under the License. from typing import Optional +from unittest.mock import Mock from synapse.api.constants import EventTypes, JoinRules from synapse.api.errors import Codes, ResourceLimitError -from synapse.api.filtering import DEFAULT_FILTER_COLLECTION +from synapse.api.filtering import Filtering from synapse.api.room_versions import RoomVersions from synapse.handlers.sync import SyncConfig from synapse.rest import admin @@ -197,7 +198,7 @@ def generate_sync_config( _request_key += 1 return SyncConfig( user=UserID.from_string(user_id), - filter_collection=DEFAULT_FILTER_COLLECTION, + filter_collection=Filtering(Mock()).DEFAULT_FILTER_COLLECTION, is_guest=False, request_key=("request_key", _request_key), device_id=device_id, diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 376853fd65..10a4a4dc5e 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -25,7 +25,12 @@ from urllib import parse as urlparse from twisted.internet import defer import synapse.rest.admin -from synapse.api.constants import EventContentFields, EventTypes, Membership +from synapse.api.constants import ( + EventContentFields, + EventTypes, + Membership, + RelationTypes, +) from synapse.api.errors import Codes, HttpResponseException from synapse.handlers.pagination import PurgeStatus from synapse.rest import admin @@ -2157,6 +2162,153 @@ class LabelsTestCase(unittest.HomeserverTestCase): return event_id +class RelationsTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def default_config(self): + config = super().default_config() + config["experimental_features"] = {"msc3440_enabled": True} + return config + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("test", "test") + self.tok = self.login("test", "test") + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + self.second_user_id = self.register_user("second", "test") + self.second_tok = self.login("second", "test") + self.helper.join( + room=self.room_id, user=self.second_user_id, tok=self.second_tok + ) + + self.third_user_id = self.register_user("third", "test") + self.third_tok = self.login("third", "test") + self.helper.join(room=self.room_id, user=self.third_user_id, tok=self.third_tok) + + # An initial event with a relation from second user. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "Message 1"}, + tok=self.tok, + ) + self.event_id_1 = res["event_id"] + self.helper.send_event( + room_id=self.room_id, + type="m.reaction", + content={ + "m.relates_to": { + "rel_type": RelationTypes.ANNOTATION, + "event_id": self.event_id_1, + "key": "👍", + } + }, + tok=self.second_tok, + ) + + # Another event with a relation from third user. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "Message 2"}, + tok=self.tok, + ) + self.event_id_2 = res["event_id"] + self.helper.send_event( + room_id=self.room_id, + type="m.reaction", + content={ + "m.relates_to": { + "rel_type": RelationTypes.REFERENCE, + "event_id": self.event_id_2, + } + }, + tok=self.third_tok, + ) + + # An event with no relations. + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "No relations"}, + tok=self.tok, + ) + + def _filter_messages(self, filter: JsonDict) -> List[JsonDict]: + """Make a request to /messages with a filter, returns the chunk of events.""" + channel = self.make_request( + "GET", + "/rooms/%s/messages?filter=%s&dir=b" % (self.room_id, json.dumps(filter)), + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + return channel.json_body["chunk"] + + def test_filter_relation_senders(self): + # Messages which second user reacted to. + filter = {"io.element.relation_senders": [self.second_user_id]} + chunk = self._filter_messages(filter) + self.assertEqual(len(chunk), 1, chunk) + self.assertEqual(chunk[0]["event_id"], self.event_id_1) + + # Messages which third user reacted to. + filter = {"io.element.relation_senders": [self.third_user_id]} + chunk = self._filter_messages(filter) + self.assertEqual(len(chunk), 1, chunk) + self.assertEqual(chunk[0]["event_id"], self.event_id_2) + + # Messages which either user reacted to. + filter = { + "io.element.relation_senders": [self.second_user_id, self.third_user_id] + } + chunk = self._filter_messages(filter) + self.assertEqual(len(chunk), 2, chunk) + self.assertCountEqual( + [c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2] + ) + + def test_filter_relation_type(self): + # Messages which have annotations. + filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]} + chunk = self._filter_messages(filter) + self.assertEqual(len(chunk), 1, chunk) + self.assertEqual(chunk[0]["event_id"], self.event_id_1) + + # Messages which have references. + filter = {"io.element.relation_types": [RelationTypes.REFERENCE]} + chunk = self._filter_messages(filter) + self.assertEqual(len(chunk), 1, chunk) + self.assertEqual(chunk[0]["event_id"], self.event_id_2) + + # Messages which have either annotations or references. + filter = { + "io.element.relation_types": [ + RelationTypes.ANNOTATION, + RelationTypes.REFERENCE, + ] + } + chunk = self._filter_messages(filter) + self.assertEqual(len(chunk), 2, chunk) + self.assertCountEqual( + [c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2] + ) + + def test_filter_relation_senders_and_type(self): + # Messages which second user reacted to. + filter = { + "io.element.relation_senders": [self.second_user_id], + "io.element.relation_types": [RelationTypes.ANNOTATION], + } + chunk = self._filter_messages(filter) + self.assertEqual(len(chunk), 1, chunk) + self.assertEqual(chunk[0]["event_id"], self.event_id_1) + + class ContextTestCase(unittest.HomeserverTestCase): servlets = [ diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py new file mode 100644 index 0000000000..ce782c7e1d --- /dev/null +++ b/tests/storage/test_stream.py @@ -0,0 +1,207 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +from synapse.api.constants import EventTypes, RelationTypes +from synapse.api.filtering import Filter +from synapse.events import EventBase +from synapse.rest import admin +from synapse.rest.client import login, room +from synapse.types import JsonDict + +from tests.unittest import HomeserverTestCase + + +class PaginationTestCase(HomeserverTestCase): + """ + Test the pre-filtering done in the pagination code. + + This is similar to some of the tests in tests.rest.client.test_rooms but here + we ensure that the filtering done in the database is applied successfully. + """ + + servlets = [ + admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def default_config(self): + config = super().default_config() + config["experimental_features"] = {"msc3440_enabled": True} + return config + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("test", "test") + self.tok = self.login("test", "test") + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + self.second_user_id = self.register_user("second", "test") + self.second_tok = self.login("second", "test") + self.helper.join( + room=self.room_id, user=self.second_user_id, tok=self.second_tok + ) + + self.third_user_id = self.register_user("third", "test") + self.third_tok = self.login("third", "test") + self.helper.join(room=self.room_id, user=self.third_user_id, tok=self.third_tok) + + # An initial event with a relation from second user. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "Message 1"}, + tok=self.tok, + ) + self.event_id_1 = res["event_id"] + self.helper.send_event( + room_id=self.room_id, + type="m.reaction", + content={ + "m.relates_to": { + "rel_type": RelationTypes.ANNOTATION, + "event_id": self.event_id_1, + "key": "👍", + } + }, + tok=self.second_tok, + ) + + # Another event with a relation from third user. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "Message 2"}, + tok=self.tok, + ) + self.event_id_2 = res["event_id"] + self.helper.send_event( + room_id=self.room_id, + type="m.reaction", + content={ + "m.relates_to": { + "rel_type": RelationTypes.REFERENCE, + "event_id": self.event_id_2, + } + }, + tok=self.third_tok, + ) + + # An event with no relations. + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "No relations"}, + tok=self.tok, + ) + + def _filter_messages(self, filter: JsonDict) -> List[EventBase]: + """Make a request to /messages with a filter, returns the chunk of events.""" + + from_token = self.get_success( + self.hs.get_event_sources().get_current_token_for_pagination() + ) + + events, next_key = self.get_success( + self.hs.get_datastore().paginate_room_events( + room_id=self.room_id, + from_key=from_token.room_key, + to_key=None, + direction="b", + limit=10, + event_filter=Filter(self.hs, filter), + ) + ) + + return events + + def test_filter_relation_senders(self): + # Messages which second user reacted to. + filter = {"io.element.relation_senders": [self.second_user_id]} + chunk = self._filter_messages(filter) + self.assertEqual(len(chunk), 1, chunk) + self.assertEqual(chunk[0].event_id, self.event_id_1) + + # Messages which third user reacted to. + filter = {"io.element.relation_senders": [self.third_user_id]} + chunk = self._filter_messages(filter) + self.assertEqual(len(chunk), 1, chunk) + self.assertEqual(chunk[0].event_id, self.event_id_2) + + # Messages which either user reacted to. + filter = { + "io.element.relation_senders": [self.second_user_id, self.third_user_id] + } + chunk = self._filter_messages(filter) + self.assertEqual(len(chunk), 2, chunk) + self.assertCountEqual( + [c.event_id for c in chunk], [self.event_id_1, self.event_id_2] + ) + + def test_filter_relation_type(self): + # Messages which have annotations. + filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]} + chunk = self._filter_messages(filter) + self.assertEqual(len(chunk), 1, chunk) + self.assertEqual(chunk[0].event_id, self.event_id_1) + + # Messages which have references. + filter = {"io.element.relation_types": [RelationTypes.REFERENCE]} + chunk = self._filter_messages(filter) + self.assertEqual(len(chunk), 1, chunk) + self.assertEqual(chunk[0].event_id, self.event_id_2) + + # Messages which have either annotations or references. + filter = { + "io.element.relation_types": [ + RelationTypes.ANNOTATION, + RelationTypes.REFERENCE, + ] + } + chunk = self._filter_messages(filter) + self.assertEqual(len(chunk), 2, chunk) + self.assertCountEqual( + [c.event_id for c in chunk], [self.event_id_1, self.event_id_2] + ) + + def test_filter_relation_senders_and_type(self): + # Messages which second user reacted to. + filter = { + "io.element.relation_senders": [self.second_user_id], + "io.element.relation_types": [RelationTypes.ANNOTATION], + } + chunk = self._filter_messages(filter) + self.assertEqual(len(chunk), 1, chunk) + self.assertEqual(chunk[0].event_id, self.event_id_1) + + def test_duplicate_relation(self): + """An event should only be returned once if there are multiple relations to it.""" + self.helper.send_event( + room_id=self.room_id, + type="m.reaction", + content={ + "m.relates_to": { + "rel_type": RelationTypes.ANNOTATION, + "event_id": self.event_id_1, + "key": "A", + } + }, + tok=self.second_tok, + ) + + filter = {"io.element.relation_senders": [self.second_user_id]} + chunk = self._filter_messages(filter) + self.assertEqual(len(chunk), 1, chunk) + self.assertEqual(chunk[0].event_id, self.event_id_1) -- cgit 1.5.1 From b6f4d122efb86e3fc44e358cf573dc2caa6ff634 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 9 Nov 2021 13:11:47 +0000 Subject: Allow admins to proactively block rooms (#11228) Co-authored-by: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> --- changelog.d/11228.feature | 1 + docs/admin_api/rooms.md | 16 +++++++---- synapse/handlers/room.py | 51 ++++++++++++++++++++++++++-------- synapse/rest/admin/rooms.py | 21 +++++++++++--- synapse/storage/databases/main/room.py | 7 ++++- tests/rest/admin/test_room.py | 28 +++++++++++++++++++ 6 files changed, 103 insertions(+), 21 deletions(-) create mode 100644 changelog.d/11228.feature (limited to 'synapse/rest') diff --git a/changelog.d/11228.feature b/changelog.d/11228.feature new file mode 100644 index 0000000000..33c1756b50 --- /dev/null +++ b/changelog.d/11228.feature @@ -0,0 +1 @@ +Allow the admin [Delete Room API](https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#delete-room-api) to block a room without the need to join it. diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md index ab6b82a082..41a4961d00 100644 --- a/docs/admin_api/rooms.md +++ b/docs/admin_api/rooms.md @@ -396,13 +396,17 @@ The new room will be created with the user specified by the `new_room_user_id` p as room administrator and will contain a message explaining what happened. Users invited to the new room will have power level `-10` by default, and thus be unable to speak. -If `block` is `True` it prevents new joins to the old room. +If `block` is `true`, users will be prevented from joining the old room. +This option can also be used to pre-emptively block a room, even if it's unknown +to this homeserver. In this case, the room will be blocked, and no further action +will be taken. If `block` is `false`, attempting to delete an unknown room is +invalid and will be rejected as a bad request. This API will remove all trace of the old room from your database after removing all local users. If `purge` is `true` (the default), all traces of the old room will be removed from your database after removing all local users. If you do not want this to happen, set `purge` to `false`. -Depending on the amount of history being purged a call to the API may take +Depending on the amount of history being purged, a call to the API may take several minutes or longer. The local server will only have the power to move local user and room aliases to @@ -464,8 +468,9 @@ The following JSON body parameters are available: `new_room_user_id` in the new room. Ideally this will clearly convey why the original room was shut down. Defaults to `Sharing illegal content on this server is not permitted and rooms in violation will be blocked.` -* `block` - Optional. If set to `true`, this room will be added to a blocking list, preventing - future attempts to join the room. Defaults to `false`. +* `block` - Optional. If set to `true`, this room will be added to a blocking list, + preventing future attempts to join the room. Rooms can be blocked + even if they're not yet known to the homeserver. Defaults to `false`. * `purge` - Optional. If set to `true`, it will remove all traces of the room from your database. Defaults to `true`. * `force_purge` - Optional, and ignored unless `purge` is `true`. If set to `true`, it @@ -483,7 +488,8 @@ The following fields are returned in the JSON response body: * `failed_to_kick_users` - An array of users (`user_id`) that that were not kicked. * `local_aliases` - An array of strings representing the local aliases that were migrated from the old room to the new. -* `new_room_id` - A string representing the room ID of the new room. +* `new_room_id` - A string representing the room ID of the new room, or `null` if + no such room was created. ## Undoing room deletions diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 7a95e31cbe..11af30eee7 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Contains functions for performing events on rooms.""" - +"""Contains functions for performing actions on rooms.""" import itertools import logging import math @@ -31,6 +30,8 @@ from typing import ( Tuple, ) +from typing_extensions import TypedDict + from synapse.api.constants import ( EventContentFields, EventTypes, @@ -1277,6 +1278,13 @@ class RoomEventSource(EventSource[RoomStreamToken, EventBase]): return self.store.get_room_events_max_id(room_id) +class ShutdownRoomResponse(TypedDict): + kicked_users: List[str] + failed_to_kick_users: List[str] + local_aliases: List[str] + new_room_id: Optional[str] + + class RoomShutdownHandler: DEFAULT_MESSAGE = ( @@ -1302,7 +1310,7 @@ class RoomShutdownHandler: new_room_name: Optional[str] = None, message: Optional[str] = None, block: bool = False, - ) -> dict: + ) -> ShutdownRoomResponse: """ Shuts down a room. Moves all local users and room aliases automatically to a new room if `new_room_user_id` is set. Otherwise local users only @@ -1336,8 +1344,13 @@ class RoomShutdownHandler: Defaults to `Sharing illegal content on this server is not permitted and rooms in violation will be blocked.` block: - If set to `true`, this room will be added to a blocking list, - preventing future attempts to join the room. Defaults to `false`. + If set to `True`, users will be prevented from joining the old + room. This option can also be used to pre-emptively block a room, + even if it's unknown to this homeserver. In this case, the room + will be blocked, and no further action will be taken. If `False`, + attempting to delete an unknown room is invalid. + + Defaults to `False`. Returns: a dict containing the following keys: kicked_users: An array of users (`user_id`) that were kicked. @@ -1346,7 +1359,9 @@ class RoomShutdownHandler: local_aliases: An array of strings representing the local aliases that were migrated from the old room to the new. - new_room_id: A string representing the room ID of the new room. + new_room_id: + A string representing the room ID of the new room, or None if + no such room was created. """ if not new_room_name: @@ -1357,14 +1372,28 @@ class RoomShutdownHandler: if not RoomID.is_valid(room_id): raise SynapseError(400, "%s is not a legal room ID" % (room_id,)) - if not await self.store.get_room(room_id): - raise NotFoundError("Unknown room id %s" % (room_id,)) - - # This will work even if the room is already blocked, but that is - # desirable in case the first attempt at blocking the room failed below. + # Action the block first (even if the room doesn't exist yet) if block: + # This will work even if the room is already blocked, but that is + # desirable in case the first attempt at blocking the room failed below. await self.store.block_room(room_id, requester_user_id) + if not await self.store.get_room(room_id): + if block: + # We allow you to block an unknown room. + return { + "kicked_users": [], + "failed_to_kick_users": [], + "local_aliases": [], + "new_room_id": None, + } + else: + # But if you don't want to preventatively block another room, + # this function can't do anything useful. + raise NotFoundError( + "Cannot shut down room: unknown room id %s" % (room_id,) + ) + if new_room_user_id is not None: if not self.hs.is_mine_id(new_room_user_id): raise SynapseError( diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 05c823d9ce..a2f4edebb8 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -13,7 +13,7 @@ # limitations under the License. import logging from http import HTTPStatus -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple, cast from urllib import parse as urlparse from synapse.api.constants import EventTypes, JoinRules, Membership @@ -239,9 +239,22 @@ class RoomRestServlet(RestServlet): # Purge room if purge: - await pagination_handler.purge_room(room_id, force=force_purge) - - return 200, ret + try: + await pagination_handler.purge_room(room_id, force=force_purge) + except NotFoundError: + if block: + # We can block unknown rooms with this endpoint, in which case + # a failed purge is expected. + pass + else: + # But otherwise, we expect this purge to have succeeded. + raise + + # Cast safety: cast away the knowledge that this is a TypedDict. + # See https://github.com/python/mypy/issues/4976#issuecomment-579883622 + # for some discussion on why this is necessary. Either way, + # `ret` is an opaque dictionary blob as far as the rest of the app cares. + return 200, cast(JsonDict, ret) class RoomMembersRestServlet(RestServlet): diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index cefc77fa0f..17b398bb69 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -1751,7 +1751,12 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): ) async def block_room(self, room_id: str, user_id: str) -> None: - """Marks the room as blocked. Can be called multiple times. + """Marks the room as blocked. + + Can be called multiple times (though we'll only track the last user to + block this room). + + Can be called on a room unknown to this homeserver. Args: room_id: Room to block diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 46116644ce..11ec54c82e 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -14,9 +14,12 @@ import json import urllib.parse +from http import HTTPStatus from typing import List, Optional from unittest.mock import Mock +from parameterized import parameterized + import synapse.rest.admin from synapse.api.constants import EventTypes, Membership from synapse.api.errors import Codes @@ -281,6 +284,31 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): self._is_blocked(self.room_id, expect=True) self._has_no_members(self.room_id) + @parameterized.expand([(True,), (False,)]) + def test_block_unknown_room(self, purge: bool) -> None: + """ + We can block an unknown room. In this case, the `purge` argument + should be ignored. + """ + room_id = "!unknown:test" + + # The room isn't already in the blocked rooms table + self._is_blocked(room_id, expect=False) + + # Request the room be blocked. + channel = self.make_request( + "DELETE", + f"/_synapse/admin/v1/rooms/{room_id}", + {"block": True, "purge": purge}, + access_token=self.admin_user_tok, + ) + + # The room is now blocked. + self.assertEqual( + HTTPStatus.OK, int(channel.result["code"]), msg=channel.result["body"] + ) + self._is_blocked(room_id) + def test_shutdown_room_consent(self): """Test that we can shutdown rooms with local users who have not yet accepted the privacy policy. This used to fail when we tried to -- cgit 1.5.1 From 6ce19b94e84eb6ad83ef303f88d8bd59a3d414e6 Mon Sep 17 00:00:00 2001 From: Neeeflix <35348173+Neeeflix@users.noreply.github.com> Date: Wed, 10 Nov 2021 21:49:43 +0100 Subject: Fix error in thumbnail generation (#11288) Signed-off-by: Jonas Zeunert --- changelog.d/11288.bugfix | 1 + synapse/rest/media/v1/thumbnailer.py | 8 ++++---- 2 files changed, 5 insertions(+), 4 deletions(-) create mode 100644 changelog.d/11288.bugfix (limited to 'synapse/rest') diff --git a/changelog.d/11288.bugfix b/changelog.d/11288.bugfix new file mode 100644 index 0000000000..d85b1779ba --- /dev/null +++ b/changelog.d/11288.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where uploading extremely thin images (e.g. 1000x1) would fail. Contributed by @Neeeflix. diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index 46701a8b83..5e17664b5b 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -101,8 +101,8 @@ class Thumbnailer: fits within the given rectangle:: (w_in / h_in) = (w_out / h_out) - w_out = min(w_max, h_max * (w_in / h_in)) - h_out = min(h_max, w_max * (h_in / w_in)) + w_out = max(min(w_max, h_max * (w_in / h_in)), 1) + h_out = max(min(h_max, w_max * (h_in / w_in)), 1) Args: max_width: The largest possible width. @@ -110,9 +110,9 @@ class Thumbnailer: """ if max_width * self.height < max_height * self.width: - return max_width, (max_width * self.height) // self.width + return max_width, max((max_width * self.height) // self.width, 1) else: - return (max_height * self.width) // self.height, max_height + return max((max_height * self.width) // self.height, 1), max_height def _resize(self, width: int, height: int) -> Image.Image: # 1-bit or 8-bit color palette images need converting to RGB -- cgit 1.5.1 From 8840a7b7f1074073c49135d13918d9e4d4a04577 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Fri, 12 Nov 2021 13:35:31 +0100 Subject: Convert delete room admin API to async endpoint (#11223) Signed-off-by: Dirk Klimpel dirk@klimpel.org --- changelog.d/11223.feature | 1 + docs/admin_api/purge_history_api.md | 2 + docs/admin_api/rooms.md | 189 +++++++++- synapse/handlers/pagination.py | 289 +++++++++++++- synapse/handlers/room.py | 13 +- synapse/rest/admin/__init__.py | 6 + synapse/rest/admin/rooms.py | 134 ++++++- tests/rest/admin/test_admin.py | 48 +++ tests/rest/admin/test_room.py | 726 ++++++++++++++++++++++++++++++++---- 9 files changed, 1317 insertions(+), 91 deletions(-) create mode 100644 changelog.d/11223.feature (limited to 'synapse/rest') diff --git a/changelog.d/11223.feature b/changelog.d/11223.feature new file mode 100644 index 0000000000..55ea693dcd --- /dev/null +++ b/changelog.d/11223.feature @@ -0,0 +1 @@ +Add a new version of delete room admin API `DELETE /_synapse/admin/v2/rooms/` to run it in background. Contributed by @dklimpel. \ No newline at end of file diff --git a/docs/admin_api/purge_history_api.md b/docs/admin_api/purge_history_api.md index bd29e29ab8..277e28d9cb 100644 --- a/docs/admin_api/purge_history_api.md +++ b/docs/admin_api/purge_history_api.md @@ -70,6 +70,8 @@ This API returns a JSON body like the following: The status will be one of `active`, `complete`, or `failed`. +If `status` is `failed` there will be a string `error` with the error message. + ## Reclaim disk space (Postgres) To reclaim the disk space and return it to the operating system, you need to run diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md index 41a4961d00..6a6ae92d66 100644 --- a/docs/admin_api/rooms.md +++ b/docs/admin_api/rooms.md @@ -4,6 +4,9 @@ - [Room Members API](#room-members-api) - [Room State API](#room-state-api) - [Delete Room API](#delete-room-api) + * [Version 1 (old version)](#version-1-old-version) + * [Version 2 (new version)](#version-2-new-version) + * [Status of deleting rooms](#status-of-deleting-rooms) * [Undoing room shutdowns](#undoing-room-shutdowns) - [Make Room Admin API](#make-room-admin-api) - [Forward Extremities Admin API](#forward-extremities-admin-api) @@ -397,10 +400,10 @@ as room administrator and will contain a message explaining what happened. Users to the new room will have power level `-10` by default, and thus be unable to speak. If `block` is `true`, users will be prevented from joining the old room. -This option can also be used to pre-emptively block a room, even if it's unknown -to this homeserver. In this case, the room will be blocked, and no further action -will be taken. If `block` is `false`, attempting to delete an unknown room is -invalid and will be rejected as a bad request. +This option can in [Version 1](#version-1-old-version) also be used to pre-emptively +block a room, even if it's unknown to this homeserver. In this case, the room will be +blocked, and no further action will be taken. If `block` is `false`, attempting to +delete an unknown room is invalid and will be rejected as a bad request. This API will remove all trace of the old room from your database after removing all local users. If `purge` is `true` (the default), all traces of the old room will @@ -412,6 +415,17 @@ several minutes or longer. The local server will only have the power to move local user and room aliases to the new room. Users on other servers will be unaffected. +To use it, you will need to authenticate by providing an ``access_token`` for a +server admin: see [Admin API](../usage/administration/admin_api). + +## Version 1 (old version) + +This version works synchronously. That means you only get the response once the server has +finished the action, which may take a long time. If you request the same action +a second time, and the server has not finished the first one, the second request will block. +This is fixed in version 2 of this API. The parameters are the same in both APIs. +This API will become deprecated in the future. + The API is: ``` @@ -430,9 +444,6 @@ with a body of: } ``` -To use it, you will need to authenticate by providing an ``access_token`` for a -server admin: see [Admin API](../usage/administration/admin_api). - A response body like the following is returned: ```json @@ -449,6 +460,44 @@ A response body like the following is returned: } ``` +The parameters and response values have the same format as +[version 2](#version-2-new-version) of the API. + +## Version 2 (new version) + +**Note**: This API is new, experimental and "subject to change". + +This version works asynchronously, meaning you get the response from server immediately +while the server works on that task in background. You can then request the status of the action +to check if it has completed. + +The API is: + +``` +DELETE /_synapse/admin/v2/rooms/ +``` + +with a body of: + +```json +{ + "new_room_user_id": "@someuser:example.com", + "room_name": "Content Violation Notification", + "message": "Bad Room has been shutdown due to content violations on this server. Please review our Terms of Service.", + "block": true, + "purge": true +} +``` + +The API starts the shut down and purge running, and returns immediately with a JSON body with +a purge id: + +```json +{ + "delete_id": "" +} +``` + **Parameters** The following parameters should be set in the URL: @@ -470,7 +519,8 @@ The following JSON body parameters are available: is not permitted and rooms in violation will be blocked.` * `block` - Optional. If set to `true`, this room will be added to a blocking list, preventing future attempts to join the room. Rooms can be blocked - even if they're not yet known to the homeserver. Defaults to `false`. + even if they're not yet known to the homeserver (only with + [Version 1](#version-1-old-version) of the API). Defaults to `false`. * `purge` - Optional. If set to `true`, it will remove all traces of the room from your database. Defaults to `true`. * `force_purge` - Optional, and ignored unless `purge` is `true`. If set to `true`, it @@ -480,17 +530,124 @@ The following JSON body parameters are available: The JSON body must not be empty. The body must be at least `{}`. -**Response** +## Status of deleting rooms -The following fields are returned in the JSON response body: +**Note**: This API is new, experimental and "subject to change". + +It is possible to query the status of the background task for deleting rooms. +The status can be queried up to 24 hours after completion of the task, +or until Synapse is restarted (whichever happens first). + +### Query by `room_id` -* `kicked_users` - An array of users (`user_id`) that were kicked. -* `failed_to_kick_users` - An array of users (`user_id`) that that were not kicked. -* `local_aliases` - An array of strings representing the local aliases that were migrated from - the old room to the new. -* `new_room_id` - A string representing the room ID of the new room, or `null` if - no such room was created. +With this API you can get the status of all active deletion tasks, and all those completed in the last 24h, +for the given `room_id`. + +The API is: + +``` +GET /_synapse/admin/v2/rooms//delete_status +``` + +A response body like the following is returned: + +```json +{ + "results": [ + { + "delete_id": "delete_id1", + "status": "failed", + "error": "error message", + "shutdown_room": { + "kicked_users": [], + "failed_to_kick_users": [], + "local_aliases": [], + "new_room_id": null + } + }, { + "delete_id": "delete_id2", + "status": "purging", + "shutdown_room": { + "kicked_users": [ + "@foobar:example.com" + ], + "failed_to_kick_users": [], + "local_aliases": [ + "#badroom:example.com", + "#evilsaloon:example.com" + ], + "new_room_id": "!newroomid:example.com" + } + } + ] +} +``` + +**Parameters** + +The following parameters should be set in the URL: + +* `room_id` - The ID of the room. + +### Query by `delete_id` + +With this API you can get the status of one specific task by `delete_id`. + +The API is: + +``` +GET /_synapse/admin/v2/rooms/delete_status/ +``` + +A response body like the following is returned: + +```json +{ + "status": "purging", + "shutdown_room": { + "kicked_users": [ + "@foobar:example.com" + ], + "failed_to_kick_users": [], + "local_aliases": [ + "#badroom:example.com", + "#evilsaloon:example.com" + ], + "new_room_id": "!newroomid:example.com" + } +} +``` + +**Parameters** + +The following parameters should be set in the URL: + +* `delete_id` - The ID for this delete. + +### Response + +The following fields are returned in the JSON response body: +- `results` - An array of objects, each containing information about one task. + This field is omitted from the result when you query by `delete_id`. + Task objects contain the following fields: + - `delete_id` - The ID for this purge if you query by `room_id`. + - `status` - The status will be one of: + - `shutting_down` - The process is removing users from the room. + - `purging` - The process is purging the room and event data from database. + - `complete` - The process has completed successfully. + - `failed` - The process is aborted, an error has occurred. + - `error` - A string that shows an error message if `status` is `failed`. + Otherwise this field is hidden. + - `shutdown_room` - An object containing information about the result of shutting down the room. + *Note:* The result is shown after removing the room members. + The delete process can still be running. Please pay attention to the `status`. + - `kicked_users` - An array of users (`user_id`) that were kicked. + - `failed_to_kick_users` - An array of users (`user_id`) that that were not kicked. + - `local_aliases` - An array of strings representing the local aliases that were + migrated from the old room to the new. + - `new_room_id` - A string representing the room ID of the new room, or `null` if + no such room was created. ## Undoing room deletions diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index aa26911aed..cd64142735 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict, Optional, Set +from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Set import attr @@ -22,7 +22,7 @@ from twisted.python.failure import Failure from synapse.api.constants import EventTypes, Membership from synapse.api.errors import SynapseError from synapse.api.filtering import Filter -from synapse.logging.context import run_in_background +from synapse.handlers.room import ShutdownRoomResponse from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig @@ -56,11 +56,62 @@ class PurgeStatus: STATUS_FAILED: "failed", } + # Save the error message if an error occurs + error: str = "" + # Tracks whether this request has completed. One of STATUS_{ACTIVE,COMPLETE,FAILED}. status: int = STATUS_ACTIVE def asdict(self) -> JsonDict: - return {"status": PurgeStatus.STATUS_TEXT[self.status]} + ret = {"status": PurgeStatus.STATUS_TEXT[self.status]} + if self.error: + ret["error"] = self.error + return ret + + +@attr.s(slots=True, auto_attribs=True) +class DeleteStatus: + """Object tracking the status of a delete room request + + This class contains information on the progress of a delete room request, for + return by get_delete_status. + """ + + STATUS_PURGING = 0 + STATUS_COMPLETE = 1 + STATUS_FAILED = 2 + STATUS_SHUTTING_DOWN = 3 + + STATUS_TEXT = { + STATUS_PURGING: "purging", + STATUS_COMPLETE: "complete", + STATUS_FAILED: "failed", + STATUS_SHUTTING_DOWN: "shutting_down", + } + + # Tracks whether this request has completed. + # One of STATUS_{PURGING,COMPLETE,FAILED,SHUTTING_DOWN}. + status: int = STATUS_PURGING + + # Save the error message if an error occurs + error: str = "" + + # Saves the result of an action to give it back to REST API + shutdown_room: ShutdownRoomResponse = { + "kicked_users": [], + "failed_to_kick_users": [], + "local_aliases": [], + "new_room_id": None, + } + + def asdict(self) -> JsonDict: + ret = { + "status": DeleteStatus.STATUS_TEXT[self.status], + "shutdown_room": self.shutdown_room, + } + if self.error: + ret["error"] = self.error + return ret class PaginationHandler: @@ -70,6 +121,9 @@ class PaginationHandler: paginating during a purge. """ + # when to remove a completed deletion/purge from the results map + CLEAR_PURGE_AFTER_MS = 1000 * 3600 * 24 # 24 hours + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() @@ -78,11 +132,18 @@ class PaginationHandler: self.state_store = self.storage.state self.clock = hs.get_clock() self._server_name = hs.hostname + self._room_shutdown_handler = hs.get_room_shutdown_handler() self.pagination_lock = ReadWriteLock() + # IDs of rooms in which there currently an active purge *or delete* operation. self._purges_in_progress_by_room: Set[str] = set() # map from purge id to PurgeStatus self._purges_by_id: Dict[str, PurgeStatus] = {} + # map from purge id to DeleteStatus + self._delete_by_id: Dict[str, DeleteStatus] = {} + # map from room id to delete ids + # Dict[`room_id`, List[`delete_id`]] + self._delete_by_room: Dict[str, List[str]] = {} self._event_serializer = hs.get_event_client_serializer() self._retention_default_max_lifetime = ( @@ -265,8 +326,13 @@ class PaginationHandler: logger.info("[purge] starting purge_id %s", purge_id) self._purges_by_id[purge_id] = PurgeStatus() - run_in_background( - self._purge_history, purge_id, room_id, token, delete_local_events + run_as_background_process( + "purge_history", + self._purge_history, + purge_id, + room_id, + token, + delete_local_events, ) return purge_id @@ -276,7 +342,7 @@ class PaginationHandler: """Carry out a history purge on a room. Args: - purge_id: The id for this purge + purge_id: The ID for this purge. room_id: The room to purge from token: topological token to delete events before delete_local_events: True to delete local events as well as remote ones @@ -295,6 +361,7 @@ class PaginationHandler: "[purge] failed", exc_info=(f.type, f.value, f.getTracebackObject()) # type: ignore ) self._purges_by_id[purge_id].status = PurgeStatus.STATUS_FAILED + self._purges_by_id[purge_id].error = f.getErrorMessage() finally: self._purges_in_progress_by_room.discard(room_id) @@ -302,7 +369,9 @@ class PaginationHandler: def clear_purge() -> None: del self._purges_by_id[purge_id] - self.hs.get_reactor().callLater(24 * 3600, clear_purge) + self.hs.get_reactor().callLater( + PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000, clear_purge + ) def get_purge_status(self, purge_id: str) -> Optional[PurgeStatus]: """Get the current status of an active purge @@ -312,8 +381,25 @@ class PaginationHandler: """ return self._purges_by_id.get(purge_id) + def get_delete_status(self, delete_id: str) -> Optional[DeleteStatus]: + """Get the current status of an active deleting + + Args: + delete_id: delete_id returned by start_shutdown_and_purge_room + """ + return self._delete_by_id.get(delete_id) + + def get_delete_ids_by_room(self, room_id: str) -> Optional[Collection[str]]: + """Get all active delete ids by room + + Args: + room_id: room_id that is deleted + """ + return self._delete_by_room.get(room_id) + async def purge_room(self, room_id: str, force: bool = False) -> None: """Purge the given room from the database. + This function is part the delete room v1 API. Args: room_id: room to be purged @@ -472,3 +558,192 @@ class PaginationHandler: ) return chunk + + async def _shutdown_and_purge_room( + self, + delete_id: str, + room_id: str, + requester_user_id: str, + new_room_user_id: Optional[str] = None, + new_room_name: Optional[str] = None, + message: Optional[str] = None, + block: bool = False, + purge: bool = True, + force_purge: bool = False, + ) -> None: + """ + Shuts down and purges a room. + + See `RoomShutdownHandler.shutdown_room` for details of creation of the new room + + Args: + delete_id: The ID for this delete. + room_id: The ID of the room to shut down. + requester_user_id: + User who requested the action. Will be recorded as putting the room on the + blocking list. + new_room_user_id: + If set, a new room will be created with this user ID + as the creator and admin, and all users in the old room will be + moved into that room. If not set, no new room will be created + and the users will just be removed from the old room. + new_room_name: + A string representing the name of the room that new users will + be invited to. Defaults to `Content Violation Notification` + message: + A string containing the first message that will be sent as + `new_room_user_id` in the new room. Ideally this will clearly + convey why the original room was shut down. + Defaults to `Sharing illegal content on this server is not + permitted and rooms in violation will be blocked.` + block: + If set to `true`, this room will be added to a blocking list, + preventing future attempts to join the room. Defaults to `false`. + purge: + If set to `true`, purge the given room from the database. + force_purge: + If set to `true`, the room will be purged from database + also if it fails to remove some users from room. + + Saves a `RoomShutdownHandler.ShutdownRoomResponse` in `DeleteStatus`: + """ + + self._purges_in_progress_by_room.add(room_id) + try: + with await self.pagination_lock.write(room_id): + self._delete_by_id[delete_id].status = DeleteStatus.STATUS_SHUTTING_DOWN + self._delete_by_id[ + delete_id + ].shutdown_room = await self._room_shutdown_handler.shutdown_room( + room_id=room_id, + requester_user_id=requester_user_id, + new_room_user_id=new_room_user_id, + new_room_name=new_room_name, + message=message, + block=block, + ) + self._delete_by_id[delete_id].status = DeleteStatus.STATUS_PURGING + + if purge: + logger.info("starting purge room_id %s", room_id) + + # first check that we have no users in this room + if not force_purge: + joined = await self.store.is_host_joined( + room_id, self._server_name + ) + if joined: + raise SynapseError( + 400, "Users are still joined to this room" + ) + + await self.storage.purge_events.purge_room(room_id) + + logger.info("complete") + self._delete_by_id[delete_id].status = DeleteStatus.STATUS_COMPLETE + except Exception: + f = Failure() + logger.error( + "failed", + exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore + ) + self._delete_by_id[delete_id].status = DeleteStatus.STATUS_FAILED + self._delete_by_id[delete_id].error = f.getErrorMessage() + finally: + self._purges_in_progress_by_room.discard(room_id) + + # remove the delete from the list 24 hours after it completes + def clear_delete() -> None: + del self._delete_by_id[delete_id] + self._delete_by_room[room_id].remove(delete_id) + if not self._delete_by_room[room_id]: + del self._delete_by_room[room_id] + + self.hs.get_reactor().callLater( + PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000, clear_delete + ) + + def start_shutdown_and_purge_room( + self, + room_id: str, + requester_user_id: str, + new_room_user_id: Optional[str] = None, + new_room_name: Optional[str] = None, + message: Optional[str] = None, + block: bool = False, + purge: bool = True, + force_purge: bool = False, + ) -> str: + """Start off shut down and purge on a room. + + Args: + room_id: The ID of the room to shut down. + requester_user_id: + User who requested the action and put the room on the + blocking list. + new_room_user_id: + If set, a new room will be created with this user ID + as the creator and admin, and all users in the old room will be + moved into that room. If not set, no new room will be created + and the users will just be removed from the old room. + new_room_name: + A string representing the name of the room that new users will + be invited to. Defaults to `Content Violation Notification` + message: + A string containing the first message that will be sent as + `new_room_user_id` in the new room. Ideally this will clearly + convey why the original room was shut down. + Defaults to `Sharing illegal content on this server is not + permitted and rooms in violation will be blocked.` + block: + If set to `true`, this room will be added to a blocking list, + preventing future attempts to join the room. Defaults to `false`. + purge: + If set to `true`, purge the given room from the database. + force_purge: + If set to `true`, the room will be purged from database + also if it fails to remove some users from room. + + Returns: + unique ID for this delete transaction. + """ + if room_id in self._purges_in_progress_by_room: + raise SynapseError( + 400, "History purge already in progress for %s" % (room_id,) + ) + + # This check is double to `RoomShutdownHandler.shutdown_room` + # But here the requester get a direct response / error with HTTP request + # and do not have to check the purge status + if new_room_user_id is not None: + if not self.hs.is_mine_id(new_room_user_id): + raise SynapseError( + 400, "User must be our own: %s" % (new_room_user_id,) + ) + + delete_id = random_string(16) + + # we log the delete_id here so that it can be tied back to the + # request id in the log lines. + logger.info( + "starting shutdown room_id %s with delete_id %s", + room_id, + delete_id, + ) + + self._delete_by_id[delete_id] = DeleteStatus() + self._delete_by_room.setdefault(room_id, []).append(delete_id) + run_as_background_process( + "shutdown_and_purge_room", + self._shutdown_and_purge_room, + delete_id, + room_id, + requester_user_id, + new_room_user_id, + new_room_name, + message, + block, + purge, + force_purge, + ) + return delete_id diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 11af30eee7..f9a099c4f3 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1279,6 +1279,17 @@ class RoomEventSource(EventSource[RoomStreamToken, EventBase]): class ShutdownRoomResponse(TypedDict): + """ + Attributes: + kicked_users: An array of users (`user_id`) that were kicked. + failed_to_kick_users: + An array of users (`user_id`) that that were not kicked. + local_aliases: + An array of strings representing the local aliases that were + migrated from the old room to the new. + new_room_id: A string representing the room ID of the new room. + """ + kicked_users: List[str] failed_to_kick_users: List[str] local_aliases: List[str] @@ -1286,7 +1297,6 @@ class ShutdownRoomResponse(TypedDict): class RoomShutdownHandler: - DEFAULT_MESSAGE = ( "Sharing illegal content on this server is not permitted and rooms in" " violation will be blocked." @@ -1299,7 +1309,6 @@ class RoomShutdownHandler: self._room_creation_handler = hs.get_room_creation_handler() self._replication = hs.get_replication_data_handler() self.event_creation_handler = hs.get_event_creation_handler() - self.state = hs.get_state_handler() self.store = hs.get_datastore() async def shutdown_room( diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 81e98f81d6..d78fe406c4 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -46,6 +46,8 @@ from synapse.rest.admin.registration_tokens import ( RegistrationTokenRestServlet, ) from synapse.rest.admin.rooms import ( + DeleteRoomStatusByDeleteIdRestServlet, + DeleteRoomStatusByRoomIdRestServlet, ForwardExtremitiesRestServlet, JoinRoomAliasServlet, ListRoomRestServlet, @@ -53,6 +55,7 @@ from synapse.rest.admin.rooms import ( RoomEventContextServlet, RoomMembersRestServlet, RoomRestServlet, + RoomRestV2Servlet, RoomStateRestServlet, ) from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet @@ -223,7 +226,10 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ListRoomRestServlet(hs).register(http_server) RoomStateRestServlet(hs).register(http_server) RoomRestServlet(hs).register(http_server) + RoomRestV2Servlet(hs).register(http_server) RoomMembersRestServlet(hs).register(http_server) + DeleteRoomStatusByDeleteIdRestServlet(hs).register(http_server) + DeleteRoomStatusByRoomIdRestServlet(hs).register(http_server) JoinRoomAliasServlet(hs).register(http_server) VersionServlet(hs).register(http_server) UserAdminServlet(hs).register(http_server) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index a2f4edebb8..37cb4d0796 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -34,7 +34,7 @@ from synapse.rest.admin._base import ( assert_user_is_admin, ) from synapse.storage.databases.main.room import RoomSortOrder -from synapse.types import JsonDict, UserID, create_requester +from synapse.types import JsonDict, RoomID, UserID, create_requester from synapse.util import json_decoder if TYPE_CHECKING: @@ -46,6 +46,138 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +class RoomRestV2Servlet(RestServlet): + """Delete a room from server asynchronously with a background task. + + It is a combination and improvement of shutdown and purge room. + + Shuts down a room by removing all local users from the room. + Blocking all future invites and joins to the room is optional. + + If desired any local aliases will be repointed to a new room + created by `new_room_user_id` and kicked users will be auto- + joined to the new room. + + If 'purge' is true, it will remove all traces of a room from the database. + """ + + PATTERNS = admin_patterns("/rooms/(?P[^/]+)$", "v2") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._store = hs.get_datastore() + self._pagination_handler = hs.get_pagination_handler() + + async def on_DELETE( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + + requester = await self._auth.get_user_by_req(request) + await assert_user_is_admin(self._auth, requester.user) + + content = parse_json_object_from_request(request) + + block = content.get("block", False) + if not isinstance(block, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'block' must be a boolean, if given", + Codes.BAD_JSON, + ) + + purge = content.get("purge", True) + if not isinstance(purge, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'purge' must be a boolean, if given", + Codes.BAD_JSON, + ) + + force_purge = content.get("force_purge", False) + if not isinstance(force_purge, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'force_purge' must be a boolean, if given", + Codes.BAD_JSON, + ) + + if not RoomID.is_valid(room_id): + raise SynapseError(400, "%s is not a legal room ID" % (room_id,)) + + if not await self._store.get_room(room_id): + raise NotFoundError("Unknown room id %s" % (room_id,)) + + delete_id = self._pagination_handler.start_shutdown_and_purge_room( + room_id=room_id, + new_room_user_id=content.get("new_room_user_id"), + new_room_name=content.get("room_name"), + message=content.get("message"), + requester_user_id=requester.user.to_string(), + block=block, + purge=purge, + force_purge=force_purge, + ) + + return 200, {"delete_id": delete_id} + + +class DeleteRoomStatusByRoomIdRestServlet(RestServlet): + """Get the status of the delete room background task.""" + + PATTERNS = admin_patterns("/rooms/(?P[^/]+)/delete_status$", "v2") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._pagination_handler = hs.get_pagination_handler() + + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + + await assert_requester_is_admin(self._auth, request) + + if not RoomID.is_valid(room_id): + raise SynapseError(400, "%s is not a legal room ID" % (room_id,)) + + delete_ids = self._pagination_handler.get_delete_ids_by_room(room_id) + if delete_ids is None: + raise NotFoundError("No delete task for room_id '%s' found" % room_id) + + response = [] + for delete_id in delete_ids: + delete = self._pagination_handler.get_delete_status(delete_id) + if delete: + response += [ + { + "delete_id": delete_id, + **delete.asdict(), + } + ] + return 200, {"results": cast(JsonDict, response)} + + +class DeleteRoomStatusByDeleteIdRestServlet(RestServlet): + """Get the status of the delete room background task.""" + + PATTERNS = admin_patterns("/rooms/delete_status/(?P[^/]+)$", "v2") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._pagination_handler = hs.get_pagination_handler() + + async def on_GET( + self, request: SynapseRequest, delete_id: str + ) -> Tuple[int, JsonDict]: + + await assert_requester_is_admin(self._auth, request) + + delete_status = self._pagination_handler.get_delete_status(delete_id) + if delete_status is None: + raise NotFoundError("delete id '%s' not found" % delete_id) + + return 200, cast(JsonDict, delete_status.asdict()) + + class ListRoomRestServlet(RestServlet): """ List all rooms that are known to the homeserver. Results are returned diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 192073c520..af849bd471 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -474,3 +474,51 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): % server_and_media_id_2 ), ) + + +class PurgeHistoryTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + + self.room_id = self.helper.create_room_as( + self.other_user, tok=self.other_user_tok + ) + self.url = f"/_synapse/admin/v1/purge_history/{self.room_id}" + self.url_status = "/_synapse/admin/v1/purge_history_status/" + + def test_purge_history(self): + """ + Simple test of purge history API. + Test only that is is possible to call, get status 200 and purge_id. + """ + + channel = self.make_request( + "POST", + self.url, + content={"delete_local_events": True, "purge_up_to_ts": 0}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertIn("purge_id", channel.json_body) + purge_id = channel.json_body["purge_id"] + + # get status + channel = self.make_request( + "GET", + self.url_status + purge_id, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("complete", channel.json_body["status"]) diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 11ec54c82e..b48fc12e5f 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -23,6 +23,7 @@ from parameterized import parameterized import synapse.rest.admin from synapse.api.constants import EventTypes, Membership from synapse.api.errors import Codes +from synapse.handlers.pagination import PaginationHandler from synapse.rest.client import directory, events, login, room from tests import unittest @@ -71,11 +72,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", self.url, - json.dumps({}), + {}, access_token=self.other_user_tok, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_room_does_not_exist(self): @@ -87,11 +88,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", url, - json.dumps({}), + {}, access_token=self.admin_user_tok, ) - self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_room_is_not_valid(self): @@ -103,11 +104,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", url, - json.dumps({}), + {}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "invalidroom is not a legal room ID", channel.json_body["error"], @@ -122,11 +123,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", self.url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIn("new_room_id", channel.json_body) self.assertIn("kicked_users", channel.json_body) self.assertIn("failed_to_kick_users", channel.json_body) @@ -141,11 +142,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", self.url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "User must be our own: @not:exist.bla", channel.json_body["error"], @@ -160,11 +161,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", self.url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) def test_purge_is_not_bool(self): @@ -176,11 +177,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", self.url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) def test_purge_room_and_block(self): @@ -202,11 +203,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", self.url.encode("ascii"), - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(None, channel.json_body["new_room_id"]) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("failed_to_kick_users", channel.json_body) @@ -235,11 +236,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", self.url.encode("ascii"), - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(None, channel.json_body["new_room_id"]) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("failed_to_kick_users", channel.json_body) @@ -269,11 +270,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", self.url.encode("ascii"), - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(None, channel.json_body["new_room_id"]) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("failed_to_kick_users", channel.json_body) @@ -344,7 +345,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("new_room_id", channel.json_body) self.assertIn("failed_to_kick_users", channel.json_body) @@ -373,7 +374,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): json.dumps({"history_visibility": "world_readable"}), access_token=self.other_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Test that room is not purged with self.assertRaises(AssertionError): @@ -390,7 +391,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("new_room_id", channel.json_body) self.assertIn("failed_to_kick_users", channel.json_body) @@ -446,18 +447,617 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok ) + self.assertEqual(expect_code, channel.code, msg=channel.json_body) + + url = "events?timeout=0&room_id=" + room_id + channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok + ) + self.assertEqual(expect_code, channel.code, msg=channel.json_body) + + +class DeleteRoomV2TestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + events.register_servlets, + room.register_servlets, + room.register_deprecated_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.event_creation_handler = hs.get_event_creation_handler() + hs.config.consent.user_consent_version = "1" + + consent_uri_builder = Mock() + consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" + self.event_creation_handler._consent_uri_builder = consent_uri_builder + + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + + # Mark the admin user as having consented + self.get_success(self.store.user_set_consent_version(self.admin_user, "1")) + + self.room_id = self.helper.create_room_as( + self.other_user, tok=self.other_user_tok + ) + self.url = f"/_synapse/admin/v2/rooms/{self.room_id}" + self.url_status_by_room_id = ( + f"/_synapse/admin/v2/rooms/{self.room_id}/delete_status" + ) + self.url_status_by_delete_id = "/_synapse/admin/v2/rooms/delete_status/" + + @parameterized.expand( + [ + ("DELETE", "/_synapse/admin/v2/rooms/%s"), + ("GET", "/_synapse/admin/v2/rooms/%s/delete_status"), + ("GET", "/_synapse/admin/v2/rooms/delete_status/%s"), + ] + ) + def test_requester_is_no_admin(self, method: str, url: str): + """ + If the user is not a server admin, an error 403 is returned. + """ + + channel = self.make_request( + method, + url % self.room_id, + content={}, + access_token=self.other_user_tok, + ) + + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + @parameterized.expand( + [ + ("DELETE", "/_synapse/admin/v2/rooms/%s"), + ("GET", "/_synapse/admin/v2/rooms/%s/delete_status"), + ("GET", "/_synapse/admin/v2/rooms/delete_status/%s"), + ] + ) + def test_room_does_not_exist(self, method: str, url: str): + """ + Check that unknown rooms/server return error 404. + """ + + channel = self.make_request( + method, + url % "!unknown:test", + content={}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + @parameterized.expand( + [ + ("DELETE", "/_synapse/admin/v2/rooms/%s"), + ("GET", "/_synapse/admin/v2/rooms/%s/delete_status"), + ] + ) + def test_room_is_not_valid(self, method: str, url: str): + """ + Check that invalid room names, return an error 400. + """ + + channel = self.make_request( + method, + url % "invalidroom", + content={}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual( + "invalidroom is not a legal room ID", + channel.json_body["error"], + ) + + def test_new_room_user_does_not_exist(self): + """ + Tests that the user ID must be from local server but it does not have to exist. + """ + + channel = self.make_request( + "DELETE", + self.url, + content={"new_room_user_id": "@unknown:test"}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id = channel.json_body["delete_id"] + + self._test_result(delete_id, self.other_user, expect_new_room=True) + + def test_new_room_user_is_not_local(self): + """ + Check that only local users can create new room to move members. + """ + + channel = self.make_request( + "DELETE", + self.url, + content={"new_room_user_id": "@not:exist.bla"}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"] + "User must be our own: @not:exist.bla", + channel.json_body["error"], ) + def test_block_is_not_bool(self): + """ + If parameter `block` is not boolean, return an error + """ + + channel = self.make_request( + "DELETE", + self.url, + content={"block": "NotBool"}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) + + def test_purge_is_not_bool(self): + """ + If parameter `purge` is not boolean, return an error + """ + + channel = self.make_request( + "DELETE", + self.url, + content={"purge": "NotBool"}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) + + def test_delete_expired_status(self): + """Test that the task status is removed after expiration.""" + + # first task, do not purge, that we can create a second task + channel = self.make_request( + "DELETE", + self.url.encode("ascii"), + content={"purge": False}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id1 = channel.json_body["delete_id"] + + # go ahead + self.reactor.advance(PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000 / 2) + + # second task + channel = self.make_request( + "DELETE", + self.url.encode("ascii"), + content={"purge": True}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id2 = channel.json_body["delete_id"] + + # get status + channel = self.make_request( + "GET", + self.url_status_by_room_id, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(2, len(channel.json_body["results"])) + self.assertEqual("complete", channel.json_body["results"][0]["status"]) + self.assertEqual("complete", channel.json_body["results"][1]["status"]) + self.assertEqual(delete_id1, channel.json_body["results"][0]["delete_id"]) + self.assertEqual(delete_id2, channel.json_body["results"][1]["delete_id"]) + + # get status after more than clearing time for first task + # second task is not cleared + self.reactor.advance(PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000 / 2) + + channel = self.make_request( + "GET", + self.url_status_by_room_id, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(1, len(channel.json_body["results"])) + self.assertEqual("complete", channel.json_body["results"][0]["status"]) + self.assertEqual(delete_id2, channel.json_body["results"][0]["delete_id"]) + + # get status after more than clearing time for all tasks + self.reactor.advance(PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000 / 2) + + channel = self.make_request( + "GET", + self.url_status_by_room_id, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_delete_same_room_twice(self): + """Test that the call for delete a room at second time gives an exception.""" + + body = {"new_room_user_id": self.admin_user} + + # first call to delete room + # and do not wait for finish the task + first_channel = self.make_request( + "DELETE", + self.url.encode("ascii"), + content=body, + access_token=self.admin_user_tok, + await_result=False, + ) + + # second call to delete room + second_channel = self.make_request( + "DELETE", + self.url.encode("ascii"), + content=body, + access_token=self.admin_user_tok, + ) + + self.assertEqual( + HTTPStatus.BAD_REQUEST, second_channel.code, msg=second_channel.json_body + ) + self.assertEqual(Codes.UNKNOWN, second_channel.json_body["errcode"]) + self.assertEqual( + f"History purge already in progress for {self.room_id}", + second_channel.json_body["error"], + ) + + # get result of first call + first_channel.await_result() + self.assertEqual(HTTPStatus.OK, first_channel.code, msg=first_channel.json_body) + self.assertIn("delete_id", first_channel.json_body) + + # check status after finish the task + self._test_result( + first_channel.json_body["delete_id"], + self.other_user, + expect_new_room=True, + ) + + def test_purge_room_and_block(self): + """Test to purge a room and block it. + Members will not be moved to a new room and will not receive a message. + """ + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Test that room is not blocked + self._is_blocked(self.room_id, expect=False) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + channel = self.make_request( + "DELETE", + self.url.encode("ascii"), + content={"block": True, "purge": True}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id = channel.json_body["delete_id"] + + self._test_result(delete_id, self.other_user) + + self._is_purged(self.room_id) + self._is_blocked(self.room_id, expect=True) + self._has_no_members(self.room_id) + + def test_purge_room_and_not_block(self): + """Test to purge a room and do not block it. + Members will not be moved to a new room and will not receive a message. + """ + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Test that room is not blocked + self._is_blocked(self.room_id, expect=False) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + channel = self.make_request( + "DELETE", + self.url.encode("ascii"), + content={"block": False, "purge": True}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id = channel.json_body["delete_id"] + + self._test_result(delete_id, self.other_user) + + self._is_purged(self.room_id) + self._is_blocked(self.room_id, expect=False) + self._has_no_members(self.room_id) + + def test_block_room_and_not_purge(self): + """Test to block a room without purging it. + Members will not be moved to a new room and will not receive a message. + The room will not be purged. + """ + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Test that room is not blocked + self._is_blocked(self.room_id, expect=False) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + channel = self.make_request( + "DELETE", + self.url.encode("ascii"), + content={"block": True, "purge": False}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id = channel.json_body["delete_id"] + + self._test_result(delete_id, self.other_user) + + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + self._is_blocked(self.room_id, expect=True) + self._has_no_members(self.room_id) + + def test_shutdown_room_consent(self): + """Test that we can shutdown rooms with local users who have not + yet accepted the privacy policy. This used to fail when we tried to + force part the user from the old room. + Members will be moved to a new room and will receive a message. + """ + self.event_creation_handler._block_events_without_consent_error = None + + # Assert one user in room + users_in_room = self.get_success(self.store.get_users_in_room(self.room_id)) + self.assertEqual([self.other_user], users_in_room) + + # Enable require consent to send events + self.event_creation_handler._block_events_without_consent_error = "Error" + + # Assert that the user is getting consent error + self.helper.send( + self.room_id, body="foo", tok=self.other_user_tok, expect_code=403 + ) + + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + # Test that the admin can still send shutdown + channel = self.make_request( + "DELETE", + self.url, + content={"new_room_user_id": self.admin_user}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id = channel.json_body["delete_id"] + + self._test_result(delete_id, self.other_user, expect_new_room=True) + + channel = self.make_request( + "GET", + self.url_status_by_room_id, + access_token=self.admin_user_tok, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(1, len(channel.json_body["results"])) + + # Test that member has moved to new room + self._is_member( + room_id=channel.json_body["results"][0]["shutdown_room"]["new_room_id"], + user_id=self.other_user, + ) + + self._is_purged(self.room_id) + self._has_no_members(self.room_id) + + def test_shutdown_room_block_peek(self): + """Test that a world_readable room can no longer be peeked into after + it has been shut down. + Members will be moved to a new room and will receive a message. + """ + self.event_creation_handler._block_events_without_consent_error = None + + # Enable world readable + url = "rooms/%s/state/m.room.history_visibility" % (self.room_id,) + channel = self.make_request( + "PUT", + url.encode("ascii"), + content={"history_visibility": "world_readable"}, + access_token=self.other_user_tok, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + # Test that the admin can still send shutdown + channel = self.make_request( + "DELETE", + self.url, + content={"new_room_user_id": self.admin_user}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id = channel.json_body["delete_id"] + + self._test_result(delete_id, self.other_user, expect_new_room=True) + + channel = self.make_request( + "GET", + self.url_status_by_room_id, + access_token=self.admin_user_tok, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(1, len(channel.json_body["results"])) + + # Test that member has moved to new room + self._is_member( + room_id=channel.json_body["results"][0]["shutdown_room"]["new_room_id"], + user_id=self.other_user, + ) + + self._is_purged(self.room_id) + self._has_no_members(self.room_id) + + # Assert we can no longer peek into the room + self._assert_peek(self.room_id, expect_code=403) + + def _is_blocked(self, room_id: str, expect: bool = True) -> None: + """Assert that the room is blocked or not""" + d = self.store.is_room_blocked(room_id) + if expect: + self.assertTrue(self.get_success(d)) + else: + self.assertIsNone(self.get_success(d)) + + def _has_no_members(self, room_id: str) -> None: + """Assert there is now no longer anyone in the room""" + users_in_room = self.get_success(self.store.get_users_in_room(room_id)) + self.assertEqual([], users_in_room) + + def _is_member(self, room_id: str, user_id: str) -> None: + """Test that user is member of the room""" + users_in_room = self.get_success(self.store.get_users_in_room(room_id)) + self.assertIn(user_id, users_in_room) + + def _is_purged(self, room_id: str) -> None: + """Test that the following tables have been purged of all rows related to the room.""" + for table in PURGE_TABLES: + count = self.get_success( + self.store.db_pool.simple_select_one_onecol( + table=table, + keyvalues={"room_id": room_id}, + retcol="COUNT(*)", + desc="test_purge_room", + ) + ) + + self.assertEqual(count, 0, msg=f"Rows not purged in {table}") + + def _assert_peek(self, room_id: str, expect_code: int) -> None: + """Assert that the admin user can (or cannot) peek into the room.""" + + url = f"rooms/{room_id}/initialSync" + channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok + ) + self.assertEqual(expect_code, channel.code, msg=channel.json_body) + url = "events?timeout=0&room_id=" + room_id channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok ) + self.assertEqual(expect_code, channel.code, msg=channel.json_body) + + def _test_result( + self, + delete_id: str, + kicked_user: str, + expect_new_room: bool = False, + ) -> None: + """ + Test that the result is the expected. + Uses both APIs (status by room_id and delete_id) + + Args: + delete_id: id of this purge + kicked_user: a user_id which is kicked from the room + expect_new_room: if we expect that a new room was created + """ + + # get information by room_id + channel_room_id = self.make_request( + "GET", + self.url_status_by_room_id, + access_token=self.admin_user_tok, + ) self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"] + HTTPStatus.OK, channel_room_id.code, msg=channel_room_id.json_body + ) + self.assertEqual(1, len(channel_room_id.json_body["results"])) + self.assertEqual( + delete_id, channel_room_id.json_body["results"][0]["delete_id"] ) + # get information by delete_id + channel_delete_id = self.make_request( + "GET", + self.url_status_by_delete_id + delete_id, + access_token=self.admin_user_tok, + ) + self.assertEqual( + HTTPStatus.OK, + channel_delete_id.code, + msg=channel_delete_id.json_body, + ) + + # test values that are the same in both responses + for content in [ + channel_room_id.json_body["results"][0], + channel_delete_id.json_body, + ]: + self.assertEqual("complete", content["status"]) + self.assertEqual(kicked_user, content["shutdown_room"]["kicked_users"][0]) + self.assertIn("failed_to_kick_users", content["shutdown_room"]) + self.assertIn("local_aliases", content["shutdown_room"]) + self.assertNotIn("error", content) + + if expect_new_room: + self.assertIsNotNone(content["shutdown_room"]["new_room_id"]) + else: + self.assertIsNone(content["shutdown_room"]["new_room_id"]) + class RoomTestCase(unittest.HomeserverTestCase): """Test /room admin API.""" @@ -494,7 +1094,7 @@ class RoomTestCase(unittest.HomeserverTestCase): ) # Check request completed successfully - self.assertEqual(200, int(channel.code), msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Check that response json body contains a "rooms" key self.assertTrue( @@ -578,9 +1178,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual( - 200, int(channel.result["code"]), msg=channel.result["body"] - ) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertTrue("rooms" in channel.json_body) for r in channel.json_body["rooms"]: @@ -620,7 +1218,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) def test_correct_room_attributes(self): """Test the correct attributes for a room are returned""" @@ -643,7 +1241,7 @@ class RoomTestCase(unittest.HomeserverTestCase): {"room_id": room_id}, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Set this new alias as the canonical alias for this room self.helper.send_state( @@ -675,7 +1273,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Check that rooms were returned self.assertTrue("rooms" in channel.json_body) @@ -1135,7 +1733,7 @@ class RoomTestCase(unittest.HomeserverTestCase): {"room_id": room_id}, access_token=admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Set this new alias as the canonical alias for this room self.helper.send_state( @@ -1185,11 +1783,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", self.url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.second_tok, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_invalid_parameter(self): @@ -1201,11 +1799,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", self.url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) def test_local_user_does_not_exist(self): @@ -1217,11 +1815,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", self.url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_remote_user(self): @@ -1233,11 +1831,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", self.url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "This endpoint can only be used with local users", channel.json_body["error"], @@ -1253,11 +1851,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual("No known servers", channel.json_body["error"]) def test_room_is_not_valid(self): @@ -1270,11 +1868,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "invalidroom was not legal room ID or room alias", channel.json_body["error"], @@ -1289,11 +1887,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", self.url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.public_room_id, channel.json_body["room_id"]) # Validate if user is a member of the room @@ -1303,7 +1901,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEquals(200, channel.code, msg=channel.json_body) self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0]) def test_join_private_room_if_not_member(self): @@ -1320,11 +1918,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_join_private_room_if_member(self): @@ -1352,7 +1950,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok, ) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEquals(200, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) # Join user to room. @@ -1363,10 +1961,10 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["room_id"]) # Validate if user is a member of the room @@ -1376,7 +1974,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEquals(200, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) def test_join_private_room_if_owner(self): @@ -1393,11 +1991,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["room_id"]) # Validate if user is a member of the room @@ -1407,7 +2005,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEquals(200, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) def test_context_as_non_admin(self): @@ -1441,9 +2039,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): % (room_id, events[midway]["event_id"]), access_token=tok, ) - self.assertEquals( - 403, int(channel.result["code"]), msg=channel.result["body"] - ) + self.assertEquals(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_context_as_admin(self): @@ -1473,7 +2069,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): % (room_id, events[midway]["event_id"]), access_token=self.admin_user_tok, ) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEquals(200, channel.code, msg=channel.json_body) self.assertEquals( channel.json_body["event"]["event_id"], events[midway]["event_id"] ) @@ -1532,7 +2128,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Now we test that we can join the room and ban a user. self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok) @@ -1559,7 +2155,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Now we test that we can join the room (we should have received an # invite) and can ban a user. @@ -1585,7 +2181,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Now we test that we can join the room and ban a user. self.helper.join(room_id, self.second_user_id, tok=self.second_tok) @@ -1623,7 +2219,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): # # (Note we assert the error message to ensure that it's not denied for # some other reason) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( channel.json_body["error"], "No local admin user in room with power to update power levels.", -- cgit 1.5.1 From 4c96ce396e900a94af66ec070af925881b6e1e24 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 12 Nov 2021 15:50:54 +0000 Subject: Misc typing fixes for `tests`, part 1 of N (#11323) * Annotate HomeserverTestCase.servlets * Correct annotation of federation_auth_origin * Use AnyStr custom_headers instead of a Union This allows (str, str) and (bytes, bytes). This disallows (str, bytes) and (bytes, str) * DomainSpecificString.SIGIL is a ClassVar --- changelog.d/11323.misc | 1 + synapse/rest/__init__.py | 4 +++- synapse/types.py | 3 ++- tests/replication/_base.py | 5 +---- tests/rest/client/utils.py | 22 ++++++++++++++-------- tests/server.py | 15 +++++++++++---- tests/unittest.py | 32 +++++++++++++++++++++----------- 7 files changed, 53 insertions(+), 29 deletions(-) create mode 100644 changelog.d/11323.misc (limited to 'synapse/rest') diff --git a/changelog.d/11323.misc b/changelog.d/11323.misc new file mode 100644 index 0000000000..54f39e1844 --- /dev/null +++ b/changelog.d/11323.misc @@ -0,0 +1 @@ +Improve type annotations in Synapse's test suite. \ No newline at end of file diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index e04af705eb..cebdeecb81 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable from synapse.http.server import HttpServer, JsonResource from synapse.rest import admin @@ -62,6 +62,8 @@ from synapse.rest.client import ( if TYPE_CHECKING: from synapse.server import HomeServer +RegisterServletsFunc = Callable[["HomeServer", HttpServer], None] + class ClientRestResource(JsonResource): """Matrix Client API REST resource. diff --git a/synapse/types.py b/synapse/types.py index 9d7a675662..fb72f19343 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -19,6 +19,7 @@ from collections import namedtuple from typing import ( TYPE_CHECKING, Any, + ClassVar, Dict, Mapping, MutableMapping, @@ -219,7 +220,7 @@ class DomainSpecificString(metaclass=abc.ABCMeta): 'domain' : The domain part of the name """ - SIGIL: str = abc.abstractproperty() # type: ignore + SIGIL: ClassVar[str] = abc.abstractproperty() # type: ignore localpart = attr.ib(type=str) domain = attr.ib(type=str) diff --git a/tests/replication/_base.py b/tests/replication/_base.py index eac4664b41..cb02eddf07 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -12,13 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from twisted.internet.protocol import Protocol from twisted.web.resource import Resource from synapse.app.generic_worker import GenericWorkerServer -from synapse.http.server import JsonResource from synapse.http.site import SynapseRequest, SynapseSite from synapse.replication.http import ReplicationRestResource from synapse.replication.tcp.client import ReplicationDataHandler @@ -220,8 +219,6 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): unlike `BaseStreamTestCase`. """ - servlets: List[Callable[[HomeServer, JsonResource], None]] = [] - def setUp(self): super().setUp() diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index ec0979850b..7cf782e2d6 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -19,7 +19,17 @@ import json import re import time import urllib.parse -from typing import Any, Dict, Iterable, Mapping, MutableMapping, Optional, Tuple, Union +from typing import ( + Any, + AnyStr, + Dict, + Iterable, + Mapping, + MutableMapping, + Optional, + Tuple, + Union, +) from unittest.mock import patch import attr @@ -53,9 +63,7 @@ class RestHelper: tok: Optional[str] = None, expect_code: int = 200, extra_content: Optional[Dict] = None, - custom_headers: Optional[ - Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] - ] = None, + custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, ) -> str: """ Create a room. @@ -227,9 +235,7 @@ class RestHelper: txn_id=None, tok=None, expect_code=200, - custom_headers: Optional[ - Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] - ] = None, + custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, ): if body is None: body = "body_text_here" @@ -418,7 +424,7 @@ class RestHelper: path, content=image_data, access_token=tok, - custom_headers=[(b"Content-Length", str(image_length))], + custom_headers=[("Content-Length", str(image_length))], ) assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( diff --git a/tests/server.py b/tests/server.py index 103351b487..a7cc5cd325 100644 --- a/tests/server.py +++ b/tests/server.py @@ -16,7 +16,16 @@ import json import logging from collections import deque from io import SEEK_END, BytesIO -from typing import Callable, Dict, Iterable, MutableMapping, Optional, Tuple, Union +from typing import ( + AnyStr, + Callable, + Dict, + Iterable, + MutableMapping, + Optional, + Tuple, + Union, +) import attr from typing_extensions import Deque @@ -222,9 +231,7 @@ def make_request( federation_auth_origin: Optional[bytes] = None, content_is_form: bool = False, await_result: bool = True, - custom_headers: Optional[ - Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] - ] = None, + custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, client_ip: str = "127.0.0.1", ) -> FakeChannel: """ diff --git a/tests/unittest.py b/tests/unittest.py index a9b60b7eeb..ba830618c2 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -20,7 +20,20 @@ import inspect import logging import secrets import time -from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union +from typing import ( + Any, + AnyStr, + Callable, + ClassVar, + Dict, + Iterable, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, +) from unittest.mock import Mock, patch from canonicaljson import json @@ -45,6 +58,7 @@ from synapse.logging.context import ( current_context, set_current_context, ) +from synapse.rest import RegisterServletsFunc from synapse.server import HomeServer from synapse.types import JsonDict, UserID, create_requester from synapse.util import Clock @@ -204,15 +218,15 @@ class HomeserverTestCase(TestCase): config dict. Attributes: - servlets (list[function]): List of servlet registration function. + servlets: List of servlet registration function. user_id (str): The user ID to assume if auth is hijacked. hijack_auth (bool): Whether to hijack auth to return the user specified in user_id. """ - servlets = [] hijack_auth = True needs_threadpool = False + servlets: ClassVar[List[RegisterServletsFunc]] = [] def __init__(self, methodName, *args, **kwargs): super().__init__(methodName, *args, **kwargs) @@ -405,12 +419,10 @@ class HomeserverTestCase(TestCase): access_token: Optional[str] = None, request: Type[T] = SynapseRequest, shorthand: bool = True, - federation_auth_origin: str = None, + federation_auth_origin: Optional[bytes] = None, content_is_form: bool = False, await_result: bool = True, - custom_headers: Optional[ - Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] - ] = None, + custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, client_ip: str = "127.0.0.1", ) -> FakeChannel: """ @@ -425,7 +437,7 @@ class HomeserverTestCase(TestCase): a dict. shorthand: Whether to try and be helpful and prefix the given URL with the usual REST API path, if it doesn't contain it. - federation_auth_origin (bytes|None): if set to not-None, we will add a fake + federation_auth_origin: if set to not-None, we will add a fake Authorization header pretenting to be the given server name. content_is_form: Whether the content is URL encoded form data. Adds the 'Content-Type': 'application/x-www-form-urlencoded' header. @@ -639,9 +651,7 @@ class HomeserverTestCase(TestCase): username, password, device_id=None, - custom_headers: Optional[ - Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] - ] = None, + custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, ): """ Log in a user, and get an access token. Requires the Login API be -- cgit 1.5.1 From 9b90b9454b8855e0575785560662d8e47378094d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 12 Nov 2021 11:05:26 -0500 Subject: Add type hints to media repository storage module (#11311) --- changelog.d/11311.misc | 1 + mypy.ini | 1 - synapse/rest/media/v1/preview_url_resource.py | 8 +- synapse/storage/databases/main/media_repository.py | 141 ++++++++++++--------- 4 files changed, 89 insertions(+), 62 deletions(-) create mode 100644 changelog.d/11311.misc (limited to 'synapse/rest') diff --git a/changelog.d/11311.misc b/changelog.d/11311.misc new file mode 100644 index 0000000000..86594a332d --- /dev/null +++ b/changelog.d/11311.misc @@ -0,0 +1 @@ +Add type hints to storage classes. diff --git a/mypy.ini b/mypy.ini index d81e964dc7..56a62bb9b7 100644 --- a/mypy.ini +++ b/mypy.ini @@ -36,7 +36,6 @@ exclude = (?x) |synapse/storage/databases/main/events_bg_updates.py |synapse/storage/databases/main/events_worker.py |synapse/storage/databases/main/group_server.py - |synapse/storage/databases/main/media_repository.py |synapse/storage/databases/main/metrics.py |synapse/storage/databases/main/monthly_active_users.py |synapse/storage/databases/main/presence.py diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 8ca97b5b18..054f3c296d 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -45,7 +45,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.media.v1._base import get_filename_from_headers from synapse.rest.media.v1.media_storage import MediaStorage from synapse.rest.media.v1.oembed import OEmbedProvider -from synapse.types import JsonDict +from synapse.types import JsonDict, UserID from synapse.util import json_encoder from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.expiringcache import ExpiringCache @@ -231,7 +231,7 @@ class PreviewUrlResource(DirectServeJsonResource): og = await make_deferred_yieldable(observable.observe()) respond_with_json_bytes(request, 200, og, send_cors=True) - async def _do_preview(self, url: str, user: str, ts: int) -> bytes: + async def _do_preview(self, url: str, user: UserID, ts: int) -> bytes: """Check the db, and download the URL and build a preview Args: @@ -360,7 +360,7 @@ class PreviewUrlResource(DirectServeJsonResource): return jsonog.encode("utf8") - async def _download_url(self, url: str, user: str) -> MediaInfo: + async def _download_url(self, url: str, user: UserID) -> MediaInfo: # TODO: we should probably honour robots.txt... except in practice # we're most likely being explicitly triggered by a human rather than a # bot, so are we really a robot? @@ -450,7 +450,7 @@ class PreviewUrlResource(DirectServeJsonResource): ) async def _precache_image_url( - self, user: str, media_info: MediaInfo, og: JsonDict + self, user: UserID, media_info: MediaInfo, og: JsonDict ) -> None: """ Pre-cache the image (if one exists) for posterity diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 717487be28..1b076683f7 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -13,10 +13,25 @@ # See the License for the specific language governing permissions and # limitations under the License. from enum import Enum -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Collection, + Dict, + Iterable, + List, + Optional, + Tuple, + Union, +) from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) +from synapse.types import JsonDict, UserID if TYPE_CHECKING: from synapse.server import HomeServer @@ -46,7 +61,12 @@ class MediaSortOrder(Enum): class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( @@ -102,13 +122,15 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): self._drop_media_index_without_method, ) - async def _drop_media_index_without_method(self, progress, batch_size): + async def _drop_media_index_without_method( + self, progress: JsonDict, batch_size: int + ) -> int: """background update handler which removes the old constraints. Note that this is only run on postgres. """ - def f(txn): + def f(txn: LoggingTransaction) -> None: txn.execute( "ALTER TABLE local_media_repository_thumbnails DROP CONSTRAINT IF EXISTS local_media_repository_thumbn_media_id_thumbnail_width_thum_key" ) @@ -126,7 +148,12 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): """Persistence for attachments and avatars""" - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.server_name = hs.hostname @@ -174,7 +201,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): plus the total count of all the user's media """ - def get_local_media_by_user_paginate_txn(txn): + def get_local_media_by_user_paginate_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Dict[str, Any]], int]: # Set ordering order_by_column = MediaSortOrder(order_by).value @@ -184,14 +213,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): else: order = "ASC" - args = [user_id] + args: List[Union[str, int]] = [user_id] sql = """ SELECT COUNT(*) as total_media FROM local_media_repository WHERE user_id = ? """ txn.execute(sql, args) - count = txn.fetchone()[0] + count = txn.fetchone()[0] # type: ignore[index] sql = """ SELECT @@ -268,7 +297,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): ) sql += sql_keep - def _get_local_media_before_txn(txn): + def _get_local_media_before_txn(txn: LoggingTransaction) -> List[str]: txn.execute(sql, (before_ts, before_ts, size_gt)) return [row[0] for row in txn] @@ -278,13 +307,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): async def store_local_media( self, - media_id, - media_type, - time_now_ms, - upload_name, - media_length, - user_id, - url_cache=None, + media_id: str, + media_type: str, + time_now_ms: int, + upload_name: Optional[str], + media_length: int, + user_id: UserID, + url_cache: Optional[str] = None, ) -> None: await self.db_pool.simple_insert( "local_media_repository", @@ -315,7 +344,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): None if the URL isn't cached. """ - def get_url_cache_txn(txn): + def get_url_cache_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]: # get the most recently cached result (relative to the given ts) sql = ( "SELECT response_code, etag, expires_ts, og, media_id, download_ts" @@ -359,7 +388,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): async def store_url_cache( self, url, response_code, etag, expires_ts, og, media_id, download_ts - ): + ) -> None: await self.db_pool.simple_insert( "local_media_repository_url_cache", { @@ -390,13 +419,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): async def store_local_thumbnail( self, - media_id, - thumbnail_width, - thumbnail_height, - thumbnail_type, - thumbnail_method, - thumbnail_length, - ): + media_id: str, + thumbnail_width: int, + thumbnail_height: int, + thumbnail_type: str, + thumbnail_method: str, + thumbnail_length: int, + ) -> None: await self.db_pool.simple_upsert( table="local_media_repository_thumbnails", keyvalues={ @@ -430,14 +459,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): async def store_cached_remote_media( self, - origin, - media_id, - media_type, - media_length, - time_now_ms, - upload_name, - filesystem_id, - ): + origin: str, + media_id: str, + media_type: str, + media_length: int, + time_now_ms: int, + upload_name: Optional[str], + filesystem_id: str, + ) -> None: await self.db_pool.simple_insert( "remote_media_cache", { @@ -458,7 +487,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): local_media: Iterable[str], remote_media: Iterable[Tuple[str, str]], time_ms: int, - ): + ) -> None: """Updates the last access time of the given media Args: @@ -467,7 +496,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): time_ms: Current time in milliseconds """ - def update_cache_txn(txn): + def update_cache_txn(txn: LoggingTransaction) -> None: sql = ( "UPDATE remote_media_cache SET last_access_ts = ?" " WHERE media_origin = ? AND media_id = ?" @@ -488,7 +517,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): txn.execute_batch(sql, ((time_ms, media_id) for media_id in local_media)) - return await self.db_pool.runInteraction( + await self.db_pool.runInteraction( "update_cached_last_access_time", update_cache_txn ) @@ -542,15 +571,15 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): async def store_remote_media_thumbnail( self, - origin, - media_id, - filesystem_id, - thumbnail_width, - thumbnail_height, - thumbnail_type, - thumbnail_method, - thumbnail_length, - ): + origin: str, + media_id: str, + filesystem_id: str, + thumbnail_width: int, + thumbnail_height: int, + thumbnail_type: str, + thumbnail_method: str, + thumbnail_length: int, + ) -> None: await self.db_pool.simple_upsert( table="remote_media_cache_thumbnails", keyvalues={ @@ -566,7 +595,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="store_remote_media_thumbnail", ) - async def get_remote_media_before(self, before_ts): + async def get_remote_media_before(self, before_ts: int) -> List[Dict[str, str]]: sql = ( "SELECT media_origin, media_id, filesystem_id" " FROM remote_media_cache" @@ -602,7 +631,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): " LIMIT 500" ) - def _get_expired_url_cache_txn(txn): + def _get_expired_url_cache_txn(txn: LoggingTransaction) -> List[str]: txn.execute(sql, (now_ts,)) return [row[0] for row in txn] @@ -610,18 +639,16 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "get_expired_url_cache", _get_expired_url_cache_txn ) - async def delete_url_cache(self, media_ids): + async def delete_url_cache(self, media_ids: Collection[str]) -> None: if len(media_ids) == 0: return sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?" - def _delete_url_cache_txn(txn): + def _delete_url_cache_txn(txn: LoggingTransaction) -> None: txn.execute_batch(sql, [(media_id,) for media_id in media_ids]) - return await self.db_pool.runInteraction( - "delete_url_cache", _delete_url_cache_txn - ) + await self.db_pool.runInteraction("delete_url_cache", _delete_url_cache_txn) async def get_url_cache_media_before(self, before_ts: int) -> List[str]: sql = ( @@ -631,7 +658,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): " LIMIT 500" ) - def _get_url_cache_media_before_txn(txn): + def _get_url_cache_media_before_txn(txn: LoggingTransaction) -> List[str]: txn.execute(sql, (before_ts,)) return [row[0] for row in txn] @@ -639,11 +666,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "get_url_cache_media_before", _get_url_cache_media_before_txn ) - async def delete_url_cache_media(self, media_ids): + async def delete_url_cache_media(self, media_ids: Collection[str]) -> None: if len(media_ids) == 0: return - def _delete_url_cache_media_txn(txn): + def _delete_url_cache_media_txn(txn: LoggingTransaction) -> None: sql = "DELETE FROM local_media_repository WHERE media_id = ?" txn.execute_batch(sql, [(media_id,) for media_id in media_ids]) @@ -652,6 +679,6 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): txn.execute_batch(sql, [(media_id,) for media_id in media_ids]) - return await self.db_pool.runInteraction( + await self.db_pool.runInteraction( "delete_url_cache_media", _delete_url_cache_media_txn ) -- cgit 1.5.1 From 6f862c5c28f03d712502635799dfd1b01bf79712 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 15 Nov 2021 12:31:22 +0200 Subject: Add support for the stable version of MSC2778 (#11335) * Add support for the stable version of MSC2778 Signed-off-by: Tulir Asokan * Expect m.login.application_service in login and password provider tests Signed-off-by: Tulir Asokan --- changelog.d/11335.feature | 1 + synapse/rest/client/login.py | 9 +++++++-- tests/handlers/test_password_providers.py | 5 ++++- tests/rest/client/test_login.py | 5 ++++- 4 files changed, 16 insertions(+), 4 deletions(-) create mode 100644 changelog.d/11335.feature (limited to 'synapse/rest') diff --git a/changelog.d/11335.feature b/changelog.d/11335.feature new file mode 100644 index 0000000000..9b6c1b9c23 --- /dev/null +++ b/changelog.d/11335.feature @@ -0,0 +1 @@ +Support the stable version of [MSC2778](https://github.com/matrix-org/matrix-doc/pull/2778): the `m.login.application_service` login type. Contributed by @tulir. diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index d49a647b03..467444a041 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -61,7 +61,8 @@ class LoginRestServlet(RestServlet): TOKEN_TYPE = "m.login.token" JWT_TYPE = "org.matrix.login.jwt" JWT_TYPE_DEPRECATED = "m.login.jwt" - APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service" + APPSERVICE_TYPE = "m.login.application_service" + APPSERVICE_TYPE_UNSTABLE = "uk.half-shot.msc2778.login.application_service" REFRESH_TOKEN_PARAM = "org.matrix.msc2918.refresh_token" def __init__(self, hs: "HomeServer"): @@ -143,6 +144,7 @@ class LoginRestServlet(RestServlet): flows.extend({"type": t} for t in self.auth_handler.get_supported_login_types()) flows.append({"type": LoginRestServlet.APPSERVICE_TYPE}) + flows.append({"type": LoginRestServlet.APPSERVICE_TYPE_UNSTABLE}) return 200, {"flows": flows} @@ -159,7 +161,10 @@ class LoginRestServlet(RestServlet): should_issue_refresh_token = False try: - if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE: + if login_submission["type"] in ( + LoginRestServlet.APPSERVICE_TYPE, + LoginRestServlet.APPSERVICE_TYPE_UNSTABLE, + ): appservice = self.auth.get_appservice_by_req(request) if appservice.is_rate_limited(): diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 7dd4a5a367..08e9730d4d 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -31,7 +31,10 @@ from tests.unittest import override_config # (possibly experimental) login flows we expect to appear in the list after the normal # ones -ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}] +ADDITIONAL_LOGIN_FLOWS = [ + {"type": "m.login.application_service"}, + {"type": "uk.half-shot.msc2778.login.application_service"}, +] # a mock instance which the dummy auth providers delegate to, so we can see what's going # on diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index a63f04bd41..0b90e3f803 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -79,7 +79,10 @@ EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("", ""), ('q" =+"', '"fö&=o"')] # (possibly experimental) login flows we expect to appear in the list after the normal # ones -ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}] +ADDITIONAL_LOGIN_FLOWS = [ + {"type": "m.login.application_service"}, + {"type": "uk.half-shot.msc2778.login.application_service"}, +] class LoginRestServletTestCase(unittest.HomeserverTestCase): -- cgit 1.5.1 From 24b61f379ac1fc740e1b569b85363e2a0411883a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 16 Nov 2021 07:43:53 -0500 Subject: Add ability to un-shadow-ban via the admin API. (#11347) --- changelog.d/11347.feature | 1 + docs/admin_api/user_admin_api.md | 12 +++++++++--- synapse/rest/admin/users.py | 24 ++++++++++++++++++++++-- synapse/storage/databases/main/registration.py | 2 +- tests/rest/admin/test_user.py | 26 ++++++++++++++++++++------ 5 files changed, 53 insertions(+), 12 deletions(-) create mode 100644 changelog.d/11347.feature (limited to 'synapse/rest') diff --git a/changelog.d/11347.feature b/changelog.d/11347.feature new file mode 100644 index 0000000000..b0cb5345a0 --- /dev/null +++ b/changelog.d/11347.feature @@ -0,0 +1 @@ +Add admin API to un-shadow-ban a user. diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index 16ec33b3c1..ba574d795f 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -948,7 +948,7 @@ The following fields are returned in the JSON response body: See also the [Client-Server API Spec on pushers](https://matrix.org/docs/spec/client_server/latest#get-matrix-client-r0-pushers). -## Shadow-banning users +## Controlling whether a user is shadow-banned Shadow-banning is a useful tool for moderating malicious or egregiously abusive users. A shadow-banned users receives successful responses to their client-server API requests, @@ -961,16 +961,22 @@ or broken behaviour for the client. A shadow-banned user will not receive any notification and it is generally more appropriate to ban or kick abusive users. A shadow-banned user will be unable to contact anyone on the server. -The API is: +To shadow-ban a user the API is: ``` POST /_synapse/admin/v1/users//shadow_ban ``` +To un-shadow-ban a user the API is: + +``` +DELETE /_synapse/admin/v1/users//shadow_ban +``` + To use it, you will need to authenticate by providing an `access_token` for a server admin: [Admin API](../usage/administration/admin_api) -An empty JSON dict is returned. +An empty JSON dict is returned in both cases. **Parameters** diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index d14fafbbc9..23a8bf1fdb 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -909,7 +909,7 @@ class UserTokenRestServlet(RestServlet): class ShadowBanRestServlet(RestServlet): - """An admin API for shadow-banning a user. + """An admin API for controlling whether a user is shadow-banned. A shadow-banned users receives successful responses to their client-server API requests, but the events are not propagated into rooms. @@ -917,11 +917,19 @@ class ShadowBanRestServlet(RestServlet): Shadow-banning a user should be used as a tool of last resort and may lead to confusing or broken behaviour for the client. - Example: + Example of shadow-banning a user: POST /_synapse/admin/v1/users/@test:example.com/shadow_ban {} + 200 OK + {} + + Example of removing a user from being shadow-banned: + + DELETE /_synapse/admin/v1/users/@test:example.com/shadow_ban + {} + 200 OK {} """ @@ -945,6 +953,18 @@ class ShadowBanRestServlet(RestServlet): return 200, {} + async def on_DELETE( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self.auth, request) + + if not self.hs.is_mine_id(user_id): + raise SynapseError(400, "Only local users can be shadow-banned") + + await self.store.set_shadow_banned(UserID.from_string(user_id), False) + + return 200, {} + class RateLimitRestServlet(RestServlet): """An admin API to override ratelimiting for an user. diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 6c7d6ba508..5e55440570 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -476,7 +476,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): shadow_banned: true iff the user is to be shadow-banned, false otherwise. """ - def set_shadow_banned_txn(txn): + def set_shadow_banned_txn(txn: LoggingTransaction) -> None: user_id = user.to_string() self.db_pool.simple_update_one_txn( txn, diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 25e8d6cf27..c9fe0f06c2 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -3592,31 +3592,34 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): self.other_user ) - def test_no_auth(self): + @parameterized.expand(["POST", "DELETE"]) + def test_no_auth(self, method: str): """ Try to get information of an user without authentication. """ - channel = self.make_request("POST", self.url) + channel = self.make_request(method, self.url) self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_not_admin(self): + @parameterized.expand(["POST", "DELETE"]) + def test_requester_is_not_admin(self, method: str): """ If the user is not a server admin, an error is returned. """ other_user_token = self.login("user", "pass") - channel = self.make_request("POST", self.url, access_token=other_user_token) + channel = self.make_request(method, self.url, access_token=other_user_token) self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_user_is_not_local(self): + @parameterized.expand(["POST", "DELETE"]) + def test_user_is_not_local(self, method: str): """ Tests that shadow-banning for a user that is not a local returns a 400 """ url = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain" - channel = self.make_request("POST", url, access_token=self.admin_user_tok) + channel = self.make_request(method, url, access_token=self.admin_user_tok) self.assertEqual(400, channel.code, msg=channel.json_body) def test_success(self): @@ -3636,6 +3639,17 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): result = self.get_success(self.store.get_user_by_access_token(other_user_token)) self.assertTrue(result.shadow_banned) + # Un-shadow-ban the user. + channel = self.make_request( + "DELETE", self.url, access_token=self.admin_user_tok + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual({}, channel.json_body) + + # Ensure the user is no longer shadow-banned (and the cache was cleared). + result = self.get_success(self.store.get_user_by_access_token(other_user_token)) + self.assertFalse(result.shadow_banned) + class RateLimitTestCase(unittest.HomeserverTestCase): -- cgit 1.5.1 From dfa536490ee2295d664160362c6339d5e39b6b85 Mon Sep 17 00:00:00 2001 From: Aaron R Date: Tue, 16 Nov 2021 07:47:58 -0600 Subject: Add support for `/_matrix/client/v3` APIs (#11318) This is one of the changes required to support Matrix 1.1 Signed-off-by: Aaron Raimist --- changelog.d/11318.feature | 1 + synapse/app/homeserver.py | 1 + synapse/rest/client/_base.py | 4 ++-- synapse/rest/client/keys.py | 2 +- 4 files changed, 5 insertions(+), 3 deletions(-) create mode 100644 changelog.d/11318.feature (limited to 'synapse/rest') diff --git a/changelog.d/11318.feature b/changelog.d/11318.feature new file mode 100644 index 0000000000..ce28fc1eef --- /dev/null +++ b/changelog.d/11318.feature @@ -0,0 +1 @@ +Add support for the `/_matrix/client/v3` APIs from Matrix v1.1. \ No newline at end of file diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 7bb3744f04..4efadde57e 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -193,6 +193,7 @@ class SynapseHomeServer(HomeServer): { "/_matrix/client/api/v1": client_resource, "/_matrix/client/r0": client_resource, + "/_matrix/client/v3": client_resource, "/_matrix/client/unstable": client_resource, "/_matrix/client/v2_alpha": client_resource, "/_matrix/client/versions": client_resource, diff --git a/synapse/rest/client/_base.py b/synapse/rest/client/_base.py index a0971ce994..b4cb90cb76 100644 --- a/synapse/rest/client/_base.py +++ b/synapse/rest/client/_base.py @@ -27,7 +27,7 @@ logger = logging.getLogger(__name__) def client_patterns( path_regex: str, - releases: Iterable[int] = (0,), + releases: Iterable[str] = ("r0", "v3"), unstable: bool = True, v1: bool = False, ) -> Iterable[Pattern]: @@ -52,7 +52,7 @@ def client_patterns( v1_prefix = CLIENT_API_PREFIX + "/api/v1" patterns.append(re.compile("^" + v1_prefix + path_regex)) for release in releases: - new_prefix = CLIENT_API_PREFIX + "/r%d" % (release,) + new_prefix = CLIENT_API_PREFIX + f"/{release}" patterns.append(re.compile("^" + new_prefix + path_regex)) return patterns diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py index 7281b2ee29..730c18f08f 100644 --- a/synapse/rest/client/keys.py +++ b/synapse/rest/client/keys.py @@ -262,7 +262,7 @@ class SigningKeyUploadServlet(RestServlet): } """ - PATTERNS = client_patterns("/keys/device_signing/upload$", releases=()) + PATTERNS = client_patterns("/keys/device_signing/upload$", releases=("v3",)) def __init__(self, hs: "HomeServer"): super().__init__() -- cgit 1.5.1 From 0d86f6334ae5dc6be76b5bf2eea9ac5677fb89e8 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Wed, 17 Nov 2021 14:10:57 +0000 Subject: Rename `get_access_token_for_user_id` method to `create_access_token_for_user_id` (#11369) --- changelog.d/11369.misc | 1 + synapse/handlers/auth.py | 4 ++-- synapse/handlers/register.py | 2 +- synapse/rest/admin/users.py | 2 +- tests/handlers/test_auth.py | 10 +++++----- tests/rest/admin/test_user.py | 4 ++-- tests/rest/client/test_capabilities.py | 8 ++++---- 7 files changed, 16 insertions(+), 15 deletions(-) create mode 100644 changelog.d/11369.misc (limited to 'synapse/rest') diff --git a/changelog.d/11369.misc b/changelog.d/11369.misc new file mode 100644 index 0000000000..3c1dad544b --- /dev/null +++ b/changelog.d/11369.misc @@ -0,0 +1 @@ +Rename `get_access_token_for_user_id` to `create_access_token_for_user_id` to better reflect what it does. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index b62e13b725..b2c84a0fce 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -793,7 +793,7 @@ class AuthHandler: ) = await self.get_refresh_token_for_user_id( user_id=existing_token.user_id, device_id=existing_token.device_id ) - access_token = await self.get_access_token_for_user_id( + access_token = await self.create_access_token_for_user_id( user_id=existing_token.user_id, device_id=existing_token.device_id, valid_until_ms=valid_until_ms, @@ -855,7 +855,7 @@ class AuthHandler: ) return refresh_token, refresh_token_id - async def get_access_token_for_user_id( + async def create_access_token_for_user_id( self, user_id: str, device_id: Optional[str], diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index a0e6a01775..6b5c3d1974 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -819,7 +819,7 @@ class RegistrationHandler: ) valid_until_ms = self.clock.time_msec() + self.access_token_lifetime - access_token = await self._auth_handler.get_access_token_for_user_id( + access_token = await self._auth_handler.create_access_token_for_user_id( user_id, device_id=registered_device_id, valid_until_ms=valid_until_ms, diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 23a8bf1fdb..ccd9a2a175 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -898,7 +898,7 @@ class UserTokenRestServlet(RestServlet): if auth_user.to_string() == user_id: raise SynapseError(400, "Cannot use admin API to login as self") - token = await self.auth_handler.get_access_token_for_user_id( + token = await self.auth_handler.create_access_token_for_user_id( user_id=auth_user.to_string(), device_id=None, valid_until_ms=valid_until_ms, diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 12857053e7..72e176da75 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -116,7 +116,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.auth_blocking._limit_usage_by_mau = False # Ensure does not throw exception self.get_success( - self.auth_handler.get_access_token_for_user_id( + self.auth_handler.create_access_token_for_user_id( self.user1, device_id=None, valid_until_ms=None ) ) @@ -134,7 +134,7 @@ class AuthTestCase(unittest.HomeserverTestCase): ) self.get_failure( - self.auth_handler.get_access_token_for_user_id( + self.auth_handler.create_access_token_for_user_id( self.user1, device_id=None, valid_until_ms=None ), ResourceLimitError, @@ -162,7 +162,7 @@ class AuthTestCase(unittest.HomeserverTestCase): # If not in monthly active cohort self.get_failure( - self.auth_handler.get_access_token_for_user_id( + self.auth_handler.create_access_token_for_user_id( self.user1, device_id=None, valid_until_ms=None ), ResourceLimitError, @@ -179,7 +179,7 @@ class AuthTestCase(unittest.HomeserverTestCase): return_value=make_awaitable(self.clock.time_msec()) ) self.get_success( - self.auth_handler.get_access_token_for_user_id( + self.auth_handler.create_access_token_for_user_id( self.user1, device_id=None, valid_until_ms=None ) ) @@ -197,7 +197,7 @@ class AuthTestCase(unittest.HomeserverTestCase): ) # Ensure does not raise exception self.get_success( - self.auth_handler.get_access_token_for_user_id( + self.auth_handler.create_access_token_for_user_id( self.user1, device_id=None, valid_until_ms=None ) ) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index c9fe0f06c2..5011e54563 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -1169,14 +1169,14 @@ class UserRestTestCase(unittest.HomeserverTestCase): # regardless of whether password login or SSO is allowed self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.get_success( - self.auth_handler.get_access_token_for_user_id( + self.auth_handler.create_access_token_for_user_id( self.admin_user, device_id=None, valid_until_ms=None ) ) self.other_user = self.register_user("user", "pass", displayname="User") self.other_user_token = self.get_success( - self.auth_handler.get_access_token_for_user_id( + self.auth_handler.create_access_token_for_user_id( self.other_user, device_id=None, valid_until_ms=None ) ) diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py index b9e3602552..249808b031 100644 --- a/tests/rest/client/test_capabilities.py +++ b/tests/rest/client/test_capabilities.py @@ -71,7 +71,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): @override_config({"password_config": {"localdb_enabled": False}}) def test_get_change_password_capabilities_localdb_disabled(self): access_token = self.get_success( - self.auth_handler.get_access_token_for_user_id( + self.auth_handler.create_access_token_for_user_id( self.user, device_id=None, valid_until_ms=None ) ) @@ -85,7 +85,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): @override_config({"password_config": {"enabled": False}}) def test_get_change_password_capabilities_password_disabled(self): access_token = self.get_success( - self.auth_handler.get_access_token_for_user_id( + self.auth_handler.create_access_token_for_user_id( self.user, device_id=None, valid_until_ms=None ) ) @@ -174,7 +174,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): @override_config({"experimental_features": {"msc3244_enabled": False}}) def test_get_does_not_include_msc3244_fields_when_disabled(self): access_token = self.get_success( - self.auth_handler.get_access_token_for_user_id( + self.auth_handler.create_access_token_for_user_id( self.user, device_id=None, valid_until_ms=None ) ) @@ -189,7 +189,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): def test_get_does_include_msc3244_fields_when_enabled(self): access_token = self.get_success( - self.auth_handler.get_access_token_for_user_id( + self.auth_handler.create_access_token_for_user_id( self.user, device_id=None, valid_until_ms=None ) ) -- cgit 1.5.1 From 81b18fe5c060a0532ab64b9575d54b84ddbad278 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 18 Nov 2021 18:43:49 +0100 Subject: Add dedicated admin API for blocking a room (#11324) --- changelog.d/11324.feature | 1 + docs/admin_api/rooms.md | 78 +++++++++++ synapse/rest/admin/__init__.py | 2 + synapse/rest/admin/rooms.py | 63 +++++++++ synapse/storage/databases/main/room.py | 32 +++++ tests/rest/admin/test_room.py | 228 +++++++++++++++++++++++++++++++++ 6 files changed, 404 insertions(+) create mode 100644 changelog.d/11324.feature (limited to 'synapse/rest') diff --git a/changelog.d/11324.feature b/changelog.d/11324.feature new file mode 100644 index 0000000000..55494358bb --- /dev/null +++ b/changelog.d/11324.feature @@ -0,0 +1 @@ +Add dedicated admin API for blocking a room. \ No newline at end of file diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md index 6a6ae92d66..0f1a74134f 100644 --- a/docs/admin_api/rooms.md +++ b/docs/admin_api/rooms.md @@ -3,6 +3,7 @@ - [Room Details API](#room-details-api) - [Room Members API](#room-members-api) - [Room State API](#room-state-api) +- [Block Room API](#block-room-api) - [Delete Room API](#delete-room-api) * [Version 1 (old version)](#version-1-old-version) * [Version 2 (new version)](#version-2-new-version) @@ -386,6 +387,83 @@ A response body like the following is returned: } ``` +# Block Room API +The Block Room admin API allows server admins to block and unblock rooms, +and query to see if a given room is blocked. +This API can be used to pre-emptively block a room, even if it's unknown to this +homeserver. Users will be prevented from joining a blocked room. + +## Block or unblock a room + +The API is: + +``` +PUT /_synapse/admin/v1/rooms//block +``` + +with a body of: + +```json +{ + "block": true +} +``` + +A response body like the following is returned: + +```json +{ + "block": true +} +``` + +**Parameters** + +The following parameters should be set in the URL: + +- `room_id` - The ID of the room. + +The following JSON body parameters are available: + +- `block` - If `true` the room will be blocked and if `false` the room will be unblocked. + +**Response** + +The following fields are possible in the JSON response body: + +- `block` - A boolean. `true` if the room is blocked, otherwise `false` + +## Get block status + +The API is: + +``` +GET /_synapse/admin/v1/rooms//block +``` + +A response body like the following is returned: + +```json +{ + "block": true, + "user_id": "" +} +``` + +**Parameters** + +The following parameters should be set in the URL: + +- `room_id` - The ID of the room. + +**Response** + +The following fields are possible in the JSON response body: + +- `block` - A boolean. `true` if the room is blocked, otherwise `false` +- `user_id` - An optional string. If the room is blocked (`block` is `true`) shows + the user who has add the room to blocking list. Otherwise it is not displayed. + # Delete Room API The Delete Room admin API allows server admins to remove rooms from the server diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index d78fe406c4..65b76fa10c 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -46,6 +46,7 @@ from synapse.rest.admin.registration_tokens import ( RegistrationTokenRestServlet, ) from synapse.rest.admin.rooms import ( + BlockRoomRestServlet, DeleteRoomStatusByDeleteIdRestServlet, DeleteRoomStatusByRoomIdRestServlet, ForwardExtremitiesRestServlet, @@ -223,6 +224,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: Register all the admin servlets. """ register_servlets_for_client_rest_resource(hs, http_server) + BlockRoomRestServlet(hs).register(http_server) ListRoomRestServlet(hs).register(http_server) RoomStateRestServlet(hs).register(http_server) RoomRestServlet(hs).register(http_server) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 37cb4d0796..5b8ec1e5ca 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -782,3 +782,66 @@ class RoomEventContextServlet(RestServlet): ) return 200, results + + +class BlockRoomRestServlet(RestServlet): + """ + Manage blocking of rooms. + On PUT: Add or remove a room from blocking list. + On GET: Get blocking status of room and user who has blocked this room. + """ + + PATTERNS = admin_patterns("/rooms/(?P[^/]+)/block$") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._store = hs.get_datastore() + + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self._auth, request) + + if not RoomID.is_valid(room_id): + raise SynapseError( + HTTPStatus.BAD_REQUEST, "%s is not a legal room ID" % (room_id,) + ) + + blocked_by = await self._store.room_is_blocked_by(room_id) + # Test `not None` if `user_id` is an empty string + # if someone add manually an entry in database + if blocked_by is not None: + response = {"block": True, "user_id": blocked_by} + else: + response = {"block": False} + + return HTTPStatus.OK, response + + async def on_PUT( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self._auth.get_user_by_req(request) + await assert_user_is_admin(self._auth, requester.user) + + content = parse_json_object_from_request(request) + + if not RoomID.is_valid(room_id): + raise SynapseError( + HTTPStatus.BAD_REQUEST, "%s is not a legal room ID" % (room_id,) + ) + + assert_params_in_dict(content, ["block"]) + block = content.get("block") + if not isinstance(block, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'block' must be a boolean.", + Codes.BAD_JSON, + ) + + if block: + await self._store.block_room(room_id, requester.user.to_string()) + else: + await self._store.unblock_room(room_id) + + return HTTPStatus.OK, {"block": block} diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 17b398bb69..7d694d852d 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -397,6 +397,20 @@ class RoomWorkerStore(SQLBaseStore): desc="is_room_blocked", ) + async def room_is_blocked_by(self, room_id: str) -> Optional[str]: + """ + Function to retrieve user who has blocked the room. + user_id is non-nullable + It returns None if the room is not blocked. + """ + return await self.db_pool.simple_select_one_onecol( + table="blocked_rooms", + keyvalues={"room_id": room_id}, + retcol="user_id", + allow_none=True, + desc="room_is_blocked_by", + ) + async def get_rooms_paginate( self, start: int, @@ -1775,3 +1789,21 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): self.is_room_blocked, (room_id,), ) + + async def unblock_room(self, room_id: str) -> None: + """Remove the room from blocking list. + + Args: + room_id: Room to unblock + """ + await self.db_pool.simple_delete( + table="blocked_rooms", + keyvalues={"room_id": room_id}, + desc="unblock_room", + ) + await self.db_pool.runInteraction( + "block_room_invalidation", + self._invalidate_cache_and_stream, + self.is_room_blocked, + (room_id,), + ) diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index b48fc12e5f..07077aff78 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -2226,6 +2226,234 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): ) +class BlockRoomTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self._store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + + self.room_id = self.helper.create_room_as( + self.other_user, tok=self.other_user_tok + ) + self.url = "/_synapse/admin/v1/rooms/%s/block" + + @parameterized.expand([("PUT",), ("GET",)]) + def test_requester_is_no_admin(self, method: str): + """If the user is not a server admin, an error 403 is returned.""" + + channel = self.make_request( + method, + self.url % self.room_id, + content={}, + access_token=self.other_user_tok, + ) + + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + @parameterized.expand([("PUT",), ("GET",)]) + def test_room_is_not_valid(self, method: str): + """Check that invalid room names, return an error 400.""" + + channel = self.make_request( + method, + self.url % "invalidroom", + content={}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual( + "invalidroom is not a legal room ID", + channel.json_body["error"], + ) + + def test_block_is_not_valid(self): + """If parameter `block` is not valid, return an error.""" + + # `block` is not valid + channel = self.make_request( + "PUT", + self.url % self.room_id, + content={"block": "NotBool"}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) + + # `block` is not set + channel = self.make_request( + "PUT", + self.url % self.room_id, + content={}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) + + # no content is send + channel = self.make_request( + "PUT", + self.url % self.room_id, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_JSON, channel.json_body["errcode"]) + + def test_block_room(self): + """Test that block a room is successful.""" + + def _request_and_test_block_room(room_id: str) -> None: + self._is_blocked(room_id, expect=False) + channel = self.make_request( + "PUT", + self.url % room_id, + content={"block": True}, + access_token=self.admin_user_tok, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertTrue(channel.json_body["block"]) + self._is_blocked(room_id, expect=True) + + # known internal room + _request_and_test_block_room(self.room_id) + + # unknown internal room + _request_and_test_block_room("!unknown:test") + + # unknown remote room + _request_and_test_block_room("!unknown:remote") + + def test_block_room_twice(self): + """Test that block a room that is already blocked is successful.""" + + self._is_blocked(self.room_id, expect=False) + for _ in range(2): + channel = self.make_request( + "PUT", + self.url % self.room_id, + content={"block": True}, + access_token=self.admin_user_tok, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertTrue(channel.json_body["block"]) + self._is_blocked(self.room_id, expect=True) + + def test_unblock_room(self): + """Test that unblock a room is successful.""" + + def _request_and_test_unblock_room(room_id: str) -> None: + self._block_room(room_id) + + channel = self.make_request( + "PUT", + self.url % room_id, + content={"block": False}, + access_token=self.admin_user_tok, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertFalse(channel.json_body["block"]) + self._is_blocked(room_id, expect=False) + + # known internal room + _request_and_test_unblock_room(self.room_id) + + # unknown internal room + _request_and_test_unblock_room("!unknown:test") + + # unknown remote room + _request_and_test_unblock_room("!unknown:remote") + + def test_unblock_room_twice(self): + """Test that unblock a room that is not blocked is successful.""" + + self._block_room(self.room_id) + for _ in range(2): + channel = self.make_request( + "PUT", + self.url % self.room_id, + content={"block": False}, + access_token=self.admin_user_tok, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertFalse(channel.json_body["block"]) + self._is_blocked(self.room_id, expect=False) + + def test_get_blocked_room(self): + """Test get status of a blocked room""" + + def _request_blocked_room(room_id: str) -> None: + self._block_room(room_id) + + channel = self.make_request( + "GET", + self.url % room_id, + access_token=self.admin_user_tok, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertTrue(channel.json_body["block"]) + self.assertEqual(self.other_user, channel.json_body["user_id"]) + + # known internal room + _request_blocked_room(self.room_id) + + # unknown internal room + _request_blocked_room("!unknown:test") + + # unknown remote room + _request_blocked_room("!unknown:remote") + + def test_get_unblocked_room(self): + """Test get status of a unblocked room""" + + def _request_unblocked_room(room_id: str) -> None: + self._is_blocked(room_id, expect=False) + + channel = self.make_request( + "GET", + self.url % room_id, + access_token=self.admin_user_tok, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertFalse(channel.json_body["block"]) + self.assertNotIn("user_id", channel.json_body) + + # known internal room + _request_unblocked_room(self.room_id) + + # unknown internal room + _request_unblocked_room("!unknown:test") + + # unknown remote room + _request_unblocked_room("!unknown:remote") + + def _is_blocked(self, room_id: str, expect: bool = True) -> None: + """Assert that the room is blocked or not""" + d = self._store.is_room_blocked(room_id) + if expect: + self.assertTrue(self.get_success(d)) + else: + self.assertIsNone(self.get_success(d)) + + def _block_room(self, room_id: str) -> None: + """Block a room in database""" + self.get_success(self._store.block_room(room_id, self.other_user)) + self._is_blocked(room_id, expect=True) + + PURGE_TABLES = [ "current_state_events", "event_backward_extremities", -- cgit 1.5.1 From 91f2bd0907f1d05af67166846988e49644eb650c Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Fri, 19 Nov 2021 13:39:15 +0000 Subject: Prevent the media store from writing outside of the configured directory Also tighten validation of server names by forbidding invalid characters in IPv6 addresses and empty domain labels. --- synapse/rest/media/v1/_base.py | 18 ++- synapse/rest/media/v1/filepath.py | 241 +++++++++++++++++++++++++++------ synapse/util/stringutils.py | 21 ++- tests/http/test_endpoint.py | 3 + tests/rest/media/v1/test_filepath.py | 250 +++++++++++++++++++++++++++++++++++ 5 files changed, 483 insertions(+), 50 deletions(-) (limited to 'synapse/rest') diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 014fa893d6..9b40fd8a6c 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -29,7 +29,7 @@ from synapse.api.errors import Codes, SynapseError, cs_error from synapse.http.server import finish_request, respond_with_json from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable -from synapse.util.stringutils import is_ascii +from synapse.util.stringutils import is_ascii, parse_and_validate_server_name logger = logging.getLogger(__name__) @@ -51,6 +51,19 @@ TEXT_CONTENT_TYPES = [ def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]: + """Parses the server name, media ID and optional file name from the request URI + + Also performs some rough validation on the server name. + + Args: + request: The `Request`. + + Returns: + A tuple containing the parsed server name, media ID and optional file name. + + Raises: + SynapseError(404): if parsing or validation fail for any reason + """ try: # The type on postpath seems incorrect in Twisted 21.2.0. postpath: List[bytes] = request.postpath # type: ignore @@ -62,6 +75,9 @@ def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]: server_name = server_name_bytes.decode("utf-8") media_id = media_id_bytes.decode("utf8") + # Validate the server name, raising if invalid + parse_and_validate_server_name(server_name) + file_name = None if len(postpath) > 2: try: diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py index bec77088ee..c0e15c6513 100644 --- a/synapse/rest/media/v1/filepath.py +++ b/synapse/rest/media/v1/filepath.py @@ -16,7 +16,8 @@ import functools import os import re -from typing import Any, Callable, List, TypeVar, cast +import string +from typing import Any, Callable, List, TypeVar, Union, cast NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d") @@ -37,6 +38,85 @@ def _wrap_in_base_path(func: F) -> F: return cast(F, _wrapped) +GetPathMethod = TypeVar( + "GetPathMethod", bound=Union[Callable[..., str], Callable[..., List[str]]] +) + + +def _wrap_with_jail_check(func: GetPathMethod) -> GetPathMethod: + """Wraps a path-returning method to check that the returned path(s) do not escape + the media store directory. + + The check is not expected to ever fail, unless `func` is missing a call to + `_validate_path_component`, or `_validate_path_component` is buggy. + + Args: + func: The `MediaFilePaths` method to wrap. The method may return either a single + path, or a list of paths. Returned paths may be either absolute or relative. + + Returns: + The method, wrapped with a check to ensure that the returned path(s) lie within + the media store directory. Raises a `ValueError` if the check fails. + """ + + @functools.wraps(func) + def _wrapped( + self: "MediaFilePaths", *args: Any, **kwargs: Any + ) -> Union[str, List[str]]: + path_or_paths = func(self, *args, **kwargs) + + if isinstance(path_or_paths, list): + paths_to_check = path_or_paths + else: + paths_to_check = [path_or_paths] + + for path in paths_to_check: + # path may be an absolute or relative path, depending on the method being + # wrapped. When "appending" an absolute path, `os.path.join` discards the + # previous path, which is desired here. + normalized_path = os.path.normpath(os.path.join(self.real_base_path, path)) + if ( + os.path.commonpath([normalized_path, self.real_base_path]) + != self.real_base_path + ): + raise ValueError(f"Invalid media store path: {path!r}") + + return path_or_paths + + return cast(GetPathMethod, _wrapped) + + +ALLOWED_CHARACTERS = set( + string.ascii_letters + + string.digits + + "_-" + + ".[]:" # Domain names, IPv6 addresses and ports in server names +) +FORBIDDEN_NAMES = { + "", + os.path.curdir, # "." for the current platform + os.path.pardir, # ".." for the current platform +} + + +def _validate_path_component(name: str) -> str: + """Checks that the given string can be safely used as a path component + + Args: + name: The path component to check. + + Returns: + The path component if valid. + + Raises: + ValueError: If `name` cannot be safely used as a path component. + """ + if not ALLOWED_CHARACTERS.issuperset(name) or name in FORBIDDEN_NAMES: + raise ValueError(f"Invalid path component: {name!r}") + + return name + + class MediaFilePaths: """Describes where files are stored on disk. @@ -48,22 +128,46 @@ class MediaFilePaths: def __init__(self, primary_base_path: str): self.base_path = primary_base_path + # The media store directory, with all symlinks resolved. + self.real_base_path = os.path.realpath(primary_base_path) + + # Refuse to initialize if paths cannot be validated correctly for the current + # platform. + assert os.path.sep not in ALLOWED_CHARACTERS + assert os.path.altsep not in ALLOWED_CHARACTERS + # On Windows, paths have all sorts of weirdness which `_validate_path_component` + # does not consider. In any case, the remote media store can't work correctly + # for certain homeservers there, since ":"s aren't allowed in paths. + assert os.name == "posix" + + @_wrap_with_jail_check def local_media_filepath_rel(self, media_id: str) -> str: - return os.path.join("local_content", media_id[0:2], media_id[2:4], media_id[4:]) + return os.path.join( + "local_content", + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + _validate_path_component(media_id[4:]), + ) local_media_filepath = _wrap_in_base_path(local_media_filepath_rel) + @_wrap_with_jail_check def local_media_thumbnail_rel( self, media_id: str, width: int, height: int, content_type: str, method: str ) -> str: top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) return os.path.join( - "local_thumbnails", media_id[0:2], media_id[2:4], media_id[4:], file_name + "local_thumbnails", + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + _validate_path_component(media_id[4:]), + _validate_path_component(file_name), ) local_media_thumbnail = _wrap_in_base_path(local_media_thumbnail_rel) + @_wrap_with_jail_check def local_media_thumbnail_dir(self, media_id: str) -> str: """ Retrieve the local store path of thumbnails of a given media_id @@ -76,18 +180,24 @@ class MediaFilePaths: return os.path.join( self.base_path, "local_thumbnails", - media_id[0:2], - media_id[2:4], - media_id[4:], + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + _validate_path_component(media_id[4:]), ) + @_wrap_with_jail_check def remote_media_filepath_rel(self, server_name: str, file_id: str) -> str: return os.path.join( - "remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:] + "remote_content", + _validate_path_component(server_name), + _validate_path_component(file_id[0:2]), + _validate_path_component(file_id[2:4]), + _validate_path_component(file_id[4:]), ) remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel) + @_wrap_with_jail_check def remote_media_thumbnail_rel( self, server_name: str, @@ -101,11 +211,11 @@ class MediaFilePaths: file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) return os.path.join( "remote_thumbnail", - server_name, - file_id[0:2], - file_id[2:4], - file_id[4:], - file_name, + _validate_path_component(server_name), + _validate_path_component(file_id[0:2]), + _validate_path_component(file_id[2:4]), + _validate_path_component(file_id[4:]), + _validate_path_component(file_name), ) remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel) @@ -113,6 +223,7 @@ class MediaFilePaths: # Legacy path that was used to store thumbnails previously. # Should be removed after some time, when most of the thumbnails are stored # using the new path. + @_wrap_with_jail_check def remote_media_thumbnail_rel_legacy( self, server_name: str, file_id: str, width: int, height: int, content_type: str ) -> str: @@ -120,43 +231,66 @@ class MediaFilePaths: file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type) return os.path.join( "remote_thumbnail", - server_name, - file_id[0:2], - file_id[2:4], - file_id[4:], - file_name, + _validate_path_component(server_name), + _validate_path_component(file_id[0:2]), + _validate_path_component(file_id[2:4]), + _validate_path_component(file_id[4:]), + _validate_path_component(file_name), ) def remote_media_thumbnail_dir(self, server_name: str, file_id: str) -> str: return os.path.join( self.base_path, "remote_thumbnail", - server_name, - file_id[0:2], - file_id[2:4], - file_id[4:], + _validate_path_component(server_name), + _validate_path_component(file_id[0:2]), + _validate_path_component(file_id[2:4]), + _validate_path_component(file_id[4:]), ) + @_wrap_with_jail_check def url_cache_filepath_rel(self, media_id: str) -> str: if NEW_FORMAT_ID_RE.match(media_id): # Media id is of the form # E.g.: 2017-09-28-fsdRDt24DS234dsf - return os.path.join("url_cache", media_id[:10], media_id[11:]) + return os.path.join( + "url_cache", + _validate_path_component(media_id[:10]), + _validate_path_component(media_id[11:]), + ) else: - return os.path.join("url_cache", media_id[0:2], media_id[2:4], media_id[4:]) + return os.path.join( + "url_cache", + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + _validate_path_component(media_id[4:]), + ) url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel) + @_wrap_with_jail_check def url_cache_filepath_dirs_to_delete(self, media_id: str) -> List[str]: "The dirs to try and remove if we delete the media_id file" if NEW_FORMAT_ID_RE.match(media_id): - return [os.path.join(self.base_path, "url_cache", media_id[:10])] + return [ + os.path.join( + self.base_path, "url_cache", _validate_path_component(media_id[:10]) + ) + ] else: return [ - os.path.join(self.base_path, "url_cache", media_id[0:2], media_id[2:4]), - os.path.join(self.base_path, "url_cache", media_id[0:2]), + os.path.join( + self.base_path, + "url_cache", + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + ), + os.path.join( + self.base_path, "url_cache", _validate_path_component(media_id[0:2]) + ), ] + @_wrap_with_jail_check def url_cache_thumbnail_rel( self, media_id: str, width: int, height: int, content_type: str, method: str ) -> str: @@ -168,37 +302,46 @@ class MediaFilePaths: if NEW_FORMAT_ID_RE.match(media_id): return os.path.join( - "url_cache_thumbnails", media_id[:10], media_id[11:], file_name + "url_cache_thumbnails", + _validate_path_component(media_id[:10]), + _validate_path_component(media_id[11:]), + _validate_path_component(file_name), ) else: return os.path.join( "url_cache_thumbnails", - media_id[0:2], - media_id[2:4], - media_id[4:], - file_name, + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + _validate_path_component(media_id[4:]), + _validate_path_component(file_name), ) url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel) + @_wrap_with_jail_check def url_cache_thumbnail_directory_rel(self, media_id: str) -> str: # Media id is of the form # E.g.: 2017-09-28-fsdRDt24DS234dsf if NEW_FORMAT_ID_RE.match(media_id): - return os.path.join("url_cache_thumbnails", media_id[:10], media_id[11:]) + return os.path.join( + "url_cache_thumbnails", + _validate_path_component(media_id[:10]), + _validate_path_component(media_id[11:]), + ) else: return os.path.join( "url_cache_thumbnails", - media_id[0:2], - media_id[2:4], - media_id[4:], + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + _validate_path_component(media_id[4:]), ) url_cache_thumbnail_directory = _wrap_in_base_path( url_cache_thumbnail_directory_rel ) + @_wrap_with_jail_check def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]: "The dirs to try and remove if we delete the media_id thumbnails" # Media id is of the form @@ -206,21 +349,35 @@ class MediaFilePaths: if NEW_FORMAT_ID_RE.match(media_id): return [ os.path.join( - self.base_path, "url_cache_thumbnails", media_id[:10], media_id[11:] + self.base_path, + "url_cache_thumbnails", + _validate_path_component(media_id[:10]), + _validate_path_component(media_id[11:]), + ), + os.path.join( + self.base_path, + "url_cache_thumbnails", + _validate_path_component(media_id[:10]), ), - os.path.join(self.base_path, "url_cache_thumbnails", media_id[:10]), ] else: return [ os.path.join( self.base_path, "url_cache_thumbnails", - media_id[0:2], - media_id[2:4], - media_id[4:], + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + _validate_path_component(media_id[4:]), ), os.path.join( - self.base_path, "url_cache_thumbnails", media_id[0:2], media_id[2:4] + self.base_path, + "url_cache_thumbnails", + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + ), + os.path.join( + self.base_path, + "url_cache_thumbnails", + _validate_path_component(media_id[0:2]), ), - os.path.join(self.base_path, "url_cache_thumbnails", media_id[0:2]), ] diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index f029432191..ea1032b4fc 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -19,6 +19,8 @@ import string from collections.abc import Iterable from typing import Optional, Tuple +from netaddr import valid_ipv6 + from synapse.api.errors import Codes, SynapseError _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@" @@ -97,7 +99,10 @@ def parse_server_name(server_name: str) -> Tuple[str, Optional[int]]: raise ValueError("Invalid server name '%s'" % server_name) -VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z") +# An approximation of the domain name syntax in RFC 1035, section 2.3.1. +# NB: "\Z" is not equivalent to "$". +# The latter will match the position before a "\n" at the end of a string. +VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z-]+(?:\\.[0-9a-zA-Z-]+)*\\Z") def parse_and_validate_server_name(server_name: str) -> Tuple[str, Optional[int]]: @@ -122,13 +127,15 @@ def parse_and_validate_server_name(server_name: str) -> Tuple[str, Optional[int] if host[0] == "[": if host[-1] != "]": raise ValueError("Mismatched [...] in server name '%s'" % (server_name,)) - return host, port - # otherwise it should only be alphanumerics. - if not VALID_HOST_REGEX.match(host): - raise ValueError( - "Server name '%s' contains invalid characters" % (server_name,) - ) + # valid_ipv6 raises when given an empty string + ipv6_address = host[1:-1] + if not ipv6_address or not valid_ipv6(ipv6_address): + raise ValueError( + "Server name '%s' is not a valid IPv6 address" % (server_name,) + ) + elif not VALID_HOST_REGEX.match(host): + raise ValueError("Server name '%s' has an invalid format" % (server_name,)) return host, port diff --git a/tests/http/test_endpoint.py b/tests/http/test_endpoint.py index 1f9a2f9b1d..c8cc21cadd 100644 --- a/tests/http/test_endpoint.py +++ b/tests/http/test_endpoint.py @@ -36,8 +36,11 @@ class ServerNameTestCase(unittest.TestCase): "localhost:http", # non-numeric port "1234]", # smells like ipv6 literal but isn't "[1234", + "[1.2.3.4]", "underscore_.com", "percent%65.com", + "newline.com\n", + ".empty-label.com", "1234:5678:80", # too many colons ] for i in test_data: diff --git a/tests/rest/media/v1/test_filepath.py b/tests/rest/media/v1/test_filepath.py index 09504a485f..8fe94f7d85 100644 --- a/tests/rest/media/v1/test_filepath.py +++ b/tests/rest/media/v1/test_filepath.py @@ -11,6 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect +from typing import Iterable + from synapse.rest.media.v1.filepath import MediaFilePaths from tests import unittest @@ -236,3 +239,250 @@ class MediaFilePathsTestCase(unittest.TestCase): "/media_store/url_cache_thumbnails/Ge", ], ) + + def test_server_name_validation(self): + """Test validation of server names""" + self._test_path_validation( + [ + "remote_media_filepath_rel", + "remote_media_filepath", + "remote_media_thumbnail_rel", + "remote_media_thumbnail", + "remote_media_thumbnail_rel_legacy", + "remote_media_thumbnail_dir", + ], + parameter="server_name", + valid_values=[ + "matrix.org", + "matrix.org:8448", + "matrix-federation.matrix.org", + "matrix-federation.matrix.org:8448", + "10.1.12.123", + "10.1.12.123:8448", + "[fd00:abcd::ffff]", + "[fd00:abcd::ffff]:8448", + ], + invalid_values=[ + "/matrix.org", + "matrix.org/..", + "matrix.org\x00", + "", + ".", + "..", + "/", + ], + ) + + def test_file_id_validation(self): + """Test validation of local, remote and legacy URL cache file / media IDs""" + # File / media IDs get split into three parts to form paths, consisting of the + # first two characters, next two characters and rest of the ID. + valid_file_ids = [ + "GerZNDnDZVjsOtardLuwfIBg", + # Unexpected, but produces an acceptable path: + "GerZN", # "N" becomes the last directory + ] + invalid_file_ids = [ + "/erZNDnDZVjsOtardLuwfIBg", + "Ge/ZNDnDZVjsOtardLuwfIBg", + "GerZ/DnDZVjsOtardLuwfIBg", + "GerZ/..", + "G\x00rZNDnDZVjsOtardLuwfIBg", + "Ger\x00NDnDZVjsOtardLuwfIBg", + "GerZNDnDZVjsOtardLuwfIBg\x00", + "", + "Ge", + "GerZ", + "GerZ.", + "..rZNDnDZVjsOtardLuwfIBg", + "Ge..NDnDZVjsOtardLuwfIBg", + "GerZ..", + "GerZ/", + ] + + self._test_path_validation( + [ + "local_media_filepath_rel", + "local_media_filepath", + "local_media_thumbnail_rel", + "local_media_thumbnail", + "local_media_thumbnail_dir", + # Legacy URL cache media IDs + "url_cache_filepath_rel", + "url_cache_filepath", + # `url_cache_filepath_dirs_to_delete` is tested below. + "url_cache_thumbnail_rel", + "url_cache_thumbnail", + "url_cache_thumbnail_directory_rel", + "url_cache_thumbnail_directory", + "url_cache_thumbnail_dirs_to_delete", + ], + parameter="media_id", + valid_values=valid_file_ids, + invalid_values=invalid_file_ids, + ) + + # `url_cache_filepath_dirs_to_delete` ignores what would be the last path + # component, so only the first 4 characters matter. + self._test_path_validation( + [ + "url_cache_filepath_dirs_to_delete", + ], + parameter="media_id", + valid_values=valid_file_ids, + invalid_values=[ + "/erZNDnDZVjsOtardLuwfIBg", + "Ge/ZNDnDZVjsOtardLuwfIBg", + "G\x00rZNDnDZVjsOtardLuwfIBg", + "Ger\x00NDnDZVjsOtardLuwfIBg", + "", + "Ge", + "..rZNDnDZVjsOtardLuwfIBg", + "Ge..NDnDZVjsOtardLuwfIBg", + ], + ) + + self._test_path_validation( + [ + "remote_media_filepath_rel", + "remote_media_filepath", + "remote_media_thumbnail_rel", + "remote_media_thumbnail", + "remote_media_thumbnail_rel_legacy", + "remote_media_thumbnail_dir", + ], + parameter="file_id", + valid_values=valid_file_ids, + invalid_values=invalid_file_ids, + ) + + def test_url_cache_media_id_validation(self): + """Test validation of URL cache media IDs""" + self._test_path_validation( + [ + "url_cache_filepath_rel", + "url_cache_filepath", + # `url_cache_filepath_dirs_to_delete` only cares about the date prefix + "url_cache_thumbnail_rel", + "url_cache_thumbnail", + "url_cache_thumbnail_directory_rel", + "url_cache_thumbnail_directory", + "url_cache_thumbnail_dirs_to_delete", + ], + parameter="media_id", + valid_values=[ + "2020-01-02_GerZNDnDZVjsOtar", + "2020-01-02_G", # Unexpected, but produces an acceptable path + ], + invalid_values=[ + "2020-01-02", + "2020-01-02-", + "2020-01-02-.", + "2020-01-02-..", + "2020-01-02-/", + "2020-01-02-/GerZNDnDZVjsOtar", + "2020-01-02-GerZNDnDZVjsOtar/..", + "2020-01-02-GerZNDnDZVjsOtar\x00", + ], + ) + + def test_content_type_validation(self): + """Test validation of thumbnail content types""" + self._test_path_validation( + [ + "local_media_thumbnail_rel", + "local_media_thumbnail", + "remote_media_thumbnail_rel", + "remote_media_thumbnail", + "remote_media_thumbnail_rel_legacy", + "url_cache_thumbnail_rel", + "url_cache_thumbnail", + ], + parameter="content_type", + valid_values=[ + "image/jpeg", + ], + invalid_values=[ + "", # ValueError: not enough values to unpack + "image/jpeg/abc", # ValueError: too many values to unpack + "image/jpeg\x00", + ], + ) + + def test_thumbnail_method_validation(self): + """Test validation of thumbnail methods""" + self._test_path_validation( + [ + "local_media_thumbnail_rel", + "local_media_thumbnail", + "remote_media_thumbnail_rel", + "remote_media_thumbnail", + "url_cache_thumbnail_rel", + "url_cache_thumbnail", + ], + parameter="method", + valid_values=[ + "crop", + "scale", + ], + invalid_values=[ + "/scale", + "scale/..", + "scale\x00", + "/", + ], + ) + + def _test_path_validation( + self, + methods: Iterable[str], + parameter: str, + valid_values: Iterable[str], + invalid_values: Iterable[str], + ): + """Test that the specified methods validate the named parameter as expected + + Args: + methods: The names of `MediaFilePaths` methods to test + parameter: The name of the parameter to test + valid_values: A list of parameter values that are expected to be accepted + invalid_values: A list of parameter values that are expected to be rejected + + Raises: + AssertionError: If a value was accepted when it should have failed + validation. + ValueError: If a value failed validation when it should have been accepted. + """ + for method in methods: + get_path = getattr(self.filepaths, method) + + parameters = inspect.signature(get_path).parameters + kwargs = { + "server_name": "matrix.org", + "media_id": "GerZNDnDZVjsOtardLuwfIBg", + "file_id": "GerZNDnDZVjsOtardLuwfIBg", + "width": 800, + "height": 600, + "content_type": "image/jpeg", + "method": "scale", + } + + if get_path.__name__.startswith("url_"): + kwargs["media_id"] = "2020-01-02_GerZNDnDZVjsOtar" + + kwargs = {k: v for k, v in kwargs.items() if k in parameters} + kwargs.pop(parameter) + + for value in valid_values: + kwargs[parameter] = value + get_path(**kwargs) + # No exception should be raised + + for value in invalid_values: + with self.assertRaises(ValueError): + kwargs[parameter] = value + path_or_list = get_path(**kwargs) + self.fail( + f"{value!r} unexpectedly passed validation: " + f"{method} returned {path_or_list!r}" + ) -- cgit 1.5.1 From ea20937084903864865f76e22f67d27729f2d6dc Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Fri, 19 Nov 2021 20:39:46 +0100 Subject: Add an admin API to run background jobs. (#11352) Instead of having admins poke into the database directly. Can currently run jobs to populate stats and to populate the user directory. --- changelog.d/11352.feature | 1 + docs/sample_config.yaml | 4 +- .../administration/admin_api/background_updates.md | 27 +++- docs/user_directory.md | 6 +- synapse/config/user_directory.py | 4 +- synapse/rest/admin/__init__.py | 2 + synapse/rest/admin/background_updates.py | 123 ++++++++++++---- synapse/storage/background_updates.py | 2 + tests/rest/admin/test_background_updates.py | 154 +++++++++++++++++++-- 9 files changed, 280 insertions(+), 43 deletions(-) create mode 100644 changelog.d/11352.feature (limited to 'synapse/rest') diff --git a/changelog.d/11352.feature b/changelog.d/11352.feature new file mode 100644 index 0000000000..a4d01b3549 --- /dev/null +++ b/changelog.d/11352.feature @@ -0,0 +1 @@ +Add admin API to run background jobs. \ No newline at end of file diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 3c931468aa..aee300013f 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -2360,8 +2360,8 @@ user_directory: # indexes were (re)built was before Synapse 1.44, you'll have to # rebuild the indexes in order to search through all known users. # These indexes are built the first time Synapse starts; admins can - # manually trigger a rebuild following the instructions at - # https://matrix-org.github.io/synapse/latest/user_directory.html + # manually trigger a rebuild via API following the instructions at + # https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/background_updates.html#run # # Uncomment to return search results containing all known users, even if that # user does not share a room with the requester. diff --git a/docs/usage/administration/admin_api/background_updates.md b/docs/usage/administration/admin_api/background_updates.md index b36d7fe398..9f6ac7d567 100644 --- a/docs/usage/administration/admin_api/background_updates.md +++ b/docs/usage/administration/admin_api/background_updates.md @@ -42,7 +42,6 @@ For each update: `average_items_per_ms` how many items are processed per millisecond based on an exponential average. - ## Enabled This API allow pausing background updates. @@ -82,3 +81,29 @@ The API returns the `enabled` param. ``` There is also a `GET` version which returns the `enabled` state. + + +## Run + +This API schedules a specific background update to run. The job starts immediately after calling the API. + + +The API is: + +``` +POST /_synapse/admin/v1/background_updates/start_job +``` + +with the following body: + +```json +{ + "job_name": "populate_stats_process_rooms" +} +``` + +The following JSON body parameters are available: + +- `job_name` - A string which job to run. Valid values are: + - `populate_stats_process_rooms` - Recalculate the stats for all rooms. + - `regenerate_directory` - Recalculate the [user directory](../../../user_directory.md) if it is stale or out of sync. diff --git a/docs/user_directory.md b/docs/user_directory.md index 07fe954891..c4794b04cf 100644 --- a/docs/user_directory.md +++ b/docs/user_directory.md @@ -6,9 +6,9 @@ on this particular server - i.e. ones which your account shares a room with, or who are present in a publicly viewable room present on the server. The directory info is stored in various tables, which can (typically after -DB corruption) get stale or out of sync. If this happens, for now the -solution to fix it is to execute the SQL [here](https://github.com/matrix-org/synapse/blob/master/synapse/storage/schema/main/delta/53/user_dir_populate.sql) -and then restart synapse. This should then start a background task to +DB corruption) get stale or out of sync. If this happens, for now the +solution to fix it is to use the [admin API](usage/administration/admin_api/background_updates.md#run) +and execute the job `regenerate_directory`. This should then start a background task to flush the current tables and regenerate the directory. Data model diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py index 2552f688d0..6d6678c7e4 100644 --- a/synapse/config/user_directory.py +++ b/synapse/config/user_directory.py @@ -53,8 +53,8 @@ class UserDirectoryConfig(Config): # indexes were (re)built was before Synapse 1.44, you'll have to # rebuild the indexes in order to search through all known users. # These indexes are built the first time Synapse starts; admins can - # manually trigger a rebuild following the instructions at - # https://matrix-org.github.io/synapse/latest/user_directory.html + # manually trigger a rebuild via API following the instructions at + # https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/background_updates.html#run # # Uncomment to return search results containing all known users, even if that # user does not share a room with the requester. diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 65b76fa10c..ee4a5e481b 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -28,6 +28,7 @@ from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin from synapse.rest.admin.background_updates import ( BackgroundUpdateEnabledRestServlet, BackgroundUpdateRestServlet, + BackgroundUpdateStartJobRestServlet, ) from synapse.rest.admin.devices import ( DeleteDevicesRestServlet, @@ -261,6 +262,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: SendServerNoticeServlet(hs).register(http_server) BackgroundUpdateEnabledRestServlet(hs).register(http_server) BackgroundUpdateRestServlet(hs).register(http_server) + BackgroundUpdateStartJobRestServlet(hs).register(http_server) def register_servlets_for_client_rest_resource( diff --git a/synapse/rest/admin/background_updates.py b/synapse/rest/admin/background_updates.py index 0d0183bf20..479672d4d5 100644 --- a/synapse/rest/admin/background_updates.py +++ b/synapse/rest/admin/background_updates.py @@ -12,10 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from http import HTTPStatus from typing import TYPE_CHECKING, Tuple from synapse.api.errors import SynapseError -from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.servlet import ( + RestServlet, + assert_params_in_dict, + parse_json_object_from_request, +) from synapse.http.site import SynapseRequest from synapse.rest.admin._base import admin_patterns, assert_user_is_admin from synapse.types import JsonDict @@ -29,37 +34,36 @@ logger = logging.getLogger(__name__) class BackgroundUpdateEnabledRestServlet(RestServlet): """Allows temporarily disabling background updates""" - PATTERNS = admin_patterns("/background_updates/enabled") + PATTERNS = admin_patterns("/background_updates/enabled$") def __init__(self, hs: "HomeServer"): - self.group_server = hs.get_groups_server_handler() - self.is_mine_id = hs.is_mine_id - self.auth = hs.get_auth() - - self.data_stores = hs.get_datastores() + self._auth = hs.get_auth() + self._data_stores = hs.get_datastores() async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + requester = await self._auth.get_user_by_req(request) + await assert_user_is_admin(self._auth, requester.user) # We need to check that all configured databases have updates enabled. # (They *should* all be in sync.) - enabled = all(db.updates.enabled for db in self.data_stores.databases) + enabled = all(db.updates.enabled for db in self._data_stores.databases) - return 200, {"enabled": enabled} + return HTTPStatus.OK, {"enabled": enabled} async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + requester = await self._auth.get_user_by_req(request) + await assert_user_is_admin(self._auth, requester.user) body = parse_json_object_from_request(request) enabled = body.get("enabled", True) if not isinstance(enabled, bool): - raise SynapseError(400, "'enabled' parameter must be a boolean") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "'enabled' parameter must be a boolean" + ) - for db in self.data_stores.databases: + for db in self._data_stores.databases: db.updates.enabled = enabled # If we're re-enabling them ensure that we start the background @@ -67,32 +71,29 @@ class BackgroundUpdateEnabledRestServlet(RestServlet): if enabled: db.updates.start_doing_background_updates() - return 200, {"enabled": enabled} + return HTTPStatus.OK, {"enabled": enabled} class BackgroundUpdateRestServlet(RestServlet): """Fetch information about background updates""" - PATTERNS = admin_patterns("/background_updates/status") + PATTERNS = admin_patterns("/background_updates/status$") def __init__(self, hs: "HomeServer"): - self.group_server = hs.get_groups_server_handler() - self.is_mine_id = hs.is_mine_id - self.auth = hs.get_auth() - - self.data_stores = hs.get_datastores() + self._auth = hs.get_auth() + self._data_stores = hs.get_datastores() async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + requester = await self._auth.get_user_by_req(request) + await assert_user_is_admin(self._auth, requester.user) # We need to check that all configured databases have updates enabled. # (They *should* all be in sync.) - enabled = all(db.updates.enabled for db in self.data_stores.databases) + enabled = all(db.updates.enabled for db in self._data_stores.databases) current_updates = {} - for db in self.data_stores.databases: + for db in self._data_stores.databases: update = db.updates.get_current_update() if not update: continue @@ -104,4 +105,72 @@ class BackgroundUpdateRestServlet(RestServlet): "average_items_per_ms": update.average_items_per_ms(), } - return 200, {"enabled": enabled, "current_updates": current_updates} + return HTTPStatus.OK, {"enabled": enabled, "current_updates": current_updates} + + +class BackgroundUpdateStartJobRestServlet(RestServlet): + """Allows to start specific background updates""" + + PATTERNS = admin_patterns("/background_updates/start_job") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._store = hs.get_datastore() + + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + requester = await self._auth.get_user_by_req(request) + await assert_user_is_admin(self._auth, requester.user) + + body = parse_json_object_from_request(request) + assert_params_in_dict(body, ["job_name"]) + + job_name = body["job_name"] + + if job_name == "populate_stats_process_rooms": + jobs = [ + { + "update_name": "populate_stats_process_rooms", + "progress_json": "{}", + }, + ] + elif job_name == "regenerate_directory": + jobs = [ + { + "update_name": "populate_user_directory_createtables", + "progress_json": "{}", + "depends_on": "", + }, + { + "update_name": "populate_user_directory_process_rooms", + "progress_json": "{}", + "depends_on": "populate_user_directory_createtables", + }, + { + "update_name": "populate_user_directory_process_users", + "progress_json": "{}", + "depends_on": "populate_user_directory_process_rooms", + }, + { + "update_name": "populate_user_directory_cleanup", + "progress_json": "{}", + "depends_on": "populate_user_directory_process_users", + }, + ] + else: + raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid job_name") + + try: + await self._store.db_pool.simple_insert_many( + table="background_updates", + values=jobs, + desc=f"admin_api_run_{job_name}", + ) + except self._store.db_pool.engine.module.IntegrityError: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Job %s is already in queue of background updates." % (job_name,), + ) + + self._store.db_pool.updates.start_doing_background_updates() + + return HTTPStatus.OK, {} diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index b9a8ca997e..b104f9032c 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -122,6 +122,8 @@ class BackgroundUpdater: def start_doing_background_updates(self) -> None: if self.enabled: + # if we start a new background update, not all updates are done. + self._all_done = False run_as_background_process("background_updates", self.run_background_updates) async def run_background_updates(self, sleep: bool = True) -> None: diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py index 78c48db552..1786316763 100644 --- a/tests/rest/admin/test_background_updates.py +++ b/tests/rest/admin/test_background_updates.py @@ -11,8 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from http import HTTPStatus +from typing import Collection + +from parameterized import parameterized import synapse.rest.admin +from synapse.api.errors import Codes from synapse.rest.client import login from synapse.server import HomeServer @@ -30,6 +35,60 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") + @parameterized.expand( + [ + ("GET", "/_synapse/admin/v1/background_updates/enabled"), + ("POST", "/_synapse/admin/v1/background_updates/enabled"), + ("GET", "/_synapse/admin/v1/background_updates/status"), + ("POST", "/_synapse/admin/v1/background_updates/start_job"), + ] + ) + def test_requester_is_no_admin(self, method: str, url: str): + """ + If the user is not a server admin, an error 403 is returned. + """ + + self.register_user("user", "pass", admin=False) + other_user_tok = self.login("user", "pass") + + channel = self.make_request( + method, + url, + content={}, + access_token=other_user_tok, + ) + + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_invalid_parameter(self): + """ + If parameters are invalid, an error is returned. + """ + url = "/_synapse/admin/v1/background_updates/start_job" + + # empty content + channel = self.make_request( + "POST", + url, + content={}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) + + # job_name invalid + channel = self.make_request( + "POST", + url, + content={"job_name": "unknown"}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + def _register_bg_update(self): "Adds a bg update but doesn't start it" @@ -60,7 +119,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "/_synapse/admin/v1/background_updates/status", access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Background updates should be enabled, but none should be running. self.assertDictEqual( @@ -82,7 +141,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "/_synapse/admin/v1/background_updates/status", access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Background updates should be enabled, and one should be running. self.assertDictEqual( @@ -114,7 +173,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "/_synapse/admin/v1/background_updates/enabled", access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertDictEqual(channel.json_body, {"enabled": True}) # Disable the BG updates @@ -124,7 +183,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): content={"enabled": False}, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertDictEqual(channel.json_body, {"enabled": False}) # Advance a bit and get the current status, note this will finish the in @@ -137,7 +196,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "/_synapse/admin/v1/background_updates/status", access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertDictEqual( channel.json_body, { @@ -162,7 +221,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "/_synapse/admin/v1/background_updates/status", access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # There should be no change from the previous /status response. self.assertDictEqual( @@ -188,7 +247,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): content={"enabled": True}, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertDictEqual(channel.json_body, {"enabled": True}) @@ -199,7 +258,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "/_synapse/admin/v1/background_updates/status", access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Background updates should be enabled and making progress. self.assertDictEqual( @@ -216,3 +275,82 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "enabled": True, }, ) + + @parameterized.expand( + [ + ("populate_stats_process_rooms", ["populate_stats_process_rooms"]), + ( + "regenerate_directory", + [ + "populate_user_directory_createtables", + "populate_user_directory_process_rooms", + "populate_user_directory_process_users", + "populate_user_directory_cleanup", + ], + ), + ] + ) + def test_start_backround_job(self, job_name: str, updates: Collection[str]): + """ + Test that background updates add to database and be processed. + + Args: + job_name: name of the job to call with API + updates: collection of background updates to be started + """ + + # no background update is waiting + self.assertTrue( + self.get_success( + self.store.db_pool.updates.has_completed_background_updates() + ) + ) + + channel = self.make_request( + "POST", + "/_synapse/admin/v1/background_updates/start_job", + content={"job_name": job_name}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + + # test that each background update is waiting now + for update in updates: + self.assertFalse( + self.get_success( + self.store.db_pool.updates.has_completed_background_update(update) + ) + ) + + self.wait_for_background_updates() + + # background updates are done + self.assertTrue( + self.get_success( + self.store.db_pool.updates.has_completed_background_updates() + ) + ) + + def test_start_backround_job_twice(self): + """Test that add a background update twice return an error.""" + + # add job to database + self.get_success( + self.store.db_pool.simple_insert( + table="background_updates", + values={ + "update_name": "populate_stats_process_rooms", + "progress_json": "{}", + }, + ) + ) + + channel = self.make_request( + "POST", + "/_synapse/admin/v1/background_updates/start_job", + content={"job_name": "populate_stats_process_rooms"}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) -- cgit 1.5.1 From 1035663833a76196c3e3ba425fd6500c5420bbe2 Mon Sep 17 00:00:00 2001 From: Kostas Date: Mon, 22 Nov 2021 19:01:03 +0100 Subject: Add config for customizing the claim used for JWT logins. (#11361) Allows specifying a different claim (from the default "sub") to use when calculating the localpart of the Matrix ID used during the JWT login. --- changelog.d/11361.feature | 1 + docs/jwt.md | 5 +-- docs/sample_config.yaml | 6 ++++ synapse/config/jwt.py | 9 ++++++ synapse/rest/client/login.py | 3 +- tests/rest/client/test_login.py | 68 ++++++++++++++++++++++------------------- 6 files changed, 57 insertions(+), 35 deletions(-) create mode 100644 changelog.d/11361.feature (limited to 'synapse/rest') diff --git a/changelog.d/11361.feature b/changelog.d/11361.feature new file mode 100644 index 0000000000..24c9244887 --- /dev/null +++ b/changelog.d/11361.feature @@ -0,0 +1 @@ +Update the JWT login type to support custom a `sub` claim. diff --git a/docs/jwt.md b/docs/jwt.md index 5be9fd26e3..32f58cc0cb 100644 --- a/docs/jwt.md +++ b/docs/jwt.md @@ -22,8 +22,9 @@ will be removed in a future version of Synapse. The `token` field should include the JSON web token with the following claims: -* The `sub` (subject) claim is required and should encode the local part of the - user ID. +* A claim that encodes the local part of the user ID is required. By default, + the `sub` (subject) claim is used, or a custom claim can be set in the + configuration file. * The expiration time (`exp`), not before time (`nbf`), and issued at (`iat`) claims are optional, but validated if present. * The issuer (`iss`) claim is optional, but required and validated if configured. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index aee300013f..ae476d19ac 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -2039,6 +2039,12 @@ sso: # #algorithm: "provided-by-your-issuer" + # Name of the claim containing a unique identifier for the user. + # + # Optional, defaults to `sub`. + # + #subject_claim: "sub" + # The issuer to validate the "iss" claim against. # # Optional, if provided the "iss" claim will be required and diff --git a/synapse/config/jwt.py b/synapse/config/jwt.py index 9d295f5856..24c3ef01fc 100644 --- a/synapse/config/jwt.py +++ b/synapse/config/jwt.py @@ -31,6 +31,8 @@ class JWTConfig(Config): self.jwt_secret = jwt_config["secret"] self.jwt_algorithm = jwt_config["algorithm"] + self.jwt_subject_claim = jwt_config.get("subject_claim", "sub") + # The issuer and audiences are optional, if provided, it is asserted # that the claims exist on the JWT. self.jwt_issuer = jwt_config.get("issuer") @@ -46,6 +48,7 @@ class JWTConfig(Config): self.jwt_enabled = False self.jwt_secret = None self.jwt_algorithm = None + self.jwt_subject_claim = None self.jwt_issuer = None self.jwt_audiences = None @@ -88,6 +91,12 @@ class JWTConfig(Config): # #algorithm: "provided-by-your-issuer" + # Name of the claim containing a unique identifier for the user. + # + # Optional, defaults to `sub`. + # + #subject_claim: "sub" + # The issuer to validate the "iss" claim against. # # Optional, if provided the "iss" claim will be required and diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index 467444a041..00e65c66ac 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -72,6 +72,7 @@ class LoginRestServlet(RestServlet): # JWT configuration variables. self.jwt_enabled = hs.config.jwt.jwt_enabled self.jwt_secret = hs.config.jwt.jwt_secret + self.jwt_subject_claim = hs.config.jwt.jwt_subject_claim self.jwt_algorithm = hs.config.jwt.jwt_algorithm self.jwt_issuer = hs.config.jwt.jwt_issuer self.jwt_audiences = hs.config.jwt.jwt_audiences @@ -413,7 +414,7 @@ class LoginRestServlet(RestServlet): errcode=Codes.FORBIDDEN, ) - user = payload.get("sub", None) + user = payload.get(self.jwt_subject_claim, None) if user is None: raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN) diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index 0b90e3f803..19f5e46537 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -815,13 +815,20 @@ class JWTTestCase(unittest.HomeserverTestCase): jwt_secret = "secret" jwt_algorithm = "HS256" + base_config = { + "enabled": True, + "secret": jwt_secret, + "algorithm": jwt_algorithm, + } - def make_homeserver(self, reactor, clock): - self.hs = self.setup_test_homeserver() - self.hs.config.jwt.jwt_enabled = True - self.hs.config.jwt.jwt_secret = self.jwt_secret - self.hs.config.jwt.jwt_algorithm = self.jwt_algorithm - return self.hs + def default_config(self): + config = super().default_config() + + # If jwt_config has been defined (eg via @override_config), don't replace it. + if config.get("jwt_config") is None: + config["jwt_config"] = self.base_config + + return config def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str: # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. @@ -879,16 +886,7 @@ class JWTTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["error"], "Invalid JWT") - @override_config( - { - "jwt_config": { - "jwt_enabled": True, - "secret": jwt_secret, - "algorithm": jwt_algorithm, - "issuer": "test-issuer", - } - } - ) + @override_config({"jwt_config": {**base_config, "issuer": "test-issuer"}}) def test_login_iss(self): """Test validating the issuer claim.""" # A valid issuer. @@ -919,16 +917,7 @@ class JWTTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") - @override_config( - { - "jwt_config": { - "jwt_enabled": True, - "secret": jwt_secret, - "algorithm": jwt_algorithm, - "audiences": ["test-audience"], - } - } - ) + @override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}}) def test_login_aud(self): """Test validating the audience claim.""" # A valid audience. @@ -962,6 +951,19 @@ class JWTTestCase(unittest.HomeserverTestCase): channel.json_body["error"], "JWT validation failed: Invalid audience" ) + def test_login_default_sub(self): + """Test reading user ID from the default subject claim.""" + channel = self.jwt_login({"sub": "kermit"}) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["user_id"], "@kermit:test") + + @override_config({"jwt_config": {**base_config, "subject_claim": "username"}}) + def test_login_custom_sub(self): + """Test reading user ID from a custom subject claim.""" + channel = self.jwt_login({"username": "frog"}) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["user_id"], "@frog:test") + def test_login_no_token(self): params = {"type": "org.matrix.login.jwt"} channel = self.make_request(b"POST", LOGIN_URL, params) @@ -1024,12 +1026,14 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): ] ) - def make_homeserver(self, reactor, clock): - self.hs = self.setup_test_homeserver() - self.hs.config.jwt.jwt_enabled = True - self.hs.config.jwt.jwt_secret = self.jwt_pubkey - self.hs.config.jwt.jwt_algorithm = "RS256" - return self.hs + def default_config(self): + config = super().default_config() + config["jwt_config"] = { + "enabled": True, + "secret": self.jwt_pubkey, + "algorithm": "RS256", + } + return config def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str: # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. -- cgit 1.5.1 From 6a5dd485bd82b269e7e169c0385290d081eae801 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 23 Nov 2021 06:43:56 -0500 Subject: Refactor the code to inject bundled relations during serialization. (#11408) --- changelog.d/11408.misc | 1 + synapse/events/utils.py | 146 ++++++++++++++++++++++----------------- synapse/handlers/events.py | 2 +- synapse/handlers/message.py | 2 +- synapse/rest/admin/rooms.py | 4 +- synapse/rest/client/relations.py | 6 +- synapse/rest/client/room.py | 2 +- synapse/rest/client/sync.py | 2 +- 8 files changed, 92 insertions(+), 73 deletions(-) create mode 100644 changelog.d/11408.misc (limited to 'synapse/rest') diff --git a/changelog.d/11408.misc b/changelog.d/11408.misc new file mode 100644 index 0000000000..55ed064672 --- /dev/null +++ b/changelog.d/11408.misc @@ -0,0 +1 @@ +Refactor including the bundled relations when serializing an event. diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 6fa631aa1d..e5967c995e 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -1,4 +1,5 @@ # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -392,15 +393,16 @@ class EventClientSerializer: self, event: Union[JsonDict, EventBase], time_now: int, - bundle_aggregations: bool = True, + bundle_relations: bool = True, **kwargs: Any, ) -> JsonDict: """Serializes a single event. Args: - event + event: The event being serialized. time_now: The current time in milliseconds - bundle_aggregations: Whether to bundle in related events + bundle_relations: Whether to include the bundled relations for this + event. **kwargs: Arguments to pass to `serialize_event` Returns: @@ -410,77 +412,93 @@ class EventClientSerializer: if not isinstance(event, EventBase): return event - event_id = event.event_id serialized_event = serialize_event(event, time_now, **kwargs) # If MSC1849 is enabled then we need to look if there are any relations # we need to bundle in with the event. # Do not bundle relations if the event has been redacted if not event.internal_metadata.is_redacted() and ( - self._msc1849_enabled and bundle_aggregations + self._msc1849_enabled and bundle_relations ): - annotations = await self.store.get_aggregation_groups_for_event(event_id) - references = await self.store.get_relations_for_event( - event_id, RelationTypes.REFERENCE, direction="f" - ) - - if annotations.chunk: - r = serialized_event["unsigned"].setdefault("m.relations", {}) - r[RelationTypes.ANNOTATION] = annotations.to_dict() - - if references.chunk: - r = serialized_event["unsigned"].setdefault("m.relations", {}) - r[RelationTypes.REFERENCE] = references.to_dict() - - edit = None - if event.type == EventTypes.Message: - edit = await self.store.get_applicable_edit(event_id) - - if edit: - # If there is an edit replace the content, preserving existing - # relations. - - # Ensure we take copies of the edit content, otherwise we risk modifying - # the original event. - edit_content = edit.content.copy() - - # Unfreeze the event content if necessary, so that we may modify it below - edit_content = unfreeze(edit_content) - serialized_event["content"] = edit_content.get("m.new_content", {}) - - # Check for existing relations - relations = event.content.get("m.relates_to") - if relations: - # Keep the relations, ensuring we use a dict copy of the original - serialized_event["content"]["m.relates_to"] = relations.copy() - else: - serialized_event["content"].pop("m.relates_to", None) - - r = serialized_event["unsigned"].setdefault("m.relations", {}) - r[RelationTypes.REPLACE] = { - "event_id": edit.event_id, - "origin_server_ts": edit.origin_server_ts, - "sender": edit.sender, - } - - # If this event is the start of a thread, include a summary of the replies. - if self._msc3440_enabled: - ( - thread_count, - latest_thread_event, - ) = await self.store.get_thread_summary(event_id) - if latest_thread_event: - r = serialized_event["unsigned"].setdefault("m.relations", {}) - r[RelationTypes.THREAD] = { - # Don't bundle aggregations as this could recurse forever. - "latest_event": await self.serialize_event( - latest_thread_event, time_now, bundle_aggregations=False - ), - "count": thread_count, - } + await self._injected_bundled_relations(event, time_now, serialized_event) return serialized_event + async def _injected_bundled_relations( + self, event: EventBase, time_now: int, serialized_event: JsonDict + ) -> None: + """Potentially injects bundled relations into the unsigned portion of the serialized event. + + Args: + event: The event being serialized. + time_now: The current time in milliseconds + serialized_event: The serialized event which may be modified. + + """ + event_id = event.event_id + + # The bundled relations to include. + relations = {} + + annotations = await self.store.get_aggregation_groups_for_event(event_id) + if annotations.chunk: + relations[RelationTypes.ANNOTATION] = annotations.to_dict() + + references = await self.store.get_relations_for_event( + event_id, RelationTypes.REFERENCE, direction="f" + ) + if references.chunk: + relations[RelationTypes.REFERENCE] = references.to_dict() + + edit = None + if event.type == EventTypes.Message: + edit = await self.store.get_applicable_edit(event_id) + + if edit: + # If there is an edit replace the content, preserving existing + # relations. + + # Ensure we take copies of the edit content, otherwise we risk modifying + # the original event. + edit_content = edit.content.copy() + + # Unfreeze the event content if necessary, so that we may modify it below + edit_content = unfreeze(edit_content) + serialized_event["content"] = edit_content.get("m.new_content", {}) + + # Check for existing relations + relates_to = event.content.get("m.relates_to") + if relates_to: + # Keep the relations, ensuring we use a dict copy of the original + serialized_event["content"]["m.relates_to"] = relates_to.copy() + else: + serialized_event["content"].pop("m.relates_to", None) + + relations[RelationTypes.REPLACE] = { + "event_id": edit.event_id, + "origin_server_ts": edit.origin_server_ts, + "sender": edit.sender, + } + + # If this event is the start of a thread, include a summary of the replies. + if self._msc3440_enabled: + ( + thread_count, + latest_thread_event, + ) = await self.store.get_thread_summary(event_id) + if latest_thread_event: + relations[RelationTypes.THREAD] = { + # Don't bundle relations as this could recurse forever. + "latest_event": await self.serialize_event( + latest_thread_event, time_now, bundle_relations=False + ), + "count": thread_count, + } + + # If any bundled relations were found, include them. + if relations: + serialized_event["unsigned"].setdefault("m.relations", {}).update(relations) + async def serialize_events( self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any ) -> List[JsonDict]: diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 1f64534a8a..b4ff935546 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -124,7 +124,7 @@ class EventStreamHandler: as_client_event=as_client_event, # We don't bundle "live" events, as otherwise clients # will end up double counting annotations. - bundle_aggregations=False, + bundle_relations=False, ) chunk = { diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 22dd4cf5fd..95b4fad3c6 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -252,7 +252,7 @@ class MessageHandler: now, # We don't bother bundling aggregations in when asked for state # events, as clients won't use them. - bundle_aggregations=False, + bundle_relations=False, ) return events diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 5b8ec1e5ca..a89dda1ba5 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -448,7 +448,7 @@ class RoomStateRestServlet(RestServlet): now, # We don't bother bundling aggregations in when asked for state # events, as clients won't use them. - bundle_aggregations=False, + bundle_relations=False, ) ret = {"state": room_state} @@ -778,7 +778,7 @@ class RoomEventContextServlet(RestServlet): results["state"], time_now, # No need to bundle aggregations for state events - bundle_aggregations=False, + bundle_relations=False, ) return 200, results diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 184cfbe196..45e9f1dd90 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -224,17 +224,17 @@ class RelationPaginationServlet(RestServlet): ) now = self.clock.time_msec() - # We set bundle_aggregations to False when retrieving the original + # We set bundle_relations to False when retrieving the original # event because we want the content before relations were applied to # it. original_event = await self._event_serializer.serialize_event( - event, now, bundle_aggregations=False + event, now, bundle_relations=False ) # Similarly, we don't allow relations to be applied to relations, so we # return the original relations without any aggregations on top of them # here. serialized_events = await self._event_serializer.serialize_events( - events, now, bundle_aggregations=False + events, now, bundle_relations=False ) return_value = pagination_chunk.to_dict() diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 03a353d53c..955d4e8641 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -719,7 +719,7 @@ class RoomEventContextServlet(RestServlet): results["state"], time_now, # No need to bundle aggregations for state events - bundle_aggregations=False, + bundle_relations=False, ) return 200, results diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 8c0fdb1940..b6a2485732 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -522,7 +522,7 @@ class SyncRestServlet(RestServlet): time_now=time_now, # We don't bundle "live" events, as otherwise clients # will end up double counting annotations. - bundle_aggregations=False, + bundle_relations=False, token_id=token_id, event_format=event_formatter, only_event_fields=only_fields, -- cgit 1.5.1 From f25c75d376ff8517e3245e06abdacf813606bd1f Mon Sep 17 00:00:00 2001 From: reivilibre Date: Tue, 23 Nov 2021 17:01:34 +0000 Subject: Rename unstable `access_token_lifetime` configuration option to `refreshable_access_token_lifetime` to make it clear it only concerns refreshable access tokens. (#11388) --- changelog.d/11388.misc | 1 + synapse/config/registration.py | 23 +++++++++++++++-------- synapse/handlers/register.py | 8 ++++++-- synapse/rest/client/login.py | 14 ++++++++++---- synapse/rest/client/register.py | 4 +++- tests/rest/client/test_auth.py | 2 +- 6 files changed, 36 insertions(+), 16 deletions(-) create mode 100644 changelog.d/11388.misc (limited to 'synapse/rest') diff --git a/changelog.d/11388.misc b/changelog.d/11388.misc new file mode 100644 index 0000000000..7ce7ad0498 --- /dev/null +++ b/changelog.d/11388.misc @@ -0,0 +1 @@ +Rename unstable `access_token_lifetime` configuration option to `refreshable_access_token_lifetime` to make it clear it only concerns refreshable access tokens. \ No newline at end of file diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 66382a479e..61e569d412 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -112,25 +112,32 @@ class RegistrationConfig(Config): session_lifetime = self.parse_duration(session_lifetime) self.session_lifetime = session_lifetime - # The `access_token_lifetime` applies for tokens that can be renewed + # The `refreshable_access_token_lifetime` applies for tokens that can be renewed # using a refresh token, as per MSC2918. If it is `None`, the refresh # token mechanism is disabled. # # Since it is incompatible with the `session_lifetime` mechanism, it is set to # `None` by default if a `session_lifetime` is set. - access_token_lifetime = config.get( - "access_token_lifetime", "5m" if session_lifetime is None else None + refreshable_access_token_lifetime = config.get( + "refreshable_access_token_lifetime", + "5m" if session_lifetime is None else None, ) - if access_token_lifetime is not None: - access_token_lifetime = self.parse_duration(access_token_lifetime) - self.access_token_lifetime = access_token_lifetime + if refreshable_access_token_lifetime is not None: + refreshable_access_token_lifetime = self.parse_duration( + refreshable_access_token_lifetime + ) + self.refreshable_access_token_lifetime = refreshable_access_token_lifetime - if session_lifetime is not None and access_token_lifetime is not None: + if ( + session_lifetime is not None + and refreshable_access_token_lifetime is not None + ): raise ConfigError( "The refresh token mechanism is incompatible with the " "`session_lifetime` option. Consider disabling the " "`session_lifetime` option or disabling the refresh token " - "mechanism by removing the `access_token_lifetime` option." + "mechanism by removing the `refreshable_access_token_lifetime` " + "option." ) # The fallback template used for authenticating using a registration token diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index b5664e9fde..448a36108e 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -116,7 +116,9 @@ class RegistrationHandler: self.pusher_pool = hs.get_pusherpool() self.session_lifetime = hs.config.registration.session_lifetime - self.access_token_lifetime = hs.config.registration.access_token_lifetime + self.refreshable_access_token_lifetime = ( + hs.config.registration.refreshable_access_token_lifetime + ) init_counters_for_auth_provider("") @@ -817,7 +819,9 @@ class RegistrationHandler: user_id, device_id=registered_device_id, ) - valid_until_ms = self.clock.time_msec() + self.access_token_lifetime + valid_until_ms = ( + self.clock.time_msec() + self.refreshable_access_token_lifetime + ) access_token = await self._auth_handler.create_access_token_for_user_id( user_id, diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index 00e65c66ac..67e03dca04 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -81,7 +81,9 @@ class LoginRestServlet(RestServlet): self.saml2_enabled = hs.config.saml2.saml2_enabled self.cas_enabled = hs.config.cas.cas_enabled self.oidc_enabled = hs.config.oidc.oidc_enabled - self._msc2918_enabled = hs.config.registration.access_token_lifetime is not None + self._msc2918_enabled = ( + hs.config.registration.refreshable_access_token_lifetime is not None + ) self.auth = hs.get_auth() @@ -453,7 +455,9 @@ class RefreshTokenServlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth_handler = hs.get_auth_handler() self._clock = hs.get_clock() - self.access_token_lifetime = hs.config.registration.access_token_lifetime + self.refreshable_access_token_lifetime = ( + hs.config.registration.refreshable_access_token_lifetime + ) async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: refresh_submission = parse_json_object_from_request(request) @@ -463,7 +467,9 @@ class RefreshTokenServlet(RestServlet): if not isinstance(token, str): raise SynapseError(400, "Invalid param: refresh_token", Codes.INVALID_PARAM) - valid_until_ms = self._clock.time_msec() + self.access_token_lifetime + valid_until_ms = ( + self._clock.time_msec() + self.refreshable_access_token_lifetime + ) access_token, refresh_token = await self._auth_handler.refresh_token( token, valid_until_ms ) @@ -562,7 +568,7 @@ class CasTicketServlet(RestServlet): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: LoginRestServlet(hs).register(http_server) - if hs.config.registration.access_token_lifetime is not None: + if hs.config.registration.refreshable_access_token_lifetime is not None: RefreshTokenServlet(hs).register(http_server) SsoRedirectServlet(hs).register(http_server) if hs.config.cas.cas_enabled: diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index bf3cb34146..d2b11e39d9 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -420,7 +420,9 @@ class RegisterRestServlet(RestServlet): self.password_policy_handler = hs.get_password_policy_handler() self.clock = hs.get_clock() self._registration_enabled = self.hs.config.registration.enable_registration - self._msc2918_enabled = hs.config.registration.access_token_lifetime is not None + self._msc2918_enabled = ( + hs.config.registration.refreshable_access_token_lifetime is not None + ) self._registration_flows = _calculate_registration_flows( hs.config, self.auth_handler diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index e2fcbdc63a..8552671431 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -598,7 +598,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): refresh_response.json_body["refresh_token"], ) - @override_config({"access_token_lifetime": "1m"}) + @override_config({"refreshable_access_token_lifetime": "1m"}) def test_refresh_token_expiration(self): """ The access token should have some time as specified in the config. -- cgit 1.5.1