summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2018-05-09 11:18:23 +0100
committerErik Johnston <erik@matrix.org>2018-05-09 11:34:24 +0100
commit05e0a2462c76be6987c7ec3d9517d500583bac65 (patch)
tree61dffe12138bdf57aa29f7bc70403ea3dc62802e
parentRemove unused from_token param (diff)
downloadsynapse-05e0a2462c76be6987c7ec3d9517d500583bac65.tar.xz
Refactor pagination DB API to return concrete type
This makes it easier to document what is being returned by the storage
functions and what some functions expect as arguments.
-rw-r--r--synapse/storage/stream.py76
1 files changed, 48 insertions, 28 deletions
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index ecd39074b8..772d2c6198 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -47,6 +47,7 @@ import abc
 import logging
 
 from six.moves import range
+from collections import namedtuple
 
 
 logger = logging.getLogger(__name__)
@@ -59,6 +60,12 @@ _STREAM_TOKEN = "stream"
 _TOPOLOGICAL_TOKEN = "topological"
 
 
+# Used as return values for pagination APIs
+_EventDictReturn = namedtuple("_EventDictReturn", (
+    "event_id", "topological_ordering", "stream_ordering",
+))
+
+
 def lower_bound(token, engine, inclusive=False):
     inclusive = "=" if inclusive else ""
     if token.topological is None:
@@ -256,9 +263,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
                     " ORDER BY stream_ordering %s LIMIT ?"
                 ) % (order,)
                 txn.execute(sql, (room_id, from_id, to_id, limit))
+
+                rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
             else:
                 sql = (
-                    "SELECT event_id, stream_ordering FROM events WHERE"
+                    "SELECT event_id, topological_ordering, stream_ordering"
+                    " FROM events"
+                    " WHERE"
                     " room_id = ?"
                     " AND not outlier"
                     " AND stream_ordering <= ?"
@@ -266,14 +277,14 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
                 ) % (order, order,)
                 txn.execute(sql, (room_id, to_id, limit))
 
-            rows = self.cursor_to_dict(txn)
+                rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn]
 
             return rows
 
         rows = yield self.runInteraction("get_room_events_stream_for_room", f)
 
         ret = yield self._get_events(
-            [r["event_id"] for r in rows],
+            [r.event_id for r in rows],
             get_prev_content=True
         )
 
@@ -283,7 +294,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             ret.reverse()
 
         if rows:
-            key = "s%d" % min(r["stream_ordering"] for r in rows)
+            key = "s%d" % min(r.stream_ordering for r in rows)
         else:
             # Assume we didn't get anything because there was nothing to
             # get.
@@ -330,14 +341,15 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
                     " ORDER BY stream_ordering ASC"
                 )
                 txn.execute(sql, (user_id, to_id,))
-            rows = self.cursor_to_dict(txn)
+
+            rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
 
             return rows
 
         rows = yield self.runInteraction("get_membership_changes_for_user", f)
 
         ret = yield self._get_events(
-            [r["event_id"] for r in rows],
+            [r.event_id for r in rows],
             get_prev_content=True
         )
 
@@ -353,14 +365,14 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         logger.debug("stream before")
         events = yield self._get_events(
-            [r["event_id"] for r in rows],
+            [r.event_id for r in rows],
             get_prev_content=True
         )
         logger.debug("stream after")
 
         self._set_before_and_after(events, rows)
 
-        defer.returnValue((events, token))
+        defer.returnValue((events, (token, end_token)))
 
     @defer.inlineCallbacks
     def get_recent_event_ids_for_room(self, room_id, limit, end_token):
@@ -372,15 +384,14 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             end_token (str): The stream token representing now.
 
         Returns:
-            Deferred[tuple[list[dict], tuple[str, str]]]: Returns a list of
-            dicts (which include event_ids, etc), and a tuple for
-            `(start_token, end_token)` representing the range of rows
-            returned.
-            The returned events are in ascending order.
+            Deferred[tuple[list[_EventDictReturn],  str]]: Returns a list of
+            _EventDictReturn and a token pointint to the start of the returned
+            events.
+            The events returned are in ascending order.
         """
         # Allow a zero limit here, and no-op.
         if limit == 0:
-            defer.returnValue(([], (end_token, end_token)))
+            defer.returnValue(([], end_token))
 
         end_token = RoomStreamToken.parse_stream_token(end_token)
 
@@ -392,7 +403,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         # We want to return the results in ascending order.
         rows.reverse()
 
-        defer.returnValue((rows, (token, str(end_token))))
+        defer.returnValue((rows, token))
 
     def get_room_event_after_stream_ordering(self, room_id, stream_ordering):
         """Gets details of the first event in a room at or after a stream ordering
@@ -496,10 +507,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
     @staticmethod
     def _set_before_and_after(events, rows, topo_order=True):
+        """Inserts ordering information to events' internal metadata from
+        the DB rows.
+
+        Args:
+            events (list[FrozenEvent])
+            rows (list[_EventDictReturn])
+            topo_order (bool): Whether the events were ordered topologically
+                or by stream ordering
+        """
         for event, row in zip(events, rows):
-            stream = row["stream_ordering"]
-            if topo_order:
-                topo = event.depth
+            stream = row.stream_ordering
+            if topo_order and row.topological_ordering:
+                topo = row.topological_ordering
             else:
                 topo = None
             internal = event.internal_metadata
@@ -586,12 +606,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         rows, start_token = self._paginate_room_events_txn(
             txn, room_id, before_token, direction='b', limit=before_limit,
         )
-        events_before = [r["event_id"] for r in rows]
+        events_before = [r.event_id for r in rows]
 
         rows, end_token = self._paginate_room_events_txn(
             txn, room_id, after_token, direction='f', limit=after_limit,
         )
-        events_after = [r["event_id"] for r in rows]
+        events_after = [r.event_id for r in rows]
 
         return {
             "before": {
@@ -672,9 +692,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
                 those that match the filter.
 
         Returns:
-            tuple[list[dict], str]: Returns the results as a list of dicts and
-            a token that points to the end of the result set. The dicts have
-            the keys "event_id", "toplogical_ordering" and "stream_ordering".
+            tuple[list[_EventDictReturn], str]: Returns the results as a list
+            of _EventDictReturn and a token that points to the end of the
+            result set.
         """
         # Tokens really represent positions between elements, but we use
         # the convention of pointing to the event before the gap. Hence
@@ -725,11 +745,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         txn.execute(sql, args)
 
-        rows = self.cursor_to_dict(txn)
+        rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn]
 
         if rows:
-            topo = rows[-1]["topological_ordering"]
-            toke = rows[-1]["stream_ordering"]
+            topo = rows[-1].topological_ordering
+            toke = rows[-1].stream_ordering
             if direction == 'b':
                 # Tokens are positions between events.
                 # This token points *after* the last event in the chunk.
@@ -764,7 +784,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         Returns:
             tuple[list[dict], str]: Returns the results as a list of dicts and
             a token that points to the end of the result set. The dicts have
-            the keys "event_id", "toplogical_ordering" and "stream_orderign".
+            the keys "event_id", "topological_ordering" and "stream_orderign".
         """
 
         from_key = RoomStreamToken.parse(from_key)
@@ -777,7 +797,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         )
 
         events = yield self._get_events(
-            [r["event_id"] for r in rows],
+            [r.event_id for r in rows],
             get_prev_content=True
         )