diff options
106 files changed, 774 insertions, 577 deletions
diff --git a/Cargo.lock b/Cargo.lock index 8a8099bc6d..428cabc39a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -343,9 +343,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.87" +version = "1.0.88" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ce777b7b150d76b9cf60d28b55f5847135a003f7d7350c6be7a773508ce7d45" +checksum = "8e8b3801309262e8184d9687fb697586833e939767aea0dda89f5a8e650e8bd7" dependencies = [ "itoa", "ryu", diff --git a/changelog.d/14376.misc b/changelog.d/14376.misc new file mode 100644 index 0000000000..2ca326fea6 --- /dev/null +++ b/changelog.d/14376.misc @@ -0,0 +1 @@ +Remove old stream ID tracking code. Contributed by Nick @Beeper (@fizzadar). diff --git a/changelog.d/14393.bugfix b/changelog.d/14393.bugfix new file mode 100644 index 0000000000..97177bc62f --- /dev/null +++ b/changelog.d/14393.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in 1.58.0 where a user with presence state 'org.matrix.msc3026.busy' would mistakenly be set to 'online' when calling `/sync` or `/events` on a worker process. \ No newline at end of file diff --git a/changelog.d/14400.misc b/changelog.d/14400.misc new file mode 100644 index 0000000000..6e025329c4 --- /dev/null +++ b/changelog.d/14400.misc @@ -0,0 +1 @@ +Remove the `worker_main_http_uri` configuration setting. This is now handled via internal replication. diff --git a/changelog.d/14403.misc b/changelog.d/14403.misc new file mode 100644 index 0000000000..ff28a2712a --- /dev/null +++ b/changelog.d/14403.misc @@ -0,0 +1 @@ +Faster joins: do not wait for full state when creating events to send. diff --git a/changelog.d/14404.misc b/changelog.d/14404.misc new file mode 100644 index 0000000000..b9ab525f2b --- /dev/null +++ b/changelog.d/14404.misc @@ -0,0 +1 @@ +Faster joins: filter out non local events when a room doesn't have its full state. diff --git a/changelog.d/14412.misc b/changelog.d/14412.misc new file mode 100644 index 0000000000..4da061d461 --- /dev/null +++ b/changelog.d/14412.misc @@ -0,0 +1 @@ +Remove duplicated type information from type hints. diff --git a/changelog.d/14449.misc b/changelog.d/14449.misc new file mode 100644 index 0000000000..320c0b6fae --- /dev/null +++ b/changelog.d/14449.misc @@ -0,0 +1 @@ +Fix type logic in TCP replication code that prevented correctly ignoring blank commands. \ No newline at end of file diff --git a/changelog.d/14452.misc b/changelog.d/14452.misc new file mode 100644 index 0000000000..cb190c0823 --- /dev/null +++ b/changelog.d/14452.misc @@ -0,0 +1 @@ +Enable mypy's [`strict_equality` check](https://mypy.readthedocs.io/en/stable/command_line.html#cmdoption-mypy-strict-equality) by default. \ No newline at end of file diff --git a/changelog.d/14468.misc b/changelog.d/14468.misc new file mode 100644 index 0000000000..2ca326fea6 --- /dev/null +++ b/changelog.d/14468.misc @@ -0,0 +1 @@ +Remove old stream ID tracking code. Contributed by Nick @Beeper (@fizzadar). diff --git a/changelog.d/14476.misc b/changelog.d/14476.misc new file mode 100644 index 0000000000..6e025329c4 --- /dev/null +++ b/changelog.d/14476.misc @@ -0,0 +1 @@ +Remove the `worker_main_http_uri` configuration setting. This is now handled via internal replication. diff --git a/changelog.d/14479.misc b/changelog.d/14479.misc new file mode 100644 index 0000000000..08edd2f929 --- /dev/null +++ b/changelog.d/14479.misc @@ -0,0 +1 @@ +`scripts-dev/federation_client`: Fix routing on servers with `.well-known` files. \ No newline at end of file diff --git a/changelog.d/14487.misc b/changelog.d/14487.misc new file mode 100644 index 0000000000..f6b47a1d8e --- /dev/null +++ b/changelog.d/14487.misc @@ -0,0 +1 @@ +Reduce default third party invite rate limit to 216 invites per day. diff --git a/changelog.d/14490.misc b/changelog.d/14490.misc new file mode 100644 index 0000000000..c0a4daa885 --- /dev/null +++ b/changelog.d/14490.misc @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 0.9 where it would fail to fetch server keys whose IDs contain a forward slash. diff --git a/changelog.d/14499.doc b/changelog.d/14499.doc new file mode 100644 index 0000000000..34ea57ef43 --- /dev/null +++ b/changelog.d/14499.doc @@ -0,0 +1 @@ +Fixed link to 'Synapse administration endpoints'. diff --git a/changelog.d/14500.misc b/changelog.d/14500.misc new file mode 100644 index 0000000000..c5d70a70f7 --- /dev/null +++ b/changelog.d/14500.misc @@ -0,0 +1 @@ +Bump pygithub from 1.56 to 1.57. diff --git a/changelog.d/14501.misc b/changelog.d/14501.misc new file mode 100644 index 0000000000..3c240d38b5 --- /dev/null +++ b/changelog.d/14501.misc @@ -0,0 +1 @@ +Bump sentry-sdk from 1.10.1 to 1.11.0. diff --git a/changelog.d/14502.misc b/changelog.d/14502.misc new file mode 100644 index 0000000000..86a19900f1 --- /dev/null +++ b/changelog.d/14502.misc @@ -0,0 +1 @@ +Bump types-pillow from 9.2.2.1 to 9.3.0.1. diff --git a/changelog.d/14503.misc b/changelog.d/14503.misc new file mode 100644 index 0000000000..e627d35cde --- /dev/null +++ b/changelog.d/14503.misc @@ -0,0 +1 @@ +Bump towncrier from 21.9.0 to 22.8.0. diff --git a/changelog.d/14504.misc b/changelog.d/14504.misc new file mode 100644 index 0000000000..e228ee46a5 --- /dev/null +++ b/changelog.d/14504.misc @@ -0,0 +1 @@ +Bump phonenumbers from 8.12.56 to 8.13.0. diff --git a/changelog.d/14505.misc b/changelog.d/14505.misc new file mode 100644 index 0000000000..45d97ec461 --- /dev/null +++ b/changelog.d/14505.misc @@ -0,0 +1 @@ +Bump serde_json from 1.0.87 to 1.0.88. diff --git a/docker/configure_workers_and_start.py b/docker/configure_workers_and_start.py index 62b1bab297..c1e1544536 100755 --- a/docker/configure_workers_and_start.py +++ b/docker/configure_workers_and_start.py @@ -213,10 +213,7 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = { "listener_resources": ["client", "replication"], "endpoint_patterns": ["^/_matrix/client/(api/v1|r0|v3|unstable)/keys/upload"], "shared_extra_conf": {}, - "worker_extra_conf": ( - "worker_main_http_uri: http://127.0.0.1:%d" - % (MAIN_PROCESS_HTTP_LISTENER_PORT,) - ), + "worker_extra_conf": "", }, "account_data": { "app": "synapse.app.generic_worker", diff --git a/docs/usage/administration/admin_api/README.md b/docs/usage/administration/admin_api/README.md index f11e0b19a6..c00de2dd44 100644 --- a/docs/usage/administration/admin_api/README.md +++ b/docs/usage/administration/admin_api/README.md @@ -19,7 +19,7 @@ already on your `$PATH` depending on how Synapse was installed. Finding your user's `access_token` is client-dependent, but will usually be shown in the client's settings. ## Making an Admin API request -For security reasons, we [recommend](reverse_proxy.md#synapse-administration-endpoints) +For security reasons, we [recommend](../../../reverse_proxy.md#synapse-administration-endpoints) that the Admin API (`/_synapse/admin/...`) should be hidden from public view using a reverse proxy. This means you should typically query the Admin API from a terminal on the machine which runs Synapse. diff --git a/docs/workers.md b/docs/workers.md index 7ee8801161..27e54c5846 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -135,8 +135,8 @@ In the config file for each worker, you must specify: [`worker_replication_http_port`](usage/configuration/config_documentation.md#worker_replication_http_port)). * If handling HTTP requests, a [`worker_listeners`](usage/configuration/config_documentation.md#worker_listeners) option with an `http` listener. - * If handling the `^/_matrix/client/v3/keys/upload` endpoint, the HTTP URI for - the main process (`worker_main_http_uri`). + * **Synapse 1.72 and older:** if handling the `^/_matrix/client/v3/keys/upload` endpoint, the HTTP URI for + the main process (`worker_main_http_uri`). This config option is no longer required and is ignored when running Synapse 1.73 and newer. For example: @@ -221,7 +221,6 @@ information. ^/_matrix/client/(api/v1|r0|v3|unstable)/search$ # Encryption requests - # Note that ^/_matrix/client/(r0|v3|unstable)/keys/upload/ requires `worker_main_http_uri` ^/_matrix/client/(r0|v3|unstable)/keys/query$ ^/_matrix/client/(r0|v3|unstable)/keys/changes$ ^/_matrix/client/(r0|v3|unstable)/keys/claim$ @@ -376,7 +375,7 @@ responsible for - persisting them to the DB, and finally - updating the events stream. -Because load is sharded in this way, you *must* restart all worker instances when +Because load is sharded in this way, you *must* restart all worker instances when adding or removing event persisters. An `event_persister` should not be mistaken for an `event_creator`. diff --git a/mypy.ini b/mypy.ini index 8f1141a239..4cd61e0484 100644 --- a/mypy.ini +++ b/mypy.ini @@ -11,6 +11,7 @@ warn_unused_ignores = True local_partial_types = True no_implicit_optional = True disallow_untyped_defs = True +strict_equality = True files = docker/, @@ -117,6 +118,9 @@ disallow_untyped_defs = True [mypy-tests.state.test_profile] disallow_untyped_defs = True +[mypy-tests.storage.test_id_generators] +disallow_untyped_defs = True + [mypy-tests.storage.test_profile] disallow_untyped_defs = True diff --git a/poetry.lock b/poetry.lock index 8d468adf12..d9e4803a5f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -663,7 +663,7 @@ python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" [[package]] name = "phonenumbers" -version = "8.12.56" +version = "8.13.0" description = "Python version of Google's common library for parsing, formatting, storing and validating international phone numbers." category = "main" optional = false @@ -814,15 +814,15 @@ python-versions = ">=3.6" [[package]] name = "pygithub" -version = "1.56" +version = "1.57" description = "Use the full Github API v3" category = "dev" optional = false -python-versions = ">=3.6" +python-versions = ">=3.7" [package.dependencies] deprecated = "*" -pyjwt = ">=2.0" +pyjwt = ">=2.4.0" pynacl = ">=1.4.0" requests = ">=2.14.0" @@ -1076,7 +1076,7 @@ doc = ["Sphinx", "sphinx-rtd-theme"] [[package]] name = "sentry-sdk" -version = "1.10.1" +version = "1.11.0" description = "Python client for Sentry (https://sentry.io)" category = "main" optional = true @@ -1098,6 +1098,7 @@ fastapi = ["fastapi (>=0.79.0)"] flask = ["blinker (>=1.1)", "flask (>=0.11)"] httpx = ["httpx (>=0.16.0)"] pure-eval = ["asttokens", "executing", "pure-eval"] +pymongo = ["pymongo (>=3.1)"] pyspark = ["pyspark (>=2.4.4)"] quart = ["blinker (>=1.1)", "quart (>=0.16.1)"] rq = ["rq (>=0.6)"] @@ -1256,11 +1257,11 @@ python-versions = ">= 3.5" [[package]] name = "towncrier" -version = "21.9.0" +version = "22.8.0" description = "Building newsfiles for your project." category = "dev" optional = false -python-versions = "*" +python-versions = ">=3.7" [package.dependencies] click = "*" @@ -1268,7 +1269,7 @@ click-default-group = "*" incremental = "*" jinja2 = "*" setuptools = "*" -tomli = {version = "*", markers = "python_version >= \"3.6\""} +tomli = "*" [package.extras] dev = ["packaging"] @@ -1439,7 +1440,7 @@ python-versions = "*" [[package]] name = "types-pillow" -version = "9.2.2.1" +version = "9.3.0.1" description = "Typing stubs for Pillow" category = "dev" optional = false @@ -2257,8 +2258,8 @@ pathspec = [ {file = "pathspec-0.9.0.tar.gz", hash = "sha256:e564499435a2673d586f6b2130bb5b95f04a3ba06f81b8f895b651a3c76aabb1"}, ] phonenumbers = [ - {file = "phonenumbers-8.12.56-py2.py3-none-any.whl", hash = "sha256:80a7422cf0999a6f9b7a2e6cfbdbbfcc56ab5b75414dc3b805bbec91276b64a3"}, - {file = "phonenumbers-8.12.56.tar.gz", hash = "sha256:82a4f226c930d02dcdf6d4b29e4cfd8678991fe65c2efd5fdd143557186f0868"}, + {file = "phonenumbers-8.13.0-py2.py3-none-any.whl", hash = "sha256:dbaea9e4005a976bcf18fbe2bb87cb9cd0a3f119136f04188ac412d7741cebf0"}, + {file = "phonenumbers-8.13.0.tar.gz", hash = "sha256:93745d7afd38e246660bb601b07deac54eeb76c8e5e43f5e83333b0383a0a1e4"}, ] pillow = [ {file = "Pillow-9.3.0-1-cp37-cp37m-win32.whl", hash = "sha256:e6ea6b856a74d560d9326c0f5895ef8050126acfdc7ca08ad703eb0081e82b74"}, @@ -2419,8 +2420,8 @@ pyflakes = [ {file = "pyflakes-2.5.0.tar.gz", hash = "sha256:491feb020dca48ccc562a8c0cbe8df07ee13078df59813b83959cbdada312ea3"}, ] pygithub = [ - {file = "PyGithub-1.56-py3-none-any.whl", hash = "sha256:d15f13d82165306da8a68aefc0f848a6f6432d5febbff13b60a94758ce3ef8b5"}, - {file = "PyGithub-1.56.tar.gz", hash = "sha256:80c6d85cf0f9418ffeb840fd105840af694c4f17e102970badbaf678251f2a01"}, + {file = "PyGithub-1.57-py3-none-any.whl", hash = "sha256:5822febeac2391f1306c55a99af2bc8f86c8bf82ded000030cd02c18f31b731f"}, + {file = "PyGithub-1.57.tar.gz", hash = "sha256:c273f252b278fb81f1769505cc6921bdb6791e1cebd6ac850cc97dad13c31ff3"}, ] pygments = [ {file = "Pygments-2.11.2-py3-none-any.whl", hash = "sha256:44238f1b60a76d78fc8ca0528ee429702aae011c265fe6a8dd8b63049ae41c65"}, @@ -2568,8 +2569,8 @@ semantic-version = [ {file = "semantic_version-2.10.0.tar.gz", hash = "sha256:bdabb6d336998cbb378d4b9db3a4b56a1e3235701dc05ea2690d9a997ed5041c"}, ] sentry-sdk = [ - {file = "sentry-sdk-1.10.1.tar.gz", hash = "sha256:105faf7bd7b7fa25653404619ee261527266b14103fe1389e0ce077bd23a9691"}, - {file = "sentry_sdk-1.10.1-py2.py3-none-any.whl", hash = "sha256:06c0fa9ccfdc80d7e3b5d2021978d6eb9351fa49db9b5847cf4d1f2a473414ad"}, + {file = "sentry-sdk-1.11.0.tar.gz", hash = "sha256:e7b78a1ddf97a5f715a50ab8c3f7a93f78b114c67307785ee828ef67a5d6f117"}, + {file = "sentry_sdk-1.11.0-py2.py3-none-any.whl", hash = "sha256:f467e6c7fac23d4d42bc83eb049c400f756cd2d65ab44f0cc1165d0c7c3d40bc"}, ] service-identity = [ {file = "service-identity-21.1.0.tar.gz", hash = "sha256:6e6c6086ca271dc11b033d17c3a8bea9f24ebff920c587da090afc9519419d34"}, @@ -2720,8 +2721,8 @@ tornado = [ {file = "tornado-6.1.tar.gz", hash = "sha256:33c6e81d7bd55b468d2e793517c909b139960b6c790a60b7991b9b6b76fb9791"}, ] towncrier = [ - {file = "towncrier-21.9.0-py2.py3-none-any.whl", hash = "sha256:fc5a88a2a54988e3a8ed2b60d553599da8330f65722cc607c839614ed87e0f92"}, - {file = "towncrier-21.9.0.tar.gz", hash = "sha256:9cb6f45c16e1a1eec9d0e7651165e7be60cd0ab81d13a5c96ca97a498ae87f48"}, + {file = "towncrier-22.8.0-py2.py3-none-any.whl", hash = "sha256:3b780c3d966e1b26414830aec3d15000654b31e64e024f3e5fd128b4c6eb8f47"}, + {file = "towncrier-22.8.0.tar.gz", hash = "sha256:7d3839b033859b45fb55df82b74cfd702431933c0cc9f287a5a7ea3e05d042cb"}, ] treq = [ {file = "treq-22.2.0-py3-none-any.whl", hash = "sha256:27d95b07c5c14be3e7b280416139b036087617ad5595be913b1f9b3ce981b9b2"}, @@ -2808,8 +2809,8 @@ types-opentracing = [ {file = "types_opentracing-2.4.10-py3-none-any.whl", hash = "sha256:66d9cfbbdc4a6f8ca8189a15ad26f0fe41cee84c07057759c5d194e2505b84c2"}, ] types-pillow = [ - {file = "types-Pillow-9.2.2.1.tar.gz", hash = "sha256:85c139e06e1c46ec5f9c634d5c54a156b0958d5d0e8be024ed353db0c804b426"}, - {file = "types_Pillow-9.2.2.1-py3-none-any.whl", hash = "sha256:3a6a871cade8428433a21ef459bb0a65532b87d05f9e836a0664431ce445bdcf"}, + {file = "types-Pillow-9.3.0.1.tar.gz", hash = "sha256:f3b7cada3fa496c78d75253c6b1f07a843d625f42e5639b320a72acaff6f7cfb"}, + {file = "types_Pillow-9.3.0.1-py3-none-any.whl", hash = "sha256:79837755fe9659f29efd1016e9903ac4a500e0c73260483f07296bd6ca47668b"}, ] types-psycopg2 = [ {file = "types-psycopg2-2.9.21.1.tar.gz", hash = "sha256:f5532cf15afdc6b5ebb1e59b7d896617217321f488fd1fbd74e7efb94decfab6"}, diff --git a/scripts-dev/federation_client.py b/scripts-dev/federation_client.py index 763dd02c47..b1d5e2e616 100755 --- a/scripts-dev/federation_client.py +++ b/scripts-dev/federation_client.py @@ -46,11 +46,12 @@ import signedjson.key import signedjson.types import srvlookup import yaml +from requests import PreparedRequest, Response from requests.adapters import HTTPAdapter from urllib3 import HTTPConnectionPool # uncomment the following to enable debug logging of http requests -# from httplib import HTTPConnection +# from http.client import HTTPConnection # HTTPConnection.debuglevel = 1 @@ -103,6 +104,7 @@ def request( destination: str, path: str, content: Optional[str], + verify_tls: bool, ) -> requests.Response: if method is None: if content is None: @@ -141,7 +143,6 @@ def request( s.mount("matrix://", MatrixConnectionAdapter()) headers: Dict[str, str] = { - "Host": destination, "Authorization": authorization_headers[0], } @@ -152,7 +153,7 @@ def request( method=method, url=dest, headers=headers, - verify=False, + verify=verify_tls, data=content, stream=True, ) @@ -203,6 +204,12 @@ def main() -> None: parser.add_argument("--body", help="Data to send as the body of the HTTP request") parser.add_argument( + "--insecure", + action="store_true", + help="Disable TLS certificate verification", + ) + + parser.add_argument( "path", help="request path, including the '/_matrix/federation/...' prefix." ) @@ -227,6 +234,7 @@ def main() -> None: args.destination, args.path, content=args.body, + verify_tls=not args.insecure, ) sys.stderr.write("Status Code: %d\n" % (result.status_code,)) @@ -254,36 +262,93 @@ def read_args_from_config(args: argparse.Namespace) -> None: class MatrixConnectionAdapter(HTTPAdapter): + def send( + self, + request: PreparedRequest, + *args: Any, + **kwargs: Any, + ) -> Response: + # overrides the send() method in the base class. + + # We need to look for .well-known redirects before passing the request up to + # HTTPAdapter.send(). + assert isinstance(request.url, str) + parsed = urlparse.urlsplit(request.url) + server_name = parsed.netloc + well_known = self._get_well_known(parsed.netloc) + + if well_known: + server_name = well_known + + # replace the scheme in the uri with https, so that cert verification is done + # also replace the hostname if we got a .well-known result + request.url = urlparse.urlunsplit( + ("https", server_name, parsed.path, parsed.query, parsed.fragment) + ) + + # at this point we also add the host header (otherwise urllib will add one + # based on the `host` from the connection returned by `get_connection`, + # which will be wrong if there is an SRV record). + request.headers["Host"] = server_name + + return super().send(request, *args, **kwargs) + + def get_connection( + self, url: str, proxies: Optional[Dict[str, str]] = None + ) -> HTTPConnectionPool: + # overrides the get_connection() method in the base class + parsed = urlparse.urlsplit(url) + (host, port, ssl_server_name) = self._lookup(parsed.netloc) + print( + f"Connecting to {host}:{port} with SNI {ssl_server_name}", file=sys.stderr + ) + return self.poolmanager.connection_from_host( + host, + port=port, + scheme="https", + pool_kwargs={"server_hostname": ssl_server_name}, + ) + @staticmethod - def lookup(s: str, skip_well_known: bool = False) -> Tuple[str, int]: - if s[-1] == "]": + def _lookup(server_name: str) -> Tuple[str, int, str]: + """ + Do an SRV lookup on a server name and return the host:port to connect to + Given the server_name (after any .well-known lookup), return the host, port and + the ssl server name + """ + if server_name[-1] == "]": # ipv6 literal (with no port) - return s, 8448 + return server_name, 8448, server_name - if ":" in s: - out = s.rsplit(":", 1) + if ":" in server_name: + # explicit port + out = server_name.rsplit(":", 1) try: port = int(out[1]) except ValueError: - raise ValueError("Invalid host:port '%s'" % s) - return out[0], port - - # try a .well-known lookup - if not skip_well_known: - well_known = MatrixConnectionAdapter.get_well_known(s) - if well_known: - return MatrixConnectionAdapter.lookup(well_known, skip_well_known=True) + raise ValueError("Invalid host:port '%s'" % (server_name,)) + return out[0], port, out[0] try: - srv = srvlookup.lookup("matrix", "tcp", s)[0] - return srv.host, srv.port + srv = srvlookup.lookup("matrix", "tcp", server_name)[0] + print( + f"SRV lookup on _matrix._tcp.{server_name} gave {srv}", + file=sys.stderr, + ) + return srv.host, srv.port, server_name except Exception: - return s, 8448 + return server_name, 8448, server_name @staticmethod - def get_well_known(server_name: str) -> Optional[str]: - uri = "https://%s/.well-known/matrix/server" % (server_name,) - print("fetching %s" % (uri,), file=sys.stderr) + def _get_well_known(server_name: str) -> Optional[str]: + if ":" in server_name: + # explicit port, or ipv6 literal. Either way, no .well-known + return None + + # TODO: check for ipv4 literals + + uri = f"https://{server_name}/.well-known/matrix/server" + print(f"fetching {uri}", file=sys.stderr) try: resp = requests.get(uri) @@ -304,19 +369,6 @@ class MatrixConnectionAdapter(HTTPAdapter): print("Invalid response from %s: %s" % (uri, e), file=sys.stderr) return None - def get_connection( - self, url: str, proxies: Optional[Dict[str, str]] = None - ) -> HTTPConnectionPool: - parsed = urlparse.urlparse(url) - - (host, port) = self.lookup(parsed.netloc) - netloc = "%s:%d" % (host, port) - print("Connecting to %s" % (netloc,), file=sys.stderr) - url = urlparse.urlunparse( - ("https", netloc, parsed.path, parsed.params, parsed.query, parsed.fragment) - ) - return super().get_connection(url, proxies) - if __name__ == "__main__": main() diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 400dd12aba..e2cfcea0f2 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -713,7 +713,7 @@ class HttpResponseException(CodeMessageException): set to the reason code from the HTTP response. Returns: - SynapseError: + The error converted to a SynapseError. """ # try to parse the body as json, to get better errcode/msg, but # default to M_UNKNOWN with the HTTP status as the error text diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 1d9aef45c2..74909b7d4a 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -14,14 +14,12 @@ # limitations under the License. import logging import sys -from typing import Dict, List, Optional, Tuple +from typing import Dict, List -from twisted.internet import address from twisted.web.resource import Resource import synapse import synapse.events -from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError from synapse.api.urls import ( CLIENT_API_PREFIX, FEDERATION_PREFIX, @@ -43,8 +41,6 @@ from synapse.config.logger import setup_logging from synapse.config.server import ListenerConfig from synapse.federation.transport.server import TransportLayerServer from synapse.http.server import JsonResource, OptionsResource -from synapse.http.servlet import RestServlet, parse_json_object_from_request -from synapse.http.site import SynapseRequest from synapse.logging.context import LoggingContext from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource @@ -70,12 +66,12 @@ from synapse.rest.client import ( versions, voip, ) -from synapse.rest.client._base import client_patterns from synapse.rest.client.account import ThreepidRestServlet, WhoamiRestServlet from synapse.rest.client.devices import DevicesRestServlet from synapse.rest.client.keys import ( KeyChangesServlet, KeyQueryServlet, + KeyUploadServlet, OneTimeKeyServlet, ) from synapse.rest.client.register import ( @@ -132,107 +128,12 @@ from synapse.storage.databases.main.transactions import TransactionWorkerStore from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore from synapse.storage.databases.main.user_directory import UserDirectoryStore from synapse.storage.databases.main.user_erasure_store import UserErasureWorkerStore -from synapse.types import JsonDict from synapse.util import SYNAPSE_VERSION from synapse.util.httpresourcetree import create_resource_tree logger = logging.getLogger("synapse.app.generic_worker") -class KeyUploadServlet(RestServlet): - """An implementation of the `KeyUploadServlet` that responds to read only - requests, but otherwise proxies through to the master instance. - """ - - PATTERNS = client_patterns("/keys/upload(/(?P<device_id>[^/]+))?$") - - def __init__(self, hs: HomeServer): - """ - Args: - hs: server - """ - super().__init__() - self.auth = hs.get_auth() - self.store = hs.get_datastores().main - self.http_client = hs.get_simple_http_client() - self.main_uri = hs.config.worker.worker_main_http_uri - - async def on_POST( - self, request: SynapseRequest, device_id: Optional[str] - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - user_id = requester.user.to_string() - body = parse_json_object_from_request(request) - - if device_id is not None: - # passing the device_id here is deprecated; however, we allow it - # for now for compatibility with older clients. - if requester.device_id is not None and device_id != requester.device_id: - logger.warning( - "Client uploading keys for a different device " - "(logged in as %s, uploading for %s)", - requester.device_id, - device_id, - ) - else: - device_id = requester.device_id - - if device_id is None: - raise SynapseError( - 400, "To upload keys, you must pass device_id when authenticating" - ) - - if body: - # They're actually trying to upload something, proxy to main synapse. - - # Proxy headers from the original request, such as the auth headers - # (in case the access token is there) and the original IP / - # User-Agent of the request. - headers: Dict[bytes, List[bytes]] = { - header: list(request.requestHeaders.getRawHeaders(header, [])) - for header in (b"Authorization", b"User-Agent") - } - # Add the previous hop to the X-Forwarded-For header. - x_forwarded_for = list( - request.requestHeaders.getRawHeaders(b"X-Forwarded-For", []) - ) - # we use request.client here, since we want the previous hop, not the - # original client (as returned by request.getClientAddress()). - if isinstance(request.client, (address.IPv4Address, address.IPv6Address)): - previous_host = request.client.host.encode("ascii") - # If the header exists, add to the comma-separated list of the first - # instance of the header. Otherwise, generate a new header. - if x_forwarded_for: - x_forwarded_for = [x_forwarded_for[0] + b", " + previous_host] - x_forwarded_for.extend(x_forwarded_for[1:]) - else: - x_forwarded_for = [previous_host] - headers[b"X-Forwarded-For"] = x_forwarded_for - - # Replicate the original X-Forwarded-Proto header. Note that - # XForwardedForRequest overrides isSecure() to give us the original protocol - # used by the client, as opposed to the protocol used by our upstream proxy - # - which is what we want here. - headers[b"X-Forwarded-Proto"] = [ - b"https" if request.isSecure() else b"http" - ] - - try: - result = await self.http_client.post_json_get_json( - self.main_uri + request.uri.decode("ascii"), body, headers=headers - ) - except HttpResponseException as e: - raise e.to_synapse_error() from e - except RequestSendFailed as e: - raise SynapseError(502, "Failed to talk to master") from e - - return 200, result - else: - # Just interested in counts. - result = await self.store.count_e2e_one_time_keys(user_id, device_id) - return 200, {"one_time_key_counts": result} - - class GenericWorkerSlavedStore( # FIXME(#3714): We need to add UserDirectoryStore as we write directly # rather than going via the correct worker. diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 94d1150415..5468b963a2 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -317,10 +317,9 @@ def setup_logging( Set up the logging subsystem. Args: - config (LoggingConfig | synapse.config.worker.WorkerConfig): - configuration data + config: configuration data - use_worker_options (bool): True to use the 'worker_log_config' option + use_worker_options: True to use the 'worker_log_config' option instead of 'log_config'. logBeginner: The Twisted logBeginner to use. diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 1ed001e105..5c13fe428a 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -150,8 +150,5 @@ class RatelimitConfig(Config): self.rc_third_party_invite = RatelimitSettings( config.get("rc_third_party_invite", {}), - defaults={ - "per_second": self.rc_message.per_second, - "burst_count": self.rc_message.burst_count, - }, + defaults={"per_second": 0.0025, "burst_count": 5}, ) diff --git a/synapse/config/workers.py b/synapse/config/workers.py index 88b3168cbc..913b83e174 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -162,7 +162,13 @@ class WorkerConfig(Config): self.worker_name = config.get("worker_name", self.worker_app) self.instance_name = self.worker_name or "master" + # FIXME: Remove this check after a suitable amount of time. self.worker_main_http_uri = config.get("worker_main_http_uri", None) + if self.worker_main_http_uri is not None: + logger.warning( + "The config option worker_main_http_uri is unused since Synapse 1.73. " + "It can be safely removed from your configuration." + ) # This option is really only here to support `--manhole` command line # argument. diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index c88afb2986..ed15f88350 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -213,7 +213,7 @@ class Keyring: def verify_json_objects_for_server( self, server_and_json: Iterable[Tuple[str, dict, int]] - ) -> List[defer.Deferred]: + ) -> List["defer.Deferred[None]"]: """Bulk verifies signatures of json objects, bulk fetching keys as necessary. @@ -226,10 +226,9 @@ class Keyring: valid. Returns: - List<Deferred[None]>: for each input triplet, a deferred indicating success - or failure to verify each json object's signature for the given - server_name. The deferreds run their callbacks in the sentinel - logcontext. + For each input triplet, a deferred indicating success or failure to + verify each json object's signature for the given server_name. The + deferreds run their callbacks in the sentinel logcontext. """ return [ run_in_background( @@ -858,7 +857,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher): response = await self.client.get_json( destination=server_name, path="/_matrix/key/v2/server/" - + urllib.parse.quote(requested_key_id), + + urllib.parse.quote(requested_key_id, safe=""), ignore_backoff=True, # we only give the remote server 10s to respond. It should be an # easy request to handle, so if it doesn't reply within 10s, it's diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 030c3ca408..8aca9a3ab9 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -597,8 +597,7 @@ def _event_type_from_format_version( format_version: The event format version Returns: - type: A type that can be initialized as per the initializer of - `FrozenEvent` + A type that can be initialized as per the initializer of `FrozenEvent` """ if format_version == EventFormatVersions.ROOM_V1_V2: diff --git a/synapse/events/builder.py b/synapse/events/builder.py index e2ee10dd3d..d62906043f 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -128,6 +128,7 @@ class EventBuilder: state_filter=StateFilter.from_types( auth_types_for_event(self.room_version, self) ), + await_full_state=False, ) auth_event_ids = self._event_auth_handler.compute_auth_events( self, state_ids diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 084c45a95c..3ae5e8634c 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -505,6 +505,7 @@ class PerDestinationQueue: new_pdus = await filter_events_for_server( self._storage_controllers, self._destination, + self._server_name, new_pdus, redact=False, ) diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index cd39d4d111..a3cfc701cd 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -280,12 +280,11 @@ class TransportLayerClient: Note that this does not append any events to any graphs. Args: - destination (str): address of remote homeserver - room_id (str): room to join/leave - user_id (str): user to be joined/left - membership (str): one of join/leave - params (dict[str, str|Iterable[str]]): Query parameters to include in the - request. + destination: address of remote homeserver + room_id: room to join/leave + user_id: user to be joined/left + membership: one of join/leave + params: Query parameters to include in the request. Returns: Succeeds when we get a 2xx HTTP response. The result diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py index 1db8009d6c..cdaf0d5de7 100644 --- a/synapse/federation/transport/server/_base.py +++ b/synapse/federation/transport/server/_base.py @@ -224,10 +224,10 @@ class BaseFederationServlet: With arguments: - origin (unicode|None): The authenticated server_name of the calling server, + origin (str|None): The authenticated server_name of the calling server, unless REQUIRE_AUTH is set to False and authentication failed. - content (unicode|None): decoded json body of the request. None if the + content (str|None): decoded json body of the request. None if the request was a GET. query (dict[bytes, list[bytes]]): Query params from the request. url-decoded diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index a9912c467d..bf1221f523 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -870,7 +870,7 @@ class E2eKeysHandler: - signatures of the user's master key by the user's devices. Args: - user_id (string): the user uploading the keys + user_id: the user uploading the keys signatures (dict[string, dict]): map of devices to signed keys Returns: diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index 28dc08c22a..83f53ceb88 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -377,8 +377,9 @@ class E2eRoomKeysHandler: """Deletes a given version of the user's e2e_room_keys backup Args: - user_id(str): the user whose current backup version we're deleting - version(str): the version id of the backup being deleted + user_id: the user whose current backup version we're deleting + version: Optional. the version ID of the backup version we're deleting + If missing, we delete the current backup version info. Raises: NotFoundError: if this backup version doesn't exist """ diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 5fc3b8bc8c..d92582fd5c 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -379,6 +379,7 @@ class FederationHandler: filtered_extremities = await filter_events_for_server( self._storage_controllers, self.server_name, + self.server_name, events_to_check, redact=False, check_history_visibility_only=True, @@ -1231,7 +1232,9 @@ class FederationHandler: async def on_backfill_request( self, origin: str, room_id: str, pdu_list: List[str], limit: int ) -> List[EventBase]: - await self._event_auth_handler.assert_host_in_room(room_id, origin) + # We allow partially joined rooms since in this case we are filtering out + # non-local events in `filter_events_for_server`. + await self._event_auth_handler.assert_host_in_room(room_id, origin, True) # Synapse asks for 100 events per backfill request. Do not allow more. limit = min(limit, 100) @@ -1252,7 +1255,7 @@ class FederationHandler: ) events = await filter_events_for_server( - self._storage_controllers, origin, events + self._storage_controllers, origin, self.server_name, events ) return events @@ -1283,7 +1286,7 @@ class FederationHandler: await self._event_auth_handler.assert_host_in_room(event.room_id, origin) events = await filter_events_for_server( - self._storage_controllers, origin, [event] + self._storage_controllers, origin, self.server_name, [event] ) event = events[0] return event @@ -1296,7 +1299,9 @@ class FederationHandler: latest_events: List[str], limit: int, ) -> List[EventBase]: - await self._event_auth_handler.assert_host_in_room(room_id, origin) + # We allow partially joined rooms since in this case we are filtering out + # non-local events in `filter_events_for_server`. + await self._event_auth_handler.assert_host_in_room(room_id, origin, True) # Only allow up to 20 events to be retrieved per request. limit = min(limit, 20) @@ -1309,7 +1314,7 @@ class FederationHandler: ) missing_events = await filter_events_for_server( - self._storage_controllers, origin, missing_events + self._storage_controllers, origin, self.server_name, missing_events ) return missing_events @@ -1596,8 +1601,8 @@ class FederationHandler: Fetch the complexity of a remote room over federation. Args: - remote_room_hosts (list[str]): The remote servers to ask. - room_id (str): The room ID to ask about. + remote_room_hosts: The remote servers to ask. + room_id: The room ID to ask about. Returns: Dict contains the complexity diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 93d09e9939..848e46eb9b 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -711,7 +711,7 @@ class IdentityHandler: inviter_display_name: The current display name of the inviter. inviter_avatar_url: The URL of the inviter's avatar. - id_access_token (str): The access token to authenticate to the identity + id_access_token: The access token to authenticate to the identity server with Returns: diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index 867973dcca..41c675f408 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -787,7 +787,7 @@ class OidcProvider: Must include an ``access_token`` field. Returns: - UserInfo: an object representing the user. + an object representing the user. """ logger.debug("Using the OAuth2 access_token to request userinfo") metadata = await self.load_metadata() diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 0066d63987..cf08737d11 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -201,7 +201,7 @@ class BasePresenceHandler(abc.ABC): """Get the current presence state for multiple users. Returns: - dict: `user_id` -> `UserPresenceState` + A mapping of `user_id` -> `UserPresenceState` """ states = {} missing = [] @@ -478,7 +478,7 @@ class WorkerPresenceHandler(BasePresenceHandler): return _NullContextManager() prev_state = await self.current_state_for_user(user_id) - if prev_state != PresenceState.BUSY: + if prev_state.state != PresenceState.BUSY: # We set state here but pass ignore_status_msg = True as we don't want to # cause the status message to be cleared. # Note that this causes last_active_ts to be incremented which is not diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py index 9602f0d0bb..874860d461 100644 --- a/synapse/handlers/saml.py +++ b/synapse/handlers/saml.py @@ -441,7 +441,7 @@ class DefaultSamlMappingProvider: client_redirect_url: where the client wants to redirect to Returns: - dict: A dict containing new user attributes. Possible keys: + A dict containing new user attributes. Possible keys: * mxid_localpart (str): Required. The localpart of the user's mxid * displayname (str): The displayname of the user * emails (list[str]): Any emails for the user @@ -483,7 +483,7 @@ class DefaultSamlMappingProvider: Args: config: A dictionary containing configuration options for this provider Returns: - SamlConfig: A custom config object for this module + A custom config object for this module """ # Parse config options and use defaults where necessary mxid_source_attribute = config.get("mxid_source_attribute", "uid") diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py index 6a9f6635d2..8729630581 100644 --- a/synapse/http/additional_resource.py +++ b/synapse/http/additional_resource.py @@ -45,8 +45,7 @@ class AdditionalResource(DirectServeJsonResource): Args: hs: homeserver - handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred): - function to be called to handle the request. + handler: function to be called to handle the request. """ super().__init__() self._handler = handler diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 2f0177f1e2..0359231e7d 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -155,11 +155,10 @@ class MatrixFederationAgent: a file for a file upload). Or None if the request is to have no body. Returns: - Deferred[twisted.web.iweb.IResponse]: - fires when the header of the response has been received (regardless of the - response status code). Fails if there is any problem which prevents that - response from being received (including problems that prevent the request - from being sent). + A deferred which fires when the header of the response has been received + (regardless of the response status code). Fails if there is any problem + which prevents that response from being received (including problems that + prevent the request from being sent). """ # We use urlparse as that will set `port` to None if there is no # explicit port. diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 3c35b1d2c7..b92f1d3d1a 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -951,8 +951,7 @@ class MatrixFederationHttpClient: args: query params Returns: - dict|list: Succeeds when we get a 2xx HTTP response. The - result will be the decoded JSON body. + Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. Raises: HttpResponseException: If we get an HTTP response code >= 300 diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py index 1f8227896f..18899bc6d1 100644 --- a/synapse/http/proxyagent.py +++ b/synapse/http/proxyagent.py @@ -34,7 +34,7 @@ from twisted.web.client import ( ) from twisted.web.error import SchemeNotSupported from twisted.web.http_headers import Headers -from twisted.web.iweb import IAgent, IBodyProducer, IPolicyForHTTPS +from twisted.web.iweb import IAgent, IBodyProducer, IPolicyForHTTPS, IResponse from synapse.http import redact_uri from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint, ProxyCredentials @@ -134,7 +134,7 @@ class ProxyAgent(_AgentBase): uri: bytes, headers: Optional[Headers] = None, bodyProducer: Optional[IBodyProducer] = None, - ) -> defer.Deferred: + ) -> "defer.Deferred[IResponse]": """ Issue a request to the server indicated by the given uri. @@ -157,17 +157,17 @@ class ProxyAgent(_AgentBase): a file upload). Or, None if the request is to have no body. Returns: - Deferred[IResponse]: completes when the header of the response has - been received (regardless of the response status code). + A deferred which completes when the header of the response has + been received (regardless of the response status code). - Can fail with: - SchemeNotSupported: if the uri is not http or https + Can fail with: + SchemeNotSupported: if the uri is not http or https - twisted.internet.error.TimeoutError if the server we are connecting - to (proxy or destination) does not accept a connection before - connectTimeout. + twisted.internet.error.TimeoutError if the server we are connecting + to (proxy or destination) does not accept a connection before + connectTimeout. - ... other things too. + ... other things too. """ uri = uri.strip() if not _VALID_URI.match(uri): diff --git a/synapse/http/server.py b/synapse/http/server.py index b26e34bceb..051a1899a0 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -267,7 +267,7 @@ class HttpServer(Protocol): request. The first argument will be the request object and subsequent arguments will be any matched groups from the regex. This should return either tuple of (code, response), or None. - servlet_classname (str): The name of the handler to be used in prometheus + servlet_classname: The name of the handler to be used in prometheus and opentracing logs. """ diff --git a/synapse/http/site.py b/synapse/http/site.py index 3dbd541fed..6a1dbf7f33 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -400,7 +400,7 @@ class SynapseRequest(Request): be sure to call finished_processing. Args: - servlet_name (str): the name of the servlet which will be + servlet_name: the name of the servlet which will be processing this request. This is used in the metrics. It is possible to update this afterwards by updating diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 6a08ffed64..f62bea968f 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -117,8 +117,7 @@ class ContextResourceUsage: """Create a new ContextResourceUsage Args: - copy_from (ContextResourceUsage|None): if not None, an object to - copy stats from + copy_from: if not None, an object to copy stats from """ if copy_from is None: self.reset() @@ -162,7 +161,7 @@ class ContextResourceUsage: """Add another ContextResourceUsage's stats to this one's. Args: - other (ContextResourceUsage): the other resource usage object + other: the other resource usage object """ self.ru_utime += other.ru_utime self.ru_stime += other.ru_stime @@ -342,7 +341,7 @@ class LoggingContext: called directly. Returns: - LoggingContext: the current logging context + The current logging context """ warnings.warn( "synapse.logging.context.LoggingContext.current_context() is deprecated " @@ -362,7 +361,8 @@ class LoggingContext: called directly. Args: - context(LoggingContext): The context to activate. + context: The context to activate. + Returns: The context that was previously active """ @@ -474,8 +474,7 @@ class LoggingContext: """Get resources used by this logcontext so far. Returns: - ContextResourceUsage: a *copy* of the object tracking resource - usage so far + A *copy* of the object tracking resource usage so far """ # we always return a copy, for consistency res = self._resource_usage.copy() @@ -663,7 +662,8 @@ def current_context() -> LoggingContextOrSentinel: def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSentinel: """Set the current logging context in thread local storage Args: - context(LoggingContext): The context to activate. + context: The context to activate. + Returns: The context that was previously active """ @@ -700,7 +700,7 @@ def nested_logging_context(suffix: str) -> LoggingContext: suffix: suffix to add to the parent context's 'name'. Returns: - LoggingContext: new logging context. + A new logging context. """ curr_context = current_context() if not curr_context: @@ -898,20 +898,19 @@ def defer_to_thread( on it. Args: - reactor (twisted.internet.base.ReactorBase): The reactor in whose main thread - the Deferred will be invoked, and whose threadpool we should use for the - function. + reactor: The reactor in whose main thread the Deferred will be invoked, + and whose threadpool we should use for the function. Normally this will be hs.get_reactor(). - f (callable): The function to call. + f: The function to call. args: positional arguments to pass to f. kwargs: keyword arguments to pass to f. Returns: - Deferred: A Deferred which fires a callback with the result of `f`, or an + A Deferred which fires a callback with the result of `f`, or an errback if `f` throws an exception. """ return defer_to_threadpool(reactor, reactor.getThreadPool(), f, *args, **kwargs) @@ -939,20 +938,20 @@ def defer_to_threadpool( on it. Args: - reactor (twisted.internet.base.ReactorBase): The reactor in whose main thread - the Deferred will be invoked. Normally this will be hs.get_reactor(). + reactor: The reactor in whose main thread the Deferred will be invoked. + Normally this will be hs.get_reactor(). - threadpool (twisted.python.threadpool.ThreadPool): The threadpool to use for - running `f`. Normally this will be hs.get_reactor().getThreadPool(). + threadpool: The threadpool to use for running `f`. Normally this will be + hs.get_reactor().getThreadPool(). - f (callable): The function to call. + f: The function to call. args: positional arguments to pass to f. kwargs: keyword arguments to pass to f. Returns: - Deferred: A Deferred which fires a callback with the result of `f`, or an + A Deferred which fires a callback with the result of `f`, or an errback if `f` throws an exception. """ curr_context = current_context() diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 8ce5a2a338..b69060854f 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -721,7 +721,7 @@ def inject_header_dict( destination: address of entity receiving the span context. Must be given unless check_destination is False. The context will only be injected if the destination matches the opentracing whitelist - check_destination (bool): If false, destination will be ignored and the context + check_destination: If false, destination will be ignored and the context will always be injected. Note: @@ -780,7 +780,7 @@ def get_active_span_text_map(destination: Optional[str] = None) -> Dict[str, str destination: the name of the remote server. Returns: - dict: the active span's context if opentracing is enabled, otherwise empty. + the active span's context if opentracing is enabled, otherwise empty. """ if destination and not whitelisted_homeserver(destination): diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 30e689d00d..1adc1fd64f 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -787,7 +787,7 @@ class ModuleApi: Added in Synapse v0.25.0. Args: - access_token(str): access token + access_token: access token Returns: twisted.internet.defer.Deferred - resolves once the access token @@ -832,7 +832,7 @@ class ModuleApi: **kwargs: named args to be passed to func Returns: - Deferred[object]: result of func + Result of func """ # type-ignore: See https://github.com/python/mypy/issues/8862 return defer.ensureDeferred( @@ -924,8 +924,7 @@ class ModuleApi: to represent 'any') of the room state to acquire. Returns: - twisted.internet.defer.Deferred[list(synapse.events.FrozenEvent)]: - The filtered state events in the room. + The filtered state events in the room. """ state_ids = yield defer.ensureDeferred( self._storage_controllers.state.get_current_state_ids( diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 5e661f8c73..3f4d3fc51a 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -153,7 +153,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): argument list. Returns: - dict: If POST/PUT request then dictionary must be JSON serialisable, + If POST/PUT request then dictionary must be JSON serialisable, otherwise must be appropriate for adding as query args. """ return {} diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py index 3d63645726..c21629def8 100644 --- a/synapse/replication/http/devices.py +++ b/synapse/replication/http/devices.py @@ -18,6 +18,7 @@ from typing import TYPE_CHECKING, Tuple from twisted.web.server import Request from synapse.http.server import HttpServer +from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint from synapse.types import JsonDict @@ -78,5 +79,71 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): return 200, user_devices +class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint): + """Ask master to upload keys for the user and send them out over federation to + update other servers. + + For now, only the master is permitted to handle key upload requests; + any worker can handle key query requests (since they're read-only). + + Calls to e2e_keys_handler.upload_keys_for_user(user_id, device_id, keys) on + the main process to accomplish this. + + Defined in https://spec.matrix.org/v1.4/client-server-api/#post_matrixclientv3keysupload + Request format(borrowed and expanded from KeyUploadServlet): + + POST /_synapse/replication/upload_keys_for_user + + { + "user_id": "<user_id>", + "device_id": "<device_id>", + "keys": { + ....this part can be found in KeyUploadServlet in rest/client/keys.py.... + } + } + + Response is equivalent to ` /_matrix/client/v3/keys/upload` found in KeyUploadServlet + + """ + + NAME = "upload_keys_for_user" + PATH_ARGS = () + CACHE = False + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self.e2e_keys_handler = hs.get_e2e_keys_handler() + self.store = hs.get_datastores().main + self.clock = hs.get_clock() + + @staticmethod + async def _serialize_payload( # type: ignore[override] + user_id: str, device_id: str, keys: JsonDict + ) -> JsonDict: + + return { + "user_id": user_id, + "device_id": device_id, + "keys": keys, + } + + async def _handle_request( # type: ignore[override] + self, request: Request + ) -> Tuple[int, JsonDict]: + content = parse_json_object_from_request(request) + + user_id = content["user_id"] + device_id = content["device_id"] + keys = content["keys"] + + results = await self.e2e_keys_handler.upload_keys_for_user( + user_id, device_id, keys + ) + + return 200, results + + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ReplicationUserDevicesResyncRestServlet(hs).register(http_server) + ReplicationUploadKeysForUserRestServlet(hs).register(http_server) diff --git a/synapse/replication/slave/__init__.py b/synapse/replication/slave/__init__.py deleted file mode 100644 index f43a360a80..0000000000 --- a/synapse/replication/slave/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/synapse/replication/slave/storage/__init__.py b/synapse/replication/slave/storage/__init__.py deleted file mode 100644 index f43a360a80..0000000000 --- a/synapse/replication/slave/storage/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py deleted file mode 100644 index 8f3f953ed4..0000000000 --- a/synapse/replication/slave/storage/_slaved_id_tracker.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import List, Optional, Tuple - -from synapse.storage.database import LoggingDatabaseConnection -from synapse.storage.util.id_generators import AbstractStreamIdTracker, _load_current_id - - -class SlavedIdTracker(AbstractStreamIdTracker): - """Tracks the "current" stream ID of a stream with a single writer. - - See `AbstractStreamIdTracker` for more details. - - Note that this class does not work correctly when there are multiple - writers. - """ - - def __init__( - self, - db_conn: LoggingDatabaseConnection, - table: str, - column: str, - extra_tables: Optional[List[Tuple[str, str]]] = None, - step: int = 1, - ): - self.step = step - self._current = _load_current_id(db_conn, table, column, step) - if extra_tables: - for table, column in extra_tables: - self.advance(None, _load_current_id(db_conn, table, column)) - - def advance(self, instance_name: Optional[str], new_id: int) -> None: - self._current = (max if self.step > 0 else min)(self._current, new_id) - - def get_current_token(self) -> int: - return self._current - - def get_current_token_for_writer(self, instance_name: str) -> int: - return self.get_current_token() diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 7763ffb2d0..56a5c21910 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -245,7 +245,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self._parse_and_dispatch_line(line) def _parse_and_dispatch_line(self, line: bytes) -> None: - if line.strip() == "": + if line.strip() == b"": # Ignore blank lines return diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 1951b8a9f2..6e0c44be2a 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -903,8 +903,9 @@ class PushersRestServlet(RestServlet): @user:server/pushers Returns: - pushers: Dictionary containing pushers information. - total: Number of pushers in dictionary `pushers`. + A dictionary with keys: + pushers: Dictionary containing pushers information. + total: Number of pushers in dictionary `pushers`. """ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/pushers$") diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py index f653d2a3e1..ee038c7192 100644 --- a/synapse/rest/client/keys.py +++ b/synapse/rest/client/keys.py @@ -27,6 +27,7 @@ from synapse.http.servlet import ( ) from synapse.http.site import SynapseRequest from synapse.logging.opentracing import log_kv, set_tag +from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet from synapse.rest.client._base import client_patterns, interactive_auth_handler from synapse.types import JsonDict, StreamToken from synapse.util.cancellation import cancellable @@ -43,24 +44,48 @@ class KeyUploadServlet(RestServlet): Content-Type: application/json { - "device_keys": { - "user_id": "<user_id>", - "device_id": "<device_id>", - "valid_until_ts": <millisecond_timestamp>, - "algorithms": [ - "m.olm.curve25519-aes-sha2", - ] - "keys": { - "<algorithm>:<device_id>": "<key_base64>", + "device_keys": { + "user_id": "<user_id>", + "device_id": "<device_id>", + "valid_until_ts": <millisecond_timestamp>, + "algorithms": [ + "m.olm.curve25519-aes-sha2", + ] + "keys": { + "<algorithm>:<device_id>": "<key_base64>", + }, + "signatures:" { + "<user_id>" { + "<algorithm>:<device_id>": "<signature_base64>" + } + } + }, + "fallback_keys": { + "<algorithm>:<device_id>": "<key_base64>", + "signed_<algorithm>:<device_id>": { + "fallback": true, + "key": "<key_base64>", + "signatures": { + "<user_id>": { + "<algorithm>:<device_id>": "<key_base64>" + } + } + } + } + "one_time_keys": { + "<algorithm>:<key_id>": "<key_base64>" }, - "signatures:" { - "<user_id>" { - "<algorithm>:<device_id>": "<signature_base64>" - } } }, - "one_time_keys": { - "<algorithm>:<key_id>": "<key_base64>" - }, } + + response, e.g.: + + { + "one_time_key_counts": { + "curve25519": 10, + "signed_curve25519": 20 + } + } + """ PATTERNS = client_patterns("/keys/upload(/(?P<device_id>[^/]+))?$") @@ -71,6 +96,13 @@ class KeyUploadServlet(RestServlet): self.e2e_keys_handler = hs.get_e2e_keys_handler() self.device_handler = hs.get_device_handler() + if hs.config.worker.worker_app is None: + # if main process + self.key_uploader = self.e2e_keys_handler.upload_keys_for_user + else: + # then a worker + self.key_uploader = ReplicationUploadKeysForUserRestServlet.make_client(hs) + async def on_POST( self, request: SynapseRequest, device_id: Optional[str] ) -> Tuple[int, JsonDict]: @@ -109,8 +141,8 @@ class KeyUploadServlet(RestServlet): 400, "To upload keys, you must pass device_id when authenticating" ) - result = await self.e2e_keys_handler.upload_keys_for_user( - user_id, device_id, body + result = await self.key_uploader( + user_id=user_id, device_id=device_id, keys=body ) return 200, result diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index 05706b598c..8adced41e5 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -350,7 +350,7 @@ class LoginRestServlet(RestServlet): auth_provider_session_id: The session ID got during login from the SSO IdP. Returns: - result: Dictionary of account information after successful login. + Dictionary of account information after successful login. """ # Before we actually log them in we check if they've already logged in diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 328c0c5477..40b0d39eb2 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -344,8 +344,8 @@ class MediaRepository: download from remote server. Args: - server_name (str): Remote server_name where the media originated. - media_id (str): The media ID of the content (as defined by the + server_name: Remote server_name where the media originated. + media_id: The media ID of the content (as defined by the remote server). Returns: diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index 9b93b9b4f6..a48a4de92a 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -138,7 +138,7 @@ class Thumbnailer: """Rescales the image to the given dimensions. Returns: - BytesIO: the bytes of the encoded image ready to be written to disk + The bytes of the encoded image ready to be written to disk """ with self._resize(width, height) as scaled: return self._encode_image(scaled, output_type) @@ -155,7 +155,7 @@ class Thumbnailer: max_height: The largest possible height. Returns: - BytesIO: the bytes of the encoded image ready to be written to disk + The bytes of the encoded image ready to be written to disk """ if width * self.height > height * self.width: scaled_width = width diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py index 698ca742ed..94025ba41f 100644 --- a/synapse/server_notices/consent_server_notices.py +++ b/synapse/server_notices/consent_server_notices.py @@ -113,9 +113,8 @@ def copy_with_str_subst(x: Any, substitutions: Any) -> Any: """Deep-copy a structure, carrying out string substitutions on any strings Args: - x (object): structure to be copied - substitutions (object): substitutions to be made - passed into the - string '%' operator + x: structure to be copied + substitutions: substitutions to be made - passed into the string '%' operator Returns: copy of x diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index 3134cd2d3d..a31a2c99a7 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -170,11 +170,13 @@ class ResourceLimitsServerNotices: room_id: The room id of the server notices room Returns: - bool: Is the room currently blocked - list: The list of pinned event IDs that are unrelated to limit blocking - This list can be used as a convenience in the case where the block - is to be lifted and the remaining pinned event references need to be - preserved + Tuple of: + Is the room currently blocked + + The list of pinned event IDs that are unrelated to limit blocking + This list can be used as a convenience in the case where the block + is to be lifted and the remaining pinned event references need to be + preserved """ currently_blocked = False pinned_state_event = None diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 6f3dd0463e..833ffec3de 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -190,6 +190,7 @@ class StateHandler: room_id: str, event_ids: Collection[str], state_filter: Optional[StateFilter] = None, + await_full_state: bool = True, ) -> StateMap[str]: """Fetch the state after each of the given event IDs. Resolve them and return. @@ -200,13 +201,18 @@ class StateHandler: Args: room_id: the room_id containing the given events. event_ids: the events whose state should be fetched and resolved. + await_full_state: if `True`, will block if we do not yet have complete state + at the given `event_id`s, regardless of whether `state_filter` is + satisfied by partial state. Returns: the state dict (a mapping from (event_type, state_key) -> event_id) which holds the resolution of the states after the given event IDs. """ logger.debug("calling resolve_state_groups from compute_state_after_events") - ret = await self.resolve_state_groups_for_events(room_id, event_ids) + ret = await self.resolve_state_groups_for_events( + room_id, event_ids, await_full_state + ) return await ret.get_state(self._state_storage_controller, state_filter) async def get_current_user_ids_in_room( diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index 48976dc570..33ffef521b 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -204,9 +204,8 @@ class _EventPeristenceQueue(Generic[_PersistResult]): process to to so, calling the per_item_callback for each item. Args: - room_id (str): - task (_EventPersistQueueTask): A _PersistEventsTask or - _UpdateCurrentStateTask to process. + room_id: + task: A _PersistEventsTask or _UpdateCurrentStateTask to process. Returns: the result returned by the `_per_item_callback` passed to diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 4717c9728a..0dc44b246c 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -569,15 +569,15 @@ class DatabasePool: retcols=["update_name"], desc="check_background_updates", ) - updates = [x["update_name"] for x in updates] + background_update_names = [x["update_name"] for x in updates] for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items(): - if update_name not in updates: + if update_name not in background_update_names: logger.debug("Now safe to upsert in %s", table) self._unsafe_to_upsert_tables.discard(table) # If there's any updates still running, reschedule to run. - if updates: + if background_update_names: self._clock.call_later( 15.0, run_as_background_process, diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index c38b8a9e5a..282687ebce 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -27,7 +27,6 @@ from typing import ( ) from synapse.api.constants import AccountDataTypes -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream from synapse.storage._base import db_to_json from synapse.storage.database import ( @@ -68,12 +67,11 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) # to write account data. A value of `True` implies that `_account_data_id_gen` # is an `AbstractStreamIdGenerator` and not just a tracker. self._account_data_id_gen: AbstractStreamIdTracker + self._can_write_to_account_data = ( + self._instance_name in hs.config.worker.writers.account_data + ) if isinstance(database.engine, PostgresEngine): - self._can_write_to_account_data = ( - self._instance_name in hs.config.worker.writers.account_data - ) - self._account_data_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, @@ -95,21 +93,13 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). - if self._instance_name in hs.config.worker.writers.account_data: - self._can_write_to_account_data = True - self._account_data_id_gen = StreamIdGenerator( - db_conn, - "room_account_data", - "stream_id", - extra_tables=[("room_tags_revisions", "stream_id")], - ) - else: - self._account_data_id_gen = SlavedIdTracker( - db_conn, - "room_account_data", - "stream_id", - extra_tables=[("room_tags_revisions", "stream_id")], - ) + self._account_data_id_gen = StreamIdGenerator( + db_conn, + "room_account_data", + "stream_id", + extra_tables=[("room_tags_revisions", "stream_id")], + is_writer=self._instance_name in hs.config.worker.writers.account_data, + ) account_max = self.get_max_account_data_stream_id() self._account_data_stream_cache = StreamChangeCache( diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index aa58c2adc3..57230df5ae 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -38,7 +38,6 @@ from synapse.logging.opentracing import ( whitelisted_homeserver, ) from synapse.metrics.background_process_metrics import wrap_as_background_process -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( @@ -86,28 +85,19 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): ): super().__init__(database, db_conn, hs) - if hs.config.worker.worker_app is None: - self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator( - db_conn, - "device_lists_stream", - "stream_id", - extra_tables=[ - ("user_signature_stream", "stream_id"), - ("device_lists_outbound_pokes", "stream_id"), - ("device_lists_changes_in_room", "stream_id"), - ], - ) - else: - self._device_list_id_gen = SlavedIdTracker( - db_conn, - "device_lists_stream", - "stream_id", - extra_tables=[ - ("user_signature_stream", "stream_id"), - ("device_lists_outbound_pokes", "stream_id"), - ("device_lists_changes_in_room", "stream_id"), - ], - ) + # In the worker store this is an ID tracker which we overwrite in the non-worker + # class below that is used on the main process. + self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + db_conn, + "device_lists_stream", + "stream_id", + extra_tables=[ + ("user_signature_stream", "stream_id"), + ("device_lists_outbound_pokes", "stream_id"), + ("device_lists_changes_in_room", "stream_id"), + ], + is_writer=hs.config.worker.worker_app is None, + ) # Type-ignore: _device_list_id_gen is mixed in from either DataStore (as a # StreamIdGenerator) or SlavedDataStore (as a SlavedIdTracker). @@ -535,7 +525,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): limit: Maximum number of device updates to return Returns: - List: List of device update tuples: + List of device update tuples: - user_id - device_id - stream_id diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py index af59be6b48..6240f9a75e 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py @@ -391,10 +391,10 @@ class EndToEndRoomKeyStore(SQLBaseStore): Returns: A dict giving the info metadata for this backup version, with fields including: - version(str) - algorithm(str) - auth_data(object): opaque dict supplied by the client - etag(int): tag of the keys in the backup + version (str) + algorithm (str) + auth_data (object): opaque dict supplied by the client + etag (int): tag of the keys in the backup """ def _get_e2e_room_keys_version_info_txn(txn: LoggingTransaction) -> JsonDict: diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 2a4f58ed92..cf33e73e2b 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -412,10 +412,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker """Retrieve a number of one-time keys for a user Args: - user_id(str): id of user to get keys for - device_id(str): id of device to get keys for - key_ids(list[str]): list of key ids (excluding algorithm) to - retrieve + user_id: id of user to get keys for + device_id: id of device to get keys for + key_ids: list of key ids (excluding algorithm) to retrieve Returns: A map from (algorithm, key_id) to json string for key diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index c4acff5be6..d68f127f9b 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1279,9 +1279,10 @@ class PersistEventsStore: Pick the earliest non-outlier if there is one, else the earliest one. Args: - events_and_contexts (list[(EventBase, EventContext)]): + events_and_contexts: + Returns: - list[(EventBase, EventContext)]: filtered list + filtered list """ new_events_and_contexts: OrderedDict[ str, Tuple[EventBase, EventContext] @@ -1307,9 +1308,8 @@ class PersistEventsStore: """Update min_depth for each room Args: - txn (twisted.enterprise.adbapi.Connection): db connection - events_and_contexts (list[(EventBase, EventContext)]): events - we are persisting + txn: db connection + events_and_contexts: events we are persisting """ depth_updates: Dict[str, int] = {} for event, context in events_and_contexts: @@ -1580,13 +1580,11 @@ class PersistEventsStore: """Update all the miscellaneous tables for new events Args: - txn (twisted.enterprise.adbapi.Connection): db connection - events_and_contexts (list[(EventBase, EventContext)]): events - we are persisting - all_events_and_contexts (list[(EventBase, EventContext)]): all - events that we were going to persist. This includes events - we've already persisted, etc, that wouldn't appear in - events_and_context. + txn: db connection + events_and_contexts: events we are persisting + all_events_and_contexts: all events that we were going to persist. + This includes events we've already persisted, etc, that wouldn't + appear in events_and_context. inhibit_local_membership_updates: Stop the local_current_membership from being updated by these events. This should be set to True for backfilled events because backfilled events in the past do diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 467d20253d..01e935edef 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -59,7 +59,6 @@ from synapse.metrics.background_process_metrics import ( run_as_background_process, wrap_as_background_process, ) -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause @@ -213,26 +212,20 @@ class EventsWorkerStore(SQLBaseStore): # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). - if hs.get_instance_name() in hs.config.worker.writers.events: - self._stream_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - ) - self._backfill_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - step=-1, - extra_tables=[("ex_outlier_stream", "event_stream_ordering")], - ) - else: - self._stream_id_gen = SlavedIdTracker( - db_conn, "events", "stream_ordering" - ) - self._backfill_id_gen = SlavedIdTracker( - db_conn, "events", "stream_ordering", step=-1 - ) + self._stream_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + is_writer=hs.get_instance_name() in hs.config.worker.writers.events, + ) + self._backfill_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + step=-1, + extra_tables=[("ex_outlier_stream", "event_stream_ordering")], + is_writer=hs.get_instance_name() in hs.config.worker.writers.events, + ) events_max = self._stream_id_gen.get_current_token() curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict( @@ -1589,7 +1582,7 @@ class EventsWorkerStore(SQLBaseStore): room_id: The room ID to query. Returns: - dict[str:float] of complexity version to complexity. + Map of complexity version to complexity. """ state_events = await self.get_current_state_event_counts(room_id) diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index efd136a864..db9a24db5e 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -217,7 +217,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): def _reap_users(txn: LoggingTransaction, reserved_users: List[str]) -> None: """ Args: - reserved_users (tuple): reserved users to preserve + reserved_users: reserved users to preserve """ thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) @@ -370,8 +370,8 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): should not appear in the MAU stats). Args: - txn (cursor): - user_id (str): user to add/update + txn: + user_id: user to add/update """ assert ( self._update_on_this_worker @@ -401,7 +401,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): add the user to the monthly active tables Args: - user_id(str): the user_id to query + user_id: the user_id to query """ assert ( self._update_on_this_worker diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 8ae10f6127..12ad44dbb3 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -30,7 +30,6 @@ from typing import ( from synapse.api.errors import StoreError from synapse.config.homeserver import ExperimentalConfig -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import PushRulesStream from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -111,14 +110,14 @@ class PushRulesWorkerStore( ): super().__init__(database, db_conn, hs) - if hs.config.worker.worker_app is None: - self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator( - db_conn, "push_rules_stream", "stream_id" - ) - else: - self._push_rules_stream_id_gen = SlavedIdTracker( - db_conn, "push_rules_stream", "stream_id" - ) + # In the worker store this is an ID tracker which we overwrite in the non-worker + # class below that is used on the main process. + self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + db_conn, + "push_rules_stream", + "stream_id", + is_writer=hs.config.worker.worker_app is None, + ) push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict( db_conn, diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index 4a01562d45..fee37b9ce4 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -27,7 +27,6 @@ from typing import ( ) from synapse.push import PusherConfig, ThrottleParams -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import PushersStream from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import ( @@ -59,20 +58,15 @@ class PusherWorkerStore(SQLBaseStore): ): super().__init__(database, db_conn, hs) - if hs.config.worker.worker_app is None: - self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator( - db_conn, - "pushers", - "id", - extra_tables=[("deleted_pushers", "stream_id")], - ) - else: - self._pushers_id_gen = SlavedIdTracker( - db_conn, - "pushers", - "id", - extra_tables=[("deleted_pushers", "stream_id")], - ) + # In the worker store this is an ID tracker which we overwrite in the non-worker + # class below that is used on the main process. + self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + db_conn, + "pushers", + "id", + extra_tables=[("deleted_pushers", "stream_id")], + is_writer=hs.config.worker.worker_app is None, + ) self.db_pool.updates.register_background_update_handler( "remove_deactivated_pushers", diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index fbf27497ec..a580e4bdda 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -27,7 +27,6 @@ from typing import ( ) from synapse.api.constants import EduTypes -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import ReceiptsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( @@ -61,6 +60,9 @@ class ReceiptsWorkerStore(SQLBaseStore): hs: "HomeServer", ): self._instance_name = hs.get_instance_name() + + # In the worker store this is an ID tracker which we overwrite in the non-worker + # class below that is used on the main process. self._receipts_id_gen: AbstractStreamIdTracker if isinstance(database.engine, PostgresEngine): @@ -87,14 +89,12 @@ class ReceiptsWorkerStore(SQLBaseStore): # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). - if hs.get_instance_name() in hs.config.worker.writers.receipts: - self._receipts_id_gen = StreamIdGenerator( - db_conn, "receipts_linearized", "stream_id" - ) - else: - self._receipts_id_gen = SlavedIdTracker( - db_conn, "receipts_linearized", "stream_id" - ) + self._receipts_id_gen = StreamIdGenerator( + db_conn, + "receipts_linearized", + "stream_id", + is_writer=hs.get_instance_name() in hs.config.worker.writers.receipts, + ) super().__init__(database, db_conn, hs) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 5167089e03..31f0f2bd3d 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -953,7 +953,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): """Returns user id from threepid Args: - txn (cursor): + txn: medium: threepid medium e.g. email address: threepid address e.g. me@example.com @@ -1283,8 +1283,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): """Sets an expiration date to the account with the given user ID. Args: - user_id (str): User ID to set an expiration date for. - use_delta (bool): If set to False, the expiration date for the user will be + user_id: User ID to set an expiration date for. + use_delta: If set to False, the expiration date for the user will be now + validity period. If set to True, this expiration date will be a random value in the [now + period - d ; now + period] range, d being a delta equal to 10% of the validity period. diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 7d97f8f60e..4fbaefad73 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -2057,7 +2057,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): Args: report_id: ID of reported event in database Returns: - event_report: json list of information from event report + JSON dict of information from an event report or None if the + report does not exist. """ def _get_event_report_txn( @@ -2130,8 +2131,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): user_id: search for user_id. Ignored if user_id is None room_id: search for room_id. Ignored if room_id is None Returns: - event_reports: json list of event reports - count: total number of event reports matching the filter criteria + Tuple of: + json list of event reports + total number of event reports matching the filter criteria """ def _get_event_reports_paginate_txn( diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index ddb25b5cea..698d6f7515 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -185,9 +185,8 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): - who should be in the user_directory. Args: - progress (dict) - batch_size (int): Maximum number of state events to process - per cycle. + progress + batch_size: Maximum number of state events to process per cycle. Returns: number of events processed. @@ -708,10 +707,10 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): Returns the rooms that a user is in. Args: - user_id(str): Must be a local user + user_id: Must be a local user Returns: - list: user_id + List of room IDs """ rows = await self.db_pool.simple_select_onecol( table="users_who_share_private_rooms", diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 2dfe4c0b66..0d7108f01b 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -186,11 +186,13 @@ class StreamIdGenerator(AbstractStreamIdGenerator): column: str, extra_tables: Iterable[Tuple[str, str]] = (), step: int = 1, + is_writer: bool = True, ) -> None: assert step != 0 self._lock = threading.Lock() self._step: int = step self._current: int = _load_current_id(db_conn, table, column, step) + self._is_writer = is_writer for table, column in extra_tables: self._current = (max if step > 0 else min)( self._current, _load_current_id(db_conn, table, column, step) @@ -204,9 +206,11 @@ class StreamIdGenerator(AbstractStreamIdGenerator): self._unfinished_ids: OrderedDict[int, int] = OrderedDict() def advance(self, instance_name: str, new_id: int) -> None: - # `StreamIdGenerator` should only be used when there is a single writer, - # so replication should never happen. - raise Exception("Replication is not supported by StreamIdGenerator") + # Advance should never be called on a writer instance, only over replication + if self._is_writer: + raise Exception("Replication is not supported by writer StreamIdGenerator") + + self._current = (max if self._step > 0 else min)(self._current, new_id) def get_next(self) -> AsyncContextManager[int]: with self._lock: @@ -249,6 +253,9 @@ class StreamIdGenerator(AbstractStreamIdGenerator): return _AsyncCtxManagerWrapper(manager()) def get_current_token(self) -> int: + if not self._is_writer: + return self._current + with self._lock: if self._unfinished_ids: return next(iter(self._unfinished_ids)) - self._step diff --git a/synapse/types.py b/synapse/types.py index 773f0438d5..f2d436ddc3 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -143,8 +143,8 @@ class Requester: Requester. Args: - store (DataStore): Used to convert AS ID to AS object - input (dict): A dict produced by `serialize` + store: Used to convert AS ID to AS object + input: A dict produced by `serialize` Returns: Requester diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 7f1d41eb3c..d24c4f68c4 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -217,7 +217,8 @@ async def concurrently_execute( limit: Maximum number of conccurent executions. Returns: - Deferred: Resolved when all function invocations have finished. + None, when all function invocations have finished. The return values + from those functions are discarded. """ it = iter(args) diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index f7c3a6794e..9387632d0d 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -197,7 +197,7 @@ def register_cache( resize_callback: A function which can be called to resize the cache. Returns: - CacheMetric: an object which provides inc_{hits,misses,evictions} methods + an object which provides inc_{hits,misses,evictions} methods """ if resizable: if not resize_callback: diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index bcb1cba362..bf7bd351e0 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -153,7 +153,7 @@ class DeferredCache(Generic[KT, VT]): Args: key: callback: Gets called when the entry in the cache is invalidated - update_metrics (bool): whether to update the cache hit rate metrics + update_metrics: whether to update the cache hit rate metrics Returns: A Deferred which completes with the result. Note that this may later fail diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index fa91479c97..5eaf70c7ab 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -169,10 +169,11 @@ class DictionaryCache(Generic[KT, DKT, DV]): if it is in the cache. Returns: - DictionaryEntry: If `dict_keys` is not None then `DictionaryEntry` - will contain include the keys that are in the cache. If None then - will either return the full dict if in the cache, or the empty - dict (with `full` set to False) if it isn't. + If `dict_keys` is not None then `DictionaryEntry` will contain include + the keys that are in the cache. + + If None then will either return the full dict if in the cache, or the + empty dict (with `full` set to False) if it isn't. """ if dict_keys is None: # The caller wants the full set of dictionary keys for this cache key diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index c6a5d0dfc0..01ad02af67 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -207,7 +207,7 @@ class ExpiringCache(Generic[KT, VT]): items from the cache. Returns: - bool: Whether the cache changed size or not. + Whether the cache changed size or not. """ new_size = int(self._original_max_size * factor) if new_size != self._max_size: diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index aa93109d13..dcf0eac3bf 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -389,11 +389,11 @@ class LruCache(Generic[KT, VT]): cache_name: The name of this cache, for the prometheus metrics. If unset, no metrics will be reported on this cache. - cache_type (type): + cache_type: type of underlying cache to be used. Typically one of dict or TreeCache. - size_callback (func(V) -> int | None): + size_callback: metrics_collection_callback: metrics collection callback. This is called early in the metrics @@ -403,7 +403,7 @@ class LruCache(Generic[KT, VT]): Ignored if cache_name is None. - apply_cache_factor_from_config (bool): If true, `max_size` will be + apply_cache_factor_from_config: If true, `max_size` will be multiplied by a cache factor derived from the homeserver config clock: @@ -796,7 +796,7 @@ class LruCache(Generic[KT, VT]): items from the cache. Returns: - bool: Whether the cache changed size or not. + Whether the cache changed size or not. """ if not self.apply_cache_factor_from_config: return False diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index 9f64fed0d7..2aceb1a47f 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -183,7 +183,7 @@ class FederationRateLimiter: # Handle request ... Args: - host (str): Origin of incoming request. + host: Origin of incoming request. Returns: context manager which returns a deferred. diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py index 1e9c2faa64..54bc7589fd 100644 --- a/synapse/util/threepids.py +++ b/synapse/util/threepids.py @@ -48,7 +48,7 @@ async def check_3pid_allowed( registration: whether we want to bind the 3PID as part of registering a new user. Returns: - bool: whether the 3PID medium/address is allowed to be added to this HS + whether the 3PID medium/address is allowed to be added to this HS """ if not await hs.get_password_auth_provider().is_3pid_allowed( medium, address, registration diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py index 177e198e7e..b1ec7f4bd8 100644 --- a/synapse/util/wheel_timer.py +++ b/synapse/util/wheel_timer.py @@ -90,10 +90,10 @@ class WheelTimer(Generic[T]): """Fetch any objects that have timed out Args: - now (ms): Current time in msec + now: Current time in msec Returns: - list: List of objects that have timed out + List of objects that have timed out """ now_key = int(now / self.bucket_size) diff --git a/synapse/visibility.py b/synapse/visibility.py index 40a9c5b53f..b443857571 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -563,7 +563,8 @@ def get_effective_room_visibility_from_state(state: StateMap[EventBase]) -> str: async def filter_events_for_server( storage: StorageControllers, - server_name: str, + target_server_name: str, + local_server_name: str, events: List[EventBase], redact: bool = True, check_history_visibility_only: bool = False, @@ -603,7 +604,7 @@ async def filter_events_for_server( # if the server is either in the room or has been invited # into the room. for ev in memberships.values(): - assert get_domain_from_id(ev.state_key) == server_name + assert get_domain_from_id(ev.state_key) == target_server_name memtype = ev.membership if memtype == Membership.JOIN: @@ -622,6 +623,24 @@ async def filter_events_for_server( # to no users having been erased. erased_senders = {} + # Filter out non-local events when we are in the middle of a partial join, since our servers + # list can be out of date and we could leak events to servers not in the room anymore. + # This can also be true for local events but we consider it to be an acceptable risk. + + # We do this check as a first step and before retrieving membership events because + # otherwise a room could be fully joined after we retrieve those, which would then bypass + # this check but would base the filtering on an outdated view of the membership events. + + partial_state_invisible_events = set() + if not check_history_visibility_only: + for e in events: + sender_domain = get_domain_from_id(e.sender) + if ( + sender_domain != local_server_name + and await storage.main.is_partial_state_room(e.room_id) + ): + partial_state_invisible_events.add(e) + # Let's check to see if all the events have a history visibility # of "shared" or "world_readable". If that's the case then we don't # need to check membership (as we know the server is in the room). @@ -636,7 +655,7 @@ async def filter_events_for_server( if event_to_history_vis[e.event_id] not in (HistoryVisibility.SHARED, HistoryVisibility.WORLD_READABLE) ], - server_name, + target_server_name, ) to_return = [] @@ -645,6 +664,10 @@ async def filter_events_for_server( visible = check_event_is_visible( event_to_history_vis[e.event_id], event_to_memberships.get(e.event_id, {}) ) + + if e in partial_state_invisible_events: + visible = False + if visible and not erased: to_return.append(e) elif redact: diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 820a1a54e2..63628aa6b0 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -469,6 +469,18 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0)) self.assertEqual(keys, {}) + def test_keyid_containing_forward_slash(self) -> None: + """We should url-encode any url unsafe chars in key ids. + + Detects https://github.com/matrix-org/synapse/issues/14488. + """ + fetcher = ServerKeyFetcher(self.hs) + self.get_success(fetcher.get_keys("example.com", ["key/potato"], 0)) + + self.http_client.get_json.assert_called_once() + args, kwargs = self.http_client.get_json.call_args + self.assertEqual(kwargs["path"], "/_matrix/key/v2/server/key%2Fpotato") + class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index c96dc6caf2..c5981ff965 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -15,6 +15,7 @@ from typing import Optional from unittest.mock import Mock, call +from parameterized import parameterized from signedjson.key import generate_signing_key from synapse.api.constants import EventTypes, Membership, PresenceState @@ -37,6 +38,7 @@ from synapse.rest.client import room from synapse.types import UserID, get_domain_from_id from tests import unittest +from tests.replication._base import BaseMultiWorkerStreamTestCase class PresenceUpdateTestCase(unittest.HomeserverTestCase): @@ -505,7 +507,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): self.assertEqual(state, new_state) -class PresenceHandlerTestCase(unittest.HomeserverTestCase): +class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): def prepare(self, reactor, clock, hs): self.presence_handler = hs.get_presence_handler() self.clock = hs.get_clock() @@ -716,20 +718,47 @@ class PresenceHandlerTestCase(unittest.HomeserverTestCase): # our status message should be the same as it was before self.assertEqual(state.status_msg, status_msg) - def test_set_presence_from_syncing_keeps_busy(self): - """Test that presence set by syncing doesn't affect busy status""" - # while this isn't the default - self.presence_handler._busy_presence_enabled = True + @parameterized.expand([(False,), (True,)]) + @unittest.override_config( + { + "experimental_features": { + "msc3026_enabled": True, + }, + } + ) + def test_set_presence_from_syncing_keeps_busy(self, test_with_workers: bool): + """Test that presence set by syncing doesn't affect busy status + Args: + test_with_workers: If True, check the presence state of the user by calling + /sync against a worker, rather than the main process. + """ user_id = "@test:server" status_msg = "I'm busy!" + # By default, we call /sync against the main process. + worker_to_sync_against = self.hs + if test_with_workers: + # Create a worker and use it to handle /sync traffic instead. + # This is used to test that presence changes get replicated from workers + # to the main process correctly. + worker_to_sync_against = self.make_worker_hs( + "synapse.app.generic_worker", {"worker_name": "presence_writer"} + ) + + # Set presence to BUSY self._set_presencestate_with_status_msg(user_id, PresenceState.BUSY, status_msg) + # Perform a sync with a presence state other than busy. This should NOT change + # our presence status; we only change from busy if we explicitly set it via + # /presence/*. self.get_success( - self.presence_handler.user_syncing(user_id, True, PresenceState.ONLINE) + worker_to_sync_against.get_presence_handler().user_syncing( + user_id, True, PresenceState.ONLINE + ) ) + # Check against the main process that the user's presence did not change. state = self.get_success( self.presence_handler.get_state(UserID.from_string(user_id)) ) diff --git a/tests/http/__init__.py b/tests/http/__init__.py index e74f7f5b48..093537adef 100644 --- a/tests/http/__init__.py +++ b/tests/http/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. import os.path import subprocess +from typing import List from zope.interface import implementer @@ -70,14 +71,14 @@ subjectAltName = %(sanentries)s """ -def create_test_cert_file(sanlist): +def create_test_cert_file(sanlist: List[bytes]) -> str: """build an x509 certificate file Args: - sanlist: list[bytes]: a list of subjectAltName values for the cert + sanlist: a list of subjectAltName values for the cert Returns: - str: the path to the file + The path to the file """ global cert_file_count csr_filename = "server.csr" diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 02cef6f876..058ca57e55 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -778,8 +778,11 @@ def _test_sending_local_online_presence_to_local_user( worker process. The test users will still sync with the main process. The purpose of testing with a worker is to check whether a Synapse module running on a worker can inform other workers/ the main process that they should include additional presence when a user next syncs. + If this argument is True, `test_case` MUST be an instance of BaseMultiWorkerStreamTestCase. """ if test_with_workers: + assert isinstance(test_case, BaseMultiWorkerStreamTestCase) + # Create a worker process to make module_api calls against worker_hs = test_case.make_worker_hs( "synapse.app.generic_worker", {"worker_name": "presence_writer"} diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 121f3d8d65..3029a16dda 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -542,8 +542,13 @@ class FakeRedisPubSubProtocol(Protocol): self.send("OK") elif command == b"GET": self.send(None) + + # Connection keep-alives. + elif command == b"PING": + self.send("PONG") + else: - raise Exception("Unknown command") + raise Exception(f"Unknown command: {command}") def send(self, msg): """Send a message back to the client.""" diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 96f3880923..dce71f7334 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -143,6 +143,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase): self.persist(type="m.room.create", key="", creator=USER_ID) self.check("get_invited_rooms_for_local_user", [USER_ID_2], []) event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite") + assert event.internal_metadata.stream_ordering is not None self.replicate() @@ -230,6 +231,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase): j2 = self.persist( type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join" ) + assert j2.internal_metadata.stream_ordering is not None self.replicate() expected_pos = PersistedEventPosition( @@ -287,6 +289,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase): ) ) self.replicate() + assert j2.internal_metadata.stream_ordering is not None event_source = RoomEventSource(self.hs) event_source.store = self.slaved_store @@ -336,10 +339,10 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase): event_id = 0 - def persist(self, backfill=False, **kwargs): + def persist(self, backfill=False, **kwargs) -> FrozenEvent: """ Returns: - synapse.events.FrozenEvent: The event that was persisted. + The event that was persisted. """ event, context = self.build_event(**kwargs) diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py index 13aa5eb51a..96cdf2c45b 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py @@ -15,8 +15,9 @@ import logging import os from typing import Optional, Tuple +from twisted.internet.interfaces import IOpenSSLServerConnectionCreator from twisted.internet.protocol import Factory -from twisted.protocols.tls import TLSMemoryBIOFactory +from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol from twisted.web.http import HTTPChannel from twisted.web.server import Request @@ -102,7 +103,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): ) # fish the test server back out of the server-side TLS protocol. - http_server = server_tls_protocol.wrappedProtocol + http_server: HTTPChannel = server_tls_protocol.wrappedProtocol # type: ignore[assignment] # give the reactor a pump to get the TLS juices flowing. self.reactor.pump((0.1,)) @@ -238,16 +239,15 @@ def get_connection_factory(): return test_server_connection_factory -def _build_test_server(connection_creator): +def _build_test_server( + connection_creator: IOpenSSLServerConnectionCreator, +) -> TLSMemoryBIOProtocol: """Construct a test server This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol Args: - connection_creator (IOpenSSLServerConnectionCreator): thing to build - SSL connections - sanlist (list[bytes]): list of the SAN entries for the cert returned - by the server + connection_creator: thing to build SSL connections Returns: TLSMemoryBIOProtocol diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index bf403045e9..7cbc40736c 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Tuple from unittest.mock import Mock from twisted.test.proto_helpers import MemoryReactor @@ -350,14 +351,15 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): self.assertTrue(notice_in_room, "No server notice in room") - def _trigger_notice_and_join(self): + def _trigger_notice_and_join(self) -> Tuple[str, str, str]: """Creates enough active users to hit the MAU limit and trigger a system notice about it, then joins the system notices room with one of the users created. Returns: - user_id (str): The ID of the user that joined the room. - tok (str): The access token of the user that joined the room. - room_id (str): The ID of the room that's been joined. + A tuple of: + user_id: The ID of the user that joined the room. + tok: The access token of the user that joined the room. + room_id: The ID of the room that's been joined. """ user_id = None tok = None diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index 2d8d1f860f..d6a2b8d274 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -16,15 +16,157 @@ from typing import List, Optional from twisted.test.proto_helpers import MemoryReactor from synapse.server import HomeServer -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.engines import IncorrectDatabaseSetup -from synapse.storage.util.id_generators import MultiWriterIdGenerator +from synapse.storage.types import Cursor +from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.util import Clock from tests.unittest import HomeserverTestCase from tests.utils import USE_POSTGRES_FOR_TESTS +class StreamIdGeneratorTestCase(HomeserverTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.db_pool: DatabasePool = self.store.db_pool + + self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) + + def _setup_db(self, txn: LoggingTransaction) -> None: + txn.execute( + """ + CREATE TABLE foobar ( + stream_id BIGINT NOT NULL, + data TEXT + ); + """ + ) + txn.execute("INSERT INTO foobar VALUES (123, 'hello world');") + + def _create_id_generator(self) -> StreamIdGenerator: + def _create(conn: LoggingDatabaseConnection) -> StreamIdGenerator: + return StreamIdGenerator( + db_conn=conn, + table="foobar", + column="stream_id", + ) + + return self.get_success_or_raise(self.db_pool.runWithConnection(_create)) + + def test_initial_value(self) -> None: + """Check that we read the current token from the DB.""" + id_gen = self._create_id_generator() + self.assertEqual(id_gen.get_current_token(), 123) + + def test_single_gen_next(self) -> None: + """Check that we correctly increment the current token from the DB.""" + id_gen = self._create_id_generator() + + async def test_gen_next() -> None: + async with id_gen.get_next() as next_id: + # We haven't persisted `next_id` yet; current token is still 123 + self.assertEqual(id_gen.get_current_token(), 123) + # But we did learn what the next value is + self.assertEqual(next_id, 124) + + # Once the context manager closes we assume that the `next_id` has been + # written to the DB. + self.assertEqual(id_gen.get_current_token(), 124) + + self.get_success(test_gen_next()) + + def test_multiple_gen_nexts(self) -> None: + """Check that we handle overlapping calls to gen_next sensibly.""" + id_gen = self._create_id_generator() + + async def test_gen_next() -> None: + ctx1 = id_gen.get_next() + ctx2 = id_gen.get_next() + ctx3 = id_gen.get_next() + + # Request three new stream IDs. + self.assertEqual(await ctx1.__aenter__(), 124) + self.assertEqual(await ctx2.__aenter__(), 125) + self.assertEqual(await ctx3.__aenter__(), 126) + + # None are persisted: current token unchanged. + self.assertEqual(id_gen.get_current_token(), 123) + + # Persist each in turn. + await ctx1.__aexit__(None, None, None) + self.assertEqual(id_gen.get_current_token(), 124) + await ctx2.__aexit__(None, None, None) + self.assertEqual(id_gen.get_current_token(), 125) + await ctx3.__aexit__(None, None, None) + self.assertEqual(id_gen.get_current_token(), 126) + + self.get_success(test_gen_next()) + + def test_multiple_gen_nexts_closed_in_different_order(self) -> None: + """Check that we handle overlapping calls to gen_next, even when their IDs + created and persisted in different orders.""" + id_gen = self._create_id_generator() + + async def test_gen_next() -> None: + ctx1 = id_gen.get_next() + ctx2 = id_gen.get_next() + ctx3 = id_gen.get_next() + + # Request three new stream IDs. + self.assertEqual(await ctx1.__aenter__(), 124) + self.assertEqual(await ctx2.__aenter__(), 125) + self.assertEqual(await ctx3.__aenter__(), 126) + + # None are persisted: current token unchanged. + self.assertEqual(id_gen.get_current_token(), 123) + + # Persist them in a different order, starting with 126 from ctx3. + await ctx3.__aexit__(None, None, None) + # We haven't persisted 124 from ctx1 yet---current token is still 123. + self.assertEqual(id_gen.get_current_token(), 123) + + # Now persist 124 from ctx1. + await ctx1.__aexit__(None, None, None) + # Current token is then 124, waiting for 125 to be persisted. + self.assertEqual(id_gen.get_current_token(), 124) + + # Finally persist 125 from ctx2. + await ctx2.__aexit__(None, None, None) + # Current token is then 126 (skipping over 125). + self.assertEqual(id_gen.get_current_token(), 126) + + self.get_success(test_gen_next()) + + def test_gen_next_while_still_waiting_for_persistence(self) -> None: + """Check that we handle overlapping calls to gen_next.""" + id_gen = self._create_id_generator() + + async def test_gen_next() -> None: + ctx1 = id_gen.get_next() + ctx2 = id_gen.get_next() + ctx3 = id_gen.get_next() + + # Request two new stream IDs. + self.assertEqual(await ctx1.__aenter__(), 124) + self.assertEqual(await ctx2.__aenter__(), 125) + + # Persist ctx2 first. + await ctx2.__aexit__(None, None, None) + # Still waiting on ctx1's ID to be persisted. + self.assertEqual(id_gen.get_current_token(), 123) + + # Now request a third stream ID. It should be 126 (the smallest ID that + # we've not yet handed out.) + self.assertEqual(await ctx3.__aenter__(), 126) + + self.get_success(test_gen_next()) + + class MultiWriterIdGeneratorTestCase(HomeserverTestCase): if not USE_POSTGRES_FOR_TESTS: skip = "Requires Postgres" @@ -48,9 +190,9 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): ) def _create_id_generator( - self, instance_name="master", writers: Optional[List[str]] = None + self, instance_name: str = "master", writers: Optional[List[str]] = None ) -> MultiWriterIdGenerator: - def _create(conn): + def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator: return MultiWriterIdGenerator( conn, self.db_pool, @@ -446,7 +588,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self._insert_row_with_id("master", 3) # Now we add a row *without* updating the stream ID - def _insert(txn): + def _insert(txn: Cursor) -> None: txn.execute("INSERT INTO foobar VALUES (26, 'master')") self.get_success(self.db_pool.runInteraction("_insert", _insert)) @@ -481,9 +623,9 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): ) def _create_id_generator( - self, instance_name="master", writers: Optional[List[str]] = None + self, instance_name: str = "master", writers: Optional[List[str]] = None ) -> MultiWriterIdGenerator: - def _create(conn): + def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator: return MultiWriterIdGenerator( conn, self.db_pool, @@ -617,9 +759,9 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase): ) def _create_id_generator( - self, instance_name="master", writers: Optional[List[str]] = None + self, instance_name: str = "master", writers: Optional[List[str]] = None ) -> MultiWriterIdGenerator: - def _create(conn): + def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator: return MultiWriterIdGenerator( conn, self.db_pool, @@ -641,7 +783,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase): instance_name: str, number: int, update_stream_table: bool = True, - ): + ) -> None: """Insert N rows as the given instance, inserting with stream IDs pulled from the postgres sequence. """ diff --git a/tests/test_visibility.py b/tests/test_visibility.py index c385b2f8d4..d0b9ad5454 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -61,7 +61,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): filtered = self.get_success( filter_events_for_server( - self._storage_controllers, "test_server", events_to_filter + self._storage_controllers, "test_server", "hs", events_to_filter ) ) @@ -83,7 +83,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): self.assertEqual( self.get_success( filter_events_for_server( - self._storage_controllers, "remote_hs", [outlier] + self._storage_controllers, "remote_hs", "hs", [outlier] ) ), [outlier], @@ -94,7 +94,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): filtered = self.get_success( filter_events_for_server( - self._storage_controllers, "remote_hs", [outlier, evt] + self._storage_controllers, "remote_hs", "local_hs", [outlier, evt] ) ) self.assertEqual(len(filtered), 2, f"expected 2 results, got: {filtered}") @@ -106,7 +106,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): # be redacted) filtered = self.get_success( filter_events_for_server( - self._storage_controllers, "other_server", [outlier, evt] + self._storage_controllers, "other_server", "local_hs", [outlier, evt] ) ) self.assertEqual(filtered[0], outlier) @@ -141,7 +141,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): # ... and the filtering happens. filtered = self.get_success( filter_events_for_server( - self._storage_controllers, "test_server", events_to_filter + self._storage_controllers, "test_server", "local_hs", events_to_filter ) ) diff --git a/tests/unittest.py b/tests/unittest.py index 5116be338e..a120c2976c 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -360,13 +360,13 @@ class HomeserverTestCase(TestCase): store.db_pool.updates.do_next_background_update(False), by=0.1 ) - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock): """ Make and return a homeserver. Args: reactor: A Twisted Reactor, or something that pretends to be one. - clock (synapse.util.Clock): The Clock, associated with the reactor. + clock: The Clock, associated with the reactor. Returns: A homeserver suitable for testing. @@ -426,9 +426,8 @@ class HomeserverTestCase(TestCase): Args: reactor: A Twisted Reactor, or something that pretends to be one. - clock (synapse.util.Clock): The Clock, associated with the reactor. - homeserver (synapse.server.HomeServer): The HomeServer to test - against. + clock: The Clock, associated with the reactor. + homeserver: The HomeServer to test against. Function to optionally be overridden in subclasses. """ @@ -452,11 +451,10 @@ class HomeserverTestCase(TestCase): given content. Args: - method (bytes/unicode): The HTTP request method ("verb"). - path (bytes/unicode): The HTTP path, suitably URL encoded (e.g. - escaped UTF-8 & spaces and such). - content (bytes or dict): The body of the request. JSON-encoded, if - a dict. + method: The HTTP request method ("verb"). + path: The HTTP path, suitably URL encoded (e.g. escaped UTF-8 & spaces + and such). content (bytes or dict): The body of the request. + JSON-encoded, if a dict. shorthand: Whether to try and be helpful and prefix the given URL with the usual REST API path, if it doesn't contain it. federation_auth_origin: if set to not-None, we will add a fake |