diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py
index 29199f5b46..37bcd3de66 100644
--- a/synapse/replication/tcp/streams/__init__.py
+++ b/synapse/replication/tcp/streams/__init__.py
@@ -24,6 +24,9 @@ Each stream is defined by the following information:
current_token: The function that returns the current token for the stream
update_function: The function that returns a list of updates between two tokens
"""
+
+from typing import Dict, Type
+
from synapse.replication.tcp.streams._base import (
AccountDataStream,
BackfillStream,
@@ -35,6 +38,7 @@ from synapse.replication.tcp.streams._base import (
PushersStream,
PushRulesStream,
ReceiptsStream,
+ Stream,
TagAccountDataStream,
ToDeviceStream,
TypingStream,
@@ -63,10 +67,12 @@ STREAMS_MAP = {
GroupServerStream,
UserSignatureStream,
)
-}
+} # type: Dict[str, Type[Stream]]
+
__all__ = [
"STREAMS_MAP",
+ "Stream",
"BackfillStream",
"PresenceStream",
"TypingStream",
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 32d9514883..c14dff6c64 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -14,13 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import itertools
import logging
from collections import namedtuple
-from typing import Any, List, Optional, Tuple
+from typing import Any, Awaitable, Callable, List, Optional, Tuple
import attr
+from synapse.replication.http.streams import ReplicationGetStreamUpdates
from synapse.types import JsonDict
logger = logging.getLogger(__name__)
@@ -29,6 +29,15 @@ logger = logging.getLogger(__name__)
MAX_EVENTS_BEHIND = 500000
+# Some type aliases to make things a bit easier.
+
+# A stream position token
+Token = int
+
+# A pair of position in stream and args used to create an instance of `ROW_TYPE`.
+StreamRow = Tuple[Token, tuple]
+
+
class Stream(object):
"""Base class for the streams.
@@ -56,6 +65,7 @@ class Stream(object):
return cls.ROW_TYPE(*row)
def __init__(self, hs):
+
# The token from which we last asked for updates
self.last_token = self.current_token()
@@ -65,61 +75,46 @@ class Stream(object):
"""
self.last_token = self.current_token()
- async def get_updates(self):
+ async def get_updates(self) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
"""Gets all updates since the last time this function was called (or
since the stream was constructed if it hadn't been called before).
Returns:
- Deferred[Tuple[List[Tuple[int, Any]], int]:
- Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
- list of ``(token, row)`` entries. ``row`` will be json-serialised and
- sent over the replication steam.
+ A triplet `(updates, new_last_token, limited)`, where `updates` is
+ a list of `(token, row)` entries, `new_last_token` is the new
+ position in stream, and `limited` is whether there are more updates
+ to fetch.
"""
- updates, current_token = await self.get_updates_since(self.last_token)
+ current_token = self.current_token()
+ updates, current_token, limited = await self.get_updates_since(
+ self.last_token, current_token
+ )
self.last_token = current_token
- return updates, current_token
+ return updates, current_token, limited
async def get_updates_since(
- self, from_token: int
- ) -> Tuple[List[Tuple[int, JsonDict]], int]:
+ self, from_token: Token, upto_token: Token, limit: int = 100
+ ) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
"""Like get_updates except allows specifying from when we should
stream updates
Returns:
- Resolves to a pair `(updates, new_last_token)`, where `updates` is
- a list of `(token, row)` entries and `new_last_token` is the new
- position in stream.
+ A triplet `(updates, new_last_token, limited)`, where `updates` is
+ a list of `(token, row)` entries, `new_last_token` is the new
+ position in stream, and `limited` is whether there are more updates
+ to fetch.
"""
- if from_token in ("NOW", "now"):
- return [], self.current_token()
-
- current_token = self.current_token()
-
from_token = int(from_token)
- if from_token == current_token:
- return [], current_token
+ if from_token == upto_token:
+ return [], upto_token, False
- rows = await self.update_function(
- from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
+ updates, upto_token, limited = await self.update_function(
+ from_token, upto_token, limit=limit,
)
-
- # never turn more than MAX_EVENTS_BEHIND + 1 into updates.
- rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
-
- updates = [(row[0], row[1:]) for row in rows]
-
- # check we didn't get more rows than the limit.
- # doing it like this allows the update_function to be a generator.
- if len(updates) >= MAX_EVENTS_BEHIND:
- raise Exception("stream %s has fallen behind" % (self.NAME))
-
- # The update function didn't hit the limit, so we must have got all
- # the updates to `current_token`, and can return that as our new
- # stream position.
- return updates, current_token
+ return updates, upto_token, limited
def current_token(self):
"""Gets the current token of the underlying streams. Should be provided
@@ -141,6 +136,48 @@ class Stream(object):
raise NotImplementedError()
+def db_query_to_update_function(
+ query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
+) -> Callable[[Token, Token, int], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
+ """Wraps a db query function which returns a list of rows to make it
+ suitable for use as an `update_function` for the Stream class
+ """
+
+ async def update_function(from_token, upto_token, limit):
+ rows = await query_function(from_token, upto_token, limit)
+ updates = [(row[0], row[1:]) for row in rows]
+ limited = False
+ if len(updates) == limit:
+ upto_token = rows[-1][0]
+ limited = True
+
+ return updates, upto_token, limited
+
+ return update_function
+
+
+def make_http_update_function(
+ hs, stream_name: str
+) -> Callable[[Token, Token, Token], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
+ """Makes a suitable function for use as an `update_function` that queries
+ the master process for updates.
+ """
+
+ client = ReplicationGetStreamUpdates.make_client(hs)
+
+ async def update_function(
+ from_token: int, upto_token: int, limit: int
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+ return await client(
+ stream_name=stream_name,
+ from_token=from_token,
+ upto_token=upto_token,
+ limit=limit,
+ )
+
+ return update_function
+
+
class BackfillStream(Stream):
"""We fetched some old events and either we had never seen that event before
or it went from being an outlier to not.
@@ -164,7 +201,7 @@ class BackfillStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
self.current_token = store.get_current_backfill_token # type: ignore
- self.update_function = store.get_all_new_backfill_event_rows # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_new_backfill_event_rows) # type: ignore
super(BackfillStream, self).__init__(hs)
@@ -190,8 +227,15 @@ class PresenceStream(Stream):
store = hs.get_datastore()
presence_handler = hs.get_presence_handler()
+ self._is_worker = hs.config.worker_app is not None
+
self.current_token = store.get_current_presence_token # type: ignore
- self.update_function = presence_handler.get_all_presence_updates # type: ignore
+
+ if hs.config.worker_app is None:
+ self.update_function = db_query_to_update_function(presence_handler.get_all_presence_updates) # type: ignore
+ else:
+ # Query master process
+ self.update_function = make_http_update_function(hs, self.NAME) # type: ignore
super(PresenceStream, self).__init__(hs)
@@ -208,7 +252,12 @@ class TypingStream(Stream):
typing_handler = hs.get_typing_handler()
self.current_token = typing_handler.get_current_token # type: ignore
- self.update_function = typing_handler.get_all_typing_updates # type: ignore
+
+ if hs.config.worker_app is None:
+ self.update_function = db_query_to_update_function(typing_handler.get_all_typing_updates) # type: ignore
+ else:
+ # Query master process
+ self.update_function = make_http_update_function(hs, self.NAME) # type: ignore
super(TypingStream, self).__init__(hs)
@@ -232,7 +281,7 @@ class ReceiptsStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_max_receipt_stream_id # type: ignore
- self.update_function = store.get_all_updated_receipts # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_updated_receipts) # type: ignore
super(ReceiptsStream, self).__init__(hs)
@@ -256,7 +305,13 @@ class PushRulesStream(Stream):
async def update_function(self, from_token, to_token, limit):
rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
- return [(row[0], row[2]) for row in rows]
+
+ limited = False
+ if len(rows) == limit:
+ to_token = rows[-1][0]
+ limited = True
+
+ return [(row[0], (row[2],)) for row in rows], to_token, limited
class PushersStream(Stream):
@@ -275,7 +330,7 @@ class PushersStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_pushers_stream_token # type: ignore
- self.update_function = store.get_all_updated_pushers_rows # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_updated_pushers_rows) # type: ignore
super(PushersStream, self).__init__(hs)
@@ -307,7 +362,7 @@ class CachesStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_cache_stream_token # type: ignore
- self.update_function = store.get_all_updated_caches # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_updated_caches) # type: ignore
super(CachesStream, self).__init__(hs)
@@ -333,7 +388,7 @@ class PublicRoomsStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_current_public_room_stream_id # type: ignore
- self.update_function = store.get_all_new_public_rooms # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_new_public_rooms) # type: ignore
super(PublicRoomsStream, self).__init__(hs)
@@ -354,7 +409,7 @@ class DeviceListsStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_device_stream_token # type: ignore
- self.update_function = store.get_all_device_list_changes_for_remotes # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_device_list_changes_for_remotes) # type: ignore
super(DeviceListsStream, self).__init__(hs)
@@ -372,7 +427,7 @@ class ToDeviceStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_to_device_stream_token # type: ignore
- self.update_function = store.get_all_new_device_messages # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_new_device_messages) # type: ignore
super(ToDeviceStream, self).__init__(hs)
@@ -392,7 +447,7 @@ class TagAccountDataStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_max_account_data_stream_id # type: ignore
- self.update_function = store.get_all_updated_tags # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_updated_tags) # type: ignore
super(TagAccountDataStream, self).__init__(hs)
@@ -412,10 +467,11 @@ class AccountDataStream(Stream):
self.store = hs.get_datastore()
self.current_token = self.store.get_max_account_data_stream_id # type: ignore
+ self.update_function = db_query_to_update_function(self._update_function) # type: ignore
super(AccountDataStream, self).__init__(hs)
- async def update_function(self, from_token, to_token, limit):
+ async def _update_function(self, from_token, to_token, limit):
global_results, room_results = await self.store.get_all_updated_account_data(
from_token, from_token, to_token, limit
)
@@ -442,7 +498,7 @@ class GroupServerStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_group_stream_token # type: ignore
- self.update_function = store.get_all_groups_changes # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_groups_changes) # type: ignore
super(GroupServerStream, self).__init__(hs)
@@ -460,6 +516,6 @@ class UserSignatureStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_device_stream_token # type: ignore
- self.update_function = store.get_all_user_signature_changes_for_remotes # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_user_signature_changes_for_remotes) # type: ignore
super(UserSignatureStream, self).__init__(hs)
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index b3afabb8cd..c6a595629f 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -19,7 +19,7 @@ from typing import Tuple, Type
import attr
-from ._base import Stream
+from ._base import Stream, db_query_to_update_function
"""Handling of the 'events' replication stream
@@ -117,10 +117,11 @@ class EventsStream(Stream):
def __init__(self, hs):
self._store = hs.get_datastore()
self.current_token = self._store.get_current_events_token # type: ignore
+ self.update_function = db_query_to_update_function(self._update_function) # type: ignore
super(EventsStream, self).__init__(hs)
- async def update_function(self, from_token, current_token, limit=None):
+ async def _update_function(self, from_token, current_token, limit=None):
event_rows = await self._store.get_all_new_forward_event_rows(
from_token, current_token, limit
)
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index f5f9336430..48c1d45718 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -15,7 +15,9 @@
# limitations under the License.
from collections import namedtuple
-from ._base import Stream
+from twisted.internet import defer
+
+from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function
class FederationStream(Stream):
@@ -33,11 +35,18 @@ class FederationStream(Stream):
NAME = "federation"
ROW_TYPE = FederationStreamRow
+ _QUERY_MASTER = True
def __init__(self, hs):
- federation_sender = hs.get_federation_sender()
-
- self.current_token = federation_sender.get_current_token # type: ignore
- self.update_function = federation_sender.get_replication_rows # type: ignore
+ # Not all synapse instances will have a federation sender instance,
+ # whether that's a `FederationSender` or a `FederationRemoteSendQueue`,
+ # so we stub the stream out when that is the case.
+ if hs.config.worker_app is None or hs.should_send_federation():
+ federation_sender = hs.get_federation_sender()
+ self.current_token = federation_sender.get_current_token # type: ignore
+ self.update_function = db_query_to_update_function(federation_sender.get_replication_rows) # type: ignore
+ else:
+ self.current_token = lambda: 0 # type: ignore
+ self.update_function = lambda from_token, upto_token, limit: defer.succeed(([], upto_token, bool)) # type: ignore
super(FederationStream, self).__init__(hs)
|