From b3bcacf3c1c72bfadeb46fe4d0198ca155a8c615 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 9 Dec 2021 12:23:34 +0100 Subject: Add missing `errcode` to `parse_string` and `parse_boolean` (#11542) --- synapse/http/servlet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'synapse/http') diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 6dd9b9ad03..1627225f28 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -246,7 +246,7 @@ def parse_boolean_from_args( message = ( "Boolean query parameter %r must be one of ['true', 'false']" ) % (name,) - raise SynapseError(400, message) + raise SynapseError(400, message, errcode=Codes.INVALID_PARAM) else: if required: message = "Missing boolean query parameter %r" % (name,) @@ -414,7 +414,7 @@ def _parse_string_value( name, ", ".join(repr(v) for v in allowed_values), ) - raise SynapseError(400, message) + raise SynapseError(400, message, errcode=Codes.INVALID_PARAM) else: return value_str -- cgit 1.5.1 From 941ebe49ffc32c6d67b487094a6f8e1c290e93bc Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 9 Dec 2021 12:58:25 +0100 Subject: Use HTTPStatus constants in place of literals in `synapse.http` (#11543) --- changelog.d/11543.misc | 1 + synapse/http/client.py | 15 ++++++++--- synapse/http/matrixfederationclient.py | 3 ++- synapse/http/servlet.py | 47 ++++++++++++++++++++++++---------- 4 files changed, 47 insertions(+), 19 deletions(-) create mode 100644 changelog.d/11543.misc (limited to 'synapse/http') diff --git a/changelog.d/11543.misc b/changelog.d/11543.misc new file mode 100644 index 0000000000..99817d71a4 --- /dev/null +++ b/changelog.d/11543.misc @@ -0,0 +1 @@ +Use HTTPStatus constants in place of literals in `synapse.http`. \ No newline at end of file diff --git a/synapse/http/client.py b/synapse/http/client.py index b5a2d333a6..fbbeceabeb 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -14,6 +14,7 @@ # limitations under the License. import logging import urllib.parse +from http import HTTPStatus from io import BytesIO from typing import ( TYPE_CHECKING, @@ -280,7 +281,9 @@ class BlacklistingAgentWrapper(Agent): ip_address, self._ip_whitelist, self._ip_blacklist ): logger.info("Blocking access to %s due to blacklist" % (ip_address,)) - e = SynapseError(403, "IP address blocked by IP blacklist entry") + e = SynapseError( + HTTPStatus.FORBIDDEN, "IP address blocked by IP blacklist entry" + ) return defer.fail(Failure(e)) return self._agent.request( @@ -719,7 +722,9 @@ class SimpleHttpClient: if response.code > 299: logger.warning("Got %d when downloading %s" % (response.code, url)) - raise SynapseError(502, "Got error %d" % (response.code,), Codes.UNKNOWN) + raise SynapseError( + HTTPStatus.BAD_GATEWAY, "Got error %d" % (response.code,), Codes.UNKNOWN + ) # TODO: if our Content-Type is HTML or something, just read the first # N bytes into RAM rather than saving it all to disk only to read it @@ -731,12 +736,14 @@ class SimpleHttpClient: ) except BodyExceededMaxSize: raise SynapseError( - 502, + HTTPStatus.BAD_GATEWAY, "Requested file is too large > %r bytes" % (max_size,), Codes.TOO_LARGE, ) except Exception as e: - raise SynapseError(502, ("Failed to download remote body: %s" % e)) from e + raise SynapseError( + HTTPStatus.BAD_GATEWAY, ("Failed to download remote body: %s" % e) + ) from e return ( length, diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 203d723d41..deedde0b5b 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -19,6 +19,7 @@ import random import sys import typing import urllib.parse +from http import HTTPStatus from io import BytesIO, StringIO from typing import ( TYPE_CHECKING, @@ -1154,7 +1155,7 @@ class MatrixFederationHttpClient: request.destination, msg, ) - raise SynapseError(502, msg, Codes.TOO_LARGE) + raise SynapseError(HTTPStatus.BAD_GATEWAY, msg, Codes.TOO_LARGE) except defer.TimeoutError as e: logger.warning( "{%s} [%s] Timed out reading response - %s %s", diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 1627225f28..e543cc6e01 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -14,6 +14,7 @@ """ This module contains base REST classes for constructing REST servlets. """ import logging +from http import HTTPStatus from typing import ( TYPE_CHECKING, Iterable, @@ -137,11 +138,15 @@ def parse_integer_from_args( return int(args[name_bytes][0]) except Exception: message = "Query parameter %r must be an integer" % (name,) - raise SynapseError(400, message, errcode=Codes.INVALID_PARAM) + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM + ) else: if required: message = "Missing integer query parameter %r" % (name,) - raise SynapseError(400, message, errcode=Codes.MISSING_PARAM) + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM + ) else: return default @@ -246,11 +251,15 @@ def parse_boolean_from_args( message = ( "Boolean query parameter %r must be one of ['true', 'false']" ) % (name,) - raise SynapseError(400, message, errcode=Codes.INVALID_PARAM) + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM + ) else: if required: message = "Missing boolean query parameter %r" % (name,) - raise SynapseError(400, message, errcode=Codes.MISSING_PARAM) + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM + ) else: return default @@ -313,7 +322,7 @@ def parse_bytes_from_args( return args[name_bytes][0] elif required: message = "Missing string query parameter %s" % (name,) - raise SynapseError(400, message, errcode=Codes.MISSING_PARAM) + raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM) return default @@ -407,14 +416,16 @@ def _parse_string_value( try: value_str = value.decode(encoding) except ValueError: - raise SynapseError(400, "Query parameter %r must be %s" % (name, encoding)) + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Query parameter %r must be %s" % (name, encoding) + ) if allowed_values is not None and value_str not in allowed_values: message = "Query parameter %r must be one of [%s]" % ( name, ", ".join(repr(v) for v in allowed_values), ) - raise SynapseError(400, message, errcode=Codes.INVALID_PARAM) + raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM) else: return value_str @@ -510,7 +521,9 @@ def parse_strings_from_args( else: if required: message = "Missing string query parameter %r" % (name,) - raise SynapseError(400, message, errcode=Codes.MISSING_PARAM) + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM + ) return default @@ -638,7 +651,7 @@ def parse_json_value_from_request( try: content_bytes = request.content.read() # type: ignore except Exception: - raise SynapseError(400, "Error reading JSON content.") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Error reading JSON content.") if not content_bytes and allow_empty_body: return None @@ -647,7 +660,9 @@ def parse_json_value_from_request( content = json_decoder.decode(content_bytes.decode("utf-8")) except Exception as e: logger.warning("Unable to parse JSON: %s (%s)", e, content_bytes) - raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Content not JSON.", errcode=Codes.NOT_JSON + ) return content @@ -673,7 +688,7 @@ def parse_json_object_from_request( if not isinstance(content, dict): message = "Content must be a JSON object." - raise SynapseError(400, message, errcode=Codes.BAD_JSON) + raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.BAD_JSON) return content @@ -685,7 +700,9 @@ def assert_params_in_dict(body: JsonDict, required: Iterable[str]) -> None: absent.append(k) if len(absent) > 0: - raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Missing params: %r" % absent, Codes.MISSING_PARAM + ) class RestServlet: @@ -758,10 +775,12 @@ class ResolveRoomIdMixin: resolved_room_id = room_id.to_string() else: raise SynapseError( - 400, "%s was not legal room ID or room alias" % (room_identifier,) + HTTPStatus.BAD_REQUEST, + "%s was not legal room ID or room alias" % (room_identifier,), ) if not resolved_room_id: raise SynapseError( - 400, "Unknown room ID or room alias %s" % room_identifier + HTTPStatus.BAD_REQUEST, + "Unknown room ID or room alias %s" % room_identifier, ) return resolved_room_id, remote_room_hosts -- cgit 1.5.1 From 33abbc327813e65aaa91e10f98a31622c045004c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 14 Dec 2021 07:00:47 -0500 Subject: Add missing type hints to synapse.http. (#11571) --- changelog.d/11571.misc | 1 + mypy.ini | 3 ++ synapse/http/__init__.py | 6 +-- synapse/http/additional_resource.py | 12 +++-- synapse/http/server.py | 90 ++++++++++++++++++------------- synapse/http/servlet.py | 3 +- synapse/http/site.py | 8 +-- synapse/rest/key/v2/local_key_resource.py | 4 +- 8 files changed, 76 insertions(+), 51 deletions(-) create mode 100644 changelog.d/11571.misc (limited to 'synapse/http') diff --git a/changelog.d/11571.misc b/changelog.d/11571.misc new file mode 100644 index 0000000000..4e396b271e --- /dev/null +++ b/changelog.d/11571.misc @@ -0,0 +1 @@ +Add missing type hints to `synapse.http`. diff --git a/mypy.ini b/mypy.ini index a7b1f4eb64..9aeeca2bb2 100644 --- a/mypy.ini +++ b/mypy.ini @@ -161,6 +161,9 @@ disallow_untyped_defs = False [mypy-synapse.handlers.*] disallow_untyped_defs = True +[mypy-synapse.http.server] +disallow_untyped_defs = True + [mypy-synapse.metrics.*] disallow_untyped_defs = True diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py index 578fc48ef4..efecb089c1 100644 --- a/synapse/http/__init__.py +++ b/synapse/http/__init__.py @@ -25,7 +25,7 @@ from synapse.api.errors import SynapseError class RequestTimedOutError(SynapseError): """Exception representing timeout of an outbound request""" - def __init__(self, msg): + def __init__(self, msg: str): super().__init__(504, msg) @@ -33,7 +33,7 @@ ACCESS_TOKEN_RE = re.compile(r"(\?.*access(_|%5[Ff])token=)[^&]*(.*)$") CLIENT_SECRET_RE = re.compile(r"(\?.*client(_|%5[Ff])secret=)[^&]*(.*)$") -def redact_uri(uri): +def redact_uri(uri: str) -> str: """Strips sensitive information from the uri replaces with """ uri = ACCESS_TOKEN_RE.sub(r"\1\3", uri) return CLIENT_SECRET_RE.sub(r"\1\3", uri) @@ -46,7 +46,7 @@ class QuieterFileBodyProducer(FileBodyProducer): https://twistedmatrix.com/trac/ticket/6528 """ - def stopProducing(self): + def stopProducing(self) -> None: try: FileBodyProducer.stopProducing(self) except task.TaskStopped: diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py index 9a2684aca4..6a9f6635d2 100644 --- a/synapse/http/additional_resource.py +++ b/synapse/http/additional_resource.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple from twisted.web.server import Request @@ -32,7 +32,11 @@ class AdditionalResource(DirectServeJsonResource): and exception handling. """ - def __init__(self, hs: "HomeServer", handler): + def __init__( + self, + hs: "HomeServer", + handler: Callable[[Request], Awaitable[Optional[Tuple[int, Any]]]], + ): """Initialise AdditionalResource The ``handler`` should return a deferred which completes when it has @@ -47,7 +51,7 @@ class AdditionalResource(DirectServeJsonResource): super().__init__() self._handler = handler - def _async_render(self, request: Request): + async def _async_render(self, request: Request) -> Optional[Tuple[int, Any]]: # Cheekily pass the result straight through, so we don't need to worry # if its an awaitable or not. - return self._handler(request) + return await self._handler(request) diff --git a/synapse/http/server.py b/synapse/http/server.py index 91badb0b0a..4fd5660a08 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -30,6 +30,7 @@ from typing import ( Iterable, Iterator, List, + NoReturn, Optional, Pattern, Tuple, @@ -170,7 +171,9 @@ def return_html_error( respond_with_html(request, code, body) -def wrap_async_request_handler(h): +def wrap_async_request_handler( + h: Callable[["_AsyncResource", SynapseRequest], Awaitable[None]] +) -> Callable[["_AsyncResource", SynapseRequest], "defer.Deferred[None]"]: """Wraps an async request handler so that it calls request.processing. This helps ensure that work done by the request handler after the request is completed @@ -183,7 +186,9 @@ def wrap_async_request_handler(h): logged until the deferred completes. """ - async def wrapped_async_request_handler(self, request): + async def wrapped_async_request_handler( + self: "_AsyncResource", request: SynapseRequest + ) -> None: with request.processing(): await h(self, request) @@ -240,18 +245,18 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): context from the request the servlet is handling. """ - def __init__(self, extract_context=False): + def __init__(self, extract_context: bool = False): super().__init__() self._extract_context = extract_context - def render(self, request): + def render(self, request: SynapseRequest) -> int: """This gets called by twisted every time someone sends us a request.""" defer.ensureDeferred(self._async_render_wrapper(request)) return NOT_DONE_YET @wrap_async_request_handler - async def _async_render_wrapper(self, request: SynapseRequest): + async def _async_render_wrapper(self, request: SynapseRequest) -> None: """This is a wrapper that delegates to `_async_render` and handles exceptions, return values, metrics, etc. """ @@ -271,7 +276,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): f = failure.Failure() self._send_error_response(f, request) - async def _async_render(self, request: Request): + async def _async_render(self, request: SynapseRequest) -> Optional[Tuple[int, Any]]: """Delegates to `_async_render_` methods, or returns a 400 if no appropriate method exists. Can be overridden in sub classes for different routing. @@ -318,7 +323,7 @@ class DirectServeJsonResource(_AsyncResource): formatting responses and errors as JSON. """ - def __init__(self, canonical_json=False, extract_context=False): + def __init__(self, canonical_json: bool = False, extract_context: bool = False): super().__init__(extract_context) self.canonical_json = canonical_json @@ -327,7 +332,7 @@ class DirectServeJsonResource(_AsyncResource): request: SynapseRequest, code: int, response_object: Any, - ): + ) -> None: """Implements _AsyncResource._send_response""" # TODO: Only enable CORS for the requests that need it. respond_with_json( @@ -368,34 +373,45 @@ class JsonResource(DirectServeJsonResource): isLeaf = True - def __init__(self, hs: "HomeServer", canonical_json=True, extract_context=False): + def __init__( + self, + hs: "HomeServer", + canonical_json: bool = True, + extract_context: bool = False, + ): super().__init__(canonical_json, extract_context) self.clock = hs.get_clock() self.path_regexs: Dict[bytes, List[_PathEntry]] = {} self.hs = hs - def register_paths(self, method, path_patterns, callback, servlet_classname): + def register_paths( + self, + method: str, + path_patterns: Iterable[Pattern], + callback: ServletCallback, + servlet_classname: str, + ) -> None: """ Registers a request handler against a regular expression. Later request URLs are checked against these regular expressions in order to identify an appropriate handler for that request. Args: - method (str): GET, POST etc + method: GET, POST etc - path_patterns (Iterable[str]): A list of regular expressions to which - the request URLs are compared. + path_patterns: A list of regular expressions to which the request + URLs are compared. - callback (function): The handler for the request. Usually a Servlet + callback: The handler for the request. Usually a Servlet - servlet_classname (str): The name of the handler to be used in prometheus + servlet_classname: The name of the handler to be used in prometheus and opentracing logs. """ - method = method.encode("utf-8") # method is bytes on py3 + method_bytes = method.encode("utf-8") for path_pattern in path_patterns: logger.debug("Registering for %s %s", method, path_pattern.pattern) - self.path_regexs.setdefault(method, []).append( + self.path_regexs.setdefault(method_bytes, []).append( _PathEntry(path_pattern, callback, servlet_classname) ) @@ -427,7 +443,7 @@ class JsonResource(DirectServeJsonResource): # Huh. No one wanted to handle that? Fiiiiiine. Send 400. return _unrecognised_request_handler, "unrecognised_request_handler", {} - async def _async_render(self, request): + async def _async_render(self, request: SynapseRequest) -> Tuple[int, Any]: callback, servlet_classname, group_dict = self._get_handler_for_request(request) # Make sure we have an appropriate name for this handler in prometheus @@ -468,7 +484,7 @@ class DirectServeHtmlResource(_AsyncResource): request: SynapseRequest, code: int, response_object: Any, - ): + ) -> None: """Implements _AsyncResource._send_response""" # We expect to get bytes for us to write assert isinstance(response_object, bytes) @@ -492,12 +508,12 @@ class StaticResource(File): Differs from the File resource by adding clickjacking protection. """ - def render_GET(self, request: Request): + def render_GET(self, request: Request) -> bytes: set_clickjacking_protection_headers(request) return super().render_GET(request) -def _unrecognised_request_handler(request): +def _unrecognised_request_handler(request: Request) -> NoReturn: """Request handler for unrecognised requests This is a request handler suitable for return from @@ -505,7 +521,7 @@ def _unrecognised_request_handler(request): UnrecognizedRequestError. Args: - request (twisted.web.http.Request): + request: Unused, but passed in to match the signature of ServletCallback. """ raise UnrecognizedRequestError() @@ -513,14 +529,14 @@ def _unrecognised_request_handler(request): class RootRedirect(resource.Resource): """Redirects the root '/' path to another path.""" - def __init__(self, path): + def __init__(self, path: str): resource.Resource.__init__(self) self.url = path - def render_GET(self, request): + def render_GET(self, request: Request) -> bytes: return redirectTo(self.url.encode("ascii"), request) - def getChild(self, name, request): + def getChild(self, name: str, request: Request) -> resource.Resource: if len(name) == 0: return self # select ourselves as the child to render return resource.Resource.getChild(self, name, request) @@ -529,7 +545,7 @@ class RootRedirect(resource.Resource): class OptionsResource(resource.Resource): """Responds to OPTION requests for itself and all children.""" - def render_OPTIONS(self, request): + def render_OPTIONS(self, request: Request) -> bytes: request.setResponseCode(204) request.setHeader(b"Content-Length", b"0") @@ -537,7 +553,7 @@ class OptionsResource(resource.Resource): return b"" - def getChildWithDefault(self, path, request): + def getChildWithDefault(self, path: str, request: Request) -> resource.Resource: if request.method == b"OPTIONS": return self # select ourselves as the child to render return resource.Resource.getChildWithDefault(self, path, request) @@ -649,7 +665,7 @@ def respond_with_json( json_object: Any, send_cors: bool = False, canonical_json: bool = True, -): +) -> Optional[int]: """Sends encoded JSON in response to the given request. Args: @@ -696,7 +712,7 @@ def respond_with_json_bytes( code: int, json_bytes: bytes, send_cors: bool = False, -): +) -> Optional[int]: """Sends encoded JSON in response to the given request. Args: @@ -713,7 +729,7 @@ def respond_with_json_bytes( logger.warning( "Not sending response to request %s, already disconnected.", request ) - return + return None request.setResponseCode(code) request.setHeader(b"Content-Type", b"application/json") @@ -731,7 +747,7 @@ async def _async_write_json_to_request_in_thread( request: SynapseRequest, json_encoder: Callable[[Any], bytes], json_object: Any, -): +) -> None: """Encodes the given JSON object on a thread and then writes it to the request. @@ -773,7 +789,7 @@ def _write_bytes_to_request(request: Request, bytes_to_write: bytes) -> None: _ByteProducer(request, bytes_generator) -def set_cors_headers(request: Request): +def set_cors_headers(request: Request) -> None: """Set the CORS headers so that javascript running in a web browsers can use this API @@ -790,14 +806,14 @@ def set_cors_headers(request: Request): ) -def respond_with_html(request: Request, code: int, html: str): +def respond_with_html(request: Request, code: int, html: str) -> None: """ Wraps `respond_with_html_bytes` by first encoding HTML from a str to UTF-8 bytes. """ respond_with_html_bytes(request, code, html.encode("utf-8")) -def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes): +def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes) -> None: """ Sends HTML (encoded as UTF-8 bytes) as the response to the given request. @@ -815,7 +831,7 @@ def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes): logger.warning( "Not sending response to request %s, already disconnected.", request ) - return + return None request.setResponseCode(code) request.setHeader(b"Content-Type", b"text/html; charset=utf-8") @@ -828,7 +844,7 @@ def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes): finish_request(request) -def set_clickjacking_protection_headers(request: Request): +def set_clickjacking_protection_headers(request: Request) -> None: """ Set headers to guard against clickjacking of embedded content. @@ -850,7 +866,7 @@ def respond_with_redirect(request: Request, url: bytes) -> None: finish_request(request) -def finish_request(request: Request): +def finish_request(request: Request) -> None: """Finish writing the response to the request. Twisted throws a RuntimeException if the connection closed before the diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index e543cc6e01..4ff840ca0e 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -31,6 +31,7 @@ from typing_extensions import Literal from twisted.web.server import Request from synapse.api.errors import Codes, SynapseError +from synapse.http.server import HttpServer from synapse.types import JsonDict, RoomAlias, RoomID from synapse.util import json_decoder @@ -726,7 +727,7 @@ class RestServlet: into the appropriate HTTP response. """ - def register(self, http_server): + def register(self, http_server: HttpServer) -> None: """Register this servlet with the given HTTP server.""" patterns = getattr(self, "PATTERNS", None) if patterns: diff --git a/synapse/http/site.py b/synapse/http/site.py index 755ad56637..9f68d7e191 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -14,7 +14,7 @@ import contextlib import logging import time -from typing import Generator, Optional, Tuple, Union +from typing import Any, Generator, Optional, Tuple, Union import attr from zope.interface import implementer @@ -66,9 +66,9 @@ class SynapseRequest(Request): self, channel: HTTPChannel, site: "SynapseSite", - *args, + *args: Any, max_request_body_size: int = 1024, - **kw, + **kw: Any, ): super().__init__(channel, *args, **kw) self._max_request_body_size = max_request_body_size @@ -557,7 +557,7 @@ class SynapseSite(Site): proxied = config.http_options.x_forwarded request_class = XForwardedForRequest if proxied else SynapseRequest - def request_factory(channel, queued: bool) -> Request: + def request_factory(channel: HTTPChannel, queued: bool) -> Request: return request_class( channel, self, diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py index 12b3ae120c..b9bfbea21b 100644 --- a/synapse/rest/key/v2/local_key_resource.py +++ b/synapse/rest/key/v2/local_key_resource.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from canonicaljson import encode_canonical_json from signedjson.sign import sign_json @@ -99,7 +99,7 @@ class LocalKey(Resource): json_object = sign_json(json_object, self.config.server.server_name, key) return json_object - def render_GET(self, request: Request) -> int: + def render_GET(self, request: Request) -> Optional[int]: time_now = self.clock.time_msec() # Update the expiry time if less than half the interval remains. if time_now + self.config.key.key_refresh_interval / 2 > self.valid_until_ts: -- cgit 1.5.1 From 0147b3de20f313975226a9a3f319c77b90aa2793 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Tue, 14 Dec 2021 17:35:28 +0000 Subject: Add missing type hints to `synapse.logging.context` (#11556) --- changelog.d/11556.misc | 1 + mypy.ini | 3 + stubs/txredisapi.pyi | 9 +- synapse/federation/federation_server.py | 9 +- synapse/handlers/federation.py | 19 +-- synapse/handlers/initial_sync.py | 33 +++-- synapse/handlers/message.py | 13 +- synapse/http/federation/matrix_federation_agent.py | 7 +- synapse/logging/context.py | 149 ++++++++++++++------- synapse/util/async_helpers.py | 57 +++++++- synapse/util/caches/cached_call.py | 1 + synapse/util/file_consumer.py | 1 + tests/util/test_logcontext.py | 35 ----- 13 files changed, 215 insertions(+), 122 deletions(-) create mode 100644 changelog.d/11556.misc (limited to 'synapse/http') diff --git a/changelog.d/11556.misc b/changelog.d/11556.misc new file mode 100644 index 0000000000..53b26aa676 --- /dev/null +++ b/changelog.d/11556.misc @@ -0,0 +1 @@ +Add missing type hints to `synapse.logging.context`. diff --git a/mypy.ini b/mypy.ini index 4551302c82..1867322044 100644 --- a/mypy.ini +++ b/mypy.ini @@ -167,6 +167,9 @@ disallow_untyped_defs = True [mypy-synapse.http.server] disallow_untyped_defs = True +[mypy-synapse.logging.context] +disallow_untyped_defs = True + [mypy-synapse.metrics.*] disallow_untyped_defs = True diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi index 4ff3c6de5f..429234d7ae 100644 --- a/stubs/txredisapi.pyi +++ b/stubs/txredisapi.pyi @@ -17,11 +17,12 @@ from typing import Any, List, Optional, Type, Union from twisted.internet import protocol +from twisted.internet.defer import Deferred class RedisProtocol(protocol.Protocol): def publish(self, channel: str, message: bytes): ... - async def ping(self) -> None: ... - async def set( + def ping(self) -> "Deferred[None]": ... + def set( self, key: str, value: Any, @@ -29,8 +30,8 @@ class RedisProtocol(protocol.Protocol): pexpire: Optional[int] = None, only_if_not_exists: bool = False, only_if_exists: bool = False, - ) -> None: ... - async def get(self, key: str) -> Any: ... + ) -> "Deferred[None]": ... + def get(self, key: str) -> "Deferred[Any]": ... class SubscriberProtocol(RedisProtocol): def __init__(self, *args, **kwargs): ... diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 8e37e76206..cf067b56c6 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -30,7 +30,6 @@ from typing import ( from prometheus_client import Counter, Gauge, Histogram -from twisted.internet import defer from twisted.internet.abstract import isIPAddress from twisted.python import failure @@ -67,7 +66,7 @@ from synapse.replication.http.federation import ( from synapse.storage.databases.main.lock import Lock from synapse.types import JsonDict, get_domain_from_id from synapse.util import glob_to_regex, json_decoder, unwrapFirstError -from synapse.util.async_helpers import Linearizer, concurrently_execute +from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results from synapse.util.caches.response_cache import ResponseCache from synapse.util.stringutils import parse_server_name @@ -360,13 +359,13 @@ class FederationServer(FederationBase): # want to block things like to device messages from reaching clients # behind the potentially expensive handling of PDUs. pdu_results, _ = await make_deferred_yieldable( - defer.gatherResults( - [ + gather_results( + ( run_in_background( self._handle_pdus_in_txn, origin, transaction, request_time ), run_in_background(self._handle_edus_in_txn, origin, transaction), - ], + ), consumeErrors=True, ).addErrback(unwrapFirstError) ) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 1ea837d082..26b8e3f43c 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -360,31 +360,34 @@ class FederationHandler: logger.debug("calling resolve_state_groups in _maybe_backfill") resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events) - states = await make_deferred_yieldable( + states_list = await make_deferred_yieldable( defer.gatherResults( [resolve(room_id, [e]) for e in event_ids], consumeErrors=True ) ) - # dict[str, dict[tuple, str]], a map from event_id to state map of - # event_ids. - states = dict(zip(event_ids, [s.state for s in states])) + # A map from event_id to state map of event_ids. + state_ids: Dict[str, StateMap[str]] = dict( + zip(event_ids, [s.state for s in states_list]) + ) state_map = await self.store.get_events( - [e_id for ids in states.values() for e_id in ids.values()], + [e_id for ids in state_ids.values() for e_id in ids.values()], get_prev_content=False, ) - states = { + + # A map from event_id to state map of events. + state_events: Dict[str, StateMap[EventBase]] = { key: { k: state_map[e_id] for k, e_id in state_dict.items() if e_id in state_map } - for key, state_dict in states.items() + for key, state_dict in state_ids.items() } for e_id in event_ids: - likely_extremeties_domains = get_domains_from_state(states[e_id]) + likely_extremeties_domains = get_domains_from_state(state_events[e_id]) success = await try_backfill( [ diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 9cd21e7f2b..9ab723ff97 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -13,21 +13,27 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, List, Optional, Tuple - -from twisted.internet import defer +from typing import TYPE_CHECKING, List, Optional, Tuple, cast from synapse.api.constants import EduTypes, EventTypes, Membership from synapse.api.errors import SynapseError +from synapse.events import EventBase from synapse.events.validator import EventValidator from synapse.handlers.presence import format_user_presence_state from synapse.handlers.receipts import ReceiptEventSource from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.storage.roommember import RoomsForUser from synapse.streams.config import PaginationConfig -from synapse.types import JsonDict, Requester, RoomStreamToken, StreamToken, UserID +from synapse.types import ( + JsonDict, + Requester, + RoomStreamToken, + StateMap, + StreamToken, + UserID, +) from synapse.util import unwrapFirstError -from synapse.util.async_helpers import concurrently_execute +from synapse.util.async_helpers import concurrently_execute, gather_results from synapse.util.caches.response_cache import ResponseCache from synapse.visibility import filter_events_for_client @@ -190,14 +196,13 @@ class InitialSyncHandler: ) deferred_room_state = run_in_background( self.state_store.get_state_for_events, [event.event_id] - ) - deferred_room_state.addCallback( - lambda states: states[event.event_id] + ).addCallback( + lambda states: cast(StateMap[EventBase], states[event.event_id]) ) (messages, token), current_state = await make_deferred_yieldable( - defer.gatherResults( - [ + gather_results( + ( run_in_background( self.store.get_recent_events_for_room, event.room_id, @@ -205,7 +210,7 @@ class InitialSyncHandler: end_token=room_end_token, ), deferred_room_state, - ] + ) ) ).addErrback(unwrapFirstError) @@ -454,8 +459,8 @@ class InitialSyncHandler: return receipts presence, receipts, (messages, token) = await make_deferred_yieldable( - defer.gatherResults( - [ + gather_results( + ( run_in_background(get_presence), run_in_background(get_receipts), run_in_background( @@ -464,7 +469,7 @@ class InitialSyncHandler: limit=limit, end_token=now_token.room_key, ), - ], + ), consumeErrors=True, ).addErrback(unwrapFirstError) ) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 38409fef38..5e3d3886eb 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -21,7 +21,6 @@ from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple from canonicaljson import encode_canonical_json -from twisted.internet import defer from twisted.internet.interfaces import IDelayedCall from synapse import event_auth @@ -57,7 +56,7 @@ from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.state import StateFilter from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester from synapse.util import json_decoder, json_encoder, log_failure -from synapse.util.async_helpers import Linearizer, unwrapFirstError +from synapse.util.async_helpers import Linearizer, gather_results, unwrapFirstError from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import measure_func from synapse.visibility import filter_events_for_client @@ -1168,9 +1167,9 @@ class EventCreationHandler: # We now persist the event (and update the cache in parallel, since we # don't want to block on it). - result = await make_deferred_yieldable( - defer.gatherResults( - [ + result, _ = await make_deferred_yieldable( + gather_results( + ( run_in_background( self._persist_event, requester=requester, @@ -1182,12 +1181,12 @@ class EventCreationHandler: run_in_background( self.cache_joined_hosts_for_event, event, context ).addErrback(log_failure, "cache_joined_hosts_for_event failed"), - ], + ), consumeErrors=True, ) ).addErrback(unwrapFirstError) - return result[0] + return result async def _persist_event( self, diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 1238bfd287..a8a520f809 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -25,6 +25,7 @@ from zope.interface import implementer from twisted.internet import defer from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.internet.interfaces import ( + IProtocol, IProtocolFactory, IReactorCore, IStreamClientEndpoint, @@ -309,12 +310,14 @@ class MatrixHostnameEndpoint: self._srv_resolver = srv_resolver - def connect(self, protocol_factory: IProtocolFactory) -> defer.Deferred: + def connect( + self, protocol_factory: IProtocolFactory + ) -> "defer.Deferred[IProtocol]": """Implements IStreamClientEndpoint interface""" return run_in_background(self._do_connect, protocol_factory) - async def _do_connect(self, protocol_factory: IProtocolFactory) -> None: + async def _do_connect(self, protocol_factory: IProtocolFactory) -> IProtocol: first_exception = None server_list = await self._resolve_server() diff --git a/synapse/logging/context.py b/synapse/logging/context.py index d8ae3188b7..25e78cc82f 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -22,20 +22,33 @@ them. See doc/log_contexts.rst for details on how this works. """ -import inspect import logging import threading import typing import warnings -from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union +from types import TracebackType +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Optional, + Tuple, + Type, + TypeVar, + Union, + overload, +) import attr from typing_extensions import Literal from twisted.internet import defer, threads +from twisted.python.threadpool import ThreadPool if TYPE_CHECKING: from synapse.logging.scopecontextmanager import _LogContextScope + from synapse.types import ISynapseReactor logger = logging.getLogger(__name__) @@ -66,7 +79,7 @@ except Exception: # a hook which can be set during testing to assert that we aren't abusing logcontexts. -def logcontext_error(msg: str): +def logcontext_error(msg: str) -> None: logger.warning(msg) @@ -223,22 +236,19 @@ class _Sentinel: def __str__(self) -> str: return "sentinel" - def copy_to(self, record): - pass - - def start(self, rusage: "Optional[resource.struct_rusage]"): + def start(self, rusage: "Optional[resource.struct_rusage]") -> None: pass - def stop(self, rusage: "Optional[resource.struct_rusage]"): + def stop(self, rusage: "Optional[resource.struct_rusage]") -> None: pass - def add_database_transaction(self, duration_sec): + def add_database_transaction(self, duration_sec: float) -> None: pass - def add_database_scheduled(self, sched_sec): + def add_database_scheduled(self, sched_sec: float) -> None: pass - def record_event_fetch(self, event_count): + def record_event_fetch(self, event_count: int) -> None: pass def __bool__(self) -> Literal[False]: @@ -379,7 +389,12 @@ class LoggingContext: ) return self - def __exit__(self, type, value, traceback) -> None: + def __exit__( + self, + type: Optional[Type[BaseException]], + value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: """Restore the logging context in thread local storage to the state it was before this context was entered. Returns: @@ -399,17 +414,6 @@ class LoggingContext: # recorded against the correct metrics. self.finished = True - def copy_to(self, record) -> None: - """Copy logging fields from this context to a log record or - another LoggingContext - """ - - # we track the current request - record.request = self.request - - # we also track the current scope: - record.scope = self.scope - def start(self, rusage: "Optional[resource.struct_rusage]") -> None: """ Record that this logcontext is currently running. @@ -626,7 +630,12 @@ class PreserveLoggingContext: def __enter__(self) -> None: self._old_context = set_current_context(self._new_context) - def __exit__(self, type, value, traceback) -> None: + def __exit__( + self, + type: Optional[Type[BaseException]], + value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: context = set_current_context(self._old_context) if context != self._new_context: @@ -711,16 +720,61 @@ def nested_logging_context(suffix: str) -> LoggingContext: ) -def preserve_fn(f): +R = TypeVar("R") + + +@overload +def preserve_fn( # type: ignore[misc] + f: Callable[..., Awaitable[R]], +) -> Callable[..., "defer.Deferred[R]"]: + # The `type: ignore[misc]` above suppresses + # "Overloaded function signatures 1 and 2 overlap with incompatible return types" + ... + + +@overload +def preserve_fn(f: Callable[..., R]) -> Callable[..., "defer.Deferred[R]"]: + ... + + +def preserve_fn( + f: Union[ + Callable[..., R], + Callable[..., Awaitable[R]], + ] +) -> Callable[..., "defer.Deferred[R]"]: """Function decorator which wraps the function with run_in_background""" - def g(*args, **kwargs): + def g(*args: Any, **kwargs: Any) -> "defer.Deferred[R]": return run_in_background(f, *args, **kwargs) return g -def run_in_background(f, *args, **kwargs) -> defer.Deferred: +@overload +def run_in_background( # type: ignore[misc] + f: Callable[..., Awaitable[R]], *args: Any, **kwargs: Any +) -> "defer.Deferred[R]": + # The `type: ignore[misc]` above suppresses + # "Overloaded function signatures 1 and 2 overlap with incompatible return types" + ... + + +@overload +def run_in_background( + f: Callable[..., R], *args: Any, **kwargs: Any +) -> "defer.Deferred[R]": + ... + + +def run_in_background( + f: Union[ + Callable[..., R], + Callable[..., Awaitable[R]], + ], + *args: Any, + **kwargs: Any, +) -> "defer.Deferred[R]": """Calls a function, ensuring that the current context is restored after return from the function, and that the sentinel context is set once the deferred returned by the function completes. @@ -751,6 +805,10 @@ def run_in_background(f, *args, **kwargs) -> defer.Deferred: # At this point we should have a Deferred, if not then f was a synchronous # function, wrap it in a Deferred for consistency. if not isinstance(res, defer.Deferred): + # `res` is not a `Deferred` and not a `Coroutine`. + # There are no other types of `Awaitable`s we expect to encounter in Synapse. + assert not isinstance(res, Awaitable) + return defer.succeed(res) if res.called and not res.paused: @@ -778,13 +836,14 @@ def run_in_background(f, *args, **kwargs) -> defer.Deferred: return res -def make_deferred_yieldable(deferred): - """Given a deferred (or coroutine), make it follow the Synapse logcontext - rules: +T = TypeVar("T") + - If the deferred has completed (or is not actually a Deferred), essentially - does nothing (just returns another completed deferred with the - result/failure). +def make_deferred_yieldable(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]": + """Given a deferred, make it follow the Synapse logcontext rules: + + If the deferred has completed, essentially does nothing (just returns another + completed deferred with the result/failure). If the deferred has not yet completed, resets the logcontext before returning a deferred. Then, when the deferred completes, restores the @@ -792,16 +851,6 @@ def make_deferred_yieldable(deferred): (This is more-or-less the opposite operation to run_in_background.) """ - if inspect.isawaitable(deferred): - # If we're given a coroutine we convert it to a deferred so that we - # run it and find out if it immediately finishes, it it does then we - # don't need to fiddle with log contexts at all and can return - # immediately. - deferred = defer.ensureDeferred(deferred) - - if not isinstance(deferred, defer.Deferred): - return deferred - if deferred.called and not deferred.paused: # it looks like this deferred is ready to run any callbacks we give it # immediately. We may as well optimise out the logcontext faffery. @@ -823,7 +872,9 @@ def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT: return result -def defer_to_thread(reactor, f, *args, **kwargs): +def defer_to_thread( + reactor: "ISynapseReactor", f: Callable[..., R], *args: Any, **kwargs: Any +) -> "defer.Deferred[R]": """ Calls the function `f` using a thread from the reactor's default threadpool and returns the result as a Deferred. @@ -855,7 +906,13 @@ def defer_to_thread(reactor, f, *args, **kwargs): return defer_to_threadpool(reactor, reactor.getThreadPool(), f, *args, **kwargs) -def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs): +def defer_to_threadpool( + reactor: "ISynapseReactor", + threadpool: ThreadPool, + f: Callable[..., R], + *args: Any, + **kwargs: Any, +) -> "defer.Deferred[R]": """ A wrapper for twisted.internet.threads.deferToThreadpool, which handles logcontexts correctly. @@ -897,7 +954,7 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs): assert isinstance(curr_context, LoggingContext) parent_context = curr_context - def g(): + def g() -> R: with LoggingContext(str(curr_context), parent_context=parent_context): return f(*args, **kwargs) diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 20ce294209..bde99ea878 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -30,9 +30,11 @@ from typing import ( Iterator, Optional, Set, + Tuple, TypeVar, Union, cast, + overload, ) import attr @@ -234,6 +236,59 @@ def yieldable_gather_results( ).addErrback(unwrapFirstError) +T1 = TypeVar("T1") +T2 = TypeVar("T2") +T3 = TypeVar("T3") + + +@overload +def gather_results( + deferredList: Tuple[()], consumeErrors: bool = ... +) -> "defer.Deferred[Tuple[()]]": + ... + + +@overload +def gather_results( + deferredList: Tuple["defer.Deferred[T1]"], + consumeErrors: bool = ..., +) -> "defer.Deferred[Tuple[T1]]": + ... + + +@overload +def gather_results( + deferredList: Tuple["defer.Deferred[T1]", "defer.Deferred[T2]"], + consumeErrors: bool = ..., +) -> "defer.Deferred[Tuple[T1, T2]]": + ... + + +@overload +def gather_results( + deferredList: Tuple[ + "defer.Deferred[T1]", "defer.Deferred[T2]", "defer.Deferred[T3]" + ], + consumeErrors: bool = ..., +) -> "defer.Deferred[Tuple[T1, T2, T3]]": + ... + + +def gather_results( # type: ignore[misc] + deferredList: Tuple["defer.Deferred[T1]", ...], + consumeErrors: bool = False, +) -> "defer.Deferred[Tuple[T1, ...]]": + """Combines a tuple of `Deferred`s into a single `Deferred`. + + Wraps `defer.gatherResults` to provide type annotations that support heterogenous + lists of `Deferred`s. + """ + # The `type: ignore[misc]` above suppresses + # "Overloaded function implementation cannot produce return type of signature 1/2/3" + deferred = defer.gatherResults(deferredList, consumeErrors=consumeErrors) + return deferred.addCallback(tuple) + + @attr.s(slots=True) class _LinearizerEntry: # The number of things executing. @@ -352,7 +407,7 @@ class Linearizer: logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key) - new_defer = make_deferred_yieldable(defer.Deferred()) + new_defer: "defer.Deferred[None]" = make_deferred_yieldable(defer.Deferred()) entry.deferreds[new_defer] = 1 def cb(_r: None) -> "defer.Deferred[None]": diff --git a/synapse/util/caches/cached_call.py b/synapse/util/caches/cached_call.py index 470f4f91a5..e325f44da3 100644 --- a/synapse/util/caches/cached_call.py +++ b/synapse/util/caches/cached_call.py @@ -76,6 +76,7 @@ class CachedCall(Generic[TV]): # Fire off the callable now if this is our first time if not self._deferred: + assert self._callable is not None self._deferred = run_in_background(self._callable) # we will never need the callable again, so make sure it can be GCed diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py index de2adacd70..46771a401b 100644 --- a/synapse/util/file_consumer.py +++ b/synapse/util/file_consumer.py @@ -142,6 +142,7 @@ class BackgroundFileConsumer: def wait(self) -> "Deferred[None]": """Returns a deferred that resolves when finished writing to file""" + assert self._finished_deferred is not None return make_deferred_yieldable(self._finished_deferred) def _resume_paused_producer(self) -> None: diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py index 5d9c4665aa..621b0f9fcd 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py @@ -152,46 +152,11 @@ class LoggingContextTestCase(unittest.TestCase): # now it should be restored self._check_test_key("one") - @defer.inlineCallbacks - def test_make_deferred_yieldable_on_non_deferred(self): - """Check that make_deferred_yieldable does the right thing when its - argument isn't actually a deferred""" - - with LoggingContext("one"): - d1 = make_deferred_yieldable("bum") - self._check_test_key("one") - - r = yield d1 - self.assertEqual(r, "bum") - self._check_test_key("one") - def test_nested_logging_context(self): with LoggingContext("foo"): nested_context = nested_logging_context(suffix="bar") self.assertEqual(nested_context.name, "foo-bar") - @defer.inlineCallbacks - def test_make_deferred_yieldable_with_await(self): - # an async function which returns an incomplete coroutine, but doesn't - # follow the synapse rules. - - async def blocking_function(): - d = defer.Deferred() - reactor.callLater(0, d.callback, None) - await d - - sentinel_context = current_context() - - with LoggingContext("one"): - d1 = make_deferred_yieldable(blocking_function()) - # make sure that the context was reset by make_deferred_yieldable - self.assertIs(current_context(), sentinel_context) - - yield d1 - - # now it should be restored - self._check_test_key("one") - # a function which returns a deferred which has been "called", but # which had a function which returned another incomplete deferred on -- cgit 1.5.1 From 3e0cfd447e17658a937fe62555db9e968f00b15b Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 20 Dec 2021 11:00:13 -0500 Subject: Return JSON errors for unknown resources under /matrix/client. (#11602) Instead of returning 404 errors with HTML bodies when an unknown prefix was requested (e.g. /matrix/client/v1 before Synapse v1.49.0). --- changelog.d/11602.bugfix | 1 + synapse/app/homeserver.py | 9 ++------- synapse/http/server.py | 6 +++--- 3 files changed, 6 insertions(+), 10 deletions(-) create mode 100644 changelog.d/11602.bugfix (limited to 'synapse/http') diff --git a/changelog.d/11602.bugfix b/changelog.d/11602.bugfix new file mode 100644 index 0000000000..e0dfbf1a15 --- /dev/null +++ b/changelog.d/11602.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug that some unknown endpoints would return HTML error pages instead of JSON `M_UNRECOGNIZED` errors. diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index dd76e07321..177ce040e8 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -27,6 +27,7 @@ import synapse import synapse.config.logger from synapse import events from synapse.api.urls import ( + CLIENT_API_PREFIX, FEDERATION_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_R0_PREFIX, @@ -192,13 +193,7 @@ class SynapseHomeServer(HomeServer): resources.update( { - "/_matrix/client/api/v1": client_resource, - "/_matrix/client/r0": client_resource, - "/_matrix/client/v1": client_resource, - "/_matrix/client/v3": client_resource, - "/_matrix/client/unstable": client_resource, - "/_matrix/client/v2_alpha": client_resource, - "/_matrix/client/versions": client_resource, + CLIENT_API_PREFIX: client_resource, "/.well-known": well_known_resource(self), "/_synapse/admin": AdminRestResource(self), **build_synapse_client_resource_tree(self), diff --git a/synapse/http/server.py b/synapse/http/server.py index 4fd5660a08..7bbbe7648b 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -530,7 +530,7 @@ class RootRedirect(resource.Resource): """Redirects the root '/' path to another path.""" def __init__(self, path: str): - resource.Resource.__init__(self) + super().__init__() self.url = path def render_GET(self, request: Request) -> bytes: @@ -539,7 +539,7 @@ class RootRedirect(resource.Resource): def getChild(self, name: str, request: Request) -> resource.Resource: if len(name) == 0: return self # select ourselves as the child to render - return resource.Resource.getChild(self, name, request) + return super().getChild(name, request) class OptionsResource(resource.Resource): @@ -556,7 +556,7 @@ class OptionsResource(resource.Resource): def getChildWithDefault(self, path: str, request: Request) -> resource.Resource: if request.method == b"OPTIONS": return self # select ourselves as the child to render - return resource.Resource.getChildWithDefault(self, path, request) + return super().getChildWithDefault(path, request) class RootOptionsRedirectResource(OptionsResource, RootRedirect): -- cgit 1.5.1 From 60fa4935b5d3ee26f9ebb4b25ec74bed26d3c98d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 20 Dec 2021 17:45:03 +0000 Subject: Improve opentracing for incoming HTTP requests (#11618) * remove `start_active_span_from_request` Instead, pull out a separate function, `span_context_from_request`, to extract the parent span, which we can then pass into `start_active_span` as normal. This seems to be clearer all round. * Remove redundant tags from `incoming-federation-request` These are all wrapped up inside a parent span generated in AsyncResource, so there's no point duplicating all the tags that are set there. * Leave request spans open until the request completes It may take some time for the response to be encoded into JSON, and that JSON to be streamed back to the client, and really we want that inside the top-level span, so let's hand responsibility for closure to the SynapseRequest. * opentracing logs for HTTP request events * changelog --- changelog.d/11618.misc | 1 + synapse/federation/transport/server/_base.py | 39 ++++++---------- synapse/http/site.py | 30 +++++++++++- synapse/logging/opentracing.py | 68 +++++++++------------------- 4 files changed, 65 insertions(+), 73 deletions(-) create mode 100644 changelog.d/11618.misc (limited to 'synapse/http') diff --git a/changelog.d/11618.misc b/changelog.d/11618.misc new file mode 100644 index 0000000000..4076b30bf7 --- /dev/null +++ b/changelog.d/11618.misc @@ -0,0 +1 @@ +Improve opentracing support for incoming HTTP requests. diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py index dc39e3537b..da1fbf8b63 100644 --- a/synapse/federation/transport/server/_base.py +++ b/synapse/federation/transport/server/_base.py @@ -22,13 +22,11 @@ from synapse.api.urls import FEDERATION_V1_PREFIX from synapse.http.server import HttpServer, ServletCallback from synapse.http.servlet import parse_json_object_from_request from synapse.http.site import SynapseRequest -from synapse.logging import opentracing from synapse.logging.context import run_in_background from synapse.logging.opentracing import ( - SynapseTags, - start_active_span, - start_active_span_from_request, - tags, + set_tag, + span_context_from_request, + start_active_span_follows_from, whitelisted_homeserver, ) from synapse.server import HomeServer @@ -279,30 +277,19 @@ class BaseFederationServlet: logger.warning("authenticate_request failed: %s", e) raise - request_tags = { - SynapseTags.REQUEST_ID: request.get_request_id(), - tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER, - tags.HTTP_METHOD: request.get_method(), - tags.HTTP_URL: request.get_redacted_uri(), - tags.PEER_HOST_IPV6: request.getClientIP(), - "authenticated_entity": origin, - "servlet_name": request.request_metrics.name, - } - - # Only accept the span context if the origin is authenticated - # and whitelisted + # update the active opentracing span with the authenticated entity + set_tag("authenticated_entity", origin) + + # if the origin is authenticated and whitelisted, link to its span context + context = None if origin and whitelisted_homeserver(origin): - scope = start_active_span_from_request( - request, "incoming-federation-request", tags=request_tags - ) - else: - scope = start_active_span( - "incoming-federation-request", tags=request_tags - ) + context = span_context_from_request(request) - with scope: - opentracing.inject_response_headers(request.responseHeaders) + scope = start_active_span_follows_from( + "incoming-federation-request", contexts=(context,) if context else () + ) + with scope: if origin and self.RATELIMIT: with ratelimiter.ratelimit(origin) as d: await d diff --git a/synapse/http/site.py b/synapse/http/site.py index 9f68d7e191..80f7a2ff58 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -14,7 +14,7 @@ import contextlib import logging import time -from typing import Any, Generator, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Generator, Optional, Tuple, Union import attr from zope.interface import implementer @@ -35,6 +35,9 @@ from synapse.logging.context import ( ) from synapse.types import Requester +if TYPE_CHECKING: + import opentracing + logger = logging.getLogger(__name__) _next_request_seq = 0 @@ -81,6 +84,10 @@ class SynapseRequest(Request): # server name, for client requests this is the Requester object. self._requester: Optional[Union[Requester, str]] = None + # An opentracing span for this request. Will be closed when the request is + # completely processed. + self._opentracing_span: "Optional[opentracing.Span]" = None + # we can't yet create the logcontext, as we don't know the method. self.logcontext: Optional[LoggingContext] = None @@ -148,6 +155,13 @@ class SynapseRequest(Request): # If there's no authenticated entity, it was the requester. self.logcontext.request.authenticated_entity = authenticated_entity or requester + def set_opentracing_span(self, span: "opentracing.Span") -> None: + """attach an opentracing span to this request + + Doing so will cause the span to be closed when we finish processing the request + """ + self._opentracing_span = span + def get_request_id(self) -> str: return "%s-%i" % (self.get_method(), self.request_seq) @@ -286,6 +300,9 @@ class SynapseRequest(Request): self._processing_finished_time = time.time() self._is_processing = False + if self._opentracing_span: + self._opentracing_span.log_kv({"event": "finished processing"}) + # if we've already sent the response, log it now; otherwise, we wait for the # response to be sent. if self.finish_time is not None: @@ -299,6 +316,8 @@ class SynapseRequest(Request): """ self.finish_time = time.time() Request.finish(self) + if self._opentracing_span: + self._opentracing_span.log_kv({"event": "response sent"}) if not self._is_processing: assert self.logcontext is not None with PreserveLoggingContext(self.logcontext): @@ -333,6 +352,11 @@ class SynapseRequest(Request): with PreserveLoggingContext(self.logcontext): logger.info("Connection from client lost before response was sent") + if self._opentracing_span: + self._opentracing_span.log_kv( + {"event": "client connection lost", "reason": str(reason.value)} + ) + if not self._is_processing: self._finished_processing() @@ -421,6 +445,10 @@ class SynapseRequest(Request): usage.evt_db_fetch_count, ) + # complete the opentracing span, if any. + if self._opentracing_span: + self._opentracing_span.finish() + try: self.request_metrics.stop(self.finish_time, self.code, self.sentLength) except Exception as e: diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 5d93ab07f1..6364290615 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -173,6 +173,7 @@ from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Pattern, Typ import attr from twisted.internet import defer +from twisted.web.http import Request from twisted.web.http_headers import Headers from synapse.config import ConfigError @@ -490,48 +491,6 @@ def start_active_span_follows_from( return scope -def start_active_span_from_request( - request, - operation_name, - references=None, - tags=None, - start_time=None, - ignore_active_span=False, - finish_on_close=True, -): - """ - Extracts a span context from a Twisted Request. - args: - headers (twisted.web.http.Request) - - For the other args see opentracing.tracer - - returns: - span_context (opentracing.span.SpanContext) - """ - # Twisted encodes the values as lists whereas opentracing doesn't. - # So, we take the first item in the list. - # Also, twisted uses byte arrays while opentracing expects strings. - - if opentracing is None: - return noop_context_manager() # type: ignore[unreachable] - - header_dict = { - k.decode(): v[0].decode() for k, v in request.requestHeaders.getAllRawHeaders() - } - context = opentracing.tracer.extract(opentracing.Format.HTTP_HEADERS, header_dict) - - return opentracing.tracer.start_active_span( - operation_name, - child_of=context, - references=references, - tags=tags, - start_time=start_time, - ignore_active_span=ignore_active_span, - finish_on_close=finish_on_close, - ) - - def start_active_span_from_edu( edu_content, operation_name, @@ -743,6 +702,20 @@ def active_span_context_as_string(): return json_encoder.encode(carrier) +def span_context_from_request(request: Request) -> "Optional[opentracing.SpanContext]": + """Extract an opentracing context from the headers on an HTTP request + + This is useful when we have received an HTTP request from another part of our + system, and want to link our spans to those of the remote system. + """ + if not opentracing: + return None + header_dict = { + k.decode(): v[0].decode() for k, v in request.requestHeaders.getAllRawHeaders() + } + return opentracing.tracer.extract(opentracing.Format.HTTP_HEADERS, header_dict) + + @only_if_tracing def span_context_from_string(carrier): """ @@ -882,10 +855,13 @@ def trace_servlet(request: "SynapseRequest", extract_context: bool = False): } request_name = request.request_metrics.name - if extract_context: - scope = start_active_span_from_request(request, request_name) - else: - scope = start_active_span(request_name) + context = span_context_from_request(request) if extract_context else None + + # we configure the scope not to finish the span immediately on exit, and instead + # pass the span into the SynapseRequest, which will finish it once we've finished + # sending the response to the client. + scope = start_active_span(request_name, child_of=context, finish_on_close=False) + request.set_opentracing_span(scope.span) with scope: inject_response_headers(request.responseHeaders) -- cgit 1.5.1 From 221595414751f7b8fd0c79772c5ac4ffefeca10a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 21 Dec 2021 11:10:36 +0000 Subject: Various opentracing enhancements (#11619) * Wrap `auth.get_user_by_req` in an opentracing span give `get_user_by_req` its own opentracing span, since it can result in a non-trivial number of sub-spans which it is useful to group together. This requires a bit of reorganisation because it also sets some tags (and may force tracing) on the servlet span. * Emit opentracing span for encoding json responses This can be a significant time sink. * Rename all sync spans with a prefix * Write an opentracing span for encoding sync response * opentracing span to group generate_room_entries * opentracing spans within sync.encode_response * changelog * Use the `trace` decorator instead of context managers --- changelog.d/11619.misc | 1 + synapse/api/auth.py | 53 +++++++++++++++++++++++++++++++-------------- synapse/handlers/sync.py | 7 +++--- synapse/http/server.py | 19 ++++++++++++++-- synapse/rest/client/sync.py | 6 +++++ 5 files changed, 65 insertions(+), 21 deletions(-) create mode 100644 changelog.d/11619.misc (limited to 'synapse/http') diff --git a/changelog.d/11619.misc b/changelog.d/11619.misc new file mode 100644 index 0000000000..2125cbddd2 --- /dev/null +++ b/changelog.d/11619.misc @@ -0,0 +1 @@ +A number of improvements to opentracing support. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 0bf58dff08..4a32d430bd 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -32,7 +32,7 @@ from synapse.appservice import ApplicationService from synapse.events import EventBase from synapse.http import get_request_user_agent from synapse.http.site import SynapseRequest -from synapse.logging import opentracing as opentracing +from synapse.logging.opentracing import active_span, force_tracing, start_active_span from synapse.storage.databases.main.registration import TokenLookupResult from synapse.types import Requester, StateMap, UserID, create_requester from synapse.util.caches.lrucache import LruCache @@ -149,6 +149,42 @@ class Auth: is invalid. AuthError if access is denied for the user in the access token """ + parent_span = active_span() + with start_active_span("get_user_by_req"): + requester = await self._wrapped_get_user_by_req( + request, allow_guest, rights, allow_expired + ) + + if parent_span: + if requester.authenticated_entity in self._force_tracing_for_users: + # request tracing is enabled for this user, so we need to force it + # tracing on for the parent span (which will be the servlet span). + # + # It's too late for the get_user_by_req span to inherit the setting, + # so we also force it on for that. + force_tracing() + force_tracing(parent_span) + parent_span.set_tag( + "authenticated_entity", requester.authenticated_entity + ) + parent_span.set_tag("user_id", requester.user.to_string()) + if requester.device_id is not None: + parent_span.set_tag("device_id", requester.device_id) + if requester.app_service is not None: + parent_span.set_tag("appservice_id", requester.app_service.id) + return requester + + async def _wrapped_get_user_by_req( + self, + request: SynapseRequest, + allow_guest: bool, + rights: str, + allow_expired: bool, + ) -> Requester: + """Helper for get_user_by_req + + Once get_user_by_req has set up the opentracing span, this does the actual work. + """ try: ip_addr = request.getClientIP() user_agent = get_request_user_agent(request) @@ -177,14 +213,6 @@ class Auth: ) request.requester = user_id - if user_id in self._force_tracing_for_users: - opentracing.force_tracing() - opentracing.set_tag("authenticated_entity", user_id) - opentracing.set_tag("user_id", user_id) - if device_id is not None: - opentracing.set_tag("device_id", device_id) - opentracing.set_tag("appservice_id", app_service.id) - return requester user_info = await self.get_user_by_access_token( @@ -242,13 +270,6 @@ class Auth: ) request.requester = requester - if user_info.token_owner in self._force_tracing_for_users: - opentracing.force_tracing() - opentracing.set_tag("authenticated_entity", user_info.token_owner) - opentracing.set_tag("user_id", user_info.user_id) - if device_id: - opentracing.set_tag("device_id", device_id) - return requester except KeyError: raise MissingClientTokenError() diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index bcd10cbb30..d24124d6ac 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -421,7 +421,7 @@ class SyncHandler: span to track the sync. See `generate_sync_result` for the next part of your indoctrination. """ - with start_active_span("current_sync_for_user"): + with start_active_span("sync.current_sync_for_user"): log_kv({"since_token": since_token}) sync_result = await self.generate_sync_result( sync_config, since_token, full_state @@ -1585,7 +1585,8 @@ class SyncHandler: ) logger.debug("Generated room entry for %s", room_entry.room_id) - await concurrently_execute(handle_room_entries, room_entries, 10) + with start_active_span("sync.generate_room_entries"): + await concurrently_execute(handle_room_entries, room_entries, 10) sync_result_builder.invited.extend(invited) sync_result_builder.knocked.extend(knocked) @@ -2045,7 +2046,7 @@ class SyncHandler: since_token = room_builder.since_token upto_token = room_builder.upto_token - with start_active_span("generate_room_entry"): + with start_active_span("sync.generate_room_entry"): set_tag("room_id", room_id) log_kv({"events": len(events or ())}) diff --git a/synapse/http/server.py b/synapse/http/server.py index 7bbbe7648b..e302946591 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -58,12 +58,14 @@ from synapse.api.errors import ( ) from synapse.http.site import SynapseRequest from synapse.logging.context import defer_to_thread, preserve_fn, run_in_background -from synapse.logging.opentracing import trace_servlet +from synapse.logging.opentracing import active_span, start_active_span, trace_servlet from synapse.util import json_encoder from synapse.util.caches import intern_dict from synapse.util.iterutils import chunk_seq if TYPE_CHECKING: + import opentracing + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -759,7 +761,20 @@ async def _async_write_json_to_request_in_thread( expensive. """ - json_str = await defer_to_thread(request.reactor, json_encoder, json_object) + def encode(opentracing_span: "Optional[opentracing.Span]") -> bytes: + # it might take a while for the threadpool to schedule us, so we write + # opentracing logs once we actually get scheduled, so that we can see how + # much that contributed. + if opentracing_span: + opentracing_span.log_kv({"event": "scheduled"}) + res = json_encoder(json_object) + if opentracing_span: + opentracing_span.log_kv({"event": "encoded"}) + return res + + with start_active_span("encode_json_response"): + span = active_span() + json_str = await defer_to_thread(request.reactor, encode, span) _write_bytes_to_request(request, json_str) diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 8c4b0f6e5d..e99a943d0d 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -48,6 +48,7 @@ from synapse.handlers.sync import ( from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.http.site import SynapseRequest +from synapse.logging.opentracing import trace from synapse.types import JsonDict, StreamToken from synapse.util import json_decoder @@ -222,6 +223,7 @@ class SyncRestServlet(RestServlet): logger.debug("Event formatting complete") return 200, response_content + @trace(opname="sync.encode_response") async def encode_response( self, time_now: int, @@ -332,6 +334,7 @@ class SyncRestServlet(RestServlet): ] } + @trace(opname="sync.encode_joined") async def encode_joined( self, rooms: List[JoinedSyncResult], @@ -368,6 +371,7 @@ class SyncRestServlet(RestServlet): return joined + @trace(opname="sync.encode_invited") async def encode_invited( self, rooms: List[InvitedSyncResult], @@ -406,6 +410,7 @@ class SyncRestServlet(RestServlet): return invited + @trace(opname="sync.encode_knocked") async def encode_knocked( self, rooms: List[KnockedSyncResult], @@ -460,6 +465,7 @@ class SyncRestServlet(RestServlet): return knocked + @trace(opname="sync.encode_archived") async def encode_archived( self, rooms: List[ArchivedSyncResult], -- cgit 1.5.1 From cbd82d0b2db069400b5d43373838817d8a0209e7 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 30 Dec 2021 13:47:12 -0500 Subject: Convert all namedtuples to attrs. (#11665) To improve type hints throughout the code. --- changelog.d/11665.misc | 1 + synapse/api/filtering.py | 3 +- synapse/config/repository.py | 34 +++---- synapse/federation/federation_base.py | 5 - synapse/federation/send_queue.py | 47 +++++----- synapse/handlers/appservice.py | 4 +- synapse/handlers/directory.py | 10 +- synapse/handlers/room_list.py | 22 ++--- synapse/handlers/typing.py | 14 ++- synapse/http/server.py | 10 +- synapse/replication/tcp/streams/_base.py | 129 +++++++++++++------------- synapse/replication/tcp/streams/federation.py | 15 ++- synapse/rest/media/v1/media_repository.py | 19 ++-- synapse/state/__init__.py | 5 +- synapse/storage/databases/main/directory.py | 10 +- synapse/storage/databases/main/events.py | 13 ++- synapse/storage/databases/main/room.py | 26 ++++-- synapse/storage/databases/main/search.py | 16 +++- synapse/storage/databases/main/state.py | 14 --- synapse/storage/databases/main/stream.py | 12 ++- synapse/types.py | 22 ++--- tests/replication/test_federation_ack.py | 6 +- 22 files changed, 231 insertions(+), 206 deletions(-) create mode 100644 changelog.d/11665.misc (limited to 'synapse/http') diff --git a/changelog.d/11665.misc b/changelog.d/11665.misc new file mode 100644 index 0000000000..e7cc8ff23f --- /dev/null +++ b/changelog.d/11665.misc @@ -0,0 +1 @@ +Convert `namedtuples` to `attrs`. diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 13dd6ce248..d087c816db 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -351,8 +351,7 @@ class Filter: True if the event matches the filter. """ # We usually get the full "events" as dictionaries coming through, - # except for presence which actually gets passed around as its own - # namedtuple type. + # except for presence which actually gets passed around as its own type. if isinstance(event, UserPresenceState): user_id = event.user_id field_matchers = { diff --git a/synapse/config/repository.py b/synapse/config/repository.py index b129b9dd68..1980351e77 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -14,10 +14,11 @@ import logging import os -from collections import namedtuple from typing import Dict, List, Tuple from urllib.request import getproxies_environment # type: ignore +import attr + from synapse.config.server import DEFAULT_IP_RANGE_BLACKLIST, generate_ip_set from synapse.python_dependencies import DependencyException, check_requirements from synapse.types import JsonDict @@ -44,18 +45,20 @@ THUMBNAIL_SIZE_YAML = """\ HTTP_PROXY_SET_WARNING = """\ The Synapse config url_preview_ip_range_blacklist will be ignored as an HTTP(s) proxy is configured.""" -ThumbnailRequirement = namedtuple( - "ThumbnailRequirement", ["width", "height", "method", "media_type"] -) -MediaStorageProviderConfig = namedtuple( - "MediaStorageProviderConfig", - ( - "store_local", # Whether to store newly uploaded local files - "store_remote", # Whether to store newly downloaded remote files - "store_synchronous", # Whether to wait for successful storage for local uploads - ), -) +@attr.s(frozen=True, slots=True, auto_attribs=True) +class ThumbnailRequirement: + width: int + height: int + method: str + media_type: str + + +@attr.s(frozen=True, slots=True, auto_attribs=True) +class MediaStorageProviderConfig: + store_local: bool # Whether to store newly uploaded local files + store_remote: bool # Whether to store newly downloaded remote files + store_synchronous: bool # Whether to wait for successful storage for local uploads def parse_thumbnail_requirements( @@ -66,11 +69,10 @@ def parse_thumbnail_requirements( method, and thumbnail media type to precalculate Args: - thumbnail_sizes(list): List of dicts with "width", "height", and - "method" keys + thumbnail_sizes: List of dicts with "width", "height", and "method" keys + Returns: - Dictionary mapping from media type string to list of - ThumbnailRequirement tuples. + Dictionary mapping from media type string to list of ThumbnailRequirement. """ requirements: Dict[str, List[ThumbnailRequirement]] = {} for size in thumbnail_sizes: diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index f56344a3b9..4df90e02d7 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from collections import namedtuple from typing import TYPE_CHECKING from synapse.api.constants import MAX_DEPTH, EventContentFields, EventTypes, Membership @@ -104,10 +103,6 @@ class FederationBase: return pdu -class PduToCheckSig(namedtuple("PduToCheckSig", ["pdu", "sender_domain", "deferreds"])): - pass - - async def _check_sigs_on_pdu( keyring: Keyring, room_version: RoomVersion, pdu: EventBase ) -> None: diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 63289a5a33..0d7c4f5067 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -30,7 +30,6 @@ Events are replicated via a separate events stream. """ import logging -from collections import namedtuple from typing import ( TYPE_CHECKING, Dict, @@ -43,6 +42,7 @@ from typing import ( Type, ) +import attr from sortedcontainers import SortedDict from synapse.api.presence import UserPresenceState @@ -382,13 +382,11 @@ class BaseFederationRow: raise NotImplementedError() -class PresenceDestinationsRow( - BaseFederationRow, - namedtuple( - "PresenceDestinationsRow", - ("state", "destinations"), # UserPresenceState # list[str] - ), -): +@attr.s(slots=True, frozen=True, auto_attribs=True) +class PresenceDestinationsRow(BaseFederationRow): + state: UserPresenceState + destinations: List[str] + TypeId = "pd" @staticmethod @@ -404,17 +402,15 @@ class PresenceDestinationsRow( buff.presence_destinations.append((self.state, self.destinations)) -class KeyedEduRow( - BaseFederationRow, - namedtuple( - "KeyedEduRow", - ("key", "edu"), # tuple(str) - the edu key passed to send_edu # Edu - ), -): +@attr.s(slots=True, frozen=True, auto_attribs=True) +class KeyedEduRow(BaseFederationRow): """Streams EDUs that have an associated key that is ued to clobber. For example, typing EDUs clobber based on room_id. """ + key: Tuple[str, ...] # the edu key passed to send_edu + edu: Edu + TypeId = "k" @staticmethod @@ -428,9 +424,12 @@ class KeyedEduRow( buff.keyed_edus.setdefault(self.edu.destination, {})[self.key] = self.edu -class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu +@attr.s(slots=True, frozen=True, auto_attribs=True) +class EduRow(BaseFederationRow): """Streams EDUs that don't have keys. See KeyedEduRow""" + edu: Edu + TypeId = "e" @staticmethod @@ -453,14 +452,14 @@ _rowtypes: Tuple[Type[BaseFederationRow], ...] = ( TypeToRow = {Row.TypeId: Row for Row in _rowtypes} -ParsedFederationStreamData = namedtuple( - "ParsedFederationStreamData", - ( - "presence_destinations", # list of tuples of UserPresenceState and destinations - "keyed_edus", # dict of destination -> { key -> Edu } - "edus", # dict of destination -> [Edu] - ), -) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class ParsedFederationStreamData: + # list of tuples of UserPresenceState and destinations + presence_destinations: List[Tuple[UserPresenceState, List[str]]] + # dict of destination -> { key -> Edu } + keyed_edus: Dict[str, Dict[Tuple[str, ...], Edu]] + # dict of destination -> [Edu] + edus: Dict[str, List[Edu]] def process_rows_for_federation( diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 9abdad262b..7833e77e2b 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -462,9 +462,9 @@ class ApplicationServicesHandler: Args: room_alias: The room alias to query. + Returns: - namedtuple: with keys "room_id" and "servers" or None if no - association can be found. + RoomAliasMapping or None if no association can be found. """ room_alias_str = room_alias.to_string() services = self.store.get_app_services() diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 7ee5c47fd9..082f521791 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -278,13 +278,15 @@ class DirectoryHandler: users = await self.store.get_users_in_room(room_id) extra_servers = {get_domain_from_id(u) for u in users} - servers = set(extra_servers) | set(servers) + servers_set = set(extra_servers) | set(servers) # If this server is in the list of servers, return it first. - if self.server_name in servers: - servers = [self.server_name] + [s for s in servers if s != self.server_name] + if self.server_name in servers_set: + servers = [self.server_name] + [ + s for s in servers_set if s != self.server_name + ] else: - servers = list(servers) + servers = list(servers_set) return {"room_id": room_id, "servers": servers} diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index ba7a14d651..1a33211a1f 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -13,9 +13,9 @@ # limitations under the License. import logging -from collections import namedtuple from typing import TYPE_CHECKING, Any, Optional, Tuple +import attr import msgpack from unpaddedbase64 import decode_base64, encode_base64 @@ -474,16 +474,12 @@ class RoomListHandler: ) -class RoomListNextBatch( - namedtuple( - "RoomListNextBatch", - ( - "last_joined_members", # The count to get rooms after/before - "last_room_id", # The room_id to get rooms after/before - "direction_is_forward", # Bool if this is a next_batch, false if prev_batch - ), - ) -): +@attr.s(slots=True, frozen=True, auto_attribs=True) +class RoomListNextBatch: + last_joined_members: int # The count to get rooms after/before + last_room_id: str # The room_id to get rooms after/before + direction_is_forward: bool # True if this is a next_batch, false if prev_batch + KEY_DICT = { "last_joined_members": "m", "last_room_id": "r", @@ -502,12 +498,12 @@ class RoomListNextBatch( def to_token(self) -> str: return encode_base64( msgpack.dumps( - {self.KEY_DICT[key]: val for key, val in self._asdict().items()} + {self.KEY_DICT[key]: val for key, val in attr.asdict(self).items()} ) ) def copy_and_replace(self, **kwds: Any) -> "RoomListNextBatch": - return self._replace(**kwds) + return attr.evolve(self, **kwds) def _matches_room_entry(room_entry: JsonDict, search_filter: dict) -> bool: diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 1676ebd057..e43c22832d 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -13,9 +13,10 @@ # limitations under the License. import logging import random -from collections import namedtuple from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple +import attr + from synapse.api.errors import AuthError, ShadowBanError, SynapseError from synapse.appservice import ApplicationService from synapse.metrics.background_process_metrics import ( @@ -37,7 +38,10 @@ logger = logging.getLogger(__name__) # A tiny object useful for storing a user's membership in a room, as a mapping # key -RoomMember = namedtuple("RoomMember", ("room_id", "user_id")) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class RoomMember: + room_id: str + user_id: str # How often we expect remote servers to resend us presence. @@ -119,7 +123,7 @@ class FollowerTypingHandler: self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000) def is_typing(self, member: RoomMember) -> bool: - return member.user_id in self._room_typing.get(member.room_id, []) + return member.user_id in self._room_typing.get(member.room_id, set()) async def _push_remote(self, member: RoomMember, typing: bool) -> None: if not self.federation: @@ -166,9 +170,9 @@ class FollowerTypingHandler: for row in rows: self._room_serials[row.room_id] = token - prev_typing = set(self._room_typing.get(row.room_id, [])) + prev_typing = self._room_typing.get(row.room_id, set()) now_typing = set(row.user_ids) - self._room_typing[row.room_id] = row.user_ids + self._room_typing[row.room_id] = now_typing if self.federation: run_as_background_process( diff --git a/synapse/http/server.py b/synapse/http/server.py index e302946591..09b4125489 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -14,7 +14,6 @@ # limitations under the License. import abc -import collections import html import logging import types @@ -37,6 +36,7 @@ from typing import ( Union, ) +import attr import jinja2 from canonicaljson import encode_canonical_json from typing_extensions import Protocol @@ -354,9 +354,11 @@ class DirectServeJsonResource(_AsyncResource): return_json_error(f, request) -_PathEntry = collections.namedtuple( - "_PathEntry", ["pattern", "callback", "servlet_classname"] -) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _PathEntry: + pattern: Pattern + callback: ServletCallback + servlet_classname: str class JsonResource(DirectServeJsonResource): diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 743a01da08..5a2d90c530 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -15,7 +15,6 @@ import heapq import logging -from collections import namedtuple from typing import ( TYPE_CHECKING, Any, @@ -30,6 +29,7 @@ from typing import ( import attr from synapse.replication.http.streams import ReplicationGetStreamUpdates +from synapse.types import JsonDict if TYPE_CHECKING: from synapse.server import HomeServer @@ -226,17 +226,14 @@ class BackfillStream(Stream): or it went from being an outlier to not. """ - BackfillStreamRow = namedtuple( - "BackfillStreamRow", - ( - "event_id", # str - "room_id", # str - "type", # str - "state_key", # str, optional - "redacts", # str, optional - "relates_to", # str, optional - ), - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class BackfillStreamRow: + event_id: str + room_id: str + type: str + state_key: Optional[str] + redacts: Optional[str] + relates_to: Optional[str] NAME = "backfill" ROW_TYPE = BackfillStreamRow @@ -256,18 +253,15 @@ class BackfillStream(Stream): class PresenceStream(Stream): - PresenceStreamRow = namedtuple( - "PresenceStreamRow", - ( - "user_id", # str - "state", # str - "last_active_ts", # int - "last_federation_update_ts", # int - "last_user_sync_ts", # int - "status_msg", # str - "currently_active", # bool - ), - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class PresenceStreamRow: + user_id: str + state: str + last_active_ts: int + last_federation_update_ts: int + last_user_sync_ts: int + status_msg: str + currently_active: bool NAME = "presence" ROW_TYPE = PresenceStreamRow @@ -302,7 +296,7 @@ class PresenceFederationStream(Stream): send. """ - @attr.s(slots=True, auto_attribs=True) + @attr.s(slots=True, frozen=True, auto_attribs=True) class PresenceFederationStreamRow: destination: str user_id: str @@ -320,9 +314,10 @@ class PresenceFederationStream(Stream): class TypingStream(Stream): - TypingStreamRow = namedtuple( - "TypingStreamRow", ("room_id", "user_ids") # str # list(str) - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class TypingStreamRow: + room_id: str + user_ids: List[str] NAME = "typing" ROW_TYPE = TypingStreamRow @@ -348,16 +343,13 @@ class TypingStream(Stream): class ReceiptsStream(Stream): - ReceiptsStreamRow = namedtuple( - "ReceiptsStreamRow", - ( - "room_id", # str - "receipt_type", # str - "user_id", # str - "event_id", # str - "data", # dict - ), - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class ReceiptsStreamRow: + room_id: str + receipt_type: str + user_id: str + event_id: str + data: dict NAME = "receipts" ROW_TYPE = ReceiptsStreamRow @@ -374,7 +366,9 @@ class ReceiptsStream(Stream): class PushRulesStream(Stream): """A user has changed their push rules""" - PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) # str + @attr.s(slots=True, frozen=True, auto_attribs=True) + class PushRulesStreamRow: + user_id: str NAME = "push_rules" ROW_TYPE = PushRulesStreamRow @@ -396,10 +390,12 @@ class PushRulesStream(Stream): class PushersStream(Stream): """A user has added/changed/removed a pusher""" - PushersStreamRow = namedtuple( - "PushersStreamRow", - ("user_id", "app_id", "pushkey", "deleted"), # str # str # str # bool - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class PushersStreamRow: + user_id: str + app_id: str + pushkey: str + deleted: bool NAME = "pushers" ROW_TYPE = PushersStreamRow @@ -419,7 +415,7 @@ class CachesStream(Stream): the cache on the workers """ - @attr.s(slots=True) + @attr.s(slots=True, frozen=True, auto_attribs=True) class CachesStreamRow: """Stream to inform workers they should invalidate their cache. @@ -430,9 +426,9 @@ class CachesStream(Stream): invalidation_ts: Timestamp of when the invalidation took place. """ - cache_func = attr.ib(type=str) - keys = attr.ib(type=Optional[List[Any]]) - invalidation_ts = attr.ib(type=int) + cache_func: str + keys: Optional[List[Any]] + invalidation_ts: int NAME = "caches" ROW_TYPE = CachesStreamRow @@ -451,9 +447,9 @@ class DeviceListsStream(Stream): told about a device update. """ - @attr.s(slots=True) + @attr.s(slots=True, frozen=True, auto_attribs=True) class DeviceListsStreamRow: - entity = attr.ib(type=str) + entity: str NAME = "device_lists" ROW_TYPE = DeviceListsStreamRow @@ -470,7 +466,9 @@ class DeviceListsStream(Stream): class ToDeviceStream(Stream): """New to_device messages for a client""" - ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str + @attr.s(slots=True, frozen=True, auto_attribs=True) + class ToDeviceStreamRow: + entity: str NAME = "to_device" ROW_TYPE = ToDeviceStreamRow @@ -487,9 +485,11 @@ class ToDeviceStream(Stream): class TagAccountDataStream(Stream): """Someone added/removed a tag for a room""" - TagAccountDataStreamRow = namedtuple( - "TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class TagAccountDataStreamRow: + user_id: str + room_id: str + data: JsonDict NAME = "tag_account_data" ROW_TYPE = TagAccountDataStreamRow @@ -506,10 +506,11 @@ class TagAccountDataStream(Stream): class AccountDataStream(Stream): """Global or per room account data was changed""" - AccountDataStreamRow = namedtuple( - "AccountDataStreamRow", - ("user_id", "room_id", "data_type"), # str # Optional[str] # str - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class AccountDataStreamRow: + user_id: str + room_id: Optional[str] + data_type: str NAME = "account_data" ROW_TYPE = AccountDataStreamRow @@ -573,10 +574,12 @@ class AccountDataStream(Stream): class GroupServerStream(Stream): - GroupsStreamRow = namedtuple( - "GroupsStreamRow", - ("group_id", "user_id", "type", "content"), # str # str # str # dict - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class GroupsStreamRow: + group_id: str + user_id: str + type: str + content: JsonDict NAME = "groups" ROW_TYPE = GroupsStreamRow @@ -593,7 +596,9 @@ class GroupServerStream(Stream): class UserSignatureStream(Stream): """A user has signed their own device with their user-signing key""" - UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str + @attr.s(slots=True, frozen=True, auto_attribs=True) + class UserSignatureStreamRow: + user_id: str NAME = "user_signature" ROW_TYPE = UserSignatureStreamRow diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py index 0600cdbf36..4046bdec69 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py @@ -12,14 +12,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from typing import TYPE_CHECKING, Any, Awaitable, Callable, List, Tuple +import attr + from synapse.replication.tcp.streams._base import ( Stream, current_token_without_instance, make_http_update_function, ) +from synapse.types import JsonDict if TYPE_CHECKING: from synapse.server import HomeServer @@ -30,13 +32,10 @@ class FederationStream(Stream): sending disabled. """ - FederationStreamRow = namedtuple( - "FederationStreamRow", - ( - "type", # str, the type of data as defined in the BaseFederationRows - "data", # dict, serialization of a federation.send_queue.BaseFederationRow - ), - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class FederationStreamRow: + type: str # the type of data as defined in the BaseFederationRows + data: JsonDict # serialization of a federation.send_queue.BaseFederationRow NAME = "federation" ROW_TYPE = FederationStreamRow diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 244ba261bb..71b9a34b14 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -739,14 +739,21 @@ class MediaRepository: # We deduplicate the thumbnail sizes by ignoring the cropped versions if # they have the same dimensions of a scaled one. thumbnails: Dict[Tuple[int, int, str], str] = {} - for r_width, r_height, r_method, r_type in requirements: - if r_method == "crop": - thumbnails.setdefault((r_width, r_height, r_type), r_method) - elif r_method == "scale": - t_width, t_height = thumbnailer.aspect(r_width, r_height) + for requirement in requirements: + if requirement.method == "crop": + thumbnails.setdefault( + (requirement.width, requirement.height, requirement.media_type), + requirement.method, + ) + elif requirement.method == "scale": + t_width, t_height = thumbnailer.aspect( + requirement.width, requirement.height + ) t_width = min(m_width, t_width) t_height = min(m_height, t_height) - thumbnails[(t_width, t_height, r_type)] = r_method + thumbnails[ + (t_width, t_height, requirement.media_type) + ] = requirement.method # Now we generate the thumbnails for each dimension, store it for (t_width, t_height, t_type), t_method in thumbnails.items(): diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 446204dbe5..69ac8c3423 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -14,7 +14,7 @@ # limitations under the License. import heapq import logging -from collections import defaultdict, namedtuple +from collections import defaultdict from typing import ( TYPE_CHECKING, Any, @@ -69,9 +69,6 @@ state_groups_histogram = Histogram( ) -KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) - - EVICTION_TIMEOUT_SECONDS = 60 * 60 diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py index a3442814d7..f76c6121e8 100644 --- a/synapse/storage/databases/main/directory.py +++ b/synapse/storage/databases/main/directory.py @@ -12,16 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from typing import Iterable, List, Optional, Tuple +import attr + from synapse.api.errors import SynapseError from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main import CacheInvalidationWorkerStore from synapse.types import RoomAlias from synapse.util.caches.descriptors import cached -RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers")) + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class RoomAliasMapping: + room_id: str + room_alias: str + servers: List[str] class DirectoryWorkerStore(CacheInvalidationWorkerStore): diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 81e67ece55..dd255aefb9 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1976,14 +1976,17 @@ class PersistEventsStore: txn, self.store.get_retention_policy_for_room, (event.room_id,) ) - def store_event_search_txn(self, txn, event, key, value): + def store_event_search_txn( + self, txn: LoggingTransaction, event: EventBase, key: str, value: str + ) -> None: """Add event to the search table Args: - txn (cursor): - event (EventBase): - key (str): - value (str): + txn: The database transaction. + event: The event being added to the search table. + key: A key describing the search value (one of "content.name", + "content.topic", or "content.body") + value: The value from the event's content. """ self.store.store_search_entries_txn( txn, diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 4472335af9..c0e837854a 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -13,11 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -import collections import logging from abc import abstractmethod from enum import Enum -from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple, cast +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Dict, + List, + Optional, + Tuple, + Union, + cast, +) + +import attr from synapse.api.constants import EventContentFields, EventTypes, JoinRules from synapse.api.errors import StoreError @@ -43,9 +54,10 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -RatelimitOverride = collections.namedtuple( - "RatelimitOverride", ("messages_per_second", "burst_count") -) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class RatelimitOverride: + messages_per_second: int + burst_count: int class RoomSortOrder(Enum): @@ -207,6 +219,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): WHERE appservice_id = ? AND network_id = ? """ query_args.append(network_tuple.appservice_id) + assert network_tuple.network_id is not None query_args.append(network_tuple.network_id) else: published_sql = """ @@ -284,7 +297,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): """ where_clauses = [] - query_args = [] + query_args: List[Union[str, int]] = [] if network_tuple: if network_tuple.appservice_id: @@ -293,6 +306,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): WHERE appservice_id = ? AND network_id = ? """ query_args.append(network_tuple.appservice_id) + assert network_tuple.network_id is not None query_args.append(network_tuple.network_id) else: published_sql = """ diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index f87acfb866..2d085a5764 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -14,9 +14,10 @@ import logging import re -from collections import namedtuple from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Set +import attr + from synapse.api.errors import SynapseError from synapse.events import EventBase from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause @@ -33,10 +34,15 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -SearchEntry = namedtuple( - "SearchEntry", - ["key", "value", "event_id", "room_id", "stream_ordering", "origin_server_ts"], -) + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class SearchEntry: + key: str + value: str + event_id: str + room_id: str + stream_ordering: Optional[int] + origin_server_ts: int def _clean_value_for_search(value: str) -> str: diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 4bc044fb16..7e5a6aae18 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -14,7 +14,6 @@ # limitations under the License. import collections.abc import logging -from collections import namedtuple from typing import TYPE_CHECKING, Iterable, Optional, Set from synapse.api.constants import EventTypes, Membership @@ -43,19 +42,6 @@ logger = logging.getLogger(__name__) MAX_STATE_DELTA_HOPS = 100 -class _GetStateGroupDelta( - namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids")) -): - """Return type of get_state_group_delta that implements __len__, which lets - us use the itrable flag when caching - """ - - __slots__ = [] - - def __len__(self): - return len(self.delta_ids) if self.delta_ids else 0 - - # this inherits from EventsWorkerStore because it calls self.get_events class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): """The parts of StateGroupStore that can be called from workers.""" diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 9488fd5094..b0642ca69f 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -36,9 +36,9 @@ what sort order was used: """ import abc import logging -from collections import namedtuple from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple +import attr from frozendict import frozendict from twisted.internet import defer @@ -74,9 +74,11 @@ _TOPOLOGICAL_TOKEN = "topological" # Used as return values for pagination APIs -_EventDictReturn = namedtuple( - "_EventDictReturn", ("event_id", "topological_ordering", "stream_ordering") -) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _EventDictReturn: + event_id: str + topological_ordering: Optional[int] + stream_ordering: int def generate_pagination_where_clause( @@ -825,7 +827,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): for event, row in zip(events, rows): stream = row.stream_ordering if topo_order and row.topological_ordering: - topo = row.topological_ordering + topo: Optional[int] = row.topological_ordering else: topo = None internal = event.internal_metadata diff --git a/synapse/types.py b/synapse/types.py index b06979e8e8..42aeaf6270 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -15,7 +15,6 @@ import abc import re import string -from collections import namedtuple from typing import ( TYPE_CHECKING, Any, @@ -227,8 +226,7 @@ class DomainSpecificString(metaclass=abc.ABCMeta): localpart = attr.ib(type=str) domain = attr.ib(type=str) - # Because this class is a namedtuple of strings and booleans, it is deeply - # immutable. + # Because this is a frozen class, it is deeply immutable. def __copy__(self): return self @@ -708,16 +706,18 @@ class PersistedEventPosition: return RoomStreamToken(None, self.stream) -class ThirdPartyInstanceID( - namedtuple("ThirdPartyInstanceID", ("appservice_id", "network_id")) -): +@attr.s(slots=True, frozen=True, auto_attribs=True) +class ThirdPartyInstanceID: + appservice_id: Optional[str] + network_id: Optional[str] + # Deny iteration because it will bite you if you try to create a singleton # set by: # users = set(user) def __iter__(self): raise ValueError("Attempted to iterate a %s" % (type(self).__name__,)) - # Because this class is a namedtuple of strings, it is deeply immutable. + # Because this class is a frozen class, it is deeply immutable. def __copy__(self): return self @@ -725,22 +725,18 @@ class ThirdPartyInstanceID( return self @classmethod - def from_string(cls, s): + def from_string(cls, s: str) -> "ThirdPartyInstanceID": bits = s.split("|", 2) if len(bits) != 2: raise SynapseError(400, "Invalid ID %r" % (s,)) return cls(appservice_id=bits[0], network_id=bits[1]) - def to_string(self): + def to_string(self) -> str: return "%s|%s" % (self.appservice_id, self.network_id) __str__ = to_string - @classmethod - def create(cls, appservice_id, network_id): - return cls(appservice_id=appservice_id, network_id=network_id) - @attr.s(slots=True) class ReadReceipt: diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py index 04a869e295..1b6a4bf4b0 100644 --- a/tests/replication/test_federation_ack.py +++ b/tests/replication/test_federation_ack.py @@ -62,7 +62,11 @@ class FederationAckTestCase(HomeserverTestCase): "federation", "master", token=10, - rows=[FederationStream.FederationStreamRow(type="x", data=[1, 2, 3])], + rows=[ + FederationStream.FederationStreamRow( + type="x", data={"test": [1, 2, 3]} + ) + ], ) ) -- cgit 1.5.1 From 0201c6371cdfa0e8245c59686c131e40384bbac2 Mon Sep 17 00:00:00 2001 From: Fr3shTea <31766876+Fr3shTea@users.noreply.github.com> Date: Wed, 5 Jan 2022 11:59:29 +0000 Subject: Fix SimpleHttpClient not sending Accept header in `get_json` (#11677) Co-authored-by: reivilibre --- changelog.d/11677.bugfix | 1 + synapse/http/client.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/11677.bugfix (limited to 'synapse/http') diff --git a/changelog.d/11677.bugfix b/changelog.d/11677.bugfix new file mode 100644 index 0000000000..5691064a30 --- /dev/null +++ b/changelog.d/11677.bugfix @@ -0,0 +1 @@ +Fix wrong variable reference in `SimpleHttpClient.get_json` that results in the absence of the `Accept` header in the request. diff --git a/synapse/http/client.py b/synapse/http/client.py index fbbeceabeb..ca33b45cb2 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -588,7 +588,7 @@ class SimpleHttpClient: if headers: actual_headers.update(headers) # type: ignore - body = await self.get_raw(uri, args, headers=headers) + body = await self.get_raw(uri, args, headers=actual_headers) return json_decoder.decode(body.decode("utf-8")) async def put_json( -- cgit 1.5.1