summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/events/__init__.py45
-rw-r--r--synapse/handlers/message.py30
-rw-r--r--synapse/handlers/relations.py20
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py6
-rw-r--r--synapse/storage/databases/main/events.py49
5 files changed, 91 insertions, 59 deletions
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index c238376caf..39ad2793d9 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 
 import abc
+import collections.abc
 import os
 from typing import (
     TYPE_CHECKING,
@@ -32,9 +33,11 @@ from typing import (
     overload,
 )
 
+import attr
 from typing_extensions import Literal
 from unpaddedbase64 import encode_base64
 
+from synapse.api.constants import RelationTypes
 from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions
 from synapse.types import JsonDict, RoomStreamToken
 from synapse.util.caches import intern_dict
@@ -615,3 +618,45 @@ def make_event_from_dict(
     return event_type(
         event_dict, room_version, internal_metadata_dict or {}, rejected_reason
     )
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _EventRelation:
+    # The target event of the relation.
+    parent_id: str
+    # The relation type.
+    rel_type: str
+    # The aggregation key. Will be None if the rel_type is not m.annotation or is
+    # not a string.
+    aggregation_key: Optional[str]
+
+
+def relation_from_event(event: EventBase) -> Optional[_EventRelation]:
+    """
+    Attempt to parse relation information an event.
+
+    Returns:
+        The event relation information, if it is valid. None, otherwise.
+    """
+    relation = event.content.get("m.relates_to")
+    if not relation or not isinstance(relation, collections.abc.Mapping):
+        # No relation information.
+        return None
+
+    # Relations must have a type and parent event ID.
+    rel_type = relation.get("rel_type")
+    if not isinstance(rel_type, str):
+        return None
+
+    parent_id = relation.get("event_id")
+    if not isinstance(parent_id, str):
+        return None
+
+    # Annotations have a key field.
+    aggregation_key = None
+    if rel_type == RelationTypes.ANNOTATION:
+        aggregation_key = relation.get("key")
+        if not isinstance(aggregation_key, str):
+            aggregation_key = None
+
+    return _EventRelation(parent_id, rel_type, aggregation_key)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 4a4b535bae..0951b9c71f 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -44,7 +44,7 @@ from synapse.api.errors import (
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
 from synapse.api.urls import ConsentURIBuilder
 from synapse.event_auth import validate_event_for_room_version
-from synapse.events import EventBase
+from synapse.events import EventBase, relation_from_event
 from synapse.events.builder import EventBuilder
 from synapse.events.snapshot import EventContext
 from synapse.events.validator import EventValidator
@@ -1060,20 +1060,11 @@ class EventCreationHandler:
             SynapseError if the event is invalid.
         """
 
-        relation = event.content.get("m.relates_to")
+        relation = relation_from_event(event)
         if not relation:
             return
 
-        relation_type = relation.get("rel_type")
-        if not relation_type:
-            return
-
-        # Ensure the parent is real.
-        relates_to = relation.get("event_id")
-        if not relates_to:
-            return
-
-        parent_event = await self.store.get_event(relates_to, allow_none=True)
+        parent_event = await self.store.get_event(relation.parent_id, allow_none=True)
         if parent_event:
             # And in the same room.
             if parent_event.room_id != event.room_id:
@@ -1082,28 +1073,31 @@ class EventCreationHandler:
         else:
             # There must be some reason that the client knows the event exists,
             # see if there are existing relations. If so, assume everything is fine.
-            if not await self.store.event_is_target_of_relation(relates_to):
+            if not await self.store.event_is_target_of_relation(relation.parent_id):
                 # Otherwise, the client can't know about the parent event!
                 raise SynapseError(400, "Can't send relation to unknown event")
 
         # If this event is an annotation then we check that that the sender
         # can't annotate the same way twice (e.g. stops users from liking an
         # event multiple times).
-        if relation_type == RelationTypes.ANNOTATION:
-            aggregation_key = relation["key"]
+        if relation.rel_type == RelationTypes.ANNOTATION:
+            aggregation_key = relation.aggregation_key
+
+            if aggregation_key is None:
+                raise SynapseError(400, "Missing aggregation key")
 
             if len(aggregation_key) > 500:
                 raise SynapseError(400, "Aggregation key is too long")
 
             already_exists = await self.store.has_user_annotated_event(
-                relates_to, event.type, aggregation_key, event.sender
+                relation.parent_id, event.type, aggregation_key, event.sender
             )
             if already_exists:
                 raise SynapseError(400, "Can't send same reaction twice")
 
         # Don't attempt to start a thread if the parent event is a relation.
-        elif relation_type == RelationTypes.THREAD:
-            if await self.store.event_includes_relation(relates_to):
+        elif relation.rel_type == RelationTypes.THREAD:
+            if await self.store.event_includes_relation(relation.parent_id):
                 raise SynapseError(
                     400, "Cannot start threads from an event with a relation"
                 )
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index c2754ec918..ab7e54857d 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -11,7 +11,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import collections.abc
 import logging
 from typing import (
     TYPE_CHECKING,
@@ -28,7 +27,7 @@ import attr
 
 from synapse.api.constants import RelationTypes
 from synapse.api.errors import SynapseError
-from synapse.events import EventBase
+from synapse.events import EventBase, relation_from_event
 from synapse.storage.databases.main.relations import _RelatedEvent
 from synapse.types import JsonDict, Requester, StreamToken, UserID
 from synapse.visibility import filter_events_for_client
@@ -373,20 +372,21 @@ class RelationsHandler:
             if event.is_state():
                 continue
 
-            relates_to = event.content.get("m.relates_to")
-            relation_type = None
-            if isinstance(relates_to, collections.abc.Mapping):
-                relation_type = relates_to.get("rel_type")
+            relates_to = relation_from_event(event)
+            if relates_to:
                 # An event which is a replacement (ie edit) or annotation (ie,
                 # reaction) may not have any other event related to it.
-                if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
+                if relates_to.rel_type in (
+                    RelationTypes.ANNOTATION,
+                    RelationTypes.REPLACE,
+                ):
                     continue
 
+                # Track the event's relation information for later.
+                relations_by_id[event.event_id] = relates_to.rel_type
+
             # The event should get bundled aggregations.
             events_by_id[event.event_id] = event
-            # Track the event's relation information for later.
-            if isinstance(relation_type, str):
-                relations_by_id[event.event_id] = relation_type
 
         # event ID -> bundled aggregation in non-serialized form.
         results: Dict[str, BundledAggregations] = {}
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 0ffafc882b..4ac2c546bf 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -21,7 +21,7 @@ from prometheus_client import Counter
 
 from synapse.api.constants import EventTypes, Membership, RelationTypes
 from synapse.event_auth import get_user_power_level
-from synapse.events import EventBase
+from synapse.events import EventBase, relation_from_event
 from synapse.events.snapshot import EventContext
 from synapse.state import POWER_KEY
 from synapse.storage.databases.main.roommember import EventIdMembership
@@ -78,8 +78,8 @@ def _should_count_as_unread(event: EventBase, context: EventContext) -> bool:
         return False
 
     # Exclude edits.
-    relates_to = event.content.get("m.relates_to", {})
-    if relates_to.get("rel_type") == RelationTypes.REPLACE:
+    relates_to = relation_from_event(event)
+    if relates_to and relates_to.rel_type == RelationTypes.REPLACE:
         return False
 
     # Mark events that have a non-empty string body as unread.
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index f544bcfff0..42d484dc98 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -36,8 +36,8 @@ from prometheus_client import Counter
 import synapse.metrics
 from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
 from synapse.api.room_versions import RoomVersions
-from synapse.events import EventBase  # noqa: F401
-from synapse.events.snapshot import EventContext  # noqa: F401
+from synapse.events import EventBase, relation_from_event
+from synapse.events.snapshot import EventContext
 from synapse.storage._base import db_to_json, make_in_list_sql_clause
 from synapse.storage.database import (
     DatabasePool,
@@ -1807,52 +1807,45 @@ class PersistEventsStore:
             txn: The current database transaction.
             event: The event which might have relations.
         """
-        relation = event.content.get("m.relates_to")
+        relation = relation_from_event(event)
         if not relation:
-            # No relations
+            # No relation, nothing to do.
             return
 
-        # Relations must have a type and parent event ID.
-        rel_type = relation.get("rel_type")
-        if not isinstance(rel_type, str):
-            return
-
-        parent_id = relation.get("event_id")
-        if not isinstance(parent_id, str):
-            return
-
-        # Annotations have a key field.
-        aggregation_key = None
-        if rel_type == RelationTypes.ANNOTATION:
-            aggregation_key = relation.get("key")
-
         self.db_pool.simple_insert_txn(
             txn,
             table="event_relations",
             values={
                 "event_id": event.event_id,
-                "relates_to_id": parent_id,
-                "relation_type": rel_type,
-                "aggregation_key": aggregation_key,
+                "relates_to_id": relation.parent_id,
+                "relation_type": relation.rel_type,
+                "aggregation_key": relation.aggregation_key,
             },
         )
 
-        txn.call_after(self.store.get_relations_for_event.invalidate, (parent_id,))
         txn.call_after(
-            self.store.get_aggregation_groups_for_event.invalidate, (parent_id,)
+            self.store.get_relations_for_event.invalidate, (relation.parent_id,)
+        )
+        txn.call_after(
+            self.store.get_aggregation_groups_for_event.invalidate,
+            (relation.parent_id,),
         )
 
-        if rel_type == RelationTypes.REPLACE:
-            txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
+        if relation.rel_type == RelationTypes.REPLACE:
+            txn.call_after(
+                self.store.get_applicable_edit.invalidate, (relation.parent_id,)
+            )
 
-        if rel_type == RelationTypes.THREAD:
-            txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,))
+        if relation.rel_type == RelationTypes.THREAD:
+            txn.call_after(
+                self.store.get_thread_summary.invalidate, (relation.parent_id,)
+            )
             # It should be safe to only invalidate the cache if the user has not
             # previously participated in the thread, but that's difficult (and
             # potentially error-prone) so it is always invalidated.
             txn.call_after(
                 self.store.get_thread_participated.invalidate,
-                (parent_id, event.sender),
+                (relation.parent_id, event.sender),
             )
 
     def _handle_insertion_event(