diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index 9a1cc11bb3..0b63cd2186 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -12,16 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import (
- TYPE_CHECKING,
- Collection,
- Dict,
- FrozenSet,
- Iterable,
- List,
- Optional,
- Tuple,
-)
+from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple
import attr
@@ -256,13 +247,19 @@ class RelationsHandler:
return filtered_results
- async def get_threads_for_events(
- self, event_ids: Collection[str], user_id: str, ignored_users: FrozenSet[str]
+ async def _get_threads_for_events(
+ self,
+ events_by_id: Dict[str, EventBase],
+ relations_by_id: Dict[str, str],
+ user_id: str,
+ ignored_users: FrozenSet[str],
) -> Dict[str, _ThreadAggregation]:
"""Get the bundled aggregations for threads for the requested events.
Args:
- event_ids: Events to get aggregations for threads.
+ events_by_id: A map of event_id to events to get aggregations for threads.
+ relations_by_id: A map of event_id to the relation type, if one exists
+ for that event.
user_id: The user requesting the bundled aggregations.
ignored_users: The users ignored by the requesting user.
@@ -273,16 +270,34 @@ class RelationsHandler:
"""
user = UserID.from_string(user_id)
+ # It is not valid to start a thread on an event which itself relates to another event.
+ event_ids = [eid for eid in events_by_id.keys() if eid not in relations_by_id]
+
# Fetch thread summaries.
summaries = await self._main_store.get_thread_summaries(event_ids)
- # Only fetch participated for a limited selection based on what had
- # summaries.
+ # Limit fetching whether the requester has participated in a thread to
+ # events which are thread roots.
thread_event_ids = [
event_id for event_id, summary in summaries.items() if summary
]
- participated = await self._main_store.get_threads_participated(
- thread_event_ids, user_id
+
+ # Pre-seed thread participation with whether the requester sent the event.
+ participated = {
+ event_id: events_by_id[event_id].sender == user_id
+ for event_id in thread_event_ids
+ }
+ # For events the requester did not send, check the database for whether
+ # the requester sent a threaded reply.
+ participated.update(
+ await self._main_store.get_threads_participated(
+ [
+ event_id
+ for event_id in thread_event_ids
+ if not participated[event_id]
+ ],
+ user_id,
+ )
)
# Then subtract off the results for any ignored users.
@@ -343,7 +358,8 @@ class RelationsHandler:
count=thread_count,
# If there's a thread summary it must also exist in the
# participated dictionary.
- current_user_participated=participated[event_id],
+ current_user_participated=events_by_id[event_id].sender == user_id
+ or participated[event_id],
)
return results
@@ -401,9 +417,9 @@ class RelationsHandler:
# events to be fetched. Thus, we check those first!
# Fetch thread summaries (but only for the directly requested events).
- threads = await self.get_threads_for_events(
- # It is not valid to start a thread on an event which itself relates to another event.
- [eid for eid in events_by_id.keys() if eid not in relations_by_id],
+ threads = await self._get_threads_for_events(
+ events_by_id,
+ relations_by_id,
user_id,
ignored_users,
)
|