summary refs log tree commit diff
path: root/synapse/storage/stream.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/stream.py')
-rw-r--r--synapse/storage/stream.py198
1 files changed, 173 insertions, 25 deletions
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 02b1913e26..7f4a827528 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -39,7 +39,7 @@ from ._base import SQLBaseStore
 from synapse.util.caches.descriptors import cachedInlineCallbacks
 from synapse.api.constants import EventTypes
 from synapse.types import RoomStreamToken
-from synapse.util.logutils import log_function
+from synapse.util.logcontext import preserve_fn
 
 import logging
 
@@ -77,7 +77,6 @@ def upper_bound(token):
 
 
 class StreamStore(SQLBaseStore):
-
     @defer.inlineCallbacks
     def get_appservice_room_stream(self, service, from_key, to_key, limit=0):
         # NB this lives here instead of appservice.py so we can reuse the
@@ -157,7 +156,153 @@ class StreamStore(SQLBaseStore):
         results = yield self.runInteraction("get_appservice_room_stream", f)
         defer.returnValue(results)
 
-    @log_function
+    @defer.inlineCallbacks
+    def get_room_events_stream_for_rooms(self, room_ids, from_key, to_key, limit=0,
+                                         order='DESC'):
+        from_id = RoomStreamToken.parse_stream_token(from_key).stream
+
+        room_ids = yield self._events_stream_cache.get_entities_changed(
+            room_ids, from_id
+        )
+
+        if not room_ids:
+            defer.returnValue({})
+
+        results = {}
+        room_ids = list(room_ids)
+        for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)):
+            res = yield defer.gatherResults([
+                preserve_fn(self.get_room_events_stream_for_room)(
+                    room_id, from_key, to_key, limit, order=order,
+                )
+                for room_id in rm_ids
+            ])
+            results.update(dict(zip(rm_ids, res)))
+
+        defer.returnValue(results)
+
+    @defer.inlineCallbacks
+    def get_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0,
+                                        order='DESC'):
+        # Note: If from_key is None then we return in topological order. This
+        # is because in that case we're using this as a "get the last few messages
+        # in a room" function, rather than "get new messages since last sync"
+        if from_key is not None:
+            from_id = RoomStreamToken.parse_stream_token(from_key).stream
+        else:
+            from_id = None
+        to_id = RoomStreamToken.parse_stream_token(to_key).stream
+
+        if from_key == to_key:
+            defer.returnValue(([], from_key))
+
+        if from_id:
+            has_changed = yield self._events_stream_cache.has_entity_changed(
+                room_id, from_id
+            )
+
+            if not has_changed:
+                defer.returnValue(([], from_key))
+
+        def f(txn):
+            if from_id is not None:
+                sql = (
+                    "SELECT event_id, stream_ordering FROM events WHERE"
+                    " room_id = ?"
+                    " AND not outlier"
+                    " AND stream_ordering > ? AND stream_ordering <= ?"
+                    " ORDER BY stream_ordering %s LIMIT ?"
+                ) % (order,)
+                txn.execute(sql, (room_id, from_id, to_id, limit))
+            else:
+                sql = (
+                    "SELECT event_id, stream_ordering FROM events WHERE"
+                    " room_id = ?"
+                    " AND not outlier"
+                    " AND stream_ordering <= ?"
+                    " ORDER BY topological_ordering %s, stream_ordering %s LIMIT ?"
+                ) % (order, order,)
+                txn.execute(sql, (room_id, to_id, limit))
+
+            rows = self.cursor_to_dict(txn)
+
+            return rows
+
+        rows = yield self.runInteraction("get_room_events_stream_for_room", f)
+
+        ret = yield self._get_events(
+            [r["event_id"] for r in rows],
+            get_prev_content=True
+        )
+
+        self._set_before_and_after(ret, rows, topo_order=from_id is None)
+
+        if order.lower() == "desc":
+            ret.reverse()
+
+        if rows:
+            key = "s%d" % min(r["stream_ordering"] for r in rows)
+        else:
+            # Assume we didn't get anything because there was nothing to
+            # get.
+            key = from_key
+
+        defer.returnValue((ret, key))
+
+    @defer.inlineCallbacks
+    def get_membership_changes_for_user(self, user_id, from_key, to_key):
+        if from_key is not None:
+            from_id = RoomStreamToken.parse_stream_token(from_key).stream
+        else:
+            from_id = None
+        to_id = RoomStreamToken.parse_stream_token(to_key).stream
+
+        if from_key == to_key:
+            defer.returnValue([])
+
+        if from_id:
+            has_changed = self._membership_stream_cache.has_entity_changed(
+                user_id, int(from_id)
+            )
+            if not has_changed:
+                defer.returnValue([])
+
+        def f(txn):
+            if from_id is not None:
+                sql = (
+                    "SELECT m.event_id, stream_ordering FROM events AS e,"
+                    " room_memberships AS m"
+                    " WHERE e.event_id = m.event_id"
+                    " AND m.user_id = ?"
+                    " AND e.stream_ordering > ? AND e.stream_ordering <= ?"
+                    " ORDER BY e.stream_ordering ASC"
+                )
+                txn.execute(sql, (user_id, from_id, to_id,))
+            else:
+                sql = (
+                    "SELECT m.event_id, stream_ordering FROM events AS e,"
+                    " room_memberships AS m"
+                    " WHERE e.event_id = m.event_id"
+                    " AND m.user_id = ?"
+                    " AND stream_ordering <= ?"
+                    " ORDER BY stream_ordering ASC"
+                )
+                txn.execute(sql, (user_id, to_id,))
+            rows = self.cursor_to_dict(txn)
+
+            return rows
+
+        rows = yield self.runInteraction("get_membership_changes_for_user", f)
+
+        ret = yield self._get_events(
+            [r["event_id"] for r in rows],
+            get_prev_content=True
+        )
+
+        self._set_before_and_after(ret, rows, topo_order=False)
+
+        defer.returnValue(ret)
+
     def get_room_events_stream(
         self,
         user_id,
@@ -174,7 +319,8 @@ class StreamStore(SQLBaseStore):
                 "SELECT c.room_id FROM history_visibility AS h"
                 " INNER JOIN current_state_events AS c"
                 " ON h.event_id = c.event_id"
-                " WHERE c.room_id IN (%s) AND h.history_visibility = 'world_readable'" % (
+                " WHERE c.room_id IN (%s)"
+                " AND h.history_visibility = 'world_readable'" % (
                     ",".join(map(lambda _: "?", room_ids))
                 )
             )
@@ -187,11 +333,6 @@ class StreamStore(SQLBaseStore):
                 " WHERE m.user_id = ? AND m.membership = 'join'"
             )
             current_room_membership_args = [user_id]
