diff options
author | Patrick Cloke <patrickc@matrix.org> | 2022-07-25 15:01:06 -0400 |
---|---|---|
committer | Patrick Cloke <patrickc@matrix.org> | 2022-07-27 11:46:22 -0400 |
commit | e9a649ec314122d34d1b408c83800e1f1e263554 (patch) | |
tree | d29cd47c03635234d7140189cb4ddcda527a4406 | |
parent | Implement MSC3848: Introduce errcodes for specific event sending failures (#1... (diff) | |
download | synapse-e9a649ec314122d34d1b408c83800e1f1e263554.tar.xz |
Add an API for listing threads in a room.
-rw-r--r-- | changelog.d/13394.feature | 1 | ||||
-rw-r--r-- | synapse/config/experimental.py | 3 | ||||
-rw-r--r-- | synapse/handlers/relations.py | 63 | ||||
-rw-r--r-- | synapse/rest/client/relations.py | 44 | ||||
-rw-r--r-- | synapse/storage/databases/main/events.py | 7 | ||||
-rw-r--r-- | synapse/storage/databases/main/relations.py | 87 | ||||
-rw-r--r-- | tests/rest/client/test_relations.py | 74 |
7 files changed, 277 insertions, 2 deletions
diff --git a/changelog.d/13394.feature b/changelog.d/13394.feature new file mode 100644 index 0000000000..68de079cf3 --- /dev/null +++ b/changelog.d/13394.feature @@ -0,0 +1 @@ +Experimental support for [MSC3856](https://github.com/matrix-org/matrix-spec-proposals/pull/3856): threads list API. diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 1902222d7b..5613c8e038 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -93,3 +93,6 @@ class ExperimentalConfig(Config): # MSC3848: Introduce errcodes for specific event sending failures self.msc3848_enabled: bool = experimental.get("msc3848_enabled", False) + + # MSC3856: Threads list API + self.msc3856_enabled: bool = experimental.get("msc3856_enabled", False) diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 8f797e3ae9..af83ac1a7b 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -480,3 +480,66 @@ class RelationsHandler: results.setdefault(event_id, BundledAggregations()).replace = edit return results + + async def get_threads( + self, + requester: Requester, + room_id: str, + limit: int = 5, + from_token: Optional[StreamToken] = None, + to_token: Optional[StreamToken] = None, + ) -> JsonDict: + """Get related events of a event, ordered by topological ordering. + + Args: + requester: The user requesting the relations. + room_id: The room the event belongs to. + limit: Only fetch the most recent `limit` events. + from_token: Fetch rows from the given token, or from the start if None. + to_token: Fetch rows up to the given token, or up to the end if None. + + Returns: + The pagination chunk. + """ + + user_id = requester.user.to_string() + + # TODO Properly handle a user leaving a room. + (_, member_event_id) = await self._auth.check_user_in_room_or_world_readable( + room_id, user_id, allow_departed_users=True + ) + + # Note that ignored users are not passed into get_relations_for_event + # below. Ignored users are handled in filter_events_for_client (and by + # not passing them in here we should get a better cache hit rate). + thread_roots, next_token = await self._main_store.get_threads( + room_id=room_id, limit=limit, from_token=from_token, to_token=to_token + ) + + events = await self._main_store.get_events_as_list(thread_roots) + + events = await filter_events_for_client( + self._storage_controllers, + user_id, + events, + is_peeking=(member_event_id is None), + ) + + now = self._clock.time_msec() + + aggregations = await self.get_bundled_aggregations( + events, requester.user.to_string() + ) + serialized_events = self._event_serializer.serialize_events( + events, now, bundle_aggregations=aggregations + ) + + return_value: JsonDict = {"chunk": serialized_events} + + if next_token: + return_value["next_batch"] = await next_token.to_string(self._main_store) + + if from_token: + return_value["prev_batch"] = await from_token.to_string(self._main_store) + + return return_value diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index ce97080013..faa962e3a8 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +import re from typing import TYPE_CHECKING, Optional, Tuple from synapse.http.server import HttpServer @@ -91,5 +92,48 @@ class RelationPaginationServlet(RestServlet): return 200, result +class ThreadsServlet(RestServlet): + PATTERNS = ( + re.compile( + "^/_matrix/client/unstable/org.matrix.msc3856/rooms/(?P<room_id>[^/]*)/threads" + ), + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastores().main + self._relations_handler = hs.get_relations_handler() + + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + + limit = parse_integer(request, "limit", default=5) + from_token_str = parse_string(request, "from") + to_token_str = parse_string(request, "to") + + # Return the relations + from_token = None + if from_token_str: + from_token = await StreamToken.from_string(self.store, from_token_str) + to_token = None + if to_token_str: + to_token = await StreamToken.from_string(self.store, to_token_str) + + result = await self._relations_handler.get_threads( + requester=requester, + room_id=room_id, + limit=limit, + from_token=from_token, + to_token=to_token, + ) + + return 200, result + + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RelationPaginationServlet(hs).register(http_server) + if hs.config.experimental.msc3856_enabled: + ThreadsServlet(hs).register(http_server) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 1f600f1190..cf18dd179c 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1594,7 +1594,7 @@ class PersistEventsStore: ) # Remove from relations table. - self._handle_redact_relations(txn, event.redacts) + self._handle_redact_relations(txn, event.room_id, event.redacts) # Update the event_forward_extremities, event_backward_extremities and # event_edges tables. @@ -1909,6 +1909,7 @@ class PersistEventsStore: self.store.get_thread_participated.invalidate, (relation.parent_id, event.sender), ) + txn.call_after(self.store.get_threads.invalidate, (event.room_id,)) def _handle_insertion_event( self, txn: LoggingTransaction, event: EventBase @@ -2033,13 +2034,14 @@ class PersistEventsStore: txn.execute(sql, (batch_id,)) def _handle_redact_relations( - self, txn: LoggingTransaction, redacted_event_id: str + self, txn: LoggingTransaction, room_id: str, redacted_event_id: str ) -> None: """Handles receiving a redaction and checking whether the redacted event has any relations which must be removed from the database. Args: txn + room_id: The room ID of the event that was redacted. redacted_event_id: The event that was redacted. """ @@ -2068,6 +2070,7 @@ class PersistEventsStore: self.store._invalidate_cache_and_stream( txn, self.store.get_thread_participated, (redacted_relates_to,) ) + txn.call_after(self.store.get_threads.invalidate, (room_id,)) self.store._invalidate_cache_and_stream( txn, self.store.get_mutual_event_relations_for_rel_type, diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 7bd27790eb..57b2f7c188 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -814,6 +814,93 @@ class RelationsWorkerStore(SQLBaseStore): "get_event_relations", _get_event_relations ) + @cached(tree=True) + async def get_threads( + self, + room_id: str, + limit: int = 5, + from_token: Optional[StreamToken] = None, + to_token: Optional[StreamToken] = None, + ) -> Tuple[List[str], Optional[StreamToken]]: + """Get a list of thread IDs, ordered by topological ordering of their + latest reply. + + Args: + room_id: The room the event belongs to. + limit: Only fetch the most recent `limit` threads. + from_token: Fetch rows from the given token, or from the start if None. + to_token: Fetch rows up to the given token, or up to the end if None. + + Returns: + A tuple of: + A list of thread root event IDs. + + The next stream token, if one exists. + """ + pagination_clause = generate_pagination_where_clause( + direction="b", + column_names=("topological_ordering", "stream_ordering"), + from_token=from_token.room_key.as_historical_tuple() + if from_token + else None, + to_token=to_token.room_key.as_historical_tuple() if to_token else None, + engine=self.database_engine, + ) + + if pagination_clause: + pagination_clause = "AND " + pagination_clause + + sql = f""" + SELECT relates_to_id, MAX(topological_ordering), MAX(stream_ordering) + FROM event_relations + INNER JOIN events USING (event_id) + WHERE + room_id = ? AND + relation_type = '{RelationTypes.THREAD}' + {pagination_clause} + GROUP BY relates_to_id + ORDER BY MAX(topological_ordering) DESC, MAX(stream_ordering) DESC + LIMIT ? + """ + + def _get_threads_txn( + txn: LoggingTransaction, + ) -> Tuple[List[str], Optional[StreamToken]]: + txn.execute(sql, [room_id, limit + 1]) + + last_topo_id = None + last_stream_id = None + thread_ids = [] + for thread_id, topo_id, stream_id in txn: + thread_ids.append(thread_id) + last_topo_id = topo_id + last_stream_id = stream_id + + # If there are more events, generate the next pagination key. + next_token = None + if len(thread_ids) > limit and last_topo_id and last_stream_id: + next_key = RoomStreamToken(last_topo_id, last_stream_id) + if from_token: + next_token = from_token.copy_and_replace( + StreamKeyType.ROOM, next_key + ) + else: + next_token = StreamToken( + room_key=next_key, + presence_key=0, + typing_key=0, + receipt_key=0, + account_data_key=0, + push_rules_key=0, + to_device_key=0, + device_list_key=0, + groups_key=0, + ) + + return thread_ids[:limit], next_token + + return await self.db_pool.runInteraction("get_threads", _get_threads_txn) + class RelationsStore(RelationsWorkerStore): pass diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index ad03eee17b..0666bec479 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1677,3 +1677,77 @@ class RelationRedactionTestCase(BaseRelationsTestCase): relations[RelationTypes.THREAD]["latest_event"]["event_id"], related_event_id, ) + + +class ThreadsTestCase(BaseRelationsTestCase): + @unittest.override_config({"experimental_features": {"msc3856_enabled": True}}) + def test_threads(self) -> None: + """Create threads and ensure the ordering is due to their latest event.""" + # Create 2 threads. + thread_1 = self.parent_id + res = self.helper.send(self.room, body="Thread Root!", tok=self.user_token) + thread_2 = res["event_id"] + + self._send_relation(RelationTypes.THREAD, "m.room.test") + self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2) + + # Request the threads in the room. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(thread_roots, [thread_2, thread_1]) + + # Update the first thread, the ordering should swap. + self._send_relation(RelationTypes.THREAD, "m.room.test") + + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(thread_roots, [thread_1, thread_2]) + + @unittest.override_config({"experimental_features": {"msc3856_enabled": True}}) + def test_pagination(self) -> None: + """Create threads and paginate through them.""" + # Create 2 threads. + thread_1 = self.parent_id + res = self.helper.send(self.room, body="Thread Root!", tok=self.user_token) + thread_2 = res["event_id"] + + self._send_relation(RelationTypes.THREAD, "m.room.test") + self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2) + + # Request the threads in the room. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?limit=1", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(thread_roots, [thread_2]) + + # Make sure next_batch has something in it that looks like it could be a + # valid token. + next_batch = channel.json_body.get("next_batch") + self.assertIsInstance(next_batch, str, channel.json_body) + + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?limit=1&from={next_batch}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(thread_roots, [thread_1], channel.json_body) + + self.assertNotIn("next_batch", channel.json_body, channel.json_body) + + # XXX Test ignoring users. |