summary refs log tree commit diff
path: root/synapse/storage/stream.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/stream.py')
-rw-r--r--synapse/storage/stream.py284
1 files changed, 151 insertions, 133 deletions
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 580fafeb3a..9cd1e0f9fe 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -59,9 +59,9 @@ _TOPOLOGICAL_TOKEN = "topological"
 
 
 # Used as return values for pagination APIs
-_EventDictReturn = namedtuple("_EventDictReturn", (
-    "event_id", "topological_ordering", "stream_ordering",
-))
+_EventDictReturn = namedtuple(
+    "_EventDictReturn", ("event_id", "topological_ordering", "stream_ordering")
+)
 
 
 def lower_bound(token, engine, inclusive=False):
@@ -74,13 +74,20 @@ def lower_bound(token, engine, inclusive=False):
             # as it optimises ``(x,y) < (a,b)`` on multicolumn indexes. So we
             # use the later form when running against postgres.
             return "((%d,%d) <%s (%s,%s))" % (
-                token.topological, token.stream, inclusive,
-                "topological_ordering", "stream_ordering",
+                token.topological,
+                token.stream,
+                inclusive,
+                "topological_ordering",
+                "stream_ordering",
             )
         return "(%d < %s OR (%d = %s AND %d <%s %s))" % (
-            token.topological, "topological_ordering",
-            token.topological, "topological_ordering",
-            token.stream, inclusive, "stream_ordering",
+            token.topological,
+            "topological_ordering",
+            token.topological,
+            "topological_ordering",
+            token.stream,
+            inclusive,
+            "stream_ordering",
         )
 
 
@@ -94,13 +101,20 @@ def upper_bound(token, engine, inclusive=True):
             # as it optimises ``(x,y) > (a,b)`` on multicolumn indexes. So we
             # use the later form when running against postgres.
             return "((%d,%d) >%s (%s,%s))" % (
-                token.topological, token.stream, inclusive,
-                "topological_ordering", "stream_ordering",
+                token.topological,
+                token.stream,
+                inclusive,
+                "topological_ordering",
+                "stream_ordering",
             )
         return "(%d > %s OR (%d = %s AND %d >%s %s))" % (
-            token.topological, "topological_ordering",
-            token.topological, "topological_ordering",
-            token.stream, inclusive, "stream_ordering",
+            token.topological,
+            "topological_ordering",
+            token.topological,
+            "topological_ordering",
+            token.stream,
+            inclusive,
+            "stream_ordering",
         )
 
 
@@ -116,9 +130,7 @@ def filter_to_clause(event_filter):
     args = []
 
     if event_filter.types:
-        clauses.append(
-            "(%s)" % " OR ".join("type = ?" for _ in event_filter.types)
-        )
+        clauses.append("(%s)" % " OR ".join("type = ?" for _ in event_filter.types))
         args.extend(event_filter.types)
 
     for typ in event_filter.not_types:
@@ -126,9 +138,7 @@ def filter_to_clause(event_filter):
         args.append(typ)
 
     if event_filter.senders:
-        clauses.append(
-            "(%s)" % " OR ".join("sender = ?" for _ in event_filter.senders)
-        )
+        clauses.append("(%s)" % " OR ".join("sender = ?" for _ in event_filter.senders))
         args.extend(event_filter.senders)
 
     for sender in event_filter.not_senders:
@@ -136,9 +146,7 @@ def filter_to_clause(event_filter):
         args.append(sender)
 
     if event_filter.rooms:
-        clauses.append(
-            "(%s)" % " OR ".join("room_id = ?" for _ in event_filter.rooms)
-        )
+        clauses.append("(%s)" % " OR ".join("room_id = ?" for _ in event_filter.rooms))
         args.extend(event_filter.rooms)
 
     for room_id in event_filter.not_rooms:
@@ -165,17 +173,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         events_max = self.get_room_max_stream_ordering()
         event_cache_prefill, min_event_val = self._get_cache_dict(
-            db_conn, "events",
+            db_conn,
+            "events",
             entity_column="room_id",
             stream_column="stream_ordering",
             max_value=events_max,
         )
         self._events_stream_cache = StreamChangeCache(
-            "EventsRoomStreamChangeCache", min_event_val,
+            "EventsRoomStreamChangeCache",
+            min_event_val,
             prefilled_cache=event_cache_prefill,
         )
         self._membership_stream_cache = StreamChangeCache(
-            "MembershipStreamChangeCache", events_max,
+            "MembershipStreamChangeCache", events_max
         )
 
         self._stream_order_on_start = self.get_room_max_stream_ordering()
@@ -189,8 +199,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         raise NotImplementedError()
 
     @defer.inlineCallbacks
