diff options
Diffstat (limited to 'synapse/replication/http/_base.py')
-rw-r--r-- | synapse/replication/http/_base.py | 47 |
1 files changed, 41 insertions, 6 deletions
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 2b3972cb14..1492ac922c 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -106,6 +106,25 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): assert self.METHOD in ("PUT", "POST", "GET") + self._replication_secret = None + if hs.config.worker.worker_replication_secret: + self._replication_secret = hs.config.worker.worker_replication_secret + + def _check_auth(self, request) -> None: + # Get the authorization header. + auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") + + if len(auth_headers) > 1: + raise RuntimeError("Too many Authorization headers.") + parts = auth_headers[0].split(b" ") + if parts[0] == b"Bearer" and len(parts) == 2: + received_secret = parts[1].decode("ascii") + if self._replication_secret == received_secret: + # Success! + return + + raise RuntimeError("Invalid Authorization header.") + @abc.abstractmethod async def _serialize_payload(**kwargs): """Static method that is called when creating a request. @@ -150,6 +169,12 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): outgoing_gauge = _pending_outgoing_requests.labels(cls.NAME) + replication_secret = None + if hs.config.worker.worker_replication_secret: + replication_secret = hs.config.worker.worker_replication_secret.encode( + "ascii" + ) + @trace(opname="outgoing_replication_request") @outgoing_gauge.track_inprogress() async def send_request(instance_name="master", **kwargs): @@ -202,6 +227,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): # the master, and so whether we should clean up or not. while True: headers = {} # type: Dict[bytes, List[bytes]] + # Add an authorization header, if configured. + if replication_secret: + headers[b"Authorization"] = [b"Bearer " + replication_secret] inject_active_span_byte_dict(headers, None, check_destination=False) try: result = await request_func(uri, data, headers=headers) @@ -236,21 +264,19 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): """ url_args = list(self.PATH_ARGS) - handler = self._handle_request method = self.METHOD if self.CACHE: - handler = self._cached_handler # type: ignore url_args.append("txn_id") args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args) pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args)) http_server.register_paths( - method, [pattern], handler, self.__class__.__name__, + method, [pattern], self._check_auth_and_handle, self.__class__.__name__, ) - def _cached_handler(self, request, txn_id, **kwargs): + def _check_auth_and_handle(self, request, **kwargs): """Called on new incoming requests when caching is enabled. Checks if there is a cached response for the request and returns that, otherwise calls `_handle_request` and caches its response. @@ -258,6 +284,15 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): # We just use the txn_id here, but we probably also want to use the # other PATH_ARGS as well. - assert self.CACHE + # Check the authorization headers before handling the request. + if self._replication_secret: + self._check_auth(request) + + if self.CACHE: + txn_id = kwargs.pop("txn_id") + + return self.response_cache.wrap( + txn_id, self._handle_request, request, **kwargs + ) - return self.response_cache.wrap(txn_id, self._handle_request, request, **kwargs) + return self._handle_request(request, **kwargs) |