-            if room_ids:
-                current_room_membership_sql += " AND m.room_id in (%s)" % (
-                    ",".join(map(lambda _: "?", room_ids))
-                )
-                current_room_membership_args = [user_id] + room_ids
 
         # We also want to get any membership events about that user, e.g.
         # invites or leave notifications.
@@ -393,7 +534,7 @@ class StreamStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def get_room_events_max_id(self, direction='f'):
-        token = yield self._stream_id_gen.get_max_token(self)
+        token = yield self._stream_id_gen.get_max_token()
         if direction != 'b':
             defer.returnValue("s%d" % (token,))
         else:
@@ -430,10 +571,23 @@ class StreamStore(SQLBaseStore):
             table="events",
             keyvalues={"event_id": event_id},
             retcols=("stream_ordering", "topological_ordering"),
+            desc="get_topological_token_for_event",
         ).addCallback(lambda row: "t%d-%d" % (
             row["topological_ordering"], row["stream_ordering"],)
         )
 
+    def get_max_topological_token_for_stream_and_room(self, room_id, stream_key):
+        sql = (
+            "SELECT max(topological_ordering) FROM events"
+            " WHERE room_id = ? AND stream_ordering < ?"
+        )
+        return self._execute(
+            "get_max_topological_token_for_stream_and_room", None,
+            sql, room_id, stream_key,
+        ).addCallback(
+            lambda r: r[0][0] if r else 0
+        )
+
     def _get_max_topological_txn(self, txn):
         txn.execute(
             "SELECT MAX(topological_ordering) FROM events"
@@ -444,27 +598,21 @@ class StreamStore(SQLBaseStore):
         rows = txn.fetchall()
         return rows[0][0] if rows else 0
 
-    @defer.inlineCallbacks
-    def _get_min_token(self):
-        row = yield self._execute(
-            "_get_min_token", None, "SELECT MIN(stream_ordering) FROM events"
-        )
-
-        self.min_token = row[0][0] if row and row[0] and row[0][0] else -1
-        self.min_token = min(self.min_token, -1)
-
-        logger.debug("min_token is: %s", self.min_token)
-
-        defer.returnValue(self.min_token)
-
     @staticmethod
-    def _set_before_and_after(events, rows):
+    def _set_before_and_after(events, rows, topo_order=True):
         for event, row in zip(events, rows):
             stream = row["stream_ordering"]
-            topo = event.depth
+            if topo_order:
+                topo = event.depth
+            else:
+                topo = None
             internal = event.internal_metadata
             internal.before = str(RoomStreamToken(topo, stream - 1))
             internal.after = str(RoomStreamToken(topo, stream))
+            internal.order = (
+                int(topo) if topo else 0,
+                int(stream),
+            )
 
     @defer.inlineCallbacks
     def get_events_around(self, room_id, event_id, before_limit, after_limit):