diff options
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/events/utils.py | 11 | ||||
-rw-r--r-- | synapse/rest/client/relations.py | 7 | ||||
-rw-r--r-- | synapse/storage/databases/main/events.py | 8 | ||||
-rw-r--r-- | synapse/storage/databases/main/relations.py | 36 |
4 files changed, 45 insertions, 17 deletions
diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 84ef69df67..3da432c1df 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -454,23 +454,26 @@ class EventClientSerializer: return event_id = event.event_id + room_id = event.room_id # The bundled aggregations to include. aggregations = {} - annotations = await self.store.get_aggregation_groups_for_event(event_id) + annotations = await self.store.get_aggregation_groups_for_event( + event_id, room_id + ) if annotations.chunk: aggregations[RelationTypes.ANNOTATION] = annotations.to_dict() references = await self.store.get_relations_for_event( - event_id, RelationTypes.REFERENCE, direction="f" + event_id, room_id, RelationTypes.REFERENCE, direction="f" ) if references.chunk: aggregations[RelationTypes.REFERENCE] = references.to_dict() edit = None if event.type == EventTypes.Message: - edit = await self.store.get_applicable_edit(event_id) + edit = await self.store.get_applicable_edit(event_id, room_id) if edit: # If there is an edit replace the content, preserving existing @@ -503,7 +506,7 @@ class EventClientSerializer: ( thread_count, latest_thread_event, - ) = await self.store.get_thread_summary(event_id) + ) = await self.store.get_thread_summary(event_id, room_id) if latest_thread_event: aggregations[RelationTypes.THREAD] = { # Don't bundle aggregations as this could recurse forever. diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index fc4e6921c5..ffa37ef06c 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -212,6 +212,7 @@ class RelationPaginationServlet(RestServlet): pagination_chunk = await self.store.get_relations_for_event( event_id=parent_id, + room_id=room_id, relation_type=relation_type, event_type=event_type, limit=limit, @@ -317,6 +318,7 @@ class RelationAggregationPaginationServlet(RestServlet): pagination_chunk = await self.store.get_aggregation_groups_for_event( event_id=parent_id, + room_id=room_id, event_type=event_type, limit=limit, from_token=from_token, @@ -383,7 +385,9 @@ class RelationAggregationGroupPaginationServlet(RestServlet): # This checks that a) the event exists and b) the user is allowed to # view it. - await self.event_handler.get_event(requester.user, room_id, parent_id) + event = await self.event_handler.get_event(requester.user, room_id, parent_id) + if event is None: + raise SynapseError(404, "Unknown parent event.") if relation_type != RelationTypes.ANNOTATION: raise SynapseError(400, "Relation type must be 'annotation'") @@ -402,6 +406,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): result = await self.store.get_relations_for_event( event_id=parent_id, + room_id=room_id, relation_type=relation_type, event_type=event_type, aggregation_key=key, diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 4e528612ea..f1f4ce5e07 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1780,10 +1780,14 @@ class PersistEventsStore: ) if rel_type == RelationTypes.REPLACE: - txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,)) + txn.call_after( + self.store.get_applicable_edit.invalidate, (parent_id, event.room_id) + ) if rel_type == RelationTypes.THREAD: - txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,)) + txn.call_after( + self.store.get_thread_summary.invalidate, (parent_id, event.room_id) + ) def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase): """Handles keeping track of insertion events and edges/connections. diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 0a43acda07..3368a8b084 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -37,6 +37,7 @@ class RelationsWorkerStore(SQLBaseStore): async def get_relations_for_event( self, event_id: str, + room_id: str, relation_type: Optional[str] = None, event_type: Optional[str] = None, aggregation_key: Optional[str] = None, @@ -49,6 +50,7 @@ class RelationsWorkerStore(SQLBaseStore): Args: event_id: Fetch events that relate to this event ID. + room_id: The room the event belongs to. relation_type: Only fetch events with this relation type, if given. event_type: Only fetch events with this event type, if given. aggregation_key: Only fetch events with this aggregation key, if given. @@ -63,8 +65,8 @@ class RelationsWorkerStore(SQLBaseStore): the form `{"event_id": "..."}`. """ - where_clause = ["relates_to_id = ?"] - where_args: List[Union[str, int]] = [event_id] + where_clause = ["relates_to_id = ?", "room_id = ?"] + where_args: List[Union[str, int]] = [event_id, room_id] if relation_type is not None: where_clause.append("relation_type = ?") @@ -199,6 +201,7 @@ class RelationsWorkerStore(SQLBaseStore): async def get_aggregation_groups_for_event( self, event_id: str, + room_id: str, event_type: Optional[str] = None, limit: int = 5, direction: str = "b", @@ -213,6 +216,7 @@ class RelationsWorkerStore(SQLBaseStore): Args: event_id: Fetch events that relate to this event ID. + room_id: The room the event belongs to. event_type: Only fetch events with this event type, if given. limit: Only fetch the `limit` groups. direction: Whether to fetch the highest count first (`"b"`) or @@ -225,8 +229,12 @@ class RelationsWorkerStore(SQLBaseStore): `type`, `key` and `count` fields. """ - where_clause = ["relates_to_id = ?", "relation_type = ?"] - where_args: List[Union[str, int]] = [event_id, RelationTypes.ANNOTATION] + where_clause = ["relates_to_id = ?", "room_id = ?", "relation_type = ?"] + where_args: List[Union[str, int]] = [ + event_id, + room_id, + RelationTypes.ANNOTATION, + ] if event_type: where_clause.append("type = ?") @@ -288,7 +296,9 @@ class RelationsWorkerStore(SQLBaseStore): ) @cached() - async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]: + async def get_applicable_edit( + self, event_id: str, room_id: str + ) -> Optional[EventBase]: """Get the most recent edit (if any) that has happened for the given event. @@ -296,6 +306,7 @@ class RelationsWorkerStore(SQLBaseStore): Args: event_id: The original event ID + room_id: The original event's room ID Returns: The most recent edit, if any. @@ -317,13 +328,14 @@ class RelationsWorkerStore(SQLBaseStore): WHERE relates_to_id = ? AND relation_type = ? + AND edit.room_id = ? AND edit.type = 'm.room.message' ORDER by edit.origin_server_ts DESC, edit.event_id DESC LIMIT 1 """ def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]: - txn.execute(sql, (event_id, RelationTypes.REPLACE)) + txn.execute(sql, (event_id, RelationTypes.REPLACE, room_id)) row = txn.fetchone() if row: return row[0] @@ -340,13 +352,14 @@ class RelationsWorkerStore(SQLBaseStore): @cached() async def get_thread_summary( - self, event_id: str + self, event_id: str, room_id: str ) -> Tuple[int, Optional[EventBase]]: """Get the number of threaded replies, the senders of those replies, and the latest reply (if any) for the given event. Args: - event_id: The original event ID + event_id: Summarize the thread related to this event ID. + room_id: The room the event belongs to. Returns: The number of items in the thread and the most recent response, if any. @@ -363,12 +376,13 @@ class RelationsWorkerStore(SQLBaseStore): INNER JOIN events USING (event_id) WHERE relates_to_id = ? + AND room_id = ? AND relation_type = ? ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT 1 """ - txn.execute(sql, (event_id, RelationTypes.THREAD)) + txn.execute(sql, (event_id, room_id, RelationTypes.THREAD)) row = txn.fetchone() if row is None: return 0, None @@ -378,11 +392,13 @@ class RelationsWorkerStore(SQLBaseStore): sql = """ SELECT COALESCE(COUNT(event_id), 0) FROM event_relations + INNER JOIN events USING (event_id) WHERE relates_to_id = ? + AND room_id = ? AND relation_type = ? """ - txn.execute(sql, (event_id, RelationTypes.THREAD)) + txn.execute(sql, (event_id, room_id, RelationTypes.THREAD)) count = txn.fetchone()[0] # type: ignore[index] return count, latest_event_id |