diff options
Diffstat (limited to 'synapse/storage')
-rw-r--r-- | synapse/storage/databases/main/events.py | 9 | ||||
-rw-r--r-- | synapse/storage/databases/main/push_rule.py | 5 | ||||
-rw-r--r-- | synapse/storage/databases/main/relations.py | 52 |
3 files changed, 66 insertions, 0 deletions
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 0df8ff5395..17e35cf63e 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1828,6 +1828,10 @@ class PersistEventsStore: self.store.get_aggregation_groups_for_event.invalidate, (relation.parent_id,), ) + txn.call_after( + self.store.get_mutual_event_relations_for_rel_type.invalidate, + (relation.parent_id,), + ) if relation.rel_type == RelationTypes.REPLACE: txn.call_after( @@ -2004,6 +2008,11 @@ class PersistEventsStore: self.store._invalidate_cache_and_stream( txn, self.store.get_thread_participated, (redacted_relates_to,) ) + self.store._invalidate_cache_and_stream( + txn, + self.store.get_mutual_event_relations_for_rel_type, + (redacted_relates_to,), + ) self.db_pool.simple_delete_txn( txn, table="event_relations", keyvalues={"event_id": redacted_event_id} diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index ad67901cc1..4adabc88cc 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -61,6 +61,11 @@ def _is_experimental_rule_enabled( and not experimental_config.msc3786_enabled ): return False + if ( + rule_id == "global/underride/.org.matrix.msc3772.thread_reply" + and not experimental_config.msc3772_enabled + ): + return False return True diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index fe8fded88b..3b1b2ce6cb 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +from collections import defaultdict from typing import ( Collection, Dict, @@ -767,6 +768,57 @@ class RelationsWorkerStore(SQLBaseStore): "get_if_user_has_annotated_event", _get_if_user_has_annotated_event ) + @cached(iterable=True) + async def get_mutual_event_relations_for_rel_type( + self, event_id: str, relation_type: str + ) -> Set[Tuple[str, str]]: + raise NotImplementedError() + + @cachedList( + cached_method_name="get_mutual_event_relations_for_rel_type", + list_name="relation_types", + ) + async def get_mutual_event_relations( + self, event_id: str, relation_types: Collection[str] + ) -> Dict[str, Set[Tuple[str, str]]]: + """ + Fetch event metadata for events which related to the same event as the given event. + + If the given event has no relation information, returns an empty dictionary. + + Args: + event_id: The event ID which is targeted by relations. + relation_types: The relation types to check for mutual relations. + + Returns: + A dictionary of relation type to: + A set of tuples of: + The sender + The event type + """ + rel_type_sql, rel_type_args = make_in_list_sql_clause( + self.database_engine, "relation_type", relation_types + ) + + sql = f""" + SELECT DISTINCT relation_type, sender, type FROM event_relations + INNER JOIN events USING (event_id) + WHERE relates_to_id = ? AND {rel_type_sql} + """ + + def _get_event_relations( + txn: LoggingTransaction, + ) -> Dict[str, Set[Tuple[str, str]]]: + txn.execute(sql, [event_id] + rel_type_args) + result = defaultdict(set) + for rel_type, sender, type in txn.fetchall(): + result[rel_type].add((sender, type)) + return result + + return await self.db_pool.runInteraction( + "get_event_relations", _get_event_relations + ) + class RelationsStore(RelationsWorkerStore): pass |