summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/rest/client/v2_alpha/relations.py50
-rw-r--r--synapse/storage/relations.py80
-rw-r--r--tests/rest/client/v2_alpha/test_relations.py30
3 files changed, 160 insertions, 0 deletions
diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py
index b504b4a8be..bac9e85c21 100644
--- a/synapse/rest/client/v2_alpha/relations.py
+++ b/synapse/rest/client/v2_alpha/relations.py
@@ -27,6 +27,7 @@ from synapse.api.constants import EventTypes
 from synapse.api.errors import SynapseError
 from synapse.http.servlet import (
     RestServlet,
+    parse_integer,
     parse_json_object_from_request,
     parse_string,
 )
@@ -106,5 +107,54 @@ class RelationSendServlet(RestServlet):
         defer.returnValue((200, {"event_id": event.event_id}))
 
 
+class RelationPaginationServlet(RestServlet):
+    """API to paginate relations on an event by topological ordering, optionally
+    filtered by relation type and event type.
+    """
+
+    PATTERNS = client_v2_patterns(
+        "/rooms/(?P<room_id>[^/]*)/relations/(?P<parent_id>[^/]*)"
+        "(/(?P<relation_type>[^/]*)(/(?P<event_type>[^/]*))?)?$",
+        releases=(),
+    )
+
+    def __init__(self, hs):
+        super(RelationPaginationServlet, self).__init__()
+        self.auth = hs.get_auth()
+        self.store = hs.get_datastore()
+        self.clock = hs.get_clock()
+        self._event_serializer = hs.get_event_client_serializer()
+
+    @defer.inlineCallbacks
+    def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None):
+        requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+
+        yield self.auth.check_in_room_or_world_readable(
+            room_id, requester.user.to_string()
+        )
+
+        limit = parse_integer(request, "limit", default=5)
+
+        result = yield self.store.get_relations_for_event(
+            event_id=parent_id,
+            relation_type=relation_type,
+            event_type=event_type,
+            limit=limit,
+        )
+
+        events = yield self.store.get_events_as_list(
+            [c["event_id"] for c in result.chunk]
+        )
+
+        now = self.clock.time_msec()
+        events = yield self._event_serializer.serialize_events(events, now)
+
+        return_value = result.to_dict()
+        return_value["chunk"] = events
+
+        defer.returnValue((200, return_value))
+
+
 def register_servlets(hs, http_server):
     RelationSendServlet(hs).register(http_server)
+    RelationPaginationServlet(hs).register(http_server)
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index a4905162e0..f304b8a9a4 100644
--- a/synapse/storage/relations.py
+++ b/synapse/storage/relations.py
@@ -15,13 +15,93 @@
 
 import logging
 
+import attr
+
 from synapse.api.constants import RelationTypes
 from synapse.storage._base import SQLBaseStore
 
 logger = logging.getLogger(__name__)
 
 
+@attr.s
+class PaginationChunk(object):
+    """Returned by relation pagination APIs.
+
+    Attributes:
+        chunk (list): The rows returned by pagination
+    """
+
+    chunk = attr.ib()
+
+    def to_dict(self):
+        d = {"chunk": self.chunk}
+
+        return d
+
+
 class RelationsStore(SQLBaseStore):
+    def get_relations_for_event(
+        self, event_id, relation_type=None, event_type=None, limit=5, direction="b"
+    ):
+        """Get a list of relations for an event, ordered by topological ordering.
+
+        Args:
+            event_id (str): Fetch events that relate to this event ID.
+            relation_type (str|None): Only fetch events with this relation
+                type, if given.
+            event_type (str|None): Only fetch events with this event type, if
+                given.
+            limit (int): Only fetch the most recent `limit` events.
+            direction (str): Whether to fetch the most recent first (`"b"`) or
+                the oldest first (`"f"`).
+
+        Returns:
+            Deferred[PaginationChunk]: List of event IDs that match relations
+            requested. The rows are of the form `{"event_id": "..."}`.
+        """
+
+        # TODO: Pagination tokens
+
+        where_clause = ["relates_to_id = ?"]
+        where_args = [event_id]
+
+        if relation_type:
+            where_clause.append("relation_type = ?")
+            where_args.append(relation_type)
+
+        if event_type:
+            where_clause.append("type = ?")
+            where_args.append(event_type)
+
+        order = "ASC"
+        if direction == "b":
+            order = "DESC"
+
+        sql = """
+            SELECT event_id FROM event_relations
+            INNER JOIN events USING (event_id)
+            WHERE %s
+            ORDER BY topological_ordering %s, stream_ordering %s
+            LIMIT ?
+        """ % (
+            " AND ".join(where_clause),
+            order,
+            order,
+        )
+
+        def _get_recent_references_for_event_txn(txn):
+            txn.execute(sql, where_args + [limit + 1])
+
+            events = [{"event_id": row[0]} for row in txn]
+
+            return PaginationChunk(
+                chunk=list(events[:limit]),
+            )
+
+        return self.runInteraction(
+            "get_recent_references_for_event", _get_recent_references_for_event_txn
+        )
+
     def _handle_event_relations(self, txn, event):
         """Handles inserting relation data during peristence of events
 
diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py
index 61163d5b26..bcc1c1bb85 100644
--- a/tests/rest/client/v2_alpha/test_relations.py
+++ b/tests/rest/client/v2_alpha/test_relations.py
@@ -72,6 +72,36 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member)
         self.assertEquals(400, channel.code, channel.json_body)
 
+    def test_paginate(self):
+        """Tests that calling pagination API corectly the latest relations.
+        """
+        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction")
+        self.assertEquals(200, channel.code, channel.json_body)
+
+        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction")
+        self.assertEquals(200, channel.code, channel.json_body)
+        annotation_id = channel.json_body["event_id"]
+
+        request, channel = self.make_request(
+            "GET",
+            "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1"
+            % (self.room, self.parent_id),
+        )
+        self.render(request)
+        self.assertEquals(200, channel.code, channel.json_body)
+
+        # We expect to get back a single pagination result, which is the full
+        # relation event we sent above.
+        self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body)
+        self.assert_dict(
+            {
+                "event_id": annotation_id,
+                "sender": self.user_id,
+                "type": "m.reaction",
+            },
+            channel.json_body["chunk"][0],
+        )
+
     def _send_relation(self, relation_type, event_type, key=None):
         """Helper function to send a relation pointing at `self.parent_id`