summary refs log tree commit diff
path: root/synapse/storage/stream.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/stream.py')
-rw-r--r--synapse/storage/stream.py223
1 files changed, 169 insertions, 54 deletions
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index fb463c525a..0d32a3a498 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -41,6 +41,7 @@ from synapse.storage.events import EventsWorkerStore
 from synapse.types import RoomStreamToken
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.logcontext import make_deferred_yieldable, run_in_background
+from synapse.storage.chunk_ordered_table import ChunkDBOrderedListStore
 from synapse.storage.engines import PostgresEngine
 
 import abc
@@ -62,24 +63,25 @@ _TOPOLOGICAL_TOKEN = "topological"
 
 # Used as return values for pagination APIs
 _EventDictReturn = namedtuple("_EventDictReturn", (
-    "event_id", "topological_ordering", "stream_ordering",
+    "event_id", "chunk_id", "topological_ordering", "stream_ordering",
 ))
 
 
 def lower_bound(token, engine, inclusive=False):
     inclusive = "=" if inclusive else ""
-    if token.topological is None:
+    if token.chunk is None:
         return "(%d <%s %s)" % (token.stream, inclusive, "stream_ordering")
     else:
         if isinstance(engine, PostgresEngine):
             # Postgres doesn't optimise ``(x < a) OR (x=a AND y<b)`` as well
             # as it optimises ``(x,y) < (a,b)`` on multicolumn indexes. So we
             # use the later form when running against postgres.
