summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/relations.py58
1 files changed, 37 insertions, 21 deletions
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index 9a1cc11bb3..0b63cd2186 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -12,16 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import (
-    TYPE_CHECKING,
-    Collection,
-    Dict,
-    FrozenSet,
-    Iterable,
-    List,
-    Optional,
-    Tuple,
-)
+from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple
 
 import attr
 
@@ -256,13 +247,19 @@ class RelationsHandler:
 
         return filtered_results
 
-    async def get_threads_for_events(
-        self, event_ids: Collection[str], user_id: str, ignored_users: FrozenSet[str]
+    async def _get_threads_for_events(
+        self,
+        events_by_id: Dict[str, EventBase],
+        relations_by_id: Dict[str, str],
+        user_id: str,
+        ignored_users: FrozenSet[str],
     ) -> Dict[str, _ThreadAggregation]:
         """Get the bundled aggregations for threads for the requested events.
 
         Args:
-            event_ids: Events to get aggregations for threads.
+            events_by_id: A map of event_id to events to get aggregations for threads.
+            relations_by_id: A map of event_id to the relation type, if one exists
+                for that event.
             user_id: The user requesting the bundled aggregations.
             ignored_users: The users ignored by the requesting user.
 
@@ -273,16 +270,34 @@ class RelationsHandler:
         """
         user = UserID.from_string(user_id)
 
+        # It is not valid to start a thread on an event which itself relates to another event.
+        event_ids = [eid for eid in events_by_id.keys() if eid not in relations_by_id]
+
         # Fetch thread summaries.
         summaries = await self._main_store.get_thread_summaries(event_ids)
 
-        # Only fetch participated for a limited selection based on what had
-        # summaries.
+        # Limit fetching whether the requester has participated in a thread to
+        # events which are thread roots.
         thread_event_ids = [
             event_id for event_id, summary in summaries.items() if summary
         ]
-        participated = await self._main_store.get_threads_participated(
-            thread_event_ids, user_id
+
+        # Pre-seed thread participation with whether the requester sent the event.
+        participated = {
+            event_id: events_by_id[event_id].sender == user_id
+            for event_id in thread_event_ids
+        }
+        # For events the requester did not send, check the database for whether
+        # the requester sent a threaded reply.
+        participated.update(
+            await self._main_store.get_threads_participated(
+                [
+                    event_id
+                    for event_id in thread_event_ids
+                    if not participated[event_id]
+                ],
+                user_id,
+            )
         )
 
         # Then subtract off the results for any ignored users.
@@ -343,7 +358,8 @@ class RelationsHandler:
                     count=thread_count,
                     # If there's a thread summary it must also exist in the
                     # participated dictionary.
-                    current_user_participated=participated[event_id],
+                    current_user_participated=events_by_id[event_id].sender == user_id
+                    or participated[event_id],
                 )
 
         return results
@@ -401,9 +417,9 @@ class RelationsHandler:
         # events to be fetched. Thus, we check those first!
 
         # Fetch thread summaries (but only for the directly requested events).
-        threads = await self.get_threads_for_events(
-            # It is not valid to start a thread on an event which itself relates to another event.
-            [eid for eid in events_by_id.keys() if eid not in relations_by_id],
+        threads = await self._get_threads_for_events(
+            events_by_id,
+            relations_by_id,
             user_id,
             ignored_users,
         )