diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index fe8fded88b..3b1b2ce6cb 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -13,6 +13,7 @@
# limitations under the License.
import logging
+from collections import defaultdict
from typing import (
Collection,
Dict,
@@ -767,6 +768,57 @@ class RelationsWorkerStore(SQLBaseStore):
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
)
+ @cached(iterable=True)
+ async def get_mutual_event_relations_for_rel_type(
+ self, event_id: str, relation_type: str
+ ) -> Set[Tuple[str, str]]:
+ raise NotImplementedError()
+
+ @cachedList(
+ cached_method_name="get_mutual_event_relations_for_rel_type",
+ list_name="relation_types",
+ )
+ async def get_mutual_event_relations(
+ self, event_id: str, relation_types: Collection[str]
+ ) -> Dict[str, Set[Tuple[str, str]]]:
+ """
+ Fetch event metadata for events which related to the same event as the given event.
+
+ If the given event has no relation information, returns an empty dictionary.
+
+ Args:
+ event_id: The event ID which is targeted by relations.
+ relation_types: The relation types to check for mutual relations.
+
+ Returns:
+ A dictionary of relation type to:
+ A set of tuples of:
+ The sender
+ The event type
+ """
+ rel_type_sql, rel_type_args = make_in_list_sql_clause(
+ self.database_engine, "relation_type", relation_types
+ )
+
+ sql = f"""
+ SELECT DISTINCT relation_type, sender, type FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE relates_to_id = ? AND {rel_type_sql}
+ """
+
+ def _get_event_relations(
+ txn: LoggingTransaction,
+ ) -> Dict[str, Set[Tuple[str, str]]]:
+ txn.execute(sql, [event_id] + rel_type_args)
+ result = defaultdict(set)
+ for rel_type, sender, type in txn.fetchall():
+ result[rel_type].add((sender, type))
+ return result
+
+ return await self.db_pool.runInteraction(
+ "get_event_relations", _get_event_relations
+ )
+
class RelationsStore(RelationsWorkerStore):
pass
|