diff --git a/changelog.d/7024.misc b/changelog.d/7024.misc
new file mode 100644
index 0000000000..676f285377
--- /dev/null
+++ b/changelog.d/7024.misc
@@ -0,0 +1 @@
+Move catchup of replication streams logic to worker.
diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md
index e3a4634b14..d4f7d9ec18 100644
--- a/docs/tcp_replication.md
+++ b/docs/tcp_replication.md
@@ -14,16 +14,16 @@ example flow would be (where '>' indicates master to worker and
'<' worker to master flows):
> SERVER example.com
- < REPLICATE events 53
+ < REPLICATE
+ > POSITION events 53
> RDATA events 54 ["$foo1:bar.com", ...]
> RDATA events 55 ["$foo4:bar.com", ...]
-The example shows the server accepting a new connection and sending its
-identity with the `SERVER` command, followed by the client asking to
-subscribe to the `events` stream from the token `53`. The server then
-periodically sends `RDATA` commands which have the format
-`RDATA <stream_name> <token> <row>`, where the format of `<row>` is
-defined by the individual streams.
+The example shows the server accepting a new connection and sending its identity
+with the `SERVER` command, followed by the client server to respond with the
+position of all streams. The server then periodically sends `RDATA` commands
+which have the format `RDATA <stream_name> <token> <row>`, where the format of
+`<row>` is defined by the individual streams.
Error reporting happens by either the client or server sending an ERROR
command, and usually the connection will be closed.
@@ -32,9 +32,6 @@ Since the protocol is a simple line based, its possible to manually
connect to the server using a tool like netcat. A few things should be
noted when manually using the protocol:
-- When subscribing to a stream using `REPLICATE`, the special token
- `NOW` can be used to get all future updates. The special stream name
- `ALL` can be used with `NOW` to subscribe to all available streams.
- The federation stream is only available if federation sending has
been disabled on the main process.
- The server will only time connections out that have sent a `PING`
@@ -91,9 +88,7 @@ The client:
- Sends a `NAME` command, allowing the server to associate a human
friendly name with the connection. This is optional.
- Sends a `PING` as above
-- For each stream the client wishes to subscribe to it sends a
- `REPLICATE` with the `stream_name` and token it wants to subscribe
- from.
+- Sends a `REPLICATE` to get the current position of all streams.
- On receipt of a `SERVER` command, checks that the server name
matches the expected server name.
@@ -140,9 +135,7 @@ the wire:
> PING 1490197665618
< NAME synapse.app.appservice
< PING 1490197665618
- < REPLICATE events 1
- < REPLICATE backfill 1
- < REPLICATE caches 1
+ < REPLICATE
> POSITION events 1
> POSITION backfill 1
> POSITION caches 1
@@ -181,9 +174,9 @@ client (C):
#### POSITION (S)
- The position of the stream has been updated. Sent to the client
- after all missing updates for a stream have been sent to the client
- and they're now up to date.
+ On receipt of a POSITION command clients should check if they have missed any
+ updates, and if so then fetch them out of band. Sent in response to a
+ REPLICATE command (but can happen at any time).
#### ERROR (S, C)
@@ -199,20 +192,7 @@ client (C):
#### REPLICATE (C)
-Asks the server to replicate a given stream. The syntax is:
-
-```
- REPLICATE <stream_name> <token>
-```
-
-Where `<token>` may be either:
- * a numeric stream_id to stream updates since (exclusive)
- * `NOW` to stream all subsequent updates.
-
-The `<stream_name>` is the name of a replication stream to subscribe
-to (see [here](../synapse/replication/tcp/streams/_base.py) for a list
-of streams). It can also be `ALL` to subscribe to all known streams,
-in which case the `<token>` must be set to `NOW`.
+Asks the server for the current position of all streams.
#### USER_SYNC (C)
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index bd1733573b..fba7ad9551 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -401,6 +401,9 @@ class GenericWorkerTyping(object):
self._room_serials[row.room_id] = token
self._room_typing[row.room_id] = row.user_ids
+ def get_current_token(self) -> int:
+ return self._latest_room_serial
+
class GenericWorkerSlavedStore(
# FIXME(#3714): We need to add UserDirectoryStore as we write directly
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 233cb33daf..a477578e44 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -499,4 +499,13 @@ class FederationSender(object):
self._get_per_destination_queue(destination).attempt_new_transaction()
def get_current_token(self) -> int:
+ # Dummy implementation for case where federation sender isn't offloaded
+ # to a worker.
return 0
+
+ async def get_replication_rows(
+ self, from_token, to_token, limit, federation_ack=None
+ ):
+ # Dummy implementation for case where federation sender isn't offloaded
+ # to a worker.
+ return []
diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py
index 28dbc6fcba..4613b2538c 100644
--- a/synapse/replication/http/__init__.py
+++ b/synapse/replication/http/__init__.py
@@ -21,6 +21,7 @@ from synapse.replication.http import (
membership,
register,
send_event,
+ streams,
)
REPLICATION_PREFIX = "/_synapse/replication"
@@ -38,3 +39,4 @@ class ReplicationRestResource(JsonResource):
login.register_servlets(hs, self)
register.register_servlets(hs, self)
devices.register_servlets(hs, self)
+ streams.register_servlets(hs, self)
diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py
new file mode 100644
index 0000000000..ffd4c61993
--- /dev/null
+++ b/synapse/replication/http/streams.py
@@ -0,0 +1,78 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.api.errors import SynapseError
+from synapse.http.servlet import parse_integer
+from synapse.replication.http._base import ReplicationEndpoint
+
+logger = logging.getLogger(__name__)
+
+
+class ReplicationGetStreamUpdates(ReplicationEndpoint):
+ """Fetches stream updates from a server. Used for streams not persisted to
+ the database, e.g. typing notifications.
+
+ The API looks like:
+
+ GET /_synapse/replication/get_repl_stream_updates/events?from_token=0&to_token=10&limit=100
+
+ 200 OK
+
+ {
+ updates: [ ... ],
+ upto_token: 10,
+ limited: False,
+ }
+
+ """
+
+ NAME = "get_repl_stream_updates"
+ PATH_ARGS = ("stream_name",)
+ METHOD = "GET"
+
+ def __init__(self, hs):
+ super().__init__(hs)
+
+ # We pull the streams from the replication steamer (if we try and make
+ # them ourselves we end up in an import loop).
+ self.streams = hs.get_replication_streamer().get_streams()
+
+ @staticmethod
+ def _serialize_payload(stream_name, from_token, upto_token, limit):
+ return {"from_token": from_token, "upto_token": upto_token, "limit": limit}
+
+ async def _handle_request(self, request, stream_name):
+ stream = self.streams.get(stream_name)
+ if stream is None:
+ raise SynapseError(400, "Unknown stream")
+
+ from_token = parse_integer(request, "from_token", required=True)
+ upto_token = parse_integer(request, "upto_token", required=True)
+ limit = parse_integer(request, "limit", required=True)
+
+ updates, upto_token, limited = await stream.get_updates_since(
+ from_token, upto_token, limit
+ )
+
+ return (
+ 200,
+ {"updates": updates, "upto_token": upto_token, "limited": limited},
+ )
+
+
+def register_servlets(hs, http_server):
+ ReplicationGetStreamUpdates(hs).register(http_server)
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index f45cbd37a0..751c799d94 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -18,8 +18,10 @@ from typing import Dict, Optional
import six
-from synapse.storage._base import SQLBaseStore
-from synapse.storage.data_stores.main.cache import CURRENT_STATE_CACHE_NAME
+from synapse.storage.data_stores.main.cache import (
+ CURRENT_STATE_CACHE_NAME,
+ CacheInvalidationWorkerStore,
+)
from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine
@@ -35,7 +37,7 @@ def __func__(inp):
return inp.__func__
-class BaseSlavedStore(SQLBaseStore):
+class BaseSlavedStore(CacheInvalidationWorkerStore):
def __init__(self, database: Database, db_conn, hs):
super(BaseSlavedStore, self).__init__(database, db_conn, hs)
if isinstance(self.database_engine, PostgresEngine):
@@ -60,6 +62,12 @@ class BaseSlavedStore(SQLBaseStore):
pos["caches"] = self._cache_id_gen.get_current_token()
return pos
+ def get_cache_stream_token(self):
+ if self._cache_id_gen:
+ return self._cache_id_gen.get_current_token()
+ else:
+ return 0
+
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "caches":
if self._cache_id_gen:
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index f22c2d44a3..bce8a3d115 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -33,6 +33,9 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
result["pushers"] = self._pushers_id_gen.get_current_token()
return result
+ def get_pushers_stream_token(self):
+ return self._pushers_id_gen.get_current_token()
+
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "pushers":
self._pushers_id_gen.advance(token)
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 02ab5b66ea..7e7ad0f798 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -55,6 +55,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
self.client_name = client_name
self.handler = handler
self.server_name = hs.config.server_name
+ self.hs = hs
self._clock = hs.get_clock() # As self.clock is defined in super class
hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying)
@@ -65,7 +66,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
def buildProtocol(self, addr):
logger.info("Connected to replication: %r", addr)
return ClientReplicationStreamProtocol(
- self.client_name, self.server_name, self._clock, self.handler
+ self.hs, self.client_name, self.server_name, self._clock, self.handler,
)
def clientConnectionLost(self, connector, reason):
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 451671412d..5a6b734094 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -136,8 +136,8 @@ class PositionCommand(Command):
"""Sent by the server to tell the client the stream postition without
needing to send an RDATA.
- Sent to the client after all missing updates for a stream have been sent
- to the client and they're now up to date.
+ On receipt of a POSITION command clients should check if they have missed
+ any updates, and if so then fetch them out of band.
"""
NAME = "POSITION"
@@ -179,42 +179,24 @@ class NameCommand(Command):
class ReplicateCommand(Command):
- """Sent by the client to subscribe to the stream.
+ """Sent by the client to subscribe to streams.
Format::
- REPLICATE <stream_name> <token>
-
- Where <token> may be either:
- * a numeric stream_id to stream updates from
- * "NOW" to stream all subsequent updates.
-
- The <stream_name> can be "ALL" to subscribe to all known streams, in which
- case the <token> must be set to "NOW", i.e.::
-
- REPLICATE ALL NOW
+ REPLICATE
"""
NAME = "REPLICATE"
- def __init__(self, stream_name, token):
- self.stream_name = stream_name
- self.token = token
+ def __init__(self):
+ pass
@classmethod
def from_line(cls, line):
- stream_name, token = line.split(" ", 1)
- if token in ("NOW", "now"):
- token = "NOW"
- else:
- token = int(token)
- return cls(stream_name, token)
+ return cls()
def to_line(self):
- return " ".join((self.stream_name, str(self.token)))
-
- def get_logcontext_id(self):
- return "REPLICATE-" + self.stream_name
+ return ""
class UserSyncCommand(Command):
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index bc1482a9bb..f81d2e2442 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -35,9 +35,7 @@ indicate which side is sending, these are *not* included on the wire::
> PING 1490197665618
< NAME synapse.app.appservice
< PING 1490197665618
- < REPLICATE events 1
- < REPLICATE backfill 1
- < REPLICATE caches 1
+ < REPLICATE
> POSITION events 1
> POSITION backfill 1
> POSITION caches 1
@@ -53,17 +51,15 @@ import fcntl
import logging
import struct
from collections import defaultdict
-from typing import Any, DefaultDict, Dict, List, Set, Tuple
+from typing import Any, DefaultDict, Dict, List, Set
-from six import iteritems, iterkeys
+from six import iteritems
from prometheus_client import Counter
-from twisted.internet import defer
from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure
-from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.commands import (
@@ -82,11 +78,16 @@ from synapse.replication.tcp.commands import (
SyncCommand,
UserSyncCommand,
)
-from synapse.replication.tcp.streams import STREAMS_MAP
+from synapse.replication.tcp.streams import STREAMS_MAP, Stream
from synapse.types import Collection
from synapse.util import Clock
from synapse.util.stringutils import random_string
+MYPY = False
+if MYPY:
+ from synapse.server import HomeServer
+
+
connection_close_counter = Counter(
"synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
)
@@ -411,16 +412,6 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.server_name = server_name
self.streamer = streamer
- # The streams the client has subscribed to and is up to date with
- self.replication_streams = set() # type: Set[str]
-
- # The streams the client is currently subscribing to.
- self.connecting_streams = set() # type: Set[str]
-
- # Map from stream name to list of updates to send once we've finished
- # subscribing the client to the stream.
- self.pending_rdata = {} # type: Dict[str, List[Tuple[int, Any]]]
-
def connectionMade(self):
self.send_command(ServerCommand(self.server_name))
BaseReplicationStreamProtocol.connectionMade(self)
@@ -436,21 +427,10 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
)
async def on_REPLICATE(self, cmd):
- stream_name = cmd.stream_name
- token = cmd.token
-
- if stream_name == "ALL":
- # Subscribe to all streams we're publishing to.
- deferreds = [
- run_in_background(self.subscribe_to_stream, stream, token)
- for stream in iterkeys(self.streamer.streams_by_name)
- ]
-
- await make_deferred_yieldable(
- defer.gatherResults(deferreds, consumeErrors=True)
- )
- else:
- await self.subscribe_to_stream(stream_name, token)
+ # Subscribe to all streams we're publishing to.
+ for stream_name in self.streamer.streams_by_name:
+ current_token = self.streamer.get_stream_token(stream_name)
+ self.send_command(PositionCommand(stream_name, current_token))
async def on_FEDERATION_ACK(self, cmd):
self.streamer.federation_ack(cmd.token)
@@ -474,87 +454,12 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
cmd.last_seen,
)
- async def subscribe_to_stream(self, stream_name, token):
- """Subscribe the remote to a stream.
-
- This invloves checking if they've missed anything and sending those
- updates down if they have. During that time new updates for the stream
- are queued and sent once we've sent down any missed updates.
- """
- self.replication_streams.discard(stream_name)
- self.connecting_streams.add(stream_name)
-
- try:
- # Get missing updates
- updates, current_token = await self.streamer.get_stream_updates(
- stream_name, token
- )
-
- # Send all the missing updates
- for update in updates:
- token, row = update[0], update[1]
- self.send_command(RdataCommand(stream_name, token, row))
-
- # We send a POSITION command to ensure that they have an up to
- # date token (especially useful if we didn't send any updates
- # above)
- self.send_command(PositionCommand(stream_name, current_token))
-
- # Now we can send any updates that came in while we were subscribing
- pending_rdata = self.pending_rdata.pop(stream_name, [])
- updates = []
- for token, update in pending_rdata:
- # If the token is null, it is part of a batch update. Batches
- # are multiple updates that share a single token. To denote
- # this, the token is set to None for all tokens in the batch
- # except for the last. If we find a None token, we keep looking
- # through tokens until we find one that is not None and then
- # process all previous updates in the batch as if they had the
- # final token.
- if token is None:
- # Store this update as part of a batch
- updates.append(update)
- continue
-
- if token <= current_token:
- # This update or batch of updates is older than
- # current_token, dismiss it
- updates = []
- continue
-
- updates.append(update)
-
- # Send all updates that are part of this batch with the
- # found token
- for update in updates:
- self.send_command(RdataCommand(stream_name, token, update))
-
- # Clear stored updates
- updates = []
-
- # They're now fully subscribed
- self.replication_streams.add(stream_name)
- except Exception as e:
- logger.exception("[%s] Failed to handle REPLICATE command", self.id())
- self.send_error("failed to handle replicate: %r", e)
- finally:
- self.connecting_streams.discard(stream_name)
-
def stream_update(self, stream_name, token, data):
"""Called when a new update is available to stream to clients.
We need to check if the client is interested in the stream or not
"""
- if stream_name in self.replication_streams:
- # The client is subscribed to the stream
- self.send_command(RdataCommand(stream_name, token, data))
- elif stream_name in self.connecting_streams:
- # The client is being subscribed to the stream
- logger.debug("[%s] Queuing RDATA %r %r", self.id(), stream_name, token)
- self.pending_rdata.setdefault(stream_name, []).append((token, data))
- else:
- # The client isn't subscribed
- logger.debug("[%s] Dropping RDATA %r %r", self.id(), stream_name, token)
+ self.send_command(RdataCommand(stream_name, token, data))
def send_sync(self, data):
self.send_command(SyncCommand(data))
@@ -638,6 +543,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
def __init__(
self,
+ hs: "HomeServer",
client_name: str,
server_name: str,
clock: Clock,
@@ -649,22 +555,25 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.server_name = server_name
self.handler = handler
+ self.streams = {
+ stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
+ } # type: Dict[str, Stream]
+
# Set of stream names that have been subscribe to, but haven't yet
# caught up with. This is used to track when the client has been fully
# connected to the remote.
- self.streams_connecting = set() # type: Set[str]
+ self.streams_connecting = set(STREAMS_MAP) # type: Set[str]
# Map of stream to batched updates. See RdataCommand for info on how
# batching works.
- self.pending_batches = {} # type: Dict[str, Any]
+ self.pending_batches = {} # type: Dict[str, List[Any]]
def connectionMade(self):
self.send_command(NameCommand(self.client_name))
BaseReplicationStreamProtocol.connectionMade(self)
# Once we've connected subscribe to the necessary streams
- for stream_name, token in iteritems(self.handler.get_streams_to_replicate()):
- self.replicate(stream_name, token)
+ self.replicate()
# Tell the server if we have any users currently syncing (should only
# happen on synchrotrons)
@@ -676,10 +585,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# We've now finished connecting to so inform the client handler
self.handler.update_connection(self)
- # This will happen if we don't actually subscribe to any streams
- if not self.streams_connecting:
- self.handler.finished_connecting()
-
async def on_SERVER(self, cmd):
if cmd.data != self.server_name:
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
@@ -697,7 +602,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
)
raise
- if cmd.token is None:
+ if cmd.token is None or stream_name in self.streams_connecting:
# I.e. this is part of a batch of updates for this stream. Batch
# until we get an update for the stream with a non None token
self.pending_batches.setdefault(stream_name, []).append(row)
@@ -707,14 +612,55 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
rows.append(row)
await self.handler.on_rdata(stream_name, cmd.token, rows)
- async def on_POSITION(self, cmd):
- # When we get a `POSITION` command it means we've finished getting
- # missing updates for the given stream, and are now up to date.
+ async def on_POSITION(self, cmd: PositionCommand):
+ stream = self.streams.get(cmd.stream_name)
+ if not stream:
+ logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
+ return
+
+ # Find where we previously streamed up to.
+ current_token = self.handler.get_streams_to_replicate().get(cmd.stream_name)
+ if current_token is None:
+ logger.warning(
+ "Got POSITION for stream we're not subscribed to: %s", cmd.stream_name
+ )
+ return
+
+ # Fetch all updates between then and now.
+ limited = True
+ while limited:
+ updates, current_token, limited = await stream.get_updates_since(
+ current_token, cmd.token
+ )
+
+ # Check if the connection was closed underneath us, if so we bail
+ # rather than risk having concurrent catch ups going on.
+ if self.state == ConnectionStates.CLOSED:
+ return
+
+ if updates:
+ await self.handler.on_rdata(
+ cmd.stream_name,
+ current_token,
+ [stream.parse_row(update[1]) for update in updates],
+ )
+
+ # We've now caught up to position sent to us, notify handler.
+ await self.handler.on_position(cmd.stream_name, cmd.token)
+
self.streams_connecting.discard(cmd.stream_name)
if not self.streams_connecting:
self.handler.finished_connecting()
- await self.handler.on_position(cmd.stream_name, cmd.token)
+ # Check if the connection was closed underneath us, if so we bail
+ # rather than risk having concurrent catch ups going on.
+ if self.state == ConnectionStates.CLOSED:
+ return
+
+ # Handle any RDATA that came in while we were catching up.
+ rows = self.pending_batches.pop(cmd.stream_name, [])
+ if rows:
+ await self.handler.on_rdata(cmd.stream_name, rows[-1].token, rows)
async def on_SYNC(self, cmd):
self.handler.on_sync(cmd.data)
@@ -722,22 +668,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
self.handler.on_remote_server_up(cmd.data)
- def replicate(self, stream_name, token):
+ def replicate(self):
"""Send the subscription request to the server
"""
- if stream_name not in STREAMS_MAP:
- raise Exception("Invalid stream name %r" % (stream_name,))
-
- logger.info(
- "[%s] Subscribing to replication stream: %r from %r",
- self.id(),
- stream_name,
- token,
- )
-
- self.streams_connecting.add(stream_name)
+ logger.info("[%s] Subscribing to replication streams", self.id())
- self.send_command(ReplicateCommand(stream_name, token))
+ self.send_command(ReplicateCommand())
def on_connection_closed(self):
BaseReplicationStreamProtocol.on_connection_closed(self)
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 6e2ebaf614..4374e99e32 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -17,7 +17,7 @@
import logging
import random
-from typing import Any, List
+from typing import Any, Dict, List
from six import itervalues
@@ -30,7 +30,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.metrics import Measure, measure_func
from .protocol import ServerReplicationStreamProtocol
-from .streams import STREAMS_MAP
+from .streams import STREAMS_MAP, Stream
from .streams.federation import FederationStream
stream_updates_counter = Counter(
@@ -52,7 +52,7 @@ class ReplicationStreamProtocolFactory(Factory):
"""
def __init__(self, hs):
- self.streamer = ReplicationStreamer(hs)
+ self.streamer = hs.get_replication_streamer()
self.clock = hs.get_clock()
self.server_name = hs.config.server_name
@@ -133,6 +133,11 @@ class ReplicationStreamer(object):
for conn in self.connections:
conn.send_error("server shutting down")
+ def get_streams(self) -> Dict[str, Stream]:
+ """Get a mapp from stream name to stream instance.
+ """
+ return self.streams_by_name
+
def on_notifier_poke(self):
"""Checks if there is actually any new data and sends it to the
connections if there are.
@@ -190,7 +195,8 @@ class ReplicationStreamer(object):
stream.current_token(),
)
try:
- updates, current_token = await stream.get_updates()
+ updates, current_token, limited = await stream.get_updates()
+ self.pending_updates |= limited
except Exception:
logger.info("Failed to handle stream %s", stream.NAME)
raise
@@ -226,8 +232,7 @@ class ReplicationStreamer(object):
self.pending_updates = False
self.is_looping = False
- @measure_func("repl.get_stream_updates")
- async def get_stream_updates(self, stream_name, token):
+ def get_stream_token(self, stream_name):
"""For a given stream get all updates since token. This is called when
a client first subscribes to a stream.
"""
@@ -235,7 +240,7 @@ class ReplicationStreamer(object):
if not stream:
raise Exception("unknown stream %s", stream_name)
- return await stream.get_updates_since(token)
+ return stream.current_token()
@measure_func("repl.federation_ack")
def federation_ack(self, token):
diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py
index 29199f5b46..37bcd3de66 100644
--- a/synapse/replication/tcp/streams/__init__.py
+++ b/synapse/replication/tcp/streams/__init__.py
@@ -24,6 +24,9 @@ Each stream is defined by the following information:
current_token: The function that returns the current token for the stream
update_function: The function that returns a list of updates between two tokens
"""
+
+from typing import Dict, Type
+
from synapse.replication.tcp.streams._base import (
AccountDataStream,
BackfillStream,
@@ -35,6 +38,7 @@ from synapse.replication.tcp.streams._base import (
PushersStream,
PushRulesStream,
ReceiptsStream,
+ Stream,
TagAccountDataStream,
ToDeviceStream,
TypingStream,
@@ -63,10 +67,12 @@ STREAMS_MAP = {
GroupServerStream,
UserSignatureStream,
)
-}
+} # type: Dict[str, Type[Stream]]
+
__all__ = [
"STREAMS_MAP",
+ "Stream",
"BackfillStream",
"PresenceStream",
"TypingStream",
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 32d9514883..c14dff6c64 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -14,13 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import itertools
import logging
from collections import namedtuple
-from typing import Any, List, Optional, Tuple
+from typing import Any, Awaitable, Callable, List, Optional, Tuple
import attr
+from synapse.replication.http.streams import ReplicationGetStreamUpdates
from synapse.types import JsonDict
logger = logging.getLogger(__name__)
@@ -29,6 +29,15 @@ logger = logging.getLogger(__name__)
MAX_EVENTS_BEHIND = 500000
+# Some type aliases to make things a bit easier.
+
+# A stream position token
+Token = int
+
+# A pair of position in stream and args used to create an instance of `ROW_TYPE`.
+StreamRow = Tuple[Token, tuple]
+
+
class Stream(object):
"""Base class for the streams.
@@ -56,6 +65,7 @@ class Stream(object):
return cls.ROW_TYPE(*row)
def __init__(self, hs):
+
# The token from which we last asked for updates
self.last_token = self.current_token()
@@ -65,61 +75,46 @@ class Stream(object):
"""
self.last_token = self.current_token()
- async def get_updates(self):
+ async def get_updates(self) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
"""Gets all updates since the last time this function was called (or
since the stream was constructed if it hadn't been called before).
Returns:
- Deferred[Tuple[List[Tuple[int, Any]], int]:
- Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
- list of ``(token, row)`` entries. ``row`` will be json-serialised and
- sent over the replication steam.
+ A triplet `(updates, new_last_token, limited)`, where `updates` is
+ a list of `(token, row)` entries, `new_last_token` is the new
+ position in stream, and `limited` is whether there are more updates
+ to fetch.
"""
- updates, current_token = await self.get_updates_since(self.last_token)
+ current_token = self.current_token()
+ updates, current_token, limited = await self.get_updates_since(
+ self.last_token, current_token
+ )
self.last_token = current_token
- return updates, current_token
+ return updates, current_token, limited
async def get_updates_since(
- self, from_token: int
- ) -> Tuple[List[Tuple[int, JsonDict]], int]:
+ self, from_token: Token, upto_token: Token, limit: int = 100
+ ) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
"""Like get_updates except allows specifying from when we should
stream updates
Returns:
- Resolves to a pair `(updates, new_last_token)`, where `updates` is
- a list of `(token, row)` entries and `new_last_token` is the new
- position in stream.
+ A triplet `(updates, new_last_token, limited)`, where `updates` is
+ a list of `(token, row)` entries, `new_last_token` is the new
+ position in stream, and `limited` is whether there are more updates
+ to fetch.
"""
- if from_token in ("NOW", "now"):
- return [], self.current_token()
-
- current_token = self.current_token()
-
from_token = int(from_token)
- if from_token == current_token:
- return [], current_token
+ if from_token == upto_token:
+ return [], upto_token, False
- rows = await self.update_function(
- from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
+ updates, upto_token, limited = await self.update_function(
+ from_token, upto_token, limit=limit,
)
-
- # never turn more than MAX_EVENTS_BEHIND + 1 into updates.
- rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
-
- updates = [(row[0], row[1:]) for row in rows]
-
- # check we didn't get more rows than the limit.
- # doing it like this allows the update_function to be a generator.
- if len(updates) >= MAX_EVENTS_BEHIND:
- raise Exception("stream %s has fallen behind" % (self.NAME))
-
- # The update function didn't hit the limit, so we must have got all
- # the updates to `current_token`, and can return that as our new
- # stream position.
- return updates, current_token
+ return updates, upto_token, limited
def current_token(self):
"""Gets the current token of the underlying streams. Should be provided
@@ -141,6 +136,48 @@ class Stream(object):
raise NotImplementedError()
+def db_query_to_update_function(
+ query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
+) -> Callable[[Token, Token, int], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
+ """Wraps a db query function which returns a list of rows to make it
+ suitable for use as an `update_function` for the Stream class
+ """
+
+ async def update_function(from_token, upto_token, limit):
+ rows = await query_function(from_token, upto_token, limit)
+ updates = [(row[0], row[1:]) for row in rows]
+ limited = False
+ if len(updates) == limit:
+ upto_token = rows[-1][0]
+ limited = True
+
+ return updates, upto_token, limited
+
+ return update_function
+
+
+def make_http_update_function(
+ hs, stream_name: str
+) -> Callable[[Token, Token, Token], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
+ """Makes a suitable function for use as an `update_function` that queries
+ the master process for updates.
+ """
+
+ client = ReplicationGetStreamUpdates.make_client(hs)
+
+ async def update_function(
+ from_token: int, upto_token: int, limit: int
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+ return await client(
+ stream_name=stream_name,
+ from_token=from_token,
+ upto_token=upto_token,
+ limit=limit,
+ )
+
+ return update_function
+
+
class BackfillStream(Stream):
"""We fetched some old events and either we had never seen that event before
or it went from being an outlier to not.
@@ -164,7 +201,7 @@ class BackfillStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
self.current_token = store.get_current_backfill_token # type: ignore
- self.update_function = store.get_all_new_backfill_event_rows # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_new_backfill_event_rows) # type: ignore
super(BackfillStream, self).__init__(hs)
@@ -190,8 +227,15 @@ class PresenceStream(Stream):
store = hs.get_datastore()
presence_handler = hs.get_presence_handler()
+ self._is_worker = hs.config.worker_app is not None
+
self.current_token = store.get_current_presence_token # type: ignore
- self.update_function = presence_handler.get_all_presence_updates # type: ignore
+
+ if hs.config.worker_app is None:
+ self.update_function = db_query_to_update_function(presence_handler.get_all_presence_updates) # type: ignore
+ else:
+ # Query master process
+ self.update_function = make_http_update_function(hs, self.NAME) # type: ignore
super(PresenceStream, self).__init__(hs)
@@ -208,7 +252,12 @@ class TypingStream(Stream):
typing_handler = hs.get_typing_handler()
self.current_token = typing_handler.get_current_token # type: ignore
- self.update_function = typing_handler.get_all_typing_updates # type: ignore
+
+ if hs.config.worker_app is None:
+ self.update_function = db_query_to_update_function(typing_handler.get_all_typing_updates) # type: ignore
+ else:
+ # Query master process
+ self.update_function = make_http_update_function(hs, self.NAME) # type: ignore
super(TypingStream, self).__init__(hs)
@@ -232,7 +281,7 @@ class ReceiptsStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_max_receipt_stream_id # type: ignore
- self.update_function = store.get_all_updated_receipts # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_updated_receipts) # type: ignore
super(ReceiptsStream, self).__init__(hs)
@@ -256,7 +305,13 @@ class PushRulesStream(Stream):
async def update_function(self, from_token, to_token, limit):
rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
- return [(row[0], row[2]) for row in rows]
+
+ limited = False
+ if len(rows) == limit:
+ to_token = rows[-1][0]
+ limited = True
+
+ return [(row[0], (row[2],)) for row in rows], to_token, limited
class PushersStream(Stream):
@@ -275,7 +330,7 @@ class PushersStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_pushers_stream_token # type: ignore
- self.update_function = store.get_all_updated_pushers_rows # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_updated_pushers_rows) # type: ignore
super(PushersStream, self).__init__(hs)
@@ -307,7 +362,7 @@ class CachesStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_cache_stream_token # type: ignore
- self.update_function = store.get_all_updated_caches # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_updated_caches) # type: ignore
super(CachesStream, self).__init__(hs)
@@ -333,7 +388,7 @@ class PublicRoomsStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_current_public_room_stream_id # type: ignore
- self.update_function = store.get_all_new_public_rooms # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_new_public_rooms) # type: ignore
super(PublicRoomsStream, self).__init__(hs)
@@ -354,7 +409,7 @@ class DeviceListsStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_device_stream_token # type: ignore
- self.update_function = store.get_all_device_list_changes_for_remotes # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_device_list_changes_for_remotes) # type: ignore
super(DeviceListsStream, self).__init__(hs)
@@ -372,7 +427,7 @@ class ToDeviceStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_to_device_stream_token # type: ignore
- self.update_function = store.get_all_new_device_messages # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_new_device_messages) # type: ignore
super(ToDeviceStream, self).__init__(hs)
@@ -392,7 +447,7 @@ class TagAccountDataStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_max_account_data_stream_id # type: ignore
- self.update_function = store.get_all_updated_tags # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_updated_tags) # type: ignore
super(TagAccountDataStream, self).__init__(hs)
@@ -412,10 +467,11 @@ class AccountDataStream(Stream):
self.store = hs.get_datastore()
self.current_token = self.store.get_max_account_data_stream_id # type: ignore
+ self.update_function = db_query_to_update_function(self._update_function) # type: ignore
super(AccountDataStream, self).__init__(hs)
- async def update_function(self, from_token, to_token, limit):
+ async def _update_function(self, from_token, to_token, limit):
global_results, room_results = await self.store.get_all_updated_account_data(
from_token, from_token, to_token, limit
)
@@ -442,7 +498,7 @@ class GroupServerStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_group_stream_token # type: ignore
- self.update_function = store.get_all_groups_changes # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_groups_changes) # type: ignore
super(GroupServerStream, self).__init__(hs)
@@ -460,6 +516,6 @@ class UserSignatureStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_device_stream_token # type: ignore
- self.update_function = store.get_all_user_signature_changes_for_remotes # type: ignore
+ self.update_function = db_query_to_update_function(store.get_all_user_signature_changes_for_remotes) # type: ignore
super(UserSignatureStream, self).__init__(hs)
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index b3afabb8cd..c6a595629f 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -19,7 +19,7 @@ from typing import Tuple, Type
import attr
-from ._base import Stream
+from ._base import Stream, db_query_to_update_function
"""Handling of the 'events' replication stream
@@ -117,10 +117,11 @@ class EventsStream(Stream):
def __init__(self, hs):
self._store = hs.get_datastore()
self.current_token = self._store.get_current_events_token # type: ignore
+ self.update_function = db_query_to_update_function(self._update_function) # type: ignore
super(EventsStream, self).__init__(hs)
- async def update_function(self, from_token, current_token, limit=None):
+ async def _update_function(self, from_token, current_token, limit=None):
event_rows = await self._store.get_all_new_forward_event_rows(
from_token, current_token, limit
)
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index f5f9336430..48c1d45718 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -15,7 +15,9 @@
# limitations under the License.
from collections import namedtuple
-from ._base import Stream
+from twisted.internet import defer
+
+from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function
class FederationStream(Stream):
@@ -33,11 +35,18 @@ class FederationStream(Stream):
NAME = "federation"
ROW_TYPE = FederationStreamRow
+ _QUERY_MASTER = True
def __init__(self, hs):
- federation_sender = hs.get_federation_sender()
-
- self.current_token = federation_sender.get_current_token # type: ignore
- self.update_function = federation_sender.get_replication_rows # type: ignore
+ # Not all synapse instances will have a federation sender instance,
+ # whether that's a `FederationSender` or a `FederationRemoteSendQueue`,
+ # so we stub the stream out when that is the case.
+ if hs.config.worker_app is None or hs.should_send_federation():
+ federation_sender = hs.get_federation_sender()
+ self.current_token = federation_sender.get_current_token # type: ignore
+ self.update_function = db_query_to_update_function(federation_sender.get_replication_rows) # type: ignore
+ else:
+ self.current_token = lambda: 0 # type: ignore
+ self.update_function = lambda from_token, upto_token, limit: defer.succeed(([], upto_token, bool)) # type: ignore
super(FederationStream, self).__init__(hs)
diff --git a/synapse/server.py b/synapse/server.py
index 1b980371de..9426eb1672 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -85,6 +85,7 @@ from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.notifier import Notifier
from synapse.push.action_generator import ActionGenerator
from synapse.push.pusherpool import PusherPool
+from synapse.replication.tcp.resource import ReplicationStreamer
from synapse.rest.media.v1.media_repository import (
MediaRepository,
MediaRepositoryResource,
@@ -199,6 +200,7 @@ class HomeServer(object):
"saml_handler",
"event_client_serializer",
"storage",
+ "replication_streamer",
]
REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"]
@@ -536,6 +538,9 @@ class HomeServer(object):
def build_storage(self) -> Storage:
return Storage(self, self.datastores)
+ def build_replication_streamer(self) -> ReplicationStreamer:
+ return ReplicationStreamer(self)
+
def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py
index d4c44dcc75..4dc5da3fe8 100644
--- a/synapse/storage/data_stores/main/cache.py
+++ b/synapse/storage/data_stores/main/cache.py
@@ -32,7 +32,29 @@ logger = logging.getLogger(__name__)
CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
-class CacheInvalidationStore(SQLBaseStore):
+class CacheInvalidationWorkerStore(SQLBaseStore):
+ def get_all_updated_caches(self, last_id, current_id, limit):
+ if last_id == current_id:
+ return defer.succeed([])
+
+ def get_all_updated_caches_txn(txn):
+ # We purposefully don't bound by the current token, as we want to
+ # send across cache invalidations as quickly as possible. Cache
+ # invalidations are idempotent, so duplicates are fine.
+ sql = (
+ "SELECT stream_id, cache_func, keys, invalidation_ts"
+ " FROM cache_invalidation_stream"
+ " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
+ )
+ txn.execute(sql, (last_id, limit))
+ return txn.fetchall()
+
+ return self.db.runInteraction(
+ "get_all_updated_caches", get_all_updated_caches_txn
+ )
+
+
+class CacheInvalidationStore(CacheInvalidationWorkerStore):
async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
"""Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches.
@@ -145,26 +167,6 @@ class CacheInvalidationStore(SQLBaseStore):
},
)
- def get_all_updated_caches(self, last_id, current_id, limit):
- if last_id == current_id:
- return defer.succeed([])
-
- def get_all_updated_caches_txn(txn):
- # We purposefully don't bound by the current token, as we want to
- # send across cache invalidations as quickly as possible. Cache
- # invalidations are idempotent, so duplicates are fine.
- sql = (
- "SELECT stream_id, cache_func, keys, invalidation_ts"
- " FROM cache_invalidation_stream"
- " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
- )
- txn.execute(sql, (last_id, limit))
- return txn.fetchall()
-
- return self.db.runInteraction(
- "get_all_updated_caches", get_all_updated_caches_txn
- )
-
def get_cache_stream_token(self):
if self._cache_id_gen:
return self._cache_id_gen.get_current_token()
diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py
index 0613b49f4a..9a1178fb39 100644
--- a/synapse/storage/data_stores/main/deviceinbox.py
+++ b/synapse/storage/data_stores/main/deviceinbox.py
@@ -207,6 +207,50 @@ class DeviceInboxWorkerStore(SQLBaseStore):
"delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
)
+ def get_all_new_device_messages(self, last_pos, current_pos, limit):
+ """
+ Args:
+ last_pos(int):
+ current_pos(int):
+ limit(int):
+ Returns:
+ A deferred list of rows from the device inbox
+ """
+ if last_pos == current_pos:
+ return defer.succeed([])
+
+ def get_all_new_device_messages_txn(txn):
+ # We limit like this as we might have multiple rows per stream_id, and
+ # we want to make sure we always get all entries for any stream_id
+ # we return.
+ upper_pos = min(current_pos, last_pos + limit)
+ sql = (
+ "SELECT max(stream_id), user_id"
+ " FROM device_inbox"
+ " WHERE ? < stream_id AND stream_id <= ?"
+ " GROUP BY user_id"
+ )
+ txn.execute(sql, (last_pos, upper_pos))
+ rows = txn.fetchall()
+
+ sql = (
+ "SELECT max(stream_id), destination"
+ " FROM device_federation_outbox"
+ " WHERE ? < stream_id AND stream_id <= ?"
+ " GROUP BY destination"
+ )
+ txn.execute(sql, (last_pos, upper_pos))
+ rows.extend(txn)
+
+ # Order by ascending stream ordering
+ rows.sort()
+
+ return rows
+
+ return self.db.runInteraction(
+ "get_all_new_device_messages", get_all_new_device_messages_txn
+ )
+
class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
@@ -411,47 +455,3 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
rows.append((user_id, device_id, stream_id, message_json))
txn.executemany(sql, rows)
-
- def get_all_new_device_messages(self, last_pos, current_pos, limit):
- """
- Args:
- last_pos(int):
- current_pos(int):
- limit(int):
- Returns:
- A deferred list of rows from the device inbox
- """
- if last_pos == current_pos:
- return defer.succeed([])
-
- def get_all_new_device_messages_txn(txn):
- # We limit like this as we might have multiple rows per stream_id, and
- # we want to make sure we always get all entries for any stream_id
- # we return.
- upper_pos = min(current_pos, last_pos + limit)
- sql = (
- "SELECT max(stream_id), user_id"
- " FROM device_inbox"
- " WHERE ? < stream_id AND stream_id <= ?"
- " GROUP BY user_id"
- )
- txn.execute(sql, (last_pos, upper_pos))
- rows = txn.fetchall()
-
- sql = (
- "SELECT max(stream_id), destination"
- " FROM device_federation_outbox"
- " WHERE ? < stream_id AND stream_id <= ?"
- " GROUP BY destination"
- )
- txn.execute(sql, (last_pos, upper_pos))
- rows.extend(txn)
-
- # Order by ascending stream ordering
- rows.sort()
-
- return rows
-
- return self.db.runInteraction(
- "get_all_new_device_messages", get_all_new_device_messages_txn
- )
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index d593ef47b8..e71c23541d 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -1267,104 +1267,6 @@ class EventsStore(
ret = yield self.db.runInteraction("count_daily_active_rooms", _count)
return ret
- def get_current_backfill_token(self):
- """The current minimum token that backfilled events have reached"""
- return -self._backfill_id_gen.get_current_token()
-
- def get_current_events_token(self):
- """The current maximum token that events have reached"""
- return self._stream_id_gen.get_current_token()
-
- def get_all_new_forward_event_rows(self, last_id, current_id, limit):
- if last_id == current_id:
- return defer.succeed([])
-
- def get_all_new_forward_event_rows(txn):
- sql = (
- "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts, relates_to_id"
- " FROM events AS e"
- " LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
- " LEFT JOIN event_relations USING (event_id)"
- " WHERE ? < stream_ordering AND stream_ordering <= ?"
- " ORDER BY stream_ordering ASC"
- " LIMIT ?"
- )
- txn.execute(sql, (last_id, current_id, limit))
- new_event_updates = txn.fetchall()
-
- if len(new_event_updates) == limit:
- upper_bound = new_event_updates[-1][0]
- else:
- upper_bound = current_id
-
- sql = (
- "SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts, relates_to_id"
- " FROM events AS e"
- " INNER JOIN ex_outlier_stream USING (event_id)"
- " LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
- " LEFT JOIN event_relations USING (event_id)"
- " WHERE ? < event_stream_ordering"
- " AND event_stream_ordering <= ?"
- " ORDER BY event_stream_ordering DESC"
- )
- txn.execute(sql, (last_id, upper_bound))
- new_event_updates.extend(txn)
-
- return new_event_updates
-
- return self.db.runInteraction(
- "get_all_new_forward_event_rows", get_all_new_forward_event_rows
- )
-
- def get_all_new_backfill_event_rows(self, last_id, current_id, limit):
- if last_id == current_id:
- return defer.succeed([])
-
- def get_all_new_backfill_event_rows(txn):
- sql = (
- "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts, relates_to_id"
- " FROM events AS e"
- " LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
- " LEFT JOIN event_relations USING (event_id)"
- " WHERE ? > stream_ordering AND stream_ordering >= ?"
- " ORDER BY stream_ordering ASC"
- " LIMIT ?"
- )
- txn.execute(sql, (-last_id, -current_id, limit))
- new_event_updates = txn.fetchall()
-
- if len(new_event_updates) == limit:
- upper_bound = new_event_updates[-1][0]
- else:
- upper_bound = current_id
-
- sql = (
- "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts, relates_to_id"
- " FROM events AS e"
- " INNER JOIN ex_outlier_stream USING (event_id)"
- " LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
- " LEFT JOIN event_relations USING (event_id)"
- " WHERE ? > event_stream_ordering"
- " AND event_stream_ordering >= ?"
- " ORDER BY event_stream_ordering DESC"
- )
- txn.execute(sql, (-last_id, -upper_bound))
- new_event_updates.extend(txn.fetchall())
-
- return new_event_updates
-
- return self.db.runInteraction(
- "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
- )
-
@cached(num_args=5, max_entries=10)
def get_all_new_events(
self,
@@ -1850,22 +1752,6 @@ class EventsStore(
return (int(res["topological_ordering"]), int(res["stream_ordering"]))
- def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
- def get_all_updated_current_state_deltas_txn(txn):
- sql = """
- SELECT stream_id, room_id, type, state_key, event_id
- FROM current_state_delta_stream
- WHERE ? < stream_id AND stream_id <= ?
- ORDER BY stream_id ASC LIMIT ?
- """
- txn.execute(sql, (from_token, to_token, limit))
- return txn.fetchall()
-
- return self.db.runInteraction(
- "get_all_updated_current_state_deltas",
- get_all_updated_current_state_deltas_txn,
- )
-
def insert_labels_for_event_txn(
self, txn, event_id, labels, room_id, topological_ordering
):
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index 3013f49d32..16ea8948b1 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -963,3 +963,117 @@ class EventsWorkerStore(SQLBaseStore):
complexity_v1 = round(state_events / 500, 2)
return {"v1": complexity_v1}
+
+ def get_current_backfill_token(self):
+ """The current minimum token that backfilled events have reached"""
+ return -self._backfill_id_gen.get_current_token()
+
+ def get_current_events_token(self):
+ """The current maximum token that events have reached"""
+ return self._stream_id_gen.get_current_token()
+
+ def get_all_new_forward_event_rows(self, last_id, current_id, limit):
+ if last_id == current_id:
+ return defer.succeed([])
+
+ def get_all_new_forward_event_rows(txn):
+ sql = (
+ "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
+ " state_key, redacts, relates_to_id"
+ " FROM events AS e"
+ " LEFT JOIN redactions USING (event_id)"
+ " LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN event_relations USING (event_id)"
+ " WHERE ? < stream_ordering AND stream_ordering <= ?"
+ " ORDER BY stream_ordering ASC"
+ " LIMIT ?"
+ )
+ txn.execute(sql, (last_id, current_id, limit))
+ new_event_updates = txn.fetchall()
+
+ if len(new_event_updates) == limit:
+ upper_bound = new_event_updates[-1][0]
+ else:
+ upper_bound = current_id
+
+ sql = (
+ "SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
+ " state_key, redacts, relates_to_id"
+ " FROM events AS e"
+ " INNER JOIN ex_outlier_stream USING (event_id)"
+ " LEFT JOIN redactions USING (event_id)"
+ " LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN event_relations USING (event_id)"
+ " WHERE ? < event_stream_ordering"
+ " AND event_stream_ordering <= ?"
+ " ORDER BY event_stream_ordering DESC"
+ )
+ txn.execute(sql, (last_id, upper_bound))
+ new_event_updates.extend(txn)
+
+ return new_event_updates
+
+ return self.db.runInteraction(
+ "get_all_new_forward_event_rows", get_all_new_forward_event_rows
+ )
+
+ def get_all_new_backfill_event_rows(self, last_id, current_id, limit):
+ if last_id == current_id:
+ return defer.succeed([])
+
+ def get_all_new_backfill_event_rows(txn):
+ sql = (
+ "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
+ " state_key, redacts, relates_to_id"
+ " FROM events AS e"
+ " LEFT JOIN redactions USING (event_id)"
+ " LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN event_relations USING (event_id)"
+ " WHERE ? > stream_ordering AND stream_ordering >= ?"
+ " ORDER BY stream_ordering ASC"
+ " LIMIT ?"
+ )
+ txn.execute(sql, (-last_id, -current_id, limit))
+ new_event_updates = txn.fetchall()
+
+ if len(new_event_updates) == limit:
+ upper_bound = new_event_updates[-1][0]
+ else:
+ upper_bound = current_id
+
+ sql = (
+ "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
+ " state_key, redacts, relates_to_id"
+ " FROM events AS e"
+ " INNER JOIN ex_outlier_stream USING (event_id)"
+ " LEFT JOIN redactions USING (event_id)"
+ " LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN event_relations USING (event_id)"
+ " WHERE ? > event_stream_ordering"
+ " AND event_stream_ordering >= ?"
+ " ORDER BY event_stream_ordering DESC"
+ )
+ txn.execute(sql, (-last_id, -upper_bound))
+ new_event_updates.extend(txn.fetchall())
+
+ return new_event_updates
+
+ return self.db.runInteraction(
+ "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
+ )
+
+ def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
+ def get_all_updated_current_state_deltas_txn(txn):
+ sql = """
+ SELECT stream_id, room_id, type, state_key, event_id
+ FROM current_state_delta_stream
+ WHERE ? < stream_id AND stream_id <= ?
+ ORDER BY stream_id ASC LIMIT ?
+ """
+ txn.execute(sql, (from_token, to_token, limit))
+ return txn.fetchall()
+
+ return self.db.runInteraction(
+ "get_all_updated_current_state_deltas",
+ get_all_updated_current_state_deltas_txn,
+ )
diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py
index e6c10c6316..aaebe427d3 100644
--- a/synapse/storage/data_stores/main/room.py
+++ b/synapse/storage/data_stores/main/room.py
@@ -732,6 +732,26 @@ class RoomWorkerStore(SQLBaseStore):
return total_media_quarantined
+ def get_all_new_public_rooms(self, prev_id, current_id, limit):
+ def get_all_new_public_rooms(txn):
+ sql = """
+ SELECT stream_id, room_id, visibility, appservice_id, network_id
+ FROM public_room_list_stream
+ WHERE stream_id > ? AND stream_id <= ?
+ ORDER BY stream_id ASC
+ LIMIT ?
+ """
+
+ txn.execute(sql, (prev_id, current_id, limit))
+ return txn.fetchall()
+
+ if prev_id == current_id:
+ return defer.succeed([])
+
+ return self.db.runInteraction(
+ "get_all_new_public_rooms", get_all_new_public_rooms
+ )
+
class RoomBackgroundUpdateStore(SQLBaseStore):
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
@@ -1249,26 +1269,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
def get_current_public_room_stream_id(self):
return self._public_room_id_gen.get_current_token()
- def get_all_new_public_rooms(self, prev_id, current_id, limit):
- def get_all_new_public_rooms(txn):
- sql = """
- SELECT stream_id, room_id, visibility, appservice_id, network_id
- FROM public_room_list_stream
- WHERE stream_id > ? AND stream_id <= ?
- ORDER BY stream_id ASC
- LIMIT ?
- """
-
- txn.execute(sql, (prev_id, current_id, limit))
- return txn.fetchall()
-
- if prev_id == current_id:
- return defer.succeed([])
-
- return self.db.runInteraction(
- "get_all_new_public_rooms", get_all_new_public_rooms
- )
-
@defer.inlineCallbacks
def block_room(self, room_id, user_id):
"""Marks the room as blocked. Can be called multiple times.
diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py
index e96ad4ca4e..a755fe2879 100644
--- a/tests/replication/tcp/streams/_base.py
+++ b/tests/replication/tcp/streams/_base.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
from mock import Mock
from synapse.replication.tcp.commands import ReplicateCommand
@@ -29,19 +30,37 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
# build a replication server
server_factory = ReplicationStreamProtocolFactory(self.hs)
self.streamer = server_factory.streamer
- server = server_factory.buildProtocol(None)
+ self.server = server_factory.buildProtocol(None)
- # build a replication client, with a dummy handler
- handler_factory = Mock()
- self.test_handler = TestReplicationClientHandler()
- self.test_handler.factory = handler_factory
+ self.test_handler = Mock(wraps=TestReplicationClientHandler())
self.client = ClientReplicationStreamProtocol(
- "client", "test", clock, self.test_handler
+ hs, "client", "test", clock, self.test_handler,
)
- # wire them together
- self.client.makeConnection(FakeTransport(server, reactor))
- server.makeConnection(FakeTransport(self.client, reactor))
+ self._client_transport = None
+ self._server_transport = None
+
+ def reconnect(self):
+ if self._client_transport:
+ self.client.close()
+
+ if self._server_transport:
+ self.server.close()
+
+ self._client_transport = FakeTransport(self.server, self.reactor)
+ self.client.makeConnection(self._client_transport)
+
+ self._server_transport = FakeTransport(self.client, self.reactor)
+ self.server.makeConnection(self._server_transport)
+
+ def disconnect(self):
+ if self._client_transport:
+ self._client_transport = None
+ self.client.close()
+
+ if self._server_transport:
+ self._server_transport = None
+ self.server.close()
def replicate(self):
"""Tell the master side of replication that something has happened, and then
@@ -50,19 +69,24 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.streamer.on_notifier_poke()
self.pump(0.1)
- def replicate_stream(self, stream, token="NOW"):
+ def replicate_stream(self):
"""Make the client end a REPLICATE command to set up a subscription to a stream"""
- self.client.send_command(ReplicateCommand(stream, token))
+ self.client.send_command(ReplicateCommand())
class TestReplicationClientHandler(object):
"""Drop-in for ReplicationClientHandler which just collects RDATA rows"""
def __init__(self):
- self.received_rdata_rows = []
+ self.streams = set()
+ self._received_rdata_rows = []
def get_streams_to_replicate(self):
- return {}
+ positions = {s: 0 for s in self.streams}
+ for stream, token, _ in self._received_rdata_rows:
+ if stream in self.streams:
+ positions[stream] = max(token, positions.get(stream, 0))
+ return positions
def get_currently_syncing_users(self):
return []
@@ -73,6 +97,9 @@ class TestReplicationClientHandler(object):
def finished_connecting(self):
pass
+ async def on_position(self, stream_name, token):
+ """Called when we get new position data."""
+
async def on_rdata(self, stream_name, token, rows):
for r in rows:
- self.received_rdata_rows.append((stream_name, token, r))
+ self._received_rdata_rows.append((stream_name, token, r))
diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py
index fa2493cad6..0ec0825a0e 100644
--- a/tests/replication/tcp/streams/test_receipts.py
+++ b/tests/replication/tcp/streams/test_receipts.py
@@ -17,30 +17,64 @@ from synapse.replication.tcp.streams._base import ReceiptsStream
from tests.replication.tcp.streams._base import BaseStreamTestCase
USER_ID = "@feeling:blue"
-ROOM_ID = "!room:blue"
-EVENT_ID = "$event:blue"
class ReceiptsStreamTestCase(BaseStreamTestCase):
def test_receipt(self):
+ self.reconnect()
+
# make the client subscribe to the receipts stream
- self.replicate_stream("receipts", "NOW")
+ self.replicate_stream()
+ self.test_handler.streams.add("receipts")
# tell the master to send a new receipt
self.get_success(
self.hs.get_datastore().insert_receipt(
- ROOM_ID, "m.read", USER_ID, [EVENT_ID], {"a": 1}
+ "!room:blue", "m.read", USER_ID, ["$event:blue"], {"a": 1}
)
)
self.replicate()
# there should be one RDATA command
- rdata_rows = self.test_handler.received_rdata_rows
+ self.test_handler.on_rdata.assert_called_once()
+ stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+ self.assertEqual(stream_name, "receipts")
self.assertEqual(1, len(rdata_rows))
- self.assertEqual(rdata_rows[0][0], "receipts")
- row = rdata_rows[0][2] # type: ReceiptsStream.ReceiptsStreamRow
- self.assertEqual(ROOM_ID, row.room_id)
+ row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow
+ self.assertEqual("!room:blue", row.room_id)
self.assertEqual("m.read", row.receipt_type)
self.assertEqual(USER_ID, row.user_id)
- self.assertEqual(EVENT_ID, row.event_id)
+ self.assertEqual("$event:blue", row.event_id)
self.assertEqual({"a": 1}, row.data)
+
+ # Now let's disconnect and insert some data.
+ self.disconnect()
+
+ self.test_handler.on_rdata.reset_mock()
+
+ self.get_success(
+ self.hs.get_datastore().insert_receipt(
+ "!room2:blue", "m.read", USER_ID, ["$event2:foo"], {"a": 2}
+ )
+ )
+ self.replicate()
+
+ # Nothing should have happened as we are disconnected
+ self.test_handler.on_rdata.assert_not_called()
+
+ self.reconnect()
+ self.pump(0.1)
+
+ # We should now have caught up and get the missing data
+ self.test_handler.on_rdata.assert_called_once()
+ stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+ self.assertEqual(stream_name, "receipts")
+ self.assertEqual(token, 3)
+ self.assertEqual(1, len(rdata_rows))
+
+ row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow
+ self.assertEqual("!room2:blue", row.room_id)
+ self.assertEqual("m.read", row.receipt_type)
+ self.assertEqual(USER_ID, row.user_id)
+ self.assertEqual("$event2:foo", row.event_id)
+ self.assertEqual({"a": 2}, row.data)
|