-            return "((%d,%d) <%s (%s,%s))" % (
-                token.topological, token.stream, inclusive,
+            return "(chunk_id = %d AND (%d,%d) <%s (%s,%s))" % (
+                token.chunk, token.topological, token.stream, inclusive,
                 "topological_ordering", "stream_ordering",
             )
-        return "(%d < %s OR (%d = %s AND %d <%s %s))" % (
+        return "(chunk_id = %d AND (%d < %s OR (%d = %s AND %d <%s %s)))" % (
+            token.chunk,
             token.topological, "topological_ordering",
             token.topological, "topological_ordering",
             token.stream, inclusive, "stream_ordering",
@@ -88,18 +90,19 @@ def lower_bound(token, engine, inclusive=False):
 
 def upper_bound(token, engine, inclusive=True):
     inclusive = "=" if inclusive else ""
-    if token.topological is None:
+    if token.chunk is None:
         return "(%d >%s %s)" % (token.stream, inclusive, "stream_ordering")
     else:
         if isinstance(engine, PostgresEngine):
             # Postgres doesn't optimise ``(x > a) OR (x=a AND y>b)`` as well
             # as it optimises ``(x,y) > (a,b)`` on multicolumn indexes. So we
             # use the later form when running against postgres.
-            return "((%d,%d) >%s (%s,%s))" % (
-                token.topological, token.stream, inclusive,
+            return "(chunk_id = %d AND (%d,%d) >%s (%s,%s))" % (
+                token.chunk, token.topological, token.stream, inclusive,
                 "topological_ordering", "stream_ordering",
             )
-        return "(%d > %s OR (%d = %s AND %d >%s %s))" % (
+        return "(chunk_id = %d AND (%d > %s OR (%d = %s AND %d >%s %s)))" % (
+            token.chunk,
             token.topological, "topological_ordering",
             token.topological, "topological_ordering",
             token.stream, inclusive, "stream_ordering",
@@ -275,7 +278,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             ) % (order,)
             txn.execute(sql, (room_id, from_id, to_id, limit))
 
-            rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
+            rows = [_EventDictReturn(row[0], None, None, row[1]) for row in txn]
             return rows
 
         rows = yield self.runInteraction("get_room_events_stream_for_room", f)
@@ -325,7 +328,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             )
             txn.execute(sql, (user_id, from_id, to_id,))
 
-            rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
+            rows = [_EventDictReturn(row[0], None, None, row[1]) for row in txn]
 
             return rows
 
@@ -392,7 +395,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         end_token = RoomStreamToken.parse(end_token)
 
-        rows, token = yield self.runInteraction(
+        rows, token, _ = yield self.runInteraction(
             "get_recent_event_ids_for_room", self._paginate_room_events_txn,
             room_id, from_token=end_token, limit=limit,
         )
@@ -437,15 +440,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         `room_id` causes it to return the current room specific topological
         token.
         """
-        token = yield self.get_room_max_stream_ordering()
         if room_id is None:
-            defer.returnValue("s%d" % (token,))
+            token = yield self.get_room_max_stream_ordering()
+            defer.returnValue(str(RoomStreamToken(None, None, token)))
         else:
-            topo = yield self.runInteraction(
-                "_get_max_topological_txn", self._get_max_topological_txn,
+            token = yield self.runInteraction(
+                "get_room_events_max_id", self._get_topological_token_for_room_txn,
                 room_id,
             )
-            defer.returnValue("t%d-%d" % (topo, token))
+            if not token:
+                raise Exception("Server not in room")
+            defer.returnValue(str(token))
 
     def get_stream_token_for_event(self, event_id):
         """The stream token for an event
@@ -460,7 +465,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             table="events",
             keyvalues={"event_id": event_id},
             retcol="stream_ordering",
-        ).addCallback(lambda row: "s%d" % (row,))
+        ).addCallback(lambda row: str(RoomStreamToken(None, None, row)))
 
     def get_topological_token_for_event(self, event_id):
         """The stream token for an event
@@ -469,16 +474,34 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         Raises:
             StoreError if the event wasn't in the database.
         Returns:
-            A deferred "t%d-%d" topological token.
+            A deferred topological token.
         """
         return self._simple_select_one(
             table="events",
             keyvalues={"event_id": event_id},
-            retcols=("stream_ordering", "topological_ordering"),
+            retcols=("stream_ordering", "topological_ordering", "chunk_id"),
             desc="get_topological_token_for_event",
-        ).addCallback(lambda row: "t%d-%d" % (
-            row["topological_ordering"], row["stream_ordering"],)
-        )
+        ).addCallback(lambda row: str(RoomStreamToken(
+            row["chunk_id"],
+            row["topological_ordering"],
+            row["stream_ordering"],
+        )))
+
+    def _get_topological_token_for_room_txn(self, txn, room_id):
+        sql = """
+            SELECT chunk_id, topological_ordering, stream_ordering
+            FROM events
+            NATURAL JOIN event_forward_extremities
+            WHERE room_id = ?
+            ORDER BY stream_ordering DESC
+            LIMIT 1
+        """
+        txn.execute(sql, (room_id,))
+        row = txn.fetchone()
+        if row:
+            c, t, s = row
+            return RoomStreamToken(c, t, s)
+        return None
 
     def get_max_topological_token(self, room_id, stream_key):
         sql = (
@@ -515,18 +538,20 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
                 null topological_ordering.
         """
         for event, row in zip(events, rows):
+            chunk = row.chunk_id
+            topo = row.topological_ordering
             stream = row.stream_ordering
-            if topo_order and row.topological_ordering:
-                topo = row.topological_ordering
-            else:
-                topo = None
+
             internal = event.internal_metadata
-            internal.before = str(RoomStreamToken(topo, stream - 1))
-            internal.after = str(RoomStreamToken(topo, stream))
-            internal.order = (
-                int(topo) if topo else 0,
-                int(stream),
-            )
+
+            internal.stream_ordering = stream
+
+            if topo_order:
+                internal.before = str(RoomStreamToken(chunk, topo, stream - 1))
+                internal.after = str(RoomStreamToken(chunk, topo, stream))
+            else:
+                internal.before = str(RoomStreamToken(None, None, stream - 1))
+                internal.after = str(RoomStreamToken(None, None, stream))
 
     @defer.inlineCallbacks
     def get_events_around(self, room_id, event_id, before_limit, after_limit):
@@ -586,27 +611,29 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
                 "event_id": event_id,
                 "room_id": room_id,
             },
-            retcols=["stream_ordering", "topological_ordering"],
+            retcols=["stream_ordering", "topological_ordering", "chunk_id"],
         )
 
         # Paginating backwards includes the event at the token, but paginating
         # forward doesn't.
         before_token = RoomStreamToken(
-            results["topological_ordering"] - 1,
-            results["stream_ordering"],
+            results["chunk_id"],
+            results["topological_ordering"],
+            results["stream_ordering"] - 1,
         )
 
         after_token = RoomStreamToken(
+            results["chunk_id"],
             results["topological_ordering"],
             results["stream_ordering"],
         )
 
-        rows, start_token = self._paginate_room_events_txn(
+        rows, start_token, _ = self._paginate_room_events_txn(
             txn, room_id, before_token, direction='b', limit=before_limit,
         )
         events_before = [r.event_id for r in rows]
 
-        rows, end_token = self._paginate_room_events_txn(
+        rows, end_token, _ = self._paginate_room_events_txn(
             txn, room_id, after_token, direction='f', limit=after_limit,
         )
         events_after = [r.event_id for r in rows]
@@ -689,12 +716,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
                 those that match the filter.
 
         Returns:
-            Deferred[tuple[list[_EventDictReturn], str]]: Returns the results
-            as a list of _EventDictReturn and a token that points to the end
-            of the result set.
+            Deferred[tuple[list[_EventDictReturn], str, list[int]]: Returns
+            the results as a list of _EventDictReturn, a token that points to
+            the end of the result set, and a list of chunks iterated over.
         """
 
-        assert int(limit) >= 0
+        limit = int(limit)  # Sometimes we are passed a string from somewhere
+        assert limit >= 0
+
+        # There are two modes of fetching events: by stream order or by
+        # topological order. This is determined by whether the from_token is a
+        # stream or topological token. If stream then we can simply do a select
+        # ordered by stream_ordering column. If topological, then we need to
+        # fetch events from one chunk at a time until we hit the limit.
 
         # Tokens really represent positions between elements, but we use
         # the convention of pointing to the event before the gap. Hence
@@ -725,10 +759,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             bounds += " AND " + filter_clause
             args.extend(filter_args)
 
-        args.append(int(limit))
+        args.append(limit)
 
         sql = (
-            "SELECT event_id, topological_ordering, stream_ordering"
+            "SELECT event_id, chunk_id, topological_ordering, stream_ordering"
             " FROM events"
             " WHERE outlier = ? AND room_id = ? AND %(bounds)s"
             " ORDER BY topological_ordering %(order)s,"
@@ -740,9 +774,65 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         txn.execute(sql, args)
 
-        rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn]
+        rows = [_EventDictReturn(*row) for row in txn]
+
+        # If we are paginating topologically and we haven't hit the limit on
+        # number of events then we need to fetch events from the previous or
+        # next chunk.
+
+        iterated_chunks = []
+
+        chunk_id = None
+        if from_token.chunk:  # FIXME: may be topological but no chunk.
+            if rows:
+                chunk_id = rows[-1].chunk_id
+                iterated_chunks = [r.chunk_id for r in rows]
+            else:
+                chunk_id = from_token.chunk
+                iterated_chunks = [chunk_id]
+
+        table = ChunkDBOrderedListStore(
+            txn, room_id, self.clock,
+        )
+
+        if filter_clause:
+            filter_clause = "AND " + filter_clause
+
+        sql = (
+            "SELECT event_id, chunk_id, topological_ordering, stream_ordering"
+            " FROM events"
+            " WHERE outlier = ? AND room_id = ? %(filter_clause)s"
+            " ORDER BY topological_ordering %(order)s,"
+            " stream_ordering %(order)s LIMIT ?"
+        ) % {
+            "filter_clause": filter_clause,
+            "order": order,
+        }
+
+        args = [False, room_id] + filter_args + [limit]
+
+        while chunk_id and (limit <= 0 or len(rows) < limit):
+            if chunk_id not in iterated_chunks:
+                iterated_chunks.append(chunk_id)
+
+            if direction == 'b':
+                chunk_id = table.get_prev(chunk_id)
+            else:
+                chunk_id = table.get_next(chunk_id)
+
+            if chunk_id is None:
+                break
+
+            txn.execute(sql, args)
+            new_rows = [_EventDictReturn(*row) for row in txn]
+
+            rows.extend(new_rows)
+
+        # We may have inserted more rows than necessary in the loop above
+        rows = rows[:limit]
 
         if rows:
+            chunk = rows[-1].chunk_id
             topo = rows[-1].topological_ordering
             toke = rows[-1].stream_ordering
             if direction == 'b':
@@ -752,12 +842,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
                 # when we are going backwards so we subtract one from the
                 # stream part.
                 toke -= 1
-            next_token = RoomStreamToken(topo, toke)
+            next_token = RoomStreamToken(chunk, topo, toke)
         else:
             # TODO (erikj): We should work out what to do here instead.
             next_token = to_token if to_token else from_token
 
-        return rows, str(next_token),
+        return rows, str(next_token), iterated_chunks,
 
     @defer.inlineCallbacks
     def paginate_room_events(self, room_id, from_key, to_key=None,
@@ -777,18 +867,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
                 those that match the filter.
 
         Returns:
-            tuple[list[dict], str]: Returns the results as a list of dicts and
-            a token that points to the end of the result set. The dicts have
-            the keys "event_id", "topological_ordering" and "stream_orderign".
+            tuple[list[dict], str, list[str]]: Returns the results as a list of
+            dicts, a token that points to the end of the result set, and a list
+            of backwards extremities. The dicts have the keys "event_id",
+            "topological_ordering" and "stream_ordering".
         """
 
         from_key = RoomStreamToken.parse(from_key)
         if to_key:
             to_key = RoomStreamToken.parse(to_key)
 
-        rows, token = yield self.runInteraction(
-            "paginate_room_events", self._paginate_room_events_txn,
-            room_id, from_key, to_key, direction, limit, event_filter,
+        def _do_paginate_room_events(txn):
+            rows, token, chunks = self._paginate_room_events_txn(
+                txn, room_id, from_key, to_key, direction, limit, event_filter,
+            )
+
+            # We now fetch the extremities by fetching the extremities for
+            # each chunk we iterated over.
+            extremities = []
+            seen = set()
+            for chunk_id in chunks:
+                if chunk_id in seen:
+                    continue
+                seen.add(chunk_id)
+
+                event_ids = self._simple_select_onecol_txn(
+                    txn,
+                    table="chunk_backwards_extremities",
+                    keyvalues={"chunk_id": chunk_id},
+                    retcol="event_id"
+                )
+
+                extremities.extend(e for e in event_ids if e not in extremities)
+
+            return rows, token, extremities
+
+        rows, token, extremities = yield self.runInteraction(
+            "paginate_room_events", _do_paginate_room_events,
         )
 
         events = yield self._get_events(
@@ -798,7 +913,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         self._set_before_and_after(events, rows)
 
-        defer.returnValue((events, token))
+        defer.returnValue((events, token, extremities))
 
 
 class StreamStore(StreamWorkerStore):