diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 2b3972cb14..1492ac922c 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -106,6 +106,25 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
assert self.METHOD in ("PUT", "POST", "GET")
+ self._replication_secret = None
+ if hs.config.worker.worker_replication_secret:
+ self._replication_secret = hs.config.worker.worker_replication_secret
+
+ def _check_auth(self, request) -> None:
+ # Get the authorization header.
+ auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
+
+ if len(auth_headers) > 1:
+ raise RuntimeError("Too many Authorization headers.")
+ parts = auth_headers[0].split(b" ")
+ if parts[0] == b"Bearer" and len(parts) == 2:
+ received_secret = parts[1].decode("ascii")
+ if self._replication_secret == received_secret:
+ # Success!
+ return
+
+ raise RuntimeError("Invalid Authorization header.")
+
@abc.abstractmethod
async def _serialize_payload(**kwargs):
"""Static method that is called when creating a request.
@@ -150,6 +169,12 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
outgoing_gauge = _pending_outgoing_requests.labels(cls.NAME)
+ replication_secret = None
+ if hs.config.worker.worker_replication_secret:
+ replication_secret = hs.config.worker.worker_replication_secret.encode(
+ "ascii"
+ )
+
@trace(opname="outgoing_replication_request")
@outgoing_gauge.track_inprogress()
async def send_request(instance_name="master", **kwargs):
@@ -202,6 +227,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
# the master, and so whether we should clean up or not.
while True:
headers = {} # type: Dict[bytes, List[bytes]]
+ # Add an authorization header, if configured.
+ if replication_secret:
+ headers[b"Authorization"] = [b"Bearer " + replication_secret]
inject_active_span_byte_dict(headers, None, check_destination=False)
try:
result = await request_func(uri, data, headers=headers)
@@ -236,21 +264,19 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
"""
url_args = list(self.PATH_ARGS)
- handler = self._handle_request
method = self.METHOD
if self.CACHE:
- handler = self._cached_handler # type: ignore
url_args.append("txn_id")
args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args))
http_server.register_paths(
- method, [pattern], handler, self.__class__.__name__,
+ method, [pattern], self._check_auth_and_handle, self.__class__.__name__,
)
- def _cached_handler(self, request, txn_id, **kwargs):
+ def _check_auth_and_handle(self, request, **kwargs):
"""Called on new incoming requests when caching is enabled. Checks
if there is a cached response for the request and returns that,
otherwise calls `_handle_request` and caches its response.
@@ -258,6 +284,15 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
# We just use the txn_id here, but we probably also want to use the
# other PATH_ARGS as well.
- assert self.CACHE
+ # Check the authorization headers before handling the request.
+ if self._replication_secret:
+ self._check_auth(request)
+
+ if self.CACHE:
+ txn_id = kwargs.pop("txn_id")
+
+ return self.response_cache.wrap(
+ txn_id, self._handle_request, request, **kwargs
+ )
- return self.response_cache.wrap(txn_id, self._handle_request, request, **kwargs)
+ return self._handle_request(request, **kwargs)
diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py
index eb74903d68..0d39a93ed2 100644
--- a/synapse/replication/slave/storage/_slaved_id_tracker.py
+++ b/synapse/replication/slave/storage/_slaved_id_tracker.py
@@ -12,21 +12,31 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import List, Optional, Tuple
+from synapse.storage.types import Connection
from synapse.storage.util.id_generators import _load_current_id
class SlavedIdTracker:
- def __init__(self, db_conn, table, column, extra_tables=[], step=1):
+ def __init__(
+ self,
+ db_conn: Connection,
+ table: str,
+ column: str,
+ extra_tables: Optional[List[Tuple[str, str]]] = None,
+ step: int = 1,
+ ):
self.step = step
self._current = _load_current_id(db_conn, table, column, step)
- for table, column in extra_tables:
- self.advance(None, _load_current_id(db_conn, table, column))
+ if extra_tables:
+ for table, column in extra_tables:
+ self.advance(None, _load_current_id(db_conn, table, column))
- def advance(self, instance_name, new_id):
+ def advance(self, instance_name: Optional[str], new_id: int):
self._current = (max if self.step > 0 else min)(self._current, new_id)
- def get_current_token(self):
+ def get_current_token(self) -> int:
"""
Returns:
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index c418730ba8..045bd014da 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -13,26 +13,33 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import TYPE_CHECKING
from synapse.replication.tcp.streams import PushersStream
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.pusher import PusherWorkerStore
+from synapse.storage.types import Connection
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
- self._pushers_id_gen = SlavedIdTracker(
+ self._pushers_id_gen = SlavedIdTracker( # type: ignore
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
)
- def get_pushers_stream_token(self):
+ def get_pushers_stream_token(self) -> int:
return self._pushers_id_gen.get_current_token()
- def process_replication_rows(self, stream_name, instance_name, token, rows):
+ def process_replication_rows(
+ self, stream_name: str, instance_name: str, token, rows
+ ) -> None:
if stream_name == PushersStream.NAME:
- self._pushers_id_gen.advance(instance_name, token)
+ self._pushers_id_gen.advance(instance_name, token) # type: ignore
return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index a509e599c2..804da994ea 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -172,8 +172,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# a logcontext which we use for processing incoming commands. We declare it as a
# background process so that the CPU stats get reported to prometheus.
ctx_name = "replication-conn-%s" % self.conn_id
- self._logging_context = BackgroundProcessLoggingContext(ctx_name)
- self._logging_context.request = ctx_name
+ self._logging_context = BackgroundProcessLoggingContext(ctx_name, ctx_name)
def connectionMade(self):
logger.info("[%s] Connection established", self.id())
|