-    def get_room_events_stream_for_rooms(self, room_ids, from_key, to_key, limit=0,
-                                         order='DESC'):
+    def get_room_events_stream_for_rooms(
+        self, room_ids, from_key, to_key, limit=0, order='DESC'
+    ):
         """Get new room events in stream ordering since `from_key`.
 
         Args:
@@ -221,14 +232,23 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         results = {}
         room_ids = list(room_ids)
-        for rm_ids in (room_ids[i:i + 20] for i in range(0, len(room_ids), 20)):
-            res = yield make_deferred_yieldable(defer.gatherResults([
-                run_in_background(
-                    self.get_room_events_stream_for_room,
-                    room_id, from_key, to_key, limit, order=order,
+        for rm_ids in (room_ids[i : i + 20] for i in range(0, len(room_ids), 20)):
+            res = yield make_deferred_yieldable(
+                defer.gatherResults(
+                    [
+                        run_in_background(
+                            self.get_room_events_stream_for_room,
+                            room_id,
+                            from_key,
+                            to_key,
+                            limit,
+                            order=order,
+                        )
+                        for room_id in rm_ids
+                    ],
+                    consumeErrors=True,
                 )
-                for room_id in rm_ids
-            ], consumeErrors=True))
+            )
             results.update(dict(zip(rm_ids, res)))
 
         defer.returnValue(results)
@@ -243,13 +263,15 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         """
         from_key = RoomStreamToken.parse_stream_token(from_key).stream
         return set(
-            room_id for room_id in room_ids
+            room_id
+            for room_id in room_ids
             if self._events_stream_cache.has_entity_changed(room_id, from_key)
         )
 
     @defer.inlineCallbacks
-    def get_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0,
-                                        order='DESC'):
+    def get_room_events_stream_for_room(
+        self, room_id, from_key, to_key, limit=0, order='DESC'
+    ):
 
         """Get new room events in stream ordering since `from_key`.
 
@@ -297,10 +319,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         rows = yield self.runInteraction("get_room_events_stream_for_room", f)
 
-        ret = yield self._get_events(
-            [r.event_id for r in rows],
-            get_prev_content=True
-        )
+        ret = yield self._get_events([r.event_id for r in rows], get_prev_content=True)
 
         self._set_before_and_after(ret, rows, topo_order=from_id is None)
 
@@ -340,7 +359,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
                 " AND e.stream_ordering > ? AND e.stream_ordering <= ?"
                 " ORDER BY e.stream_ordering ASC"
             )
-            txn.execute(sql, (user_id, from_id, to_id,))
+            txn.execute(sql, (user_id, from_id, to_id))
 
             rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
 
@@ -348,10 +367,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         rows = yield self.runInteraction("get_membership_changes_for_user", f)
 
-        ret = yield self._get_events(
-            [r.event_id for r in rows],
-            get_prev_content=True
-        )
+        ret = yield self._get_events([r.event_id for r in rows], get_prev_content=True)
 
         self._set_before_and_after(ret, rows, topo_order=False)
 
@@ -374,13 +390,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         """
 
         rows, token = yield self.get_recent_event_ids_for_room(
-            room_id, limit, end_token,
+            room_id, limit, end_token
         )
 
         logger.debug("stream before")
         events = yield self._get_events(
-            [r.event_id for r in rows],
-            get_prev_content=True
+            [r.event_id for r in rows], get_prev_content=True
         )
         logger.debug("stream after")
 
@@ -410,8 +425,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         end_token = RoomStreamToken.parse(end_token)
 
         rows, token = yield self.runInteraction(
-            "get_recent_event_ids_for_room", self._paginate_room_events_txn,
-            room_id, from_token=end_token, limit=limit,
+            "get_recent_event_ids_for_room",
+            self._paginate_room_events_txn,
+            room_id,
+            from_token=end_token,
+            limit=limit,
         )
 
         # We want to return the results in ascending order.
@@ -430,6 +448,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             Deferred[(int, int, str)]:
                 (stream ordering, topological ordering, event_id)
         """
+
         def _f(txn):
             sql = (
                 "SELECT stream_ordering, topological_ordering, event_id"
@@ -439,12 +458,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
                 " ORDER BY stream_ordering"
                 " LIMIT 1"
             )
-            txn.execute(sql, (room_id, stream_ordering, ))
+            txn.execute(sql, (room_id, stream_ordering))
             return txn.fetchone()
 
-        return self.runInteraction(
-            "get_room_event_after_stream_ordering", _f,
-        )
+        return self.runInteraction("get_room_event_after_stream_ordering", _f)
 
     @defer.inlineCallbacks
     def get_room_events_max_id(self, room_id=None):
@@ -459,8 +476,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             defer.returnValue("s%d" % (token,))
         else:
             topo = yield self.runInteraction(
-                "_get_max_topological_txn", self._get_max_topological_txn,
-                room_id,
+                "_get_max_topological_txn", self._get_max_topological_txn, room_id
             )
             defer.returnValue("t%d-%d" % (topo, token))
 
@@ -474,9 +490,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             A deferred "s%d" stream token.
         """
         return self._simple_select_one_onecol(
-            table="events",
-            keyvalues={"event_id": event_id},
-            retcol="stream_ordering",
+            table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering"
         ).addCallback(lambda row: "s%d" % (row,))
 
     def get_topological_token_for_event(self, event_id):
