diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index d5b9c2831b..2699e466bc 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -16,7 +16,7 @@
import logging
from collections import namedtuple
-from typing import Any, List, Optional, Tuple, Union
+from typing import Any, Awaitable, Callable, List, Optional, Tuple, Union
import attr
@@ -40,10 +40,6 @@ class Stream(object):
# The type of the row. Used by the default impl of parse_row.
ROW_TYPE = None # type: Any
- # Whether the update function is only available on master. If True then
- # calls to get updates are proxied to the master via a HTTP call.
- _QUERY_MASTER = False
-
@classmethod
def parse_row(cls, row):
"""Parse a row received over replication
@@ -60,10 +56,6 @@ class Stream(object):
return cls.ROW_TYPE(*row)
def __init__(self, hs):
- self._is_worker = hs.config.worker_app is not None
-
- if self._QUERY_MASTER and self._is_worker:
- self._replication_client = ReplicationGetStreamUpdates.make_client(hs)
# The token from which we last asked for updates
self.last_token = self.current_token()
@@ -105,31 +97,15 @@ class Stream(object):
to fetch.
"""
- if from_token in ("NOW", "now"):
- return [], upto_token, False
-
from_token = int(from_token)
if from_token == upto_token:
return [], upto_token, False
- if self._is_worker and self._QUERY_MASTER:
- result = await self._replication_client(
- stream_name=self.NAME,
- from_token=from_token,
- upto_token=upto_token,
- limit=limit,
- )
- return result["updates"], result["upto_token"], result["limited"]
- else:
- limited = False
- rows = await self.update_function(from_token, upto_token, limit=limit)
- updates = [(row[0], row[1:]) for row in rows]
- if len(updates) == limit:
- upto_token = rows[-1][0]
- limited = True
-
- return updates, upto_token, limited
+ updates, upto_token, limited = await self.update_function(
+ from_token, upto_token, limit=limit,
+ )
+ return updates, upto_token, limited
def current_token(self):
"""Gets the current token of the underlying streams. Should be provided
@@ -151,6 +127,26 @@ class Stream(object):
raise NotImplementedError()
+def db_query_to_update_function(
+ query_function: Callable[[int, int, int], Awaitable[List[tuple]]]
+) -> Callable[[int, int, int], Awaitable[Tuple[List[Tuple[int, tuple]], int, bool]]]:
+ """Wraps a db query function which returns a list of rows to make it
+ suitable for use as an `update_function` for the Stream class
+ """
+
+ async def update_function(from_token, upto_token, limit):
+ rows = await query_function(from_token, upto_token, limit)
+ updates = [(row[0], row[1:]) for row in rows]
+ limited = False
+ if len(updates) == limit:
+ upto_token = rows[-1][0]
+ limited = True
+
+ return updates, upto_token, limited
+
+ return update_function
+
+
class BackfillStream(Stream):
"""We fetched some old events and either we had never seen that event before
or it went from being an outlier to not.
@@ -174,7 +170,7 @@ class BackfillStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
self.current_token = store.get_current_backfill_token # type: ignore
- self.update_function = store.get_all_new_backfill_event_rows # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_new_backfill_event_rows) # type: ignore
super(BackfillStream, self).__init__(hs)
@@ -195,16 +191,20 @@ class PresenceStream(Stream):
NAME = "presence"
ROW_TYPE = PresenceStreamRow
- _QUERY_MASTER = True
def __init__(self, hs):
store = hs.get_datastore()
presence_handler = hs.get_presence_handler()
+ self._is_worker = hs.config.worker_app is not None
+
self.current_token = store.get_current_presence_token # type: ignore
if hs.config.worker_app is None:
- self.update_function = presence_handler.get_all_presence_updates # type: ignore
+ self.update_function = db_query_to_update_function(presence_handler.get_all_presence_updates) # type: ignore
+ else:
+ # Query master process
+ self.update_function = ReplicationGetStreamUpdates.make_client(hs) # type: ignore
super(PresenceStream, self).__init__(hs)
@@ -216,7 +216,6 @@ class TypingStream(Stream):
NAME = "typing"
ROW_TYPE = TypingStreamRow
- _QUERY_MASTER = True
def __init__(self, hs):
typing_handler = hs.get_typing_handler()
@@ -224,7 +223,10 @@ class TypingStream(Stream):
self.current_token = typing_handler.get_current_token # type: ignore
if hs.config.worker_app is None:
- self.update_function = typing_handler.get_all_typing_updates # type: ignore
+ self.update_function = db_query_to_update_function(typing_handler.get_all_typing_updates) # type: ignore
+ else:
+ # Query master process
+ self.update_function = ReplicationGetStreamUpdates.make_client(hs) # type: ignore
super(TypingStream, self).__init__(hs)
@@ -248,7 +250,7 @@ class ReceiptsStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_max_receipt_stream_id # type: ignore
- self.update_function = store.get_all_updated_receipts # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_updated_receipts) # type: ignore
super(ReceiptsStream, self).__init__(hs)
@@ -272,7 +274,13 @@ class PushRulesStream(Stream):
async def update_function(self, from_token, to_token, limit):
rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
- return [(row[0], row[2]) for row in rows]
+
+ limited = False
+ if len(rows) == limit:
+ to_token = rows[-1][0]
+ limited = True
+
+ return [(row[0], (row[2],)) for row in rows], to_token, limited
class PushersStream(Stream):
@@ -291,7 +299,7 @@ class PushersStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_pushers_stream_token # type: ignore
- self.update_function = store.get_all_updated_pushers_rows # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_updated_pushers_rows) # type: ignore
super(PushersStream, self).__init__(hs)
@@ -323,7 +331,7 @@ class CachesStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_cache_stream_token # type: ignore
- self.update_function = store.get_all_updated_caches # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_updated_caches) # type: ignore
super(CachesStream, self).__init__(hs)
@@ -349,7 +357,7 @@ class PublicRoomsStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_current_public_room_stream_id # type: ignore
- self.update_function = store.get_all_new_public_rooms # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_new_public_rooms) # type: ignore
super(PublicRoomsStream, self).__init__(hs)
@@ -370,7 +378,7 @@ class DeviceListsStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_device_stream_token # type: ignore
- self.update_function = store.get_all_device_list_changes_for_remotes # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_device_list_changes_for_remotes) # type: ignore
super(DeviceListsStream, self).__init__(hs)
@@ -388,7 +396,7 @@ class ToDeviceStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_to_device_stream_token # type: ignore
- self.update_function = store.get_all_new_device_messages # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_new_device_messages) # type: ignore
super(ToDeviceStream, self).__init__(hs)
@@ -408,7 +416,7 @@ class TagAccountDataStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_max_account_data_stream_id # type: ignore
- self.update_function = store.get_all_updated_tags # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_updated_tags) # type: ignore
super(TagAccountDataStream, self).__init__(hs)
@@ -428,10 +436,11 @@ class AccountDataStream(Stream):
self.store = hs.get_datastore()
self.current_token = self.store.get_max_account_data_stream_id # type: ignore
+ self.update_function = db_query_to_update_function(self._update_function) # type: ignore
super(AccountDataStream, self).__init__(hs)
- async def update_function(self, from_token, to_token, limit):
+ async def _update_function(self, from_token, to_token, limit):
global_results, room_results = await self.store.get_all_updated_account_data(
from_token, from_token, to_token, limit
)
@@ -458,7 +467,7 @@ class GroupServerStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_group_stream_token # type: ignore
- self.update_function = store.get_all_groups_changes # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_groups_changes) # type: ignore
super(GroupServerStream, self).__init__(hs)
@@ -476,6 +485,6 @@ class UserSignatureStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_device_stream_token # type: ignore
- self.update_function = store.get_all_user_signature_changes_for_remotes # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_user_signature_changes_for_remotes) # type: ignore
super(UserSignatureStream, self).__init__(hs)
|