summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2019-05-20 14:06:40 +0100
committerGitHub <noreply@github.com>2019-05-20 14:06:40 +0100
commit57ba3451b6c22633c4aaf1bec19e91b66add65cd (patch)
tree1ffb24e6c61372dc6013fa7828ed8bc665ac2f99 /synapse/storage
parentLimit UserIds to a length that fits in a state key (#5198) (diff)
parentNewsfile (diff)
downloadsynapse-57ba3451b6c22633c4aaf1bec19e91b66add65cd.tar.xz
Merge pull request #5209 from matrix-org/erikj/reactions_base
Land basic reaction and edit support.
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/__init__.py2
-rw-r--r--synapse/storage/events.py17
-rw-r--r--synapse/storage/relations.py434
-rw-r--r--synapse/storage/schema/delta/54/relations.sql27
4 files changed, 476 insertions, 4 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index c432041b4e..7522d3fd57 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -49,6 +49,7 @@ from .pusher import PusherStore
 from .receipts import ReceiptsStore
 from .registration import RegistrationStore
 from .rejections import RejectionsStore
+from .relations import RelationsStore
 from .room import RoomStore
 from .roommember import RoomMemberStore
 from .search import SearchStore
@@ -99,6 +100,7 @@ class DataStore(
     GroupServerStore,
     UserErasureStore,
     MonthlyActiveUsersStore,
+    RelationsStore,
 ):
     def __init__(self, db_conn, hs):
         self.hs = hs
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 7a7f841c6c..881d6d0126 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -1325,6 +1325,9 @@ class EventsStore(
                     txn, event.room_id, event.redacts
                 )
 
+                # Remove from relations table.
+                self._handle_redaction(txn, event.redacts)
+
         # Update the event_forward_extremities, event_backward_extremities and
         # event_edges tables.
         self._handle_mult_prev_events(
@@ -1351,6 +1354,8 @@ class EventsStore(
                 # Insert into the event_search table.
                 self._store_guest_access_txn(txn, event)
 
+            self._handle_event_relations(txn, event)
+
         # Insert into the room_memberships table.
         self._store_room_members_txn(
             txn,
@@ -1655,10 +1660,11 @@ class EventsStore(
         def get_all_new_forward_event_rows(txn):
             sql = (
                 "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts"
+                " state_key, redacts, relates_to_id"
                 " FROM events AS e"
                 " LEFT JOIN redactions USING (event_id)"
                 " LEFT JOIN state_events USING (event_id)"
+                " LEFT JOIN event_relations USING (event_id)"
                 " WHERE ? < stream_ordering AND stream_ordering <= ?"
                 " ORDER BY stream_ordering ASC"
                 " LIMIT ?"
@@ -1673,11 +1679,12 @@ class EventsStore(
 
             sql = (
                 "SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts"
+                " state_key, redacts, relates_to_id"
                 " FROM events AS e"
                 " INNER JOIN ex_outlier_stream USING (event_id)"
                 " LEFT JOIN redactions USING (event_id)"
                 " LEFT JOIN state_events USING (event_id)"
+                " LEFT JOIN event_relations USING (event_id)"
                 " WHERE ? < event_stream_ordering"
                 " AND event_stream_ordering <= ?"
                 " ORDER BY event_stream_ordering DESC"
@@ -1698,10 +1705,11 @@ class EventsStore(
         def get_all_new_backfill_event_rows(txn):
             sql = (
                 "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts"
+                " state_key, redacts, relates_to_id"
                 " FROM events AS e"
                 " LEFT JOIN redactions USING (event_id)"
                 " LEFT JOIN state_events USING (event_id)"
+                " LEFT JOIN event_relations USING (event_id)"
                 " WHERE ? > stream_ordering AND stream_ordering >= ?"
                 " ORDER BY stream_ordering ASC"
                 " LIMIT ?"
@@ -1716,11 +1724,12 @@ class EventsStore(
 
             sql = (
                 "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts"
+                " state_key, redacts, relates_to_id"
                 " FROM events AS e"
                 " INNER JOIN ex_outlier_stream USING (event_id)"
                 " LEFT JOIN redactions USING (event_id)"
                 " LEFT JOIN state_events USING (event_id)"
+                " LEFT JOIN event_relations USING (event_id)"
                 " WHERE ? > event_stream_ordering"
                 " AND event_stream_ordering >= ?"
                 " ORDER BY event_stream_ordering DESC"
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
new file mode 100644
index 0000000000..63e6185ee3
--- /dev/null
+++ b/synapse/storage/relations.py
@@ -0,0 +1,434 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+import attr
+
+from twisted.internet import defer
+
+from synapse.api.constants import RelationTypes
+from synapse.api.errors import SynapseError
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.stream import generate_pagination_where_clause
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+
+logger = logging.getLogger(__name__)
+
+
+@attr.s
+class PaginationChunk(object):
+    """Returned by relation pagination APIs.
+
+    Attributes:
+        chunk (list): The rows returned by pagination
+        next_batch (Any|None): Token to fetch next set of results with, if
+            None then there are no more results.
+        prev_batch (Any|None): Token to fetch previous set of results with, if
+            None then there are no previous results.
+    """
+
+    chunk = attr.ib()
+    next_batch = attr.ib(default=None)
+    prev_batch = attr.ib(default=None)
+
+    def to_dict(self):
+        d = {"chunk": self.chunk}
+
+        if self.next_batch:
+            d["next_batch"] = self.next_batch.to_string()
+
+        if self.prev_batch:
+            d["prev_batch"] = self.prev_batch.to_string()
+
+        return d
+
+
+@attr.s(frozen=True, slots=True)
+class RelationPaginationToken(object):
+    """Pagination token for relation pagination API.
+
+    As the results are order by topological ordering, we can use the
+    `topological_ordering` and `stream_ordering` fields of the events at the
+    boundaries of the chunk as pagination tokens.
+
+    Attributes:
+        topological (int): The topological ordering of the boundary event
+        stream (int): The stream ordering of the boundary event.
+    """
+
+    topological = attr.ib()
+    stream = attr.ib()
+
+    @staticmethod
+    def from_string(string):
+        try:
+            t, s = string.split("-")
+            return RelationPaginationToken(int(t), int(s))
+        except ValueError:
+            raise SynapseError(400, "Invalid token")
+
+    def to_string(self):
+        return "%d-%d" % (self.topological, self.stream)
+
+    def as_tuple(self):
+        return attr.astuple(self)
+
+
+@attr.s(frozen=True, slots=True)
+class AggregationPaginationToken(object):
+    """Pagination token for relation aggregation pagination API.
+
+    As the results are order by count and then MAX(stream_ordering) of the
+    aggregation groups, we can just use them as our pagination token.
+
+    Attributes:
+        count (int): The count of relations in the boundar group.
+        stream (int): The MAX stream ordering in the boundary group.
+    """
+
+    count = attr.ib()
+    stream = attr.ib()
+
+    @staticmethod
+    def from_string(string):
+        try:
+            c, s = string.split("-")
+            return AggregationPaginationToken(int(c), int(s))
+        except ValueError:
+            raise SynapseError(400, "Invalid token")
+
+    def to_string(self):
+        return "%d-%d" % (self.count, self.stream)
+
+    def as_tuple(self):
+        return attr.astuple(self)
+
+
+class RelationsWorkerStore(SQLBaseStore):
+    @cached(tree=True)
+    def get_relations_for_event(
+        self,
+        event_id,
+        relation_type=None,
+        event_type=None,
+        aggregation_key=None,
+        limit=5,
+        direction="b",
+        from_token=None,
+        to_token=None,
+    ):
+        """Get a list of relations for an event, ordered by topological ordering.
+
+        Args:
+            event_id (str): Fetch events that relate to this event ID.
+            relation_type (str|None): Only fetch events with this relation
+                type, if given.
+            event_type (str|None): Only fetch events with this event type, if
+                given.
+            aggregation_key (str|None): Only fetch events with this aggregation
+                key, if given.
+            limit (int): Only fetch the most recent `limit` events.
+            direction (str): Whether to fetch the most recent first (`"b"`) or
+                the oldest first (`"f"`).
+            from_token (RelationPaginationToken|None): Fetch rows from the given
+                token, or from the start if None.
+            to_token (RelationPaginationToken|None): Fetch rows up to the given
+                token, or up to the end if None.
+
+        Returns:
+            Deferred[PaginationChunk]: List of event IDs that match relations
+            requested. The rows are of the form `{"event_id": "..."}`.
+        """
+
+        where_clause = ["relates_to_id = ?"]
+        where_args = [event_id]
+
+        if relation_type is not None:
+            where_clause.append("relation_type = ?")
+            where_args.append(relation_type)
+
+        if event_type is not None:
+            where_clause.append("type = ?")
+            where_args.append(event_type)
+
+        if aggregation_key:
+            where_clause.append("aggregation_key = ?")
+            where_args.append(aggregation_key)
+
+        pagination_clause = generate_pagination_where_clause(
+            direction=direction,
+            column_names=("topological_ordering", "stream_ordering"),
+            from_token=attr.astuple(from_token) if from_token else None,
+            to_token=attr.astuple(to_token) if to_token else None,
+            engine=self.database_engine,
+        )
+
+        if pagination_clause:
+            where_clause.append(pagination_clause)
+
+        if direction == "b":
+            order = "DESC"
+        else:
+            order = "ASC"
+
+        sql = """
+            SELECT event_id, topological_ordering, stream_ordering
+            FROM event_relations
+            INNER JOIN events USING (event_id)
+            WHERE %s
+            ORDER BY topological_ordering %s, stream_ordering %s
+            LIMIT ?
+        """ % (
+            " AND ".join(where_clause),
+            order,
+            order,
+        )
+
+        def _get_recent_references_for_event_txn(txn):
+            txn.execute(sql, where_args + [limit + 1])
+
+            last_topo_id = None
+            last_stream_id = None
+            events = []
+            for row in txn:
+                events.append({"event_id": row[0]})
+                last_topo_id = row[1]
+                last_stream_id = row[2]
+
+            next_batch = None
+            if len(events) > limit and last_topo_id and last_stream_id:
+                next_batch = RelationPaginationToken(last_topo_id, last_stream_id)
+
+            return PaginationChunk(
+                chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
+            )
+
+        return self.runInteraction(
+            "get_recent_references_for_event", _get_recent_references_for_event_txn
+        )
+
+    @cached(tree=True)
+    def get_aggregation_groups_for_event(
+        self,
+        event_id,
+        event_type=None,
+        limit=5,
+        direction="b",
+        from_token=None,
+        to_token=None,
+    ):
+        """Get a list of annotations on the event, grouped by event type and
+        aggregation key, sorted by count.
+
+        This is used e.g. to get the what and how many reactions have happend
+        on an event.
+
+        Args:
+            event_id (str): Fetch events that relate to this event ID.
+            event_type (str|None): Only fetch events with this event type, if
+                given.
+            limit (int): Only fetch the `limit` groups.
+            direction (str): Whether to fetch the highest count first (`"b"`) or
+                the lowest count first (`"f"`).
+            from_token (AggregationPaginationToken|None): Fetch rows from the
+                given token, or from the start if None.
+            to_token (AggregationPaginationToken|None): Fetch rows up to the
+                given token, or up to the end if None.
+
+
+        Returns:
+            Deferred[PaginationChunk]: List of groups of annotations that
+            match. Each row is a dict with `type`, `key` and `count` fields.
+        """
+
+        where_clause = ["relates_to_id = ?", "relation_type = ?"]
+        where_args = [event_id, RelationTypes.ANNOTATION]
+
+        if event_type:
+            where_clause.append("type = ?")
+            where_args.append(event_type)
+
+        having_clause = generate_pagination_where_clause(
+            direction=direction,
+            column_names=("COUNT(*)", "MAX(stream_ordering)"),
+            from_token=attr.astuple(from_token) if from_token else None,
+            to_token=attr.astuple(to_token) if to_token else None,
+            engine=self.database_engine,
+        )
+
+        if direction == "b":
+            order = "DESC"
+        else:
+            order = "ASC"
+
+        if having_clause:
+            having_clause = "HAVING " + having_clause
+        else:
+            having_clause = ""
+
+        sql = """
+            SELECT type, aggregation_key, COUNT(*), MAX(stream_ordering)
+            FROM event_relations
+            INNER JOIN events USING (event_id)
+            WHERE {where_clause}
+            GROUP BY relation_type, type, aggregation_key
+            {having_clause}
+            ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order}
+            LIMIT ?
+        """.format(
+            where_clause=" AND ".join(where_clause),
+            order=order,
+            having_clause=having_clause,
+        )
+
+        def _get_aggregation_groups_for_event_txn(txn):
+            txn.execute(sql, where_args + [limit + 1])
+
+            next_batch = None
+            events = []
+            for row in txn:
+                events.append({"type": row[0], "key": row[1], "count": row[2]})
+                next_batch = AggregationPaginationToken(row[2], row[3])
+
+            if len(events) <= limit:
+                next_batch = None
+
+            return PaginationChunk(
+                chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
+            )
+
+        return self.runInteraction(
+            "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
+        )
+
+    @cachedInlineCallbacks()
+    def get_applicable_edit(self, event_id):
+        """Get the most recent edit (if any) that has happened for the given
+        event.
+
+        Correctly handles checking whether edits were allowed to happen.
+
+        Args:
+            event_id (str): The original event ID
+
+        Returns:
+            Deferred[EventBase|None]: Returns the most recent edit, if any.
+        """
+
+        # We only allow edits for `m.room.message` events that have the same sender
+        # and event type. We can't assert these things during regular event auth so
+        # we have to do the checks post hoc.
+
+        # Fetches latest edit that has the same type and sender as the
+        # original, and is an `m.room.message`.
+        sql = """
+            SELECT edit.event_id FROM events AS edit
+            INNER JOIN event_relations USING (event_id)
+            INNER JOIN events AS original ON
+                original.event_id = relates_to_id
+                AND edit.type = original.type
+                AND edit.sender = original.sender
+            WHERE
+                relates_to_id = ?
+                AND relation_type = ?
+                AND edit.type = 'm.room.message'
+            ORDER by edit.origin_server_ts DESC, edit.event_id DESC
+            LIMIT 1
+        """
+
+        def _get_applicable_edit_txn(txn):
+            txn.execute(
+                sql, (event_id, RelationTypes.REPLACES,)
+            )
+            row = txn.fetchone()
+            if row:
+                return row[0]
+
+        edit_id = yield self.runInteraction(
+            "get_applicable_edit", _get_applicable_edit_txn
+        )
+
+        if not edit_id:
+            return
+
+        edit_event = yield self.get_event(edit_id, allow_none=True)
+        defer.returnValue(edit_event)
+
+
+class RelationsStore(RelationsWorkerStore):
+    def _handle_event_relations(self, txn, event):
+        """Handles inserting relation data during peristence of events
+
+        Args:
+            txn
+            event (EventBase)
+        """
+        relation = event.content.get("m.relates_to")
+        if not relation:
+            # No relations
+            return
+
+        rel_type = relation.get("rel_type")
+        if rel_type not in (
+            RelationTypes.ANNOTATION,
+            RelationTypes.REFERENCES,
+            RelationTypes.REPLACES,
+        ):
+            # Unknown relation type
+            return
+
+        parent_id = relation.get("event_id")
+        if not parent_id:
+            # Invalid relation
+            return
+
+        aggregation_key = relation.get("key")
+
+        self._simple_insert_txn(
+            txn,
+            table="event_relations",
+            values={
+                "event_id": event.event_id,
+                "relates_to_id": parent_id,
+                "relation_type": rel_type,
+                "aggregation_key": aggregation_key,
+            },
+        )
+
+        txn.call_after(self.get_relations_for_event.invalidate_many, (parent_id,))
+        txn.call_after(
+            self.get_aggregation_groups_for_event.invalidate_many, (parent_id,)
+        )
+
+        if rel_type == RelationTypes.REPLACES:
+            txn.call_after(self.get_applicable_edit.invalidate, (parent_id,))
+
+    def _handle_redaction(self, txn, redacted_event_id):
+        """Handles receiving a redaction and checking whether we need to remove
+        any redacted relations from the database.
+
+        Args:
+            txn
+            redacted_event_id (str): The event that was redacted.
+        """
+
+        self._simple_delete_txn(
+            txn,
+            table="event_relations",
+            keyvalues={
+                "event_id": redacted_event_id,
+            }
+        )
diff --git a/synapse/storage/schema/delta/54/relations.sql b/synapse/storage/schema/delta/54/relations.sql
new file mode 100644
index 0000000000..134862b870
--- /dev/null
+++ b/synapse/storage/schema/delta/54/relations.sql
@@ -0,0 +1,27 @@
+/* Copyright 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Tracks related events, like reactions, replies, edits, etc. Note that things
+-- in this table are not necessarily "valid", e.g. it may contain edits from
+-- people who don't have power to edit other peoples events.
+CREATE TABLE IF NOT EXISTS event_relations (
+    event_id TEXT NOT NULL,
+    relates_to_id TEXT NOT NULL,
+    relation_type TEXT NOT NULL,
+    aggregation_key TEXT
+);
+
+CREATE UNIQUE INDEX event_relations_id ON event_relations(event_id);
+CREATE INDEX event_relations_relates ON event_relations(relates_to_id, relation_type, aggregation_key);