summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/7384.bugfix1
-rw-r--r--synapse/replication/tcp/streams/_base.py68
-rw-r--r--synapse/storage/data_stores/main/account_data.py62
-rw-r--r--tests/replication/tcp/streams/test_account_data.py117
4 files changed, 217 insertions, 31 deletions
diff --git a/changelog.d/7384.bugfix b/changelog.d/7384.bugfix
new file mode 100644
index 0000000000..f49c600173
--- /dev/null
+++ b/changelog.d/7384.bugfix
@@ -0,0 +1 @@
+Fix a bug where event updates might not be sent over replication to worker processes after the stream falls behind.
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index b48a6a3e91..d42aaff055 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -14,14 +14,27 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import heapq
 import logging
 from collections import namedtuple
-from typing import Any, Awaitable, Callable, List, Optional, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Awaitable,
+    Callable,
+    List,
+    Optional,
+    Tuple,
+    TypeVar,
+)
 
 import attr
 
 from synapse.replication.http.streams import ReplicationGetStreamUpdates
 
+if TYPE_CHECKING:
+    import synapse.server
+
 logger = logging.getLogger(__name__)
 
 # the number of rows to request from an update_function.
@@ -37,7 +50,7 @@ Token = int
 # parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's
 # just a row from a database query, though this is dependent on the stream in question.
 #
-StreamRow = Tuple
+StreamRow = TypeVar("StreamRow", bound=Tuple)
 
 # The type returned by the update_function of a stream, as well as get_updates(),
 # get_updates_since, etc.
@@ -533,32 +546,63 @@ class AccountDataStream(Stream):
     """
 
     AccountDataStreamRow = namedtuple(
-        "AccountDataStream", ("user_id", "room_id", "data_type")  # str  # str  # str
+        "AccountDataStream",
+        ("user_id", "room_id", "data_type"),  # str  # Optional[str]  # str
     )
 
     NAME = "account_data"
     ROW_TYPE = AccountDataStreamRow
 
-    def __init__(self, hs):
+    def __init__(self, hs: "synapse.server.HomeServer"):
         self.store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
             current_token_without_instance(self.store.get_max_account_data_stream_id),
-            db_query_to_update_function(self._update_function),
+            self._update_function,
+        )
+
+    async def _update_function(
+        self, instance_name: str, from_token: int, to_token: int, limit: int
+    ) -> StreamUpdateResult:
+        limited = False
+        global_results = await self.store.get_updated_global_account_data(
+            from_token, to_token, limit
         )
 
-    async def _update_function(self, from_token, to_token, limit):
-        global_results, room_results = await self.store.get_all_updated_account_data(
-            from_token, from_token, to_token, limit
+        # if the global results hit the limit, we'll need to limit the room results to
+        # the same stream token.
+        if len(global_results) >= limit:
+            to_token = global_results[-1][0]
+            limited = True
+
+        room_results = await self.store.get_updated_room_account_data(
+            from_token, to_token, limit
         )
 
-        results = list(room_results)
-        results.extend(
-            (stream_id, user_id, None, account_data_type)
+        # likewise, if the room results hit the limit, limit the global results to
+        # the same stream token.
+        if len(room_results) >= limit:
+            to_token = room_results[-1][0]
+            limited = True
+
+        # convert the global results to the right format, and limit them to the to_token
+        # at the same time
+        global_rows = (
+            (stream_id, (user_id, None, account_data_type))
             for stream_id, user_id, account_data_type in global_results
+            if stream_id <= to_token
+        )
+
+        # we know that the room_results are already limited to `to_token` so no need
+        # for a check on `stream_id` here.
+        room_rows = (
+            (stream_id, (user_id, room_id, account_data_type))
+            for stream_id, user_id, room_id, account_data_type in room_results
         )
 
-        return results
+        # we need to return a sorted list, so merge them together.
+        updates = list(heapq.merge(room_rows, global_rows))
+        return updates, to_token, limited
 
 
 class GroupServerStream(Stream):
diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py
index 46b494b334..f9eef1b78e 100644
--- a/synapse/storage/data_stores/main/account_data.py
+++ b/synapse/storage/data_stores/main/account_data.py
@@ -16,6 +16,7 @@
 
 import abc
 import logging
+from typing import List, Tuple
 
 from canonicaljson import json
 
@@ -175,41 +176,64 @@ class AccountDataWorkerStore(SQLBaseStore):
             "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
         )
 
-    def get_all_updated_account_data(
-        self, last_global_id, last_room_id, current_id, limit
-    ):
-        """Get all the client account_data that has changed on the server
+    async def get_updated_global_account_data(
+        self, last_id: int, current_id: int, limit: int
+    ) -> List[Tuple[int, str, str]]:
+        """Get the global account_data that has changed, for the account_data stream
+
         Args:
-            last_global_id(int): The position to fetch from for top level data
-            last_room_id(int): The position to fetch from for per room data
-            current_id(int): The position to fetch up to.
+            last_id: the last stream_id from the previous batch.
+            current_id: the maximum stream_id to return up to
+            limit: the maximum number of rows to return
+
         Returns:
-            A deferred pair of lists of tuples of stream_id int, user_id string,
-            room_id string, and type string.
+            A list of tuples of stream_id int, user_id string,
+            and type string.
         """
-        if last_room_id == current_id and last_global_id == current_id:
-            return defer.succeed(([], []))
+        if last_id == current_id:
+            return []
 
-        def get_updated_account_data_txn(txn):
+        def get_updated_global_account_data_txn(txn):
             sql = (
                 "SELECT stream_id, user_id, account_data_type"
                 " FROM account_data WHERE ? < stream_id AND stream_id <= ?"
                 " ORDER BY stream_id ASC LIMIT ?"
             )
-            txn.execute(sql, (last_global_id, current_id, limit))
-            global_results = txn.fetchall()
+            txn.execute(sql, (last_id, current_id, limit))
+            return txn.fetchall()
+
+        return await self.db.runInteraction(
+            "get_updated_global_account_data", get_updated_global_account_data_txn
+        )
+
+    async def get_updated_room_account_data(
+        self, last_id: int, current_id: int, limit: int
+    ) -> List[Tuple[int, str, str, str]]:
+        """Get the global account_data that has changed, for the account_data stream
 
+        Args:
+            last_id: the last stream_id from the previous batch.
+            current_id: the maximum stream_id to return up to
+            limit: the maximum number of rows to return
+
+        Returns:
+            A list of tuples of stream_id int, user_id string,
+            room_id string and type string.
+        """
+        if last_id == current_id:
+            return []
+
+        def get_updated_room_account_data_txn(txn):
             sql = (
                 "SELECT stream_id, user_id, room_id, account_data_type"
                 " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?"
                 " ORDER BY stream_id ASC LIMIT ?"
             )
-            txn.execute(sql, (last_room_id, current_id, limit))
-            room_results = txn.fetchall()
-            return global_results, room_results
+            txn.execute(sql, (last_id, current_id, limit))
+            return txn.fetchall()
 
-        return self.db.runInteraction(
-            "get_all_updated_account_data_txn", get_updated_account_data_txn
+        return await self.db.runInteraction(
+            "get_updated_room_account_data", get_updated_room_account_data_txn
         )
 
     def get_updated_account_data_for_user(self, user_id, stream_id):
diff --git a/tests/replication/tcp/streams/test_account_data.py b/tests/replication/tcp/streams/test_account_data.py
new file mode 100644
index 0000000000..6a5116dd2a
--- /dev/null
+++ b/tests/replication/tcp/streams/test_account_data.py
@@ -0,0 +1,117 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.replication.tcp.streams._base import (
+    _STREAM_UPDATE_TARGET_ROW_COUNT,
+    AccountDataStream,
+)
+
+from tests.replication._base import BaseStreamTestCase
+
+
+class AccountDataStreamTestCase(BaseStreamTestCase):
+    def test_update_function_room_account_data_limit(self):
+        """Test replication with many room account data updates
+        """
+        store = self.hs.get_datastore()
+
+        # generate lots of account data updates
+        updates = []
+        for i in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 5):
+            update = "m.test_type.%i" % (i,)
+            self.get_success(
+                store.add_account_data_to_room("test_user", "test_room", update, {})
+            )
+            updates.append(update)
+
+        # also one global update
+        self.get_success(store.add_account_data_for_user("test_user", "m.global", {}))
+
+        # tell the notifier to catch up to avoid duplicate rows.
+        # workaround for https://github.com/matrix-org/synapse/issues/7360
+        # FIXME remove this when the above is fixed
+        self.replicate()
+
+        # check we're testing what we think we are: no rows should yet have been
+        # received
+        self.assertEqual([], self.test_handler.received_rdata_rows)
+
+        # now reconnect to pull the updates
+        self.reconnect()
+        self.replicate()
+
+        # we should have received all the expected rows in the right order
+        received_rows = self.test_handler.received_rdata_rows
+
+        for t in updates:
+            (stream_name, token, row) = received_rows.pop(0)
+            self.assertEqual(stream_name, AccountDataStream.NAME)
+            self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
+            self.assertEqual(row.data_type, t)
+            self.assertEqual(row.room_id, "test_room")
+
+        (stream_name, token, row) = received_rows.pop(0)
+        self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
+        self.assertEqual(row.data_type, "m.global")
+        self.assertIsNone(row.room_id)
+
+        self.assertEqual([], received_rows)
+
+    def test_update_function_global_account_data_limit(self):
+        """Test replication with many global account data updates
+        """
+        store = self.hs.get_datastore()
+
+        # generate lots of account data updates
+        updates = []
+        for i in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 5):
+            update = "m.test_type.%i" % (i,)
+            self.get_success(store.add_account_data_for_user("test_user", update, {}))
+            updates.append(update)
+
+        # also one per-room update
+        self.get_success(
+            store.add_account_data_to_room("test_user", "test_room", "m.per_room", {})
+        )
+
+        # tell the notifier to catch up to avoid duplicate rows.
+        # workaround for https://github.com/matrix-org/synapse/issues/7360
+        # FIXME remove this when the above is fixed
+        self.replicate()
+
+        # check we're testing what we think we are: no rows should yet have been
+        # received
+        self.assertEqual([], self.test_handler.received_rdata_rows)
+
+        # now reconnect to pull the updates
+        self.reconnect()
+        self.replicate()
+
+        # we should have received all the expected rows in the right order
+        received_rows = self.test_handler.received_rdata_rows
+
+        for t in updates:
+            (stream_name, token, row) = received_rows.pop(0)
+            self.assertEqual(stream_name, AccountDataStream.NAME)
+            self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
+            self.assertEqual(row.data_type, t)
+            self.assertIsNone(row.room_id)
+
+        (stream_name, token, row) = received_rows.pop(0)
+        self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
+        self.assertEqual(row.data_type, "m.per_room")
+        self.assertEqual(row.room_id, "test_room")
+
+        self.assertEqual([], received_rows)