summary refs log tree commit diff
path: root/synapse/replication/tcp/streams
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp/streams')
-rw-r--r--synapse/replication/tcp/streams/_base.py19
-rw-r--r--synapse/replication/tcp/streams/events.py113
2 files changed, 105 insertions, 27 deletions
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index a860072ccf..4ae3cffb1e 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -24,8 +24,8 @@ from synapse.replication.http.streams import ReplicationGetStreamUpdates
 
 logger = logging.getLogger(__name__)
 
-
-MAX_EVENTS_BEHIND = 500000
+# the number of rows to request from an update_function.
+_STREAM_UPDATE_TARGET_ROW_COUNT = 100
 
 
 # Some type aliases to make things a bit easier.
@@ -56,7 +56,11 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool]
 #  * from_token: the previous stream token: the starting point for fetching the
 #    updates
 #  * to_token: the new stream token: the point to get updates up to
-#  * limit: the maximum number of rows to return
+#  * target_row_count: a target for the number of rows to be returned.
+#
+# The update_function is expected to return up to _approximately_ target_row_count rows.
+# If there are more updates available, it should set `limited` in the result, and
+# it will be called again to get the next batch.
 #
 UpdateFunction = Callable[[Token, Token, int], Awaitable[StreamUpdateResult]]
 
@@ -138,7 +142,7 @@ class Stream(object):
         return updates, current_token, limited
 
     async def get_updates_since(
-        self, from_token: Token, upto_token: Token, limit: int = 100
+        self, from_token: Token, upto_token: Token
     ) -> StreamUpdateResult:
         """Like get_updates except allows specifying from when we should
         stream updates
@@ -156,7 +160,7 @@ class Stream(object):
             return [], upto_token, False
 
         updates, upto_token, limited = await self.update_function(
-            from_token, upto_token, limit,
+            from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT,
         )
         return updates, upto_token, limited
 
@@ -193,10 +197,7 @@ def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
         from_token: int, upto_token: int, limit: int
     ) -> StreamUpdateResult:
         result = await client(
-            stream_name=stream_name,
-            from_token=from_token,
-            upto_token=upto_token,
-            limit=limit,
+            stream_name=stream_name, from_token=from_token, upto_token=upto_token,
         )
         return result["updates"], result["upto_token"], result["limited"]
 
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 051114596b..aa50492569 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -15,11 +15,12 @@
 # limitations under the License.
 
 import heapq
-from typing import Iterable, Tuple, Type
+from collections import Iterable
+from typing import List, Tuple, Type
 
 import attr
 
-from ._base import Stream, Token, db_query_to_update_function
+from ._base import Stream, StreamUpdateResult, Token
 
 
 """Handling of the 'events' replication stream
@@ -117,30 +118,106 @@ class EventsStream(Stream):
     def __init__(self, hs):
         self._store = hs.get_datastore()
         super().__init__(
-            self._store.get_current_events_token,
-            db_query_to_update_function(self._update_function),
+            self._store.get_current_events_token, self._update_function,
         )
 
     async def _update_function(
-        self, from_token: Token, current_token: Token, limit: int
-    ) -> Iterable[tuple]:
+        self, from_token: Token, current_token: Token, target_row_count: int
+    ) -> StreamUpdateResult:
+
+        # the events stream merges together three separate sources:
+        #  * new events
+        #  * current_state changes
+        #  * events which were previously outliers, but have now been de-outliered.
+        #
+        # The merge operation is complicated by the fact that we only have a single
+        # "stream token" which is supposed to indicate how far we have got through
+        # all three streams. It's therefore no good to return rows 1-1000 from the
+        # "new events" table if the state_deltas are limited to rows 1-100 by the
+        # target_row_count.
+        #
+        # In other words: we must pick a new upper limit, and must return *all* rows
+        # up to that point for each of the three sources.
+        #
+        # Start by trying to split the target_row_count up. We expect to have a
+        # negligible number of ex-outliers, and a rough approximation based on recent
+        # traffic on sw1v.org shows that there are approximately the same number of
+        # event rows between a given pair of stream ids as there are state
+        # updates, so let's split our target_row_count among those two types. The target
+        # is only an approximation - it doesn't matter if we end up going a bit over it.
+
+        target_row_count //= 2
+
+        # now we fetch up to that many rows from the events table
+
         event_rows = await self._store.get_all_new_forward_event_rows(
-            from_token, current_token, limit
-        )
-        event_updates = (
-            (row[0], EventsStreamEventRow.TypeId, row[1:]) for row in event_rows
-        )
+            from_token, current_token, target_row_count
+        )  # type: List[Tuple]
+
+        # we rely on get_all_new_forward_event_rows strictly honouring the limit, so
+        # that we know it is safe to just take upper_limit = event_rows[-1][0].
+        assert (
+            len(event_rows) <= target_row_count
+        ), "get_all_new_forward_event_rows did not honour row limit"
+
+        # if we hit the limit on event_updates, there's no point in going beyond the
+        # last stream_id in the batch for the other sources.
+
+        if len(event_rows) == target_row_count:
+            limited = True
+            upper_limit = event_rows[-1][0]  # type: int
+        else:
+            limited = False
+            upper_limit = current_token
+
+        # next up is the state delta table
 
         state_rows = await self._store.get_all_updated_current_state_deltas(
-            from_token, current_token, limit
-        )
-        state_updates = (
-            (row[0], EventsStreamCurrentStateRow.TypeId, row[1:]) for row in state_rows
-        )
+            from_token, upper_limit, target_row_count
+        )  # type: List[Tuple]
+
+        # again, if we've hit the limit there, we'll need to limit the other sources
+        assert len(state_rows) < target_row_count
+        if len(state_rows) == target_row_count:
+            assert state_rows[-1][0] <= upper_limit
+            upper_limit = state_rows[-1][0]
+            limited = True
+
+            # FIXME: is it a given that there is only one row per stream_id in the
+            # state_deltas table (so that we can be sure that we have got all of the
+            # rows for upper_limit)?
+
+        # finally, fetch the ex-outliers rows. We assume there are few enough of these
+        # not to bother with the limit.
 
-        all_updates = heapq.merge(event_updates, state_updates)
+        ex_outliers_rows = await self._store.get_ex_outlier_stream_rows(
+            from_token, upper_limit
+        )  # type: List[Tuple]
 
-        return all_updates
+        # we now need to turn the raw database rows returned into tuples suitable
+        # for the replication protocol (basically, we add an identifier to
+        # distinguish the row type). At the same time, we can limit the event_rows
+        # to the max stream_id from state_rows.
+
+        event_updates = (
+            (stream_id, (EventsStreamEventRow.TypeId, rest))
+            for (stream_id, *rest) in event_rows
+            if stream_id <= upper_limit
+        )  # type: Iterable[Tuple[int, Tuple]]
+
+        state_updates = (
+            (stream_id, (EventsStreamCurrentStateRow.TypeId, rest))
+            for (stream_id, *rest) in state_rows
+        )  # type: Iterable[Tuple[int, Tuple]]
+
+        ex_outliers_updates = (
+            (stream_id, (EventsStreamEventRow.TypeId, rest))
+            for (stream_id, *rest) in ex_outliers_rows
+        )  # type: Iterable[Tuple[int, Tuple]]
+
+        # we need to return a sorted list, so merge them together.
+        updates = list(heapq.merge(event_updates, state_updates, ex_outliers_updates))
+        return updates, upper_limit, limited
 
     @classmethod
     def parse_row(cls, row):