summary refs log tree commit diff
path: root/synapse/storage/databases/main/stream.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/stream.py')
-rw-r--r--synapse/storage/databases/main/stream.py280
1 files changed, 243 insertions, 37 deletions
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index a94bec1ac5..e3b9ff5ca6 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -53,7 +53,9 @@ from synapse.storage.database import (
 )
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
 from synapse.types import Collection, PersistedEventPosition, RoomStreamToken
+from synapse.util.caches.descriptors import cached
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 if TYPE_CHECKING:
@@ -208,6 +210,55 @@ def _make_generic_sql_bound(
     )
 
 
+def _filter_results(
+    lower_token: Optional[RoomStreamToken],
+    upper_token: Optional[RoomStreamToken],
+    instance_name: str,
+    topological_ordering: int,
+    stream_ordering: int,
+) -> bool:
+    """Returns True if the event persisted by the given instance at the given
+    topological/stream_ordering falls between the two tokens (taking a None
+    token to mean unbounded).
+
+    Used to filter results from fetching events in the DB against the given
+    tokens. This is necessary to handle the case where the tokens include
+    position maps, which we handle by fetching more than necessary from the DB
+    and then filtering (rather than attempting to construct a complicated SQL
+    query).
+    """
+
+    event_historical_tuple = (
+        topological_ordering,
+        stream_ordering,
+    )
+
+    if lower_token:
+        if lower_token.topological is not None:
+            # If these are historical tokens we compare the `(topological, stream)`
+            # tuples.
+            if event_historical_tuple <= lower_token.as_historical_tuple():
+                return False
+
+        else:
+            # If these are live tokens we compare the stream ordering against the
+            # writers stream position.
+            if stream_ordering <= lower_token.get_stream_pos_for_instance(
+                instance_name
+            ):
+                return False
+
+    if upper_token:
+        if upper_token.topological is not None:
+            if upper_token.as_historical_tuple() < event_historical_tuple:
+                return False
+        else:
+            if upper_token.get_stream_pos_for_instance(instance_name) < stream_ordering:
+                return False
+
+    return True
+
+
 def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
     # NB: This may create SQL clauses that don't optimise well (and we don't
     # have indices on all possible clauses). E.g. it may create
@@ -305,7 +356,31 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
         raise NotImplementedError()
 
     def get_room_max_token(self) -> RoomStreamToken:
-        return RoomStreamToken(None, self.get_room_max_stream_ordering())
+        """Get a `RoomStreamToken` that marks the current maximum persisted
+        position of the events stream. Useful to get a token that represents
+        "now".
+
+        The token returned is a "live" token that may have an instance_map
+        component.
+        """
+
+        min_pos = self._stream_id_gen.get_current_token()
+
+        positions = {}
+        if isinstance(self._stream_id_gen, MultiWriterIdGenerator):
+            # The `min_pos` is the minimum position that we know all instances
+            # have finished persisting to, so we only care about instances whose
+            # positions are ahead of that. (Instance positions can be behind the
+            # min position as there are times we can work out that the minimum
+            # position is ahead of the naive minimum across all current
+            # positions. See MultiWriterIdGenerator for details)
+            positions = {
+                i: p
+                for i, p in self._stream_id_gen.get_positions().items()
+                if p > min_pos
+            }
+
+        return RoomStreamToken(None, min_pos, positions)
 
     async def get_room_events_stream_for_rooms(
         self,
@@ -404,25 +479,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
         if from_key == to_key:
             return [], from_key
 
-        from_id = from_key.stream
-        to_id = to_key.stream
-
-        has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id)
+        has_changed = self._events_stream_cache.has_entity_changed(
+            room_id, from_key.stream
+        )
 
         if not has_changed:
             return [], from_key
 
         def f(txn):
