diff --git a/changelog.d/11577.feature b/changelog.d/11577.feature
new file mode 100644
index 0000000000..f9c8a0d5f4
--- /dev/null
+++ b/changelog.d/11577.feature
@@ -0,0 +1 @@
+Include whether the requesting user has participated in a thread when generating a summary for [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440).
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 472688f045..973f262964 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -537,7 +537,7 @@ class PaginationHandler:
state_dict = await self.store.get_events(list(state_ids.values()))
state = state_dict.values()
- aggregations = await self.store.get_bundled_aggregations(events)
+ aggregations = await self.store.get_bundled_aggregations(events, user_id)
time_now = self.clock.time_msec()
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 3d47163f25..f963078e59 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -1182,12 +1182,18 @@ class RoomContextHandler:
results["event"] = filtered[0]
# Fetch the aggregations.
- aggregations = await self.store.get_bundled_aggregations([results["event"]])
+ aggregations = await self.store.get_bundled_aggregations(
+ [results["event"]], user.to_string()
+ )
aggregations.update(
- await self.store.get_bundled_aggregations(results["events_before"])
+ await self.store.get_bundled_aggregations(
+ results["events_before"], user.to_string()
+ )
)
aggregations.update(
- await self.store.get_bundled_aggregations(results["events_after"])
+ await self.store.get_bundled_aggregations(
+ results["events_after"], user.to_string()
+ )
)
results["aggregations"] = aggregations
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index e1df9b3106..ffc6b748e8 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -637,7 +637,9 @@ class SyncHandler:
# as clients will have all the necessary information.
bundled_aggregations = None
if limited or newly_joined_room:
- bundled_aggregations = await self.store.get_bundled_aggregations(recents)
+ bundled_aggregations = await self.store.get_bundled_aggregations(
+ recents, sync_config.user.to_string()
+ )
return TimelineBatch(
events=recents,
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index 37d949a71e..8cf5ebaa07 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -118,7 +118,9 @@ class RelationPaginationServlet(RestServlet):
)
# The relations returned for the requested event do include their
# bundled aggregations.
- aggregations = await self.store.get_bundled_aggregations(events)
+ aggregations = await self.store.get_bundled_aggregations(
+ events, requester.user.to_string()
+ )
serialized_events = self._event_serializer.serialize_events(
events, now, bundle_aggregations=aggregations
)
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index da6014900a..31fd329a38 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -663,7 +663,9 @@ class RoomEventServlet(RestServlet):
if event:
# Ensure there are bundled aggregations available.
- aggregations = await self._store.get_bundled_aggregations([event])
+ aggregations = await self._store.get_bundled_aggregations(
+ [event], requester.user.to_string()
+ )
time_now = self.clock.time_msec()
event_dict = self._event_serializer.serialize_event(
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 2be36a741a..7278002322 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1793,6 +1793,13 @@ class PersistEventsStore:
txn.call_after(
self.store.get_thread_summary.invalidate, (parent_id, event.room_id)
)
+ # It should be safe to only invalidate the cache if the user has not
+ # previously participated in the thread, but that's difficult (and
+ # potentially error-prone) so it is always invalidated.
+ txn.call_after(
+ self.store.get_thread_participated.invalidate,
+ (parent_id, event.room_id, event.sender),
+ )
def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
"""Handles keeping track of insertion events and edges/connections.
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index c6c4bd18da..2cb5d06c13 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -384,8 +384,7 @@ class RelationsWorkerStore(SQLBaseStore):
async def get_thread_summary(
self, event_id: str, room_id: str
) -> Tuple[int, Optional[EventBase]]:
- """Get the number of threaded replies, the senders of those replies, and
- the latest reply (if any) for the given event.
+ """Get the number of threaded replies and the latest reply (if any) for the given event.
Args:
event_id: Summarize the thread related to this event ID.
@@ -398,7 +397,7 @@ class RelationsWorkerStore(SQLBaseStore):
def _get_thread_summary_txn(
txn: LoggingTransaction,
) -> Tuple[int, Optional[str]]:
- # Fetch the count of threaded events and the latest event ID.
+ # Fetch the latest event ID in the thread.
# TODO Should this only allow m.room.message events.
sql = """
SELECT event_id
@@ -419,6 +418,7 @@ class RelationsWorkerStore(SQLBaseStore):
latest_event_id = row[0]
+ # Fetch the number of threaded replies.
sql = """
SELECT COUNT(event_id)
FROM event_relations
@@ -443,6 +443,44 @@ class RelationsWorkerStore(SQLBaseStore):
return count, latest_event
+ @cached()
+ async def get_thread_participated(
+ self, event_id: str, room_id: str, user_id: str
+ ) -> bool:
+ """Get whether the requesting user participated in a thread.
+
+ This is separate from get_thread_summary since that can be cached across
+ all users while this value is specific to the requeser.
+
+ Args:
+ event_id: The thread related to this event ID.
+ room_id: The room the event belongs to.
+ user_id: The user requesting the summary.
+
+ Returns:
+ True if the requesting user participated in the thread, otherwise false.
+ """
+
+ def _get_thread_summary_txn(txn: LoggingTransaction) -> bool:
+ # Fetch whether the requester has participated or not.
+ sql = """
+ SELECT 1
+ FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE
+ relates_to_id = ?
+ AND room_id = ?
+ AND relation_type = ?
+ AND sender = ?
+ """
+
+ txn.execute(sql, (event_id, room_id, RelationTypes.THREAD, user_id))
+ return bool(txn.fetchone())
+
+ return await self.db_pool.runInteraction(
+ "get_thread_summary", _get_thread_summary_txn
+ )
+
async def events_have_relations(
self,
parent_ids: List[str],
@@ -546,7 +584,7 @@ class RelationsWorkerStore(SQLBaseStore):
)
async def _get_bundled_aggregation_for_event(
- self, event: EventBase
+ self, event: EventBase, user_id: str
) -> Optional[Dict[str, Any]]:
"""Generate bundled aggregations for an event.
@@ -554,6 +592,7 @@ class RelationsWorkerStore(SQLBaseStore):
Args:
event: The event to calculate bundled aggregations for.
+ user_id: The user requesting the bundled aggregations.
Returns:
The bundled aggregations for an event, if bundled aggregations are
@@ -598,27 +637,32 @@ class RelationsWorkerStore(SQLBaseStore):
# If this event is the start of a thread, include a summary of the replies.
if self._msc3440_enabled:
- (
- thread_count,
- latest_thread_event,
- ) = await self.get_thread_summary(event_id, room_id)
+ thread_count, latest_thread_event = await self.get_thread_summary(
+ event_id, room_id
+ )
+ participated = await self.get_thread_participated(
+ event_id, room_id, user_id
+ )
if latest_thread_event:
aggregations[RelationTypes.THREAD] = {
- # Don't bundle aggregations as this could recurse forever.
"latest_event": latest_thread_event,
"count": thread_count,
+ "current_user_participated": participated,
}
# Store the bundled aggregations in the event metadata for later use.
return aggregations
async def get_bundled_aggregations(
- self, events: Iterable[EventBase]
+ self,
+ events: Iterable[EventBase],
+ user_id: str,
) -> Dict[str, Dict[str, Any]]:
"""Generate bundled aggregations for events.
Args:
events: The iterable of events to calculate bundled aggregations for.
+ user_id: The user requesting the bundled aggregations.
Returns:
A map of event ID to the bundled aggregation for the event. Not all
@@ -631,7 +675,7 @@ class RelationsWorkerStore(SQLBaseStore):
# TODO Parallelize.
results = {}
for event in events:
- event_result = await self._get_bundled_aggregation_for_event(event)
+ event_result = await self._get_bundled_aggregation_for_event(event, user_id)
if event_result is not None:
results[event.event_id] = event_result
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index ee26751430..4b20ab0e3e 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -515,6 +515,9 @@ class RelationsTestCase(unittest.HomeserverTestCase):
2,
actual[RelationTypes.THREAD].get("count"),
)
+ self.assertTrue(
+ actual[RelationTypes.THREAD].get("current_user_participated")
+ )
# The latest thread event has some fields that don't matter.
self.assert_dict(
{
|