diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 208e8a667b..b0f87c365b 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -14,117 +14,71 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import itertools
import logging
from collections import namedtuple
-from typing import Any, List, Optional
+from typing import Any, Awaitable, Callable, List, Optional, Tuple
import attr
+from synapse.replication.http.streams import ReplicationGetStreamUpdates
+
logger = logging.getLogger(__name__)
+# the number of rows to request from an update_function.
+_STREAM_UPDATE_TARGET_ROW_COUNT = 100
-MAX_EVENTS_BEHIND = 500000
-
-BackfillStreamRow = namedtuple(
- "BackfillStreamRow",
- (
- "event_id", # str
- "room_id", # str
- "type", # str
- "state_key", # str, optional
- "redacts", # str, optional
- "relates_to", # str, optional
- ),
-)
-PresenceStreamRow = namedtuple(
- "PresenceStreamRow",
- (
- "user_id", # str
- "state", # str
- "last_active_ts", # int
- "last_federation_update_ts", # int
- "last_user_sync_ts", # int
- "status_msg", # str
- "currently_active", # bool
- ),
-)
-TypingStreamRow = namedtuple(
- "TypingStreamRow", ("room_id", "user_ids") # str # list(str)
-)
-ReceiptsStreamRow = namedtuple(
- "ReceiptsStreamRow",
- (
- "room_id", # str
- "receipt_type", # str
- "user_id", # str
- "event_id", # str
- "data", # dict
- ),
-)
-PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) # str
-PushersStreamRow = namedtuple(
- "PushersStreamRow",
- ("user_id", "app_id", "pushkey", "deleted"), # str # str # str # bool
-)
-
-
-@attr.s
-class CachesStreamRow:
- """Stream to inform workers they should invalidate their cache.
-
- Attributes:
- cache_func: Name of the cached function.
- keys: The entry in the cache to invalidate. If None then will
- invalidate all.
- invalidation_ts: Timestamp of when the invalidation took place.
- """
- cache_func = attr.ib(type=str)
- keys = attr.ib(type=Optional[List[Any]])
- invalidation_ts = attr.ib(type=int)
-
-
-PublicRoomsStreamRow = namedtuple(
- "PublicRoomsStreamRow",
- (
- "room_id", # str
- "visibility", # str
- "appservice_id", # str, optional
- "network_id", # str, optional
- ),
-)
-DeviceListsStreamRow = namedtuple(
- "DeviceListsStreamRow", ("user_id", "destination") # str # str
-)
-ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str
-TagAccountDataStreamRow = namedtuple(
- "TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict
-)
-AccountDataStreamRow = namedtuple(
- "AccountDataStream", ("user_id", "room_id", "data_type") # str # str # str
-)
-GroupsStreamRow = namedtuple(
- "GroupsStreamRow",
- ("group_id", "user_id", "type", "content"), # str # str # str # dict
-)
-UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str
+# Some type aliases to make things a bit easier.
+
+# A stream position token
+Token = int
+
+# The type of a stream update row, after JSON deserialisation, but before
+# parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's
+# just a row from a database query, though this is dependent on the stream in question.
+#
+StreamRow = Tuple
+
+# The type returned by the update_function of a stream, as well as get_updates(),
+# get_updates_since, etc.
+#
+# It consists of a triplet `(updates, new_last_token, limited)`, where:
+# * `updates` is a list of `(token, row)` entries.
+# * `new_last_token` is the new position in stream.
+# * `limited` is whether there are more updates to fetch.
+#
+StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool]
+
+# The type of an update_function for a stream
+#
+# The arguments are:
+#
+# * instance_name: the writer of the stream
+# * from_token: the previous stream token: the starting point for fetching the
+# updates
+# * to_token: the new stream token: the point to get updates up to
+# * target_row_count: a target for the number of rows to be returned.
+#
+# The update_function is expected to return up to _approximately_ target_row_count rows.
+# If there are more updates available, it should set `limited` in the result, and
+# it will be called again to get the next batch.
+#
+UpdateFunction = Callable[[str, Token, Token, int], Awaitable[StreamUpdateResult]]
class Stream(object):
"""Base class for the streams.
Provides a `get_updates()` function that returns new updates since the last
- time it was called up until the point `advance_current_token` was called.
+ time it was called.
"""
NAME = None # type: str # The name of the stream
# The type of the row. Used by the default impl of parse_row.
ROW_TYPE = None # type: Any
- _LIMITED = True # Whether the update function takes a limit
@classmethod
- def parse_row(cls, row):
+ def parse_row(cls, row: StreamRow):
"""Parse a row received over replication
By default, assumes that the row data is an array object and passes its contents
@@ -138,101 +92,122 @@ class Stream(object):
"""
return cls.ROW_TYPE(*row)
- def __init__(self, hs):
- # The token from which we last asked for updates
- self.last_token = self.current_token()
+ def __init__(
+ self,
+ local_instance_name: str,
+ current_token_function: Callable[[], Token],
+ update_function: UpdateFunction,
+ ):
+ """Instantiate a Stream
+
+ current_token_function and update_function are callbacks which should be
+ implemented by subclasses.
+
+ current_token_function is called to get the current token of the underlying
+ stream.
- # The token that we will get updates up to
- self.upto_token = self.current_token()
+ update_function is called to get updates for this stream between a pair of
+ stream tokens. See the UpdateFunction type definition for more info.
- def advance_current_token(self):
- """Updates `upto_token` to "now", which updates up until which point
- get_updates[_since] will fetch rows till.
+ Args:
+ local_instance_name: The instance name of the current process
+ current_token_function: callback to get the current token, as above
+ update_function: callback go get stream updates, as above
"""
- self.upto_token = self.current_token()
+ self.local_instance_name = local_instance_name
+ self.current_token = current_token_function
+ self.update_function = update_function
+
+ # The token from which we last asked for updates
+ self.last_token = self.current_token()
def discard_updates_and_advance(self):
"""Called when the stream should advance but the updates would be discarded,
e.g. when there are no currently connected workers.
"""
- self.upto_token = self.current_token()
- self.last_token = self.upto_token
+ self.last_token = self.current_token()
- async def get_updates(self):
+ async def get_updates(self) -> StreamUpdateResult:
"""Gets all updates since the last time this function was called (or
- since the stream was constructed if it hadn't been called before),
- until the `upto_token`
+ since the stream was constructed if it hadn't been called before).
Returns:
- Deferred[Tuple[List[Tuple[int, Any]], int]:
- Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
- list of ``(token, row)`` entries. ``row`` will be json-serialised and
- sent over the replication steam.
+ A triplet `(updates, new_last_token, limited)`, where `updates` is
+ a list of `(token, row)` entries, `new_last_token` is the new
+ position in stream, and `limited` is whether there are more updates
+ to fetch.
"""
- updates, current_token = await self.get_updates_since(self.last_token)
+ current_token = self.current_token()
+ updates, current_token, limited = await self.get_updates_since(
+ self.local_instance_name, self.last_token, current_token
+ )
self.last_token = current_token
- return updates, current_token
+ return updates, current_token, limited
- async def get_updates_since(self, from_token):
+ async def get_updates_since(
+ self, instance_name: str, from_token: Token, upto_token: Token
+ ) -> StreamUpdateResult:
"""Like get_updates except allows specifying from when we should
stream updates
Returns:
- Deferred[Tuple[List[Tuple[int, Any]], int]:
- Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
- list of ``(token, row)`` entries. ``row`` will be json-serialised and
- sent over the replication steam.
+ A triplet `(updates, new_last_token, limited)`, where `updates` is
+ a list of `(token, row)` entries, `new_last_token` is the new
+ position in stream, and `limited` is whether there are more updates
+ to fetch.
"""
- if from_token in ("NOW", "now"):
- return [], self.upto_token
-
- current_token = self.upto_token
from_token = int(from_token)
- if from_token == current_token:
- return [], current_token
+ if from_token == upto_token:
+ return [], upto_token, False
- logger.info("get_updates_since: %s", self.__class__)
- if self._LIMITED:
- rows = await self.update_function(
- from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
- )
+ updates, upto_token, limited = await self.update_function(
+ instance_name, from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT,
+ )
+ return updates, upto_token, limited
- # never turn more than MAX_EVENTS_BEHIND + 1 into updates.
- rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
- else:
- rows = await self.update_function(from_token, current_token)
+def db_query_to_update_function(
+ query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
+) -> UpdateFunction:
+ """Wraps a db query function which returns a list of rows to make it
+ suitable for use as an `update_function` for the Stream class
+ """
+
+ async def update_function(instance_name, from_token, upto_token, limit):
+ rows = await query_function(from_token, upto_token, limit)
updates = [(row[0], row[1:]) for row in rows]
+ limited = False
+ if len(updates) >= limit:
+ upto_token = updates[-1][0]
+ limited = True
- # check we didn't get more rows than the limit.
- # doing it like this allows the update_function to be a generator.
- if self._LIMITED and len(updates) >= MAX_EVENTS_BEHIND:
- raise Exception("stream %s has fallen behind" % (self.NAME))
+ return updates, upto_token, limited
- return updates, current_token
+ return update_function
- def current_token(self):
- """Gets the current token of the underlying streams. Should be provided
- by the sub classes
- Returns:
- int
- """
- raise NotImplementedError()
+def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
+ """Makes a suitable function for use as an `update_function` that queries
+ the master process for updates.
+ """
- def update_function(self, from_token, current_token, limit=None):
- """Get updates between from_token and to_token. If Stream._LIMITED is
- True then limit is provided, otherwise it's not.
+ client = ReplicationGetStreamUpdates.make_client(hs)
- Returns:
- Deferred(list(tuple)): the first entry in the tuple is the token for
- that update, and the rest of the tuple gets used to construct
- a ``ROW_TYPE`` instance
- """
- raise NotImplementedError()
+ async def update_function(
+ instance_name: str, from_token: int, upto_token: int, limit: int
+ ) -> StreamUpdateResult:
+ result = await client(
+ instance_name=instance_name,
+ stream_name=stream_name,
+ from_token=from_token,
+ upto_token=upto_token,
+ )
+ return result["updates"], result["upto_token"], result["limited"]
+
+ return update_function
class BackfillStream(Stream):
@@ -240,93 +215,166 @@ class BackfillStream(Stream):
or it went from being an outlier to not.
"""
+ BackfillStreamRow = namedtuple(
+ "BackfillStreamRow",
+ (
+ "event_id", # str
+ "room_id", # str
+ "type", # str
+ "state_key", # str, optional
+ "redacts", # str, optional
+ "relates_to", # str, optional
+ ),
+ )
+
NAME = "backfill"
ROW_TYPE = BackfillStreamRow
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_current_backfill_token # type: ignore
- self.update_function = store.get_all_new_backfill_event_rows # type: ignore
-
- super(BackfillStream, self).__init__(hs)
+ super().__init__(
+ hs.get_instance_name(),
+ store.get_current_backfill_token,
+ db_query_to_update_function(store.get_all_new_backfill_event_rows),
+ )
class PresenceStream(Stream):
+ PresenceStreamRow = namedtuple(
+ "PresenceStreamRow",
+ (
+ "user_id", # str
+ "state", # str
+ "last_active_ts", # int
+ "last_federation_update_ts", # int
+ "last_user_sync_ts", # int
+ "status_msg", # str
+ "currently_active", # bool
+ ),
+ )
+
NAME = "presence"
- _LIMITED = False
ROW_TYPE = PresenceStreamRow
def __init__(self, hs):
store = hs.get_datastore()
- presence_handler = hs.get_presence_handler()
- self.current_token = store.get_current_presence_token # type: ignore
- self.update_function = presence_handler.get_all_presence_updates # type: ignore
+ if hs.config.worker_app is None:
+ # on the master, query the presence handler
+ presence_handler = hs.get_presence_handler()
+ update_function = db_query_to_update_function(
+ presence_handler.get_all_presence_updates
+ )
+ else:
+ # Query master process
+ update_function = make_http_update_function(hs, self.NAME)
- super(PresenceStream, self).__init__(hs)
+ super().__init__(
+ hs.get_instance_name(), store.get_current_presence_token, update_function
+ )
class TypingStream(Stream):
+ TypingStreamRow = namedtuple(
+ "TypingStreamRow", ("room_id", "user_ids") # str # list(str)
+ )
+
NAME = "typing"
- _LIMITED = False
ROW_TYPE = TypingStreamRow
def __init__(self, hs):
typing_handler = hs.get_typing_handler()
- self.current_token = typing_handler.get_current_token # type: ignore
- self.update_function = typing_handler.get_all_typing_updates # type: ignore
+ if hs.config.worker_app is None:
+ # on the master, query the typing handler
+ update_function = db_query_to_update_function(
+ typing_handler.get_all_typing_updates
+ )
+ else:
+ # Query master process
+ update_function = make_http_update_function(hs, self.NAME)
- super(TypingStream, self).__init__(hs)
+ super().__init__(
+ hs.get_instance_name(), typing_handler.get_current_token, update_function
+ )
class ReceiptsStream(Stream):
+ ReceiptsStreamRow = namedtuple(
+ "ReceiptsStreamRow",
+ (
+ "room_id", # str
+ "receipt_type", # str
+ "user_id", # str
+ "event_id", # str
+ "data", # dict
+ ),
+ )
+
NAME = "receipts"
ROW_TYPE = ReceiptsStreamRow
def __init__(self, hs):
store = hs.get_datastore()
-
- self.current_token = store.get_max_receipt_stream_id # type: ignore
- self.update_function = store.get_all_updated_receipts # type: ignore
-
- super(ReceiptsStream, self).__init__(hs)
+ super().__init__(
+ hs.get_instance_name(),
+ store.get_max_receipt_stream_id,
+ db_query_to_update_function(store.get_all_updated_receipts),
+ )
class PushRulesStream(Stream):
"""A user has changed their push rules
"""
+ PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) # str
+
NAME = "push_rules"
ROW_TYPE = PushRulesStreamRow
def __init__(self, hs):
self.store = hs.get_datastore()
- super(PushRulesStream, self).__init__(hs)
+ super(PushRulesStream, self).__init__(
+ hs.get_instance_name(), self._current_token, self._update_function
+ )
- def current_token(self):
+ def _current_token(self) -> int:
push_rules_token, _ = self.store.get_push_rules_stream_token()
return push_rules_token
- async def update_function(self, from_token, to_token, limit):
+ async def _update_function(
+ self, instance_name: str, from_token: Token, to_token: Token, limit: int
+ ):
rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
- return [(row[0], row[2]) for row in rows]
+
+ limited = False
+ if len(rows) == limit:
+ to_token = rows[-1][0]
+ limited = True
+
+ return [(row[0], (row[2],)) for row in rows], to_token, limited
class PushersStream(Stream):
"""A user has added/changed/removed a pusher
"""
+ PushersStreamRow = namedtuple(
+ "PushersStreamRow",
+ ("user_id", "app_id", "pushkey", "deleted"), # str # str # str # bool
+ )
+
NAME = "pushers"
ROW_TYPE = PushersStreamRow
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_pushers_stream_token # type: ignore
- self.update_function = store.get_all_updated_pushers_rows # type: ignore
-
- super(PushersStream, self).__init__(hs)
+ super().__init__(
+ hs.get_instance_name(),
+ store.get_pushers_stream_token,
+ db_query_to_update_function(store.get_all_updated_pushers_rows),
+ )
class CachesStream(Stream):
@@ -334,98 +382,138 @@ class CachesStream(Stream):
the cache on the workers
"""
+ @attr.s
+ class CachesStreamRow:
+ """Stream to inform workers they should invalidate their cache.
+
+ Attributes:
+ cache_func: Name of the cached function.
+ keys: The entry in the cache to invalidate. If None then will
+ invalidate all.
+ invalidation_ts: Timestamp of when the invalidation took place.
+ """
+
+ cache_func = attr.ib(type=str)
+ keys = attr.ib(type=Optional[List[Any]])
+ invalidation_ts = attr.ib(type=int)
+
NAME = "caches"
ROW_TYPE = CachesStreamRow
def __init__(self, hs):
store = hs.get_datastore()
-
- self.current_token = store.get_cache_stream_token # type: ignore
- self.update_function = store.get_all_updated_caches # type: ignore
-
- super(CachesStream, self).__init__(hs)
+ super().__init__(
+ hs.get_instance_name(),
+ store.get_cache_stream_token,
+ db_query_to_update_function(store.get_all_updated_caches),
+ )
class PublicRoomsStream(Stream):
"""The public rooms list changed
"""
+ PublicRoomsStreamRow = namedtuple(
+ "PublicRoomsStreamRow",
+ (
+ "room_id", # str
+ "visibility", # str
+ "appservice_id", # str, optional
+ "network_id", # str, optional
+ ),
+ )
+
NAME = "public_rooms"
ROW_TYPE = PublicRoomsStreamRow
def __init__(self, hs):
store = hs.get_datastore()
-
- self.current_token = store.get_current_public_room_stream_id # type: ignore
- self.update_function = store.get_all_new_public_rooms # type: ignore
-
- super(PublicRoomsStream, self).__init__(hs)
+ super().__init__(
+ hs.get_instance_name(),
+ store.get_current_public_room_stream_id,
+ db_query_to_update_function(store.get_all_new_public_rooms),
+ )
class DeviceListsStream(Stream):
- """Someone added/changed/removed a device
+ """Either a user has updated their devices or a remote server needs to be
+ told about a device update.
"""
+ @attr.s
+ class DeviceListsStreamRow:
+ entity = attr.ib(type=str)
+
NAME = "device_lists"
- _LIMITED = False
ROW_TYPE = DeviceListsStreamRow
def __init__(self, hs):
store = hs.get_datastore()
-
- self.current_token = store.get_device_stream_token # type: ignore
- self.update_function = store.get_all_device_list_changes_for_remotes # type: ignore
-
- super(DeviceListsStream, self).__init__(hs)
+ super().__init__(
+ hs.get_instance_name(),
+ store.get_device_stream_token,
+ db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
+ )
class ToDeviceStream(Stream):
"""New to_device messages for a client
"""
+ ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str
+
NAME = "to_device"
ROW_TYPE = ToDeviceStreamRow
def __init__(self, hs):
store = hs.get_datastore()
-
- self.current_token = store.get_to_device_stream_token # type: ignore
- self.update_function = store.get_all_new_device_messages # type: ignore
-
- super(ToDeviceStream, self).__init__(hs)
+ super().__init__(
+ hs.get_instance_name(),
+ store.get_to_device_stream_token,
+ db_query_to_update_function(store.get_all_new_device_messages),
+ )
class TagAccountDataStream(Stream):
"""Someone added/removed a tag for a room
"""
+ TagAccountDataStreamRow = namedtuple(
+ "TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict
+ )
+
NAME = "tag_account_data"
ROW_TYPE = TagAccountDataStreamRow
def __init__(self, hs):
store = hs.get_datastore()
-
- self.current_token = store.get_max_account_data_stream_id # type: ignore
- self.update_function = store.get_all_updated_tags # type: ignore
-
- super(TagAccountDataStream, self).__init__(hs)
+ super().__init__(
+ hs.get_instance_name(),
+ store.get_max_account_data_stream_id,
+ db_query_to_update_function(store.get_all_updated_tags),
+ )
class AccountDataStream(Stream):
"""Global or per room account data was changed
"""
+ AccountDataStreamRow = namedtuple(
+ "AccountDataStream", ("user_id", "room_id", "data_type") # str # str # str
+ )
+
NAME = "account_data"
ROW_TYPE = AccountDataStreamRow
def __init__(self, hs):
self.store = hs.get_datastore()
+ super().__init__(
+ hs.get_instance_name(),
+ self.store.get_max_account_data_stream_id,
+ db_query_to_update_function(self._update_function),
+ )
- self.current_token = self.store.get_max_account_data_stream_id # type: ignore
-
- super(AccountDataStream, self).__init__(hs)
-
- async def update_function(self, from_token, to_token, limit):
+ async def _update_function(self, from_token, to_token, limit):
global_results, room_results = await self.store.get_all_updated_account_data(
from_token, from_token, to_token, limit
)
@@ -440,30 +528,38 @@ class AccountDataStream(Stream):
class GroupServerStream(Stream):
+ GroupsStreamRow = namedtuple(
+ "GroupsStreamRow",
+ ("group_id", "user_id", "type", "content"), # str # str # str # dict
+ )
+
NAME = "groups"
ROW_TYPE = GroupsStreamRow
def __init__(self, hs):
store = hs.get_datastore()
-
- self.current_token = store.get_group_stream_token # type: ignore
- self.update_function = store.get_all_groups_changes # type: ignore
-
- super(GroupServerStream, self).__init__(hs)
+ super().__init__(
+ hs.get_instance_name(),
+ store.get_group_stream_token,
+ db_query_to_update_function(store.get_all_groups_changes),
+ )
class UserSignatureStream(Stream):
"""A user has signed their own device with their user-signing key
"""
+ UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str
+
NAME = "user_signature"
- _LIMITED = False
ROW_TYPE = UserSignatureStreamRow
def __init__(self, hs):
store = hs.get_datastore()
-
- self.current_token = store.get_device_stream_token # type: ignore
- self.update_function = store.get_all_user_signature_changes_for_remotes # type: ignore
-
- super(UserSignatureStream, self).__init__(hs)
+ super().__init__(
+ hs.get_instance_name(),
+ store.get_device_stream_token,
+ db_query_to_update_function(
+ store.get_all_user_signature_changes_for_remotes
+ ),
+ )
|