diff options
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/config/server.py | 13 | ||||
-rw-r--r-- | synapse/http/site.py | 14 |
2 files changed, 23 insertions, 4 deletions
diff --git a/synapse/config/server.py b/synapse/config/server.py index c91df636d9..f2353ce5fb 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -206,6 +206,7 @@ class HttpListenerConfig: resources: List[HttpResourceConfig] = attr.Factory(list) additional_resources: Dict[str, dict] = attr.Factory(dict) tag: Optional[str] = None + request_id_header: Optional[str] = None @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -520,9 +521,11 @@ class ServerConfig(Config): ): raise ConfigError("allowed_avatar_mimetypes must be a list") - self.listeners = [ - parse_listener_def(i, x) for i, x in enumerate(config.get("listeners", [])) - ] + listeners = config.get("listeners", []) + if not isinstance(listeners, list): + raise ConfigError("Expected a list", ("listeners",)) + + self.listeners = [parse_listener_def(i, x) for i, x in enumerate(listeners)] # no_tls is not really supported any more, but let's grandfather it in # here. @@ -889,6 +892,9 @@ def read_gc_thresholds( def parse_listener_def(num: int, listener: Any) -> ListenerConfig: """parse a listener config from the config file""" + if not isinstance(listener, dict): + raise ConfigError("Expected a dictionary", ("listeners", str(num))) + listener_type = listener["type"] # Raise a helpful error if direct TCP replication is still configured. if listener_type == "replication": @@ -928,6 +934,7 @@ def parse_listener_def(num: int, listener: Any) -> ListenerConfig: resources=resources, additional_resources=listener.get("additional_resources", {}), tag=listener.get("tag"), + request_id_header=listener.get("request_id_header"), ) return ListenerConfig(port, bind_addresses, listener_type, tls, http_config) diff --git a/synapse/http/site.py b/synapse/http/site.py index 1155f3f610..55a6afce35 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -72,10 +72,12 @@ class SynapseRequest(Request): site: "SynapseSite", *args: Any, max_request_body_size: int = 1024, + request_id_header: Optional[str] = None, **kw: Any, ): super().__init__(channel, *args, **kw) self._max_request_body_size = max_request_body_size + self.request_id_header = request_id_header self.synapse_site = site self.reactor = site.reactor self._channel = channel # this is used by the tests @@ -172,7 +174,14 @@ class SynapseRequest(Request): self._opentracing_span = span def get_request_id(self) -> str: - return "%s-%i" % (self.get_method(), self.request_seq) + request_id_value = None + if self.request_id_header: + request_id_value = self.getHeader(self.request_id_header) + + if request_id_value is None: + request_id_value = str(self.request_seq) + + return "%s-%s" % (self.get_method(), request_id_value) def get_redacted_uri(self) -> str: """Gets the redacted URI associated with the request (or placeholder if the URI @@ -611,12 +620,15 @@ class SynapseSite(Site): proxied = config.http_options.x_forwarded request_class = XForwardedForRequest if proxied else SynapseRequest + request_id_header = config.http_options.request_id_header + def request_factory(channel: HTTPChannel, queued: bool) -> Request: return request_class( channel, self, max_request_body_size=max_request_body_size, queued=queued, + request_id_header=request_id_header, ) self.requestFactory = request_factory # type: ignore |