@@ -493,8 +507,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             keyvalues={"event_id": event_id},
             retcols=("stream_ordering", "topological_ordering"),
             desc="get_topological_token_for_event",
-        ).addCallback(lambda row: "t%d-%d" % (
-            row["topological_ordering"], row["stream_ordering"],)
+        ).addCallback(
+            lambda row: "t%d-%d" % (row["topological_ordering"], row["stream_ordering"])
         )
 
     def get_max_topological_token(self, room_id, stream_key):
@@ -503,17 +517,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             " WHERE room_id = ? AND stream_ordering < ?"
         )
         return self._execute(
-            "get_max_topological_token", None,
-            sql, room_id, stream_key,
-        ).addCallback(
-            lambda r: r[0][0] if r else 0
-        )
+            "get_max_topological_token", None, sql, room_id, stream_key
+        ).addCallback(lambda r: r[0][0] if r else 0)
 
     def _get_max_topological_txn(self, txn, room_id):
         txn.execute(
-            "SELECT MAX(topological_ordering) FROM events"
-            " WHERE room_id = ?",
-            (room_id,)
+            "SELECT MAX(topological_ordering) FROM events" " WHERE room_id = ?",
+            (room_id,),
         )
 
         rows = txn.fetchall()
@@ -540,14 +550,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             internal = event.internal_metadata
             internal.before = str(RoomStreamToken(topo, stream - 1))
             internal.after = str(RoomStreamToken(topo, stream))
-            internal.order = (
-                int(topo) if topo else 0,
-                int(stream),
-            )
+            internal.order = (int(topo) if topo else 0, int(stream))
 
     @defer.inlineCallbacks
     def get_events_around(
-        self, room_id, event_id, before_limit, after_limit, event_filter=None,
+        self, room_id, event_id, before_limit, after_limit, event_filter=None
     ):
         """Retrieve events and pagination tokens around a given event in a
         room.
@@ -564,29 +571,34 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         """
 
         results = yield self.runInteraction(
-            "get_events_around", self._get_events_around_txn,
-            room_id, event_id, before_limit, after_limit, event_filter,
+            "get_events_around",
+            self._get_events_around_txn,
+            room_id,
+            event_id,
+            before_limit,
+            after_limit,
+            event_filter,
         )
 
         events_before = yield self._get_events(
-            [e for e in results["before"]["event_ids"]],
-            get_prev_content=True
+            [e for e in results["before"]["event_ids"]], get_prev_content=True
         )
 
         events_after = yield self._get_events(
-            [e for e in results["after"]["event_ids"]],
-            get_prev_content=True
+            [e for e in results["after"]["event_ids"]], get_prev_content=True
         )
 
