diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index a4905162e0..f304b8a9a4 100644
--- a/synapse/storage/relations.py
+++ b/synapse/storage/relations.py
@@ -15,13 +15,93 @@
import logging
+import attr
+
from synapse.api.constants import RelationTypes
from synapse.storage._base import SQLBaseStore
logger = logging.getLogger(__name__)
+@attr.s
+class PaginationChunk(object):
+ """Returned by relation pagination APIs.
+
+ Attributes:
+ chunk (list): The rows returned by pagination
+ """
+
+ chunk = attr.ib()
+
+ def to_dict(self):
+ d = {"chunk": self.chunk}
+
+ return d
+
+
class RelationsStore(SQLBaseStore):
+ def get_relations_for_event(
+ self, event_id, relation_type=None, event_type=None, limit=5, direction="b"
+ ):
+ """Get a list of relations for an event, ordered by topological ordering.
+
+ Args:
+ event_id (str): Fetch events that relate to this event ID.
+ relation_type (str|None): Only fetch events with this relation
+ type, if given.
+ event_type (str|None): Only fetch events with this event type, if
+ given.
+ limit (int): Only fetch the most recent `limit` events.
+ direction (str): Whether to fetch the most recent first (`"b"`) or
+ the oldest first (`"f"`).
+
+ Returns:
+ Deferred[PaginationChunk]: List of event IDs that match relations
+ requested. The rows are of the form `{"event_id": "..."}`.
+ """
+
+ # TODO: Pagination tokens
+
+ where_clause = ["relates_to_id = ?"]
+ where_args = [event_id]
+
+ if relation_type:
+ where_clause.append("relation_type = ?")
+ where_args.append(relation_type)
+
+ if event_type:
+ where_clause.append("type = ?")
+ where_args.append(event_type)
+
+ order = "ASC"
+ if direction == "b":
+ order = "DESC"
+
+ sql = """
+ SELECT event_id FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE %s
+ ORDER BY topological_ordering %s, stream_ordering %s
+ LIMIT ?
+ """ % (
+ " AND ".join(where_clause),
+ order,
+ order,
+ )
+
+ def _get_recent_references_for_event_txn(txn):
+ txn.execute(sql, where_args + [limit + 1])
+
+ events = [{"event_id": row[0]} for row in txn]
+
+ return PaginationChunk(
+ chunk=list(events[:limit]),
+ )
+
+ return self.runInteraction(
+ "get_recent_references_for_event", _get_recent_references_for_event_txn
+ )
+
def _handle_event_relations(self, txn, event):
"""Handles inserting relation data during peristence of events
|