diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 02ab5b66ea..7e7ad0f798 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -55,6 +55,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
self.client_name = client_name
self.handler = handler
self.server_name = hs.config.server_name
+ self.hs = hs
self._clock = hs.get_clock() # As self.clock is defined in super class
hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying)
@@ -65,7 +66,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
def buildProtocol(self, addr):
logger.info("Connected to replication: %r", addr)
return ClientReplicationStreamProtocol(
- self.client_name, self.server_name, self._clock, self.handler
+ self.hs, self.client_name, self.server_name, self._clock, self.handler,
)
def clientConnectionLost(self, connector, reason):
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 451671412d..5a6b734094 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -136,8 +136,8 @@ class PositionCommand(Command):
"""Sent by the server to tell the client the stream postition without
needing to send an RDATA.
- Sent to the client after all missing updates for a stream have been sent
- to the client and they're now up to date.
+ On receipt of a POSITION command clients should check if they have missed
+ any updates, and if so then fetch them out of band.
"""
NAME = "POSITION"
@@ -179,42 +179,24 @@ class NameCommand(Command):
class ReplicateCommand(Command):
- """Sent by the client to subscribe to the stream.
+ """Sent by the client to subscribe to streams.
Format::
- REPLICATE <stream_name> <token>
-
- Where <token> may be either:
- * a numeric stream_id to stream updates from
- * "NOW" to stream all subsequent updates.
-
- The <stream_name> can be "ALL" to subscribe to all known streams, in which
- case the <token> must be set to "NOW", i.e.::
-
- REPLICATE ALL NOW
+ REPLICATE
"""
NAME = "REPLICATE"
- def __init__(self, stream_name, token):
- self.stream_name = stream_name
- self.token = token
+ def __init__(self):
+ pass
@classmethod
def from_line(cls, line):
- stream_name, token = line.split(" ", 1)
- if token in ("NOW", "now"):
- token = "NOW"
- else:
- token = int(token)
- return cls(stream_name, token)
+ return cls()
def to_line(self):
- return " ".join((self.stream_name, str(self.token)))
-
- def get_logcontext_id(self):
- return "REPLICATE-" + self.stream_name
+ return ""
class UserSyncCommand(Command):
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index bc1482a9bb..f81d2e2442 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -35,9 +35,7 @@ indicate which side is sending, these are *not* included on the wire::
> PING 1490197665618
< NAME synapse.app.appservice
< PING 1490197665618
- < REPLICATE events 1
- < REPLICATE backfill 1
- < REPLICATE caches 1
+ < REPLICATE
> POSITION events 1
> POSITION backfill 1
> POSITION caches 1
@@ -53,17 +51,15 @@ import fcntl
import logging
import struct
from collections import defaultdict
-from typing import Any, DefaultDict, Dict, List, Set, Tuple
+from typing import Any, DefaultDict, Dict, List, Set
-from six import iteritems, iterkeys
+from six import iteritems
from prometheus_client import Counter
-from twisted.internet import defer
from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure
-from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.commands import (
@@ -82,11 +78,16 @@ from synapse.replication.tcp.commands import (
SyncCommand,
UserSyncCommand,
)
-from synapse.replication.tcp.streams import STREAMS_MAP
+from synapse.replication.tcp.streams import STREAMS_MAP, Stream
from synapse.types import Collection
from synapse.util import Clock
from synapse.util.stringutils import random_string
+MYPY = False
+if MYPY:
+ from synapse.server import HomeServer
+
+
connection_close_counter = Counter(
"synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
)
@@ -411,16 +412,6 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.server_name = server_name
self.streamer = streamer
- # The streams the client has subscribed to and is up to date with
- self.replication_streams = set() # type: Set[str]
-
- # The streams the client is currently subscribing to.
- self.connecting_streams = set() # type: Set[str]
-
- # Map from stream name to list of updates to send once we've finished
- # subscribing the client to the stream.
- self.pending_rdata = {} # type: Dict[str, List[Tuple[int, Any]]]
-
def connectionMade(self):
self.send_command(ServerCommand(self.server_name))
BaseReplicationStreamProtocol.connectionMade(self)
@@ -436,21 +427,10 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
)
async def on_REPLICATE(self, cmd):
- stream_name = cmd.stream_name
- token = cmd.token
-
- if stream_name == "ALL":
- # Subscribe to all streams we're publishing to.
- deferreds = [
- run_in_background(self.subscribe_to_stream, stream, token)
- for stream in iterkeys(self.streamer.streams_by_name)
- ]
-
- await make_deferred_yieldable(
- defer.gatherResults(deferreds, consumeErrors=True)
- )
- else:
- await self.subscribe_to_stream(stream_name, token)
+ # Subscribe to all streams we're publishing to.
+ for stream_name in self.streamer.streams_by_name:
+ current_token = self.streamer.get_stream_token(stream_name)
+ self.send_command(PositionCommand(stream_name, current_token))
async def on_FEDERATION_ACK(self, cmd):
self.streamer.federation_ack(cmd.token)
@@ -474,87 +454,12 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
cmd.last_seen,
)
- async def subscribe_to_stream(self, stream_name, token):
- """Subscribe the remote to a stream.
-
- This invloves checking if they've missed anything and sending those
- updates down if they have. During that time new updates for the stream
- are queued and sent once we've sent down any missed updates.
- """
- self.replication_streams.discard(stream_name)
- self.connecting_streams.add(stream_name)
-
- try:
- # Get missing updates
- updates, current_token = await self.streamer.get_stream_updates(
- stream_name, token
- )
-
- # Send all the missing updates
- for update in updates:
- token, row = update[0], update[1]
- self.send_command(RdataCommand(stream_name, token, row))
-
- # We send a POSITION command to ensure that they have an up to
- # date token (especially useful if we didn't send any updates
- # above)
- self.send_command(PositionCommand(stream_name, current_token))
-
- # Now we can send any updates that came in while we were subscribing
- pending_rdata = self.pending_rdata.pop(stream_name, [])
- updates = []
- for token, update in pending_rdata:
- # If the token is null, it is part of a batch update. Batches
- # are multiple updates that share a single token. To denote
- # this, the token is set to None for all tokens in the batch
- # except for the last. If we find a None token, we keep looking
- # through tokens until we find one that is not None and then
- # process all previous updates in the batch as if they had the
- # final token.
- if token is None:
- # Store this update as part of a batch
- updates.append(update)
- continue
-
- if token <= current_token:
- # This update or batch of updates is older than
- # current_token, dismiss it
- updates = []
- continue
-
- updates.append(update)
-
- # Send all updates that are part of this batch with the
- # found token
- for update in updates:
- self.send_command(RdataCommand(stream_name, token, update))
-
- # Clear stored updates
- updates = []
-
- # They're now fully subscribed
- self.replication_streams.add(stream_name)
- except Exception as e:
- logger.exception("[%s] Failed to handle REPLICATE command", self.id())
- self.send_error("failed to handle replicate: %r", e)
- finally:
- self.connecting_streams.discard(stream_name)
-
def stream_update(self, stream_name, token, data):
"""Called when a new update is available to stream to clients.
We need to check if the client is interested in the stream or not
"""
- if stream_name in self.replication_streams:
- # The client is subscribed to the stream
- self.send_command(RdataCommand(stream_name, token, data))
- elif stream_name in self.connecting_streams:
- # The client is being subscribed to the stream
- logger.debug("[%s] Queuing RDATA %r %r", self.id(), stream_name, token)
- self.pending_rdata.setdefault(stream_name, []).append((token, data))
- else:
- # The client isn't subscribed
- logger.debug("[%s] Dropping RDATA %r %r", self.id(), stream_name, token)
+ self.send_command(RdataCommand(stream_name, token, data))
def send_sync(self, data):
self.send_command(SyncCommand(data))
@@ -638,6 +543,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
def __init__(
self,
+ hs: "HomeServer",
client_name: str,
server_name: str,
clock: Clock,
@@ -649,22 +555,25 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.server_name = server_name
self.handler = handler
+ self.streams = {
+ stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
+ } # type: Dict[str, Stream]
+
# Set of stream names that have been subscribe to, but haven't yet
# caught up with. This is used to track when the client has been fully
# connected to the remote.
- self.streams_connecting = set() # type: Set[str]
+ self.streams_connecting = set(STREAMS_MAP) # type: Set[str]
# Map of stream to batched updates. See RdataCommand for info on how
# batching works.
- self.pending_batches = {} # type: Dict[str, Any]
+ self.pending_batches = {} # type: Dict[str, List[Any]]
def connectionMade(self):
self.send_command(NameCommand(self.client_name))
BaseReplicationStreamProtocol.connectionMade(self)
# Once we've connected subscribe to the necessary streams
- for stream_name, token in iteritems(self.handler.get_streams_to_replicate()):
- self.replicate(stream_name, token)
+ self.replicate()
# Tell the server if we have any users currently syncing (should only
# happen on synchrotrons)
@@ -676,10 +585,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# We've now finished connecting to so inform the client handler
self.handler.update_connection(self)
- # This will happen if we don't actually subscribe to any streams
- if not self.streams_connecting:
- self.handler.finished_connecting()
-
async def on_SERVER(self, cmd):
if cmd.data != self.server_name:
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
@@ -697,7 +602,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
)
raise
- if cmd.token is None:
+ if cmd.token is None or stream_name in self.streams_connecting:
# I.e. this is part of a batch of updates for this stream. Batch
# until we get an update for the stream with a non None token
self.pending_batches.setdefault(stream_name, []).append(row)
@@ -707,14 +612,55 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
rows.append(row)
await self.handler.on_rdata(stream_name, cmd.token, rows)
- async def on_POSITION(self, cmd):
- # When we get a `POSITION` command it means we've finished getting
- # missing updates for the given stream, and are now up to date.
+ async def on_POSITION(self, cmd: PositionCommand):
+ stream = self.streams.get(cmd.stream_name)
+ if not stream:
+ logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
+ return
+
+ # Find where we previously streamed up to.
+ current_token = self.handler.get_streams_to_replicate().get(cmd.stream_name)
+ if current_token is None:
+ logger.warning(
+ "Got POSITION for stream we're not subscribed to: %s", cmd.stream_name
+ )
+ return
+
+ # Fetch all updates between then and now.
+ limited = True
+ while limited:
+ updates, current_token, limited = await stream.get_updates_since(
+ current_token, cmd.token
+ )
+
+ # Check if the connection was closed underneath us, if so we bail
+ # rather than risk having concurrent catch ups going on.
+ if self.state == ConnectionStates.CLOSED:
+ return
+
+ if updates:
+ await self.handler.on_rdata(
+ cmd.stream_name,
+ current_token,
+ [stream.parse_row(update[1]) for update in updates],
+ )
+
+ # We've now caught up to position sent to us, notify handler.
+ await self.handler.on_position(cmd.stream_name, cmd.token)
+
self.streams_connecting.discard(cmd.stream_name)
if not self.streams_connecting:
self.handler.finished_connecting()
- await self.handler.on_position(cmd.stream_name, cmd.token)
+ # Check if the connection was closed underneath us, if so we bail
+ # rather than risk having concurrent catch ups going on.
+ if self.state == ConnectionStates.CLOSED:
+ return
+
+ # Handle any RDATA that came in while we were catching up.
+ rows = self.pending_batches.pop(cmd.stream_name, [])
+ if rows:
+ await self.handler.on_rdata(cmd.stream_name, rows[-1].token, rows)
async def on_SYNC(self, cmd):
self.handler.on_sync(cmd.data)
@@ -722,22 +668,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
self.handler.on_remote_server_up(cmd.data)
- def replicate(self, stream_name, token):
+ def replicate(self):
"""Send the subscription request to the server
"""
- if stream_name not in STREAMS_MAP:
- raise Exception("Invalid stream name %r" % (stream_name,))
-
- logger.info(
- "[%s] Subscribing to replication stream: %r from %r",
- self.id(),
- stream_name,
- token,
- )
-
- self.streams_connecting.add(stream_name)
+ logger.info("[%s] Subscribing to replication streams", self.id())
- self.send_command(ReplicateCommand(stream_name, token))
+ self.send_command(ReplicateCommand())
def on_connection_closed(self):
BaseReplicationStreamProtocol.on_connection_closed(self)
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 6e2ebaf614..4374e99e32 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -17,7 +17,7 @@
import logging
import random
-from typing import Any, List
+from typing import Any, Dict, List
from six import itervalues
@@ -30,7 +30,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.metrics import Measure, measure_func
from .protocol import ServerReplicationStreamProtocol
-from .streams import STREAMS_MAP
+from .streams import STREAMS_MAP, Stream
from .streams.federation import FederationStream
stream_updates_counter = Counter(
@@ -52,7 +52,7 @@ class ReplicationStreamProtocolFactory(Factory):
"""
def __init__(self, hs):
- self.streamer = ReplicationStreamer(hs)
+ self.streamer = hs.get_replication_streamer()
self.clock = hs.get_clock()
self.server_name = hs.config.server_name
@@ -133,6 +133,11 @@ class ReplicationStreamer(object):
for conn in self.connections:
conn.send_error("server shutting down")
+ def get_streams(self) -> Dict[str, Stream]:
+ """Get a mapp from stream name to stream instance.
+ """
+ return self.streams_by_name
+
def on_notifier_poke(self):
"""Checks if there is actually any new data and sends it to the
connections if there are.
@@ -190,7 +195,8 @@ class ReplicationStreamer(object):
stream.current_token(),
)
try:
- updates, current_token = await stream.get_updates()
+ updates, current_token, limited = await stream.get_updates()
+ self.pending_updates |= limited
except Exception:
logger.info("Failed to handle stream %s", stream.NAME)
raise
@@ -226,8 +232,7 @@ class ReplicationStreamer(object):
self.pending_updates = False
self.is_looping = False
- @measure_func("repl.get_stream_updates")
- async def get_stream_updates(self, stream_name, token):
+ def get_stream_token(self, stream_name):
"""For a given stream get all updates since token. This is called when
a client first subscribes to a stream.
"""
@@ -235,7 +240,7 @@ class ReplicationStreamer(object):
if not stream:
raise Exception("unknown stream %s", stream_name)
- return await stream.get_updates_since(token)
+ return stream.current_token()
@measure_func("repl.federation_ack")
def federation_ack(self, token):
diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py
index 29199f5b46..37bcd3de66 100644
--- a/synapse/replication/tcp/streams/__init__.py
+++ b/synapse/replication/tcp/streams/__init__.py
@@ -24,6 +24,9 @@ Each stream is defined by the following information:
current_token: The function that returns the current token for the stream
update_function: The function that returns a list of updates between two tokens
"""
+
+from typing import Dict, Type
+
from synapse.replication.tcp.streams._base import (
AccountDataStream,
BackfillStream,
@@ -35,6 +38,7 @@ from synapse.replication.tcp.streams._base import (
PushersStream,
PushRulesStream,
ReceiptsStream,
+ Stream,
TagAccountDataStream,
ToDeviceStream,
TypingStream,
@@ -63,10 +67,12 @@ STREAMS_MAP = {
GroupServerStream,
UserSignatureStream,
)
-}
+} # type: Dict[str, Type[Stream]]
+
__all__ = [
"STREAMS_MAP",
+ "Stream",
"BackfillStream",
"PresenceStream",
"TypingStream",
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 32d9514883..c14dff6c64 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -14,13 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import itertools
import logging
from collections import namedtuple
-from typing import Any, List, Optional, Tuple
+from typing import Any, Awaitable, Callable, List, Optional, Tuple
import attr
+from synapse.replication.http.streams import ReplicationGetStreamUpdates
from synapse.types import JsonDict
logger = logging.getLogger(__name__)
@@ -29,6 +29,15 @@ logger = logging.getLogger(__name__)
MAX_EVENTS_BEHIND = 500000
+# Some type aliases to make things a bit easier.
+
+# A stream position token
+Token = int
+
+# A pair of position in stream and args used to create an instance of `ROW_TYPE`.
+StreamRow = Tuple[Token, tuple]
+
+
class Stream(object):
"""Base class for the streams.
@@ -56,6 +65,7 @@ class Stream(object):
return cls.ROW_TYPE(*row)
def __init__(self, hs):
+
# The token from which we last asked for updates
self.last_token = self.current_token()
@@ -65,61 +75,46 @@ class Stream(object):
"""
self.last_token = self.current_token()
- async def get_updates(self):
+ async def get_updates(self) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
"""Gets all updates since the last time this function was called (or
since the stream was constructed if it hadn't been called before).
Returns:
- Deferred[Tuple[List[Tuple[int, Any]], int]:
- Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
- list of ``(token, row)`` entries. ``row`` will be json-serialised and
- sent over the replication steam.
+ A triplet `(updates, new_last_token, limited)`, where `updates` is
+ a list of `(token, row)` entries, `new_last_token` is the new
+ position in stream, and `limited` is whether there are more updates
+ to fetch.
"""
- updates, current_token = await self.get_updates_since(self.last_token)
+ current_token = self.current_token()
+ updates, current_token, limited = await self.get_updates_since(
+ self.last_token, current_token
+ )
self.last_token = current_token
- return updates, current_token
+ return updates, current_token, limited
async def get_updates_since(
- self, from_token: int
- ) -> Tuple[List[Tuple[int, JsonDict]], int]:
+ self, from_token: Token, upto_token: Token, limit: int = 100
+ ) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
"""Like get_updates except allows specifying from when we should
stream updates
Returns:
- Resolves to a pair `(updates, new_last_token)`, where `updates` is
- a list of `(token, row)` entries and `new_last_token` is the new
- position in stream.
+ A triplet `(updates, new_last_token, limited)`, where `updates` is
+ a list of `(token, row)` entries, `new_last_token` is the new
+ position in stream, and `limited` is whether there are more updates
+ to fetch.
"""
- if from_token in ("NOW", "now"):
- return [], self.current_token()
-
- current_token = self.current_token()
-
from_token = int(from_token)
- if from_token == current_token:
- return [], current_token
+ if from_token == upto_token:
+ return [], upto_token, False
- rows = await self.update_function(
- from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
+ updates, upto_token, limited = await self.update_function(
+ from_token, upto_token, limit=limit,
)
-
- # never turn more than MAX_EVENTS_BEHIND + 1 into updates.
- rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
-
- updates = [(row[0], row[1:]) for row in rows]
-
- # check we didn't get more rows than the limit.
- # doing it like this allows the update_function to be a generator.
- if len(updates) >= MAX_EVENTS_BEHIND:
- raise Exception("stream %s has fallen behind" % (self.NAME))
-
- # The update function didn't hit the limit, so we must have got all
- # the updates to `current_token`, and can return that as our new
- # stream position.
- return updates, current_token
+ return updates, upto_token, limited
def current_token(self):
"""Gets the current token of the underlying streams. Should be provided
@@ -141,6 +136,48 @@ class Stream(object):
raise NotImplementedError()
+def db_query_to_update_function(
+ query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
+) -> Callable[[Token, Token, int], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
+ """Wraps a db query function which returns a list of rows to make it
+ suitable for use as an `update_function` for the Stream class
+ """
+
+ async def update_function(from_token, upto_token, limit):
+ rows = await query_function(from_token, upto_token, limit)
+ updates = [(row[0], row[1:]) for row in rows]
+ limited = False
+ if len(updates) == limit:
+ upto_token = rows[-1][0]
+ limited = True
+
+ return updates, upto_token, limited
+
+ return update_function
+
+
+def make_http_update_function(
+ hs, stream_name: str
+) -> Callable[[Token, Token, Token], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
+ """Makes a suitable function for use as an `update_function` that queries
+ the master process for updates.
+ """
+
+ client = ReplicationGetStreamUpdates.make_client(hs)
+
+ async def update_function(
+ from_token: int, upto_token: int, limit: int
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+ return await client(
+ stream_name=stream_name,
+ from_token=from_token,
+ upto_token=upto_token,
+ limit=limit,
+ )
+
+ return update_function
+
+
class BackfillStream(Stream):
"""We fetched some old events and either we had never seen that event before
or it went from being an outlier to not.
@@ -164,7 +201,7 @@ class BackfillStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
self.current_token = store.get_current_backfill_token # type: ignore
- self.update_function = store.get_all_new_backfill_event_rows # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_new_backfill_event_rows) # type: ignore
super(BackfillStream, self).__init__(hs)
@@ -190,8 +227,15 @@ class PresenceStream(Stream):
store = hs.get_datastore()
presence_handler = hs.get_presence_handler()
+ self._is_worker = hs.config.worker_app is not None
+
self.current_token = store.get_current_presence_token # type: ignore
- self.update_function = presence_handler.get_all_presence_updates # type: ignore
+
+ if hs.config.worker_app is None:
+ self.update_function = db_query_to_update_function(presence_handler.get_all_presence_updates) # type: ignore
+ else:
+ # Query master process
+ self.update_function = make_http_update_function(hs, self.NAME) # type: ignore
super(PresenceStream, self).__init__(hs)
@@ -208,7 +252,12 @@ class TypingStream(Stream):
typing_handler = hs.get_typing_handler()
self.current_token = typing_handler.get_current_token # type: ignore
- self.update_function = typing_handler.get_all_typing_updates # type: ignore
+
+ if hs.config.worker_app is None:
+ self.update_function = db_query_to_update_function(typing_handler.get_all_typing_updates) # type: ignore
+ else:
+ # Query master process
+ self.update_function = make_http_update_function(hs, self.NAME) # type: ignore
super(TypingStream, self).__init__(hs)
@@ -232,7 +281,7 @@ class ReceiptsStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_max_receipt_stream_id # type: ignore
- self.update_function = store.get_all_updated_receipts # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_updated_receipts) # type: ignore
super(ReceiptsStream, self).__init__(hs)
@@ -256,7 +305,13 @@ class PushRulesStream(Stream):
async def update_function(self, from_token, to_token, limit):
rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
- return [(row[0], row[2]) for row in rows]
+
+ limited = False
+ if len(rows) == limit:
+ to_token = rows[-1][0]
+ limited = True
+
+ return [(row[0], (row[2],)) for row in rows], to_token, limited
class PushersStream(Stream):
@@ -275,7 +330,7 @@ class PushersStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_pushers_stream_token # type: ignore
- self.update_function = store.get_all_updated_pushers_rows # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_updated_pushers_rows) # type: ignore
super(PushersStream, self).__init__(hs)
@@ -307,7 +362,7 @@ class CachesStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_cache_stream_token # type: ignore
- self.update_function = store.get_all_updated_caches # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_updated_caches) # type: ignore
super(CachesStream, self).__init__(hs)
@@ -333,7 +388,7 @@ class PublicRoomsStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_current_public_room_stream_id # type: ignore
- self.update_function = store.get_all_new_public_rooms # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_new_public_rooms) # type: ignore
super(PublicRoomsStream, self).__init__(hs)
@@ -354,7 +409,7 @@ class DeviceListsStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_device_stream_token # type: ignore
- self.update_function = store.get_all_device_list_changes_for_remotes # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_device_list_changes_for_remotes) # type: ignore
super(DeviceListsStream, self).__init__(hs)
@@ -372,7 +427,7 @@ class ToDeviceStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_to_device_stream_token # type: ignore
- self.update_function = store.get_all_new_device_messages # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_new_device_messages) # type: ignore
super(ToDeviceStream, self).__init__(hs)
@@ -392,7 +447,7 @@ class TagAccountDataStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_max_account_data_stream_id # type: ignore
- self.update_function = store.get_all_updated_tags # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_updated_tags) # type: ignore
super(TagAccountDataStream, self).__init__(hs)
@@ -412,10 +467,11 @@ class AccountDataStream(Stream):
self.store = hs.get_datastore()
self.current_token = self.store.get_max_account_data_stream_id # type: ignore
+ self.update_function = db_query_to_update_function(self._update_function) # type: ignore
super(AccountDataStream, self).__init__(hs)
- async def update_function(self, from_token, to_token, limit):
+ async def _update_function(self, from_token, to_token, limit):
global_results, room_results = await self.store.get_all_updated_account_data(
from_token, from_token, to_token, limit
)
@@ -442,7 +498,7 @@ class GroupServerStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_group_stream_token # type: ignore
- self.update_function = store.get_all_groups_changes # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_groups_changes) # type: ignore
super(GroupServerStream, self).__init__(hs)
@@ -460,6 +516,6 @@ class UserSignatureStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_device_stream_token # type: ignore
- self.update_function = store.get_all_user_signature_changes_for_remotes # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_user_signature_changes_for_remotes) # type: ignore
super(UserSignatureStream, self).__init__(hs)
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index b3afabb8cd..c6a595629f 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -19,7 +19,7 @@ from typing import Tuple, Type
import attr
-from ._base import Stream
+from ._base import Stream, db_query_to_update_function
"""Handling of the 'events' replication stream
@@ -117,10 +117,11 @@ class EventsStream(Stream):
def __init__(self, hs):
self._store = hs.get_datastore()
self.current_token = self._store.get_current_events_token # type: ignore
+ self.update_function = db_query_to_update_function(self._update_function) # type: ignore
super(EventsStream, self).__init__(hs)
- async def update_function(self, from_token, current_token, limit=None):
+ async def _update_function(self, from_token, current_token, limit=None):
event_rows = await self._store.get_all_new_forward_event_rows(
from_token, current_token, limit
)
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index f5f9336430..48c1d45718 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -15,7 +15,9 @@
# limitations under the License.
from collections import namedtuple
-from ._base import Stream
+from twisted.internet import defer
+
+from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function
class FederationStream(Stream):
@@ -33,11 +35,18 @@ class FederationStream(Stream):
NAME = "federation"
ROW_TYPE = FederationStreamRow
+ _QUERY_MASTER = True
def __init__(self, hs):
- federation_sender = hs.get_federation_sender()
-
- self.current_token = federation_sender.get_current_token # type: ignore
- self.update_function = federation_sender.get_replication_rows # type: ignore
+ # Not all synapse instances will have a federation sender instance,
+ # whether that's a `FederationSender` or a `FederationRemoteSendQueue`,
+ # so we stub the stream out when that is the case.
+ if hs.config.worker_app is None or hs.should_send_federation():
+ federation_sender = hs.get_federation_sender()
+ self.current_token = federation_sender.get_current_token # type: ignore
+ self.update_function = db_query_to_update_function(federation_sender.get_replication_rows) # type: ignore
+ else:
+ self.current_token = lambda: 0 # type: ignore
+ self.update_function = lambda from_token, upto_token, limit: defer.succeed(([], upto_token, bool)) # type: ignore
super(FederationStream, self).__init__(hs)
|