diff options
Diffstat (limited to 'synapse/replication/tcp')
-rw-r--r-- | synapse/replication/tcp/streams/_base.py | 68 |
1 files changed, 56 insertions, 12 deletions
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index b48a6a3e91..d42aaff055 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -14,14 +14,27 @@ # See the License for the specific language governing permissions and # limitations under the License. +import heapq import logging from collections import namedtuple -from typing import Any, Awaitable, Callable, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + List, + Optional, + Tuple, + TypeVar, +) import attr from synapse.replication.http.streams import ReplicationGetStreamUpdates +if TYPE_CHECKING: + import synapse.server + logger = logging.getLogger(__name__) # the number of rows to request from an update_function. @@ -37,7 +50,7 @@ Token = int # parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's # just a row from a database query, though this is dependent on the stream in question. # -StreamRow = Tuple +StreamRow = TypeVar("StreamRow", bound=Tuple) # The type returned by the update_function of a stream, as well as get_updates(), # get_updates_since, etc. @@ -533,32 +546,63 @@ class AccountDataStream(Stream): """ AccountDataStreamRow = namedtuple( - "AccountDataStream", ("user_id", "room_id", "data_type") # str # str # str + "AccountDataStream", + ("user_id", "room_id", "data_type"), # str # Optional[str] # str ) NAME = "account_data" ROW_TYPE = AccountDataStreamRow - def __init__(self, hs): + def __init__(self, hs: "synapse.server.HomeServer"): self.store = hs.get_datastore() super().__init__( hs.get_instance_name(), current_token_without_instance(self.store.get_max_account_data_stream_id), - db_query_to_update_function(self._update_function), + self._update_function, + ) + + async def _update_function( + self, instance_name: str, from_token: int, to_token: int, limit: int + ) -> StreamUpdateResult: + limited = False + global_results = await self.store.get_updated_global_account_data( + from_token, to_token, limit ) - async def _update_function(self, from_token, to_token, limit): - global_results, room_results = await self.store.get_all_updated_account_data( - from_token, from_token, to_token, limit + # if the global results hit the limit, we'll need to limit the room results to + # the same stream token. + if len(global_results) >= limit: + to_token = global_results[-1][0] + limited = True + + room_results = await self.store.get_updated_room_account_data( + from_token, to_token, limit ) - results = list(room_results) - results.extend( - (stream_id, user_id, None, account_data_type) + # likewise, if the room results hit the limit, limit the global results to + # the same stream token. + if len(room_results) >= limit: + to_token = room_results[-1][0] + limited = True + + # convert the global results to the right format, and limit them to the to_token + # at the same time + global_rows = ( + (stream_id, (user_id, None, account_data_type)) for stream_id, user_id, account_data_type in global_results + if stream_id <= to_token + ) + + # we know that the room_results are already limited to `to_token` so no need + # for a check on `stream_id` here. + room_rows = ( + (stream_id, (user_id, room_id, account_data_type)) + for stream_id, user_id, room_id, account_data_type in room_results ) - return results + # we need to return a sorted list, so merge them together. + updates = list(heapq.merge(room_rows, global_rows)) + return updates, to_token, limited class GroupServerStream(Stream): |