diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 2bbf6d6a95..53576ad52f 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -13,13 +13,14 @@
# limitations under the License.
import logging
-from typing import Optional
+from typing import List, Optional, Tuple, Union
import attr
from synapse.api.constants import RelationTypes
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.relations import (
AggregationPaginationToken,
@@ -63,7 +64,7 @@ class RelationsWorkerStore(SQLBaseStore):
"""
where_clause = ["relates_to_id = ?"]
- where_args = [event_id]
+ where_args: List[Union[str, int]] = [event_id]
if relation_type is not None:
where_clause.append("relation_type = ?")
@@ -80,8 +81,8 @@ class RelationsWorkerStore(SQLBaseStore):
pagination_clause = generate_pagination_where_clause(
direction=direction,
column_names=("topological_ordering", "stream_ordering"),
- from_token=attr.astuple(from_token) if from_token else None,
- to_token=attr.astuple(to_token) if to_token else None,
+ from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type]
+ to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type]
engine=self.database_engine,
)
@@ -106,7 +107,9 @@ class RelationsWorkerStore(SQLBaseStore):
order,
)
- def _get_recent_references_for_event_txn(txn):
+ def _get_recent_references_for_event_txn(
+ txn: LoggingTransaction,
+ ) -> PaginationChunk:
txn.execute(sql, where_args + [limit + 1])
last_topo_id = None
@@ -160,7 +163,7 @@ class RelationsWorkerStore(SQLBaseStore):
"""
where_clause = ["relates_to_id = ?", "relation_type = ?"]
- where_args = [event_id, RelationTypes.ANNOTATION]
+ where_args: List[Union[str, int]] = [event_id, RelationTypes.ANNOTATION]
if event_type:
where_clause.append("type = ?")
@@ -169,8 +172,8 @@ class RelationsWorkerStore(SQLBaseStore):
having_clause = generate_pagination_where_clause(
direction=direction,
column_names=("COUNT(*)", "MAX(stream_ordering)"),
- from_token=attr.astuple(from_token) if from_token else None,
- to_token=attr.astuple(to_token) if to_token else None,
+ from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type]
+ to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type]
engine=self.database_engine,
)
@@ -199,7 +202,9 @@ class RelationsWorkerStore(SQLBaseStore):
having_clause=having_clause,
)
- def _get_aggregation_groups_for_event_txn(txn):
+ def _get_aggregation_groups_for_event_txn(
+ txn: LoggingTransaction,
+ ) -> PaginationChunk:
txn.execute(sql, where_args + [limit + 1])
next_batch = None
@@ -254,11 +259,12 @@ class RelationsWorkerStore(SQLBaseStore):
LIMIT 1
"""
- def _get_applicable_edit_txn(txn):
+ def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]:
txn.execute(sql, (event_id, RelationTypes.REPLACE))
row = txn.fetchone()
if row:
return row[0]
+ return None
edit_id = await self.db_pool.runInteraction(
"get_applicable_edit", _get_applicable_edit_txn
@@ -267,7 +273,66 @@ class RelationsWorkerStore(SQLBaseStore):
if not edit_id:
return None
- return await self.get_event(edit_id, allow_none=True)
+ return await self.get_event(edit_id, allow_none=True) # type: ignore[attr-defined]
+
+ @cached()
+ async def get_thread_summary(
+ self, event_id: str
+ ) -> Tuple[int, Optional[EventBase]]:
+ """Get the number of threaded replies, the senders of those replies, and
+ the latest reply (if any) for the given event.
+
+ Args:
+ event_id: The original event ID
+
+ Returns:
+ The number of items in the thread and the most recent response, if any.
+ """
+
+ def _get_thread_summary_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[int, Optional[str]]:
+ # Fetch the count of threaded events and the latest event ID.
+ # TODO Should this only allow m.room.message events.
+ sql = """
+ SELECT event_id
+ FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE
+ relates_to_id = ?
+ AND relation_type = ?
+ ORDER BY topological_ordering DESC, stream_ordering DESC
+ LIMIT 1
+ """
+
+ txn.execute(sql, (event_id, RelationTypes.THREAD))
+ row = txn.fetchone()
+ if row is None:
+ return 0, None
+
+ latest_event_id = row[0]
+
+ sql = """
+ SELECT COALESCE(COUNT(event_id), 0)
+ FROM event_relations
+ WHERE
+ relates_to_id = ?
+ AND relation_type = ?
+ """
+ txn.execute(sql, (event_id, RelationTypes.THREAD))
+ count = txn.fetchone()[0] # type: ignore[index]
+
+ return count, latest_event_id
+
+ count, latest_event_id = await self.db_pool.runInteraction(
+ "get_thread_summary", _get_thread_summary_txn
+ )
+
+ latest_event = None
+ if latest_event_id:
+ latest_event = await self.get_event(latest_event_id, allow_none=True) # type: ignore[attr-defined]
+
+ return count, latest_event
async def has_user_annotated_event(
self, parent_id: str, event_type: str, aggregation_key: str, sender: str
@@ -297,7 +362,7 @@ class RelationsWorkerStore(SQLBaseStore):
LIMIT 1;
"""
- def _get_if_user_has_annotated_event(txn):
+ def _get_if_user_has_annotated_event(txn: LoggingTransaction) -> bool:
txn.execute(
sql,
(
|