From d56202b0383627fdb4e04404d62922dce16868f8 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 4 Mar 2022 10:25:18 +0000 Subject: Fix type of `events` in `StateGroupStorage` and `StateHandler` (#12156) We make multiple passes over this, so a regular iterable won't do. --- synapse/storage/state.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/state.py b/synapse/storage/state.py index e79ecf64a0..86f1a5373b 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -561,7 +561,7 @@ class StateGroupStorage: return state_group_delta.prev_group, state_group_delta.delta_ids async def get_state_groups_ids( - self, _room_id: str, event_ids: Iterable[str] + self, _room_id: str, event_ids: Collection[str] ) -> Dict[int, MutableStateMap[str]]: """Get the event IDs of all the state for the state groups for the given events @@ -596,7 +596,7 @@ class StateGroupStorage: return group_to_state[state_group] async def get_state_groups( - self, room_id: str, event_ids: Iterable[str] + self, room_id: str, event_ids: Collection[str] ) -> Dict[int, List[EventBase]]: """Get the state groups for the given list of event_ids @@ -648,7 +648,7 @@ class StateGroupStorage: return self.stores.state._get_state_groups_from_groups(groups, state_filter) async def get_state_for_events( - self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None + self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None ) -> Dict[str, StateMap[EventBase]]: """Given a list of event_ids and type tuples, return a list of state dicts for each event. @@ -684,7 +684,7 @@ class StateGroupStorage: return {event: event_to_state[event] for event in event_ids} async def get_state_ids_for_events( - self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None + self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None ) -> Dict[str, StateMap[str]]: """ Get the state dicts corresponding to a list of events, containing the event_ids -- cgit 1.5.1 From cd1ae3d0b438ff453b7d4750c4fe901f266fcbb6 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 4 Mar 2022 07:10:10 -0500 Subject: Remove backwards compatibility with RelationPaginationToken. (#12138) --- changelog.d/12138.removal | 1 + synapse/rest/client/relations.py | 55 +++++++--------------------- synapse/storage/relations.py | 31 ---------------- tests/rest/client/test_relations.py | 73 +------------------------------------ 4 files changed, 16 insertions(+), 144 deletions(-) create mode 100644 changelog.d/12138.removal (limited to 'synapse/storage') diff --git a/changelog.d/12138.removal b/changelog.d/12138.removal new file mode 100644 index 0000000000..6ed84d476c --- /dev/null +++ b/changelog.d/12138.removal @@ -0,0 +1 @@ +Remove backwards compatibilty with pagination tokens from the `/relations` and `/aggregations` endpoints generated from Synapse < v1.52.0. diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 487ea38b55..07fa1cdd4c 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -27,50 +27,15 @@ from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns -from synapse.storage.relations import ( - AggregationPaginationToken, - PaginationChunk, - RelationPaginationToken, -) -from synapse.types import JsonDict, RoomStreamToken, StreamToken +from synapse.storage.relations import AggregationPaginationToken, PaginationChunk +from synapse.types import JsonDict, StreamToken if TYPE_CHECKING: from synapse.server import HomeServer - from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) -async def _parse_token( - store: "DataStore", token: Optional[str] -) -> Optional[StreamToken]: - """ - For backwards compatibility support RelationPaginationToken, but new pagination - tokens are generated as full StreamTokens, to be compatible with /sync and /messages. - """ - if not token: - return None - # Luckily the format for StreamToken and RelationPaginationToken differ enough - # that they can easily be separated. An "_" appears in the serialization of - # RoomStreamToken (as part of StreamToken), but RelationPaginationToken uses - # "-" only for separators. - if "_" in token: - return await StreamToken.from_string(store, token) - else: - relation_token = RelationPaginationToken.from_string(token) - return StreamToken( - room_key=RoomStreamToken(relation_token.topological, relation_token.stream), - presence_key=0, - typing_key=0, - receipt_key=0, - account_data_key=0, - push_rules_key=0, - to_device_key=0, - device_list_key=0, - groups_key=0, - ) - - class RelationPaginationServlet(RestServlet): """API to paginate relations on an event by topological ordering, optionally filtered by relation type and event type. @@ -122,8 +87,12 @@ class RelationPaginationServlet(RestServlet): pagination_chunk = PaginationChunk(chunk=[]) else: # Return the relations - from_token = await _parse_token(self.store, from_token_str) - to_token = await _parse_token(self.store, to_token_str) + from_token = None + if from_token_str: + from_token = await StreamToken.from_string(self.store, from_token_str) + to_token = None + if to_token_str: + to_token = await StreamToken.from_string(self.store, to_token_str) pagination_chunk = await self.store.get_relations_for_event( event_id=parent_id, @@ -317,8 +286,12 @@ class RelationAggregationGroupPaginationServlet(RestServlet): from_token_str = parse_string(request, "from") to_token_str = parse_string(request, "to") - from_token = await _parse_token(self.store, from_token_str) - to_token = await _parse_token(self.store, to_token_str) + from_token = None + if from_token_str: + from_token = await StreamToken.from_string(self.store, from_token_str) + to_token = None + if to_token_str: + to_token = await StreamToken.from_string(self.store, to_token_str) result = await self.store.get_relations_for_event( event_id=parent_id, diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index 36ca2b8273..fba270150b 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -54,37 +54,6 @@ class PaginationChunk: return d -@attr.s(frozen=True, slots=True, auto_attribs=True) -class RelationPaginationToken: - """Pagination token for relation pagination API. - - As the results are in topological order, we can use the - `topological_ordering` and `stream_ordering` fields of the events at the - boundaries of the chunk as pagination tokens. - - Attributes: - topological: The topological ordering of the boundary event - stream: The stream ordering of the boundary event. - """ - - topological: int - stream: int - - @staticmethod - def from_string(string: str) -> "RelationPaginationToken": - try: - t, s = string.split("-") - return RelationPaginationToken(int(t), int(s)) - except ValueError: - raise SynapseError(400, "Invalid relation pagination token") - - async def to_string(self, store: "DataStore") -> str: - return "%d-%d" % (self.topological, self.stream) - - def as_tuple(self) -> Tuple[Any, ...]: - return attr.astuple(self) - - @attr.s(frozen=True, slots=True, auto_attribs=True) class AggregationPaginationToken: """Pagination token for relation aggregation pagination API. diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 53062b41de..274f9c44c1 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -24,8 +24,7 @@ from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin from synapse.rest.client import login, register, relations, room, sync from synapse.server import HomeServer -from synapse.storage.relations import RelationPaginationToken -from synapse.types import JsonDict, StreamToken +from synapse.types import JsonDict from synapse.util import Clock from tests import unittest @@ -281,15 +280,6 @@ class RelationsTestCase(BaseRelationsTestCase): channel.json_body["chunk"][0], ) - def _stream_token_to_relation_token(self, token: str) -> str: - """Convert a StreamToken into a legacy token (RelationPaginationToken).""" - room_key = self.get_success(StreamToken.from_string(self.store, token)).room_key - return self.get_success( - RelationPaginationToken( - topological=room_key.topological, stream=room_key.stream - ).to_string(self.store) - ) - def test_repeated_paginate_relations(self) -> None: """Test that if we paginate using a limit and tokens then we get the expected events. @@ -330,34 +320,6 @@ class RelationsTestCase(BaseRelationsTestCase): found_event_ids.reverse() self.assertEqual(found_event_ids, expected_event_ids) - # Reset and try again, but convert the tokens to the legacy format. - prev_token = "" - found_event_ids = [] - for _ in range(20): - from_token = "" - if prev_token: - from_token = "&from=" + self._stream_token_to_relation_token(prev_token) - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) - next_batch = channel.json_body.get("next_batch") - - self.assertNotEqual(prev_token, next_batch) - prev_token = next_batch - - if not prev_token: - break - - # We paginated backwards, so reverse - found_event_ids.reverse() - self.assertEqual(found_event_ids, expected_event_ids) - def test_pagination_from_sync_and_messages(self) -> None: """Pagination tokens from /sync and /messages can be used to paginate /relations.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A") @@ -543,39 +505,6 @@ class RelationsTestCase(BaseRelationsTestCase): found_event_ids.reverse() self.assertEqual(found_event_ids, expected_event_ids) - # Reset and try again, but convert the tokens to the legacy format. - prev_token = "" - found_event_ids = [] - for _ in range(20): - from_token = "" - if prev_token: - from_token = "&from=" + self._stream_token_to_relation_token(prev_token) - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}" - f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}" - f"/m.reaction/{encoded_key}?limit=1{from_token}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - - found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) - - next_batch = channel.json_body.get("next_batch") - - self.assertNotEqual(prev_token, next_batch) - prev_token = next_batch - - if not prev_token: - break - - # We paginated backwards, so reverse - found_event_ids.reverse() - self.assertEqual(found_event_ids, expected_event_ids) - def test_aggregation(self) -> None: """Test that annotations get correctly aggregated.""" -- cgit 1.5.1 From 0752ab7a3621b90073f9332fbfdc8afe16a3be01 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 4 Mar 2022 17:57:27 +0000 Subject: Reduce to-device queries for /sync. (#12163) --- changelog.d/12163.misc | 1 + synapse/storage/databases/main/deviceinbox.py | 3 +++ 2 files changed, 4 insertions(+) create mode 100644 changelog.d/12163.misc (limited to 'synapse/storage') diff --git a/changelog.d/12163.misc b/changelog.d/12163.misc new file mode 100644 index 0000000000..13de0895f5 --- /dev/null +++ b/changelog.d/12163.misc @@ -0,0 +1 @@ +Reduce number of DB queries made during processing of `/sync`. diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 1392363de1..b4a1b041b1 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -298,6 +298,9 @@ class DeviceInboxWorkerStore(SQLBaseStore): # This user has new messages sent to them. Query messages for them user_ids_to_query.add(user_id) + if not user_ids_to_query: + return {}, to_stream_id + def get_device_messages_txn(txn: LoggingTransaction): # Build a query to select messages from any of the given devices that # are between the given stream id bounds. -- cgit 1.5.1 From f63bedef07360216a8de71dc38f00f1aea503903 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 7 Mar 2022 09:00:05 -0500 Subject: Invalidate caches when an event with a relation is redacted. (#12121) The caches for the target of the relation must be cleared so that the bundled aggregations are re-calculated after the redaction is processed. --- changelog.d/12113.bugfix | 1 + changelog.d/12113.misc | 1 - changelog.d/12121.bugfix | 1 + synapse/storage/databases/main/cache.py | 2 + synapse/storage/databases/main/events.py | 38 +++++- tests/rest/client/test_relations.py | 207 ++++++++++++++++++++++++------- 6 files changed, 202 insertions(+), 48 deletions(-) create mode 100644 changelog.d/12113.bugfix delete mode 100644 changelog.d/12113.misc create mode 100644 changelog.d/12121.bugfix (limited to 'synapse/storage') diff --git a/changelog.d/12113.bugfix b/changelog.d/12113.bugfix new file mode 100644 index 0000000000..df9b0dc413 --- /dev/null +++ b/changelog.d/12113.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug when redacting events with relations. diff --git a/changelog.d/12113.misc b/changelog.d/12113.misc deleted file mode 100644 index 102e064053..0000000000 --- a/changelog.d/12113.misc +++ /dev/null @@ -1 +0,0 @@ -Refactor the tests for event relations. diff --git a/changelog.d/12121.bugfix b/changelog.d/12121.bugfix new file mode 100644 index 0000000000..df9b0dc413 --- /dev/null +++ b/changelog.d/12121.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug when redacting events with relations. diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index c428dd5596..abd54c7dc7 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -200,6 +200,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self.get_relations_for_event.invalidate((relates_to,)) self.get_aggregation_groups_for_event.invalidate((relates_to,)) self.get_applicable_edit.invalidate((relates_to,)) + self.get_thread_summary.invalidate((relates_to,)) + self.get_thread_participated.invalidate((relates_to,)) async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]): """Invalidates the cache and adds it to the cache stream so slaves diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index ca2a9ba9d1..1dc83aa5e3 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1518,7 +1518,7 @@ class PersistEventsStore: ) # Remove from relations table. - self._handle_redaction(txn, event.redacts) + self._handle_redact_relations(txn, event.redacts) # Update the event_forward_extremities, event_backward_extremities and # event_edges tables. @@ -1943,15 +1943,43 @@ class PersistEventsStore: txn.execute(sql, (batch_id,)) - def _handle_redaction(self, txn, redacted_event_id): - """Handles receiving a redaction and checking whether we need to remove - any redacted relations from the database. + def _handle_redact_relations( + self, txn: LoggingTransaction, redacted_event_id: str + ) -> None: + """Handles receiving a redaction and checking whether the redacted event + has any relations which must be removed from the database. Args: txn - redacted_event_id (str): The event that was redacted. + redacted_event_id: The event that was redacted. """ + # Fetch the current relation of the event being redacted. + redacted_relates_to = self.db_pool.simple_select_one_onecol_txn( + txn, + table="event_relations", + keyvalues={"event_id": redacted_event_id}, + retcol="relates_to_id", + allow_none=True, + ) + # Any relation information for the related event must be cleared. + if redacted_relates_to is not None: + self.store._invalidate_cache_and_stream( + txn, self.store.get_relations_for_event, (redacted_relates_to,) + ) + self.store._invalidate_cache_and_stream( + txn, self.store.get_aggregation_groups_for_event, (redacted_relates_to,) + ) + self.store._invalidate_cache_and_stream( + txn, self.store.get_applicable_edit, (redacted_relates_to,) + ) + self.store._invalidate_cache_and_stream( + txn, self.store.get_thread_summary, (redacted_relates_to,) + ) + self.store._invalidate_cache_and_stream( + txn, self.store.get_thread_participated, (redacted_relates_to,) + ) + self.db_pool.simple_delete_txn( txn, table="event_relations", keyvalues={"event_id": redacted_event_id} ) diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 274f9c44c1..a40a5de399 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1273,7 +1273,21 @@ class RelationsTestCase(BaseRelationsTestCase): class RelationRedactionTestCase(BaseRelationsTestCase): - """Test the behaviour of relations when the parent or child event is redacted.""" + """ + Test the behaviour of relations when the parent or child event is redacted. + + The behaviour of each relation type is subtly different which causes the tests + to be a bit repetitive, they follow a naming scheme of: + + test_redact_(relation|parent)_{relation_type} + + The first bit of "relation" means that the event with the relation defined + on it (the child event) is to be redacted. A "parent" means that the target + of the relation (the parent event) is to be redacted. + + The relation_type describes which type of relation is under test (i.e. it is + related to the value of rel_type in the event content). + """ def _redact(self, event_id: str) -> None: channel = self.make_request( @@ -1284,9 +1298,53 @@ class RelationRedactionTestCase(BaseRelationsTestCase): ) self.assertEqual(200, channel.code, channel.json_body) + def _make_relation_requests(self) -> Tuple[List[str], JsonDict]: + """ + Makes requests and ensures they result in a 200 response, returns a + tuple of results: + + 1. `/relations` -> Returns a list of event IDs. + 2. `/event` -> Returns the response's m.relations field (from unsigned), + if it exists. + """ + + # Request the relations of the event. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + event_ids = [ev["event_id"] for ev in channel.json_body["chunk"]] + + # Fetch the bundled aggregations of the event. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/event/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + bundled_relations = channel.json_body["unsigned"].get("m.relations", {}) + + return event_ids, bundled_relations + + def _get_aggregations(self) -> List[JsonDict]: + """Request /aggregations on the parent ID and includes the returned chunk.""" + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + return channel.json_body["chunk"] + def test_redact_relation_annotation(self) -> None: - """Test that annotations of an event are properly handled after the + """ + Test that annotations of an event are properly handled after the annotation is redacted. + + The redacted relation should not be included in bundled aggregations or + the response to relations. """ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") self.assertEqual(200, channel.code, channel.json_body) @@ -1296,24 +1354,97 @@ class RelationRedactionTestCase(BaseRelationsTestCase): RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token ) self.assertEqual(200, channel.code, channel.json_body) + unredacted_event_id = channel.json_body["event_id"] + + # Both relations should exist. + event_ids, relations = self._make_relation_requests() + self.assertCountEqual(event_ids, [to_redact_event_id, unredacted_event_id]) + self.assertEquals( + relations["m.annotation"], + {"chunk": [{"type": "m.reaction", "key": "a", "count": 2}]}, + ) + + # Both relations appear in the aggregation. + chunk = self._get_aggregations() + self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 2}]) # Redact one of the reactions. self._redact(to_redact_event_id) - # Ensure that the aggregations are correct. - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}", - access_token=self.user_token, + # The unredacted relation should still exist. + event_ids, relations = self._make_relation_requests() + self.assertEquals(event_ids, [unredacted_event_id]) + self.assertEquals( + relations["m.annotation"], + {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]}, + ) + + # The unredacted aggregation should still exist. + chunk = self._get_aggregations() + self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 1}]) + + @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) + def test_redact_relation_thread(self) -> None: + """ + Test that thread replies are properly handled after the thread reply redacted. + + The redacted event should not be included in bundled aggregations or + the response to relations. + """ + channel = self._send_relation( + RelationTypes.THREAD, + EventTypes.Message, + content={"body": "reply 1", "msgtype": "m.text"}, ) self.assertEqual(200, channel.code, channel.json_body) + unredacted_event_id = channel.json_body["event_id"] + # Note that the *last* event in the thread is redacted, as that gets + # included in the bundled aggregation. + channel = self._send_relation( + RelationTypes.THREAD, + EventTypes.Message, + content={"body": "reply 2", "msgtype": "m.text"}, + ) + self.assertEqual(200, channel.code, channel.json_body) + to_redact_event_id = channel.json_body["event_id"] + + # Both relations exist. + event_ids, relations = self._make_relation_requests() + self.assertEquals(event_ids, [to_redact_event_id, unredacted_event_id]) + self.assertDictContainsSubset( + { + "count": 2, + "current_user_participated": True, + }, + relations[RelationTypes.THREAD], + ) + # And the latest event returned is the event that will be redacted. self.assertEqual( - channel.json_body, - {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]}, + relations[RelationTypes.THREAD]["latest_event"]["event_id"], + to_redact_event_id, ) - def test_redact_relation_edit(self) -> None: + # Redact one of the reactions. + self._redact(to_redact_event_id) + + # The unredacted relation should still exist. + event_ids, relations = self._make_relation_requests() + self.assertEquals(event_ids, [unredacted_event_id]) + self.assertDictContainsSubset( + { + "count": 1, + "current_user_participated": True, + }, + relations[RelationTypes.THREAD], + ) + # And the latest event is now the unredacted event. + self.assertEqual( + relations[RelationTypes.THREAD]["latest_event"]["event_id"], + unredacted_event_id, + ) + + def test_redact_parent_edit(self) -> None: """Test that edits of an event are redacted when the original event is redacted. """ @@ -1331,34 +1462,19 @@ class RelationRedactionTestCase(BaseRelationsTestCase): self.assertEqual(200, channel.code, channel.json_body) # Check the relation is returned - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations" - f"/{self.parent_id}/m.replace/m.room.message", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - self.assertIn("chunk", channel.json_body) - self.assertEqual(len(channel.json_body["chunk"]), 1) + event_ids, relations = self._make_relation_requests() + self.assertEqual(len(event_ids), 1) + self.assertIn(RelationTypes.REPLACE, relations) # Redact the original event self._redact(self.parent_id) - # Try to check for remaining m.replace relations - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations" - f"/{self.parent_id}/m.replace/m.room.message", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - # Check that no relations are returned - self.assertIn("chunk", channel.json_body) - self.assertEqual(channel.json_body["chunk"], []) + # The relations are not returned. + event_ids, relations = self._make_relation_requests() + self.assertEqual(len(event_ids), 0) + self.assertEqual(relations, {}) - def test_redact_parent(self) -> None: + def test_redact_parent_annotation(self) -> None: """Test that annotations of an event are redacted when the original event is redacted. """ @@ -1366,16 +1482,23 @@ class RelationRedactionTestCase(BaseRelationsTestCase): channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") self.assertEqual(200, channel.code, channel.json_body) + # The relations should exist. + event_ids, relations = self._make_relation_requests() + self.assertEqual(len(event_ids), 1) + self.assertIn(RelationTypes.ANNOTATION, relations) + + # The aggregation should exist. + chunk = self._get_aggregations() + self.assertEqual(chunk, [{"type": "m.reaction", "key": "👍", "count": 1}]) + # Redact the original event. self._redact(self.parent_id) - # Check that aggregations returns zero - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}/m.annotation/m.reaction", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) + # The relations are not returned. + event_ids, relations = self._make_relation_requests() + self.assertEqual(event_ids, []) + self.assertEqual(relations, {}) - self.assertIn("chunk", channel.json_body) - self.assertEqual(channel.json_body["chunk"], []) + # There's nothing to aggregate. + chunk = self._get_aggregations() + self.assertEqual(chunk, []) -- cgit 1.5.1 From 26211fec24d8d0a967de33147e148166359ec8cb Mon Sep 17 00:00:00 2001 From: Shay Date: Mon, 7 Mar 2022 09:44:33 -0800 Subject: Fix a bug in background updates wherein background updates are never run using the default batch size (#12157) --- changelog.d/12157.bugfix | 1 + synapse/storage/background_updates.py | 8 +++++--- tests/rest/admin/test_background_updates.py | 18 ++++++++---------- tests/storage/test_background_update.py | 4 ++-- 4 files changed, 16 insertions(+), 15 deletions(-) create mode 100644 changelog.d/12157.bugfix (limited to 'synapse/storage') diff --git a/changelog.d/12157.bugfix b/changelog.d/12157.bugfix new file mode 100644 index 0000000000..c3d2e700bb --- /dev/null +++ b/changelog.d/12157.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in #4864 whereby background updates are never run with the default background batch size. diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index d64910aded..4acc2c997d 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -102,10 +102,12 @@ class BackgroundUpdatePerformance: Returns: A duration in ms as a float """ - if self.avg_duration_ms == 0: - return 0 - elif self.total_item_count == 0: + # We want to return None if this is the first background update item + if self.total_item_count == 0: return None + # Avoid dividing by zero + elif self.avg_duration_ms == 0: + return 0 else: # Use the exponential moving average so that we can adapt to # changes in how long the update process takes. diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py index fb36aa9940..becec84524 100644 --- a/tests/rest/admin/test_background_updates.py +++ b/tests/rest/admin/test_background_updates.py @@ -155,10 +155,10 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "current_updates": { "master": { "name": "test_update", - "average_items_per_ms": 0.001, + "average_items_per_ms": 0.1, "total_duration_ms": 1000.0, "total_item_count": ( - BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE + BackgroundUpdater.DEFAULT_BACKGROUND_BATCH_SIZE ), } }, @@ -210,10 +210,10 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "current_updates": { "master": { "name": "test_update", - "average_items_per_ms": 0.001, + "average_items_per_ms": 0.1, "total_duration_ms": 1000.0, "total_item_count": ( - BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE + BackgroundUpdater.DEFAULT_BACKGROUND_BATCH_SIZE ), } }, @@ -239,10 +239,10 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "current_updates": { "master": { "name": "test_update", - "average_items_per_ms": 0.001, + "average_items_per_ms": 0.1, "total_duration_ms": 1000.0, "total_item_count": ( - BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE + BackgroundUpdater.DEFAULT_BACKGROUND_BATCH_SIZE ), } }, @@ -278,11 +278,9 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "current_updates": { "master": { "name": "test_update", - "average_items_per_ms": 0.001, + "average_items_per_ms": 0.05263157894736842, "total_duration_ms": 2000.0, - "total_item_count": ( - 2 * BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE - ), + "total_item_count": (110), } }, "enabled": True, diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 39dcc094bd..9fdf54ea31 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -66,13 +66,13 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): self.update_handler.reset_mock() res = self.get_success( self.updates.do_next_background_update(False), - by=0.01, + by=0.02, ) self.assertFalse(res) # on the first call, we should get run with the default background update size self.update_handler.assert_called_once_with( - {"my_key": 1}, self.updates.MINIMUM_BACKGROUND_BATCH_SIZE + {"my_key": 1}, self.updates.DEFAULT_BACKGROUND_BATCH_SIZE ) # second step: complete the update -- cgit 1.5.1 From 032688854babeea832cbb4f762fc70fe31e73cc6 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 9 Mar 2022 10:29:39 -0500 Subject: Remove some unused variables/parameters. (#12187) --- changelog.d/12187.misc | 1 + synapse/storage/databases/main/roommember.py | 14 +++++--------- 2 files changed, 6 insertions(+), 9 deletions(-) create mode 100644 changelog.d/12187.misc (limited to 'synapse/storage') diff --git a/changelog.d/12187.misc b/changelog.d/12187.misc new file mode 100644 index 0000000000..c53e68faa5 --- /dev/null +++ b/changelog.d/12187.misc @@ -0,0 +1 @@ +Remove unused variables. diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index e48ec5f495..bef675b845 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -46,7 +46,7 @@ from synapse.storage.roommember import ( ProfileInfo, RoomsForUser, ) -from synapse.types import PersistedEventPosition, StateMap, get_domain_from_id +from synapse.types import PersistedEventPosition, get_domain_from_id from synapse.util.async_helpers import Linearizer from synapse.util.caches import intern_string from synapse.util.caches.descriptors import _CacheContext, cached, cachedList @@ -273,7 +273,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): txn.execute(sql, (room_id,)) res = {} for count, membership in txn: - summary = res.setdefault(membership, MemberSummary([], count)) + res.setdefault(membership, MemberSummary([], count)) # we order by membership and then fairly arbitrarily by event_id so # heroes are consistent @@ -839,18 +839,14 @@ class RoomMemberWorkerStore(EventsWorkerStore): with Measure(self._clock, "get_joined_hosts"): return await self._get_joined_hosts( - room_id, state_group, state_entry.state, state_entry=state_entry + room_id, state_group, state_entry=state_entry ) @cached(num_args=2, max_entries=10000, iterable=True) async def _get_joined_hosts( - self, - room_id: str, - state_group: int, - current_state_ids: StateMap[str], - state_entry: "_StateCacheEntry", + self, room_id: str, state_group: int, state_entry: "_StateCacheEntry" ) -> FrozenSet[str]: - # We don't use `state_group`, its there so that we can cache based on + # We don't use `state_group`, it's there so that we can cache based on # it. However, its important that its never None, since two # current_state's with a state_group of None are likely to be different. # -- cgit 1.5.1 From 690cb4f3b32938f5ced5590abe9429733040a129 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 9 Mar 2022 13:07:41 -0500 Subject: Allow for ignoring some arguments when caching. (#12189) * `@cached` can now take an `uncached_args` which is an iterable of names to not use in the cache key. * Requires `@cached`, @cachedList` and `@lru_cache` to use keyword arguments for clarity. * Asserts that keyword-only arguments in cached functions are not accepted. (I tested this briefly and I don't believe this works properly.) --- changelog.d/12189.misc | 1 + synapse/storage/databases/main/events_worker.py | 4 +- synapse/util/caches/descriptors.py | 74 +++++++++++++++++----- tests/util/caches/test_descriptors.py | 84 ++++++++++++++++++++++++- 4 files changed, 142 insertions(+), 21 deletions(-) create mode 100644 changelog.d/12189.misc (limited to 'synapse/storage') diff --git a/changelog.d/12189.misc b/changelog.d/12189.misc new file mode 100644 index 0000000000..015e808e63 --- /dev/null +++ b/changelog.d/12189.misc @@ -0,0 +1 @@ +Support skipping some arguments when generating cache keys. diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 26784f755e..59454a47df 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1286,7 +1286,7 @@ class EventsWorkerStore(SQLBaseStore): ) return {eid for ((_rid, eid), have_event) in res.items() if have_event} - @cachedList("have_seen_event", "keys") + @cachedList(cached_method_name="have_seen_event", list_name="keys") async def _have_seen_events_dict( self, keys: Iterable[Tuple[str, str]] ) -> Dict[Tuple[str, str], bool]: @@ -1954,7 +1954,7 @@ class EventsWorkerStore(SQLBaseStore): get_event_id_for_timestamp_txn, ) - @cachedList("is_partial_state_event", list_name="event_ids") + @cachedList(cached_method_name="is_partial_state_event", list_name="event_ids") async def get_partial_state_events( self, event_ids: Collection[str] ) -> Dict[str, bool]: diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 1cdead02f1..c3c5c16db9 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -20,6 +20,7 @@ from typing import ( Any, Awaitable, Callable, + Collection, Dict, Generic, Hashable, @@ -69,6 +70,7 @@ class _CacheDescriptorBase: self, orig: Callable[..., Any], num_args: Optional[int], + uncached_args: Optional[Collection[str]] = None, cache_context: bool = False, ): self.orig = orig @@ -76,6 +78,13 @@ class _CacheDescriptorBase: arg_spec = inspect.getfullargspec(orig) all_args = arg_spec.args + # There's no reason that keyword-only arguments couldn't be supported, + # but right now they're buggy so do not allow them. + if arg_spec.kwonlyargs: + raise ValueError( + "_CacheDescriptorBase does not support keyword-only arguments." + ) + if "cache_context" in all_args: if not cache_context: raise ValueError( @@ -88,6 +97,9 @@ class _CacheDescriptorBase: " named `cache_context`" ) + if num_args is not None and uncached_args is not None: + raise ValueError("Cannot provide both num_args and uncached_args") + if num_args is None: num_args = len(all_args) - 1 if cache_context: @@ -105,6 +117,12 @@ class _CacheDescriptorBase: # list of the names of the args used as the cache key self.arg_names = all_args[1 : num_args + 1] + # If there are args to not cache on, filter them out (and fix the size of num_args). + if uncached_args is not None: + include_arg_in_cache_key = [n not in uncached_args for n in self.arg_names] + else: + include_arg_in_cache_key = [True] * len(self.arg_names) + # self.arg_defaults is a map of arg name to its default value for each # argument that has a default value if arg_spec.defaults: @@ -119,8 +137,8 @@ class _CacheDescriptorBase: self.add_cache_context = cache_context - self.cache_key_builder = get_cache_key_builder( - self.arg_names, self.arg_defaults + self.cache_key_builder = _get_cache_key_builder( + self.arg_names, include_arg_in_cache_key, self.arg_defaults ) @@ -130,8 +148,7 @@ class _LruCachedFunction(Generic[F]): def lru_cache( - max_entries: int = 1000, - cache_context: bool = False, + *, max_entries: int = 1000, cache_context: bool = False ) -> Callable[[F], _LruCachedFunction[F]]: """A method decorator that applies a memoizing cache around the function. @@ -186,7 +203,9 @@ class LruCacheDescriptor(_CacheDescriptorBase): max_entries: int = 1000, cache_context: bool = False, ): - super().__init__(orig, num_args=None, cache_context=cache_context) + super().__init__( + orig, num_args=None, uncached_args=None, cache_context=cache_context + ) self.max_entries = max_entries def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]: @@ -260,6 +279,9 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): num_args: number of positional arguments (excluding ``self`` and ``cache_context``) to use as cache keys. Defaults to all named args of the function. + uncached_args: a list of argument names to not use as the cache key. + (``self`` and ``cache_context`` are always ignored.) Cannot be used + with num_args. tree: cache_context: iterable: @@ -273,12 +295,18 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): orig: Callable[..., Any], max_entries: int = 1000, num_args: Optional[int] = None, + uncached_args: Optional[Collection[str]] = None, tree: bool = False, cache_context: bool = False, iterable: bool = False, prune_unread_entries: bool = True, ): - super().__init__(orig, num_args=num_args, cache_context=cache_context) + super().__init__( + orig, + num_args=num_args, + uncached_args=uncached_args, + cache_context=cache_context, + ) if tree and self.num_args < 2: raise RuntimeError( @@ -369,7 +397,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): but including list_name) to use as cache keys. Defaults to all named args of the function. """ - super().__init__(orig, num_args=num_args) + super().__init__(orig, num_args=num_args, uncached_args=None) self.list_name = list_name @@ -530,8 +558,10 @@ class _CacheContext: def cached( + *, max_entries: int = 1000, num_args: Optional[int] = None, + uncached_args: Optional[Collection[str]] = None, tree: bool = False, cache_context: bool = False, iterable: bool = False, @@ -541,6 +571,7 @@ def cached( orig, max_entries=max_entries, num_args=num_args, + uncached_args=uncached_args, tree=tree, cache_context=cache_context, iterable=iterable, @@ -551,7 +582,7 @@ def cached( def cachedList( - cached_method_name: str, list_name: str, num_args: Optional[int] = None + *, cached_method_name: str, list_name: str, num_args: Optional[int] = None ) -> Callable[[F], _CachedFunction[F]]: """Creates a descriptor that wraps a function in a `CacheListDescriptor`. @@ -590,13 +621,16 @@ def cachedList( return cast(Callable[[F], _CachedFunction[F]], func) -def get_cache_key_builder( - param_names: Sequence[str], param_defaults: Mapping[str, Any] +def _get_cache_key_builder( + param_names: Sequence[str], + include_params: Sequence[bool], + param_defaults: Mapping[str, Any], ) -> Callable[[Sequence[Any], Mapping[str, Any]], CacheKey]: """Construct a function which will build cache keys suitable for a cached function Args: param_names: list of formal parameter names for the cached function + include_params: list of bools of whether to include the parameter name in the cache key param_defaults: a mapping from parameter name to default value for that param Returns: @@ -608,6 +642,7 @@ def get_cache_key_builder( if len(param_names) == 1: nm = param_names[0] + assert include_params[0] is True def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey: if nm in kwargs: @@ -620,13 +655,18 @@ def get_cache_key_builder( else: def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey: - return tuple(_get_cache_key_gen(param_names, param_defaults, args, kwargs)) + return tuple( + _get_cache_key_gen( + param_names, include_params, param_defaults, args, kwargs + ) + ) return get_cache_key def _get_cache_key_gen( param_names: Iterable[str], + include_params: Iterable[bool], param_defaults: Mapping[str, Any], args: Sequence[Any], kwargs: Mapping[str, Any], @@ -637,16 +677,18 @@ def _get_cache_key_gen( This is essentially the same operation as `inspect.getcallargs`, but optimised so that we don't need to inspect the target function for each call. """ - # We loop through each arg name, looking up if its in the `kwargs`, # otherwise using the next argument in `args`. If there are no more # args then we try looking the arg name up in the defaults. pos = 0 - for nm in param_names: + for nm, inc in zip(param_names, include_params): if nm in kwargs: - yield kwargs[nm] + if inc: + yield kwargs[nm] elif pos < len(args): - yield args[pos] + if inc: + yield args[pos] pos += 1 else: - yield param_defaults[nm] + if inc: + yield param_defaults[nm] diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 19741ffcda..6a4b17527a 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -141,6 +141,84 @@ class DescriptorTestCase(unittest.TestCase): self.assertEqual(r, "chips") obj.mock.assert_not_called() + @defer.inlineCallbacks + def test_cache_uncached_args(self): + """ + Only the arguments not named in uncached_args should matter to the cache + + Note that this is identical to test_cache_num_args, but provides the + arguments differently. + """ + + class Cls: + # Note that it is important that this is not the last argument to + # test behaviour of skipping arguments properly. + @descriptors.cached(uncached_args=("arg2",)) + def fn(self, arg1, arg2, arg3): + return self.mock(arg1, arg2, arg3) + + def __init__(self): + self.mock = mock.Mock() + + obj = Cls() + obj.mock.return_value = "fish" + r = yield obj.fn(1, 2, 3) + self.assertEqual(r, "fish") + obj.mock.assert_called_once_with(1, 2, 3) + obj.mock.reset_mock() + + # a call with different params should call the mock again + obj.mock.return_value = "chips" + r = yield obj.fn(2, 3, 4) + self.assertEqual(r, "chips") + obj.mock.assert_called_once_with(2, 3, 4) + obj.mock.reset_mock() + + # the two values should now be cached; we should be able to vary + # the second argument and still get the cached result. + r = yield obj.fn(1, 4, 3) + self.assertEqual(r, "fish") + r = yield obj.fn(2, 5, 4) + self.assertEqual(r, "chips") + obj.mock.assert_not_called() + + @defer.inlineCallbacks + def test_cache_kwargs(self): + """Test that keyword arguments are treated properly""" + + class Cls: + def __init__(self): + self.mock = mock.Mock() + + @descriptors.cached() + def fn(self, arg1, kwarg1=2): + return self.mock(arg1, kwarg1=kwarg1) + + obj = Cls() + obj.mock.return_value = "fish" + r = yield obj.fn(1, kwarg1=2) + self.assertEqual(r, "fish") + obj.mock.assert_called_once_with(1, kwarg1=2) + obj.mock.reset_mock() + + # a call with different params should call the mock again + obj.mock.return_value = "chips" + r = yield obj.fn(1, kwarg1=3) + self.assertEqual(r, "chips") + obj.mock.assert_called_once_with(1, kwarg1=3) + obj.mock.reset_mock() + + # the values should now be cached. + r = yield obj.fn(1, kwarg1=2) + self.assertEqual(r, "fish") + # We should be able to not provide kwarg1 and get the cached value back. + r = yield obj.fn(1) + self.assertEqual(r, "fish") + # Keyword arguments can be in any order. + r = yield obj.fn(kwarg1=2, arg1=1) + self.assertEqual(r, "fish") + obj.mock.assert_not_called() + def test_cache_with_sync_exception(self): """If the wrapped function throws synchronously, things should continue to work""" @@ -656,7 +734,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): def fn(self, arg1, arg2): pass - @descriptors.cachedList("fn", "args1") + @descriptors.cachedList(cached_method_name="fn", list_name="args1") async def list_fn(self, args1, arg2): assert current_context().name == "c1" # we want this to behave like an asynchronous function @@ -715,7 +793,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): def fn(self, arg1): pass - @descriptors.cachedList("fn", "args1") + @descriptors.cachedList(cached_method_name="fn", list_name="args1") def list_fn(self, args1) -> "Deferred[dict]": return self.mock(args1) @@ -758,7 +836,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): def fn(self, arg1, arg2): pass - @descriptors.cachedList("fn", "args1") + @descriptors.cachedList(cached_method_name="fn", list_name="args1") async def list_fn(self, args1, arg2): # we want this to behave like an asynchronous function await run_on_reactor() -- cgit 1.5.1 From 88cd6f937807e64c05458cec86ef0ba0c1c656b3 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 10 Mar 2022 09:03:59 -0500 Subject: Allow retrieving the relations of a redacted event. (#12130) This is allowed per MSC2675, although the original implementation did not allow for it and would return an empty chunk / not bundle aggregations. The main thing to improve is that the various caches get cleared properly when an event is redacted, and that edits must not leak if the original event is redacted (as that would presumably leak something similar to the original event content). --- changelog.d/12130.bugfix | 1 + changelog.d/12189.bugfix | 1 + changelog.d/12189.misc | 1 - synapse/rest/client/relations.py | 82 +++++++++++++---------------- synapse/storage/databases/main/cache.py | 4 ++ synapse/storage/databases/main/events.py | 11 ++-- synapse/storage/databases/main/relations.py | 60 +++++++++++---------- tests/rest/client/test_relations.py | 45 ++++++++++++++-- 8 files changed, 122 insertions(+), 83 deletions(-) create mode 100644 changelog.d/12130.bugfix create mode 100644 changelog.d/12189.bugfix delete mode 100644 changelog.d/12189.misc (limited to 'synapse/storage') diff --git a/changelog.d/12130.bugfix b/changelog.d/12130.bugfix new file mode 100644 index 0000000000..df9b0dc413 --- /dev/null +++ b/changelog.d/12130.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug when redacting events with relations. diff --git a/changelog.d/12189.bugfix b/changelog.d/12189.bugfix new file mode 100644 index 0000000000..df9b0dc413 --- /dev/null +++ b/changelog.d/12189.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug when redacting events with relations. diff --git a/changelog.d/12189.misc b/changelog.d/12189.misc deleted file mode 100644 index 015e808e63..0000000000 --- a/changelog.d/12189.misc +++ /dev/null @@ -1 +0,0 @@ -Support skipping some arguments when generating cache keys. diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 07fa1cdd4c..d9a6be43f7 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -27,7 +27,7 @@ from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns -from synapse.storage.relations import AggregationPaginationToken, PaginationChunk +from synapse.storage.relations import AggregationPaginationToken from synapse.types import JsonDict, StreamToken if TYPE_CHECKING: @@ -82,28 +82,25 @@ class RelationPaginationServlet(RestServlet): from_token_str = parse_string(request, "from") to_token_str = parse_string(request, "to") - if event.internal_metadata.is_redacted(): - # If the event is redacted, return an empty list of relations - pagination_chunk = PaginationChunk(chunk=[]) - else: - # Return the relations - from_token = None - if from_token_str: - from_token = await StreamToken.from_string(self.store, from_token_str) - to_token = None - if to_token_str: - to_token = await StreamToken.from_string(self.store, to_token_str) - - pagination_chunk = await self.store.get_relations_for_event( - event_id=parent_id, - room_id=room_id, - relation_type=relation_type, - event_type=event_type, - limit=limit, - direction=direction, - from_token=from_token, - to_token=to_token, - ) + # Return the relations + from_token = None + if from_token_str: + from_token = await StreamToken.from_string(self.store, from_token_str) + to_token = None + if to_token_str: + to_token = await StreamToken.from_string(self.store, to_token_str) + + pagination_chunk = await self.store.get_relations_for_event( + event_id=parent_id, + event=event, + room_id=room_id, + relation_type=relation_type, + event_type=event_type, + limit=limit, + direction=direction, + from_token=from_token, + to_token=to_token, + ) events = await self.store.get_events_as_list( [c["event_id"] for c in pagination_chunk.chunk] @@ -193,27 +190,23 @@ class RelationAggregationPaginationServlet(RestServlet): from_token_str = parse_string(request, "from") to_token_str = parse_string(request, "to") - if event.internal_metadata.is_redacted(): - # If the event is redacted, return an empty list of relations - pagination_chunk = PaginationChunk(chunk=[]) - else: - # Return the relations - from_token = None - if from_token_str: - from_token = AggregationPaginationToken.from_string(from_token_str) - - to_token = None - if to_token_str: - to_token = AggregationPaginationToken.from_string(to_token_str) - - pagination_chunk = await self.store.get_aggregation_groups_for_event( - event_id=parent_id, - room_id=room_id, - event_type=event_type, - limit=limit, - from_token=from_token, - to_token=to_token, - ) + # Return the relations + from_token = None + if from_token_str: + from_token = AggregationPaginationToken.from_string(from_token_str) + + to_token = None + if to_token_str: + to_token = AggregationPaginationToken.from_string(to_token_str) + + pagination_chunk = await self.store.get_aggregation_groups_for_event( + event_id=parent_id, + room_id=room_id, + event_type=event_type, + limit=limit, + from_token=from_token, + to_token=to_token, + ) return 200, await pagination_chunk.to_dict(self.store) @@ -295,6 +288,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): result = await self.store.get_relations_for_event( event_id=parent_id, + event=event, room_id=room_id, relation_type=relation_type, event_type=event_type, diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index abd54c7dc7..d6a2df1afe 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -191,6 +191,10 @@ class CacheInvalidationWorkerStore(SQLBaseStore): if redacts: self._invalidate_get_event_cache(redacts) + # Caches which might leak edits must be invalidated for the event being + # redacted. + self.get_relations_for_event.invalidate((redacts,)) + self.get_applicable_edit.invalidate((redacts,)) if etype == EventTypes.Member: self._membership_stream_cache.entity_has_changed(state_key, stream_ordering) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 1dc83aa5e3..1a322882bf 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1619,9 +1619,12 @@ class PersistEventsStore: txn.call_after(prefill) - def _store_redaction(self, txn, event): - # invalidate the cache for the redacted event + def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None: + # Invalidate the caches for the redacted event, note that these caches + # are also cleared as part of event replication in _invalidate_caches_for_event. txn.call_after(self.store._invalidate_get_event_cache, event.redacts) + txn.call_after(self.store.get_relations_for_event.invalidate, (event.redacts,)) + txn.call_after(self.store.get_applicable_edit.invalidate, (event.redacts,)) self.db_pool.simple_upsert_txn( txn, @@ -1812,9 +1815,7 @@ class PersistEventsStore: txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,)) if rel_type == RelationTypes.THREAD: - txn.call_after( - self.store.get_thread_summary.invalidate, (parent_id, event.room_id) - ) + txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,)) # It should be safe to only invalidate the cache if the user has not # previously participated in the thread, but that's difficult (and # potentially error-prone) so it is always invalidated. diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 36aa1092f6..be1500092b 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -91,10 +91,11 @@ class RelationsWorkerStore(SQLBaseStore): self._msc3440_enabled = hs.config.experimental.msc3440_enabled - @cached(tree=True) + @cached(uncached_args=("event",), tree=True) async def get_relations_for_event( self, event_id: str, + event: EventBase, room_id: str, relation_type: Optional[str] = None, event_type: Optional[str] = None, @@ -108,6 +109,7 @@ class RelationsWorkerStore(SQLBaseStore): Args: event_id: Fetch events that relate to this event ID. + event: The matching EventBase to event_id. room_id: The room the event belongs to. relation_type: Only fetch events with this relation type, if given. event_type: Only fetch events with this event type, if given. @@ -122,9 +124,13 @@ class RelationsWorkerStore(SQLBaseStore): List of event IDs that match relations requested. The rows are of the form `{"event_id": "..."}`. """ + # We don't use `event_id`, it's there so that we can cache based on + # it. The `event_id` must match the `event.event_id`. + assert event.event_id == event_id where_clause = ["relates_to_id = ?", "room_id = ?"] - where_args: List[Union[str, int]] = [event_id, room_id] + where_args: List[Union[str, int]] = [event.event_id, room_id] + is_redacted = event.internal_metadata.is_redacted() if relation_type is not None: where_clause.append("relation_type = ?") @@ -157,7 +163,7 @@ class RelationsWorkerStore(SQLBaseStore): order = "ASC" sql = """ - SELECT event_id, topological_ordering, stream_ordering + SELECT event_id, relation_type, topological_ordering, stream_ordering FROM event_relations INNER JOIN events USING (event_id) WHERE %s @@ -178,9 +184,12 @@ class RelationsWorkerStore(SQLBaseStore): last_stream_id = None events = [] for row in txn: - events.append({"event_id": row[0]}) - last_topo_id = row[1] - last_stream_id = row[2] + # Do not include edits for redacted events as they leak event + # content. + if not is_redacted or row[1] != RelationTypes.REPLACE: + events.append({"event_id": row[0]}) + last_topo_id = row[2] + last_stream_id = row[3] # If there are more events, generate the next pagination key. next_token = None @@ -776,7 +785,7 @@ class RelationsWorkerStore(SQLBaseStore): ) references = await self.get_relations_for_event( - event_id, room_id, RelationTypes.REFERENCE, direction="f" + event_id, event, room_id, RelationTypes.REFERENCE, direction="f" ) if references.chunk: aggregations.references = await references.to_dict(cast("DataStore", self)) @@ -797,41 +806,36 @@ class RelationsWorkerStore(SQLBaseStore): A map of event ID to the bundled aggregation for the event. Not all events may have bundled aggregations in the results. """ - # The already processed event IDs. Tracked separately from the result - # since the result omits events which do not have bundled aggregations. - seen_event_ids = set() - - # State events and redacted events do not get bundled aggregations. - events = [ - event - for event in events - if not event.is_state() and not event.internal_metadata.is_redacted() - ] + # De-duplicate events by ID to handle the same event requested multiple times. + # + # State events do not get bundled aggregations. + events_by_id = { + event.event_id: event for event in events if not event.is_state() + } # event ID -> bundled aggregation in non-serialized form. results: Dict[str, BundledAggregations] = {} # Fetch other relations per event. - for event in events: - # De-duplicate events by ID to handle the same event requested multiple - # times. The caches that _get_bundled_aggregation_for_event use should - # capture this, but best to reduce work. - if event.event_id in seen_event_ids: - continue - seen_event_ids.add(event.event_id) - + for event in events_by_id.values(): event_result = await self._get_bundled_aggregation_for_event(event, user_id) if event_result: results[event.event_id] = event_result - # Fetch any edits. - edits = await self._get_applicable_edits(seen_event_ids) + # Fetch any edits (but not for redacted events). + edits = await self._get_applicable_edits( + [ + event_id + for event_id, event in events_by_id.items() + if not event.internal_metadata.is_redacted() + ] + ) for event_id, edit in edits.items(): results.setdefault(event_id, BundledAggregations()).replace = edit # Fetch thread summaries. if self._msc3440_enabled: - summaries = await self._get_thread_summaries(seen_event_ids) + summaries = await self._get_thread_summaries(events_by_id.keys()) # Only fetch participated for a limited selection based on what had # summaries. participated = await self._get_threads_participated( diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index a40a5de399..f9ae6e663f 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1475,12 +1475,13 @@ class RelationRedactionTestCase(BaseRelationsTestCase): self.assertEqual(relations, {}) def test_redact_parent_annotation(self) -> None: - """Test that annotations of an event are redacted when the original event + """Test that annotations of an event are viewable when the original event is redacted. """ # Add a relation channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") self.assertEqual(200, channel.code, channel.json_body) + related_event_id = channel.json_body["event_id"] # The relations should exist. event_ids, relations = self._make_relation_requests() @@ -1494,11 +1495,45 @@ class RelationRedactionTestCase(BaseRelationsTestCase): # Redact the original event. self._redact(self.parent_id) - # The relations are not returned. + # The relations are returned. event_ids, relations = self._make_relation_requests() - self.assertEqual(event_ids, []) - self.assertEqual(relations, {}) + self.assertEquals(event_ids, [related_event_id]) + self.assertEquals( + relations["m.annotation"], + {"chunk": [{"type": "m.reaction", "key": "👍", "count": 1}]}, + ) # There's nothing to aggregate. chunk = self._get_aggregations() - self.assertEqual(chunk, []) + self.assertEqual(chunk, [{"count": 1, "key": "👍", "type": "m.reaction"}]) + + @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) + def test_redact_parent_thread(self) -> None: + """ + Test that thread replies are still available when the root event is redacted. + """ + channel = self._send_relation( + RelationTypes.THREAD, + EventTypes.Message, + content={"body": "reply 1", "msgtype": "m.text"}, + ) + self.assertEqual(200, channel.code, channel.json_body) + related_event_id = channel.json_body["event_id"] + + # Redact one of the reactions. + self._redact(self.parent_id) + + # The unredacted relation should still exist. + event_ids, relations = self._make_relation_requests() + self.assertEquals(len(event_ids), 1) + self.assertDictContainsSubset( + { + "count": 1, + "current_user_participated": True, + }, + relations[RelationTypes.THREAD], + ) + self.assertEqual( + relations[RelationTypes.THREAD]["latest_event"]["event_id"], + related_event_id, + ) -- cgit 1.5.1 From ea27528b5d177dcfc5a4e38b463baeace916dc8e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 10 Mar 2022 10:36:13 -0500 Subject: Support stable identifiers for MSC3440: Threading (#12151) The unstable identifiers are still supported if the experimental configuration flag is enabled. The unstable identifiers will be removed in a future release. --- changelog.d/12151.feature | 1 + synapse/api/constants.py | 4 +- synapse/api/filtering.py | 23 ++++----- synapse/events/utils.py | 9 +++- synapse/handlers/message.py | 5 +- synapse/rest/client/versions.py | 1 + synapse/server.py | 2 +- synapse/storage/databases/main/events.py | 5 +- synapse/storage/databases/main/relations.py | 77 ++++++++++++++++++----------- synapse/storage/databases/main/stream.py | 18 ++++--- tests/rest/client/test_relations.py | 7 +-- tests/rest/client/test_rooms.py | 18 +++---- tests/storage/test_stream.py | 20 ++++---- 13 files changed, 109 insertions(+), 81 deletions(-) create mode 100644 changelog.d/12151.feature (limited to 'synapse/storage') diff --git a/changelog.d/12151.feature b/changelog.d/12151.feature new file mode 100644 index 0000000000..18432b2da9 --- /dev/null +++ b/changelog.d/12151.feature @@ -0,0 +1 @@ +Support the stable identifiers from [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440): threads. diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 36ace7c613..b0c08a074d 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -178,7 +178,9 @@ class RelationTypes: ANNOTATION: Final = "m.annotation" REPLACE: Final = "m.replace" REFERENCE: Final = "m.reference" - THREAD: Final = "io.element.thread" + THREAD: Final = "m.thread" + # TODO Remove this in Synapse >= v1.57.0. + UNSTABLE_THREAD: Final = "io.element.thread" class LimitBlockingTypes: diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index cb532d7238..27e97d6f37 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -88,7 +88,9 @@ ROOM_EVENT_FILTER_SCHEMA = { "org.matrix.labels": {"type": "array", "items": {"type": "string"}}, "org.matrix.not_labels": {"type": "array", "items": {"type": "string"}}, # MSC3440, filtering by event relations. + "related_by_senders": {"type": "array", "items": {"type": "string"}}, "io.element.relation_senders": {"type": "array", "items": {"type": "string"}}, + "related_by_rel_types": {"type": "array", "items": {"type": "string"}}, "io.element.relation_types": {"type": "array", "items": {"type": "string"}}, }, } @@ -318,19 +320,18 @@ class Filter: self.labels = filter_json.get("org.matrix.labels", None) self.not_labels = filter_json.get("org.matrix.not_labels", []) - # Ideally these would be rejected at the endpoint if they were provided - # and not supported, but that would involve modifying the JSON schema - # based on the homeserver configuration. + self.related_by_senders = self.filter_json.get("related_by_senders", None) + self.related_by_rel_types = self.filter_json.get("related_by_rel_types", None) + + # Fallback to the unstable prefix if the stable version is not given. if hs.config.experimental.msc3440_enabled: - self.relation_senders = self.filter_json.get( + self.related_by_senders = self.related_by_senders or self.filter_json.get( "io.element.relation_senders", None ) - self.relation_types = self.filter_json.get( - "io.element.relation_types", None + self.related_by_rel_types = ( + self.related_by_rel_types + or self.filter_json.get("io.element.relation_types", None) ) - else: - self.relation_senders = None - self.relation_types = None def filters_all_types(self) -> bool: return "*" in self.not_types @@ -461,7 +462,7 @@ class Filter: event_ids = [event.event_id for event in events if isinstance(event, EventBase)] # type: ignore[attr-defined] event_ids_to_keep = set( await self._store.events_have_relations( - event_ids, self.relation_senders, self.relation_types + event_ids, self.related_by_senders, self.related_by_rel_types ) ) @@ -474,7 +475,7 @@ class Filter: async def filter(self, events: Iterable[FilterEvent]) -> List[FilterEvent]: result = [event for event in events if self._check(event)] - if self.relation_senders or self.relation_types: + if self.related_by_senders or self.related_by_rel_types: return await self._check_event_relations(result) return result diff --git a/synapse/events/utils.py b/synapse/events/utils.py index ee34cb46e4..b2a237c1e0 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -38,6 +38,7 @@ from synapse.util.frozenutils import unfreeze from . import EventBase if TYPE_CHECKING: + from synapse.server import HomeServer from synapse.storage.databases.main.relations import BundledAggregations @@ -395,6 +396,9 @@ class EventClientSerializer: clients. """ + def __init__(self, hs: "HomeServer"): + self._msc3440_enabled = hs.config.experimental.msc3440_enabled + def serialize_event( self, event: Union[JsonDict, EventBase], @@ -515,11 +519,14 @@ class EventClientSerializer: thread.latest_event, serialized_latest_event, thread.latest_edit ) - serialized_aggregations[RelationTypes.THREAD] = { + thread_summary = { "latest_event": serialized_latest_event, "count": thread.count, "current_user_participated": thread.current_user_participated, } + serialized_aggregations[RelationTypes.THREAD] = thread_summary + if self._msc3440_enabled: + serialized_aggregations[RelationTypes.UNSTABLE_THREAD] = thread_summary # Include the bundled aggregations in the event. if serialized_aggregations: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 0799ec9a84..f9544fe7fb 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1079,7 +1079,10 @@ class EventCreationHandler: raise SynapseError(400, "Can't send same reaction twice") # Don't attempt to start a thread if the parent event is a relation. - elif relation_type == RelationTypes.THREAD: + elif ( + relation_type == RelationTypes.THREAD + or relation_type == RelationTypes.UNSTABLE_THREAD + ): if await self.store.event_includes_relation(relates_to): raise SynapseError( 400, "Cannot start threads from an event with a relation" diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 2e5d0e4e22..9a65aa4843 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -101,6 +101,7 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc3030": self.config.experimental.msc3030_enabled, # Adds support for thread relations, per MSC3440. "org.matrix.msc3440": self.config.experimental.msc3440_enabled, + "org.matrix.msc3440.stable": True, # TODO: remove when "v1.3" is added above }, }, ) diff --git a/synapse/server.py b/synapse/server.py index 1270abb5a3..7741ff29dc 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -754,7 +754,7 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_event_client_serializer(self) -> EventClientSerializer: - return EventClientSerializer() + return EventClientSerializer(self) @cache_in_self def get_password_policy_handler(self) -> PasswordPolicyHandler: diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 1a322882bf..1f60aef180 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1814,7 +1814,10 @@ class PersistEventsStore: if rel_type == RelationTypes.REPLACE: txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,)) - if rel_type == RelationTypes.THREAD: + if ( + rel_type == RelationTypes.THREAD + or rel_type == RelationTypes.UNSTABLE_THREAD + ): txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,)) # It should be safe to only invalidate the cache if the user has not # previously participated in the thread, but that's difficult (and diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index be1500092b..c4869d64e6 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -508,7 +508,7 @@ class RelationsWorkerStore(SQLBaseStore): AND parent.room_id = child.room_id WHERE %s - AND relation_type = ? + AND %s ORDER BY parent.event_id, child.topological_ordering DESC, child.stream_ordering DESC """ else: @@ -523,16 +523,22 @@ class RelationsWorkerStore(SQLBaseStore): AND parent.room_id = child.room_id WHERE %s - AND relation_type = ? + AND %s ORDER BY child.topological_ordering DESC, child.stream_ordering DESC """ clause, args = make_in_list_sql_clause( txn.database_engine, "relates_to_id", event_ids ) - args.append(RelationTypes.THREAD) - txn.execute(sql % (clause,), args) + if self._msc3440_enabled: + relations_clause = "(relation_type = ? OR relation_type = ?)" + args.extend((RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD)) + else: + relations_clause = "relation_type = ?" + args.append(RelationTypes.THREAD) + + txn.execute(sql % (clause, relations_clause), args) latest_event_ids = {} for parent_event_id, child_event_id in txn: # Only consider the latest threaded reply (by topological ordering). @@ -552,7 +558,7 @@ class RelationsWorkerStore(SQLBaseStore): AND parent.room_id = child.room_id WHERE %s - AND relation_type = ? + AND %s GROUP BY parent.event_id """ @@ -561,9 +567,15 @@ class RelationsWorkerStore(SQLBaseStore): clause, args = make_in_list_sql_clause( txn.database_engine, "relates_to_id", latest_event_ids.keys() ) - args.append(RelationTypes.THREAD) - txn.execute(sql % (clause,), args) + if self._msc3440_enabled: + relations_clause = "(relation_type = ? OR relation_type = ?)" + args.extend((RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD)) + else: + relations_clause = "relation_type = ?" + args.append(RelationTypes.THREAD) + + txn.execute(sql % (clause, relations_clause), args) counts = dict(cast(List[Tuple[str, int]], txn.fetchall())) return counts, latest_event_ids @@ -626,16 +638,24 @@ class RelationsWorkerStore(SQLBaseStore): AND parent.room_id = child.room_id WHERE %s - AND relation_type = ? + AND %s AND child.sender = ? """ clause, args = make_in_list_sql_clause( txn.database_engine, "relates_to_id", event_ids ) - args.extend((RelationTypes.THREAD, user_id)) - txn.execute(sql % (clause,), args) + if self._msc3440_enabled: + relations_clause = "(relation_type = ? OR relation_type = ?)" + args.extend((RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD)) + else: + relations_clause = "relation_type = ?" + args.append(RelationTypes.THREAD) + + args.append(user_id) + + txn.execute(sql % (clause, relations_clause), args) return {row[0] for row in txn.fetchall()} participated_threads = await self.db_pool.runInteraction( @@ -834,26 +854,23 @@ class RelationsWorkerStore(SQLBaseStore): results.setdefault(event_id, BundledAggregations()).replace = edit # Fetch thread summaries. - if self._msc3440_enabled: - summaries = await self._get_thread_summaries(events_by_id.keys()) - # Only fetch participated for a limited selection based on what had - # summaries. - participated = await self._get_threads_participated( - summaries.keys(), user_id - ) - for event_id, summary in summaries.items(): - if summary: - thread_count, latest_thread_event, edit = summary - results.setdefault( - event_id, BundledAggregations() - ).thread = _ThreadAggregation( - latest_event=latest_thread_event, - latest_edit=edit, - count=thread_count, - # If there's a thread summary it must also exist in the - # participated dictionary. - current_user_participated=participated[event_id], - ) + summaries = await self._get_thread_summaries(events_by_id.keys()) + # Only fetch participated for a limited selection based on what had + # summaries. + participated = await self._get_threads_participated(summaries.keys(), user_id) + for event_id, summary in summaries.items(): + if summary: + thread_count, latest_thread_event, edit = summary + results.setdefault( + event_id, BundledAggregations() + ).thread = _ThreadAggregation( + latest_event=latest_thread_event, + latest_edit=edit, + count=thread_count, + # If there's a thread summary it must also exist in the + # participated dictionary. + current_user_participated=participated[event_id], + ) return results diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index a898f847e7..39e1efe373 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -325,21 +325,23 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]: args.extend(event_filter.labels) # Filter on relation_senders / relation types from the joined tables. - if event_filter.relation_senders: + if event_filter.related_by_senders: clauses.append( "(%s)" % " OR ".join( - "related_event.sender = ?" for _ in event_filter.relation_senders + "related_event.sender = ?" for _ in event_filter.related_by_senders ) ) - args.extend(event_filter.relation_senders) + args.extend(event_filter.related_by_senders) - if event_filter.relation_types: + if event_filter.related_by_rel_types: clauses.append( "(%s)" - % " OR ".join("relation_type = ?" for _ in event_filter.relation_types) + % " OR ".join( + "relation_type = ?" for _ in event_filter.related_by_rel_types + ) ) - args.extend(event_filter.relation_types) + args.extend(event_filter.related_by_rel_types) return " AND ".join(clauses), args @@ -1203,7 +1205,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): # If there is a filter on relation_senders and relation_types join to the # relations table. if event_filter and ( - event_filter.relation_senders or event_filter.relation_types + event_filter.related_by_senders or event_filter.related_by_rel_types ): # Filtering by relations could cause the same event to appear multiple # times (since there's no limit on the number of relations to an event). @@ -1211,7 +1213,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): join_clause += """ LEFT JOIN event_relations AS relation ON (event.event_id = relation.relates_to_id) """ - if event_filter.relation_senders: + if event_filter.related_by_senders: join_clause += """ LEFT JOIN events AS related_event ON (relation.event_id = related_event.event_id) """ diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index f9ae6e663f..0cbe6c0cf7 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -547,9 +547,7 @@ class RelationsTestCase(BaseRelationsTestCase): ) self.assertEqual(400, channel.code, channel.json_body) - @unittest.override_config( - {"experimental_features": {"msc3440_enabled": True, "msc3666_enabled": True}} - ) + @unittest.override_config({"experimental_features": {"msc3666_enabled": True}}) def test_bundled_aggregations(self) -> None: """ Test that annotations, references, and threads get correctly bundled. @@ -758,7 +756,6 @@ class RelationsTestCase(BaseRelationsTestCase): }, ) - @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_ignore_invalid_room(self) -> None: """Test that we ignore invalid relations over federation.""" # Create another room and send a message in it. @@ -1065,7 +1062,6 @@ class RelationsTestCase(BaseRelationsTestCase): {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict ) - @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_edit_thread(self) -> None: """Test that editing a thread works.""" @@ -1383,7 +1379,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase): chunk = self._get_aggregations() self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 1}]) - @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_redact_relation_thread(self) -> None: """ Test that thread replies are properly handled after the thread reply redacted. diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 37866ee330..3a9617d6da 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -2141,21 +2141,19 @@ class RelationsTestCase(unittest.HomeserverTestCase): def test_filter_relation_senders(self) -> None: # Messages which second user reacted to. - filter = {"io.element.relation_senders": [self.second_user_id]} + filter = {"related_by_senders": [self.second_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0]["event_id"], self.event_id_1) # Messages which third user reacted to. - filter = {"io.element.relation_senders": [self.third_user_id]} + filter = {"related_by_senders": [self.third_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0]["event_id"], self.event_id_2) # Messages which either user reacted to. - filter = { - "io.element.relation_senders": [self.second_user_id, self.third_user_id] - } + filter = {"related_by_senders": [self.second_user_id, self.third_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 2, chunk) self.assertCountEqual( @@ -2164,20 +2162,20 @@ class RelationsTestCase(unittest.HomeserverTestCase): def test_filter_relation_type(self) -> None: # Messages which have annotations. - filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]} + filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0]["event_id"], self.event_id_1) # Messages which have references. - filter = {"io.element.relation_types": [RelationTypes.REFERENCE]} + filter = {"related_by_rel_types": [RelationTypes.REFERENCE]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0]["event_id"], self.event_id_2) # Messages which have either annotations or references. filter = { - "io.element.relation_types": [ + "related_by_rel_types": [ RelationTypes.ANNOTATION, RelationTypes.REFERENCE, ] @@ -2191,8 +2189,8 @@ class RelationsTestCase(unittest.HomeserverTestCase): def test_filter_relation_senders_and_type(self) -> None: # Messages which second user reacted to. filter = { - "io.element.relation_senders": [self.second_user_id], - "io.element.relation_types": [RelationTypes.ANNOTATION], + "related_by_senders": [self.second_user_id], + "related_by_rel_types": [RelationTypes.ANNOTATION], } chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py index 6a1cf33054..eaa0d7d749 100644 --- a/tests/storage/test_stream.py +++ b/tests/storage/test_stream.py @@ -129,21 +129,19 @@ class PaginationTestCase(HomeserverTestCase): def test_filter_relation_senders(self): # Messages which second user reacted to. - filter = {"io.element.relation_senders": [self.second_user_id]} + filter = {"related_by_senders": [self.second_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0].event_id, self.event_id_1) # Messages which third user reacted to. - filter = {"io.element.relation_senders": [self.third_user_id]} + filter = {"related_by_senders": [self.third_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0].event_id, self.event_id_2) # Messages which either user reacted to. - filter = { - "io.element.relation_senders": [self.second_user_id, self.third_user_id] - } + filter = {"related_by_senders": [self.second_user_id, self.third_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 2, chunk) self.assertCountEqual( @@ -152,20 +150,20 @@ class PaginationTestCase(HomeserverTestCase): def test_filter_relation_type(self): # Messages which have annotations. - filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]} + filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0].event_id, self.event_id_1) # Messages which have references. - filter = {"io.element.relation_types": [RelationTypes.REFERENCE]} + filter = {"related_by_rel_types": [RelationTypes.REFERENCE]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0].event_id, self.event_id_2) # Messages which have either annotations or references. filter = { - "io.element.relation_types": [ + "related_by_rel_types": [ RelationTypes.ANNOTATION, RelationTypes.REFERENCE, ] @@ -179,8 +177,8 @@ class PaginationTestCase(HomeserverTestCase): def test_filter_relation_senders_and_type(self): # Messages which second user reacted to. filter = { - "io.element.relation_senders": [self.second_user_id], - "io.element.relation_types": [RelationTypes.ANNOTATION], + "related_by_senders": [self.second_user_id], + "related_by_rel_types": [RelationTypes.ANNOTATION], } chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) @@ -201,7 +199,7 @@ class PaginationTestCase(HomeserverTestCase): tok=self.second_tok, ) - filter = {"io.element.relation_senders": [self.second_user_id]} + filter = {"related_by_senders": [self.second_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0].event_id, self.event_id_1) -- cgit 1.5.1 From bc9dff1d9597251a15a15475cb8e8194b2d14910 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 11 Mar 2022 07:06:21 -0500 Subject: Remove unnecessary pass statements. (#12206) --- changelog.d/12206.misc | 1 + synapse/handlers/device.py | 2 -- synapse/handlers/presence.py | 2 -- synapse/http/matrixfederationclient.py | 2 -- synapse/http/server.py | 1 - synapse/rest/media/v1/_base.py | 1 - synapse/server.py | 1 - synapse/storage/databases/main/registration.py | 2 -- synapse/storage/schema/main/delta/30/as_users.py | 1 - synapse/util/caches/treecache.py | 2 -- tests/handlers/test_password_providers.py | 1 - 11 files changed, 1 insertion(+), 15 deletions(-) create mode 100644 changelog.d/12206.misc (limited to 'synapse/storage') diff --git a/changelog.d/12206.misc b/changelog.d/12206.misc new file mode 100644 index 0000000000..df59bb56cd --- /dev/null +++ b/changelog.d/12206.misc @@ -0,0 +1 @@ +Remove unnecessary `pass` statements. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index d90cb259a6..d5ccaa0c37 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -371,7 +371,6 @@ class DeviceHandler(DeviceWorkerHandler): log_kv( {"reason": "User doesn't have device id.", "device_id": device_id} ) - pass else: raise @@ -414,7 +413,6 @@ class DeviceHandler(DeviceWorkerHandler): # no match set_tag("error", True) set_tag("reason", "User doesn't have that device id.") - pass else: raise diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 9927a30e6e..34d9411bbf 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -267,7 +267,6 @@ class BasePresenceHandler(abc.ABC): is_syncing: Whether or not the user is now syncing sync_time_msec: Time in ms when the user was last syncing """ - pass async def update_external_syncs_clear(self, process_id: str) -> None: """Marks all users that had been marked as syncing by a given process @@ -277,7 +276,6 @@ class BasePresenceHandler(abc.ABC): This is a no-op when presence is handled by a different worker. """ - pass async def process_replication_rows( self, stream_name: str, instance_name: str, token: int, rows: list diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 40bf1e06d6..6b98d865f5 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -120,7 +120,6 @@ class ByteParser(ByteWriteable, Generic[T], abc.ABC): """Called when response has finished streaming and the parser should return the final result (or error). """ - pass @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -601,7 +600,6 @@ class MatrixFederationHttpClient: response.code, response_phrase, ) - pass else: logger.info( "{%s} [%s] Got response headers: %d %s", diff --git a/synapse/http/server.py b/synapse/http/server.py index 09b4125489..31ca841889 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -233,7 +233,6 @@ class HttpServer(Protocol): servlet_classname (str): The name of the handler to be used in prometheus and opentracing logs. """ - pass class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 9b40fd8a6c..c35d42fab8 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -298,7 +298,6 @@ class Responder: Returns: Resolves once the response has finished being written """ - pass def __enter__(self) -> None: pass diff --git a/synapse/server.py b/synapse/server.py index 7741ff29dc..2fcf18a7a6 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -328,7 +328,6 @@ class HomeServer(metaclass=abc.ABCMeta): Does nothing in this base class; overridden in derived classes to start the appropriate listeners. """ - pass def setup_background_tasks(self) -> None: """ diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index dc6665237a..a698d10cc5 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -48,8 +48,6 @@ class ExternalIDReuseException(Exception): """Exception if writing an external id for a user fails, because this external id is given to an other user.""" - pass - @attr.s(frozen=True, slots=True, auto_attribs=True) class TokenLookupResult: diff --git a/synapse/storage/schema/main/delta/30/as_users.py b/synapse/storage/schema/main/delta/30/as_users.py index 22a7901e15..4b4b166e37 100644 --- a/synapse/storage/schema/main/delta/30/as_users.py +++ b/synapse/storage/schema/main/delta/30/as_users.py @@ -36,7 +36,6 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs): config_files = config.appservice.app_service_config_files except AttributeError: logger.warning("Could not get app_service_config_files from config") - pass appservices = load_appservices(config.server.server_name, config_files) diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py index 563845f867..e78305f787 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py @@ -22,8 +22,6 @@ class TreeCacheNode(dict): leaves. """ - pass - class TreeCache: """ diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 49d832de81..d401fda938 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -124,7 +124,6 @@ class PasswordCustomAuthProvider: ("m.login.password", ("password",)): self.check_auth, } ) - pass def check_auth(self, *args): return mock_password_provider.check_auth(*args) -- cgit 1.5.1 From ef3619e61d84493d98470eb2a69131d15eb1166b Mon Sep 17 00:00:00 2001 From: Shay Date: Fri, 11 Mar 2022 10:46:45 -0800 Subject: Add config settings for background update parameters (#11980) --- changelog.d/11980.misc | 1 + docs/sample_config.yaml | 32 ++++ synapse/config/_base.pyi | 2 + synapse/config/background_updates.py | 68 ++++++++ synapse/config/homeserver.py | 2 + synapse/storage/background_updates.py | 39 +++-- tests/config/test_background_update.py | 58 +++++++ tests/rest/admin/test_background_updates.py | 9 +- tests/storage/test_background_update.py | 253 ++++++++++++++++++++++++++-- 9 files changed, 430 insertions(+), 34 deletions(-) create mode 100644 changelog.d/11980.misc create mode 100644 synapse/config/background_updates.py create mode 100644 tests/config/test_background_update.py (limited to 'synapse/storage') diff --git a/changelog.d/11980.misc b/changelog.d/11980.misc new file mode 100644 index 0000000000..36e992e645 --- /dev/null +++ b/changelog.d/11980.misc @@ -0,0 +1 @@ +Add config settings for background update parameters. \ No newline at end of file diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index d634fd8ff5..36c6c56e58 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -2735,3 +2735,35 @@ redis: # Optional password if configured on the Redis instance # #password: + + +## Background Updates ## + +# Background updates are database updates that are run in the background in batches. +# The duration, minimum batch size, default batch size, whether to sleep between batches and if so, how long to +# sleep can all be configured. This is helpful to speed up or slow down the updates. +# +background_updates: + # How long in milliseconds to run a batch of background updates for. Defaults to 100. Uncomment and set + # a time to change the default. + # + #background_update_duration_ms: 500 + + # Whether to sleep between updates. Defaults to True. Uncomment to change the default. + # + #sleep_enabled: false + + # If sleeping between updates, how long in milliseconds to sleep for. Defaults to 1000. Uncomment + # and set a duration to change the default. + # + #sleep_duration_ms: 300 + + # Minimum size a batch of background updates can be. Must be greater than 0. Defaults to 1. Uncomment and + # set a size to change the default. + # + #min_batch_size: 10 + + # The batch size to use for the first iteration of a new background update. The default is 100. + # Uncomment and set a size to change the default. + # + #default_batch_size: 50 diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index 1eb5f5a68c..363d8b4554 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -19,6 +19,7 @@ from synapse.config import ( api, appservice, auth, + background_updates, cache, captcha, cas, @@ -113,6 +114,7 @@ class RootConfig: caches: cache.CacheConfig federation: federation.FederationConfig retention: retention.RetentionConfig + background_updates: background_updates.BackgroundUpdateConfig config_classes: List[Type["Config"]] = ... def __init__(self) -> None: ... diff --git a/synapse/config/background_updates.py b/synapse/config/background_updates.py new file mode 100644 index 0000000000..f6cdeacc4b --- /dev/null +++ b/synapse/config/background_updates.py @@ -0,0 +1,68 @@ +# Copyright 2022 Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import Config + + +class BackgroundUpdateConfig(Config): + section = "background_updates" + + def generate_config_section(self, **kwargs) -> str: + return """\ + ## Background Updates ## + + # Background updates are database updates that are run in the background in batches. + # The duration, minimum batch size, default batch size, whether to sleep between batches and if so, how long to + # sleep can all be configured. This is helpful to speed up or slow down the updates. + # + background_updates: + # How long in milliseconds to run a batch of background updates for. Defaults to 100. Uncomment and set + # a time to change the default. + # + #background_update_duration_ms: 500 + + # Whether to sleep between updates. Defaults to True. Uncomment to change the default. + # + #sleep_enabled: false + + # If sleeping between updates, how long in milliseconds to sleep for. Defaults to 1000. Uncomment + # and set a duration to change the default. + # + #sleep_duration_ms: 300 + + # Minimum size a batch of background updates can be. Must be greater than 0. Defaults to 1. Uncomment and + # set a size to change the default. + # + #min_batch_size: 10 + + # The batch size to use for the first iteration of a new background update. The default is 100. + # Uncomment and set a size to change the default. + # + #default_batch_size: 50 + """ + + def read_config(self, config, **kwargs) -> None: + bg_update_config = config.get("background_updates") or {} + + self.update_duration_ms = bg_update_config.get( + "background_update_duration_ms", 100 + ) + + self.sleep_enabled = bg_update_config.get("sleep_enabled", True) + + self.sleep_duration_ms = bg_update_config.get("sleep_duration_ms", 1000) + + self.min_batch_size = bg_update_config.get("min_batch_size", 1) + + self.default_batch_size = bg_update_config.get("default_batch_size", 100) diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 001605c265..a4ec706908 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -16,6 +16,7 @@ from .account_validity import AccountValidityConfig from .api import ApiConfig from .appservice import AppServiceConfig from .auth import AuthConfig +from .background_updates import BackgroundUpdateConfig from .cache import CacheConfig from .captcha import CaptchaConfig from .cas import CasConfig @@ -99,4 +100,5 @@ class HomeServerConfig(RootConfig): WorkerConfig, RedisConfig, ExperimentalConfig, + BackgroundUpdateConfig, ] diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 4acc2c997d..08c6eabc6d 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -60,18 +60,19 @@ class _BackgroundUpdateHandler: class _BackgroundUpdateContextManager: - BACKGROUND_UPDATE_INTERVAL_MS = 1000 - BACKGROUND_UPDATE_DURATION_MS = 100 - - def __init__(self, sleep: bool, clock: Clock): + def __init__( + self, sleep: bool, clock: Clock, sleep_duration_ms: int, update_duration: int + ): self._sleep = sleep self._clock = clock + self._sleep_duration_ms = sleep_duration_ms + self._update_duration_ms = update_duration async def __aenter__(self) -> int: if self._sleep: - await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000) + await self._clock.sleep(self._sleep_duration_ms / 1000) - return self.BACKGROUND_UPDATE_DURATION_MS + return self._update_duration_ms async def __aexit__(self, *exc) -> None: pass @@ -133,9 +134,6 @@ class BackgroundUpdater: process and autotuning the batch size. """ - MINIMUM_BACKGROUND_BATCH_SIZE = 1 - DEFAULT_BACKGROUND_BATCH_SIZE = 100 - def __init__(self, hs: "HomeServer", database: "DatabasePool"): self._clock = hs.get_clock() self.db_pool = database @@ -160,6 +158,14 @@ class BackgroundUpdater: # enable/disable background updates via the admin API. self.enabled = True + self.minimum_background_batch_size = hs.config.background_updates.min_batch_size + self.default_background_batch_size = ( + hs.config.background_updates.default_batch_size + ) + self.update_duration_ms = hs.config.background_updates.update_duration_ms + self.sleep_duration_ms = hs.config.background_updates.sleep_duration_ms + self.sleep_enabled = hs.config.background_updates.sleep_enabled + def register_update_controller_callbacks( self, on_update: ON_UPDATE_CALLBACK, @@ -216,7 +222,9 @@ class BackgroundUpdater: if self._on_update_callback is not None: return self._on_update_callback(update_name, database_name, oneshot) - return _BackgroundUpdateContextManager(sleep, self._clock) + return _BackgroundUpdateContextManager( + sleep, self._clock, self.sleep_duration_ms, self.update_duration_ms + ) async def _default_batch_size(self, update_name: str, database_name: str) -> int: """The batch size to use for the first iteration of a new background @@ -225,7 +233,7 @@ class BackgroundUpdater: if self._default_batch_size_callback is not None: return await self._default_batch_size_callback(update_name, database_name) - return self.DEFAULT_BACKGROUND_BATCH_SIZE + return self.default_background_batch_size async def _min_batch_size(self, update_name: str, database_name: str) -> int: """A lower bound on the batch size of a new background update. @@ -235,7 +243,7 @@ class BackgroundUpdater: if self._min_batch_size_callback is not None: return await self._min_batch_size_callback(update_name, database_name) - return self.MINIMUM_BACKGROUND_BATCH_SIZE + return self.minimum_background_batch_size def get_current_update(self) -> Optional[BackgroundUpdatePerformance]: """Returns the current background update, if any.""" @@ -254,9 +262,12 @@ class BackgroundUpdater: if self.enabled: # if we start a new background update, not all updates are done. self._all_done = False - run_as_background_process("background_updates", self.run_background_updates) + sleep = self.sleep_enabled + run_as_background_process( + "background_updates", self.run_background_updates, sleep + ) - async def run_background_updates(self, sleep: bool = True) -> None: + async def run_background_updates(self, sleep: bool) -> None: if self._running or not self.enabled: return diff --git a/tests/config/test_background_update.py b/tests/config/test_background_update.py new file mode 100644 index 0000000000..0c32c1ca29 --- /dev/null +++ b/tests/config/test_background_update.py @@ -0,0 +1,58 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import yaml + +from synapse.storage.background_updates import BackgroundUpdater + +from tests.unittest import HomeserverTestCase, override_config + + +class BackgroundUpdateConfigTestCase(HomeserverTestCase): + # Tests that the default values in the config are correctly loaded. Note that the default + # values are loaded when the corresponding config options are commented out, which is why there isn't + # a config specified here. + def test_default_configuration(self): + background_updater = BackgroundUpdater( + self.hs, self.hs.get_datastores().main.db_pool + ) + + self.assertEqual(background_updater.minimum_background_batch_size, 1) + self.assertEqual(background_updater.default_background_batch_size, 100) + self.assertEqual(background_updater.sleep_enabled, True) + self.assertEqual(background_updater.sleep_duration_ms, 1000) + self.assertEqual(background_updater.update_duration_ms, 100) + + # Tests that non-default values for the config options are properly picked up and passed on. + @override_config( + yaml.safe_load( + """ + background_updates: + background_update_duration_ms: 1000 + sleep_enabled: false + sleep_duration_ms: 600 + min_batch_size: 5 + default_batch_size: 50 + """ + ) + ) + def test_custom_configuration(self): + background_updater = BackgroundUpdater( + self.hs, self.hs.get_datastores().main.db_pool + ) + + self.assertEqual(background_updater.minimum_background_batch_size, 5) + self.assertEqual(background_updater.default_background_batch_size, 50) + self.assertEqual(background_updater.sleep_enabled, False) + self.assertEqual(background_updater.sleep_duration_ms, 600) + self.assertEqual(background_updater.update_duration_ms, 1000) diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py index becec84524..6cf56b1e35 100644 --- a/tests/rest/admin/test_background_updates.py +++ b/tests/rest/admin/test_background_updates.py @@ -39,6 +39,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") + self.updater = BackgroundUpdater(hs, self.store.db_pool) @parameterized.expand( [ @@ -135,10 +136,10 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): """Test the status API works with a background update.""" # Create a new background update - self._register_bg_update() self.store.db_pool.updates.start_doing_background_updates() + self.reactor.pump([1.0, 1.0, 1.0]) channel = self.make_request( @@ -158,7 +159,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "average_items_per_ms": 0.1, "total_duration_ms": 1000.0, "total_item_count": ( - BackgroundUpdater.DEFAULT_BACKGROUND_BATCH_SIZE + self.updater.default_background_batch_size ), } }, @@ -213,7 +214,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "average_items_per_ms": 0.1, "total_duration_ms": 1000.0, "total_item_count": ( - BackgroundUpdater.DEFAULT_BACKGROUND_BATCH_SIZE + self.updater.default_background_batch_size ), } }, @@ -242,7 +243,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "average_items_per_ms": 0.1, "total_duration_ms": 1000.0, "total_item_count": ( - BackgroundUpdater.DEFAULT_BACKGROUND_BATCH_SIZE + self.updater.default_background_batch_size ), } }, diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 9fdf54ea31..5cf18b690e 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -14,12 +14,15 @@ from unittest.mock import Mock +import yaml + from twisted.internet.defer import Deferred, ensureDeferred from synapse.storage.background_updates import BackgroundUpdater from tests import unittest from tests.test_utils import make_awaitable, simple_async_mock +from tests.unittest import override_config class BackgroundUpdateTestCase(unittest.HomeserverTestCase): @@ -34,6 +37,19 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): self.updates.register_background_update_handler( "test_update", self.update_handler ) + self.store = self.hs.get_datastores().main + + async def update(self, progress, count): + duration_ms = 10 + await self.clock.sleep((count * duration_ms) / 1000) + progress = {"my_key": progress["my_key"] + 1} + await self.store.db_pool.runInteraction( + "update_progress", + self.updates._background_update_progress_txn, + "test_update", + progress, + ) + return count def test_do_background_update(self): # the time we claim it takes to update one item when running the update @@ -42,27 +58,14 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): # the target runtime for each bg update target_background_update_duration_ms = 100 - store = self.hs.get_datastores().main self.get_success( - store.db_pool.simple_insert( + self.store.db_pool.simple_insert( "background_updates", values={"update_name": "test_update", "progress_json": '{"my_key": 1}'}, ) ) - # first step: make a bit of progress - async def update(progress, count): - await self.clock.sleep((count * duration_ms) / 1000) - progress = {"my_key": progress["my_key"] + 1} - await store.db_pool.runInteraction( - "update_progress", - self.updates._background_update_progress_txn, - "test_update", - progress, - ) - return count - - self.update_handler.side_effect = update + self.update_handler.side_effect = self.update self.update_handler.reset_mock() res = self.get_success( self.updates.do_next_background_update(False), @@ -72,7 +75,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): # on the first call, we should get run with the default background update size self.update_handler.assert_called_once_with( - {"my_key": 1}, self.updates.DEFAULT_BACKGROUND_BATCH_SIZE + {"my_key": 1}, self.updates.default_background_batch_size ) # second step: complete the update @@ -99,6 +102,224 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): self.assertTrue(result) self.assertFalse(self.update_handler.called) + @override_config( + yaml.safe_load( + """ + background_updates: + default_batch_size: 20 + """ + ) + ) + def test_background_update_default_batch_set_by_config(self): + """ + Test that the background update is run with the default_batch_size set by the config + """ + + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + values={"update_name": "test_update", "progress_json": '{"my_key": 1}'}, + ) + ) + + self.update_handler.side_effect = self.update + self.update_handler.reset_mock() + res = self.get_success( + self.updates.do_next_background_update(False), + by=0.01, + ) + self.assertFalse(res) + + # on the first call, we should get run with the default background update size specified in the config + self.update_handler.assert_called_once_with({"my_key": 1}, 20) + + def test_background_update_default_sleep_behavior(self): + """ + Test default background update behavior, which is to sleep + """ + + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + values={"update_name": "test_update", "progress_json": '{"my_key": 1}'}, + ) + ) + + self.update_handler.side_effect = self.update + self.update_handler.reset_mock() + self.updates.start_doing_background_updates(), + + # 2: advance the reactor less than the default sleep duration (1000ms) + self.reactor.pump([0.5]) + # check that an update has not been run + self.update_handler.assert_not_called() + + # advance reactor past default sleep duration + self.reactor.pump([1]) + # check that update has been run + self.update_handler.assert_called() + + @override_config( + yaml.safe_load( + """ + background_updates: + sleep_duration_ms: 500 + """ + ) + ) + def test_background_update_sleep_set_in_config(self): + """ + Test that changing the sleep time in the config changes how long it sleeps + """ + + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + values={"update_name": "test_update", "progress_json": '{"my_key": 1}'}, + ) + ) + + self.update_handler.side_effect = self.update + self.update_handler.reset_mock() + self.updates.start_doing_background_updates(), + + # 2: advance the reactor less than the configured sleep duration (500ms) + self.reactor.pump([0.45]) + # check that an update has not been run + self.update_handler.assert_not_called() + + # advance reactor past config sleep duration but less than default duration + self.reactor.pump([0.75]) + # check that update has been run + self.update_handler.assert_called() + + @override_config( + yaml.safe_load( + """ + background_updates: + sleep_enabled: false + """ + ) + ) + def test_disabling_background_update_sleep(self): + """ + Test that disabling sleep in the config results in bg update not sleeping + """ + + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + values={"update_name": "test_update", "progress_json": '{"my_key": 1}'}, + ) + ) + + self.update_handler.side_effect = self.update + self.update_handler.reset_mock() + self.updates.start_doing_background_updates(), + + # 2: advance the reactor very little + self.reactor.pump([0.025]) + # check that an update has run + self.update_handler.assert_called() + + @override_config( + yaml.safe_load( + """ + background_updates: + background_update_duration_ms: 500 + """ + ) + ) + def test_background_update_duration_set_in_config(self): + """ + Test that the desired duration set in the config is used in determining batch size + """ + # Duration of one background update item + duration_ms = 10 + + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + values={"update_name": "test_update", "progress_json": '{"my_key": 1}'}, + ) + ) + + self.update_handler.side_effect = self.update + self.update_handler.reset_mock() + res = self.get_success( + self.updates.do_next_background_update(False), + by=0.02, + ) + self.assertFalse(res) + + # the first update was run with the default batch size, this should be run with 500ms as the + # desired duration + async def update(progress, count): + self.assertEqual(progress, {"my_key": 2}) + self.assertAlmostEqual( + count, + 500 / duration_ms, + places=0, + ) + await self.updates._end_background_update("test_update") + return count + + self.update_handler.side_effect = update + self.get_success(self.updates.do_next_background_update(False)) + + @override_config( + yaml.safe_load( + """ + background_updates: + min_batch_size: 5 + """ + ) + ) + def test_background_update_min_batch_set_in_config(self): + """ + Test that the minimum batch size set in the config is used + """ + # a very long-running individual update + duration_ms = 50 + + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + values={"update_name": "test_update", "progress_json": '{"my_key": 1}'}, + ) + ) + + # Run the update with the long-running update item + async def update(progress, count): + await self.clock.sleep((count * duration_ms) / 1000) + progress = {"my_key": progress["my_key"] + 1} + await self.store.db_pool.runInteraction( + "update_progress", + self.updates._background_update_progress_txn, + "test_update", + progress, + ) + return count + + self.update_handler.side_effect = update + self.update_handler.reset_mock() + res = self.get_success( + self.updates.do_next_background_update(False), + by=1, + ) + self.assertFalse(res) + + # the first update was run with the default batch size, this should be run with minimum batch size + # as the first items took a very long time + async def update(progress, count): + self.assertEqual(progress, {"my_key": 2}) + self.assertEqual(count, 5) + await self.updates._end_background_update("test_update") + return count + + self.update_handler.side_effect = update + self.get_success(self.updates.do_next_background_update(False)) + class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): -- cgit 1.5.1 From 8e5706d14448c0fe8d1c55eaca38a672c701d7a9 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Mon, 14 Mar 2022 17:52:58 +0000 Subject: Fix broken background updates when using sqlite with `enable_search` off (#12215) Signed-off-by: Sean Quah --- changelog.d/12215.bugfix | 1 + synapse/storage/databases/main/search.py | 13 +++++++------ 2 files changed, 8 insertions(+), 6 deletions(-) create mode 100644 changelog.d/12215.bugfix (limited to 'synapse/storage') diff --git a/changelog.d/12215.bugfix b/changelog.d/12215.bugfix new file mode 100644 index 0000000000..593b12556b --- /dev/null +++ b/changelog.d/12215.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in 1.54.0 that broke background updates on sqlite homeservers while search was disabled. diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index e23b119072..c5e9010c83 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -125,9 +125,6 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): ): super().__init__(database, db_conn, hs) - if not hs.config.server.enable_search: - return - self.db_pool.updates.register_background_update_handler( self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search ) @@ -243,9 +240,13 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): return len(event_search_rows) - result = await self.db_pool.runInteraction( - self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn - ) + if self.hs.config.server.enable_search: + result = await self.db_pool.runInteraction( + self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn + ) + else: + # Don't index anything if search is not enabled. + result = 0 if not result: await self.db_pool.updates._end_background_update( -- cgit 1.5.1 From dda9b7fc4d2e6ca84a1a994a7ff1943b590e71df Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 15 Mar 2022 14:06:05 -0400 Subject: Use the ignored_users table to test event visibility & sync. (#12225) Instead of fetching the raw account data and re-parsing it. The ignored_users table is a denormalised version of the account data for quick searching. --- changelog.d/12225.misc | 1 + synapse/handlers/sync.py | 30 ++----------------- synapse/push/bulk_push_rule_evaluator.py | 2 +- synapse/storage/databases/main/account_data.py | 41 ++++++++++++++++++++++++-- synapse/visibility.py | 18 ++--------- tests/storage/test_account_data.py | 17 +++++++++++ 6 files changed, 62 insertions(+), 47 deletions(-) create mode 100644 changelog.d/12225.misc (limited to 'synapse/storage') diff --git a/changelog.d/12225.misc b/changelog.d/12225.misc new file mode 100644 index 0000000000..23105c727c --- /dev/null +++ b/changelog.d/12225.misc @@ -0,0 +1 @@ +Use the `ignored_users` table in additional places instead of re-parsing the account data. diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 0aa3052fd6..c9d6a18bd7 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -28,7 +28,7 @@ from typing import ( import attr from prometheus_client import Counter -from synapse.api.constants import AccountDataTypes, EventTypes, Membership, ReceiptTypes +from synapse.api.constants import EventTypes, Membership, ReceiptTypes from synapse.api.filtering import FilterCollection from synapse.api.presence import UserPresenceState from synapse.api.room_versions import KNOWN_ROOM_VERSIONS @@ -1601,7 +1601,7 @@ class SyncHandler: return set(), set(), set(), set() # 3. Work out which rooms need reporting in the sync response. - ignored_users = await self._get_ignored_users(user_id) + ignored_users = await self.store.ignored_users(user_id) if since_token: room_changes = await self._get_rooms_changed( sync_result_builder, ignored_users @@ -1627,7 +1627,6 @@ class SyncHandler: logger.debug("Generating room entry for %s", room_entry.room_id) await self._generate_room_entry( sync_result_builder, - ignored_users, room_entry, ephemeral=ephemeral_by_room.get(room_entry.room_id, []), tags=tags_by_room.get(room_entry.room_id), @@ -1657,29 +1656,6 @@ class SyncHandler: newly_left_users, ) - async def _get_ignored_users(self, user_id: str) -> FrozenSet[str]: - """Retrieve the users ignored by the given user from their global account_data. - - Returns an empty set if - - there is no global account_data entry for ignored_users - - there is such an entry, but it's not a JSON object. - """ - # TODO: Can we `SELECT ignored_user_id FROM ignored_users WHERE ignorer_user_id=?;` instead? - ignored_account_data = ( - await self.store.get_global_account_data_by_type_for_user( - user_id=user_id, data_type=AccountDataTypes.IGNORED_USER_LIST - ) - ) - - # If there is ignored users account data and it matches the proper type, - # then use it. - ignored_users: FrozenSet[str] = frozenset() - if ignored_account_data: - ignored_users_data = ignored_account_data.get("ignored_users", {}) - if isinstance(ignored_users_data, dict): - ignored_users = frozenset(ignored_users_data.keys()) - return ignored_users - async def _have_rooms_changed( self, sync_result_builder: "SyncResultBuilder" ) -> bool: @@ -2022,7 +1998,6 @@ class SyncHandler: async def _generate_room_entry( self, sync_result_builder: "SyncResultBuilder", - ignored_users: FrozenSet[str], room_builder: "RoomSyncResultBuilder", ephemeral: List[JsonDict], tags: Optional[Dict[str, Dict[str, Any]]], @@ -2051,7 +2026,6 @@ class SyncHandler: Args: sync_result_builder - ignored_users: Set of users ignored by user. room_builder ephemeral: List of new ephemeral events for room tags: List of *all* tags for room, or None if there has been diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 8140afcb6b..030898e4d0 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -213,7 +213,7 @@ class BulkPushRuleEvaluator: if not event.is_state(): ignorers = await self.store.ignored_by(event.sender) else: - ignorers = set() + ignorers = frozenset() for uid, rules in rules_by_user.items(): if event.sender == uid: diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 52146aacc8..9af9f4f18e 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -14,7 +14,17 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, cast +from typing import ( + TYPE_CHECKING, + Any, + Dict, + FrozenSet, + Iterable, + List, + Optional, + Tuple, + cast, +) from synapse.api.constants import AccountDataTypes from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker @@ -365,7 +375,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) ) @cached(max_entries=5000, iterable=True) - async def ignored_by(self, user_id: str) -> Set[str]: + async def ignored_by(self, user_id: str) -> FrozenSet[str]: """ Get users which ignore the given user. @@ -375,7 +385,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) Return: The user IDs which ignore the given user. """ - return set( + return frozenset( await self.db_pool.simple_select_onecol( table="ignored_users", keyvalues={"ignored_user_id": user_id}, @@ -384,6 +394,26 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) ) ) + @cached(max_entries=5000, iterable=True) + async def ignored_users(self, user_id: str) -> FrozenSet[str]: + """ + Get users which the given user ignores. + + Params: + user_id: The user ID which is making the request. + + Return: + The user IDs which are ignored by the given user. + """ + return frozenset( + await self.db_pool.simple_select_onecol( + table="ignored_users", + keyvalues={"ignorer_user_id": user_id}, + retcol="ignored_user_id", + desc="ignored_users", + ) + ) + def process_replication_rows( self, stream_name: str, @@ -529,6 +559,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) else: currently_ignored_users = set() + # If the data has not changed, nothing to do. + if previously_ignored_users == currently_ignored_users: + return + # Delete entries which are no longer ignored. self.db_pool.simple_delete_many_txn( txn, @@ -551,6 +585,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) # Invalidate the cache for any ignored users which were added or removed. for ignored_user_id in previously_ignored_users ^ currently_ignored_users: self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,)) + self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,)) async def purge_account_data_for_user(self, user_id: str) -> None: """ diff --git a/synapse/visibility.py b/synapse/visibility.py index 281cbe4d88..49519eb8f5 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -14,12 +14,7 @@ import logging from typing import Dict, FrozenSet, List, Optional -from synapse.api.constants import ( - AccountDataTypes, - EventTypes, - HistoryVisibility, - Membership, -) +from synapse.api.constants import EventTypes, HistoryVisibility, Membership from synapse.events import EventBase from synapse.events.utils import prune_event from synapse.storage import Storage @@ -87,15 +82,8 @@ async def filter_events_for_client( state_filter=StateFilter.from_types(types), ) - ignore_dict_content = await storage.main.get_global_account_data_by_type_for_user( - user_id, AccountDataTypes.IGNORED_USER_LIST - ) - - ignore_list: FrozenSet[str] = frozenset() - if ignore_dict_content: - ignored_users_dict = ignore_dict_content.get("ignored_users", {}) - if isinstance(ignored_users_dict, dict): - ignore_list = frozenset(ignored_users_dict.keys()) + # Get the users who are ignored by the requesting user. + ignore_list = await storage.main.ignored_users(user_id) erased_senders = await storage.main.are_users_erased(e.sender for e in events) diff --git a/tests/storage/test_account_data.py b/tests/storage/test_account_data.py index 272cd35402..72bf5b3d31 100644 --- a/tests/storage/test_account_data.py +++ b/tests/storage/test_account_data.py @@ -47,9 +47,18 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase): expected_ignorer_user_ids, ) + def assert_ignored( + self, ignorer_user_id: str, expected_ignored_user_ids: Set[str] + ) -> None: + self.assertEqual( + self.get_success(self.store.ignored_users(ignorer_user_id)), + expected_ignored_user_ids, + ) + def test_ignoring_users(self): """Basic adding/removing of users from the ignore list.""" self._update_ignore_list("@other:test", "@another:remote") + self.assert_ignored(self.user, {"@other:test", "@another:remote"}) # Check a user which no one ignores. self.assert_ignorers("@user:test", set()) @@ -62,6 +71,7 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase): # Add one user, remove one user, and leave one user. self._update_ignore_list("@foo:test", "@another:remote") + self.assert_ignored(self.user, {"@foo:test", "@another:remote"}) # Check the removed user. self.assert_ignorers("@other:test", set()) @@ -76,20 +86,24 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase): """Ensure that caching works properly between different users.""" # The first user ignores a user. self._update_ignore_list("@other:test") + self.assert_ignored(self.user, {"@other:test"}) self.assert_ignorers("@other:test", {self.user}) # The second user ignores them. self._update_ignore_list("@other:test", ignorer_user_id="@second:test") + self.assert_ignored("@second:test", {"@other:test"}) self.assert_ignorers("@other:test", {self.user, "@second:test"}) # The first user un-ignores them. self._update_ignore_list() + self.assert_ignored(self.user, set()) self.assert_ignorers("@other:test", {"@second:test"}) def test_invalid_data(self): """Invalid data ends up clearing out the ignored users list.""" # Add some data and ensure it is there. self._update_ignore_list("@other:test") + self.assert_ignored(self.user, {"@other:test"}) self.assert_ignorers("@other:test", {self.user}) # No ignored_users key. @@ -102,10 +116,12 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase): ) # No one ignores the user now. + self.assert_ignored(self.user, set()) self.assert_ignorers("@other:test", set()) # Add some data and ensure it is there. self._update_ignore_list("@other:test") + self.assert_ignored(self.user, {"@other:test"}) self.assert_ignorers("@other:test", {self.user}) # Invalid data. @@ -118,4 +134,5 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase): ) # No one ignores the user now. + self.assert_ignored(self.user, set()) self.assert_ignorers("@other:test", set()) -- cgit 1.5.1 From c486fa5fd9082643e40a55ffa59d902aa6db4c2b Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 16 Mar 2022 10:37:04 -0400 Subject: Add some missing type hints to cache datastore. (#12216) --- changelog.d/12216.misc | 1 + synapse/storage/databases/main/cache.py | 57 +++++++++++++++++++++------------ 2 files changed, 37 insertions(+), 21 deletions(-) create mode 100644 changelog.d/12216.misc (limited to 'synapse/storage') diff --git a/changelog.d/12216.misc b/changelog.d/12216.misc new file mode 100644 index 0000000000..dc398ac1e0 --- /dev/null +++ b/changelog.d/12216.misc @@ -0,0 +1 @@ +Add missing type hints for cache storage. diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index d6a2df1afe..2d7511d613 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -23,6 +23,7 @@ from synapse.replication.tcp.streams.events import ( EventsStream, EventsStreamCurrentStateRow, EventsStreamEventRow, + EventsStreamRow, ) from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -31,6 +32,7 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.engines import PostgresEngine +from synapse.util.caches.descriptors import _CachedFunction from synapse.util.iterutils import batch_iter if TYPE_CHECKING: @@ -82,7 +84,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore): if last_id == current_id: return [], current_id, False - def get_all_updated_caches_txn(txn): + def get_all_updated_caches_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: # We purposefully don't bound by the current token, as we want to # send across cache invalidations as quickly as possible. Cache # invalidations are idempotent, so duplicates are fine. @@ -107,7 +111,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore): "get_all_updated_caches", get_all_updated_caches_txn ) - def process_replication_rows(self, stream_name, instance_name, token, rows): + def process_replication_rows( + self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] + ) -> None: if stream_name == EventsStream.NAME: for row in rows: self._process_event_stream_row(token, row) @@ -142,10 +148,11 @@ class CacheInvalidationWorkerStore(SQLBaseStore): super().process_replication_rows(stream_name, instance_name, token, rows) - def _process_event_stream_row(self, token, row): + def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None: data = row.data if row.type == EventsStreamEventRow.TypeId: + assert isinstance(data, EventsStreamEventRow) self._invalidate_caches_for_event( token, data.event_id, @@ -157,9 +164,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore): backfilled=False, ) elif row.type == EventsStreamCurrentStateRow.TypeId: - self._curr_state_delta_stream_cache.entity_has_changed( - row.data.room_id, token - ) + assert isinstance(data, EventsStreamCurrentStateRow) + self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token) if data.type == EventTypes.Member: self.get_rooms_for_user_with_stream_ordering.invalidate( @@ -170,15 +176,15 @@ class CacheInvalidationWorkerStore(SQLBaseStore): def _invalidate_caches_for_event( self, - stream_ordering, - event_id, - room_id, - etype, - state_key, - redacts, - relates_to, - backfilled, - ): + stream_ordering: int, + event_id: str, + room_id: str, + etype: str, + state_key: Optional[str], + redacts: Optional[str], + relates_to: Optional[str], + backfilled: bool, + ) -> None: self._invalidate_get_event_cache(event_id) self.have_seen_event.invalidate((room_id, event_id)) @@ -207,7 +213,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self.get_thread_summary.invalidate((relates_to,)) self.get_thread_participated.invalidate((relates_to,)) - async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]): + async def invalidate_cache_and_stream( + self, cache_name: str, keys: Tuple[Any, ...] + ) -> None: """Invalidates the cache and adds it to the cache stream so slaves will know to invalidate their caches. @@ -227,7 +235,12 @@ class CacheInvalidationWorkerStore(SQLBaseStore): keys, ) - def _invalidate_cache_and_stream(self, txn, cache_func, keys): + def _invalidate_cache_and_stream( + self, + txn: LoggingTransaction, + cache_func: _CachedFunction, + keys: Tuple[Any, ...], + ) -> None: """Invalidates the cache and adds it to the cache stream so slaves will know to invalidate their caches. @@ -238,7 +251,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore): txn.call_after(cache_func.invalidate, keys) self._send_invalidation_to_replication(txn, cache_func.__name__, keys) - def _invalidate_all_cache_and_stream(self, txn, cache_func): + def _invalidate_all_cache_and_stream( + self, txn: LoggingTransaction, cache_func: _CachedFunction + ) -> None: """Invalidates the entire cache and adds it to the cache stream so slaves will know to invalidate their caches. """ @@ -279,8 +294,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore): ) def _send_invalidation_to_replication( - self, txn, cache_name: str, keys: Optional[Iterable[Any]] - ): + self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]] + ) -> None: """Notifies replication that given cache has been invalidated. Note that this does *not* invalidate the cache locally. @@ -315,7 +330,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): "instance_name": self._instance_name, "cache_func": cache_name, "keys": keys, - "invalidation_ts": self.clock.time_msec(), + "invalidation_ts": self._clock.time_msec(), }, ) -- cgit 1.5.1 From 61210567405b1ac7efaa23d5513cc0b443da0a3a Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Wed, 16 Mar 2022 15:07:41 +0000 Subject: Handle cancellation in `DatabasePool.runInteraction()` (#12199) To handle cancellation, we ensure that `after_callback`s and `exception_callback`s are always run, since the transaction will complete on another thread regardless of cancellation. We also wait until everything is done before releasing the `CancelledError`, so that logging contexts won't get used after they have been finished. Signed-off-by: Sean Quah --- changelog.d/12199.misc | 1 + synapse/storage/database.py | 61 +++++++++++++++++++++++++----------------- tests/storage/test_database.py | 58 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 96 insertions(+), 24 deletions(-) create mode 100644 changelog.d/12199.misc (limited to 'synapse/storage') diff --git a/changelog.d/12199.misc b/changelog.d/12199.misc new file mode 100644 index 0000000000..16dec1d26d --- /dev/null +++ b/changelog.d/12199.misc @@ -0,0 +1 @@ +Handle cancellation in `DatabasePool.runInteraction()`. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 99802228c9..9749f0c06e 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -41,6 +41,7 @@ from prometheus_client import Histogram from typing_extensions import Literal from twisted.enterprise import adbapi +from twisted.internet import defer from synapse.api.errors import StoreError from synapse.config.database import DatabaseConnectionConfig @@ -55,6 +56,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.types import Connection, Cursor +from synapse.util.async_helpers import delay_cancellation from synapse.util.iterutils import batch_iter if TYPE_CHECKING: @@ -732,34 +734,45 @@ class DatabasePool: Returns: The result of func """ - after_callbacks: List[_CallbackListEntry] = [] - exception_callbacks: List[_CallbackListEntry] = [] - if not current_context(): - logger.warning("Starting db txn '%s' from sentinel context", desc) + async def _runInteraction() -> R: + after_callbacks: List[_CallbackListEntry] = [] + exception_callbacks: List[_CallbackListEntry] = [] - try: - with opentracing.start_active_span(f"db.{desc}"): - result = await self.runWithConnection( - self.new_transaction, - desc, - after_callbacks, - exception_callbacks, - func, - *args, - db_autocommit=db_autocommit, - isolation_level=isolation_level, - **kwargs, - ) + if not current_context(): + logger.warning("Starting db txn '%s' from sentinel context", desc) - for after_callback, after_args, after_kwargs in after_callbacks: - after_callback(*after_args, **after_kwargs) - except Exception: - for after_callback, after_args, after_kwargs in exception_callbacks: - after_callback(*after_args, **after_kwargs) - raise + try: + with opentracing.start_active_span(f"db.{desc}"): + result = await self.runWithConnection( + self.new_transaction, + desc, + after_callbacks, + exception_callbacks, + func, + *args, + db_autocommit=db_autocommit, + isolation_level=isolation_level, + **kwargs, + ) - return cast(R, result) + for after_callback, after_args, after_kwargs in after_callbacks: + after_callback(*after_args, **after_kwargs) + + return cast(R, result) + except Exception: + for after_callback, after_args, after_kwargs in exception_callbacks: + after_callback(*after_args, **after_kwargs) + raise + + # To handle cancellation, we ensure that `after_callback`s and + # `exception_callback`s are always run, since the transaction will complete + # on another thread regardless of cancellation. + # + # We also wait until everything above is done before releasing the + # `CancelledError`, so that logging contexts won't get used after they have been + # finished. + return await delay_cancellation(defer.ensureDeferred(_runInteraction())) async def runWithConnection( self, diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py index ae13bed086..a40fc20ef9 100644 --- a/tests/storage/test_database.py +++ b/tests/storage/test_database.py @@ -15,6 +15,8 @@ from typing import Callable, Tuple from unittest.mock import Mock, call +from twisted.internet import defer +from twisted.internet.defer import CancelledError, Deferred from twisted.test.proto_helpers import MemoryReactor from synapse.server import HomeServer @@ -124,3 +126,59 @@ class CallbacksTestCase(unittest.HomeserverTestCase): ) self.assertEqual(after_callback.call_count, 2) # no additional calls exception_callback.assert_not_called() + + +class CancellationTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.db_pool: DatabasePool = self.store.db_pool + + def test_after_callback(self) -> None: + """Test that the after callback is called when a transaction succeeds.""" + d: "Deferred[None]" + after_callback = Mock() + exception_callback = Mock() + + def _test_txn(txn: LoggingTransaction) -> None: + txn.call_after(after_callback, 123, 456, extra=789) + txn.call_on_exception(exception_callback, 987, 654, extra=321) + d.cancel() + + d = defer.ensureDeferred( + self.db_pool.runInteraction("test_transaction", _test_txn) + ) + self.get_failure(d, CancelledError) + + after_callback.assert_called_once_with(123, 456, extra=789) + exception_callback.assert_not_called() + + def test_exception_callback(self) -> None: + """Test that the exception callback is called when a transaction fails.""" + d: "Deferred[None]" + after_callback = Mock() + exception_callback = Mock() + + def _test_txn(txn: LoggingTransaction) -> None: + txn.call_after(after_callback, 123, 456, extra=789) + txn.call_on_exception(exception_callback, 987, 654, extra=321) + d.cancel() + # Simulate a retryable failure on every attempt. + raise self.db_pool.engine.module.OperationalError() + + d = defer.ensureDeferred( + self.db_pool.runInteraction("test_transaction", _test_txn) + ) + self.get_failure(d, CancelledError) + + after_callback.assert_not_called() + exception_callback.assert_has_calls( + [ + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + ] + ) + self.assertEqual(exception_callback.call_count, 6) # no additional calls -- cgit 1.5.1 From 872dbb0181714e201be082c4e8bd9b727c73f177 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 18 Mar 2022 13:51:41 +0000 Subject: Correct `check_username_for_spam` annotations and docs (#12246) * Formally type the UserProfile in user searches * export UserProfile in synapse.module_api * Update docs Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> --- changelog.d/12246.doc | 1 + docs/modules/spam_checker_callbacks.md | 10 ++++++---- synapse/events/spamcheck.py | 7 +++---- synapse/handlers/user_directory.py | 4 ++-- synapse/module_api/__init__.py | 2 ++ synapse/rest/client/user_directory.py | 4 ++-- synapse/storage/databases/main/user_directory.py | 23 +++++++++++++++++++---- synapse/types.py | 11 +++++++++++ 8 files changed, 46 insertions(+), 16 deletions(-) create mode 100644 changelog.d/12246.doc (limited to 'synapse/storage') diff --git a/changelog.d/12246.doc b/changelog.d/12246.doc new file mode 100644 index 0000000000..e7fcc1b99c --- /dev/null +++ b/changelog.d/12246.doc @@ -0,0 +1 @@ +Correct `check_username_for_spam` annotations and docs. \ No newline at end of file diff --git a/docs/modules/spam_checker_callbacks.md b/docs/modules/spam_checker_callbacks.md index 2b672b78f9..472d957180 100644 --- a/docs/modules/spam_checker_callbacks.md +++ b/docs/modules/spam_checker_callbacks.md @@ -172,7 +172,7 @@ any of the subsequent implementations of this callback. _First introduced in Synapse v1.37.0_ ```python -async def check_username_for_spam(user_profile: Dict[str, str]) -> bool +async def check_username_for_spam(user_profile: synapse.module_api.UserProfile) -> bool ``` Called when computing search results in the user directory. The module must return a @@ -182,9 +182,11 @@ search results; otherwise return `False`. The profile is represented as a dictionary with the following keys: -* `user_id`: The Matrix ID for this user. -* `display_name`: The user's display name. -* `avatar_url`: The `mxc://` URL to the user's avatar. +* `user_id: str`. The Matrix ID for this user. +* `display_name: Optional[str]`. The user's display name, or `None` if this user + has not set a display name. +* `avatar_url: Optional[str]`. The `mxc://` URL to the user's avatar, or `None` + if this user has not set an avatar. The module is given a copy of the original dictionary, so modifying it from within the module cannot modify a user's profile when included in user directory search results. diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 60904a55f5..cd80fcf9d1 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -21,7 +21,6 @@ from typing import ( Awaitable, Callable, Collection, - Dict, List, Optional, Tuple, @@ -31,7 +30,7 @@ from typing import ( from synapse.rest.media.v1._base import FileInfo from synapse.rest.media.v1.media_storage import ReadableFileWrapper from synapse.spam_checker_api import RegistrationBehaviour -from synapse.types import RoomAlias +from synapse.types import RoomAlias, UserProfile from synapse.util.async_helpers import maybe_awaitable if TYPE_CHECKING: @@ -50,7 +49,7 @@ USER_MAY_SEND_3PID_INVITE_CALLBACK = Callable[[str, str, str, str], Awaitable[bo USER_MAY_CREATE_ROOM_CALLBACK = Callable[[str], Awaitable[bool]] USER_MAY_CREATE_ROOM_ALIAS_CALLBACK = Callable[[str, RoomAlias], Awaitable[bool]] USER_MAY_PUBLISH_ROOM_CALLBACK = Callable[[str, str], Awaitable[bool]] -CHECK_USERNAME_FOR_SPAM_CALLBACK = Callable[[Dict[str, str]], Awaitable[bool]] +CHECK_USERNAME_FOR_SPAM_CALLBACK = Callable[[UserProfile], Awaitable[bool]] LEGACY_CHECK_REGISTRATION_FOR_SPAM_CALLBACK = Callable[ [ Optional[dict], @@ -383,7 +382,7 @@ class SpamChecker: return True - async def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool: + async def check_username_for_spam(self, user_profile: UserProfile) -> bool: """Checks if a user ID or display name are considered "spammy" by this server. If the server considers a username spammy, then it will not be included in diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index d27ed2be6a..048fd4bb82 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -19,8 +19,8 @@ import synapse.metrics from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules, Membership from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.databases.main.user_directory import SearchResult from synapse.storage.roommember import ProfileInfo -from synapse.types import JsonDict from synapse.util.metrics import Measure if TYPE_CHECKING: @@ -78,7 +78,7 @@ class UserDirectoryHandler(StateDeltasHandler): async def search_users( self, user_id: str, search_term: str, limit: int - ) -> JsonDict: + ) -> SearchResult: """Searches for users in directory Returns: diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index d735c1d461..aa8256b36f 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -111,6 +111,7 @@ from synapse.types import ( StateMap, UserID, UserInfo, + UserProfile, create_requester, ) from synapse.util import Clock @@ -150,6 +151,7 @@ __all__ = [ "EventBase", "StateMap", "ProfileInfo", + "UserProfile", ] logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/user_directory.py b/synapse/rest/client/user_directory.py index a47d9bd01d..116c982ce6 100644 --- a/synapse/rest/client/user_directory.py +++ b/synapse/rest/client/user_directory.py @@ -19,7 +19,7 @@ from synapse.api.errors import SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.site import SynapseRequest -from synapse.types import JsonDict +from synapse.types import JsonMapping from ._base import client_patterns @@ -38,7 +38,7 @@ class UserDirectorySearchRestServlet(RestServlet): self.auth = hs.get_auth() self.user_directory_handler = hs.get_user_directory_handler() - async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonMapping]: """Searches for users in directory Returns: diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index e7fddd2426..55cc9178f0 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -26,6 +26,8 @@ from typing import ( cast, ) +from typing_extensions import TypedDict + from synapse.api.errors import StoreError if TYPE_CHECKING: @@ -40,7 +42,12 @@ from synapse.storage.database import ( from synapse.storage.databases.main.state import StateFilter from synapse.storage.databases.main.state_deltas import StateDeltasStore from synapse.storage.engines import PostgresEngine, Sqlite3Engine -from synapse.types import JsonDict, get_domain_from_id, get_localpart_from_id +from synapse.types import ( + JsonDict, + UserProfile, + get_domain_from_id, + get_localpart_from_id, +) from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -591,6 +598,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): ) +class SearchResult(TypedDict): + limited: bool + results: List[UserProfile] + + class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): # How many records do we calculate before sending it to # add_users_who_share_private_rooms? @@ -777,7 +789,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): async def search_user_dir( self, user_id: str, search_term: str, limit: int - ) -> JsonDict: + ) -> SearchResult: """Searches for users in directory Returns: @@ -910,8 +922,11 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): # This should be unreachable. raise Exception("Unrecognized database engine") - results = await self.db_pool.execute( - "search_user_dir", self.db_pool.cursor_to_dict, sql, *args + results = cast( + List[UserProfile], + await self.db_pool.execute( + "search_user_dir", self.db_pool.cursor_to_dict, sql, *args + ), ) limited = len(results) > limit diff --git a/synapse/types.py b/synapse/types.py index 53be3583a0..5ce2a5b0a5 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -34,6 +34,7 @@ from typing import ( import attr from frozendict import frozendict from signedjson.key import decode_verify_key_bytes +from typing_extensions import TypedDict from unpaddedbase64 import decode_base64 from zope.interface import Interface @@ -63,6 +64,10 @@ MutableStateMap = MutableMapping[StateKey, T] # JSON types. These could be made stronger, but will do for now. # A JSON-serialisable dict. JsonDict = Dict[str, Any] +# A JSON-serialisable mapping; roughly speaking an immutable JSONDict. +# Useful when you have a TypedDict which isn't going to be mutated and you don't want +# to cast to JsonDict everywhere. +JsonMapping = Mapping[str, Any] # A JSON-serialisable object. JsonSerializable = object @@ -791,3 +796,9 @@ class UserInfo: is_deactivated: bool is_guest: bool is_shadow_banned: bool + + +class UserProfile(TypedDict): + user_id: str + display_name: Optional[str] + avatar_url: Optional[str] -- cgit 1.5.1 From c46065fa3d6ad000f5da6e196c769371e0e76ec5 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Fri, 18 Mar 2022 16:24:18 +0100 Subject: Add some type hints to datastore (#12248) * inherit `MonthlyActiveUsersStore` from `RegistrationWorkerStore` Co-authored-by: Patrick Cloke --- changelog.d/12248.misc | 1 + mypy.ini | 6 - synapse/storage/databases/main/group_server.py | 156 +++++++++++++-------- .../storage/databases/main/monthly_active_users.py | 38 ++--- 4 files changed, 117 insertions(+), 84 deletions(-) create mode 100644 changelog.d/12248.misc (limited to 'synapse/storage') diff --git a/changelog.d/12248.misc b/changelog.d/12248.misc new file mode 100644 index 0000000000..2b1290d1e1 --- /dev/null +++ b/changelog.d/12248.misc @@ -0,0 +1 @@ +Add missing type hints for storage. \ No newline at end of file diff --git a/mypy.ini b/mypy.ini index 51f47ff5be..d8b3b3f9e5 100644 --- a/mypy.ini +++ b/mypy.ini @@ -42,9 +42,6 @@ exclude = (?x) |synapse/storage/databases/main/cache.py |synapse/storage/databases/main/devices.py |synapse/storage/databases/main/event_federation.py - |synapse/storage/databases/main/group_server.py - |synapse/storage/databases/main/metrics.py - |synapse/storage/databases/main/monthly_active_users.py |synapse/storage/databases/main/push_rule.py |synapse/storage/databases/main/receipts.py |synapse/storage/databases/main/roommember.py @@ -87,9 +84,6 @@ exclude = (?x) |tests/state/test_v2.py |tests/storage/test_background_update.py |tests/storage/test_base.py - |tests/storage/test_client_ips.py - |tests/storage/test_database.py - |tests/storage/test_event_federation.py |tests/storage/test_id_generators.py |tests/storage/test_roommember.py |tests/test_metrics.py diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index 3f6086050b..0aef121d83 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -13,13 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast from typing_extensions import TypedDict from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool, LoggingDatabaseConnection +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.types import JsonDict from synapse.util import json_encoder @@ -75,7 +79,7 @@ class GroupServerWorkerStore(SQLBaseStore): ) -> List[Dict[str, Any]]: # TODO: Pagination - keyvalues = {"group_id": group_id} + keyvalues: JsonDict = {"group_id": group_id} if not include_private: keyvalues["is_public"] = True @@ -117,7 +121,7 @@ class GroupServerWorkerStore(SQLBaseStore): # TODO: Pagination - def _get_rooms_in_group_txn(txn): + def _get_rooms_in_group_txn(txn: LoggingTransaction) -> List[_RoomInGroup]: sql = """ SELECT room_id, is_public FROM group_rooms WHERE group_id = ? @@ -176,8 +180,10 @@ class GroupServerWorkerStore(SQLBaseStore): * "order": int, the sort order of rooms in this category """ - def _get_rooms_for_summary_txn(txn): - keyvalues = {"group_id": group_id} + def _get_rooms_for_summary_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: + keyvalues: JsonDict = {"group_id": group_id} if not include_private: keyvalues["is_public"] = True @@ -241,7 +247,7 @@ class GroupServerWorkerStore(SQLBaseStore): "get_rooms_for_summary", _get_rooms_for_summary_txn ) - async def get_group_categories(self, group_id): + async def get_group_categories(self, group_id: str) -> JsonDict: rows = await self.db_pool.simple_select_list( table="group_room_categories", keyvalues={"group_id": group_id}, @@ -257,7 +263,7 @@ class GroupServerWorkerStore(SQLBaseStore): for row in rows } - async def get_group_category(self, group_id, category_id): + async def get_group_category(self, group_id: str, category_id: str) -> JsonDict: category = await self.db_pool.simple_select_one( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, @@ -269,7 +275,7 @@ class GroupServerWorkerStore(SQLBaseStore): return category - async def get_group_roles(self, group_id): + async def get_group_roles(self, group_id: str) -> JsonDict: rows = await self.db_pool.simple_select_list( table="group_roles", keyvalues={"group_id": group_id}, @@ -285,7 +291,7 @@ class GroupServerWorkerStore(SQLBaseStore): for row in rows } - async def get_group_role(self, group_id, role_id): + async def get_group_role(self, group_id: str, role_id: str) -> JsonDict: role = await self.db_pool.simple_select_one( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, @@ -311,15 +317,19 @@ class GroupServerWorkerStore(SQLBaseStore): desc="get_local_groups_for_room", ) - async def get_users_for_summary_by_role(self, group_id, include_private=False): + async def get_users_for_summary_by_role( + self, group_id: str, include_private: bool = False + ) -> Tuple[List[JsonDict], JsonDict]: """Get the users and roles that should be included in a summary request Returns: ([users], [roles]) """ - def _get_users_for_summary_txn(txn): - keyvalues = {"group_id": group_id} + def _get_users_for_summary_txn( + txn: LoggingTransaction, + ) -> Tuple[List[JsonDict], JsonDict]: + keyvalues: JsonDict = {"group_id": group_id} if not include_private: keyvalues["is_public"] = True @@ -406,7 +416,9 @@ class GroupServerWorkerStore(SQLBaseStore): allow_none=True, ) - async def get_users_membership_info_in_group(self, group_id, user_id): + async def get_users_membership_info_in_group( + self, group_id: str, user_id: str + ) -> JsonDict: """Get a dict describing the membership of a user in a group. Example if joined: @@ -421,7 +433,7 @@ class GroupServerWorkerStore(SQLBaseStore): An empty dict if the user is not join/invite/etc """ - def _get_users_membership_in_group_txn(txn): + def _get_users_membership_in_group_txn(txn: LoggingTransaction) -> JsonDict: row = self.db_pool.simple_select_one_txn( txn, table="group_users", @@ -463,10 +475,14 @@ class GroupServerWorkerStore(SQLBaseStore): desc="get_publicised_groups_for_user", ) - async def get_attestations_need_renewals(self, valid_until_ms): + async def get_attestations_need_renewals( + self, valid_until_ms: int + ) -> List[Dict[str, Any]]: """Get all attestations that need to be renewed until givent time""" - def _get_attestations_need_renewals_txn(txn): + def _get_attestations_need_renewals_txn( + txn: LoggingTransaction, + ) -> List[Dict[str, Any]]: sql = """ SELECT group_id, user_id FROM group_attestations_renewals WHERE valid_until_ms <= ? @@ -478,7 +494,9 @@ class GroupServerWorkerStore(SQLBaseStore): "get_attestations_need_renewals", _get_attestations_need_renewals_txn ) - async def get_remote_attestation(self, group_id, user_id): + async def get_remote_attestation( + self, group_id: str, user_id: str + ) -> Optional[JsonDict]: """Get the attestation that proves the remote agrees that the user is in the group. """ @@ -504,8 +522,8 @@ class GroupServerWorkerStore(SQLBaseStore): desc="get_joined_groups", ) - async def get_all_groups_for_user(self, user_id, now_token): - def _get_all_groups_for_user_txn(txn): + async def get_all_groups_for_user(self, user_id, now_token) -> List[JsonDict]: + def _get_all_groups_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]: sql = """ SELECT group_id, type, membership, u.content FROM local_group_updates AS u @@ -528,15 +546,16 @@ class GroupServerWorkerStore(SQLBaseStore): "get_all_groups_for_user", _get_all_groups_for_user_txn ) - async def get_groups_changes_for_user(self, user_id, from_token, to_token): - from_token = int(from_token) - has_changed = self._group_updates_stream_cache.has_entity_changed( + async def get_groups_changes_for_user( + self, user_id: str, from_token: int, to_token: int + ) -> List[JsonDict]: + has_changed = self._group_updates_stream_cache.has_entity_changed( # type: ignore[attr-defined] user_id, from_token ) if not has_changed: return [] - def _get_groups_changes_for_user_txn(txn): + def _get_groups_changes_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]: sql = """ SELECT group_id, membership, type, u.content FROM local_group_updates AS u @@ -583,12 +602,14 @@ class GroupServerWorkerStore(SQLBaseStore): """ last_id = int(last_id) - has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id) + has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id) # type: ignore[attr-defined] if not has_changed: return [], current_id, False - def _get_all_groups_changes_txn(txn): + def _get_all_groups_changes_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: sql = """ SELECT stream_id, group_id, user_id, type, content FROM local_group_updates @@ -596,10 +617,13 @@ class GroupServerWorkerStore(SQLBaseStore): LIMIT ? """ txn.execute(sql, (last_id, current_id, limit)) - updates = [ - (stream_id, (group_id, user_id, gtype, db_to_json(content_json))) - for stream_id, group_id, user_id, gtype, content_json in txn - ] + updates = cast( + List[Tuple[int, tuple]], + [ + (stream_id, (group_id, user_id, gtype, db_to_json(content_json))) + for stream_id, group_id, user_id, gtype, content_json in txn + ], + ) limited = False upto_token = current_id @@ -633,8 +657,8 @@ class GroupServerStore(GroupServerWorkerStore): self, group_id: str, room_id: str, - category_id: str, - order: int, + category_id: Optional[str], + order: Optional[int], is_public: Optional[bool], ) -> None: """Add (or update) room's entry in summary. @@ -661,11 +685,11 @@ class GroupServerStore(GroupServerWorkerStore): def _add_room_to_summary_txn( self, - txn, + txn: LoggingTransaction, group_id: str, room_id: str, - category_id: str, - order: int, + category_id: Optional[str], + order: Optional[int], is_public: Optional[bool], ) -> None: """Add (or update) room's entry in summary. @@ -750,7 +774,7 @@ class GroupServerStore(GroupServerWorkerStore): WHERE group_id = ? AND category_id = ? """ txn.execute(sql, (group_id, category_id)) - (order,) = txn.fetchone() + (order,) = cast(Tuple[int], txn.fetchone()) if existing: to_update = {} @@ -766,7 +790,7 @@ class GroupServerStore(GroupServerWorkerStore): "category_id": category_id, "room_id": room_id, }, - values=to_update, + updatevalues=to_update, ) else: if is_public is None: @@ -785,7 +809,7 @@ class GroupServerStore(GroupServerWorkerStore): ) async def remove_room_from_summary( - self, group_id: str, room_id: str, category_id: str + self, group_id: str, room_id: str, category_id: Optional[str] ) -> int: if category_id is None: category_id = _DEFAULT_CATEGORY_ID @@ -808,8 +832,8 @@ class GroupServerStore(GroupServerWorkerStore): is_public: Optional[bool], ) -> None: """Add/update room category for group""" - insertion_values = {} - update_values = {"category_id": category_id} # This cannot be empty + insertion_values: JsonDict = {} + update_values: JsonDict = {"category_id": category_id} # This cannot be empty if profile is None: insertion_values["profile"] = "{}" @@ -844,8 +868,8 @@ class GroupServerStore(GroupServerWorkerStore): is_public: Optional[bool], ) -> None: """Add/remove user role""" - insertion_values = {} - update_values = {"role_id": role_id} # This cannot be empty + insertion_values: JsonDict = {} + update_values: JsonDict = {"role_id": role_id} # This cannot be empty if profile is None: insertion_values["profile"] = "{}" @@ -876,8 +900,8 @@ class GroupServerStore(GroupServerWorkerStore): self, group_id: str, user_id: str, - role_id: str, - order: int, + role_id: Optional[str], + order: Optional[int], is_public: Optional[bool], ) -> None: """Add (or update) user's entry in summary. @@ -904,13 +928,13 @@ class GroupServerStore(GroupServerWorkerStore): def _add_user_to_summary_txn( self, - txn, + txn: LoggingTransaction, group_id: str, user_id: str, - role_id: str, - order: int, + role_id: Optional[str], + order: Optional[int], is_public: Optional[bool], - ): + ) -> None: """Add (or update) user's entry in summary. Args: @@ -989,7 +1013,7 @@ class GroupServerStore(GroupServerWorkerStore): WHERE group_id = ? AND role_id = ? """ txn.execute(sql, (group_id, role_id)) - (order,) = txn.fetchone() + (order,) = cast(Tuple[int], txn.fetchone()) if existing: to_update = {} @@ -1005,7 +1029,7 @@ class GroupServerStore(GroupServerWorkerStore): "role_id": role_id, "user_id": user_id, }, - values=to_update, + updatevalues=to_update, ) else: if is_public is None: @@ -1024,7 +1048,7 @@ class GroupServerStore(GroupServerWorkerStore): ) async def remove_user_from_summary( - self, group_id: str, user_id: str, role_id: str + self, group_id: str, user_id: str, role_id: Optional[str] ) -> int: if role_id is None: role_id = _DEFAULT_ROLE_ID @@ -1065,7 +1089,7 @@ class GroupServerStore(GroupServerWorkerStore): Optional if the user and group are on the same server """ - def _add_user_to_group_txn(txn): + def _add_user_to_group_txn(txn: LoggingTransaction) -> None: self.db_pool.simple_insert_txn( txn, table="group_users", @@ -1108,7 +1132,7 @@ class GroupServerStore(GroupServerWorkerStore): await self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn) async def remove_user_from_group(self, group_id: str, user_id: str) -> None: - def _remove_user_from_group_txn(txn): + def _remove_user_from_group_txn(txn: LoggingTransaction) -> None: self.db_pool.simple_delete_txn( txn, table="group_users", @@ -1159,7 +1183,7 @@ class GroupServerStore(GroupServerWorkerStore): ) async def remove_room_from_group(self, group_id: str, room_id: str) -> None: - def _remove_room_from_group_txn(txn): + def _remove_room_from_group_txn(txn: LoggingTransaction) -> None: self.db_pool.simple_delete_txn( txn, table="group_rooms", @@ -1216,7 +1240,9 @@ class GroupServerStore(GroupServerWorkerStore): content = content or {} - def _register_user_group_membership_txn(txn, next_id): + def _register_user_group_membership_txn( + txn: LoggingTransaction, next_id: int + ) -> int: # TODO: Upsert? self.db_pool.simple_delete_txn( txn, @@ -1249,7 +1275,7 @@ class GroupServerStore(GroupServerWorkerStore): ), }, ) - self._group_updates_stream_cache.entity_has_changed(user_id, next_id) + self._group_updates_stream_cache.entity_has_changed(user_id, next_id) # type: ignore[attr-defined] # TODO: Insert profile to ensure it comes down stream if its a join. @@ -1289,7 +1315,7 @@ class GroupServerStore(GroupServerWorkerStore): return next_id - async with self._group_updates_id_gen.get_next() as next_id: + async with self._group_updates_id_gen.get_next() as next_id: # type: ignore[attr-defined] res = await self.db_pool.runInteraction( "register_user_group_membership", _register_user_group_membership_txn, @@ -1298,7 +1324,13 @@ class GroupServerStore(GroupServerWorkerStore): return res async def create_group( - self, group_id, user_id, name, avatar_url, short_description, long_description + self, + group_id: str, + user_id: str, + name: str, + avatar_url: str, + short_description: str, + long_description: str, ) -> None: await self.db_pool.simple_insert( table="groups", @@ -1313,7 +1345,7 @@ class GroupServerStore(GroupServerWorkerStore): desc="create_group", ) - async def update_group_profile(self, group_id, profile): + async def update_group_profile(self, group_id: str, profile: JsonDict) -> None: await self.db_pool.simple_update_one( table="groups", keyvalues={"group_id": group_id}, @@ -1361,8 +1393,8 @@ class GroupServerStore(GroupServerWorkerStore): desc="remove_attestation_renewal", ) - def get_group_stream_token(self): - return self._group_updates_id_gen.get_current_token() + def get_group_stream_token(self) -> int: + return self._group_updates_id_gen.get_current_token() # type: ignore[attr-defined] async def delete_group(self, group_id: str) -> None: """Deletes a group fully from the database. @@ -1371,7 +1403,7 @@ class GroupServerStore(GroupServerWorkerStore): group_id: The group ID to delete. """ - def _delete_group_txn(txn): + def _delete_group_txn(txn: LoggingTransaction) -> None: tables = [ "groups", "group_users", diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index e9a0cdc6be..216622964a 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -12,15 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, List, Optional +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, + LoggingTransaction, make_in_list_sql_clause, ) +from synapse.storage.databases.main.registration import RegistrationWorkerStore from synapse.util.caches.descriptors import cached from synapse.util.threepids import canonicalise_email @@ -56,7 +58,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): Number of current monthly active users """ - def _count_users(txn): + def _count_users(txn: LoggingTransaction) -> int: # Exclude app service users sql = """ SELECT COUNT(*) @@ -66,7 +68,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): WHERE (users.appservice_id IS NULL OR users.appservice_id = ''); """ txn.execute(sql) - (count,) = txn.fetchone() + (count,) = cast(Tuple[int], txn.fetchone()) return count return await self.db_pool.runInteraction("count_users", _count_users) @@ -84,7 +86,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): """ - def _count_users_by_service(txn): + def _count_users_by_service(txn: LoggingTransaction) -> Dict[str, int]: sql = """ SELECT COALESCE(appservice_id, 'native'), COUNT(*) FROM monthly_active_users @@ -93,7 +95,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): """ txn.execute(sql) - result = txn.fetchall() + result = cast(List[Tuple[str, int]], txn.fetchall()) return dict(result) return await self.db_pool.runInteraction( @@ -141,12 +143,12 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): ) @wrap_as_background_process("reap_monthly_active_users") - async def reap_monthly_active_users(self): + async def reap_monthly_active_users(self) -> None: """Cleans out monthly active user table to ensure that no stale entries exist. """ - def _reap_users(txn, reserved_users): + def _reap_users(txn: LoggingTransaction, reserved_users: List[str]) -> None: """ Args: reserved_users (tuple): reserved users to preserve @@ -210,10 +212,10 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): # is racy. # Have resolved to invalidate the whole cache for now and do # something about it if and when the perf becomes significant - self._invalidate_all_cache_and_stream( + self._invalidate_all_cache_and_stream( # type: ignore[attr-defined] txn, self.user_last_seen_monthly_active ) - self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) + self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) # type: ignore[attr-defined] reserved_users = await self.get_registered_reserved_users() await self.db_pool.runInteraction( @@ -221,7 +223,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): ) -class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): +class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerStore): def __init__( self, database: DatabasePool, @@ -242,13 +244,15 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value], ) - def _initialise_reserved_users(self, txn, threepids): + def _initialise_reserved_users( + self, txn: LoggingTransaction, threepids: List[dict] + ) -> None: """Ensures that reserved threepids are accounted for in the MAU table, should be called on start up. Args: - txn (cursor): - threepids (list[dict]): List of threepid dicts to reserve + txn: + threepids: List of threepid dicts to reserve """ # XXX what is this function trying to achieve? It upserts into @@ -299,7 +303,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id ) - def upsert_monthly_active_user_txn(self, txn, user_id): + def upsert_monthly_active_user_txn( + self, txn: LoggingTransaction, user_id: str + ) -> None: """Updates or inserts monthly active user member We consciously do not call is_support_txn from this method because it @@ -336,7 +342,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): txn, self.user_last_seen_monthly_active, (user_id,) ) - async def populate_monthly_active_users(self, user_id): + async def populate_monthly_active_users(self, user_id: str) -> None: """Checks on the state of monthly active user limits and optionally add the user to the monthly active tables @@ -345,7 +351,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): """ if self._limit_usage_by_mau or self._mau_stats_only: # Trial users and guests should not be included as part of MAU group - is_guest = await self.is_guest(user_id) + is_guest = await self.is_guest(user_id) # type: ignore[attr-defined] if is_guest: return is_trial = await self.is_trial_user(user_id) -- cgit 1.5.1 From 80e0e1f35e6b1cdfa0267f9c40a6f212b7d774de Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 18 Mar 2022 13:15:45 -0400 Subject: Only fetch thread participation for events with threads. (#12228) We fetch the thread summary in two phases: 1. The summary that is shared by all users (count of messages and latest event). 2. Whether the requesting user has participated in the thread. There's no use in attempting step 2 for events which did not return a summary from step 1. --- changelog.d/12228.bugfix | 1 + synapse/storage/databases/main/relations.py | 4 +- tests/rest/client/test_relations.py | 509 +++++++++++++++------------- tests/server.py | 20 +- 4 files changed, 289 insertions(+), 245 deletions(-) create mode 100644 changelog.d/12228.bugfix (limited to 'synapse/storage') diff --git a/changelog.d/12228.bugfix b/changelog.d/12228.bugfix new file mode 100644 index 0000000000..4755777139 --- /dev/null +++ b/changelog.d/12228.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in v1.53.0 where an unnecessary query could be performed when fetching bundled aggregations for threads. diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index c4869d64e6..af2334a65e 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -857,7 +857,9 @@ class RelationsWorkerStore(SQLBaseStore): summaries = await self._get_thread_summaries(events_by_id.keys()) # Only fetch participated for a limited selection based on what had # summaries. - participated = await self._get_threads_participated(summaries.keys(), user_id) + participated = await self._get_threads_participated( + [event_id for event_id, summary in summaries.items() if summary], user_id + ) for event_id, summary in summaries.items(): if summary: thread_count, latest_thread_event, edit = summary diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index f3741b3001..329690f8f7 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -15,7 +15,7 @@ import itertools import urllib.parse -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple from unittest.mock import patch from twisted.test.proto_helpers import MemoryReactor @@ -155,6 +155,16 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, channel.json_body) return channel.json_body["chunk"] + def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict: + """ + Find the parent event in a chunk of events and assert that it has the proper bundled aggregations. + """ + for event in events: + if event["event_id"] == self.parent_id: + return event + + raise AssertionError(f"Event {self.parent_id} not found in chunk") + class RelationsTestCase(BaseRelationsTestCase): def test_send_relation(self) -> None: @@ -291,202 +301,6 @@ class RelationsTestCase(BaseRelationsTestCase): ) self.assertEqual(400, channel.code, channel.json_body) - @unittest.override_config({"experimental_features": {"msc3666_enabled": True}}) - def test_bundled_aggregations(self) -> None: - """ - Test that annotations, references, and threads get correctly bundled. - - Note that this doesn't test against /relations since only thread relations - get bundled via that API. See test_aggregation_get_event_for_thread. - - See test_edit for a similar test for edits. - """ - # Setup by sending a variety of relations. - self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token - ) - self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - - channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") - reply_1 = channel.json_body["event_id"] - - channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") - reply_2 = channel.json_body["event_id"] - - self._send_relation(RelationTypes.THREAD, "m.room.test") - - channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - thread_2 = channel.json_body["event_id"] - - def assert_bundle(event_json: JsonDict) -> None: - """Assert the expected values of the bundled aggregations.""" - relations_dict = event_json["unsigned"].get("m.relations") - - # Ensure the fields are as expected. - self.assertCountEqual( - relations_dict.keys(), - ( - RelationTypes.ANNOTATION, - RelationTypes.REFERENCE, - RelationTypes.THREAD, - ), - ) - - # Check the values of each field. - self.assertEqual( - { - "chunk": [ - {"type": "m.reaction", "key": "a", "count": 2}, - {"type": "m.reaction", "key": "b", "count": 1}, - ] - }, - relations_dict[RelationTypes.ANNOTATION], - ) - - self.assertEqual( - {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]}, - relations_dict[RelationTypes.REFERENCE], - ) - - self.assertEqual( - 2, - relations_dict[RelationTypes.THREAD].get("count"), - ) - self.assertTrue( - relations_dict[RelationTypes.THREAD].get("current_user_participated") - ) - # The latest thread event has some fields that don't matter. - self.assert_dict( - { - "content": { - "m.relates_to": { - "event_id": self.parent_id, - "rel_type": RelationTypes.THREAD, - } - }, - "event_id": thread_2, - "sender": self.user_id, - "type": "m.room.test", - }, - relations_dict[RelationTypes.THREAD].get("latest_event"), - ) - - # Request the event directly. - channel = self.make_request( - "GET", - f"/rooms/{self.room}/event/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - assert_bundle(channel.json_body) - - # Request the room messages. - channel = self.make_request( - "GET", - f"/rooms/{self.room}/messages?dir=b", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"])) - - # Request the room context. - channel = self.make_request( - "GET", - f"/rooms/{self.room}/context/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - assert_bundle(channel.json_body["event"]) - - # Request sync. - channel = self.make_request("GET", "/sync", access_token=self.user_token) - self.assertEqual(200, channel.code, channel.json_body) - room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] - self.assertTrue(room_timeline["limited"]) - assert_bundle(self._find_event_in_chunk(room_timeline["events"])) - - # Request search. - channel = self.make_request( - "POST", - "/search", - # Search term matches the parent message. - content={"search_categories": {"room_events": {"search_term": "Hi"}}}, - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - chunk = [ - result["result"] - for result in channel.json_body["search_categories"]["room_events"][ - "results" - ] - ] - assert_bundle(self._find_event_in_chunk(chunk)) - - def test_aggregation_get_event_for_annotation(self) -> None: - """Test that annotations do not get bundled aggregations included - when directly requested. - """ - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - annotation_id = channel.json_body["event_id"] - - # Annotate the annotation. - self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id - ) - - channel = self.make_request( - "GET", - f"/rooms/{self.room}/event/{annotation_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - self.assertIsNone(channel.json_body["unsigned"].get("m.relations")) - - def test_aggregation_get_event_for_thread(self) -> None: - """Test that threads get bundled aggregations included when directly requested.""" - channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - thread_id = channel.json_body["event_id"] - - # Annotate the annotation. - self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id - ) - - channel = self.make_request( - "GET", - f"/rooms/{self.room}/event/{thread_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - self.assertEqual( - channel.json_body["unsigned"].get("m.relations"), - { - RelationTypes.ANNOTATION: { - "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}] - }, - }, - ) - - # It should also be included when the entire thread is requested. - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - self.assertEqual(len(channel.json_body["chunk"]), 1) - - thread_message = channel.json_body["chunk"][0] - self.assertEqual( - thread_message["unsigned"].get("m.relations"), - { - RelationTypes.ANNOTATION: { - "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}] - }, - }, - ) - def test_ignore_invalid_room(self) -> None: """Test that we ignore invalid relations over federation.""" # Create another room and send a message in it. @@ -796,7 +610,7 @@ class RelationsTestCase(BaseRelationsTestCase): threaded_event_id = channel.json_body["event_id"] new_body = {"msgtype": "m.text", "body": "I've been edited!"} - channel = self._send_relation( + self._send_relation( RelationTypes.REPLACE, "m.room.message", content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, @@ -836,7 +650,7 @@ class RelationsTestCase(BaseRelationsTestCase): edit_event_id = channel.json_body["event_id"] # Edit the edit event. - channel = self._send_relation( + self._send_relation( RelationTypes.REPLACE, "m.room.message", content={ @@ -912,16 +726,6 @@ class RelationsTestCase(BaseRelationsTestCase): self.assertEqual(200, channel.code, channel.json_body) self.assertEqual(channel.json_body["chunk"], []) - def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict: - """ - Find the parent event in a chunk of events and assert that it has the proper bundled aggregations. - """ - for event in events: - if event["event_id"] == self.parent_id: - return event - - raise AssertionError(f"Event {self.parent_id} not found in chunk") - def test_background_update(self) -> None: """Test the event_arbitrary_relations background update.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") @@ -981,34 +785,6 @@ class RelationsTestCase(BaseRelationsTestCase): [annotation_event_id_good, thread_event_id], ) - def test_bundled_aggregations_with_filter(self) -> None: - """ - If "unsigned" is an omitted field (due to filtering), adding the bundled - aggregations should not break. - - Note that the spec allows for a server to return additional fields beyond - what is specified. - """ - self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - - # Note that the sync filter does not include "unsigned" as a field. - filter = urllib.parse.quote_plus( - b'{"event_fields": ["content", "event_id"], "room": {"timeline": {"limit": 3}}}' - ) - channel = self.make_request( - "GET", f"/sync?filter={filter}", access_token=self.user_token - ) - self.assertEqual(200, channel.code, channel.json_body) - - # Ensure the timeline is limited, find the parent event. - room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] - self.assertTrue(room_timeline["limited"]) - parent_event = self._find_event_in_chunk(room_timeline["events"]) - - # Ensure there's bundled aggregations on it. - self.assertIn("unsigned", parent_event) - self.assertIn("m.relations", parent_event["unsigned"]) - class RelationPaginationTestCase(BaseRelationsTestCase): def test_basic_paginate_relations(self) -> None: @@ -1255,7 +1031,7 @@ class RelationPaginationTestCase(BaseRelationsTestCase): idx += 1 # Also send a different type of reaction so that we test we don't see it - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") prev_token = "" found_event_ids: List[str] = [] @@ -1291,6 +1067,263 @@ class RelationPaginationTestCase(BaseRelationsTestCase): self.assertEqual(found_event_ids, expected_event_ids) +class BundledAggregationsTestCase(BaseRelationsTestCase): + """ + See RelationsTestCase.test_edit for a similar test for edits. + + Note that this doesn't test against /relations since only thread relations + get bundled via that API. See test_aggregation_get_event_for_thread. + """ + + def _test_bundled_aggregations( + self, + relation_type: str, + assertion_callable: Callable[[JsonDict], None], + expected_db_txn_for_event: int, + ) -> None: + """ + Makes requests to various endpoints which should include bundled aggregations + and then calls an assertion function on the bundled aggregations. + + Args: + relation_type: The field to search for in the `m.relations` field in unsigned. + assertion_callable: Called with the contents of unsigned["m.relations"][relation_type] + for relation-specific assertions. + expected_db_txn_for_event: The number of database transactions which + are expected for a call to /event/. + """ + + def assert_bundle(event_json: JsonDict) -> None: + """Assert the expected values of the bundled aggregations.""" + relations_dict = event_json["unsigned"].get("m.relations") + + # Ensure the fields are as expected. + self.assertCountEqual(relations_dict.keys(), (relation_type,)) + assertion_callable(relations_dict[relation_type]) + + # Request the event directly. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + assert_bundle(channel.json_body) + assert channel.resource_usage is not None + self.assertEqual(channel.resource_usage.db_txn_count, expected_db_txn_for_event) + + # Request the room messages. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/messages?dir=b", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"])) + + # Request the room context. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/context/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + assert_bundle(channel.json_body["event"]) + + # Request sync. + filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 4}}}') + channel = self.make_request( + "GET", f"/sync?filter={filter}", access_token=self.user_token + ) + self.assertEqual(200, channel.code, channel.json_body) + room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] + self.assertTrue(room_timeline["limited"]) + assert_bundle(self._find_event_in_chunk(room_timeline["events"])) + + # Request search. + channel = self.make_request( + "POST", + "/search", + # Search term matches the parent message. + content={"search_categories": {"room_events": {"search_term": "Hi"}}}, + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + chunk = [ + result["result"] + for result in channel.json_body["search_categories"]["room_events"][ + "results" + ] + ] + assert_bundle(self._find_event_in_chunk(chunk)) + + @unittest.override_config({"experimental_features": {"msc3666_enabled": True}}) + def test_annotation(self) -> None: + """ + Test that annotations get correctly bundled. + """ + # Setup by sending a variety of relations. + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token + ) + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") + + def assert_annotations(bundled_aggregations: JsonDict) -> None: + self.assertEqual( + { + "chunk": [ + {"type": "m.reaction", "key": "a", "count": 2}, + {"type": "m.reaction", "key": "b", "count": 1}, + ] + }, + bundled_aggregations, + ) + + self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 7) + + @unittest.override_config({"experimental_features": {"msc3666_enabled": True}}) + def test_reference(self) -> None: + """ + Test that references get correctly bundled. + """ + channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") + reply_1 = channel.json_body["event_id"] + + channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") + reply_2 = channel.json_body["event_id"] + + def assert_annotations(bundled_aggregations: JsonDict) -> None: + self.assertEqual( + {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]}, + bundled_aggregations, + ) + + self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 7) + + @unittest.override_config({"experimental_features": {"msc3666_enabled": True}}) + def test_thread(self) -> None: + """ + Test that threads get correctly bundled. + """ + self._send_relation(RelationTypes.THREAD, "m.room.test") + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + thread_2 = channel.json_body["event_id"] + + def assert_annotations(bundled_aggregations: JsonDict) -> None: + self.assertEqual(2, bundled_aggregations.get("count")) + self.assertTrue(bundled_aggregations.get("current_user_participated")) + # The latest thread event has some fields that don't matter. + self.assert_dict( + { + "content": { + "m.relates_to": { + "event_id": self.parent_id, + "rel_type": RelationTypes.THREAD, + } + }, + "event_id": thread_2, + "sender": self.user_id, + "type": "m.room.test", + }, + bundled_aggregations.get("latest_event"), + ) + + self._test_bundled_aggregations(RelationTypes.THREAD, assert_annotations, 9) + + def test_aggregation_get_event_for_annotation(self) -> None: + """Test that annotations do not get bundled aggregations included + when directly requested. + """ + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + annotation_id = channel.json_body["event_id"] + + # Annotate the annotation. + self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id + ) + + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{annotation_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + self.assertIsNone(channel.json_body["unsigned"].get("m.relations")) + + def test_aggregation_get_event_for_thread(self) -> None: + """Test that threads get bundled aggregations included when directly requested.""" + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + thread_id = channel.json_body["event_id"] + + # Annotate the annotation. + self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id + ) + + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{thread_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + self.assertEqual( + channel.json_body["unsigned"].get("m.relations"), + { + RelationTypes.ANNOTATION: { + "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}] + }, + }, + ) + + # It should also be included when the entire thread is requested. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + self.assertEqual(len(channel.json_body["chunk"]), 1) + + thread_message = channel.json_body["chunk"][0] + self.assertEqual( + thread_message["unsigned"].get("m.relations"), + { + RelationTypes.ANNOTATION: { + "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}] + }, + }, + ) + + def test_bundled_aggregations_with_filter(self) -> None: + """ + If "unsigned" is an omitted field (due to filtering), adding the bundled + aggregations should not break. + + Note that the spec allows for a server to return additional fields beyond + what is specified. + """ + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + + # Note that the sync filter does not include "unsigned" as a field. + filter = urllib.parse.quote_plus( + b'{"event_fields": ["content", "event_id"], "room": {"timeline": {"limit": 3}}}' + ) + channel = self.make_request( + "GET", f"/sync?filter={filter}", access_token=self.user_token + ) + self.assertEqual(200, channel.code, channel.json_body) + + # Ensure the timeline is limited, find the parent event. + room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] + self.assertTrue(room_timeline["limited"]) + parent_event = self._find_event_in_chunk(room_timeline["events"]) + + # Ensure there's bundled aggregations on it. + self.assertIn("unsigned", parent_event) + self.assertIn("m.relations", parent_event["unsigned"]) + + class RelationRedactionTestCase(BaseRelationsTestCase): """ Test the behaviour of relations when the parent or child event is redacted. diff --git a/tests/server.py b/tests/server.py index 82990c2eb9..6ce2a17bf4 100644 --- a/tests/server.py +++ b/tests/server.py @@ -54,13 +54,18 @@ from twisted.internet.interfaces import ( ITransport, ) from twisted.python.failure import Failure -from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock +from twisted.test.proto_helpers import ( + AccumulatingProtocol, + MemoryReactor, + MemoryReactorClock, +) from twisted.web.http_headers import Headers from twisted.web.resource import IResource from twisted.web.server import Request, Site from synapse.config.database import DatabaseConnectionConfig from synapse.http.site import SynapseRequest +from synapse.logging.context import ContextResourceUsage from synapse.server import HomeServer from synapse.storage import DataStore from synapse.storage.engines import PostgresEngine, create_engine @@ -88,18 +93,19 @@ class TimedOutException(Exception): """ -@attr.s +@attr.s(auto_attribs=True) class FakeChannel: """ A fake Twisted Web Channel (the part that interfaces with the wire). """ - site = attr.ib(type=Union[Site, "FakeSite"]) - _reactor = attr.ib() - result = attr.ib(type=dict, default=attr.Factory(dict)) - _ip = attr.ib(type=str, default="127.0.0.1") + site: Union[Site, "FakeSite"] + _reactor: MemoryReactor + result: dict = attr.Factory(dict) + _ip: str = "127.0.0.1" _producer: Optional[Union[IPullProducer, IPushProducer]] = None + resource_usage: Optional[ContextResourceUsage] = None @property def json_body(self): @@ -168,6 +174,8 @@ class FakeChannel: def requestDone(self, _self): self.result["done"] = True + if isinstance(_self, SynapseRequest): + self.resource_usage = _self.logcontext.get_resource_usage() def getPeer(self): # We give an address so that getClientIP returns a non null entry, -- cgit 1.5.1 From 8fe930c215f69913fbcd96d609ec6950644e4ec4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 18 Mar 2022 13:49:32 -0400 Subject: Move get_bundled_aggregations to relations handler. (#12237) The get_bundled_aggregations code is fairly high-level and uses a lot of store methods, we move it into the handler as that seems like a better fit. --- changelog.d/12237.misc | 1 + synapse/events/utils.py | 2 +- synapse/handlers/pagination.py | 5 +- synapse/handlers/relations.py | 151 +++++++++++++++++++++++++++- synapse/handlers/room.py | 5 +- synapse/handlers/search.py | 3 +- synapse/handlers/sync.py | 9 +- synapse/rest/client/room.py | 3 +- synapse/storage/databases/main/relations.py | 151 +--------------------------- 9 files changed, 173 insertions(+), 157 deletions(-) create mode 100644 changelog.d/12237.misc (limited to 'synapse/storage') diff --git a/changelog.d/12237.misc b/changelog.d/12237.misc new file mode 100644 index 0000000000..41c9dcbd37 --- /dev/null +++ b/changelog.d/12237.misc @@ -0,0 +1 @@ +Refactor the relations endpoints to add a `RelationsHandler`. diff --git a/synapse/events/utils.py b/synapse/events/utils.py index a0520068e0..7120062127 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -38,8 +38,8 @@ from synapse.util.frozenutils import unfreeze from . import EventBase if TYPE_CHECKING: + from synapse.handlers.relations import BundledAggregations from synapse.server import HomeServer - from synapse.storage.databases.main.relations import BundledAggregations # Split strings on "." but not "\." This uses a negative lookbehind assertion for '\' diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 41679f7f86..876b879483 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -134,6 +134,7 @@ class PaginationHandler: self.clock = hs.get_clock() self._server_name = hs.hostname self._room_shutdown_handler = hs.get_room_shutdown_handler() + self._relations_handler = hs.get_relations_handler() self.pagination_lock = ReadWriteLock() # IDs of rooms in which there currently an active purge *or delete* operation. @@ -539,7 +540,9 @@ class PaginationHandler: state_dict = await self.store.get_events(list(state_ids.values())) state = state_dict.values() - aggregations = await self.store.get_bundled_aggregations(events, user_id) + aggregations = await self._relations_handler.get_bundled_aggregations( + events, user_id + ) time_now = self.clock.time_msec() diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 8e475475ad..57135d4519 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -12,18 +12,53 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Dict, Iterable, Optional, cast +import attr +from frozendict import frozendict + +from synapse.api.constants import RelationTypes from synapse.api.errors import SynapseError +from synapse.events import EventBase from synapse.types import JsonDict, Requester, StreamToken if TYPE_CHECKING: from synapse.server import HomeServer + from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _ThreadAggregation: + # The latest event in the thread. + latest_event: EventBase + # The latest edit to the latest event in the thread. + latest_edit: Optional[EventBase] + # The total number of events in the thread. + count: int + # True if the current user has sent an event to the thread. + current_user_participated: bool + + +@attr.s(slots=True, auto_attribs=True) +class BundledAggregations: + """ + The bundled aggregations for an event. + + Some values require additional processing during serialization. + """ + + annotations: Optional[JsonDict] = None + references: Optional[JsonDict] = None + replace: Optional[EventBase] = None + thread: Optional[_ThreadAggregation] = None + + def __bool__(self) -> bool: + return bool(self.annotations or self.references or self.replace or self.thread) + + class RelationsHandler: def __init__(self, hs: "HomeServer"): self._main_store = hs.get_datastores().main @@ -103,7 +138,7 @@ class RelationsHandler: ) # The relations returned for the requested event do include their # bundled aggregations. - aggregations = await self._main_store.get_bundled_aggregations( + aggregations = await self.get_bundled_aggregations( events, requester.user.to_string() ) serialized_events = self._event_serializer.serialize_events( @@ -115,3 +150,115 @@ class RelationsHandler: return_value["original_event"] = original_event return return_value + + async def _get_bundled_aggregation_for_event( + self, event: EventBase, user_id: str + ) -> Optional[BundledAggregations]: + """Generate bundled aggregations for an event. + + Note that this does not use a cache, but depends on cached methods. + + Args: + event: The event to calculate bundled aggregations for. + user_id: The user requesting the bundled aggregations. + + Returns: + The bundled aggregations for an event, if bundled aggregations are + enabled and the event can have bundled aggregations. + """ + + # Do not bundle aggregations for an event which represents an edit or an + # annotation. It does not make sense for them to have related events. + relates_to = event.content.get("m.relates_to") + if isinstance(relates_to, (dict, frozendict)): + relation_type = relates_to.get("rel_type") + if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE): + return None + + event_id = event.event_id + room_id = event.room_id + + # The bundled aggregations to include, a mapping of relation type to a + # type-specific value. Some types include the direct return type here + # while others need more processing during serialization. + aggregations = BundledAggregations() + + annotations = await self._main_store.get_aggregation_groups_for_event( + event_id, room_id + ) + if annotations.chunk: + aggregations.annotations = await annotations.to_dict( + cast("DataStore", self) + ) + + references = await self._main_store.get_relations_for_event( + event_id, event, room_id, RelationTypes.REFERENCE, direction="f" + ) + if references.chunk: + aggregations.references = await references.to_dict(cast("DataStore", self)) + + # Store the bundled aggregations in the event metadata for later use. + return aggregations + + async def get_bundled_aggregations( + self, events: Iterable[EventBase], user_id: str + ) -> Dict[str, BundledAggregations]: + """Generate bundled aggregations for events. + + Args: + events: The iterable of events to calculate bundled aggregations for. + user_id: The user requesting the bundled aggregations. + + Returns: + A map of event ID to the bundled aggregation for the event. Not all + events may have bundled aggregations in the results. + """ + # De-duplicate events by ID to handle the same event requested multiple times. + # + # State events do not get bundled aggregations. + events_by_id = { + event.event_id: event for event in events if not event.is_state() + } + + # event ID -> bundled aggregation in non-serialized form. + results: Dict[str, BundledAggregations] = {} + + # Fetch other relations per event. + for event in events_by_id.values(): + event_result = await self._get_bundled_aggregation_for_event(event, user_id) + if event_result: + results[event.event_id] = event_result + + # Fetch any edits (but not for redacted events). + edits = await self._main_store.get_applicable_edits( + [ + event_id + for event_id, event in events_by_id.items() + if not event.internal_metadata.is_redacted() + ] + ) + for event_id, edit in edits.items(): + results.setdefault(event_id, BundledAggregations()).replace = edit + + # Fetch thread summaries. + summaries = await self._main_store.get_thread_summaries(events_by_id.keys()) + # Only fetch participated for a limited selection based on what had + # summaries. + participated = await self._main_store.get_threads_participated( + [event_id for event_id, summary in summaries.items() if summary], user_id + ) + for event_id, summary in summaries.items(): + if summary: + thread_count, latest_thread_event, edit = summary + results.setdefault( + event_id, BundledAggregations() + ).thread = _ThreadAggregation( + latest_event=latest_thread_event, + latest_edit=edit, + count=thread_count, + # If there's a thread summary it must also exist in the + # participated dictionary. + current_user_participated=participated[event_id], + ) + + return results diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index b9735631fc..092e185c99 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -60,8 +60,8 @@ from synapse.events import EventBase from synapse.events.utils import copy_power_levels_contents from synapse.federation.federation_client import InvalidResponseError from synapse.handlers.federation import get_domains_from_state +from synapse.handlers.relations import BundledAggregations from synapse.rest.admin._base import assert_user_is_admin -from synapse.storage.databases.main.relations import BundledAggregations from synapse.storage.state import StateFilter from synapse.streams import EventSource from synapse.types import ( @@ -1118,6 +1118,7 @@ class RoomContextHandler: self.store = hs.get_datastores().main self.storage = hs.get_storage() self.state_store = self.storage.state + self._relations_handler = hs.get_relations_handler() async def get_event_context( self, @@ -1190,7 +1191,7 @@ class RoomContextHandler: event = filtered[0] # Fetch the aggregations. - aggregations = await self.store.get_bundled_aggregations( + aggregations = await self._relations_handler.get_bundled_aggregations( itertools.chain(events_before, (event,), events_after), user.to_string(), ) diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index aa16e417eb..30eddda65f 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -54,6 +54,7 @@ class SearchHandler: self.clock = hs.get_clock() self.hs = hs self._event_serializer = hs.get_event_client_serializer() + self._relations_handler = hs.get_relations_handler() self.storage = hs.get_storage() self.state_store = self.storage.state self.auth = hs.get_auth() @@ -354,7 +355,7 @@ class SearchHandler: aggregations = None if self._msc3666_enabled: - aggregations = await self.store.get_bundled_aggregations( + aggregations = await self._relations_handler.get_bundled_aggregations( # Generate an iterable of EventBase for all the events that will be # returned, including contextual events. itertools.chain( diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index c9d6a18bd7..6c569cfb1c 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -33,11 +33,11 @@ from synapse.api.filtering import FilterCollection from synapse.api.presence import UserPresenceState from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase +from synapse.handlers.relations import BundledAggregations from synapse.logging.context import current_context from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span from synapse.push.clientformat import format_push_rules_for_user from synapse.storage.databases.main.event_push_actions import NotifCounts -from synapse.storage.databases.main.relations import BundledAggregations from synapse.storage.roommember import MemberSummary from synapse.storage.state import StateFilter from synapse.types import ( @@ -269,6 +269,7 @@ class SyncHandler: self.store = hs.get_datastores().main self.notifier = hs.get_notifier() self.presence_handler = hs.get_presence_handler() + self._relations_handler = hs.get_relations_handler() self.event_sources = hs.get_event_sources() self.clock = hs.get_clock() self.state = hs.get_state_handler() @@ -638,8 +639,10 @@ class SyncHandler: # as clients will have all the necessary information. bundled_aggregations = None if limited or newly_joined_room: - bundled_aggregations = await self.store.get_bundled_aggregations( - recents, sync_config.user.to_string() + bundled_aggregations = ( + await self._relations_handler.get_bundled_aggregations( + recents, sync_config.user.to_string() + ) ) return TimelineBatch( diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 8a06ab8c5f..47e152c8cc 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -645,6 +645,7 @@ class RoomEventServlet(RestServlet): self._store = hs.get_datastores().main self.event_handler = hs.get_event_handler() self._event_serializer = hs.get_event_client_serializer() + self._relations_handler = hs.get_relations_handler() self.auth = hs.get_auth() async def on_GET( @@ -663,7 +664,7 @@ class RoomEventServlet(RestServlet): if event: # Ensure there are bundled aggregations available. - aggregations = await self._store.get_bundled_aggregations( + aggregations = await self._relations_handler.get_bundled_aggregations( [event], requester.user.to_string() ) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index af2334a65e..b2295fd51f 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -27,7 +27,6 @@ from typing import ( ) import attr -from frozendict import frozendict from synapse.api.constants import RelationTypes from synapse.events import EventBase @@ -41,45 +40,15 @@ from synapse.storage.database import ( from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.engines import PostgresEngine from synapse.storage.relations import AggregationPaginationToken, PaginationChunk -from synapse.types import JsonDict, RoomStreamToken, StreamToken +from synapse.types import RoomStreamToken, StreamToken from synapse.util.caches.descriptors import cached, cachedList if TYPE_CHECKING: from synapse.server import HomeServer - from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) -@attr.s(slots=True, frozen=True, auto_attribs=True) -class _ThreadAggregation: - # The latest event in the thread. - latest_event: EventBase - # The latest edit to the latest event in the thread. - latest_edit: Optional[EventBase] - # The total number of events in the thread. - count: int - # True if the current user has sent an event to the thread. - current_user_participated: bool - - -@attr.s(slots=True, auto_attribs=True) -class BundledAggregations: - """ - The bundled aggregations for an event. - - Some values require additional processing during serialization. - """ - - annotations: Optional[JsonDict] = None - references: Optional[JsonDict] = None - replace: Optional[EventBase] = None - thread: Optional[_ThreadAggregation] = None - - def __bool__(self) -> bool: - return bool(self.annotations or self.references or self.replace or self.thread) - - class RelationsWorkerStore(SQLBaseStore): def __init__( self, @@ -384,7 +353,7 @@ class RelationsWorkerStore(SQLBaseStore): raise NotImplementedError() @cachedList(cached_method_name="get_applicable_edit", list_name="event_ids") - async def _get_applicable_edits( + async def get_applicable_edits( self, event_ids: Collection[str] ) -> Dict[str, Optional[EventBase]]: """Get the most recent edit (if any) that has happened for the given @@ -473,7 +442,7 @@ class RelationsWorkerStore(SQLBaseStore): raise NotImplementedError() @cachedList(cached_method_name="get_thread_summary", list_name="event_ids") - async def _get_thread_summaries( + async def get_thread_summaries( self, event_ids: Collection[str] ) -> Dict[str, Optional[Tuple[int, EventBase, Optional[EventBase]]]]: """Get the number of threaded replies, the latest reply (if any), and the latest edit for that reply for the given event. @@ -587,7 +556,7 @@ class RelationsWorkerStore(SQLBaseStore): latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined] # Check to see if any of those events are edited. - latest_edits = await self._get_applicable_edits(latest_event_ids.values()) + latest_edits = await self.get_applicable_edits(latest_event_ids.values()) # Map to the event IDs to the thread summary. # @@ -610,7 +579,7 @@ class RelationsWorkerStore(SQLBaseStore): raise NotImplementedError() @cachedList(cached_method_name="get_thread_participated", list_name="event_ids") - async def _get_threads_participated( + async def get_threads_participated( self, event_ids: Collection[str], user_id: str ) -> Dict[str, bool]: """Get whether the requesting user participated in the given threads. @@ -766,116 +735,6 @@ class RelationsWorkerStore(SQLBaseStore): "get_if_user_has_annotated_event", _get_if_user_has_annotated_event ) - async def _get_bundled_aggregation_for_event( - self, event: EventBase, user_id: str - ) -> Optional[BundledAggregations]: - """Generate bundled aggregations for an event. - - Note that this does not use a cache, but depends on cached methods. - - Args: - event: The event to calculate bundled aggregations for. - user_id: The user requesting the bundled aggregations. - - Returns: - The bundled aggregations for an event, if bundled aggregations are - enabled and the event can have bundled aggregations. - """ - - # Do not bundle aggregations for an event which represents an edit or an - # annotation. It does not make sense for them to have related events. - relates_to = event.content.get("m.relates_to") - if isinstance(relates_to, (dict, frozendict)): - relation_type = relates_to.get("rel_type") - if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE): - return None - - event_id = event.event_id - room_id = event.room_id - - # The bundled aggregations to include, a mapping of relation type to a - # type-specific value. Some types include the direct return type here - # while others need more processing during serialization. - aggregations = BundledAggregations() - - annotations = await self.get_aggregation_groups_for_event(event_id, room_id) - if annotations.chunk: - aggregations.annotations = await annotations.to_dict( - cast("DataStore", self) - ) - - references = await self.get_relations_for_event( - event_id, event, room_id, RelationTypes.REFERENCE, direction="f" - ) - if references.chunk: - aggregations.references = await references.to_dict(cast("DataStore", self)) - - # Store the bundled aggregations in the event metadata for later use. - return aggregations - - async def get_bundled_aggregations( - self, events: Iterable[EventBase], user_id: str - ) -> Dict[str, BundledAggregations]: - """Generate bundled aggregations for events. - - Args: - events: The iterable of events to calculate bundled aggregations for. - user_id: The user requesting the bundled aggregations. - - Returns: - A map of event ID to the bundled aggregation for the event. Not all - events may have bundled aggregations in the results. - """ - # De-duplicate events by ID to handle the same event requested multiple times. - # - # State events do not get bundled aggregations. - events_by_id = { - event.event_id: event for event in events if not event.is_state() - } - - # event ID -> bundled aggregation in non-serialized form. - results: Dict[str, BundledAggregations] = {} - - # Fetch other relations per event. - for event in events_by_id.values(): - event_result = await self._get_bundled_aggregation_for_event(event, user_id) - if event_result: - results[event.event_id] = event_result - - # Fetch any edits (but not for redacted events). - edits = await self._get_applicable_edits( - [ - event_id - for event_id, event in events_by_id.items() - if not event.internal_metadata.is_redacted() - ] - ) - for event_id, edit in edits.items(): - results.setdefault(event_id, BundledAggregations()).replace = edit - - # Fetch thread summaries. - summaries = await self._get_thread_summaries(events_by_id.keys()) - # Only fetch participated for a limited selection based on what had - # summaries. - participated = await self._get_threads_participated( - [event_id for event_id, summary in summaries.items() if summary], user_id - ) - for event_id, summary in summaries.items(): - if summary: - thread_count, latest_thread_event, edit = summary - results.setdefault( - event_id, BundledAggregations() - ).thread = _ThreadAggregation( - latest_event=latest_thread_event, - latest_edit=edit, - count=thread_count, - # If there's a thread summary it must also exist in the - # participated dictionary. - current_user_participated=participated[event_id], - ) - - return results - class RelationsStore(RelationsWorkerStore): pass -- cgit 1.5.1 From 516d092ff95d02c0bb2133c9316a1fb4ff2f5072 Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Wed, 23 Mar 2022 12:19:20 +0100 Subject: Rename shared_rooms to mutual_rooms (#12036) Co-authored-by: reivilibre --- changelog.d/12036.misc | 1 + synapse/rest/__init__.py | 4 +- synapse/rest/client/mutual_rooms.py | 76 ++++++++++++ synapse/rest/client/shared_rooms.py | 75 ------------ synapse/storage/databases/main/user_directory.py | 6 +- tests/rest/client/test_mutual_rooms.py | 146 +++++++++++++++++++++++ tests/rest/client/test_shared_rooms.py | 146 ----------------------- 7 files changed, 228 insertions(+), 226 deletions(-) create mode 100644 changelog.d/12036.misc create mode 100644 synapse/rest/client/mutual_rooms.py delete mode 100644 synapse/rest/client/shared_rooms.py create mode 100644 tests/rest/client/test_mutual_rooms.py delete mode 100644 tests/rest/client/test_shared_rooms.py (limited to 'synapse/storage') diff --git a/changelog.d/12036.misc b/changelog.d/12036.misc new file mode 100644 index 0000000000..d2996730cc --- /dev/null +++ b/changelog.d/12036.misc @@ -0,0 +1 @@ +Rename `shared_rooms` to `mutual_rooms` (MSC2666), as per proposal changes. \ No newline at end of file diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 762808a571..57c4773edc 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -32,6 +32,7 @@ from synapse.rest.client import ( knock, login as v1_login, logout, + mutual_rooms, notifications, openid, password_policy, @@ -49,7 +50,6 @@ from synapse.rest.client import ( room_keys, room_upgrade_rest_servlet, sendtodevice, - shared_rooms, sync, tags, thirdparty, @@ -132,4 +132,4 @@ class ClientRestResource(JsonResource): admin.register_servlets_for_client_rest_resource(hs, client_resource) # unstable - shared_rooms.register_servlets(hs, client_resource) + mutual_rooms.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/mutual_rooms.py b/synapse/rest/client/mutual_rooms.py new file mode 100644 index 0000000000..d3872a76c8 --- /dev/null +++ b/synapse/rest/client/mutual_rooms.py @@ -0,0 +1,76 @@ +# Copyright 2020 Half-Shot +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import TYPE_CHECKING, Tuple + +from synapse.api.errors import Codes, SynapseError +from synapse.http.server import HttpServer +from synapse.http.servlet import RestServlet +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict, UserID + +from ._base import client_patterns + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class UserMutualRoomsServlet(RestServlet): + """ + GET /uk.half-shot.msc2666/user/mutual_rooms/{user_id} HTTP/1.1 + """ + + PATTERNS = client_patterns( + "/uk.half-shot.msc2666/user/mutual_rooms/(?P[^/]*)", + releases=(), # This is an unstable feature + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastores().main + self.user_directory_active = hs.config.server.update_user_directory + + async def on_GET( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: + + if not self.user_directory_active: + raise SynapseError( + code=400, + msg="The user directory is disabled on this server. Cannot determine shared rooms.", + errcode=Codes.FORBIDDEN, + ) + + UserID.from_string(user_id) + + requester = await self.auth.get_user_by_req(request) + if user_id == requester.user.to_string(): + raise SynapseError( + code=400, + msg="You cannot request a list of shared rooms with yourself", + errcode=Codes.FORBIDDEN, + ) + + rooms = await self.store.get_mutual_rooms_for_users( + requester.user.to_string(), user_id + ) + + return 200, {"joined": list(rooms)} + + +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: + UserMutualRoomsServlet(hs).register(http_server) diff --git a/synapse/rest/client/shared_rooms.py b/synapse/rest/client/shared_rooms.py deleted file mode 100644 index e669fa7890..0000000000 --- a/synapse/rest/client/shared_rooms.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright 2020 Half-Shot -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -from typing import TYPE_CHECKING, Tuple - -from synapse.api.errors import Codes, SynapseError -from synapse.http.server import HttpServer -from synapse.http.servlet import RestServlet -from synapse.http.site import SynapseRequest -from synapse.types import JsonDict, UserID - -from ._base import client_patterns - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -class UserSharedRoomsServlet(RestServlet): - """ - GET /uk.half-shot.msc2666/user/shared_rooms/{user_id} HTTP/1.1 - """ - - PATTERNS = client_patterns( - "/uk.half-shot.msc2666/user/shared_rooms/(?P[^/]*)", - releases=(), # This is an unstable feature - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.store = hs.get_datastores().main - self.user_directory_active = hs.config.server.update_user_directory - - async def on_GET( - self, request: SynapseRequest, user_id: str - ) -> Tuple[int, JsonDict]: - - if not self.user_directory_active: - raise SynapseError( - code=400, - msg="The user directory is disabled on this server. Cannot determine shared rooms.", - errcode=Codes.FORBIDDEN, - ) - - UserID.from_string(user_id) - - requester = await self.auth.get_user_by_req(request) - if user_id == requester.user.to_string(): - raise SynapseError( - code=400, - msg="You cannot request a list of shared rooms with yourself", - errcode=Codes.FORBIDDEN, - ) - rooms = await self.store.get_shared_rooms_for_users( - requester.user.to_string(), user_id - ) - - return 200, {"joined": list(rooms)} - - -def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: - UserSharedRoomsServlet(hs).register(http_server) diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 55cc9178f0..0595df01d3 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -730,7 +730,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): users.update(rows) return list(users) - async def get_shared_rooms_for_users( + async def get_mutual_rooms_for_users( self, user_id: str, other_user_id: str ) -> Set[str]: """ @@ -744,7 +744,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): A set of room ID's that the users share. """ - def _get_shared_rooms_for_users_txn( + def _get_mutual_rooms_for_users_txn( txn: LoggingTransaction, ) -> List[Dict[str, str]]: txn.execute( @@ -768,7 +768,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): return rows rows = await self.db_pool.runInteraction( - "get_shared_rooms_for_users", _get_shared_rooms_for_users_txn + "get_mutual_rooms_for_users", _get_mutual_rooms_for_users_txn ) return {row["room_id"] for row in rows} diff --git a/tests/rest/client/test_mutual_rooms.py b/tests/rest/client/test_mutual_rooms.py new file mode 100644 index 0000000000..7b7d283bb6 --- /dev/null +++ b/tests/rest/client/test_mutual_rooms.py @@ -0,0 +1,146 @@ +# Copyright 2020 Half-Shot +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.test.proto_helpers import MemoryReactor + +import synapse.rest.admin +from synapse.rest.client import login, mutual_rooms, room +from synapse.server import HomeServer +from synapse.util import Clock + +from tests import unittest +from tests.server import FakeChannel + + +class UserMutualRoomsTest(unittest.HomeserverTestCase): + """ + Tests the UserMutualRoomsServlet. + """ + + servlets = [ + login.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + mutual_rooms.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + config = self.default_config() + config["update_user_directory"] = True + return self.setup_test_homeserver(config=config) + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.handler = hs.get_user_directory_handler() + + def _get_mutual_rooms(self, token: str, other_user: str) -> FakeChannel: + return self.make_request( + "GET", + "/_matrix/client/unstable/uk.half-shot.msc2666/user/mutual_rooms/%s" + % other_user, + access_token=token, + ) + + def test_shared_room_list_public(self) -> None: + """ + A room should show up in the shared list of rooms between two users + if it is public. + """ + self._check_mutual_rooms_with(room_one_is_public=True, room_two_is_public=True) + + def test_shared_room_list_private(self) -> None: + """ + A room should show up in the shared list of rooms between two users + if it is private. + """ + self._check_mutual_rooms_with( + room_one_is_public=False, room_two_is_public=False + ) + + def test_shared_room_list_mixed(self) -> None: + """ + The shared room list between two users should contain both public and private + rooms. + """ + self._check_mutual_rooms_with(room_one_is_public=True, room_two_is_public=False) + + def _check_mutual_rooms_with( + self, room_one_is_public: bool, room_two_is_public: bool + ) -> None: + """Checks that shared public or private rooms between two users appear in + their shared room lists + """ + u1 = self.register_user("user1", "pass") + u1_token = self.login(u1, "pass") + u2 = self.register_user("user2", "pass") + u2_token = self.login(u2, "pass") + + # Create a room. user1 invites user2, who joins + room_id_one = self.helper.create_room_as( + u1, is_public=room_one_is_public, tok=u1_token + ) + self.helper.invite(room_id_one, src=u1, targ=u2, tok=u1_token) + self.helper.join(room_id_one, user=u2, tok=u2_token) + + # Check shared rooms from user1's perspective. + # We should see the one room in common + channel = self._get_mutual_rooms(u1_token, u2) + self.assertEqual(200, channel.code, channel.result) + self.assertEqual(len(channel.json_body["joined"]), 1) + self.assertEqual(channel.json_body["joined"][0], room_id_one) + + # Create another room and invite user2 to it + room_id_two = self.helper.create_room_as( + u1, is_public=room_two_is_public, tok=u1_token + ) + self.helper.invite(room_id_two, src=u1, targ=u2, tok=u1_token) + self.helper.join(room_id_two, user=u2, tok=u2_token) + + # Check shared rooms again. We should now see both rooms. + channel = self._get_mutual_rooms(u1_token, u2) + self.assertEqual(200, channel.code, channel.result) + self.assertEqual(len(channel.json_body["joined"]), 2) + for room_id_id in channel.json_body["joined"]: + self.assertIn(room_id_id, [room_id_one, room_id_two]) + + def test_shared_room_list_after_leave(self) -> None: + """ + A room should no longer be considered shared if the other + user has left it. + """ + u1 = self.register_user("user1", "pass") + u1_token = self.login(u1, "pass") + u2 = self.register_user("user2", "pass") + u2_token = self.login(u2, "pass") + + room = self.helper.create_room_as(u1, is_public=True, tok=u1_token) + self.helper.invite(room, src=u1, targ=u2, tok=u1_token) + self.helper.join(room, user=u2, tok=u2_token) + + # Assert user directory is not empty + channel = self._get_mutual_rooms(u1_token, u2) + self.assertEqual(200, channel.code, channel.result) + self.assertEqual(len(channel.json_body["joined"]), 1) + self.assertEqual(channel.json_body["joined"][0], room) + + self.helper.leave(room, user=u1, tok=u1_token) + + # Check user1's view of shared rooms with user2 + channel = self._get_mutual_rooms(u1_token, u2) + self.assertEqual(200, channel.code, channel.result) + self.assertEqual(len(channel.json_body["joined"]), 0) + + # Check user2's view of shared rooms with user1 + channel = self._get_mutual_rooms(u2_token, u1) + self.assertEqual(200, channel.code, channel.result) + self.assertEqual(len(channel.json_body["joined"]), 0) diff --git a/tests/rest/client/test_shared_rooms.py b/tests/rest/client/test_shared_rooms.py deleted file mode 100644 index 3818b7b14b..0000000000 --- a/tests/rest/client/test_shared_rooms.py +++ /dev/null @@ -1,146 +0,0 @@ -# Copyright 2020 Half-Shot -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from twisted.test.proto_helpers import MemoryReactor - -import synapse.rest.admin -from synapse.rest.client import login, room, shared_rooms -from synapse.server import HomeServer -from synapse.util import Clock - -from tests import unittest -from tests.server import FakeChannel - - -class UserSharedRoomsTest(unittest.HomeserverTestCase): - """ - Tests the UserSharedRoomsServlet. - """ - - servlets = [ - login.register_servlets, - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - shared_rooms.register_servlets, - ] - - def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() - config["update_user_directory"] = True - return self.setup_test_homeserver(config=config) - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastores().main - self.handler = hs.get_user_directory_handler() - - def _get_shared_rooms(self, token: str, other_user: str) -> FakeChannel: - return self.make_request( - "GET", - "/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/%s" - % other_user, - access_token=token, - ) - - def test_shared_room_list_public(self) -> None: - """ - A room should show up in the shared list of rooms between two users - if it is public. - """ - self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=True) - - def test_shared_room_list_private(self) -> None: - """ - A room should show up in the shared list of rooms between two users - if it is private. - """ - self._check_shared_rooms_with( - room_one_is_public=False, room_two_is_public=False - ) - - def test_shared_room_list_mixed(self) -> None: - """ - The shared room list between two users should contain both public and private - rooms. - """ - self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=False) - - def _check_shared_rooms_with( - self, room_one_is_public: bool, room_two_is_public: bool - ) -> None: - """Checks that shared public or private rooms between two users appear in - their shared room lists - """ - u1 = self.register_user("user1", "pass") - u1_token = self.login(u1, "pass") - u2 = self.register_user("user2", "pass") - u2_token = self.login(u2, "pass") - - # Create a room. user1 invites user2, who joins - room_id_one = self.helper.create_room_as( - u1, is_public=room_one_is_public, tok=u1_token - ) - self.helper.invite(room_id_one, src=u1, targ=u2, tok=u1_token) - self.helper.join(room_id_one, user=u2, tok=u2_token) - - # Check shared rooms from user1's perspective. - # We should see the one room in common - channel = self._get_shared_rooms(u1_token, u2) - self.assertEqual(200, channel.code, channel.result) - self.assertEqual(len(channel.json_body["joined"]), 1) - self.assertEqual(channel.json_body["joined"][0], room_id_one) - - # Create another room and invite user2 to it - room_id_two = self.helper.create_room_as( - u1, is_public=room_two_is_public, tok=u1_token - ) - self.helper.invite(room_id_two, src=u1, targ=u2, tok=u1_token) - self.helper.join(room_id_two, user=u2, tok=u2_token) - - # Check shared rooms again. We should now see both rooms. - channel = self._get_shared_rooms(u1_token, u2) - self.assertEqual(200, channel.code, channel.result) - self.assertEqual(len(channel.json_body["joined"]), 2) - for room_id_id in channel.json_body["joined"]: - self.assertIn(room_id_id, [room_id_one, room_id_two]) - - def test_shared_room_list_after_leave(self) -> None: - """ - A room should no longer be considered shared if the other - user has left it. - """ - u1 = self.register_user("user1", "pass") - u1_token = self.login(u1, "pass") - u2 = self.register_user("user2", "pass") - u2_token = self.login(u2, "pass") - - room = self.helper.create_room_as(u1, is_public=True, tok=u1_token) - self.helper.invite(room, src=u1, targ=u2, tok=u1_token) - self.helper.join(room, user=u2, tok=u2_token) - - # Assert user directory is not empty - channel = self._get_shared_rooms(u1_token, u2) - self.assertEqual(200, channel.code, channel.result) - self.assertEqual(len(channel.json_body["joined"]), 1) - self.assertEqual(channel.json_body["joined"][0], room) - - self.helper.leave(room, user=u1, tok=u1_token) - - # Check user1's view of shared rooms with user2 - channel = self._get_shared_rooms(u1_token, u2) - self.assertEqual(200, channel.code, channel.result) - self.assertEqual(len(channel.json_body["joined"]), 0) - - # Check user2's view of shared rooms with user1 - channel = self._get_shared_rooms(u2_token, u1) - self.assertEqual(200, channel.code, channel.result) - self.assertEqual(len(channel.json_body["joined"]), 0) -- cgit 1.5.1 From f4c5e5864cdc04aa61ad13d6f6ba870df811a881 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 23 Mar 2022 14:03:24 +0000 Subject: Use psycopg2 type stubs (#12269) --- changelog.d/12269.misc | 1 + setup.py | 1 + synapse/storage/database.py | 14 +++++++++++--- synapse/storage/engines/__init__.py | 2 +- 4 files changed, 14 insertions(+), 4 deletions(-) create mode 100644 changelog.d/12269.misc (limited to 'synapse/storage') diff --git a/changelog.d/12269.misc b/changelog.d/12269.misc new file mode 100644 index 0000000000..ed79cbb528 --- /dev/null +++ b/changelog.d/12269.misc @@ -0,0 +1 @@ +Use type stubs for `psycopg2`. diff --git a/setup.py b/setup.py index 439ed75d72..63da71ad7b 100755 --- a/setup.py +++ b/setup.py @@ -108,6 +108,7 @@ CONDITIONAL_REQUIREMENTS["mypy"] = [ "types-jsonschema>=3.2.0", "types-opentracing>=2.4.2", "types-Pillow>=8.3.4", + "types-psycopg2>=2.9.9", "types-pyOpenSSL>=20.0.7", "types-PyYAML>=5.4.10", "types-requests>=2.26.0", diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 9749f0c06e..367709a1a7 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -288,7 +288,7 @@ class LoggingTransaction: """ if isinstance(self.database_engine, PostgresEngine): - from psycopg2.extras import execute_batch # type: ignore + from psycopg2.extras import execute_batch self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args) else: @@ -302,10 +302,18 @@ class LoggingTransaction: rows (e.g. INSERTs). """ assert isinstance(self.database_engine, PostgresEngine) - from psycopg2.extras import execute_values # type: ignore + from psycopg2.extras import execute_values return self._do_execute( - lambda *x: execute_values(self.txn, *x, fetch=fetch), sql, *args + # Type ignore: mypy is unhappy because if `x` is a 5-tuple, then there will + # be two values for `fetch`: one given positionally, and another given + # as a keyword argument. We might be able to fix this by + # - propagating the signature of psycopg2.extras.execute_values to this + # function, or + # - changing `*args: Any` to `values: T` for some appropriate T. + lambda *x: execute_values(self.txn, *x, fetch=fetch), # type: ignore[misc] + sql, + *args, ) def execute(self, sql: str, *args: Any) -> None: diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py index 9abc02046e..afb7d5054d 100644 --- a/synapse/storage/engines/__init__.py +++ b/synapse/storage/engines/__init__.py @@ -27,7 +27,7 @@ def create_engine(database_config) -> BaseDatabaseEngine: if name == "psycopg2": # Note that psycopg2cffi-compat provides the psycopg2 module on pypy. - import psycopg2 # type: ignore + import psycopg2 return PostgresEngine(psycopg2, database_config) -- cgit 1.5.1 From e78d4f61fc881851ab35e9a889239a61cf9805e5 Mon Sep 17 00:00:00 2001 From: Shay Date: Wed, 23 Mar 2022 10:23:05 -0700 Subject: Refuse to start if DB has an unsafe locale (#12262) --- changelog.d/12262.misc | 1 + docs/postgres.md | 7 +++--- docs/sample_config.yaml | 6 +++++ synapse/config/database.py | 6 +++++ synapse/storage/engines/postgres.py | 45 ++++++++++++++++++++++++------------ tests/storage/test_unsafe_locale.py | 46 +++++++++++++++++++++++++++++++++++++ 6 files changed, 93 insertions(+), 18 deletions(-) create mode 100644 changelog.d/12262.misc create mode 100644 tests/storage/test_unsafe_locale.py (limited to 'synapse/storage') diff --git a/changelog.d/12262.misc b/changelog.d/12262.misc new file mode 100644 index 0000000000..574ac4752c --- /dev/null +++ b/changelog.d/12262.misc @@ -0,0 +1 @@ +Refuse to start if DB has non-`C` locale, unless config flag `allow_unsafe_db_locale` is set to true. \ No newline at end of file diff --git a/docs/postgres.md b/docs/postgres.md index de4e2ba4b7..cbc32e1836 100644 --- a/docs/postgres.md +++ b/docs/postgres.md @@ -234,12 +234,13 @@ host all all ::1/128 ident ### Fixing incorrect `COLLATE` or `CTYPE` Synapse will refuse to set up a new database if it has the wrong values of -`COLLATE` and `CTYPE` set, and will log warnings on existing databases. Using -different locales can cause issues if the locale library is updated from +`COLLATE` and `CTYPE` set. Synapse will also refuse to start an existing database with incorrect values +of `COLLATE` and `CTYPE` unless the config flag `allow_unsafe_locale`, found in the +`database` section of the config, is set to true. Using different locales can cause issues if the locale library is updated from underneath the database, or if a different version of the locale is used on any replicas. -The safest way to fix the issue is to dump the database and recreate it with +If you have a databse with an unsafe locale, the safest way to fix the issue is to dump the database and recreate it with the correct locale parameter (as shown above). It is also possible to change the parameters on a live database and run a `REINDEX` on the entire database, however extreme care must be taken to avoid database corruption. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 36c6c56e58..9c2359ed8e 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -783,6 +783,12 @@ caches: # 'txn_limit' gives the maximum number of transactions to run per connection # before reconnecting. Defaults to 0, which means no limit. # +# 'allow_unsafe_locale' is an option specific to Postgres. Under the default behavior, Synapse will refuse to +# start if the postgres db is set to a non-C locale. You can override this behavior (which is *not* recommended) +# by setting 'allow_unsafe_locale' to true. Note that doing so may corrupt your database. You can find more information +# here: https://matrix-org.github.io/synapse/latest/postgres.html#fixing-incorrect-collate-or-ctype and here: +# https://wiki.postgresql.org/wiki/Locale_data_changes +# # 'args' gives options which are passed through to the database engine, # except for options starting 'cp_', which are used to configure the Twisted # connection pool. For a reference to valid arguments, see: diff --git a/synapse/config/database.py b/synapse/config/database.py index 06ccf15cd9..d7f2219f53 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -37,6 +37,12 @@ DEFAULT_CONFIG = """\ # 'txn_limit' gives the maximum number of transactions to run per connection # before reconnecting. Defaults to 0, which means no limit. # +# 'allow_unsafe_locale' is an option specific to Postgres. Under the default behavior, Synapse will refuse to +# start if the postgres db is set to a non-C locale. You can override this behavior (which is *not* recommended) +# by setting 'allow_unsafe_locale' to true. Note that doing so may corrupt your database. You can find more information +# here: https://matrix-org.github.io/synapse/latest/postgres.html#fixing-incorrect-collate-or-ctype and here: +# https://wiki.postgresql.org/wiki/Locale_data_changes +# # 'args' gives options which are passed through to the database engine, # except for options starting 'cp_', which are used to configure the Twisted # connection pool. For a reference to valid arguments, see: diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index 808342fafb..e8d29e2870 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -47,17 +47,26 @@ class PostgresEngine(BaseDatabaseEngine): self.default_isolation_level = ( self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ ) + self.config = database_config @property def single_threaded(self) -> bool: return False + def get_db_locale(self, txn): + txn.execute( + "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()" + ) + collation, ctype = txn.fetchone() + return collation, ctype + def check_database(self, db_conn, allow_outdated_version: bool = False): # Get the version of PostgreSQL that we're using. As per the psycopg2 # docs: The number is formed by converting the major, minor, and # revision numbers into two-decimal-digit numbers and appending them # together. For example, version 8.1.5 will be returned as 80105 self._version = db_conn.server_version + allow_unsafe_locale = self.config.get("allow_unsafe_locale", False) # Are we on a supported PostgreSQL version? if not allow_outdated_version and self._version < 100000: @@ -72,33 +81,39 @@ class PostgresEngine(BaseDatabaseEngine): "See docs/postgres.md for more information." % (rows[0][0],) ) - txn.execute( - "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()" - ) - collation, ctype = txn.fetchone() + collation, ctype = self.get_db_locale(txn) if collation != "C": logger.warning( - "Database has incorrect collation of %r. Should be 'C'\n" - "See docs/postgres.md for more information.", + "Database has incorrect collation of %r. Should be 'C'", collation, ) + if not allow_unsafe_locale: + raise IncorrectDatabaseSetup( + "Database has incorrect collation of %r. Should be 'C'\n" + "See docs/postgres.md for more information. You can override this check by" + "setting 'allow_unsafe_locale' to true in the database config.", + collation, + ) if ctype != "C": - logger.warning( - "Database has incorrect ctype of %r. Should be 'C'\n" - "See docs/postgres.md for more information.", - ctype, - ) + if not allow_unsafe_locale: + logger.warning( + "Database has incorrect ctype of %r. Should be 'C'", + ctype, + ) + raise IncorrectDatabaseSetup( + "Database has incorrect ctype of %r. Should be 'C'\n" + "See docs/postgres.md for more information. You can override this check by" + "setting 'allow_unsafe_locale' to true in the database config.", + ctype, + ) def check_new_database(self, txn): """Gets called when setting up a brand new database. This allows us to apply stricter checks on new databases versus existing database. """ - txn.execute( - "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()" - ) - collation, ctype = txn.fetchone() + collation, ctype = self.get_db_locale(txn) errors = [] diff --git a/tests/storage/test_unsafe_locale.py b/tests/storage/test_unsafe_locale.py new file mode 100644 index 0000000000..ba53c22818 --- /dev/null +++ b/tests/storage/test_unsafe_locale.py @@ -0,0 +1,46 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import MagicMock, patch + +from synapse.storage.database import make_conn +from synapse.storage.engines._base import IncorrectDatabaseSetup + +from tests.unittest import HomeserverTestCase +from tests.utils import USE_POSTGRES_FOR_TESTS + + +class UnsafeLocaleTest(HomeserverTestCase): + if not USE_POSTGRES_FOR_TESTS: + skip = "Requires Postgres" + + @patch("synapse.storage.engines.postgres.PostgresEngine.get_db_locale") + def test_unsafe_locale(self, mock_db_locale: MagicMock) -> None: + mock_db_locale.return_value = ("B", "B") + database = self.hs.get_datastores().databases[0] + + db_conn = make_conn(database._database_config, database.engine, "test_unsafe") + with self.assertRaises(IncorrectDatabaseSetup): + database.engine.check_database(db_conn) + with self.assertRaises(IncorrectDatabaseSetup): + database.engine.check_new_database(db_conn) + db_conn.close() + + def test_safe_locale(self) -> None: + database = self.hs.get_datastores().databases[0] + + db_conn = make_conn(database._database_config, database.engine, "test_unsafe") + with db_conn.cursor() as txn: + res = database.engine.get_db_locale(txn) + self.assertEqual(res, ("C", "C")) + db_conn.close() -- cgit 1.5.1 From 7ca8ee67a5165e33f03454218c81be96397e7591 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 25 Mar 2022 14:58:56 +0000 Subject: Add cache for `get_membership_from_event_ids` (#12272) This should speed up push rule calculations for rooms with large numbers of local users when the main push rule cache fails. Co-authored-by: reivilibre --- changelog.d/12272.misc | 1 + synapse/push/bulk_push_rule_evaluator.py | 30 +++++++++++----------- synapse/storage/databases/main/cache.py | 4 +++ synapse/storage/databases/main/events.py | 7 ++++++ synapse/storage/databases/main/roommember.py | 37 +++++++++++++++++++++++++--- synapse/storage/persist_events.py | 15 ++++++++--- 6 files changed, 72 insertions(+), 22 deletions(-) create mode 100644 changelog.d/12272.misc (limited to 'synapse/storage') diff --git a/changelog.d/12272.misc b/changelog.d/12272.misc new file mode 100644 index 0000000000..95589f3361 --- /dev/null +++ b/changelog.d/12272.misc @@ -0,0 +1 @@ +Add a new cache `_get_membership_from_event_id` to speed up push rule calculations in large rooms. diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 030898e4d0..a402a3e403 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -24,6 +24,7 @@ from synapse.event_auth import get_user_power_level from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.state import POWER_KEY +from synapse.storage.databases.main.roommember import EventIdMembership from synapse.util.async_helpers import Linearizer from synapse.util.caches import CacheMetric, register_cache from synapse.util.caches.descriptors import lru_cache @@ -292,7 +293,7 @@ def _condition_checker( return True -MemberMap = Dict[str, Tuple[str, str]] +MemberMap = Dict[str, Optional[EventIdMembership]] Rule = Dict[str, dict] RulesByUser = Dict[str, List[Rule]] StateGroup = Union[object, int] @@ -306,7 +307,7 @@ class RulesForRoomData: *only* include data, and not references to e.g. the data stores. """ - # event_id -> (user_id, state) + # event_id -> EventIdMembership member_map: MemberMap = attr.Factory(dict) # user_id -> rules rules_by_user: RulesByUser = attr.Factory(dict) @@ -447,11 +448,10 @@ class RulesForRoom: res = self.data.member_map.get(event_id, None) if res: - user_id, state = res - if state == Membership.JOIN: - rules = self.data.rules_by_user.get(user_id, None) + if res.membership == Membership.JOIN: + rules = self.data.rules_by_user.get(res.user_id, None) if rules: - ret_rules_by_user[user_id] = rules + ret_rules_by_user[res.user_id] = rules continue # If a user has left a room we remove their push rule. If they @@ -502,24 +502,26 @@ class RulesForRoom: """ sequence = self.data.sequence - rows = await self.store.get_membership_from_event_ids(member_event_ids.values()) - - members = {row["event_id"]: (row["user_id"], row["membership"]) for row in rows} + members = await self.store.get_membership_from_event_ids( + member_event_ids.values() + ) - # If the event is a join event then it will be in current state evnts + # If the event is a join event then it will be in current state events # map but not in the DB, so we have to explicitly insert it. if event.type == EventTypes.Member: for event_id in member_event_ids.values(): if event_id == event.event_id: - members[event_id] = (event.state_key, event.membership) + members[event_id] = EventIdMembership( + user_id=event.state_key, membership=event.membership + ) if logger.isEnabledFor(logging.DEBUG): logger.debug("Found members %r: %r", self.room_id, members.values()) joined_user_ids = { - user_id - for user_id, membership in members.values() - if membership == Membership.JOIN + entry.user_id + for entry in members.values() + if entry and entry.membership == Membership.JOIN } logger.debug("Joined: %r", joined_user_ids) diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 2d7511d613..dd4e83a2ad 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -192,6 +192,10 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self.get_unread_event_push_actions_by_room_for_user.invalidate((room_id,)) + # The `_get_membership_from_event_id` is immutable, except for the + # case where we look up an event *before* persisting it. + self._get_membership_from_event_id.invalidate((event_id,)) + if not backfilled: self._events_stream_cache.entity_has_changed(room_id, stream_ordering) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 1f60aef180..d253243125 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1745,6 +1745,13 @@ class PersistEventsStore: (event.state_key,), ) + # The `_get_membership_from_event_id` is immutable, except for the + # case where we look up an event *before* persisting it. + txn.call_after( + self.store._get_membership_from_event_id.invalidate, + (event.event_id,), + ) + # We update the local_current_membership table only if the event is # "current", i.e., its something that has just happened. # diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index bef675b845..3248da5356 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -63,6 +63,14 @@ _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update" _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership" +@attr.s(frozen=True, slots=True, auto_attribs=True) +class EventIdMembership: + """Returned by `get_membership_from_event_ids`""" + + user_id: str + membership: str + + class RoomMemberWorkerStore(EventsWorkerStore): def __init__( self, @@ -772,7 +780,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): retcols=("user_id", "display_name", "avatar_url", "event_id"), keyvalues={"membership": Membership.JOIN}, batch_size=500, - desc="_get_membership_from_event_ids", + desc="_get_joined_profiles_from_event_ids", ) return { @@ -1000,12 +1008,26 @@ class RoomMemberWorkerStore(EventsWorkerStore): return set(room_ids) + @cached(max_entries=5000) + async def _get_membership_from_event_id( + self, member_event_id: str + ) -> Optional[EventIdMembership]: + raise NotImplementedError() + + @cachedList( + cached_method_name="_get_membership_from_event_id", list_name="member_event_ids" + ) async def get_membership_from_event_ids( self, member_event_ids: Iterable[str] - ) -> List[dict]: - """Get user_id and membership of a set of event IDs.""" + ) -> Dict[str, Optional[EventIdMembership]]: + """Get user_id and membership of a set of event IDs. + + Returns: + Mapping from event ID to `EventIdMembership` if the event is a + membership event, otherwise the value is None. + """ - return await self.db_pool.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="room_memberships", column="event_id", iterable=member_event_ids, @@ -1015,6 +1037,13 @@ class RoomMemberWorkerStore(EventsWorkerStore): desc="get_membership_from_event_ids", ) + return { + row["event_id"]: EventIdMembership( + membership=row["membership"], user_id=row["user_id"] + ) + for row in rows + } + async def is_local_host_in_room_ignoring_users( self, room_id: str, ignore_users: Collection[str] ) -> bool: diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 7d543fdbe0..b402922817 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -1023,8 +1023,13 @@ class EventsPersistenceStorage: # Check if any of the changes that we don't have events for are joins. if events_to_check: - rows = await self.main_store.get_membership_from_event_ids(events_to_check) - is_still_joined = any(row["membership"] == Membership.JOIN for row in rows) + members = await self.main_store.get_membership_from_event_ids( + events_to_check + ) + is_still_joined = any( + member and member.membership == Membership.JOIN + for member in members.values() + ) if is_still_joined: return True @@ -1060,9 +1065,11 @@ class EventsPersistenceStorage: ), event_id in current_state.items() if typ == EventTypes.Member and not self.is_mine_id(state_key) ] - rows = await self.main_store.get_membership_from_event_ids(remote_event_ids) + members = await self.main_store.get_membership_from_event_ids(remote_event_ids) potentially_left_users.update( - row["user_id"] for row in rows if row["membership"] == Membership.JOIN + member.user_id + for member in members.values() + if member and member.membership == Membership.JOIN ) return False -- cgit 1.5.1