diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 64edadb624..2b3972cb14 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -92,7 +92,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
if self.CACHE:
self.response_cache = ResponseCache(
hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000
- )
+ ) # type: ResponseCache[str]
# We reserve `instance_name` as a parameter to sending requests, so we
# assert here that sub classes don't try and use the name.
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 5393b9a9e7..7a0dbb5b1a 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -62,7 +62,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
self.store = hs.get_datastore()
self.storage = hs.get_storage()
self.clock = hs.get_clock()
- self.federation_handler = hs.get_handlers().federation_handler
+ self.federation_handler = hs.get_federation_handler()
@staticmethod
async def _serialize_payload(store, room_id, event_and_contexts, backfilled):
@@ -254,20 +254,20 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
return 200, {}
-class ReplicationStoreRoomOnInviteRestServlet(ReplicationEndpoint):
+class ReplicationStoreRoomOnOutlierMembershipRestServlet(ReplicationEndpoint):
"""Called to clean up any data in DB for a given room, ready for the
server to join the room.
Request format:
- POST /_synapse/replication/store_room_on_invite/:room_id/:txn_id
+ POST /_synapse/replication/store_room_on_outlier_membership/:room_id/:txn_id
{
"room_version": "1",
}
"""
- NAME = "store_room_on_invite"
+ NAME = "store_room_on_outlier_membership"
PATH_ARGS = ("room_id",)
def __init__(self, hs):
@@ -282,7 +282,7 @@ class ReplicationStoreRoomOnInviteRestServlet(ReplicationEndpoint):
async def _handle_request(self, request, room_id):
content = parse_json_object_from_request(request)
room_version = KNOWN_ROOM_VERSIONS[content["room_version"]]
- await self.store.maybe_store_room_on_invite(room_id, room_version)
+ await self.store.maybe_store_room_on_outlier_membership(room_id, room_version)
return 200, {}
@@ -291,4 +291,4 @@ def register_servlets(hs, http_server):
ReplicationFederationSendEduRestServlet(hs).register(http_server)
ReplicationGetQueryRestServlet(hs).register(http_server)
ReplicationCleanRoomRestServlet(hs).register(http_server)
- ReplicationStoreRoomOnInviteRestServlet(hs).register(http_server)
+ ReplicationStoreRoomOnOutlierMembershipRestServlet(hs).register(http_server)
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index 30680baee8..84e002f934 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -12,9 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, List, Optional, Tuple
+
+from twisted.web.http import Request
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
@@ -47,21 +48,28 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
def __init__(self, hs):
super().__init__(hs)
- self.federation_handler = hs.get_handlers().federation_handler
+ self.federation_handler = hs.get_federation_handler()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
@staticmethod
- async def _serialize_payload(
- requester, room_id, user_id, remote_room_hosts, content
- ):
+ async def _serialize_payload( # type: ignore
+ requester: Requester,
+ room_id: str,
+ user_id: str,
+ remote_room_hosts: List[str],
+ content: JsonDict,
+ ) -> JsonDict:
"""
Args:
- requester(Requester)
- room_id (str)
- user_id (str)
- remote_room_hosts (list[str]): Servers to try and join via
- content(dict): The event content to use for the join event
+ requester: The user making the request according to the access token
+ room_id: The ID of the room.
+ user_id: The ID of the user.
+ remote_room_hosts: Servers to try and join via
+ content: The event content to use for the join event
+
+ Returns:
+ A dict representing the payload of the request.
"""
return {
"requester": requester.serialize(),
@@ -69,7 +77,9 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
"content": content,
}
- async def _handle_request(self, request, room_id, user_id):
+ async def _handle_request( # type: ignore
+ self, request: Request, room_id: str, user_id: str
+ ) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
remote_room_hosts = content["remote_room_hosts"]
@@ -77,8 +87,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
requester = Requester.deserialize(self.store, content["requester"])
- if requester.user:
- request.authenticated_entity = requester.user.to_string()
+ request.requester = requester
logger.info("remote_join: %s into room: %s", user_id, room_id)
@@ -119,14 +128,17 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
txn_id: Optional[str],
requester: Requester,
content: JsonDict,
- ):
+ ) -> JsonDict:
"""
Args:
- invite_event_id: ID of the invite to be rejected
- txn_id: optional transaction ID supplied by the client
- requester: user making the rejection request, according to the access token
- content: additional content to include in the rejection event.
+ invite_event_id: The ID of the invite to be rejected.
+ txn_id: Optional transaction ID supplied by the client
+ requester: User making the rejection request, according to the access token
+ content: Additional content to include in the rejection event.
Normally an empty dict.
+
+ Returns:
+ A dict representing the payload of the request.
"""
return {
"txn_id": txn_id,
@@ -134,7 +146,9 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
"content": content,
}
- async def _handle_request(self, request, invite_event_id):
+ async def _handle_request( # type: ignore
+ self, request: Request, invite_event_id: str
+ ) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
txn_id = content["txn_id"]
@@ -142,8 +156,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
requester = Requester.deserialize(self.store, content["requester"])
- if requester.user:
- request.authenticated_entity = requester.user.to_string()
+ request.requester = requester
# hopefully we're now on the master, so this won't recurse!
event_id, stream_id = await self.member_handler.remote_reject_invite(
@@ -176,18 +189,25 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
self.distributor = hs.get_distributor()
@staticmethod
- async def _serialize_payload(room_id, user_id, change):
+ async def _serialize_payload( # type: ignore
+ room_id: str, user_id: str, change: str
+ ) -> JsonDict:
"""
Args:
- room_id (str)
- user_id (str)
- change (str): "left"
+ room_id: The ID of the room.
+ user_id: The ID of the user.
+ change: "left"
+
+ Returns:
+ A dict representing the payload of the request.
"""
assert change == "left"
return {}
- def _handle_request(self, request, room_id, user_id, change):
+ def _handle_request( # type: ignore
+ self, request: Request, room_id: str, user_id: str, change: str
+ ) -> Tuple[int, JsonDict]:
logger.info("user membership change: %s in %s", user_id, room_id)
user = UserID.from_string(user_id)
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index 9a3a694d5d..8fa104c8d3 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -46,6 +46,12 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
"ratelimit": true,
"extra_users": [],
}
+
+ 200 OK
+
+ { "stream_id": 12345, "event_id": "$abcdef..." }
+
+ The returned event ID may not match the sent event if it was deduplicated.
"""
NAME = "send_event"
@@ -109,18 +115,23 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
ratelimit = content["ratelimit"]
extra_users = [UserID.from_string(u) for u in content["extra_users"]]
- if requester.user:
- request.authenticated_entity = requester.user.to_string()
+ request.requester = requester
logger.info(
"Got event to send with ID: %s into room: %s", event.event_id, event.room_id
)
- stream_id = await self.event_creation_handler.persist_and_notify_client_event(
+ event = await self.event_creation_handler.persist_and_notify_client_event(
requester, event, context, ratelimit=ratelimit, extra_users=extra_users
)
- return 200, {"stream_id": stream_id}
+ return (
+ 200,
+ {
+ "stream_id": event.internal_metadata.stream_ordering,
+ "event_id": event.event_id,
+ },
+ )
def register_servlets(hs, http_server):
diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
index 1f8dafe7ea..0f5b7adef7 100644
--- a/synapse/replication/slave/storage/client_ips.py
+++ b/synapse/replication/slave/storage/client_ips.py
@@ -15,7 +15,7 @@
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
-from synapse.util.caches.descriptors import Cache
+from synapse.util.caches.lrucache import LruCache
from ._base import BaseSlavedStore
@@ -24,9 +24,9 @@ class SlavedClientIpStore(BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
- self.client_ip_last_seen = Cache(
- name="client_ip_last_seen", keylen=4, max_entries=50000
- )
+ self.client_ip_last_seen = LruCache(
+ cache_name="client_ip_last_seen", keylen=4, max_size=50000
+ ) # type: LruCache[tuple, int]
async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
now = int(self._clock.time_msec())
@@ -41,7 +41,7 @@ class SlavedClientIpStore(BaseSlavedStore):
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
return
- self.client_ip_last_seen.prefill(key, now)
+ self.client_ip_last_seen.set(key, now)
self.hs.get_tcp_replication().send_user_ip(
user_id, access_token, ip, user_agent, device_id, now
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index e165429cad..2618eb1e53 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -141,21 +141,25 @@ class ReplicationDataHandler:
if row.type != EventsStreamEventRow.TypeId:
continue
assert isinstance(row, EventsStreamRow)
+ assert isinstance(row.data, EventsStreamEventRow)
- event = await self.store.get_event(
- row.data.event_id, allow_rejected=True
- )
- if event.rejected_reason:
+ if row.data.rejected:
continue
extra_users = () # type: Tuple[UserID, ...]
- if event.type == EventTypes.Member:
- extra_users = (UserID.from_string(event.state_key),)
+ if row.data.type == EventTypes.Member and row.data.state_key:
+ extra_users = (UserID.from_string(row.data.state_key),)
max_token = self.store.get_room_max_token()
event_pos = PersistedEventPosition(instance_name, token)
- self.notifier.on_new_room_event(
- event, event_pos, max_token, extra_users
+ self.notifier.on_new_room_event_args(
+ event_pos=event_pos,
+ max_room_stream_token=max_token,
+ extra_users=extra_users,
+ room_id=row.data.room_id,
+ event_type=row.data.type,
+ state_key=row.data.state_key,
+ membership=row.data.membership,
)
# Notify any waiting deferreds. The list is ordered by position so we
@@ -191,6 +195,10 @@ class ReplicationDataHandler:
async def on_position(self, stream_name: str, instance_name: str, token: int):
self.store.process_replication_rows(stream_name, instance_name, token, [])
+ # We poke the generic "replication" notifier to wake anything up that
+ # may be streaming.
+ self.notifier.notify_replication()
+
def on_remote_server_up(self, server: str):
"""Called when get a new REMOTE_SERVER_UP command."""
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 8cd47770c1..ac532ed588 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -141,15 +141,23 @@ class RdataCommand(Command):
class PositionCommand(Command):
- """Sent by the server to tell the client the stream position without
- needing to send an RDATA.
+ """Sent by an instance to tell others the stream position without needing to
+ send an RDATA.
+
+ Two tokens are sent, the new position and the last position sent by the
+ instance (in an RDATA or other POSITION). The tokens are chosen so that *no*
+ rows were written by the instance between the `prev_token` and `new_token`.
+ (If an instance hasn't sent a position before then the new position can be
+ used for both.)
Format::
- POSITION <stream_name> <instance_name> <token>
+ POSITION <stream_name> <instance_name> <prev_token> <new_token>
- On receipt of a POSITION command clients should check if they have missed
- any updates, and if so then fetch them out of band.
+ On receipt of a POSITION command instances should check if they have missed
+ any updates, and if so then fetch them out of band. Instances can check this
+ by comparing their view of the current token for the sending instance with
+ the included `prev_token`.
The `<instance_name>` is the process that sent the command and is the source
of the stream.
@@ -157,18 +165,26 @@ class PositionCommand(Command):
NAME = "POSITION"
- def __init__(self, stream_name, instance_name, token):
+ def __init__(self, stream_name, instance_name, prev_token, new_token):
self.stream_name = stream_name
self.instance_name = instance_name
- self.token = token
+ self.prev_token = prev_token
+ self.new_token = new_token
@classmethod
def from_line(cls, line):
- stream_name, instance_name, token = line.split(" ", 2)
- return cls(stream_name, instance_name, int(token))
+ stream_name, instance_name, prev_token, new_token = line.split(" ", 3)
+ return cls(stream_name, instance_name, int(prev_token), int(new_token))
def to_line(self):
- return " ".join((self.stream_name, self.instance_name, str(self.token)))
+ return " ".join(
+ (
+ self.stream_name,
+ self.instance_name,
+ str(self.prev_token),
+ str(self.new_token),
+ )
+ )
class ErrorCommand(_SimpleCommand):
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index b323841f73..95e5502bf2 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -101,8 +101,9 @@ class ReplicationCommandHandler:
self._streams_to_replicate = [] # type: List[Stream]
for stream in self._streams.values():
- if stream.NAME == CachesStream.NAME:
- # All workers can write to the cache invalidation stream.
+ if hs.config.redis.redis_enabled and stream.NAME == CachesStream.NAME:
+ # All workers can write to the cache invalidation stream when
+ # using redis.
self._streams_to_replicate.append(stream)
continue
@@ -251,10 +252,9 @@ class ReplicationCommandHandler:
using TCP.
"""
if hs.config.redis.redis_enabled:
- import txredisapi
-
from synapse.replication.tcp.redis import (
RedisDirectTcpReplicationClientFactory,
+ lazyConnection,
)
logger.info(
@@ -271,7 +271,8 @@ class ReplicationCommandHandler:
# connection after SUBSCRIBE is called).
# First create the connection for sending commands.
- outbound_redis_connection = txredisapi.lazyConnection(
+ outbound_redis_connection = lazyConnection(
+ reactor=hs.get_reactor(),
host=hs.config.redis_host,
port=hs.config.redis_port,
password=hs.config.redis.redis_password,
@@ -313,11 +314,14 @@ class ReplicationCommandHandler:
# We respond with current position of all streams this instance
# replicates.
for stream in self.get_streams_to_replicate():
+ # Note that we use the current token as the prev token here (rather
+ # than stream.last_token), as we can't be sure that there have been
+ # no rows written between last token and the current token (since we
+ # might be racing with the replication sending bg process).
+ current_token = stream.current_token(self._instance_name)
self.send_command(
PositionCommand(
- stream.NAME,
- self._instance_name,
- stream.current_token(self._instance_name),
+ stream.NAME, self._instance_name, current_token, current_token,
)
)
@@ -511,16 +515,16 @@ class ReplicationCommandHandler:
# If the position token matches our current token then we're up to
# date and there's nothing to do. Otherwise, fetch all updates
# between then and now.
- missing_updates = cmd.token != current_token
+ missing_updates = cmd.prev_token != current_token
while missing_updates:
logger.info(
"Fetching replication rows for '%s' between %i and %i",
stream_name,
current_token,
- cmd.token,
+ cmd.new_token,
)
(updates, current_token, missing_updates) = await stream.get_updates_since(
- cmd.instance_name, current_token, cmd.token
+ cmd.instance_name, current_token, cmd.new_token
)
# TODO: add some tests for this
@@ -536,11 +540,11 @@ class ReplicationCommandHandler:
[stream.parse_row(row) for row in rows],
)
- logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
+ logger.info("Caught up with stream '%s' to %i", stream_name, cmd.new_token)
# We've now caught up to position sent to us, notify handler.
await self._replication_data_handler.on_position(
- cmd.stream_name, cmd.instance_name, cmd.token
+ cmd.stream_name, cmd.instance_name, cmd.new_token
)
self._streams_by_connection.setdefault(conn, set()).add(stream_name)
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 0b0d204e64..a509e599c2 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -51,10 +51,11 @@ import fcntl
import logging
import struct
from inspect import isawaitable
-from typing import TYPE_CHECKING, List
+from typing import TYPE_CHECKING, List, Optional
from prometheus_client import Counter
+from twisted.internet import task
from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure
@@ -152,9 +153,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.last_received_command = self.clock.time_msec()
self.last_sent_command = 0
- self.time_we_closed = None # When we requested the connection be closed
+ # When we requested the connection be closed
+ self.time_we_closed = None # type: Optional[int]
- self.received_ping = False # Have we reecived a ping from the other side
+ self.received_ping = False # Have we received a ping from the other side
self.state = ConnectionStates.CONNECTING
@@ -165,7 +167,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.pending_commands = [] # type: List[Command]
# The LoopingCall for sending pings.
- self._send_ping_loop = None
+ self._send_ping_loop = None # type: Optional[task.LoopingCall]
# a logcontext which we use for processing incoming commands. We declare it as a
# background process so that the CPU stats get reported to prometheus.
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index f225e533de..bc6ba709a7 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -15,7 +15,7 @@
import logging
from inspect import isawaitable
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Optional
import txredisapi
@@ -166,7 +166,9 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
Args:
cmd (Command)
"""
- run_as_background_process("send-cmd", self._async_send_command, cmd)
+ run_as_background_process(
+ "send-cmd", self._async_send_command, cmd, bg_start_span=False
+ )
async def _async_send_command(self, cmd: Command):
"""Encode a replication command and send it over our outbound connection"""
@@ -228,3 +230,41 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
p.password = self.password
return p
+
+
+def lazyConnection(
+ reactor,
+ host: str = "localhost",
+ port: int = 6379,
+ dbid: Optional[int] = None,
+ reconnect: bool = True,
+ charset: str = "utf-8",
+ password: Optional[str] = None,
+ connectTimeout: Optional[int] = None,
+ replyTimeout: Optional[int] = None,
+ convertNumbers: bool = True,
+) -> txredisapi.RedisProtocol:
+ """Equivalent to `txredisapi.lazyConnection`, except allows specifying a
+ reactor.
+ """
+
+ isLazy = True
+ poolsize = 1
+
+ uuid = "%s:%d" % (host, port)
+ factory = txredisapi.RedisFactory(
+ uuid,
+ dbid,
+ poolsize,
+ isLazy,
+ txredisapi.ConnectionHandler,
+ charset,
+ password,
+ replyTimeout,
+ convertNumbers,
+ )
+ factory.continueTrying = reconnect
+ for x in range(poolsize):
+ reactor.connectTCP(host, port, factory, connectTimeout)
+
+ return factory.handler
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 687984e7a8..1d4ceac0f1 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -23,7 +23,9 @@ from prometheus_client import Counter
from twisted.internet.protocol import Factory
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.replication.tcp.commands import PositionCommand
from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
+from synapse.replication.tcp.streams import EventsStream
from synapse.util.metrics import Measure
stream_updates_counter = Counter(
@@ -84,6 +86,23 @@ class ReplicationStreamer:
# Set of streams to replicate.
self.streams = self.command_handler.get_streams_to_replicate()
+ # If we have streams then we must have redis enabled or on master
+ assert (
+ not self.streams
+ or hs.config.redis.redis_enabled
+ or not hs.config.worker.worker_app
+ )
+
+ # If we are replicating an event stream we want to periodically check if
+ # we should send updated POSITIONs. We do this as a looping call rather
+ # explicitly poking when the position advances (without new data to
+ # replicate) to reduce replication traffic (otherwise each writer would
+ # likely send a POSITION for each new event received over replication).
+ #
+ # Note that if the position hasn't advanced then we won't send anything.
+ if any(EventsStream.NAME == s.NAME for s in self.streams):
+ self.clock.looping_call(self.on_notifier_poke, 1000)
+
def on_notifier_poke(self):
"""Checks if there is actually any new data and sends it to the
connections if there are.
@@ -91,13 +110,23 @@ class ReplicationStreamer:
This should get called each time new data is available, even if it
is currently being executed, so that nothing gets missed
"""
- if not self.command_handler.connected():
+ if not self.command_handler.connected() or not self.streams:
# Don't bother if nothing is listening. We still need to advance
# the stream tokens otherwise they'll fall behind forever
for stream in self.streams:
stream.discard_updates_and_advance()
return
+ # We check up front to see if anything has actually changed, as we get
+ # poked because of changes that happened on other instances.
+ if all(
+ stream.last_token == stream.current_token(self._instance_name)
+ for stream in self.streams
+ ):
+ return
+
+ # If there are updates then we need to set this even if we're already
+ # looping, as the loop needs to know that he might need to loop again.
self.pending_updates = True
if self.is_looping:
@@ -136,6 +165,8 @@ class ReplicationStreamer:
self._replication_torture_level / 1000.0
)
+ last_token = stream.last_token
+
logger.debug(
"Getting stream: %s: %s -> %s",
stream.NAME,
@@ -159,6 +190,30 @@ class ReplicationStreamer:
)
stream_updates_counter.labels(stream.NAME).inc(len(updates))
+ else:
+ # The token has advanced but there is no data to
+ # send, so we send a `POSITION` to inform other
+ # workers of the updated position.
+ if stream.NAME == EventsStream.NAME:
+ # XXX: We only do this for the EventStream as it
+ # turns out that e.g. account data streams share
+ # their "current token" with each other, meaning
+ # that it is *not* safe to send a POSITION.
+ logger.info(
+ "Sending position: %s -> %s",
+ stream.NAME,
+ current_token,
+ )
+ self.command_handler.send_command(
+ PositionCommand(
+ stream.NAME,
+ self._instance_name,
+ last_token,
+ current_token,
+ )
+ )
+ continue
+
# Some streams return multiple rows with the same stream IDs,
# we need to make sure they get sent out in batches. We do
# this by setting the current token to all but the last of
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 54dccd15a6..61b282ab2d 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -240,13 +240,18 @@ class BackfillStream(Stream):
ROW_TYPE = BackfillStreamRow
def __init__(self, hs):
- store = hs.get_datastore()
+ self.store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- current_token_without_instance(store.get_current_backfill_token),
- store.get_all_new_backfill_event_rows,
+ self._current_token,
+ self.store.get_all_new_backfill_event_rows,
)
+ def _current_token(self, instance_name: str) -> int:
+ # The backfill stream over replication operates on *positive* numbers,
+ # which means we need to negate it.
+ return -self.store._backfill_id_gen.get_current_token_for_writer(instance_name)
+
class PresenceStream(Stream):
PresenceStreamRow = namedtuple(
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index ccc7ca30d8..86a62b71eb 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -15,12 +15,15 @@
# limitations under the License.
import heapq
from collections.abc import Iterable
-from typing import List, Tuple, Type
+from typing import TYPE_CHECKING, List, Optional, Tuple, Type
import attr
from ._base import Stream, StreamUpdateResult, Token
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
"""Handling of the 'events' replication stream
This stream contains rows of various types. Each row therefore contains a 'type'
@@ -81,12 +84,14 @@ class BaseEventsStreamRow:
class EventsStreamEventRow(BaseEventsStreamRow):
TypeId = "ev"
- event_id = attr.ib() # str
- room_id = attr.ib() # str
- type = attr.ib() # str
- state_key = attr.ib() # str, optional
- redacts = attr.ib() # str, optional
- relates_to = attr.ib() # str, optional
+ event_id = attr.ib(type=str)
+ room_id = attr.ib(type=str)
+ type = attr.ib(type=str)
+ state_key = attr.ib(type=Optional[str])
+ redacts = attr.ib(type=Optional[str])
+ relates_to = attr.ib(type=Optional[str])
+ membership = attr.ib(type=Optional[str])
+ rejected = attr.ib(type=bool)
@attr.s(slots=True, frozen=True)
@@ -113,7 +118,7 @@ class EventsStream(Stream):
NAME = "events"
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self._store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
@@ -155,7 +160,7 @@ class EventsStream(Stream):
# now we fetch up to that many rows from the events table
event_rows = await self._store.get_all_new_forward_event_rows(
- from_token, current_token, target_row_count
+ instance_name, from_token, current_token, target_row_count
) # type: List[Tuple]
# we rely on get_all_new_forward_event_rows strictly honouring the limit, so
@@ -180,7 +185,7 @@ class EventsStream(Stream):
upper_limit,
state_rows_limited,
) = await self._store.get_all_updated_current_state_deltas(
- from_token, upper_limit, target_row_count
+ instance_name, from_token, upper_limit, target_row_count
)
limited = limited or state_rows_limited
@@ -189,7 +194,7 @@ class EventsStream(Stream):
# not to bother with the limit.
ex_outliers_rows = await self._store.get_ex_outlier_stream_rows(
- from_token, upper_limit
+ instance_name, from_token, upper_limit
) # type: List[Tuple]
# we now need to turn the raw database rows returned into tuples suitable
|