summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/databases/main/events.py9
-rw-r--r--synapse/storage/databases/main/push_rule.py5
-rw-r--r--synapse/storage/databases/main/relations.py52
3 files changed, 66 insertions, 0 deletions
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 0df8ff5395..17e35cf63e 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1828,6 +1828,10 @@ class PersistEventsStore:
             self.store.get_aggregation_groups_for_event.invalidate,
             (relation.parent_id,),
         )
+        txn.call_after(
+            self.store.get_mutual_event_relations_for_rel_type.invalidate,
+            (relation.parent_id,),
+        )
 
         if relation.rel_type == RelationTypes.REPLACE:
             txn.call_after(
@@ -2004,6 +2008,11 @@ class PersistEventsStore:
             self.store._invalidate_cache_and_stream(
                 txn, self.store.get_thread_participated, (redacted_relates_to,)
             )
+            self.store._invalidate_cache_and_stream(
+                txn,
+                self.store.get_mutual_event_relations_for_rel_type,
+                (redacted_relates_to,),
+            )
 
         self.db_pool.simple_delete_txn(
             txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index ad67901cc1..4adabc88cc 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -61,6 +61,11 @@ def _is_experimental_rule_enabled(
         and not experimental_config.msc3786_enabled
     ):
         return False
+    if (
+        rule_id == "global/underride/.org.matrix.msc3772.thread_reply"
+        and not experimental_config.msc3772_enabled
+    ):
+        return False
     return True
 
 
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index fe8fded88b..3b1b2ce6cb 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 
 import logging
+from collections import defaultdict
 from typing import (
     Collection,
     Dict,
@@ -767,6 +768,57 @@ class RelationsWorkerStore(SQLBaseStore):
             "get_if_user_has_annotated_event", _get_if_user_has_annotated_event
         )
 
+    @cached(iterable=True)
+    async def get_mutual_event_relations_for_rel_type(
+        self, event_id: str, relation_type: str
+    ) -> Set[Tuple[str, str]]:
+        raise NotImplementedError()
+
+    @cachedList(
+        cached_method_name="get_mutual_event_relations_for_rel_type",
+        list_name="relation_types",
+    )
+    async def get_mutual_event_relations(
+        self, event_id: str, relation_types: Collection[str]
+    ) -> Dict[str, Set[Tuple[str, str]]]:
+        """
+        Fetch event metadata for events which related to the same event as the given event.
+
+        If the given event has no relation information, returns an empty dictionary.
+
+        Args:
+            event_id: The event ID which is targeted by relations.
+            relation_types: The relation types to check for mutual relations.
+
+        Returns:
+            A dictionary of relation type to:
+                A set of tuples of:
+                    The sender
+                    The event type
+        """
+        rel_type_sql, rel_type_args = make_in_list_sql_clause(
+            self.database_engine, "relation_type", relation_types
+        )
+
+        sql = f"""
+            SELECT DISTINCT relation_type, sender, type FROM event_relations
+            INNER JOIN events USING (event_id)
+            WHERE relates_to_id = ? AND {rel_type_sql}
+        """
+
+        def _get_event_relations(
+            txn: LoggingTransaction,
+        ) -> Dict[str, Set[Tuple[str, str]]]:
+            txn.execute(sql, [event_id] + rel_type_args)
+            result = defaultdict(set)
+            for rel_type, sender, type in txn.fetchall():
+                result[rel_type].add((sender, type))
+            return result
+
+        return await self.db_pool.runInteraction(
+            "get_event_relations", _get_event_relations
+        )
+
 
 class RelationsStore(RelationsWorkerStore):
     pass