summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/replication/tcp/handler.py4
-rw-r--r--synapse/replication/tcp/streams/events.py22
-rw-r--r--synapse/server.pyi5
-rw-r--r--synapse/storage/data_stores/main/events_worker.py64
4 files changed, 76 insertions, 19 deletions
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 0db5a3a24d..3a8c7c7e2d 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -87,7 +87,9 @@ class ReplicationCommandHandler:
             stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
         }  # type: Dict[str, Stream]
 
-        self._position_linearizer = Linearizer("replication_position")
+        self._position_linearizer = Linearizer(
+            "replication_position", clock=self._clock
+        )
 
         # Map of stream to batched updates. See RdataCommand for info on how
         # batching works.
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index aa50492569..52df81b1bd 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -170,22 +170,16 @@ class EventsStream(Stream):
             limited = False
             upper_limit = current_token
 
-        # next up is the state delta table
-
-        state_rows = await self._store.get_all_updated_current_state_deltas(
+        # next up is the state delta table.
+        (
+            state_rows,
+            upper_limit,
+            state_rows_limited,
+        ) = await self._store.get_all_updated_current_state_deltas(
             from_token, upper_limit, target_row_count
-        )  # type: List[Tuple]
-
-        # again, if we've hit the limit there, we'll need to limit the other sources
-        assert len(state_rows) < target_row_count
-        if len(state_rows) == target_row_count:
-            assert state_rows[-1][0] <= upper_limit
-            upper_limit = state_rows[-1][0]
-            limited = True
+        )
 
-            # FIXME: is it a given that there is only one row per stream_id in the
-            # state_deltas table (so that we can be sure that we have got all of the
-            # rows for upper_limit)?
+        limited = limited or state_rows_limited
 
         # finally, fetch the ex-outliers rows. We assume there are few enough of these
         # not to bother with the limit.
diff --git a/synapse/server.pyi b/synapse/server.pyi
index f1a5717028..fc5886f762 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -25,6 +25,7 @@ import synapse.server_notices.server_notices_manager
 import synapse.server_notices.server_notices_sender
 import synapse.state
 import synapse.storage
+from synapse.events.builder import EventBuilderFactory
 
 class HomeServer(object):
     @property
@@ -121,3 +122,7 @@ class HomeServer(object):
         pass
     def get_instance_id(self) -> str:
         pass
+    def get_event_builder_factory(self) -> EventBuilderFactory:
+        pass
+    def get_storage(self) -> synapse.storage.Storage:
+        pass
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index ce8be72bfe..73df6b33ba 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -19,7 +19,7 @@ import itertools
 import logging
 import threading
 from collections import namedtuple
-from typing import List, Optional
+from typing import List, Optional, Tuple
 
 from canonicaljson import json
 from constantly import NamedConstant, Names
@@ -1084,7 +1084,28 @@ class EventsWorkerStore(SQLBaseStore):
             "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
         )
 
-    def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
+    async def get_all_updated_current_state_deltas(
+        self, from_token: int, to_token: int, target_row_count: int
+    ) -> Tuple[List[Tuple], int, bool]:
+        """Fetch updates from current_state_delta_stream
+
+        Args:
+            from_token: The previous stream token. Updates from this stream id will
+                be excluded.
+
+            to_token: The current stream token (ie the upper limit). Updates up to this
+                stream id will be included (modulo the 'limit' param)
+
+            target_row_count: The number of rows to try to return. If more rows are
+                available, we will set 'limited' in the result. In the event of a large
+                batch, we may return more rows than this.
+        Returns:
+            A triplet `(updates, new_last_token, limited)`, where:
+               * `updates` is a list of database tuples.
+               * `new_last_token` is the new position in stream.
+               * `limited` is whether there are more updates to fetch.
+        """
+
         def get_all_updated_current_state_deltas_txn(txn):
             sql = """
                 SELECT stream_id, room_id, type, state_key, event_id
@@ -1092,10 +1113,45 @@ class EventsWorkerStore(SQLBaseStore):
                 WHERE ? < stream_id AND stream_id <= ?
                 ORDER BY stream_id ASC LIMIT ?
             """
-            txn.execute(sql, (from_token, to_token, limit))
+            txn.execute(sql, (from_token, to_token, target_row_count))
             return txn.fetchall()
 
-        return self.db.runInteraction(
+        def get_deltas_for_stream_id_txn(txn, stream_id):
+            sql = """
+                SELECT stream_id, room_id, type, state_key, event_id
+                FROM current_state_delta_stream
+                WHERE stream_id = ?
+            """
+            txn.execute(sql, [stream_id])
+            return txn.fetchall()
+
+        # we need to make sure that, for every stream id in the results, we get *all*
+        # the rows with that stream id.
+
+        rows = await self.db.runInteraction(
             "get_all_updated_current_state_deltas",
             get_all_updated_current_state_deltas_txn,
+        )  # type: List[Tuple]
+
+        # if we've got fewer rows than the limit, we're good
+        if len(rows) < target_row_count:
+            return rows, to_token, False
+
+        # we hit the limit, so reduce the upper limit so that we exclude the stream id
+        # of the last row in the result.
+        assert rows[-1][0] <= to_token
+        to_token = rows[-1][0] - 1
+
+        # search backwards through the list for the point to truncate
+        for idx in range(len(rows) - 1, 0, -1):
+            if rows[idx - 1][0] <= to_token:
+                return rows[:idx], to_token, True
+
+        # bother. We didn't get a full set of changes for even a single
+        # stream id. let's run the query again, without a row limit, but for
+        # just one stream id.
+        to_token += 1
+        rows = await self.db.runInteraction(
+            "get_deltas_for_stream_id", get_deltas_for_stream_id_txn, to_token
         )
+        return rows, to_token, True