diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index c432041b4e..7522d3fd57 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -49,6 +49,7 @@ from .pusher import PusherStore
from .receipts import ReceiptsStore
from .registration import RegistrationStore
from .rejections import RejectionsStore
+from .relations import RelationsStore
from .room import RoomStore
from .roommember import RoomMemberStore
from .search import SearchStore
@@ -99,6 +100,7 @@ class DataStore(
GroupServerStore,
UserErasureStore,
MonthlyActiveUsersStore,
+ RelationsStore,
):
def __init__(self, db_conn, hs):
self.hs = hs
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 7a7f841c6c..881d6d0126 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -1325,6 +1325,9 @@ class EventsStore(
txn, event.room_id, event.redacts
)
+ # Remove from relations table.
+ self._handle_redaction(txn, event.redacts)
+
# Update the event_forward_extremities, event_backward_extremities and
# event_edges tables.
self._handle_mult_prev_events(
@@ -1351,6 +1354,8 @@ class EventsStore(
# Insert into the event_search table.
self._store_guest_access_txn(txn, event)
+ self._handle_event_relations(txn, event)
+
# Insert into the room_memberships table.
self._store_room_members_txn(
txn,
@@ -1655,10 +1660,11 @@ class EventsStore(
def get_all_new_forward_event_rows(txn):
sql = (
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts"
+ " state_key, redacts, relates_to_id"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN event_relations USING (event_id)"
" WHERE ? < stream_ordering AND stream_ordering <= ?"
" ORDER BY stream_ordering ASC"
" LIMIT ?"
@@ -1673,11 +1679,12 @@ class EventsStore(
sql = (
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts"
+ " state_key, redacts, relates_to_id"
" FROM events AS e"
" INNER JOIN ex_outlier_stream USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN event_relations USING (event_id)"
" WHERE ? < event_stream_ordering"
" AND event_stream_ordering <= ?"
" ORDER BY event_stream_ordering DESC"
@@ -1698,10 +1705,11 @@ class EventsStore(
def get_all_new_backfill_event_rows(txn):
sql = (
"SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts"
+ " state_key, redacts, relates_to_id"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN event_relations USING (event_id)"
" WHERE ? > stream_ordering AND stream_ordering >= ?"
" ORDER BY stream_ordering ASC"
" LIMIT ?"
@@ -1716,11 +1724,12 @@ class EventsStore(
sql = (
"SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts"
+ " state_key, redacts, relates_to_id"
" FROM events AS e"
" INNER JOIN ex_outlier_stream USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN event_relations USING (event_id)"
" WHERE ? > event_stream_ordering"
" AND event_stream_ordering >= ?"
" ORDER BY event_stream_ordering DESC"
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
new file mode 100644
index 0000000000..63e6185ee3
--- /dev/null
+++ b/synapse/storage/relations.py
@@ -0,0 +1,434 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+import attr
+
+from twisted.internet import defer
+
+from synapse.api.constants import RelationTypes
+from synapse.api.errors import SynapseError
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.stream import generate_pagination_where_clause
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+
+logger = logging.getLogger(__name__)
+
+
+@attr.s
+class PaginationChunk(object):
+ """Returned by relation pagination APIs.
+
+ Attributes:
+ chunk (list): The rows returned by pagination
+ next_batch (Any|None): Token to fetch next set of results with, if
+ None then there are no more results.
+ prev_batch (Any|None): Token to fetch previous set of results with, if
+ None then there are no previous results.
+ """
+
+ chunk = attr.ib()
+ next_batch = attr.ib(default=None)
+ prev_batch = attr.ib(default=None)
+
+ def to_dict(self):
+ d = {"chunk": self.chunk}
+
+ if self.next_batch:
+ d["next_batch"] = self.next_batch.to_string()
+
+ if self.prev_batch:
+ d["prev_batch"] = self.prev_batch.to_string()
+
+ return d
+
+
+@attr.s(frozen=True, slots=True)
+class RelationPaginationToken(object):
+ """Pagination token for relation pagination API.
+
+ As the results are order by topological ordering, we can use the
+ `topological_ordering` and `stream_ordering` fields of the events at the
+ boundaries of the chunk as pagination tokens.
+
+ Attributes:
+ topological (int): The topological ordering of the boundary event
+ stream (int): The stream ordering of the boundary event.
+ """
+
+ topological = attr.ib()
+ stream = attr.ib()
+
+ @staticmethod
+ def from_string(string):
+ try:
+ t, s = string.split("-")
+ return RelationPaginationToken(int(t), int(s))
+ except ValueError:
+ raise SynapseError(400, "Invalid token")
+
+ def to_string(self):
+ return "%d-%d" % (self.topological, self.stream)
+
+ def as_tuple(self):
+ return attr.astuple(self)
+
+
+@attr.s(frozen=True, slots=True)
+class AggregationPaginationToken(object):
+ """Pagination token for relation aggregation pagination API.
+
+ As the results are order by count and then MAX(stream_ordering) of the
+ aggregation groups, we can just use them as our pagination token.
+
+ Attributes:
+ count (int): The count of relations in the boundar group.
+ stream (int): The MAX stream ordering in the boundary group.
+ """
+
+ count = attr.ib()
+ stream = attr.ib()
+
+ @staticmethod
+ def from_string(string):
+ try:
+ c, s = string.split("-")
+ return AggregationPaginationToken(int(c), int(s))
+ except ValueError:
+ raise SynapseError(400, "Invalid token")
+
+ def to_string(self):
+ return "%d-%d" % (self.count, self.stream)
+
+ def as_tuple(self):
+ return attr.astuple(self)
+
+
+class RelationsWorkerStore(SQLBaseStore):
+ @cached(tree=True)
+ def get_relations_for_event(
+ self,
+ event_id,
+ relation_type=None,
+ event_type=None,
+ aggregation_key=None,
+ limit=5,
+ direction="b",
+ from_token=None,
+ to_token=None,
+ ):
+ """Get a list of relations for an event, ordered by topological ordering.
+
+ Args:
+ event_id (str): Fetch events that relate to this event ID.
+ relation_type (str|None): Only fetch events with this relation
+ type, if given.
+ event_type (str|None): Only fetch events with this event type, if
+ given.
+ aggregation_key (str|None): Only fetch events with this aggregation
+ key, if given.
+ limit (int): Only fetch the most recent `limit` events.
+ direction (str): Whether to fetch the most recent first (`"b"`) or
+ the oldest first (`"f"`).
+ from_token (RelationPaginationToken|None): Fetch rows from the given
+ token, or from the start if None.
+ to_token (RelationPaginationToken|None): Fetch rows up to the given
+ token, or up to the end if None.
+
+ Returns:
+ Deferred[PaginationChunk]: List of event IDs that match relations
+ requested. The rows are of the form `{"event_id": "..."}`.
+ """
+
+ where_clause = ["relates_to_id = ?"]
+ where_args = [event_id]
+
+ if relation_type is not None:
+ where_clause.append("relation_type = ?")
+ where_args.append(relation_type)
+
+ if event_type is not None:
+ where_clause.append("type = ?")
+ where_args.append(event_type)
+
+ if aggregation_key:
+ where_clause.append("aggregation_key = ?")
+ where_args.append(aggregation_key)
+
+ pagination_clause = generate_pagination_where_clause(
+ direction=direction,
+ column_names=("topological_ordering", "stream_ordering"),
+ from_token=attr.astuple(from_token) if from_token else None,
+ to_token=attr.astuple(to_token) if to_token else None,
+ engine=self.database_engine,
+ )
+
+ if pagination_clause:
+ where_clause.append(pagination_clause)
+
+ if direction == "b":
+ order = "DESC"
+ else:
+ order = "ASC"
+
+ sql = """
+ SELECT event_id, topological_ordering, stream_ordering
+ FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE %s
+ ORDER BY topological_ordering %s, stream_ordering %s
+ LIMIT ?
+ """ % (
+ " AND ".join(where_clause),
+ order,
+ order,
+ )
+
+ def _get_recent_references_for_event_txn(txn):
+ txn.execute(sql, where_args + [limit + 1])
+
+ last_topo_id = None
+ last_stream_id = None
+ events = []
+ for row in txn:
+ events.append({"event_id": row[0]})
+ last_topo_id = row[1]
+ last_stream_id = row[2]
+
+ next_batch = None
+ if len(events) > limit and last_topo_id and last_stream_id:
+ next_batch = RelationPaginationToken(last_topo_id, last_stream_id)
+
+ return PaginationChunk(
+ chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
+ )
+
+ return self.runInteraction(
+ "get_recent_references_for_event", _get_recent_references_for_event_txn
+ )
+
+ @cached(tree=True)
+ def get_aggregation_groups_for_event(
+ self,
+ event_id,
+ event_type=None,
+ limit=5,
+ direction="b",
+ from_token=None,
+ to_token=None,
+ ):
+ """Get a list of annotations on the event, grouped by event type and
+ aggregation key, sorted by count.
+
+ This is used e.g. to get the what and how many reactions have happend
+ on an event.
+
+ Args:
+ event_id (str): Fetch events that relate to this event ID.
+ event_type (str|None): Only fetch events with this event type, if
+ given.
+ limit (int): Only fetch the `limit` groups.
+ direction (str): Whether to fetch the highest count first (`"b"`) or
+ the lowest count first (`"f"`).
+ from_token (AggregationPaginationToken|None): Fetch rows from the
+ given token, or from the start if None.
+ to_token (AggregationPaginationToken|None): Fetch rows up to the
+ given token, or up to the end if None.
+
+
+ Returns:
+ Deferred[PaginationChunk]: List of groups of annotations that
+ match. Each row is a dict with `type`, `key` and `count` fields.
+ """
+
+ where_clause = ["relates_to_id = ?", "relation_type = ?"]
+ where_args = [event_id, RelationTypes.ANNOTATION]
+
+ if event_type:
+ where_clause.append("type = ?")
+ where_args.append(event_type)
+
+ having_clause = generate_pagination_where_clause(
+ direction=direction,
+ column_names=("COUNT(*)", "MAX(stream_ordering)"),
+ from_token=attr.astuple(from_token) if from_token else None,
+ to_token=attr.astuple(to_token) if to_token else None,
+ engine=self.database_engine,
+ )
+
+ if direction == "b":
+ order = "DESC"
+ else:
+ order = "ASC"
+
+ if having_clause:
+ having_clause = "HAVING " + having_clause
+ else:
+ having_clause = ""
+
+ sql = """
+ SELECT type, aggregation_key, COUNT(*), MAX(stream_ordering)
+ FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE {where_clause}
+ GROUP BY relation_type, type, aggregation_key
+ {having_clause}
+ ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order}
+ LIMIT ?
+ """.format(
+ where_clause=" AND ".join(where_clause),
+ order=order,
+ having_clause=having_clause,
+ )
+
+ def _get_aggregation_groups_for_event_txn(txn):
+ txn.execute(sql, where_args + [limit + 1])
+
+ next_batch = None
+ events = []
+ for row in txn:
+ events.append({"type": row[0], "key": row[1], "count": row[2]})
+ next_batch = AggregationPaginationToken(row[2], row[3])
+
+ if len(events) <= limit:
+ next_batch = None
+
+ return PaginationChunk(
+ chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
+ )
+
+ return self.runInteraction(
+ "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
+ )
+
+ @cachedInlineCallbacks()
+ def get_applicable_edit(self, event_id):
+ """Get the most recent edit (if any) that has happened for the given
+ event.
+
+ Correctly handles checking whether edits were allowed to happen.
+
+ Args:
+ event_id (str): The original event ID
+
+ Returns:
+ Deferred[EventBase|None]: Returns the most recent edit, if any.
+ """
+
+ # We only allow edits for `m.room.message` events that have the same sender
+ # and event type. We can't assert these things during regular event auth so
+ # we have to do the checks post hoc.
+
+ # Fetches latest edit that has the same type and sender as the
+ # original, and is an `m.room.message`.
+ sql = """
+ SELECT edit.event_id FROM events AS edit
+ INNER JOIN event_relations USING (event_id)
+ INNER JOIN events AS original ON
+ original.event_id = relates_to_id
+ AND edit.type = original.type
+ AND edit.sender = original.sender
+ WHERE
+ relates_to_id = ?
+ AND relation_type = ?
+ AND edit.type = 'm.room.message'
+ ORDER by edit.origin_server_ts DESC, edit.event_id DESC
+ LIMIT 1
+ """
+
+ def _get_applicable_edit_txn(txn):
+ txn.execute(
+ sql, (event_id, RelationTypes.REPLACES,)
+ )
+ row = txn.fetchone()
+ if row:
+ return row[0]
+
+ edit_id = yield self.runInteraction(
+ "get_applicable_edit", _get_applicable_edit_txn
+ )
+
+ if not edit_id:
+ return
+
+ edit_event = yield self.get_event(edit_id, allow_none=True)
+ defer.returnValue(edit_event)
+
+
+class RelationsStore(RelationsWorkerStore):
+ def _handle_event_relations(self, txn, event):
+ """Handles inserting relation data during peristence of events
+
+ Args:
+ txn
+ event (EventBase)
+ """
+ relation = event.content.get("m.relates_to")
+ if not relation:
+ # No relations
+ return
+
+ rel_type = relation.get("rel_type")
+ if rel_type not in (
+ RelationTypes.ANNOTATION,
+ RelationTypes.REFERENCES,
+ RelationTypes.REPLACES,
+ ):
+ # Unknown relation type
+ return
+
+ parent_id = relation.get("event_id")
+ if not parent_id:
+ # Invalid relation
+ return
+
+ aggregation_key = relation.get("key")
+
+ self._simple_insert_txn(
+ txn,
+ table="event_relations",
+ values={
+ "event_id": event.event_id,
+ "relates_to_id": parent_id,
+ "relation_type": rel_type,
+ "aggregation_key": aggregation_key,
+ },
+ )
+
+ txn.call_after(self.get_relations_for_event.invalidate_many, (parent_id,))
+ txn.call_after(
+ self.get_aggregation_groups_for_event.invalidate_many, (parent_id,)
+ )
+
+ if rel_type == RelationTypes.REPLACES:
+ txn.call_after(self.get_applicable_edit.invalidate, (parent_id,))
+
+ def _handle_redaction(self, txn, redacted_event_id):
+ """Handles receiving a redaction and checking whether we need to remove
+ any redacted relations from the database.
+
+ Args:
+ txn
+ redacted_event_id (str): The event that was redacted.
+ """
+
+ self._simple_delete_txn(
+ txn,
+ table="event_relations",
+ keyvalues={
+ "event_id": redacted_event_id,
+ }
+ )
diff --git a/synapse/storage/schema/delta/54/relations.sql b/synapse/storage/schema/delta/54/relations.sql
new file mode 100644
index 0000000000..134862b870
--- /dev/null
+++ b/synapse/storage/schema/delta/54/relations.sql
@@ -0,0 +1,27 @@
+/* Copyright 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Tracks related events, like reactions, replies, edits, etc. Note that things
+-- in this table are not necessarily "valid", e.g. it may contain edits from
+-- people who don't have power to edit other peoples events.
+CREATE TABLE IF NOT EXISTS event_relations (
+ event_id TEXT NOT NULL,
+ relates_to_id TEXT NOT NULL,
+ relation_type TEXT NOT NULL,
+ aggregation_key TEXT
+);
+
+CREATE UNIQUE INDEX event_relations_id ON event_relations(event_id);
+CREATE INDEX event_relations_relates ON event_relations(relates_to_id, relation_type, aggregation_key);
|