diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 0a43acda07..4ff6aed253 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
-from typing import List, Optional, Tuple, Union
+from typing import List, Optional, Tuple, Union, cast
import attr
@@ -37,6 +37,7 @@ class RelationsWorkerStore(SQLBaseStore):
async def get_relations_for_event(
self,
event_id: str,
+ room_id: str,
relation_type: Optional[str] = None,
event_type: Optional[str] = None,
aggregation_key: Optional[str] = None,
@@ -49,6 +50,7 @@ class RelationsWorkerStore(SQLBaseStore):
Args:
event_id: Fetch events that relate to this event ID.
+ room_id: The room the event belongs to.
relation_type: Only fetch events with this relation type, if given.
event_type: Only fetch events with this event type, if given.
aggregation_key: Only fetch events with this aggregation key, if given.
@@ -63,8 +65,8 @@ class RelationsWorkerStore(SQLBaseStore):
the form `{"event_id": "..."}`.
"""
- where_clause = ["relates_to_id = ?"]
- where_args: List[Union[str, int]] = [event_id]
+ where_clause = ["relates_to_id = ?", "room_id = ?"]
+ where_args: List[Union[str, int]] = [event_id, room_id]
if relation_type is not None:
where_clause.append("relation_type = ?")
@@ -199,6 +201,7 @@ class RelationsWorkerStore(SQLBaseStore):
async def get_aggregation_groups_for_event(
self,
event_id: str,
+ room_id: str,
event_type: Optional[str] = None,
limit: int = 5,
direction: str = "b",
@@ -213,6 +216,7 @@ class RelationsWorkerStore(SQLBaseStore):
Args:
event_id: Fetch events that relate to this event ID.
+ room_id: The room the event belongs to.
event_type: Only fetch events with this event type, if given.
limit: Only fetch the `limit` groups.
direction: Whether to fetch the highest count first (`"b"`) or
@@ -225,8 +229,12 @@ class RelationsWorkerStore(SQLBaseStore):
`type`, `key` and `count` fields.
"""
- where_clause = ["relates_to_id = ?", "relation_type = ?"]
- where_args: List[Union[str, int]] = [event_id, RelationTypes.ANNOTATION]
+ where_clause = ["relates_to_id = ?", "room_id = ?", "relation_type = ?"]
+ where_args: List[Union[str, int]] = [
+ event_id,
+ room_id,
+ RelationTypes.ANNOTATION,
+ ]
if event_type:
where_clause.append("type = ?")
@@ -288,7 +296,9 @@ class RelationsWorkerStore(SQLBaseStore):
)
@cached()
- async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]:
+ async def get_applicable_edit(
+ self, event_id: str, room_id: str
+ ) -> Optional[EventBase]:
"""Get the most recent edit (if any) that has happened for the given
event.
@@ -296,6 +306,7 @@ class RelationsWorkerStore(SQLBaseStore):
Args:
event_id: The original event ID
+ room_id: The original event's room ID
Returns:
The most recent edit, if any.
@@ -317,13 +328,14 @@ class RelationsWorkerStore(SQLBaseStore):
WHERE
relates_to_id = ?
AND relation_type = ?
+ AND edit.room_id = ?
AND edit.type = 'm.room.message'
ORDER by edit.origin_server_ts DESC, edit.event_id DESC
LIMIT 1
"""
def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]:
- txn.execute(sql, (event_id, RelationTypes.REPLACE))
+ txn.execute(sql, (event_id, RelationTypes.REPLACE, room_id))
row = txn.fetchone()
if row:
return row[0]
@@ -340,13 +352,14 @@ class RelationsWorkerStore(SQLBaseStore):
@cached()
async def get_thread_summary(
- self, event_id: str
+ self, event_id: str, room_id: str
) -> Tuple[int, Optional[EventBase]]:
"""Get the number of threaded replies, the senders of those replies, and
the latest reply (if any) for the given event.
Args:
- event_id: The original event ID
+ event_id: Summarize the thread related to this event ID.
+ room_id: The room the event belongs to.
Returns:
The number of items in the thread and the most recent response, if any.
@@ -363,12 +376,13 @@ class RelationsWorkerStore(SQLBaseStore):
INNER JOIN events USING (event_id)
WHERE
relates_to_id = ?
+ AND room_id = ?
AND relation_type = ?
ORDER BY topological_ordering DESC, stream_ordering DESC
LIMIT 1
"""
- txn.execute(sql, (event_id, RelationTypes.THREAD))
+ txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
row = txn.fetchone()
if row is None:
return 0, None
@@ -376,14 +390,16 @@ class RelationsWorkerStore(SQLBaseStore):
latest_event_id = row[0]
sql = """
- SELECT COALESCE(COUNT(event_id), 0)
+ SELECT COUNT(event_id)
FROM event_relations
+ INNER JOIN events USING (event_id)
WHERE
relates_to_id = ?
+ AND room_id = ?
AND relation_type = ?
"""
- txn.execute(sql, (event_id, RelationTypes.THREAD))
- count = txn.fetchone()[0] # type: ignore[index]
+ txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
+ count = cast(Tuple[int], txn.fetchone())[0]
return count, latest_event_id
|