diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 1be1ccbdf3..f88c80ae84 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -16,6 +16,7 @@
import abc
import logging
import re
+from inspect import signature
from typing import Dict, List, Tuple
from six import raise_from
@@ -60,6 +61,8 @@ class ReplicationEndpoint(object):
must call `register` to register the path with the HTTP server.
Requests can be sent by calling the client returned by `make_client`.
+ Requests are sent to master process by default, but can be sent to other
+ named processes by specifying an `instance_name` keyword argument.
Attributes:
NAME (str): A name for the endpoint, added to the path as well as used
@@ -91,6 +94,16 @@ class ReplicationEndpoint(object):
hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000
)
+ # We reserve `instance_name` as a parameter to sending requests, so we
+ # assert here that sub classes don't try and use the name.
+ assert (
+ "instance_name" not in self.PATH_ARGS
+ ), "`instance_name` is a reserved paramater name"
+ assert (
+ "instance_name"
+ not in signature(self.__class__._serialize_payload).parameters
+ ), "`instance_name` is a reserved paramater name"
+
assert self.METHOD in ("PUT", "POST", "GET")
@abc.abstractmethod
@@ -135,7 +148,11 @@ class ReplicationEndpoint(object):
@trace(opname="outgoing_replication_request")
@defer.inlineCallbacks
- def send_request(**kwargs):
+ def send_request(instance_name="master", **kwargs):
+ # Currently we only support sending requests to master process.
+ if instance_name != "master":
+ raise Exception("Unknown instance")
+
data = yield cls._serialize_payload(**kwargs)
url_args = [
diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py
index f35cebc710..0459f582bf 100644
--- a/synapse/replication/http/streams.py
+++ b/synapse/replication/http/streams.py
@@ -50,6 +50,8 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
def __init__(self, hs):
super().__init__(hs)
+ self._instance_name = hs.get_instance_name()
+
# We pull the streams from the replication steamer (if we try and make
# them ourselves we end up in an import loop).
self.streams = hs.get_replication_streamer().get_streams()
@@ -67,7 +69,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
upto_token = parse_integer(request, "upto_token", required=True)
updates, upto_token, limited = await stream.get_updates_since(
- from_token, upto_token
+ self._instance_name, from_token, upto_token
)
return (
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 5c28fd4ac3..3bbf3c3569 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -86,17 +86,19 @@ class ReplicationDataHandler:
def __init__(self, store: BaseSlavedStore):
self.store = store
- async def on_rdata(self, stream_name: str, token: int, rows: list):
+ async def on_rdata(
+ self, stream_name: str, instance_name: str, token: int, rows: list
+ ):
"""Called to handle a batch of replication data with a given stream token.
By default this just pokes the slave store. Can be overridden in subclasses to
handle more.
Args:
- stream_name (str): name of the replication stream for this batch of rows
- token (int): stream token for this batch of rows
- rows (list): a list of Stream.ROW_TYPE objects as returned by
- Stream.parse_row.
+ stream_name: name of the replication stream for this batch of rows
+ instance_name: the instance that wrote the rows.
+ token: stream token for this batch of rows
+ rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
"""
self.store.process_replication_rows(stream_name, token, rows)
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index d72f3d0cf9..2d1d119c7c 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -278,19 +278,24 @@ class ReplicationCommandHandler:
# Check if this is the last of a batch of updates
rows = self._pending_batches.pop(stream_name, [])
rows.append(row)
- await self.on_rdata(stream_name, cmd.token, rows)
+ await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
- async def on_rdata(self, stream_name: str, token: int, rows: list):
+ async def on_rdata(
+ self, stream_name: str, instance_name: str, token: int, rows: list
+ ):
"""Called to handle a batch of replication data with a given stream token.
Args:
stream_name: name of the replication stream for this batch of rows
+ instance_name: the instance that wrote the rows.
token: stream token for this batch of rows
rows: a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
"""
logger.debug("Received rdata %s -> %s", stream_name, token)
- await self._replication_data_handler.on_rdata(stream_name, token, rows)
+ await self._replication_data_handler.on_rdata(
+ stream_name, instance_name, token, rows
+ )
async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
if cmd.instance_name == self._instance_name:
@@ -325,7 +330,9 @@ class ReplicationCommandHandler:
updates,
current_token,
missing_updates,
- ) = await stream.get_updates_since(current_token, cmd.token)
+ ) = await stream.get_updates_since(
+ cmd.instance_name, current_token, cmd.token
+ )
# TODO: add some tests for this
@@ -334,7 +341,10 @@ class ReplicationCommandHandler:
for token, rows in _batch_updates(updates):
await self.on_rdata(
- cmd.stream_name, token, [stream.parse_row(row) for row in rows],
+ cmd.stream_name,
+ cmd.instance_name,
+ token,
+ [stream.parse_row(row) for row in rows],
)
# We've now caught up to position sent to us, notify handler.
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 4af1afd119..b0f87c365b 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -16,7 +16,7 @@
import logging
from collections import namedtuple
-from typing import Any, Awaitable, Callable, Iterable, List, Optional, Tuple
+from typing import Any, Awaitable, Callable, List, Optional, Tuple
import attr
@@ -53,6 +53,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool]
#
# The arguments are:
#
+# * instance_name: the writer of the stream
# * from_token: the previous stream token: the starting point for fetching the
# updates
# * to_token: the new stream token: the point to get updates up to
@@ -62,7 +63,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool]
# If there are more updates available, it should set `limited` in the result, and
# it will be called again to get the next batch.
#
-UpdateFunction = Callable[[Token, Token, int], Awaitable[StreamUpdateResult]]
+UpdateFunction = Callable[[str, Token, Token, int], Awaitable[StreamUpdateResult]]
class Stream(object):
@@ -93,6 +94,7 @@ class Stream(object):
def __init__(
self,
+ local_instance_name: str,
current_token_function: Callable[[], Token],
update_function: UpdateFunction,
):
@@ -108,9 +110,11 @@ class Stream(object):
stream tokens. See the UpdateFunction type definition for more info.
Args:
+ local_instance_name: The instance name of the current process
current_token_function: callback to get the current token, as above
update_function: callback go get stream updates, as above
"""
+ self.local_instance_name = local_instance_name
self.current_token = current_token_function
self.update_function = update_function
@@ -135,14 +139,14 @@ class Stream(object):
"""
current_token = self.current_token()
updates, current_token, limited = await self.get_updates_since(
- self.last_token, current_token
+ self.local_instance_name, self.last_token, current_token
)
self.last_token = current_token
return updates, current_token, limited
async def get_updates_since(
- self, from_token: Token, upto_token: Token
+ self, instance_name: str, from_token: Token, upto_token: Token
) -> StreamUpdateResult:
"""Like get_updates except allows specifying from when we should
stream updates
@@ -160,19 +164,19 @@ class Stream(object):
return [], upto_token, False
updates, upto_token, limited = await self.update_function(
- from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT,
+ instance_name, from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT,
)
return updates, upto_token, limited
def db_query_to_update_function(
- query_function: Callable[[Token, Token, int], Awaitable[Iterable[tuple]]]
+ query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
) -> UpdateFunction:
"""Wraps a db query function which returns a list of rows to make it
suitable for use as an `update_function` for the Stream class
"""
- async def update_function(from_token, upto_token, limit):
+ async def update_function(instance_name, from_token, upto_token, limit):
rows = await query_function(from_token, upto_token, limit)
updates = [(row[0], row[1:]) for row in rows]
limited = False
@@ -193,10 +197,13 @@ def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
client = ReplicationGetStreamUpdates.make_client(hs)
async def update_function(
- from_token: int, upto_token: int, limit: int
+ instance_name: str, from_token: int, upto_token: int, limit: int
) -> StreamUpdateResult:
result = await client(
- stream_name=stream_name, from_token=from_token, upto_token=upto_token,
+ instance_name=instance_name,
+ stream_name=stream_name,
+ from_token=from_token,
+ upto_token=upto_token,
)
return result["updates"], result["upto_token"], result["limited"]
@@ -226,6 +233,7 @@ class BackfillStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
+ hs.get_instance_name(),
store.get_current_backfill_token,
db_query_to_update_function(store.get_all_new_backfill_event_rows),
)
@@ -261,7 +269,9 @@ class PresenceStream(Stream):
# Query master process
update_function = make_http_update_function(hs, self.NAME)
- super().__init__(store.get_current_presence_token, update_function)
+ super().__init__(
+ hs.get_instance_name(), store.get_current_presence_token, update_function
+ )
class TypingStream(Stream):
@@ -284,7 +294,9 @@ class TypingStream(Stream):
# Query master process
update_function = make_http_update_function(hs, self.NAME)
- super().__init__(typing_handler.get_current_token, update_function)
+ super().__init__(
+ hs.get_instance_name(), typing_handler.get_current_token, update_function
+ )
class ReceiptsStream(Stream):
@@ -305,6 +317,7 @@ class ReceiptsStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
+ hs.get_instance_name(),
store.get_max_receipt_stream_id,
db_query_to_update_function(store.get_all_updated_receipts),
)
@@ -322,14 +335,16 @@ class PushRulesStream(Stream):
def __init__(self, hs):
self.store = hs.get_datastore()
super(PushRulesStream, self).__init__(
- self._current_token, self._update_function
+ hs.get_instance_name(), self._current_token, self._update_function
)
def _current_token(self) -> int:
push_rules_token, _ = self.store.get_push_rules_stream_token()
return push_rules_token
- async def _update_function(self, from_token: Token, to_token: Token, limit: int):
+ async def _update_function(
+ self, instance_name: str, from_token: Token, to_token: Token, limit: int
+ ):
rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
limited = False
@@ -356,6 +371,7 @@ class PushersStream(Stream):
store = hs.get_datastore()
super().__init__(
+ hs.get_instance_name(),
store.get_pushers_stream_token,
db_query_to_update_function(store.get_all_updated_pushers_rows),
)
@@ -387,6 +403,7 @@ class CachesStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
+ hs.get_instance_name(),
store.get_cache_stream_token,
db_query_to_update_function(store.get_all_updated_caches),
)
@@ -412,6 +429,7 @@ class PublicRoomsStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
+ hs.get_instance_name(),
store.get_current_public_room_stream_id,
db_query_to_update_function(store.get_all_new_public_rooms),
)
@@ -432,6 +450,7 @@ class DeviceListsStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
+ hs.get_instance_name(),
store.get_device_stream_token,
db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
)
@@ -449,6 +468,7 @@ class ToDeviceStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
+ hs.get_instance_name(),
store.get_to_device_stream_token,
db_query_to_update_function(store.get_all_new_device_messages),
)
@@ -468,6 +488,7 @@ class TagAccountDataStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
+ hs.get_instance_name(),
store.get_max_account_data_stream_id,
db_query_to_update_function(store.get_all_updated_tags),
)
@@ -487,6 +508,7 @@ class AccountDataStream(Stream):
def __init__(self, hs):
self.store = hs.get_datastore()
super().__init__(
+ hs.get_instance_name(),
self.store.get_max_account_data_stream_id,
db_query_to_update_function(self._update_function),
)
@@ -517,6 +539,7 @@ class GroupServerStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
+ hs.get_instance_name(),
store.get_group_stream_token,
db_query_to_update_function(store.get_all_groups_changes),
)
@@ -534,6 +557,7 @@ class UserSignatureStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
+ hs.get_instance_name(),
store.get_device_stream_token,
db_query_to_update_function(
store.get_all_user_signature_changes_for_remotes
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 52df81b1bd..890e75d827 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -118,11 +118,17 @@ class EventsStream(Stream):
def __init__(self, hs):
self._store = hs.get_datastore()
super().__init__(
- self._store.get_current_events_token, self._update_function,
+ hs.get_instance_name(),
+ self._store.get_current_events_token,
+ self._update_function,
)
async def _update_function(
- self, from_token: Token, current_token: Token, target_row_count: int
+ self,
+ instance_name: str,
+ from_token: Token,
+ current_token: Token,
+ target_row_count: int,
) -> StreamUpdateResult:
# the events stream merges together three separate sources:
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index 75133d7e40..e8bd52e389 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -48,8 +48,8 @@ class FederationStream(Stream):
current_token = lambda: 0
update_function = self._stub_update_function
- super().__init__(current_token, update_function)
+ super().__init__(hs.get_instance_name(), current_token, update_function)
@staticmethod
- async def _stub_update_function(from_token, upto_token, limit):
+ async def _stub_update_function(instance_name, from_token, upto_token, limit):
return [], upto_token, False
|