From c2000ab35b76288a625f598d2382d4e3f29f65f6 Mon Sep 17 00:00:00 2001 From: Jason Robinson Date: Wed, 4 Aug 2021 13:40:25 +0300 Subject: Add `get_userinfo_by_id` method to `ModuleApi` (#9581) Makes it easier to fetch user details in for example spam checker modules, without needing to use api._store or figure out database interactions. Signed-off-by: Jason Robinson --- tests/module_api/test_api.py | 10 ++++++++++ 1 file changed, 10 insertions(+) (limited to 'tests') diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 81d9e2f484..0b817cc701 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -79,6 +79,16 @@ class ModuleApiTestCase(HomeserverTestCase): displayname = self.get_success(self.store.get_profile_displayname("bob")) self.assertEqual(displayname, "Bobberino") + def test_get_userinfo_by_id(self): + user_id = self.register_user("alice", "1234") + found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id)) + self.assertEqual(found_user.user_id.to_string(), user_id) + self.assertIdentical(found_user.is_admin, False) + + def test_get_userinfo_by_id__no_user_found(self): + found_user = self.get_success(self.module_api.get_userinfo_by_id("@alice:test")) + self.assertIsNone(found_user) + def test_sending_events_into_room(self): """Tests that a module can send events into a room""" # Mock out create_and_send_nonmember_event to check whether events are being sent -- cgit 1.5.1 From c37dad67ab04980ac934554399f52a27e54292ab Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 4 Aug 2021 13:54:51 +0100 Subject: Improve event caching code (#10119) Ensure we only load an event from the DB once when the same event is requested multiple times at once. --- changelog.d/10119.misc | 1 + synapse/storage/databases/main/events_worker.py | 144 +++++++++++++++------ synapse/storage/databases/main/roommember.py | 6 +- tests/storage/databases/main/test_events_worker.py | 50 +++++++ 4 files changed, 158 insertions(+), 43 deletions(-) create mode 100644 changelog.d/10119.misc (limited to 'tests') diff --git a/changelog.d/10119.misc b/changelog.d/10119.misc new file mode 100644 index 0000000000..f70dc6496f --- /dev/null +++ b/changelog.d/10119.misc @@ -0,0 +1 @@ +Improve event caching mechanism to avoid having multiple copies of an event in memory at a time. diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 3c86adab56..375463e4e9 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -14,7 +14,6 @@ import logging import threading -from collections import namedtuple from typing import ( Collection, Container, @@ -27,6 +26,7 @@ from typing import ( overload, ) +import attr from constantly import NamedConstant, Names from typing_extensions import Literal @@ -42,7 +42,11 @@ from synapse.api.room_versions import ( from synapse.events import EventBase, make_event_from_dict from synapse.events.snapshot import EventContext from synapse.events.utils import prune_event -from synapse.logging.context import PreserveLoggingContext, current_context +from synapse.logging.context import ( + PreserveLoggingContext, + current_context, + make_deferred_yieldable, +) from synapse.metrics.background_process_metrics import ( run_as_background_process, wrap_as_background_process, @@ -56,6 +60,8 @@ from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.storage.util.sequence import build_sequence_generator from synapse.types import JsonDict, get_domain_from_id +from synapse.util import unwrapFirstError +from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.lrucache import LruCache from synapse.util.iterutils import batch_iter @@ -74,7 +80,10 @@ EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events -_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) +@attr.s(slots=True, auto_attribs=True) +class _EventCacheEntry: + event: EventBase + redacted_event: Optional[EventBase] class EventRedactBehaviour(Names): @@ -161,6 +170,13 @@ class EventsWorkerStore(SQLBaseStore): max_size=hs.config.caches.event_cache_size, ) + # Map from event ID to a deferred that will result in a map from event + # ID to cache entry. Note that the returned dict may not have the + # requested event in it if the event isn't in the DB. + self._current_event_fetches: Dict[ + str, ObservableDeferred[Dict[str, _EventCacheEntry]] + ] = {} + self._event_fetch_lock = threading.Condition() self._event_fetch_list = [] self._event_fetch_ongoing = 0 @@ -476,7 +492,9 @@ class EventsWorkerStore(SQLBaseStore): return events - async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False): + async def _get_events_from_cache_or_db( + self, event_ids: Iterable[str], allow_rejected: bool = False + ) -> Dict[str, _EventCacheEntry]: """Fetch a bunch of events from the cache or the database. If events are pulled from the database, they will be cached for future lookups. @@ -485,53 +503,107 @@ class EventsWorkerStore(SQLBaseStore): Args: - event_ids (Iterable[str]): The event_ids of the events to fetch + event_ids: The event_ids of the events to fetch - allow_rejected (bool): Whether to include rejected events. If False, + allow_rejected: Whether to include rejected events. If False, rejected events are omitted from the response. Returns: - Dict[str, _EventCacheEntry]: - map from event id to result + map from event id to result """ event_entry_map = self._get_events_from_cache( - event_ids, allow_rejected=allow_rejected + event_ids, ) - missing_events_ids = [e for e in event_ids if e not in event_entry_map] + missing_events_ids = {e for e in event_ids if e not in event_entry_map} + + # We now look up if we're already fetching some of the events in the DB, + # if so we wait for those lookups to finish instead of pulling the same + # events out of the DB multiple times. + already_fetching: Dict[str, defer.Deferred] = {} + + for event_id in missing_events_ids: + deferred = self._current_event_fetches.get(event_id) + if deferred is not None: + # We're already pulling the event out of the DB. Add the deferred + # to the collection of deferreds to wait on. + already_fetching[event_id] = deferred.observe() + + missing_events_ids.difference_update(already_fetching) if missing_events_ids: log_ctx = current_context() log_ctx.record_event_fetch(len(missing_events_ids)) + # Add entries to `self._current_event_fetches` for each event we're + # going to pull from the DB. We use a single deferred that resolves + # to all the events we pulled from the DB (this will result in this + # function returning more events than requested, but that can happen + # already due to `_get_events_from_db`). + fetching_deferred: ObservableDeferred[ + Dict[str, _EventCacheEntry] + ] = ObservableDeferred(defer.Deferred()) + for event_id in missing_events_ids: + self._current_event_fetches[event_id] = fetching_deferred + # Note that _get_events_from_db is also responsible for turning db rows # into FrozenEvents (via _get_event_from_row), which involves seeing if # the events have been redacted, and if so pulling the redaction event out # of the database to check it. # - missing_events = await self._get_events_from_db( - missing_events_ids, allow_rejected=allow_rejected - ) + try: + missing_events = await self._get_events_from_db( + missing_events_ids, + ) - event_entry_map.update(missing_events) + event_entry_map.update(missing_events) + except Exception as e: + with PreserveLoggingContext(): + fetching_deferred.errback(e) + raise e + finally: + # Ensure that we mark these events as no longer being fetched. + for event_id in missing_events_ids: + self._current_event_fetches.pop(event_id, None) + + with PreserveLoggingContext(): + fetching_deferred.callback(missing_events) + + if already_fetching: + # Wait for the other event requests to finish and add their results + # to ours. + results = await make_deferred_yieldable( + defer.gatherResults( + already_fetching.values(), + consumeErrors=True, + ) + ).addErrback(unwrapFirstError) + + for result in results: + event_entry_map.update(result) + + if not allow_rejected: + event_entry_map = { + event_id: entry + for event_id, entry in event_entry_map.items() + if not entry.event.rejected_reason + } return event_entry_map def _invalidate_get_event_cache(self, event_id): self._get_event_cache.invalidate((event_id,)) - def _get_events_from_cache(self, events, allow_rejected, update_metrics=True): - """Fetch events from the caches + def _get_events_from_cache( + self, events: Iterable[str], update_metrics: bool = True + ) -> Dict[str, _EventCacheEntry]: + """Fetch events from the caches. - Args: - events (Iterable[str]): list of event_ids to fetch - allow_rejected (bool): Whether to return events that were rejected - update_metrics (bool): Whether to update the cache hit ratio metrics + May return rejected events. - Returns: - dict of event_id -> _EventCacheEntry for each event_id in cache. If - allow_rejected is `False` then there will still be an entry but it - will be `None` + Args: + events: list of event_ids to fetch + update_metrics: Whether to update the cache hit ratio metrics """ event_map = {} @@ -542,10 +614,7 @@ class EventsWorkerStore(SQLBaseStore): if not ret: continue - if allow_rejected or not ret.event.rejected_reason: - event_map[event_id] = ret - else: - event_map[event_id] = None + event_map[event_id] = ret return event_map @@ -672,23 +741,23 @@ class EventsWorkerStore(SQLBaseStore): with PreserveLoggingContext(): self.hs.get_reactor().callFromThread(fire, event_list, e) - async def _get_events_from_db(self, event_ids, allow_rejected=False): + async def _get_events_from_db( + self, event_ids: Iterable[str] + ) -> Dict[str, _EventCacheEntry]: """Fetch a bunch of events from the database. + May return rejected events. + Returned events will be added to the cache for future lookups. Unknown events are omitted from the response. Args: - event_ids (Iterable[str]): The event_ids of the events to fetch - - allow_rejected (bool): Whether to include rejected events. If False, - rejected events are omitted from the response. + event_ids: The event_ids of the events to fetch Returns: - Dict[str, _EventCacheEntry]: - map from event id to result. May return extra events which - weren't asked for. + map from event id to result. May return extra events which + weren't asked for. """ fetched_events = {} events_to_fetch = event_ids @@ -717,9 +786,6 @@ class EventsWorkerStore(SQLBaseStore): rejected_reason = row["rejected_reason"] - if not allow_rejected and rejected_reason: - continue - # If the event or metadata cannot be parsed, log the error and act # as if the event is unknown. try: diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 68f1b40ea6..e8157ba3d4 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -629,14 +629,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): # We don't update the event cache hit ratio as it completely throws off # the hit ratio counts. After all, we don't populate the cache if we # miss it here - event_map = self._get_events_from_cache( - member_event_ids, allow_rejected=False, update_metrics=False - ) + event_map = self._get_events_from_cache(member_event_ids, update_metrics=False) missing_member_event_ids = [] for event_id in member_event_ids: ev_entry = event_map.get(event_id) - if ev_entry: + if ev_entry and not ev_entry.event.rejected_reason: if ev_entry.event.membership == Membership.JOIN: users_in_room[ev_entry.event.state_key] = ProfileInfo( display_name=ev_entry.event.content.get("displayname", None), diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index 932970fd9a..d05d367685 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -14,7 +14,10 @@ import json from synapse.logging.context import LoggingContext +from synapse.rest import admin +from synapse.rest.client.v1 import login, room from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.util.async_helpers import yieldable_gather_results from tests import unittest @@ -94,3 +97,50 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): res = self.get_success(self.store.have_seen_events("room1", ["event10"])) self.assertEquals(res, {"event10"}) self.assertEquals(ctx.get_resource_usage().db_txn_count, 0) + + +class EventCacheTestCase(unittest.HomeserverTestCase): + """Test that the various layers of event cache works.""" + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store: EventsWorkerStore = hs.get_datastore() + + self.user = self.register_user("user", "pass") + self.token = self.login(self.user, "pass") + + self.room = self.helper.create_room_as(self.user, tok=self.token) + + res = self.helper.send(self.room, tok=self.token) + self.event_id = res["event_id"] + + # Reset the event cache so the tests start with it empty + self.store._get_event_cache.clear() + + def test_simple(self): + """Test that we cache events that we pull from the DB.""" + + with LoggingContext("test") as ctx: + self.get_success(self.store.get_event(self.event_id)) + + # We should have fetched the event from the DB + self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1) + + def test_dedupe(self): + """Test that if we request the same event multiple times we only pull it + out once. + """ + + with LoggingContext("test") as ctx: + d = yieldable_gather_results( + self.store.get_event, [self.event_id, self.event_id] + ) + self.get_success(d) + + # We should have fetched the event from the DB + self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1) -- cgit 1.5.1 From e8a3e8140291be0548ad80d0e942a9aaae6c2434 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 4 Aug 2021 16:13:24 +0200 Subject: Don't fail on empty bodies when sending out read receipts (#10531) Fixes a bug introduced in rc1 that would cause Synapse to 400 on read receipts requests with empty bodies. Broken in #10413 --- changelog.d/10531.bugfix | 1 + synapse/rest/client/v2_alpha/receipts.py | 2 +- tests/rest/client/v2_alpha/test_sync.py | 12 ++++++++++++ 3 files changed, 14 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10531.bugfix (limited to 'tests') diff --git a/changelog.d/10531.bugfix b/changelog.d/10531.bugfix new file mode 100644 index 0000000000..aaa921ee91 --- /dev/null +++ b/changelog.d/10531.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse v1.40.0rc1 that would cause Synapse to respond with an error when clients would update their read receipts. diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index 4b98979b47..d9ab836cd8 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -43,7 +43,7 @@ class ReceiptRestServlet(RestServlet): if receipt_type != "m.read": raise SynapseError(400, "Receipt type must be 'm.read'") - body = parse_json_object_from_request(request) + body = parse_json_object_from_request(request, allow_empty_body=True) hidden = body.get(ReadReceiptEventFields.MSC2285_HIDDEN, False) if not isinstance(hidden, bool): diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index f6ae9ae181..15748ed4fd 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -418,6 +418,18 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Test that the first user can't see the other user's hidden read receipt self.assertEqual(self._get_read_receipt(), None) + def test_read_receipt_with_empty_body(self): + # Send a message as the first user + res = self.helper.send(self.room_id, body="hello", tok=self.tok) + + # Send a read receipt for this message with an empty body + channel = self.make_request( + "POST", + "/rooms/%s/receipt/m.read/%s" % (self.room_id, res["event_id"]), + access_token=self.tok2, + ) + self.assertEqual(channel.code, 200) + def _get_read_receipt(self): """Syncs and returns the read receipt.""" -- cgit 1.5.1 From a8a27b2b8bac2995c3edd20518680366eb543ac9 Mon Sep 17 00:00:00 2001 From: Will Hunt Date: Thu, 5 Aug 2021 13:22:14 +0100 Subject: Only return an appservice protocol if it has a service providing it. (#10532) If there are no services providing a protocol, omit it completely instead of returning an empty dictionary. This fixes a long-standing spec compliance bug. --- changelog.d/10532.bugfix | 1 + synapse/handlers/appservice.py | 7 +-- tests/handlers/test_appservice.py | 122 +++++++++++++++++++++++++++++++++++++- 3 files changed, 125 insertions(+), 5 deletions(-) create mode 100644 changelog.d/10532.bugfix (limited to 'tests') diff --git a/changelog.d/10532.bugfix b/changelog.d/10532.bugfix new file mode 100644 index 0000000000..d95e3d9b59 --- /dev/null +++ b/changelog.d/10532.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where protocols which are not implemented by any appservices were incorrectly returned via `GET /_matrix/client/r0/thirdparty/protocols`. diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 21a17cd2e8..4ab4046650 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -392,9 +392,6 @@ class ApplicationServicesHandler: protocols[p].append(info) def _merge_instances(infos: List[JsonDict]) -> JsonDict: - if not infos: - return {} - # Merge the 'instances' lists of multiple results, but just take # the other fields from the first as they ought to be identical # copy the result so as not to corrupt the cached one @@ -406,7 +403,9 @@ class ApplicationServicesHandler: return combined - return {p: _merge_instances(protocols[p]) for p in protocols.keys()} + return { + p: _merge_instances(protocols[p]) for p in protocols.keys() if protocols[p] + } async def _get_services_for_event( self, event: EventBase diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 024c5e963c..43998020b2 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -133,11 +133,131 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.assertEquals(result.room_id, room_id) self.assertEquals(result.servers, servers) - def _mkservice(self, is_interested): + def test_get_3pe_protocols_no_appservices(self): + self.mock_store.get_app_services.return_value = [] + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol")) + ) + self.mock_as_api.get_3pe_protocol.assert_not_called() + self.assertEquals(response, {}) + + def test_get_3pe_protocols_no_protocols(self): + service = self._mkservice(False, []) + self.mock_store.get_app_services.return_value = [service] + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols()) + ) + self.mock_as_api.get_3pe_protocol.assert_not_called() + self.assertEquals(response, {}) + + def test_get_3pe_protocols_protocol_no_response(self): + service = self._mkservice(False, ["my-protocol"]) + self.mock_store.get_app_services.return_value = [service] + self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(None) + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols()) + ) + self.mock_as_api.get_3pe_protocol.assert_called_once_with( + service, "my-protocol" + ) + self.assertEquals(response, {}) + + def test_get_3pe_protocols_select_one_protocol(self): + service = self._mkservice(False, ["my-protocol"]) + self.mock_store.get_app_services.return_value = [service] + self.mock_as_api.get_3pe_protocol.return_value = make_awaitable( + {"x-protocol-data": 42, "instances": []} + ) + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol")) + ) + self.mock_as_api.get_3pe_protocol.assert_called_once_with( + service, "my-protocol" + ) + self.assertEquals( + response, {"my-protocol": {"x-protocol-data": 42, "instances": []}} + ) + + def test_get_3pe_protocols_one_protocol(self): + service = self._mkservice(False, ["my-protocol"]) + self.mock_store.get_app_services.return_value = [service] + self.mock_as_api.get_3pe_protocol.return_value = make_awaitable( + {"x-protocol-data": 42, "instances": []} + ) + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols()) + ) + self.mock_as_api.get_3pe_protocol.assert_called_once_with( + service, "my-protocol" + ) + self.assertEquals( + response, {"my-protocol": {"x-protocol-data": 42, "instances": []}} + ) + + def test_get_3pe_protocols_multiple_protocol(self): + service_one = self._mkservice(False, ["my-protocol"]) + service_two = self._mkservice(False, ["other-protocol"]) + self.mock_store.get_app_services.return_value = [service_one, service_two] + self.mock_as_api.get_3pe_protocol.return_value = make_awaitable( + {"x-protocol-data": 42, "instances": []} + ) + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols()) + ) + self.mock_as_api.get_3pe_protocol.assert_called() + self.assertEquals( + response, + { + "my-protocol": {"x-protocol-data": 42, "instances": []}, + "other-protocol": {"x-protocol-data": 42, "instances": []}, + }, + ) + + def test_get_3pe_protocols_multiple_info(self): + service_one = self._mkservice(False, ["my-protocol"]) + service_two = self._mkservice(False, ["my-protocol"]) + + async def get_3pe_protocol(service, unusedProtocol): + if service == service_one: + return { + "x-protocol-data": 42, + "instances": [{"desc": "Alice's service"}], + } + if service == service_two: + return { + "x-protocol-data": 36, + "x-not-used": 45, + "instances": [{"desc": "Bob's service"}], + } + raise Exception("Unexpected service") + + self.mock_store.get_app_services.return_value = [service_one, service_two] + self.mock_as_api.get_3pe_protocol = get_3pe_protocol + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols()) + ) + # It's expected that the second service's data doesn't appear in the response + self.assertEquals( + response, + { + "my-protocol": { + "x-protocol-data": 42, + "instances": [ + { + "desc": "Alice's service", + }, + {"desc": "Bob's service"}, + ], + }, + }, + ) + + def _mkservice(self, is_interested, protocols=None): service = Mock() service.is_interested.return_value = make_awaitable(is_interested) service.token = "mock_service_token" service.url = "mock_service_url" + service.protocols = protocols return service def _mkservice_alias(self, is_interested_in_alias): -- cgit 1.5.1 From 3b354faad0e6b1f41ed5dd0269a1785d3f505465 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 5 Aug 2021 08:39:17 -0400 Subject: Refactoring before implementing the updated spaces summary. (#10527) This should have no user-visible changes, but refactors some pieces of the SpaceSummaryHandler before adding support for the updated MSC2946. --- changelog.d/10527.misc | 1 + synapse/federation/federation_client.py | 23 ++-- synapse/handlers/space_summary.py | 125 ++++++++++++--------- tests/handlers/test_space_summary.py | 185 ++++++++++++++++++-------------- 4 files changed, 198 insertions(+), 136 deletions(-) create mode 100644 changelog.d/10527.misc (limited to 'tests') diff --git a/changelog.d/10527.misc b/changelog.d/10527.misc new file mode 100644 index 0000000000..3cf22f9daf --- /dev/null +++ b/changelog.d/10527.misc @@ -0,0 +1 @@ +Prepare for the new spaces summary endpoint (updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946)). diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index b7a10da15a..007d1a27dc 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1290,7 +1290,7 @@ class FederationClient(FederationBase): ) -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, auto_attribs=True) class FederationSpaceSummaryEventResult: """Represents a single event in the result of a successful get_space_summary call. @@ -1299,12 +1299,13 @@ class FederationSpaceSummaryEventResult: object attributes. """ - event_type = attr.ib(type=str) - state_key = attr.ib(type=str) - via = attr.ib(type=Sequence[str]) + event_type: str + room_id: str + state_key: str + via: Sequence[str] # the raw data, including the above keys - data = attr.ib(type=JsonDict) + data: JsonDict @classmethod def from_json_dict(cls, d: JsonDict) -> "FederationSpaceSummaryEventResult": @@ -1321,6 +1322,10 @@ class FederationSpaceSummaryEventResult: if not isinstance(event_type, str): raise ValueError("Invalid event: 'event_type' must be a str") + room_id = d.get("room_id") + if not isinstance(room_id, str): + raise ValueError("Invalid event: 'room_id' must be a str") + state_key = d.get("state_key") if not isinstance(state_key, str): raise ValueError("Invalid event: 'state_key' must be a str") @@ -1335,15 +1340,15 @@ class FederationSpaceSummaryEventResult: if any(not isinstance(v, str) for v in via): raise ValueError("Invalid event: 'via' must be a list of strings") - return cls(event_type, state_key, via, d) + return cls(event_type, room_id, state_key, via, d) -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, auto_attribs=True) class FederationSpaceSummaryResult: """Represents the data returned by a successful get_space_summary call.""" - rooms = attr.ib(type=Sequence[JsonDict]) - events = attr.ib(type=Sequence[FederationSpaceSummaryEventResult]) + rooms: Sequence[JsonDict] + events: Sequence[FederationSpaceSummaryEventResult] @classmethod def from_json_dict(cls, d: JsonDict) -> "FederationSpaceSummaryResult": diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py index 5f7d4602bd..3eb232c83e 100644 --- a/synapse/handlers/space_summary.py +++ b/synapse/handlers/space_summary.py @@ -16,7 +16,17 @@ import itertools import logging import re from collections import deque -from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Set, Tuple +from typing import ( + TYPE_CHECKING, + Collection, + Dict, + Iterable, + List, + Optional, + Sequence, + Set, + Tuple, +) import attr @@ -116,20 +126,22 @@ class SpaceSummaryHandler: max_children = max_rooms_per_space if processed_rooms else None if is_in_room: - room, events = await self._summarize_local_room( + room_entry = await self._summarize_local_room( requester, None, room_id, suggested_only, max_children ) + events: Collection[JsonDict] = [] + if room_entry: + rooms_result.append(room_entry.room) + events = room_entry.children + logger.debug( "Query of local room %s returned events %s", room_id, ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in events], ) - - if room: - rooms_result.append(room) else: - fed_rooms, fed_events = await self._summarize_remote_room( + fed_rooms = await self._summarize_remote_room( queue_entry, suggested_only, max_children, @@ -141,12 +153,10 @@ class SpaceSummaryHandler: # user is not permitted see. # # Filter the returned results to only what is accessible to the user. - room_ids = set() events = [] - for room in fed_rooms: - fed_room_id = room.get("room_id") - if not fed_room_id or not isinstance(fed_room_id, str): - continue + for room_entry in fed_rooms: + room = room_entry.room + fed_room_id = room_entry.room_id # The room should only be included in the summary if: # a. the user is in the room; @@ -189,21 +199,17 @@ class SpaceSummaryHandler: # The user can see the room, include it! if include_room: rooms_result.append(room) - room_ids.add(fed_room_id) + events.extend(room_entry.children) # All rooms returned don't need visiting again (even if the user # didn't have access to them). processed_rooms.add(fed_room_id) - for event in fed_events: - if event.get("room_id") in room_ids: - events.append(event) - logger.debug( "Query of %s returned rooms %s, events %s", room_id, - [room.get("room_id") for room in fed_rooms], - ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in fed_events], + [room_entry.room.get("room_id") for room_entry in fed_rooms], + ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in events], ) # the room we queried may or may not have been returned, but don't process @@ -283,20 +289,20 @@ class SpaceSummaryHandler: # already done this room continue - logger.debug("Processing room %s", room_id) - - room, events = await self._summarize_local_room( + room_entry = await self._summarize_local_room( None, origin, room_id, suggested_only, max_rooms_per_space ) processed_rooms.add(room_id) - if room: - rooms_result.append(room) - events_result.extend(events) + if room_entry: + rooms_result.append(room_entry.room) + events_result.extend(room_entry.children) - # add any children to the queue - room_queue.extend(edge_event["state_key"] for edge_event in events) + # add any children to the queue + room_queue.extend( + edge_event["state_key"] for edge_event in room_entry.children + ) return {"rooms": rooms_result, "events": events_result} @@ -307,7 +313,7 @@ class SpaceSummaryHandler: room_id: str, suggested_only: bool, max_children: Optional[int], - ) -> Tuple[Optional[JsonDict], Sequence[JsonDict]]: + ) -> Optional["_RoomEntry"]: """ Generate a room entry and a list of event entries for a given room. @@ -326,21 +332,16 @@ class SpaceSummaryHandler: to a server-set limit. Returns: - A tuple of: - The room information, if the room should be returned to the - user. None, otherwise. - - An iterable of the sorted children events. This may be limited - to a maximum size or may include all children. + A room entry if the room should be returned. None, otherwise. """ if not await self._is_room_accessible(room_id, requester, origin): - return None, () + return None room_entry = await self._build_room_entry(room_id) # If the room is not a space, return just the room information. if room_entry.get("room_type") != RoomTypes.SPACE: - return room_entry, () + return _RoomEntry(room_id, room_entry) # Otherwise, look for child rooms/spaces. child_events = await self._get_child_events(room_id) @@ -363,7 +364,7 @@ class SpaceSummaryHandler: ) ) - return room_entry, events_result + return _RoomEntry(room_id, room_entry, events_result) async def _summarize_remote_room( self, @@ -371,7 +372,7 @@ class SpaceSummaryHandler: suggested_only: bool, max_children: Optional[int], exclude_rooms: Iterable[str], - ) -> Tuple[Sequence[JsonDict], Sequence[JsonDict]]: + ) -> Iterable["_RoomEntry"]: """ Request room entries and a list of event entries for a given room by querying a remote server. @@ -386,11 +387,7 @@ class SpaceSummaryHandler: Rooms IDs which do not need to be summarized. Returns: - A tuple of: - An iterable of rooms. - - An iterable of the sorted children events. This may be limited - to a maximum size or may include all children. + An iterable of room entries. """ room_id = room.room_id logger.info("Requesting summary for %s via %s", room_id, room.via) @@ -414,11 +411,30 @@ class SpaceSummaryHandler: e, exc_info=logger.isEnabledFor(logging.DEBUG), ) - return (), () + return () + + # Group the events by their room. + children_by_room: Dict[str, List[JsonDict]] = {} + for ev in res.events: + if ev.event_type == EventTypes.SpaceChild: + children_by_room.setdefault(ev.room_id, []).append(ev.data) + + # Generate the final results. + results = [] + for fed_room in res.rooms: + fed_room_id = fed_room.get("room_id") + if not fed_room_id or not isinstance(fed_room_id, str): + continue - return res.rooms, tuple( - ev.data for ev in res.events if ev.event_type == EventTypes.SpaceChild - ) + results.append( + _RoomEntry( + fed_room_id, + fed_room, + children_by_room.get(fed_room_id, []), + ) + ) + + return results async def _is_room_accessible( self, room_id: str, requester: Optional[str], origin: Optional[str] @@ -606,10 +622,21 @@ class SpaceSummaryHandler: return sorted(filter(_has_valid_via, events), key=_child_events_comparison_key) -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, auto_attribs=True) class _RoomQueueEntry: - room_id = attr.ib(type=str) - via = attr.ib(type=Sequence[str]) + room_id: str + via: Sequence[str] + + +@attr.s(frozen=True, slots=True, auto_attribs=True) +class _RoomEntry: + room_id: str + # The room summary for this room. + room: JsonDict + # An iterable of the sorted, stripped children events for children of this room. + # + # This may not include all children. + children: Collection[JsonDict] = () def _has_valid_via(e: EventBase) -> bool: diff --git a/tests/handlers/test_space_summary.py b/tests/handlers/test_space_summary.py index 3f73ad7f94..f982a8c8b4 100644 --- a/tests/handlers/test_space_summary.py +++ b/tests/handlers/test_space_summary.py @@ -26,7 +26,7 @@ from synapse.api.constants import ( from synapse.api.errors import AuthError from synapse.api.room_versions import RoomVersions from synapse.events import make_event_from_dict -from synapse.handlers.space_summary import _child_events_comparison_key +from synapse.handlers.space_summary import _child_events_comparison_key, _RoomEntry from synapse.rest import admin from synapse.rest.client.v1 import login, room from synapse.server import HomeServer @@ -351,26 +351,30 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): # events before child events). # Note that these entries are brief, but should contain enough info. - rooms = [ - { - "room_id": subspace, - "world_readable": True, - "room_type": RoomTypes.SPACE, - }, - { - "room_id": subroom, - "world_readable": True, - }, - ] - event_content = {"via": [fed_hostname]} - events = [ - { - "room_id": subspace, - "state_key": subroom, - "content": event_content, - }, + return [ + _RoomEntry( + subspace, + { + "room_id": subspace, + "world_readable": True, + "room_type": RoomTypes.SPACE, + }, + [ + { + "room_id": subspace, + "state_key": subroom, + "content": {"via": [fed_hostname]}, + } + ], + ), + _RoomEntry( + subroom, + { + "room_id": subroom, + "world_readable": True, + }, + ), ] - return rooms, events # Add a room to the space which is on another server. self._add_child(self.space, subspace, self.token) @@ -436,70 +440,95 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ): # Note that these entries are brief, but should contain enough info. rooms = [ - { - "room_id": public_room, - "world_readable": False, - "join_rules": JoinRules.PUBLIC, - }, - { - "room_id": knock_room, - "world_readable": False, - "join_rules": JoinRules.KNOCK, - }, - { - "room_id": not_invited_room, - "world_readable": False, - "join_rules": JoinRules.INVITE, - }, - { - "room_id": invited_room, - "world_readable": False, - "join_rules": JoinRules.INVITE, - }, - { - "room_id": restricted_room, - "world_readable": False, - "join_rules": JoinRules.MSC3083_RESTRICTED, - "allowed_spaces": [], - }, - { - "room_id": restricted_accessible_room, - "world_readable": False, - "join_rules": JoinRules.MSC3083_RESTRICTED, - "allowed_spaces": [self.room], - }, - { - "room_id": world_readable_room, - "world_readable": True, - "join_rules": JoinRules.INVITE, - }, - { - "room_id": joined_room, - "world_readable": False, - "join_rules": JoinRules.INVITE, - }, - ] - - # Place each room in the sub-space. - event_content = {"via": [fed_hostname]} - events = [ - { - "room_id": subspace, - "state_key": room["room_id"], - "content": event_content, - } - for room in rooms + _RoomEntry( + public_room, + { + "room_id": public_room, + "world_readable": False, + "join_rules": JoinRules.PUBLIC, + }, + ), + _RoomEntry( + knock_room, + { + "room_id": knock_room, + "world_readable": False, + "join_rules": JoinRules.KNOCK, + }, + ), + _RoomEntry( + not_invited_room, + { + "room_id": not_invited_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + ), + _RoomEntry( + invited_room, + { + "room_id": invited_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + ), + _RoomEntry( + restricted_room, + { + "room_id": restricted_room, + "world_readable": False, + "join_rules": JoinRules.MSC3083_RESTRICTED, + "allowed_spaces": [], + }, + ), + _RoomEntry( + restricted_accessible_room, + { + "room_id": restricted_accessible_room, + "world_readable": False, + "join_rules": JoinRules.MSC3083_RESTRICTED, + "allowed_spaces": [self.room], + }, + ), + _RoomEntry( + world_readable_room, + { + "room_id": world_readable_room, + "world_readable": True, + "join_rules": JoinRules.INVITE, + }, + ), + _RoomEntry( + joined_room, + { + "room_id": joined_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + ), ] # Also include the subspace. rooms.insert( 0, - { - "room_id": subspace, - "world_readable": True, - }, + _RoomEntry( + subspace, + { + "room_id": subspace, + "world_readable": True, + }, + # Place each room in the sub-space. + [ + { + "room_id": subspace, + "state_key": room.room_id, + "content": {"via": [fed_hostname]}, + } + for room in rooms + ], + ), ) - return rooms, events + return rooms # Add a room to the space which is on another server. self._add_child(self.space, subspace, self.token) -- cgit 1.5.1