-            sql = (
-                "SELECT event_id, stream_ordering FROM events WHERE"
-                " room_id = ?"
-                " AND not outlier"
-                " AND stream_ordering > ? AND stream_ordering <= ?"
-                " ORDER BY stream_ordering %s LIMIT ?"
-            ) % (order,)
-            txn.execute(sql, (room_id, from_id, to_id, limit))
-
-            rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
+            # To handle tokens with a non-empty instance_map we fetch more
+            # results than necessary and then filter down
+            min_from_id = from_key.stream
+            max_to_id = to_key.get_max_stream_pos()
+
+            sql = """
+                SELECT event_id, instance_name, topological_ordering, stream_ordering
+                FROM events
+                WHERE
+                    room_id = ?
+                    AND not outlier
+                    AND stream_ordering > ? AND stream_ordering <= ?
+                ORDER BY stream_ordering %s LIMIT ?
+            """ % (
+                order,
+            )
+            txn.execute(sql, (room_id, min_from_id, max_to_id, 2 * limit))
+
+            rows = [
+                _EventDictReturn(event_id, None, stream_ordering)
+                for event_id, instance_name, topological_ordering, stream_ordering in txn
+                if _filter_results(
+                    from_key,
+                    to_key,
+                    instance_name,
+                    topological_ordering,
+                    stream_ordering,
+                )
+            ][:limit]
             return rows
 
         rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f)
@@ -431,7 +524,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
             [r.event_id for r in rows], get_prev_content=True
         )
 
-        self._set_before_and_after(ret, rows, topo_order=from_id is None)
+        self._set_before_and_after(ret, rows, topo_order=False)
 
         if order.lower() == "desc":
             ret.reverse()
@@ -448,31 +541,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
     async def get_membership_changes_for_user(
         self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
     ) -> List[EventBase]:
-        from_id = from_key.stream
-        to_id = to_key.stream
-
         if from_key == to_key:
             return []
 
-        if from_id:
+        if from_key:
             has_changed = self._membership_stream_cache.has_entity_changed(
-                user_id, int(from_id)
+                user_id, int(from_key.stream)
             )
             if not has_changed:
                 return []
 
         def f(txn):
-            sql = (
-                "SELECT m.event_id, stream_ordering FROM events AS e,"
-                " room_memberships AS m"
-                " WHERE e.event_id = m.event_id"
-                " AND m.user_id = ?"
-                " AND e.stream_ordering > ? AND e.stream_ordering <= ?"
-                " ORDER BY e.stream_ordering ASC"
-            )
-            txn.execute(sql, (user_id, from_id, to_id))
-
-            rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
+            # To handle tokens with a non-empty instance_map we fetch more
+            # results than necessary and then filter down
+            min_from_id = from_key.stream
+            max_to_id = to_key.get_max_stream_pos()
+
+            sql = """
+                SELECT m.event_id, instance_name, topological_ordering, stream_ordering
+                FROM events AS e, room_memberships AS m
+                WHERE e.event_id = m.event_id
+                    AND m.user_id = ?
+                    AND e.stream_ordering > ? AND e.stream_ordering <= ?
+                ORDER BY e.stream_ordering ASC
+            """
+            txn.execute(sql, (user_id, min_from_id, max_to_id,))
+
+            rows = [
+                _EventDictReturn(event_id, None, stream_ordering)
+                for event_id, instance_name, topological_ordering, stream_ordering in txn
+                if _filter_results(
+                    from_key,
+                    to_key,
+                    instance_name,
+                    topological_ordering,
+                    stream_ordering,
+                )
+            ]
 
             return rows
 
@@ -966,11 +1071,46 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
         else:
             order = "ASC"
 