-        defer.returnValue({
-            "events_before": events_before,
-            "events_after": events_after,
-            "start": results["before"]["token"],
-            "end": results["after"]["token"],
-        })
+        defer.returnValue(
+            {
+                "events_before": events_before,
+                "events_after": events_after,
+                "start": results["before"]["token"],
+                "end": results["after"]["token"],
+            }
+        )
 
     def _get_events_around_txn(
-        self, txn, room_id, event_id, before_limit, after_limit, event_filter,
+        self, txn, room_id, event_id, before_limit, after_limit, event_filter
     ):
         """Retrieves event_ids and pagination tokens around a given event in a
         room.
@@ -605,46 +617,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         results = self._simple_select_one_txn(
             txn,
             "events",
-            keyvalues={
-                "event_id": event_id,
-                "room_id": room_id,
-            },
+            keyvalues={"event_id": event_id, "room_id": room_id},
             retcols=["stream_ordering", "topological_ordering"],
         )
 
         # Paginating backwards includes the event at the token, but paginating
         # forward doesn't.
         before_token = RoomStreamToken(
-            results["topological_ordering"] - 1,
-            results["stream_ordering"],
+            results["topological_ordering"] - 1, results["stream_ordering"]
         )
 
         after_token = RoomStreamToken(
-            results["topological_ordering"],
-            results["stream_ordering"],
+            results["topological_ordering"], results["stream_ordering"]
         )
 
         rows, start_token = self._paginate_room_events_txn(
-            txn, room_id, before_token, direction='b', limit=before_limit,
+            txn,
+            room_id,
+            before_token,
+            direction='b',
+            limit=before_limit,
             event_filter=event_filter,
         )
         events_before = [r.event_id for r in rows]
 
         rows, end_token = self._paginate_room_events_txn(
-            txn, room_id, after_token, direction='f', limit=after_limit,
+            txn,
+            room_id,
+            after_token,
+            direction='f',
+            limit=after_limit,
             event_filter=event_filter,
         )
         events_after = [r.event_id for r in rows]
 
         return {
-            "before": {
-                "event_ids": events_before,
-                "token": start_token,
-            },
-            "after": {
-                "event_ids": events_after,
-                "token": end_token,
-            },
+            "before": {"event_ids": events_before, "token": start_token},
+            "after": {"event_ids": events_after, "token": end_token},
         }
 
     @defer.inlineCallbacks
@@ -685,7 +694,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             return upper_bound, [row[1] for row in rows]
 
         upper_bound, event_ids = yield self.runInteraction(
-            "get_all_new_events_stream", get_all_new_events_stream_txn,
+            "get_all_new_events_stream", get_all_new_events_stream_txn
         )
 
         events = yield self._get_events(event_ids)
@@ -697,7 +706,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             table="federation_stream_position",
             retcol="stream_id",
             keyvalues={"type": typ},
-            desc="get_federation_out_pos"
+            desc="get_federation_out_pos",
         )
 
     def update_federation_out_pos(self, typ, stream_id):
@@ -711,8 +720,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
     def has_room_changed_since(self, room_id, stream_id):
         return self._events_stream_cache.has_entity_changed(room_id, stream_id)
 
-    def _paginate_room_events_txn(self, txn, room_id, from_token, to_token=None,
-                                  direction='b', limit=-1, event_filter=None):
+    def _paginate_room_events_txn(
+        self,
+        txn,
+        room_id,
+        from_token,
+        to_token=None,
+        direction='b',
+        limit=-1,
+        event_filter=None,
+    ):
         """Returns list of events before or after a given token.
 
         Args:
@@ -741,22 +758,20 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         args = [False, room_id]
         if direction == 'b':
             order = "DESC"
-            bounds = upper_bound(
-                from_token, self.database_engine
-            )
+            bounds = upper_bound(from_token, self.database_engine)
             if to_token:
-                bounds = "%s AND %s" % (bounds, lower_bound(
-                    to_token, self.database_engine
-                ))
+                bounds = "%s AND %s" % (
+                    bounds,
+                    lower_bound(to_token, self.database_engine),
+                )
         else:
             order = "ASC"
-            bounds = lower_bound(
-                from_token, self.database_engine
-            )
+            bounds = lower_bound(from_token, self.database_engine)
             if to_token:
-                bounds = "%s AND %s" % (bounds, upper_bound(
-                    to_token, self.database_engine
-                ))
+                bounds = "%s AND %s" % (
+                    bounds,
+                    upper_bound(to_token, self.database_engine),
+                )
 
         filter_clause, filter_args = filter_to_clause(event_filter)
 
@@ -772,10 +787,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             " WHERE outlier = ? AND room_id = ? AND %(bounds)s"
             " ORDER BY topological_ordering %(order)s,"
             " stream_ordering %(order)s LIMIT ?"
-        ) % {
-            "bounds": bounds,
-            "order": order,
-        }
+        ) % {"bounds": bounds, "order": order}
 
         txn.execute(sql, args)
 
@@ -796,11 +808,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             # TODO (erikj): We should work out what to do here instead.
             next_token = to_token if to_token else from_token
 
-        return rows, str(next_token),
+        return rows, str(next_token)
 
     @defer.inlineCallbacks
-    def paginate_room_events(self, room_id, from_key, to_key=None,
-                             direction='b', limit=-1, event_filter=None):
+    def paginate_room_events(
+        self, room_id, from_key, to_key=None, direction='b', limit=-1, event_filter=None
+    ):
         """Returns list of events before or after a given token.
 
         Args:
@@ -826,13 +839,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             to_key = RoomStreamToken.parse(to_key)
 
         rows, token = yield self.runInteraction(
-            "paginate_room_events", self._paginate_room_events_txn,
-            room_id, from_key, to_key, direction, limit, event_filter,
+            "paginate_room_events",
+            self._paginate_room_events_txn,
+            room_id,
+            from_key,
+            to_key,
+            direction,
+            limit,
+            event_filter,
         )
 
         events = yield self._get_events(
-            [r.event_id for r in rows],
-            get_prev_content=True
+            [r.event_id for r in rows], get_prev_content=True
         )
 
         self._set_before_and_after(events, rows)