diff options
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/push/bulk_push_rule_evaluator.py | 5 | ||||
-rw-r--r-- | synapse/rest/client/receipts.py | 22 | ||||
-rw-r--r-- | synapse/storage/databases/main/relations.py | 36 |
3 files changed, 61 insertions, 2 deletions
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 61d952742d..f8c4dd74f0 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -286,8 +286,13 @@ class BulkPushRuleEvaluator: relation.parent_id, itertools.chain(*(r.rules() for r in rules_by_user.values())), ) + # Recursively attempt to find the thread this event relates to. if relation.rel_type == RelationTypes.THREAD: thread_id = relation.parent_id + else: + # Since the event has not yet been persisted we check whether + # the parent is part of a thread. + thread_id = await self.store.get_thread_id(relation.parent_id) or "main" evaluator = PushRuleEvaluator( _flatten_dict(event), diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index f3ff156abe..287dfdd69e 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -16,7 +16,7 @@ import logging from typing import TYPE_CHECKING, Tuple from synapse.api.constants import ReceiptTypes -from synapse.api.errors import SynapseError +from synapse.api.errors import Codes, SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.site import SynapseRequest @@ -43,6 +43,7 @@ class ReceiptRestServlet(RestServlet): self.receipts_handler = hs.get_receipts_handler() self.read_marker_handler = hs.get_read_marker_handler() self.presence_handler = hs.get_presence_handler() + self._main_store = hs.get_datastores().main self._known_receipt_types = { ReceiptTypes.READ, @@ -71,7 +72,24 @@ class ReceiptRestServlet(RestServlet): thread_id = body.get("thread_id") if not thread_id or not isinstance(thread_id, str): raise SynapseError( - 400, "thread_id field must be a non-empty string" + 400, + "thread_id field must be a non-empty string", + Codes.INVALID_PARAM, + ) + + if receipt_type == ReceiptTypes.FULLY_READ: + raise SynapseError( + 400, + f"thread_id is not compatible with {ReceiptTypes.FULLY_READ} receipts.", + Codes.INVALID_PARAM, + ) + + # Ensure the event ID roughly correlates to the thread ID. + if thread_id != await self._main_store.get_thread_id(event_id): + raise SynapseError( + 400, + f"event_id {event_id} is not related to thread {thread_id}", + Codes.INVALID_PARAM, ) await self.presence_handler.bump_presence_active_time(requester.user) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 898947af95..154385b1e8 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -832,6 +832,42 @@ class RelationsWorkerStore(SQLBaseStore): "get_event_relations", _get_event_relations ) + @cached() + async def get_thread_id(self, event_id: str) -> Optional[str]: + """ + Get the thread ID for an event. This considers multi-level relations, + e.g. an annotation to an event which is part of a thread. + + Args: + event_id: The event ID to fetch the thread ID for. + + Returns: + The event ID of the root event in the thread, if this event is part + of a thread. None, otherwise. + """ + # Since event relations form a tree, we should only ever find 0 or 1 + # results from the below query. + sql = """ + WITH RECURSIVE related_events AS ( + SELECT event_id, relates_to_id, relation_type + FROM event_relations + WHERE event_id = ? + UNION SELECT e.event_id, e.relates_to_id, e.relation_type + FROM event_relations e + INNER JOIN related_events r ON r.relates_to_id = e.event_id + ) SELECT relates_to_id FROM related_events WHERE relation_type = 'm.thread'; + """ + + def _get_thread_id(txn: LoggingTransaction) -> Optional[str]: + txn.execute(sql, (event_id,)) + # TODO Should we ensure there's only a single result here? + row = txn.fetchone() + if row: + return row[0] + return None + + return await self.db_pool.runInteraction("get_thread_id", _get_thread_id) + class RelationsStore(RelationsWorkerStore): pass |