summary refs log tree commit diff
path: root/synapse/replication/tcp/streams/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp/streams/_base.py')
-rw-r--r--synapse/replication/tcp/streams/_base.py68
1 files changed, 56 insertions, 12 deletions
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index b48a6a3e91..d42aaff055 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -14,14 +14,27 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import heapq
 import logging
 from collections import namedtuple
-from typing import Any, Awaitable, Callable, List, Optional, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Awaitable,
+    Callable,
+    List,
+    Optional,
+    Tuple,
+    TypeVar,
+)
 
 import attr
 
 from synapse.replication.http.streams import ReplicationGetStreamUpdates
 
+if TYPE_CHECKING:
+    import synapse.server
+
 logger = logging.getLogger(__name__)
 
 # the number of rows to request from an update_function.
@@ -37,7 +50,7 @@ Token = int
 # parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's
 # just a row from a database query, though this is dependent on the stream in question.
 #
-StreamRow = Tuple
+StreamRow = TypeVar("StreamRow", bound=Tuple)
 
 # The type returned by the update_function of a stream, as well as get_updates(),
 # get_updates_since, etc.
@@ -533,32 +546,63 @@ class AccountDataStream(Stream):
     """
 
     AccountDataStreamRow = namedtuple(
-        "AccountDataStream", ("user_id", "room_id", "data_type")  # str  # str  # str
+        "AccountDataStream",
+        ("user_id", "room_id", "data_type"),  # str  # Optional[str]  # str
     )
 
     NAME = "account_data"
     ROW_TYPE = AccountDataStreamRow
 
-    def __init__(self, hs):
+    def __init__(self, hs: "synapse.server.HomeServer"):
         self.store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
             current_token_without_instance(self.store.get_max_account_data_stream_id),
-            db_query_to_update_function(self._update_function),
+            self._update_function,
+        )
+
+    async def _update_function(
+        self, instance_name: str, from_token: int, to_token: int, limit: int
+    ) -> StreamUpdateResult:
+        limited = False
+        global_results = await self.store.get_updated_global_account_data(
+            from_token, to_token, limit
         )
 
-    async def _update_function(self, from_token, to_token, limit):
-        global_results, room_results = await self.store.get_all_updated_account_data(
-            from_token, from_token, to_token, limit
+        # if the global results hit the limit, we'll need to limit the room results to
+        # the same stream token.
+        if len(global_results) >= limit:
+            to_token = global_results[-1][0]
+            limited = True
+
+        room_results = await self.store.get_updated_room_account_data(
+            from_token, to_token, limit
         )
 
-        results = list(room_results)
-        results.extend(
-            (stream_id, user_id, None, account_data_type)
+        # likewise, if the room results hit the limit, limit the global results to
+        # the same stream token.
+        if len(room_results) >= limit:
+            to_token = room_results[-1][0]
+            limited = True
+
+        # convert the global results to the right format, and limit them to the to_token
+        # at the same time
+        global_rows = (
+            (stream_id, (user_id, None, account_data_type))
             for stream_id, user_id, account_data_type in global_results
+            if stream_id <= to_token
+        )
+
+        # we know that the room_results are already limited to `to_token` so no need
+        # for a check on `stream_id` here.
+        room_rows = (
+            (stream_id, (user_id, room_id, account_data_type))
+            for stream_id, user_id, room_id, account_data_type in room_results
         )
 
-        return results
+        # we need to return a sorted list, so merge them together.
+        updates = list(heapq.merge(room_rows, global_rows))
+        return updates, to_token, limited
 
 
 class GroupServerStream(Stream):