diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index c6c4bd18da..2cb5d06c13 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -384,8 +384,7 @@ class RelationsWorkerStore(SQLBaseStore):
async def get_thread_summary(
self, event_id: str, room_id: str
) -> Tuple[int, Optional[EventBase]]:
- """Get the number of threaded replies, the senders of those replies, and
- the latest reply (if any) for the given event.
+ """Get the number of threaded replies and the latest reply (if any) for the given event.
Args:
event_id: Summarize the thread related to this event ID.
@@ -398,7 +397,7 @@ class RelationsWorkerStore(SQLBaseStore):
def _get_thread_summary_txn(
txn: LoggingTransaction,
) -> Tuple[int, Optional[str]]:
- # Fetch the count of threaded events and the latest event ID.
+ # Fetch the latest event ID in the thread.
# TODO Should this only allow m.room.message events.
sql = """
SELECT event_id
@@ -419,6 +418,7 @@ class RelationsWorkerStore(SQLBaseStore):
latest_event_id = row[0]
+ # Fetch the number of threaded replies.
sql = """
SELECT COUNT(event_id)
FROM event_relations
@@ -443,6 +443,44 @@ class RelationsWorkerStore(SQLBaseStore):
return count, latest_event
+ @cached()
+ async def get_thread_participated(
+ self, event_id: str, room_id: str, user_id: str
+ ) -> bool:
+ """Get whether the requesting user participated in a thread.
+
+ This is separate from get_thread_summary since that can be cached across
+ all users while this value is specific to the requeser.
+
+ Args:
+ event_id: The thread related to this event ID.
+ room_id: The room the event belongs to.
+ user_id: The user requesting the summary.
+
+ Returns:
+ True if the requesting user participated in the thread, otherwise false.
+ """
+
+ def _get_thread_summary_txn(txn: LoggingTransaction) -> bool:
+ # Fetch whether the requester has participated or not.
+ sql = """
+ SELECT 1
+ FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE
+ relates_to_id = ?
+ AND room_id = ?
+ AND relation_type = ?
+ AND sender = ?
+ """
+
+ txn.execute(sql, (event_id, room_id, RelationTypes.THREAD, user_id))
+ return bool(txn.fetchone())
+
+ return await self.db_pool.runInteraction(
+ "get_thread_summary", _get_thread_summary_txn
+ )
+
async def events_have_relations(
self,
parent_ids: List[str],
@@ -546,7 +584,7 @@ class RelationsWorkerStore(SQLBaseStore):
)
async def _get_bundled_aggregation_for_event(
- self, event: EventBase
+ self, event: EventBase, user_id: str
) -> Optional[Dict[str, Any]]:
"""Generate bundled aggregations for an event.
@@ -554,6 +592,7 @@ class RelationsWorkerStore(SQLBaseStore):
Args:
event: The event to calculate bundled aggregations for.
+ user_id: The user requesting the bundled aggregations.
Returns:
The bundled aggregations for an event, if bundled aggregations are
@@ -598,27 +637,32 @@ class RelationsWorkerStore(SQLBaseStore):
# If this event is the start of a thread, include a summary of the replies.
if self._msc3440_enabled:
- (
- thread_count,
- latest_thread_event,
- ) = await self.get_thread_summary(event_id, room_id)
+ thread_count, latest_thread_event = await self.get_thread_summary(
+ event_id, room_id
+ )
+ participated = await self.get_thread_participated(
+ event_id, room_id, user_id
+ )
if latest_thread_event:
aggregations[RelationTypes.THREAD] = {
- # Don't bundle aggregations as this could recurse forever.
"latest_event": latest_thread_event,
"count": thread_count,
+ "current_user_participated": participated,
}
# Store the bundled aggregations in the event metadata for later use.
return aggregations
async def get_bundled_aggregations(
- self, events: Iterable[EventBase]
+ self,
+ events: Iterable[EventBase],
+ user_id: str,
) -> Dict[str, Dict[str, Any]]:
"""Generate bundled aggregations for events.
Args:
events: The iterable of events to calculate bundled aggregations for.
+ user_id: The user requesting the bundled aggregations.
Returns:
A map of event ID to the bundled aggregation for the event. Not all
@@ -631,7 +675,7 @@ class RelationsWorkerStore(SQLBaseStore):
# TODO Parallelize.
results = {}
for event in events:
- event_result = await self._get_bundled_aggregation_for_event(event)
+ event_result = await self._get_bundled_aggregation_for_event(event, user_id)
if event_result is not None:
results[event.event_id] = event_result
|