+        # The bounds for the stream tokens are complicated by the fact
+        # that we need to handle the instance_map part of the tokens. We do this
+        # by fetching all events between the min stream token and the maximum
+        # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and
+        # then filtering the results.
+        if from_token.topological is not None:
+            from_bound = (
+                from_token.as_historical_tuple()
+            )  # type: Tuple[Optional[int], int]
+        elif direction == "b":
+            from_bound = (
+                None,
+                from_token.get_max_stream_pos(),
+            )
+        else:
+            from_bound = (
+                None,
+                from_token.stream,
+            )
+
+        to_bound = None  # type: Optional[Tuple[Optional[int], int]]
+        if to_token:
+            if to_token.topological is not None:
+                to_bound = to_token.as_historical_tuple()
+            elif direction == "b":
+                to_bound = (
+                    None,
+                    to_token.stream,
+                )
+            else:
+                to_bound = (
+                    None,
+                    to_token.get_max_stream_pos(),
+                )
+
         bounds = generate_pagination_where_clause(
             direction=direction,
             column_names=("topological_ordering", "stream_ordering"),
-            from_token=from_token.as_tuple(),
-            to_token=to_token.as_tuple() if to_token else None,
+            from_token=from_bound,
+            to_token=to_bound,
             engine=self.database_engine,
         )
 
@@ -980,7 +1120,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
             bounds += " AND " + filter_clause
             args.extend(filter_args)
 
-        args.append(int(limit))
+        # We fetch more events as we'll filter the result set
+        args.append(int(limit) * 2)
 
         select_keywords = "SELECT"
         join_clause = ""
@@ -1002,7 +1143,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
                 select_keywords += "DISTINCT"
 
         sql = """
-            %(select_keywords)s event_id, topological_ordering, stream_ordering
+            %(select_keywords)s
+                event_id, instance_name,
+                topological_ordering, stream_ordering
             FROM events
             %(join_clause)s
             WHERE outlier = ? AND room_id = ? AND %(bounds)s
@@ -1017,7 +1160,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
 
         txn.execute(sql, args)
 
-        rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn]
+        # Filter the result set.
+        rows = [
+            _EventDictReturn(event_id, topological_ordering, stream_ordering)
+            for event_id, instance_name, topological_ordering, stream_ordering in txn
+            if _filter_results(
+                lower_token=to_token if direction == "b" else from_token,
+                upper_token=from_token if direction == "b" else to_token,
+                instance_name=instance_name,
+                topological_ordering=topological_ordering,
+                stream_ordering=stream_ordering,
+            )
+        ][:limit]
 
         if rows:
             topo = rows[-1].topological_ordering
@@ -1082,6 +1236,58 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
 
         return (events, token)
 
+    @cached()
+    async def get_id_for_instance(self, instance_name: str) -> int:
+        """Get a unique, immutable ID that corresponds to the given Synapse worker instance.
+        """
+
+        def _get_id_for_instance_txn(txn):
+            instance_id = self.db_pool.simple_select_one_onecol_txn(
+                txn,
+                table="instance_map",
+                keyvalues={"instance_name": instance_name},
+                retcol="instance_id",
+                allow_none=True,
+            )
+            if instance_id is not None:
+                return instance_id
+
+            # If we don't have an entry upsert one.
+            #
+            # We could do this before the first check, and rely on the cache for
+            # efficiency, but each UPSERT causes the next ID to increment which
+            # can quickly bloat the size of the generated IDs for new instances.
+            self.db_pool.simple_upsert_txn(
+                txn,
+                table="instance_map",
+                keyvalues={"instance_name": instance_name},
+                values={},
+            )
+
+            return self.db_pool.simple_select_one_onecol_txn(
+                txn,
+                table="instance_map",
+                keyvalues={"instance_name": instance_name},
+                retcol="instance_id",
+            )
+
+        return await self.db_pool.runInteraction(
+            "get_id_for_instance", _get_id_for_instance_txn
+        )
+
+    @cached()
+    async def get_name_from_instance_id(self, instance_id: int) -> str:
+        """Get the instance name from an ID previously returned by
+        `get_id_for_instance`.
+        """
+
+        return await self.db_pool.simple_select_one_onecol(
+            table="instance_map",
+            keyvalues={"instance_id": instance_id},
+            retcol="instance_name",
+            desc="get_name_from_instance_id",
+        )
+
 
 class StreamStore(StreamWorkerStore):
     def get_room_max_stream_ordering(self) -> int: