summary refs log tree commit diff
path: root/synapse/storage/databases/main/relations.py
diff options
context:
space:
mode:
authorBrendan Abolivier <babolivier@matrix.org>2021-11-25 15:16:47 +0000
committerBrendan Abolivier <babolivier@matrix.org>2021-11-25 15:16:47 +0000
commitcb79a2b78546ffcecc6b8fad6664535c8dbf4aec (patch)
tree104072b132e9d415d2d6fdd388629bc75d02b402 /synapse/storage/databases/main/relations.py
parentPrevent the media store from writing outside of the configured directory (diff)
parentImprove performance of `remove_{hidden,deleted}_devices_from_device_inbox` (#... (diff)
downloadsynapse-cb79a2b78546ffcecc6b8fad6664535c8dbf4aec.tar.xz
Merge branch 'develop' into matrix-org-hotfixes
Diffstat (limited to 'synapse/storage/databases/main/relations.py')
-rw-r--r--synapse/storage/databases/main/relations.py121
1 files changed, 120 insertions, 1 deletions
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py

index 53576ad52f..0a43acda07 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py
@@ -20,7 +20,7 @@ import attr from synapse.api.constants import RelationTypes from synapse.events import EventBase from synapse.storage._base import SQLBaseStore -from synapse.storage.database import LoggingTransaction +from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.relations import ( AggregationPaginationToken, @@ -132,6 +132,69 @@ class RelationsWorkerStore(SQLBaseStore): "get_recent_references_for_event", _get_recent_references_for_event_txn ) + async def event_includes_relation(self, event_id: str) -> bool: + """Check if the given event relates to another event. + + An event has a relation if it has a valid m.relates_to with a rel_type + and event_id in the content: + + { + "content": { + "m.relates_to": { + "rel_type": "m.replace", + "event_id": "$other_event_id" + } + } + } + + Args: + event_id: The event to check. + + Returns: + True if the event includes a valid relation. + """ + + result = await self.db_pool.simple_select_one_onecol( + table="event_relations", + keyvalues={"event_id": event_id}, + retcol="event_id", + allow_none=True, + desc="event_includes_relation", + ) + return result is not None + + async def event_is_target_of_relation(self, parent_id: str) -> bool: + """Check if the given event is the target of another event's relation. + + An event is the target of an event relation if it has a valid + m.relates_to with a rel_type and event_id pointing to parent_id in the + content: + + { + "content": { + "m.relates_to": { + "rel_type": "m.replace", + "event_id": "$parent_id" + } + } + } + + Args: + parent_id: The event to check. + + Returns: + True if the event is the target of another event's relation. + """ + + result = await self.db_pool.simple_select_one_onecol( + table="event_relations", + keyvalues={"relates_to_id": parent_id}, + retcol="event_id", + allow_none=True, + desc="event_is_target_of_relation", + ) + return result is not None + @cached(tree=True) async def get_aggregation_groups_for_event( self, @@ -334,6 +397,62 @@ class RelationsWorkerStore(SQLBaseStore): return count, latest_event + async def events_have_relations( + self, + parent_ids: List[str], + relation_senders: Optional[List[str]], + relation_types: Optional[List[str]], + ) -> List[str]: + """Check which events have a relationship from the given senders of the + given types. + + Args: + parent_ids: The events being annotated + relation_senders: The relation senders to check. + relation_types: The relation types to check. + + Returns: + True if the event has at least one relationship from one of the given senders of the given type. + """ + # If no restrictions are given then the event has the required relations. + if not relation_senders and not relation_types: + return parent_ids + + sql = """ + SELECT relates_to_id FROM event_relations + INNER JOIN events USING (event_id) + WHERE + %s; + """ + + def _get_if_events_have_relations(txn) -> List[str]: + clauses: List[str] = [] + clause, args = make_in_list_sql_clause( + txn.database_engine, "relates_to_id", parent_ids + ) + clauses.append(clause) + + if relation_senders: + clause, temp_args = make_in_list_sql_clause( + txn.database_engine, "sender", relation_senders + ) + clauses.append(clause) + args.extend(temp_args) + if relation_types: + clause, temp_args = make_in_list_sql_clause( + txn.database_engine, "relation_type", relation_types + ) + clauses.append(clause) + args.extend(temp_args) + + txn.execute(sql % " AND ".join(clauses), args) + + return [row[0] for row in txn] + + return await self.db_pool.runInteraction( + "get_if_events_have_relations", _get_if_events_have_relations + ) + async def has_user_annotated_event( self, parent_id: str, event_type: str, aggregation_key: str, sender: str ) -> bool: