diff --git a/changelog.d/7384.bugfix b/changelog.d/7384.bugfix
new file mode 100644
index 0000000000..f49c600173
--- /dev/null
+++ b/changelog.d/7384.bugfix
@@ -0,0 +1 @@
+Fix a bug where event updates might not be sent over replication to worker processes after the stream falls behind.
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index b48a6a3e91..d42aaff055 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -14,14 +14,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import heapq
import logging
from collections import namedtuple
-from typing import Any, Awaitable, Callable, List, Optional, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Awaitable,
+ Callable,
+ List,
+ Optional,
+ Tuple,
+ TypeVar,
+)
import attr
from synapse.replication.http.streams import ReplicationGetStreamUpdates
+if TYPE_CHECKING:
+ import synapse.server
+
logger = logging.getLogger(__name__)
# the number of rows to request from an update_function.
@@ -37,7 +50,7 @@ Token = int
# parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's
# just a row from a database query, though this is dependent on the stream in question.
#
-StreamRow = Tuple
+StreamRow = TypeVar("StreamRow", bound=Tuple)
# The type returned by the update_function of a stream, as well as get_updates(),
# get_updates_since, etc.
@@ -533,32 +546,63 @@ class AccountDataStream(Stream):
"""
AccountDataStreamRow = namedtuple(
- "AccountDataStream", ("user_id", "room_id", "data_type") # str # str # str
+ "AccountDataStream",
+ ("user_id", "room_id", "data_type"), # str # Optional[str] # str
)
NAME = "account_data"
ROW_TYPE = AccountDataStreamRow
- def __init__(self, hs):
+ def __init__(self, hs: "synapse.server.HomeServer"):
self.store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
current_token_without_instance(self.store.get_max_account_data_stream_id),
- db_query_to_update_function(self._update_function),
+ self._update_function,
+ )
+
+ async def _update_function(
+ self, instance_name: str, from_token: int, to_token: int, limit: int
+ ) -> StreamUpdateResult:
+ limited = False
+ global_results = await self.store.get_updated_global_account_data(
+ from_token, to_token, limit
)
- async def _update_function(self, from_token, to_token, limit):
- global_results, room_results = await self.store.get_all_updated_account_data(
- from_token, from_token, to_token, limit
+ # if the global results hit the limit, we'll need to limit the room results to
+ # the same stream token.
+ if len(global_results) >= limit:
+ to_token = global_results[-1][0]
+ limited = True
+
+ room_results = await self.store.get_updated_room_account_data(
+ from_token, to_token, limit
)
- results = list(room_results)
- results.extend(
- (stream_id, user_id, None, account_data_type)
+ # likewise, if the room results hit the limit, limit the global results to
+ # the same stream token.
+ if len(room_results) >= limit:
+ to_token = room_results[-1][0]
+ limited = True
+
+ # convert the global results to the right format, and limit them to the to_token
+ # at the same time
+ global_rows = (
+ (stream_id, (user_id, None, account_data_type))
for stream_id, user_id, account_data_type in global_results
+ if stream_id <= to_token
+ )
+
+ # we know that the room_results are already limited to `to_token` so no need
+ # for a check on `stream_id` here.
+ room_rows = (
+ (stream_id, (user_id, room_id, account_data_type))
+ for stream_id, user_id, room_id, account_data_type in room_results
)
- return results
+ # we need to return a sorted list, so merge them together.
+ updates = list(heapq.merge(room_rows, global_rows))
+ return updates, to_token, limited
class GroupServerStream(Stream):
diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py
index 46b494b334..f9eef1b78e 100644
--- a/synapse/storage/data_stores/main/account_data.py
+++ b/synapse/storage/data_stores/main/account_data.py
@@ -16,6 +16,7 @@
import abc
import logging
+from typing import List, Tuple
from canonicaljson import json
@@ -175,41 +176,64 @@ class AccountDataWorkerStore(SQLBaseStore):
"get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
)
- def get_all_updated_account_data(
- self, last_global_id, last_room_id, current_id, limit
- ):
- """Get all the client account_data that has changed on the server
+ async def get_updated_global_account_data(
+ self, last_id: int, current_id: int, limit: int
+ ) -> List[Tuple[int, str, str]]:
+ """Get the global account_data that has changed, for the account_data stream
+
Args:
- last_global_id(int): The position to fetch from for top level data
- last_room_id(int): The position to fetch from for per room data
- current_id(int): The position to fetch up to.
+ last_id: the last stream_id from the previous batch.
+ current_id: the maximum stream_id to return up to
+ limit: the maximum number of rows to return
+
Returns:
- A deferred pair of lists of tuples of stream_id int, user_id string,
- room_id string, and type string.
+ A list of tuples of stream_id int, user_id string,
+ and type string.
"""
- if last_room_id == current_id and last_global_id == current_id:
- return defer.succeed(([], []))
+ if last_id == current_id:
+ return []
- def get_updated_account_data_txn(txn):
+ def get_updated_global_account_data_txn(txn):
sql = (
"SELECT stream_id, user_id, account_data_type"
" FROM account_data WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?"
)
- txn.execute(sql, (last_global_id, current_id, limit))
- global_results = txn.fetchall()
+ txn.execute(sql, (last_id, current_id, limit))
+ return txn.fetchall()
+
+ return await self.db.runInteraction(
+ "get_updated_global_account_data", get_updated_global_account_data_txn
+ )
+
+ async def get_updated_room_account_data(
+ self, last_id: int, current_id: int, limit: int
+ ) -> List[Tuple[int, str, str, str]]:
+ """Get the global account_data that has changed, for the account_data stream
+ Args:
+ last_id: the last stream_id from the previous batch.
+ current_id: the maximum stream_id to return up to
+ limit: the maximum number of rows to return
+
+ Returns:
+ A list of tuples of stream_id int, user_id string,
+ room_id string and type string.
+ """
+ if last_id == current_id:
+ return []
+
+ def get_updated_room_account_data_txn(txn):
sql = (
"SELECT stream_id, user_id, room_id, account_data_type"
" FROM room_account_data WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?"
)
- txn.execute(sql, (last_room_id, current_id, limit))
- room_results = txn.fetchall()
- return global_results, room_results
+ txn.execute(sql, (last_id, current_id, limit))
+ return txn.fetchall()
- return self.db.runInteraction(
- "get_all_updated_account_data_txn", get_updated_account_data_txn
+ return await self.db.runInteraction(
+ "get_updated_room_account_data", get_updated_room_account_data_txn
)
def get_updated_account_data_for_user(self, user_id, stream_id):
diff --git a/tests/replication/tcp/streams/test_account_data.py b/tests/replication/tcp/streams/test_account_data.py
new file mode 100644
index 0000000000..6a5116dd2a
--- /dev/null
+++ b/tests/replication/tcp/streams/test_account_data.py
@@ -0,0 +1,117 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.replication.tcp.streams._base import (
+ _STREAM_UPDATE_TARGET_ROW_COUNT,
+ AccountDataStream,
+)
+
+from tests.replication._base import BaseStreamTestCase
+
+
+class AccountDataStreamTestCase(BaseStreamTestCase):
+ def test_update_function_room_account_data_limit(self):
+ """Test replication with many room account data updates
+ """
+ store = self.hs.get_datastore()
+
+ # generate lots of account data updates
+ updates = []
+ for i in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 5):
+ update = "m.test_type.%i" % (i,)
+ self.get_success(
+ store.add_account_data_to_room("test_user", "test_room", update, {})
+ )
+ updates.append(update)
+
+ # also one global update
+ self.get_success(store.add_account_data_for_user("test_user", "m.global", {}))
+
+ # tell the notifier to catch up to avoid duplicate rows.
+ # workaround for https://github.com/matrix-org/synapse/issues/7360
+ # FIXME remove this when the above is fixed
+ self.replicate()
+
+ # check we're testing what we think we are: no rows should yet have been
+ # received
+ self.assertEqual([], self.test_handler.received_rdata_rows)
+
+ # now reconnect to pull the updates
+ self.reconnect()
+ self.replicate()
+
+ # we should have received all the expected rows in the right order
+ received_rows = self.test_handler.received_rdata_rows
+
+ for t in updates:
+ (stream_name, token, row) = received_rows.pop(0)
+ self.assertEqual(stream_name, AccountDataStream.NAME)
+ self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
+ self.assertEqual(row.data_type, t)
+ self.assertEqual(row.room_id, "test_room")
+
+ (stream_name, token, row) = received_rows.pop(0)
+ self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
+ self.assertEqual(row.data_type, "m.global")
+ self.assertIsNone(row.room_id)
+
+ self.assertEqual([], received_rows)
+
+ def test_update_function_global_account_data_limit(self):
+ """Test replication with many global account data updates
+ """
+ store = self.hs.get_datastore()
+
+ # generate lots of account data updates
+ updates = []
+ for i in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 5):
+ update = "m.test_type.%i" % (i,)
+ self.get_success(store.add_account_data_for_user("test_user", update, {}))
+ updates.append(update)
+
+ # also one per-room update
+ self.get_success(
+ store.add_account_data_to_room("test_user", "test_room", "m.per_room", {})
+ )
+
+ # tell the notifier to catch up to avoid duplicate rows.
+ # workaround for https://github.com/matrix-org/synapse/issues/7360
+ # FIXME remove this when the above is fixed
+ self.replicate()
+
+ # check we're testing what we think we are: no rows should yet have been
+ # received
+ self.assertEqual([], self.test_handler.received_rdata_rows)
+
+ # now reconnect to pull the updates
+ self.reconnect()
+ self.replicate()
+
+ # we should have received all the expected rows in the right order
+ received_rows = self.test_handler.received_rdata_rows
+
+ for t in updates:
+ (stream_name, token, row) = received_rows.pop(0)
+ self.assertEqual(stream_name, AccountDataStream.NAME)
+ self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
+ self.assertEqual(row.data_type, t)
+ self.assertIsNone(row.room_id)
+
+ (stream_name, token, row) = received_rows.pop(0)
+ self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
+ self.assertEqual(row.data_type, "m.per_room")
+ self.assertEqual(row.room_id, "test_room")
+
+ self.assertEqual([], received_rows)
|