diff options
author | Andrew Morgan <andrew@amorgan.xyz> | 2020-10-19 19:12:39 +0100 |
---|---|---|
committer | Andrew Morgan <andrew@amorgan.xyz> | 2020-10-19 19:12:39 +0100 |
commit | 687d30b2eddd4feb530879c9499f248197ffba00 (patch) | |
tree | e3ee7460dcf934f1450359d29179079724a57f53 | |
parent | Merge commit '3c01724b3' into anoa/dinsic_release_1_21_x (diff) | |
parent | Remove `ChainedIdGenerator`. (#8123) (diff) | |
download | synapse-687d30b2eddd4feb530879c9499f248197ffba00.tar.xz |
Merge commit 'c9c544cda' into anoa/dinsic_release_1_21_x
* commit 'c9c544cda': Remove `ChainedIdGenerator`. (#8123) Switch the JSON byte producer from a pull to a push producer. (#8116) Updated docs: Added note about missing 308 redirect support. (#8120) Be stricter about JSON that is accepted by Synapse (#8106) Convert runWithConnection to async. (#8121) Remove the unused inlineCallbacks code-paths in the caching code (#8119) Separate `get_current_token` into two. (#8113) Convert events worker database to async/await. (#8071) Add a link to the matrix-synapse-rest-password-provider. (#8111)
60 files changed, 409 insertions, 419 deletions
diff --git a/changelog.d/8071.misc b/changelog.d/8071.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8071.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/changelog.d/8106.bugfix b/changelog.d/8106.bugfix new file mode 100644 index 0000000000..c46c60448f --- /dev/null +++ b/changelog.d/8106.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where invalid JSON would be accepted by Synapse. diff --git a/changelog.d/8111.doc b/changelog.d/8111.doc new file mode 100644 index 0000000000..d3f7435452 --- /dev/null +++ b/changelog.d/8111.doc @@ -0,0 +1 @@ +Link to matrix-synapse-rest-password-provider in the password provider documentation. diff --git a/changelog.d/8113.misc b/changelog.d/8113.misc new file mode 100644 index 0000000000..00bec4f8ef --- /dev/null +++ b/changelog.d/8113.misc @@ -0,0 +1 @@ +Separate `get_current_token` into two since there are two different use cases for it. diff --git a/changelog.d/8116.feature b/changelog.d/8116.feature new file mode 100644 index 0000000000..b1eaf1e78a --- /dev/null +++ b/changelog.d/8116.feature @@ -0,0 +1 @@ +Iteratively encode JSON to avoid blocking the reactor. diff --git a/changelog.d/8119.misc b/changelog.d/8119.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8119.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/changelog.d/8120.doc b/changelog.d/8120.doc new file mode 100644 index 0000000000..877ef79fd2 --- /dev/null +++ b/changelog.d/8120.doc @@ -0,0 +1 @@ +Updated documentation to note that Synapse does not follow `HTTP 308` redirects due to an upstream library not supporting them. Contributed by Ryan Cole. \ No newline at end of file diff --git a/changelog.d/8121.misc b/changelog.d/8121.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8121.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/changelog.d/8123.misc b/changelog.d/8123.misc new file mode 100644 index 0000000000..7245122896 --- /dev/null +++ b/changelog.d/8123.misc @@ -0,0 +1 @@ +Remove `ChainedIdGenerator`. diff --git a/docs/federate.md b/docs/federate.md index a0786b9cf7..b15cd724d1 100644 --- a/docs/federate.md +++ b/docs/federate.md @@ -47,6 +47,18 @@ you invite them to. This can be caused by an incorrectly-configured reverse proxy: see [reverse_proxy.md](<reverse_proxy.md>) for instructions on how to correctly configure a reverse proxy. +### Known issues + +**HTTP `308 Permanent Redirect` redirects are not followed**: Due to missing features +in the HTTP library used by Synapse, 308 redirects are currently not followed by +federating servers, which can cause `M_UNKNOWN` or `401 Unauthorized` errors. This +may affect users who are redirecting apex-to-www (e.g. `example.com` -> `www.example.com`), +and especially users of the Kubernetes *Nginx Ingress* module, which uses 308 redirect +codes by default. For those Kubernetes users, [this Stackoverflow post](https://stackoverflow.com/a/52617528/5096871) +might be helpful. For other users, switching to a `301 Moved Permanently` code may be +an option. 308 redirect codes will be supported properly in a future +release of Synapse. + ## Running a demo federation of Synapses If you want to get up and running quickly with a trio of homeservers in a diff --git a/docs/password_auth_providers.md b/docs/password_auth_providers.md index fef1d47e85..7d98d9f255 100644 --- a/docs/password_auth_providers.md +++ b/docs/password_auth_providers.md @@ -14,6 +14,7 @@ password auth provider module implementations: * [matrix-synapse-ldap3](https://github.com/matrix-org/matrix-synapse-ldap3/) * [matrix-synapse-shared-secret-auth](https://github.com/devture/matrix-synapse-shared-secret-auth) +* [matrix-synapse-rest-password-provider](https://github.com/ma1uta/matrix-synapse-rest-password-provider) ## Required methods diff --git a/synapse/api/errors.py b/synapse/api/errors.py index a3b0c0a3e4..28a078a7b4 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -22,10 +22,10 @@ import typing from http import HTTPStatus from typing import Dict, List, Optional, Union -from canonicaljson import json - from twisted.web import http +from synapse.util import json_decoder + if typing.TYPE_CHECKING: from synapse.types import JsonDict @@ -594,7 +594,7 @@ class HttpResponseException(CodeMessageException): # try to parse the body as json, to get better errcode/msg, but # default to M_UNKNOWN with the HTTP status as the error text try: - j = json.loads(self.response.decode("utf-8")) + j = json_decoder.decode(self.response.decode("utf-8")) except ValueError: j = {} diff --git a/synapse/event_auth.py b/synapse/event_auth.py index c0981eee62..8c907ad596 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -47,7 +47,7 @@ def check( Args: room_version_obj: the version of the room event: the event being checked. - auth_events (dict: event-key -> event): the existing room state. + auth_events: the existing room state. Raises: AuthError if the checks fail diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 11c5d63298..630f571cd4 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -28,7 +28,6 @@ from typing import ( Union, ) -from canonicaljson import json from prometheus_client import Counter, Histogram from twisted.internet import defer @@ -63,7 +62,7 @@ from synapse.replication.http.federation import ( ReplicationGetQueryRestServlet, ) from synapse.types import JsonDict, get_domain_from_id -from synapse.util import glob_to_regex, unwrapFirstError +from synapse.util import glob_to_regex, json_decoder, unwrapFirstError from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.caches.response_cache import ResponseCache @@ -551,7 +550,7 @@ class FederationServer(FederationBase): for device_id, keys in device_keys.items(): for key_id, json_str in keys.items(): json_result.setdefault(user_id, {})[device_id] = { - key_id: json.loads(json_str) + key_id: json_decoder.decode(json_str) } logger.info( diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index c7f6cb3d73..9bd534a313 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -15,8 +15,6 @@ import logging from typing import TYPE_CHECKING, List, Tuple -from canonicaljson import json - from synapse.api.errors import HttpResponseException from synapse.events import EventBase from synapse.federation.persistence import TransactionActions @@ -28,6 +26,7 @@ from synapse.logging.opentracing import ( tags, whitelisted_homeserver, ) +from synapse.util import json_decoder from synapse.util.metrics import measure_func if TYPE_CHECKING: @@ -71,7 +70,7 @@ class TransactionManager(object): for edu in pending_edus: context = edu.get_context() if context: - span_contexts.append(extract_text_map(json.loads(context))) + span_contexts.append(extract_text_map(json_decoder.decode(context))) if keep_destination: edu.strip_context() diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 84169c1022..d8def45e38 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -19,7 +19,7 @@ import logging from typing import Dict, List, Optional, Tuple import attr -from canonicaljson import encode_canonical_json, json +from canonicaljson import encode_canonical_json from signedjson.key import VerifyKey, decode_verify_key_bytes from signedjson.sign import SignatureVerifyException, verify_signed_json from unpaddedbase64 import decode_base64 @@ -35,7 +35,7 @@ from synapse.types import ( get_domain_from_id, get_verify_key_from_cross_signing_key, ) -from synapse.util import unwrapFirstError +from synapse.util import json_decoder, unwrapFirstError from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.retryutils import NotRetryingDestination @@ -404,7 +404,7 @@ class E2eKeysHandler(object): for device_id, keys in device_keys.items(): for key_id, json_bytes in keys.items(): json_result.setdefault(user_id, {})[device_id] = { - key_id: json.loads(json_bytes) + key_id: json_decoder.decode(json_bytes) } @trace @@ -1186,7 +1186,7 @@ def _exception_to_failure(e): def _one_time_keys_match(old_key_json, new_key): - old_key = json.loads(old_key_json) + old_key = json_decoder.decode(old_key_json) # if either is a string rather than an object, they must match exactly if not isinstance(old_key, dict) or not isinstance(new_key, dict): diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 3b5681eb06..29863c029b 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1787,9 +1787,7 @@ class FederationHandler(BaseHandler): """Returns the state at the event. i.e. not including said event. """ - event = await self.store.get_event( - event_id, allow_none=False, check_room_id=room_id - ) + event = await self.store.get_event(event_id, check_room_id=room_id) state_groups = await self.state_store.get_state_groups(room_id, [event_id]) @@ -1815,9 +1813,7 @@ class FederationHandler(BaseHandler): async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]: """Returns the state at the event. i.e. not including said event. """ - event = await self.store.get_event( - event_id, allow_none=False, check_room_id=room_id - ) + event = await self.store.get_event(event_id, check_room_id=room_id) state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id]) @@ -2165,9 +2161,9 @@ class FederationHandler(BaseHandler): auth_types = auth_types_for_event(event) current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types] - current_auth_events = await self.store.get_events(current_state_ids) + auth_events_map = await self.store.get_events(current_state_ids) current_auth_events = { - (e.type, e.state_key): e for e in current_auth_events.values() + (e.type, e.state_key): e for e in auth_events_map.values() } try: @@ -2183,9 +2179,7 @@ class FederationHandler(BaseHandler): if not in_room: raise AuthError(403, "Host not in room.") - event = await self.store.get_event( - event_id, allow_none=False, check_room_id=room_id - ) + event = await self.store.get_event(event_id, check_room_id=room_id) # Just go through and process each event in `remote_auth_chain`. We # don't want to fall into the trap of `missing` being wrong. diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index ef930dba55..b5676b248b 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -21,9 +21,6 @@ import logging import urllib.parse from typing import Awaitable, Callable, Dict, List, Optional, Tuple -from canonicaljson import json - -from twisted.internet import defer from twisted.internet.error import TimeoutError from synapse.api.errors import ( @@ -37,6 +34,7 @@ from synapse.api.errors import ( from synapse.config.emailconfig import ThreepidBehaviour from synapse.http.client import SimpleHttpClient from synapse.types import JsonDict, Requester +from synapse.util import json_decoder from synapse.util.hash import sha256_and_url_safe_base64 from synapse.util.stringutils import assert_valid_client_secret, random_string @@ -197,7 +195,7 @@ class IdentityHandler(BaseHandler): except TimeoutError: raise SynapseError(500, "Timed out contacting identity server") except CodeMessageException as e: - data = json.loads(e.msg) # XXX WAT? + data = json_decoder.decode(e.msg) # XXX WAT? return data logger.info("Got 404 when POSTing JSON %s, falling back to v1 URL", bind_url) @@ -620,18 +618,19 @@ class IdentityHandler(BaseHandler): # the CS API. They should be consolidated with those in RoomMemberHandler # https://github.com/matrix-org/synapse-dinsic/issues/25 - @defer.inlineCallbacks - def proxy_lookup_3pid(self, id_server, medium, address): + async def proxy_lookup_3pid( + self, id_server: str, medium: str, address: str + ) -> JsonDict: """Looks up a 3pid in the passed identity server. Args: - id_server (str): The server name (including port, if required) + id_server: The server name (including port, if required) of the identity server to use. - medium (str): The type of the third party identifier (e.g. "email"). - address (str): The third party identifier (e.g. "foo@example.com"). + medium: The type of the third party identifier (e.g. "email"). + address: The third party identifier (e.g. "foo@example.com"). Returns: - Deferred[dict]: The result of the lookup. See + The result of the lookup. See https://matrix.org/docs/spec/identity_service/r0.1.0.html#association-lookup for details """ @@ -643,16 +642,11 @@ class IdentityHandler(BaseHandler): id_server_url = self.rewrite_id_server_url(id_server, add_https=True) try: - data = yield self.http_client.get_json( + data = await self.http_client.get_json( "%s/_matrix/identity/api/v1/lookup" % (id_server_url,), {"medium": medium, "address": address}, ) - if "mxid" in data: - if "signatures" not in data: - raise AuthError(401, "No signatures on 3pid binding") - yield self._verify_any_signature(data, id_server) - except HttpResponseException as e: logger.info("Proxied lookup failed: %r", e) raise e.to_synapse_error() @@ -662,18 +656,19 @@ class IdentityHandler(BaseHandler): return data - @defer.inlineCallbacks - def proxy_bulk_lookup_3pid(self, id_server, threepids): + async def proxy_bulk_lookup_3pid( + self, id_server: str, threepids: List[List[str]] + ) -> JsonDict: """Looks up given 3pids in the passed identity server. Args: - id_server (str): The server name (including port, if required) + id_server: The server name (including port, if required) of the identity server to use. - threepids ([[str, str]]): The third party identifiers to lookup, as + threepids: The third party identifiers to lookup, as a list of 2-string sized lists ([medium, address]). Returns: - Deferred[dict]: The result of the lookup. See + The result of the lookup. See https://matrix.org/docs/spec/identity_service/r0.1.0.html#association-lookup for details """ @@ -685,7 +680,7 @@ class IdentityHandler(BaseHandler): id_server_url = self.rewrite_id_server_url(id_server, add_https=True) try: - data = yield self.http_client.post_json_get_json( + data = await self.http_client.post_json_get_json( "%s/_matrix/identity/api/v1/bulk_lookup" % (id_server_url,), {"threepids": threepids}, ) @@ -697,7 +692,7 @@ class IdentityHandler(BaseHandler): logger.info("Failed to contact %s: %s", id_server, e) raise ProxiedRequestError(503, "Failed to contact identity server") - defer.returnValue(data) + return data async def lookup_3pid( self, diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 465b2a19bc..d5b12403f9 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -17,7 +17,7 @@ import logging from typing import TYPE_CHECKING, Dict, List, Optional, Tuple -from canonicaljson import encode_canonical_json, json +from canonicaljson import encode_canonical_json from twisted.internet.interfaces import IDelayedCall @@ -55,6 +55,7 @@ from synapse.types import ( UserID, create_requester, ) +from synapse.util import json_decoder from synapse.util.async_helpers import Linearizer from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.metrics import measure_func @@ -867,7 +868,7 @@ class EventCreationHandler(object): # Ensure that we can round trip before trying to persist in db try: dump = frozendict_json_encoder.encode(event.content) - json.loads(dump) + json_decoder.decode(dump) except Exception: logger.exception("Failed to encode content: %r", event.content) raise @@ -963,7 +964,7 @@ class EventCreationHandler(object): allow_none=True, ) - is_admin_redaction = ( + is_admin_redaction = bool( original_event and event.sender != original_event.sender ) @@ -1083,8 +1084,8 @@ class EventCreationHandler(object): auth_events_ids = self.auth.compute_auth_events( event, prev_state_ids, for_verification=True ) - auth_events = await self.store.get_events(auth_events_ids) - auth_events = {(e.type, e.state_key): e for e in auth_events.values()} + auth_events_map = await self.store.get_events(auth_events_ids) + auth_events = {(e.type, e.state_key): e for e in auth_events_map.values()} room_version = await self.store.get_room_version_id(event.room_id) room_version_obj = KNOWN_ROOM_VERSIONS[room_version] diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 87d28a7ae9..dd3703cbd2 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json import logging from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar from urllib.parse import urlencode @@ -39,6 +38,7 @@ from synapse.http.server import respond_with_html from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable from synapse.types import UserID, map_username_to_mxid_localpart +from synapse.util import json_decoder if TYPE_CHECKING: from synapse.server import HomeServer @@ -367,7 +367,7 @@ class OidcHandler: # and check for an error field. If not, we respond with a generic # error message. try: - resp = json.loads(resp_body.decode("utf-8")) + resp = json_decoder.decode(resp_body.decode("utf-8")) error = resp["error"] description = resp.get("error_description", error) except (ValueError, KeyError): @@ -384,7 +384,7 @@ class OidcHandler: # Since it is a not a 5xx code, body should be a valid JSON. It will # raise if not. - resp = json.loads(resp_body.decode("utf-8")) + resp = json_decoder.decode(resp_body.decode("utf-8")) if "error" in resp: error = resp["error"] diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index b05aa89455..4f3198896e 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -133,8 +133,12 @@ class BaseProfileHandler(BaseHandler): body = {"batchnum": batchnum, "batch": batch, "origin_server": self.hs.hostname} signed_body = sign_json(body, self.hs.hostname, self.hs.config.signing_key[0]) try: - yield self.http_client.post_json_get_json(url, signed_body) - yield self.store.update_replication_batch_for_host(host, batchnum) + yield defer.ensureDeferred( + self.http_client.post_json_get_json(url, signed_body) + ) + yield defer.ensureDeferred( + self.store.update_replication_batch_for_host(host, batchnum) + ) logger.info("Sucessfully replicated profile batch %d to %s", batchnum, host) except Exception: # This will get retried when the looping call next comes around diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 3cd821340b..ca9b644d0b 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -747,7 +747,7 @@ class RoomMemberHandler(object): guest_access = await self.store.get_event(guest_access_id) - return ( + return bool( guest_access and guest_access.content and "guest_access" in guest_access.content diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index a011e9fe29..9146dc1a3b 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -16,13 +16,12 @@ import logging from typing import Any -from canonicaljson import json - from twisted.web.client import PartialDownloadError from synapse.api.constants import LoginType from synapse.api.errors import Codes, LoginError, SynapseError from synapse.config.emailconfig import ThreepidBehaviour +from synapse.util import json_decoder logger = logging.getLogger(__name__) @@ -117,7 +116,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker): except PartialDownloadError as pde: # Twisted is silly data = pde.response - resp_body = json.loads(data.decode("utf-8")) + resp_body = json_decoder.decode(data.decode("utf-8")) if "success" in resp_body: # Note that we do NOT check the hostname here: we explicitly diff --git a/synapse/http/client.py b/synapse/http/client.py index 8aeb70cdec..dad01a8e56 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -19,7 +19,7 @@ import urllib from io import BytesIO import treq -from canonicaljson import encode_canonical_json, json +from canonicaljson import encode_canonical_json from netaddr import IPAddress from prometheus_client import Counter from zope.interface import implementer, provider @@ -47,6 +47,7 @@ from synapse.http import ( from synapse.http.proxyagent import ProxyAgent from synapse.logging.context import make_deferred_yieldable from synapse.logging.opentracing import set_tag, start_active_span, tags +from synapse.util import json_decoder from synapse.util.async_helpers import timeout_deferred logger = logging.getLogger(__name__) @@ -391,7 +392,7 @@ class SimpleHttpClient(object): body = await make_deferred_yieldable(readBody(response)) if 200 <= response.code < 300: - return json.loads(body.decode("utf-8")) + return json_decoder.decode(body.decode("utf-8")) else: raise HttpResponseException( response.code, response.phrase.decode("ascii", errors="replace"), body @@ -433,7 +434,7 @@ class SimpleHttpClient(object): body = await make_deferred_yieldable(readBody(response)) if 200 <= response.code < 300: - return json.loads(body.decode("utf-8")) + return json_decoder.decode(body.decode("utf-8")) else: raise HttpResponseException( response.code, response.phrase.decode("ascii", errors="replace"), body @@ -463,7 +464,7 @@ class SimpleHttpClient(object): actual_headers.update(headers) body = await self.get_raw(uri, args, headers=headers) - return json.loads(body.decode("utf-8")) + return json_decoder.decode(body.decode("utf-8")) async def put_json(self, uri, json_body, args={}, headers=None): """ Puts some json to the given URI. @@ -506,7 +507,7 @@ class SimpleHttpClient(object): body = await make_deferred_yieldable(readBody(response)) if 200 <= response.code < 300: - return json.loads(body.decode("utf-8")) + return json_decoder.decode(body.decode("utf-8")) else: raise HttpResponseException( response.code, response.phrase.decode("ascii", errors="replace"), body diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py index 89a3b041ce..f794315deb 100644 --- a/synapse/http/federation/well_known_resolver.py +++ b/synapse/http/federation/well_known_resolver.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import logging import random import time @@ -26,7 +25,7 @@ from twisted.web.http import stringToDatetime from twisted.web.http_headers import Headers from synapse.logging.context import make_deferred_yieldable -from synapse.util import Clock +from synapse.util import Clock, json_decoder from synapse.util.caches.ttlcache import TTLCache from synapse.util.metrics import Measure @@ -181,7 +180,7 @@ class WellKnownResolver(object): if response.code != 200: raise Exception("Non-200 response %s" % (response.code,)) - parsed_body = json.loads(body.decode("utf-8")) + parsed_body = json_decoder.decode(body.decode("utf-8")) logger.info("Response from .well-known: %s", parsed_body) result = parsed_body["m.server"].encode("ascii") diff --git a/synapse/http/server.py b/synapse/http/server.py index 37fdf14405..8d791bd2ca 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -500,7 +500,7 @@ class RootOptionsRedirectResource(OptionsResource, RootRedirect): pass -@implementer(interfaces.IPullProducer) +@implementer(interfaces.IPushProducer) class _ByteProducer: """ Iteratively write bytes to the request. @@ -515,52 +515,64 @@ class _ByteProducer: ): self._request = request self._iterator = iterator + self._paused = False - def start(self) -> None: - self._request.registerProducer(self, False) + # Register the producer and start producing data. + self._request.registerProducer(self, True) + self.resumeProducing() def _send_data(self, data: List[bytes]) -> None: """ - Send a list of strings as a response to the request. + Send a list of bytes as a chunk of a response. """ if not data: return self._request.write(b"".join(data)) + def pauseProducing(self) -> None: + self._paused = True + def resumeProducing(self) -> None: # We've stopped producing in the meantime (note that this might be # re-entrant after calling write). if not self._request: return - # Get the next chunk and write it to the request. - # - # The output of the JSON encoder is coalesced until min_chunk_size is - # reached. (This is because JSON encoders produce a very small output - # per iteration.) - # - # Note that buffer stores a list of bytes (instead of appending to - # bytes) to hopefully avoid many allocations. - buffer = [] - buffered_bytes = 0 - while buffered_bytes < self.min_chunk_size: - try: - data = next(self._iterator) - buffer.append(data) - buffered_bytes += len(data) - except StopIteration: - # The entire JSON object has been serialized, write any - # remaining data, finalize the producer and the request, and - # clean-up any references. - self._send_data(buffer) - self._request.unregisterProducer() - self._request.finish() - self.stopProducing() - return - - self._send_data(buffer) + self._paused = False + + # Write until there's backpressure telling us to stop. + while not self._paused: + # Get the next chunk and write it to the request. + # + # The output of the JSON encoder is buffered and coalesced until + # min_chunk_size is reached. This is because JSON encoders produce + # very small output per iteration and the Request object converts + # each call to write() to a separate chunk. Without this there would + # be an explosion in bytes written (e.g. b"{" becoming "1\r\n{\r\n"). + # + # Note that buffer stores a list of bytes (instead of appending to + # bytes) to hopefully avoid many allocations. + buffer = [] + buffered_bytes = 0 + while buffered_bytes < self.min_chunk_size: + try: + data = next(self._iterator) + buffer.append(data) + buffered_bytes += len(data) + except StopIteration: + # The entire JSON object has been serialized, write any + # remaining data, finalize the producer and the request, and + # clean-up any references. + self._send_data(buffer) + self._request.unregisterProducer() + self._request.finish() + self.stopProducing() + return + + self._send_data(buffer) def stopProducing(self) -> None: + # Clear a circular reference. self._request = None @@ -620,8 +632,7 @@ def respond_with_json( if send_cors: set_cors_headers(request) - producer = _ByteProducer(request, encoder(json_object)) - producer.start() + _ByteProducer(request, encoder(json_object)) return NOT_DONE_YET diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index a34e5ead88..53acba56cb 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -17,9 +17,8 @@ import logging -from canonicaljson import json - from synapse.api.errors import Codes, SynapseError +from synapse.util import json_decoder logger = logging.getLogger(__name__) @@ -215,7 +214,7 @@ def parse_json_value_from_request(request, allow_empty_body=False): return None try: - content = json.loads(content_bytes.decode("utf-8")) + content = json_decoder.decode(content_bytes.decode("utf-8")) except Exception as e: logger.warning("Unable to parse JSON: %s", e) raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 21dbd9f415..abe532d350 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -177,6 +177,7 @@ from canonicaljson import json from twisted.internet import defer from synapse.config import ConfigError +from synapse.util import json_decoder if TYPE_CHECKING: from synapse.http.site import SynapseRequest @@ -499,7 +500,9 @@ def start_active_span_from_edu( if opentracing is None: return _noop_context_manager() - carrier = json.loads(edu_content.get("context", "{}")).get("opentracing", {}) + carrier = json_decoder.decode(edu_content.get("context", "{}")).get( + "opentracing", {} + ) context = opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier) _references = [ opentracing.child_of(span_context_from_string(x)) @@ -699,7 +702,7 @@ def span_context_from_string(carrier): Returns: The active span context decoded from a string. """ - carrier = json.loads(carrier) + carrier = json_decoder.decode(carrier) return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier) diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index f766d16db6..4cd7932e5b 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -175,7 +175,7 @@ def run_as_background_process(desc: str, func, *args, **kwargs): It returns a Deferred which completes when the function completes, but it doesn't follow the synapse logcontext rules, which makes it appropriate for passing to clock.looping_call and friends (or for firing-and-forgetting in the middle of a - normal synapse inlineCallbacks function). + normal synapse async function). Args: desc: a description for this background process type diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py index 9d1d173b2f..d43eaf3a29 100644 --- a/synapse/replication/slave/storage/_slaved_id_tracker.py +++ b/synapse/replication/slave/storage/_slaved_id_tracker.py @@ -33,3 +33,11 @@ class SlavedIdTracker(object): int """ return self._current + + def get_current_token_for_writer(self, instance_name: str) -> int: + """Returns the position of the given writer. + + For streams with single writers this is equivalent to + `get_current_token`. + """ + return self.get_current_token() diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 590187df46..90d90833f9 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import PushRulesStream from synapse.storage.databases.main.push_rule import PushRulesWorkerStore @@ -21,16 +22,13 @@ from .events import SlavedEventStore class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): - def get_push_rules_stream_token(self): - return ( - self._push_rules_stream_id_gen.get_current_token(), - self._stream_id_gen.get_current_token(), - ) - def get_max_push_rules_stream_id(self): return self._push_rules_stream_id_gen.get_current_token() def process_replication_rows(self, stream_name, instance_name, token, rows): + # We assert this for the benefit of mypy + assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker) + if stream_name == PushRulesStream.NAME: self._push_rules_stream_id_gen.advance(token) for row in rows: diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index d853e4447e..8cd47770c1 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -21,9 +21,7 @@ import abc import logging from typing import Tuple, Type -from canonicaljson import json - -from synapse.util import json_encoder as _json_encoder +from synapse.util import json_decoder, json_encoder logger = logging.getLogger(__name__) @@ -125,7 +123,7 @@ class RdataCommand(Command): stream_name, instance_name, None if token == "batch" else int(token), - json.loads(row_json), + json_decoder.decode(row_json), ) def to_line(self): @@ -134,7 +132,7 @@ class RdataCommand(Command): self.stream_name, self.instance_name, str(self.token) if self.token is not None else "batch", - _json_encoder.encode(self.row), + json_encoder.encode(self.row), ) ) @@ -359,7 +357,7 @@ class UserIpCommand(Command): def from_line(cls, line): user_id, jsn = line.split(" ", 1) - access_token, ip, user_agent, device_id, last_seen = json.loads(jsn) + access_token, ip, user_agent, device_id, last_seen = json_decoder.decode(jsn) return cls(user_id, access_token, ip, user_agent, device_id, last_seen) @@ -367,7 +365,7 @@ class UserIpCommand(Command): return ( self.user_id + " " - + _json_encoder.encode( + + json_encoder.encode( ( self.access_token, self.ip, diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 7a42de3f7d..8c3caf30c9 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -352,7 +352,7 @@ class PushRulesStream(Stream): ) def _current_token(self, instance_name: str) -> int: - push_rules_token, _ = self.store.get_push_rules_stream_token() + push_rules_token = self.store.get_max_push_rules_stream_id() return push_rules_token @@ -405,7 +405,7 @@ class CachesStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_cache_stream_token, + store.get_cache_stream_token_for_writer, store.get_all_updated_caches, ) diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index e2df638cc5..e781a3bcf4 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -159,7 +159,7 @@ class PushRuleRestServlet(RestServlet): return 200, {} def notify_user(self, user_id): - stream_id, _ = self.store.get_push_rules_stream_token() + stream_id = self.store.get_max_push_rules_stream_id() self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id]) async def set_rule_attr(self, user_id, spec, val): diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index ed82144475..bc914d920e 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -21,8 +21,6 @@ import re from typing import List, Optional from urllib import parse as urlparse -from canonicaljson import json - from synapse.api.constants import EventTypes, Membership from synapse.api.errors import ( AuthError, @@ -46,6 +44,7 @@ from synapse.rest.client.v2_alpha._base import client_patterns from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID +from synapse.util import json_decoder MYPY = False if MYPY: @@ -519,7 +518,9 @@ class RoomMessageListRestServlet(RestServlet): filter_str = parse_string(request, b"filter", encoding="utf-8") if filter_str: filter_json = urlparse.unquote(filter_str) - event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter] + event_filter = Filter( + json_decoder.decode(filter_json) + ) # type: Optional[Filter] if ( event_filter and event_filter.filter_json.get("event_format", "client") @@ -631,7 +632,9 @@ class RoomEventContextServlet(RestServlet): filter_str = parse_string(request, b"filter", encoding="utf-8") if filter_str: filter_json = urlparse.unquote(filter_str) - event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter] + event_filter = Filter( + json_decoder.decode(filter_json) + ) # type: Optional[Filter] else: event_filter = None diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 7363d8d989..6b945e1849 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -1005,12 +1005,11 @@ class ThreepidLookupRestServlet(RestServlet): self.auth = hs.get_auth() self.identity_handler = hs.get_handlers().identity_handler - @defer.inlineCallbacks - def on_GET(self, request): + async def on_GET(self, request): """Proxy a /_matrix/identity/api/v1/lookup request to an identity server """ - yield self.auth.get_user_by_req(request) + await self.auth.get_user_by_req(request) # Verify query parameters query_params = request.args @@ -1023,9 +1022,9 @@ class ThreepidLookupRestServlet(RestServlet): # Proxy the request to the identity server. lookup_3pid handles checking # if the lookup is allowed so we don't need to do it here. - ret = yield self.identity_handler.proxy_lookup_3pid(id_server, medium, address) + ret = await self.identity_handler.proxy_lookup_3pid(id_server, medium, address) - defer.returnValue((200, ret)) + return 200, ret class ThreepidBulkLookupRestServlet(RestServlet): @@ -1036,12 +1035,11 @@ class ThreepidBulkLookupRestServlet(RestServlet): self.auth = hs.get_auth() self.identity_handler = hs.get_handlers().identity_handler - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): """Proxy a /_matrix/identity/api/v1/bulk_lookup request to an identity server """ - yield self.auth.get_user_by_req(request) + await self.auth.get_user_by_req(request) body = parse_json_object_from_request(request) @@ -1049,11 +1047,11 @@ class ThreepidBulkLookupRestServlet(RestServlet): # Proxy the request to the identity server. lookup_3pid handles checking # if the lookup is allowed so we don't need to do it here. - ret = yield self.identity_handler.proxy_bulk_lookup_3pid( + ret = await self.identity_handler.proxy_bulk_lookup_3pid( body["id_server"], body["threepids"] ) - defer.returnValue((200, ret)) + return 200, ret def assert_valid_next_link(hs: "HomeServer", next_link: str): diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index a5c24fbd63..96488b131a 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -16,8 +16,6 @@ import itertools import logging -from canonicaljson import json - from synapse.api.constants import PresenceState from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection @@ -29,6 +27,7 @@ from synapse.handlers.presence import format_user_presence_state from synapse.handlers.sync import SyncConfig from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.types import StreamToken +from synapse.util import json_decoder from ._base import client_patterns, set_timeline_upper_limit @@ -125,7 +124,7 @@ class SyncRestServlet(RestServlet): filter_collection = DEFAULT_FILTER_COLLECTION elif filter_id.startswith("{"): try: - filter_object = json.loads(filter_id) + filter_object = json_decoder.decode(filter_id) set_timeline_upper_limit( filter_object, self.hs.config.filter_timeline_limit ) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index e266204f95..5db7f81c2d 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -15,19 +15,19 @@ import logging from typing import Dict, Set -from canonicaljson import json from signedjson.sign import sign_json from synapse.api.errors import Codes, SynapseError from synapse.crypto.keyring import ServerKeyFetcher from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.servlet import parse_integer, parse_json_object_from_request +from synapse.util import json_decoder logger = logging.getLogger(__name__) class RemoteKey(DirectServeJsonResource): - """HTTP resource for retreiving the TLS certificate and NACL signature + """HTTP resource for retrieving the TLS certificate and NACL signature verification keys for a collection of servers. Checks that the reported X.509 TLS certificate matches the one used in the HTTPS connection. Checks that the NACL signature for the remote server is valid. Returns a dict of @@ -209,13 +209,15 @@ class RemoteKey(DirectServeJsonResource): # Cast to bytes since postgresql returns a memoryview. json_results.add(bytes(result["key_json"])) + # If there is a cache miss, request the missing keys, then recurse (and + # ensure the result is sent). if cache_misses and query_remote_on_cache_miss: await self.fetcher.get_keys(cache_misses) await self.query_keys(request, query, query_remote_on_cache_miss=False) else: signed_keys = [] for key_json in json_results: - key_json = json.loads(key_json.decode("utf-8")) + key_json = json_decoder.decode(key_json.decode("utf-8")) for signing_key in self.config.key_server_signing_keys: key_json = sign_json(key_json, self.config.server_name, signing_key) diff --git a/synapse/spam_checker_api/__init__.py b/synapse/spam_checker_api/__init__.py index 9b78924d96..4d9b13ac04 100644 --- a/synapse/spam_checker_api/__init__.py +++ b/synapse/spam_checker_api/__init__.py @@ -51,5 +51,5 @@ class SpamCheckerApi(object): state_ids = yield self._store.get_filtered_current_state_ids( room_id=room_id, state_filter=StateFilter.from_types(types) ) - state = yield self._store.get_events(state_ids.values()) + state = yield defer.ensureDeferred(self._store.get_events(state_ids.values())) return state.values() diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index a1d3884667..dba8d91eef 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -641,7 +641,7 @@ class StateResolutionStore(object): allow_rejected (bool): If True return rejected events. Returns: - Deferred[dict[str, FrozenEvent]]: Dict from event_id to event. + Awaitable[dict[str, FrozenEvent]]: Dict from event_id to event. """ return self.store.get_events( diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 6814bf5fcf..ab49d227de 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -19,12 +19,11 @@ import random from abc import ABCMeta from typing import Any, Optional -from canonicaljson import json - from synapse.storage.database import LoggingTransaction # noqa: F401 from synapse.storage.database import make_in_list_sql_clause # noqa: F401 from synapse.storage.database import DatabasePool from synapse.types import Collection, get_domain_from_id +from synapse.util import json_decoder logger = logging.getLogger(__name__) @@ -99,13 +98,13 @@ def db_to_json(db_content): if isinstance(db_content, memoryview): db_content = db_content.tobytes() - # Decode it to a Unicode string before feeding it to json.loads, since + # Decode it to a Unicode string before feeding it to the JSON decoder, since # Python 3.5 does not support deserializing bytes. if isinstance(db_content, (bytes, bytearray)): db_content = db_content.decode("utf8") try: - return json.loads(db_content) + return json_decoder.decode(db_content) except Exception: logging.warning("Tried to decode '%r' as JSON and failed", db_content) raise diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 8a9e06efcf..b9aef96b08 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -516,14 +516,16 @@ class DatabasePool(object): logger.warning("Starting db txn '%s' from sentinel context", desc) try: - result = yield self.runWithConnection( - self.new_transaction, - desc, - after_callbacks, - exception_callbacks, - func, - *args, - **kwargs + result = yield defer.ensureDeferred( + self.runWithConnection( + self.new_transaction, + desc, + after_callbacks, + exception_callbacks, + func, + *args, + **kwargs + ) ) for after_callback, after_args, after_kwargs in after_callbacks: @@ -535,8 +537,7 @@ class DatabasePool(object): return result - @defer.inlineCallbacks - def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any): + async def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any) -> Any: """Wraps the .runWithConnection() method on the underlying db_pool. Arguments: @@ -547,7 +548,7 @@ class DatabasePool(object): kwargs: named args to pass to `func` Returns: - Deferred: The result of func + The result of func """ parent_context = current_context() # type: Optional[LoggingContextOrSentinel] if not parent_context: @@ -570,12 +571,10 @@ class DatabasePool(object): return func(conn, *args, **kwargs) - result = yield make_deferred_yieldable( + return await make_deferred_yieldable( self._db_pool.runWithConnection(inner_func, *args, **kwargs) ) - return result - @staticmethod def cursor_to_dict(cursor): """Converts a SQL cursor into an list of dicts. diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 10de446065..1e7637a6f5 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -299,8 +299,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore): }, ) - def get_cache_stream_token(self, instance_name): + def get_cache_stream_token_for_writer(self, instance_name: str) -> int: if self._cache_id_gen: - return self._cache_id_gen.get_current_token(instance_name) + return self._cache_id_gen.get_current_token_for_writer(instance_name) else: return 0 diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 431bd76693..4826be630c 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -30,7 +30,7 @@ logger = logging.getLogger(__name__) class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore): - def get_auth_chain(self, event_ids, include_given=False): + async def get_auth_chain(self, event_ids, include_given=False): """Get auth events for given event_ids. The events *must* be state events. Args: @@ -40,9 +40,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas Returns: list of events """ - return self.get_auth_chain_ids( + event_ids = await self.get_auth_chain_ids( event_ids, include_given=include_given - ).addCallback(self.get_events_as_list) + ) + return await self.get_events_as_list(event_ids) def get_auth_chain_ids( self, @@ -459,7 +460,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn ) - def get_backfill_events(self, room_id, event_list, limit): + async def get_backfill_events(self, room_id, event_list, limit): """Get a list of Events for a given topic that occurred before (and including) the events in event_list. Return a list of max size `limit` @@ -469,17 +470,15 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas event_list (list) limit (int) """ - return ( - self.db_pool.runInteraction( - "get_backfill_events", - self._get_backfill_events, - room_id, - event_list, - limit, - ) - .addCallback(self.get_events_as_list) - .addCallback(lambda l: sorted(l, key=lambda e: -e.depth)) + event_ids = await self.db_pool.runInteraction( + "get_backfill_events", + self._get_backfill_events, + room_id, + event_list, + limit, ) + events = await self.get_events_as_list(event_ids) + return sorted(events, key=lambda e: -e.depth) def _get_backfill_events(self, txn, room_id, event_list, limit): logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit) @@ -540,8 +539,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas latest_events, limit, ) - events = await self.get_events_as_list(ids) - return events + return await self.get_events_as_list(ids) def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit): diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 8c63a0dc4d..4a3333c0db 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -19,9 +19,10 @@ import itertools import logging import threading from collections import namedtuple -from typing import List, Optional, Tuple +from typing import Dict, Iterable, List, Optional, Tuple, overload from constantly import NamedConstant, Names +from typing_extensions import Literal from twisted.internet import defer @@ -32,7 +33,7 @@ from synapse.api.room_versions import ( EventFormatVersions, RoomVersions, ) -from synapse.events import make_event_from_dict +from synapse.events import EventBase, make_event_from_dict from synapse.events.utils import prune_event from synapse.logging.context import PreserveLoggingContext, current_context from synapse.metrics.background_process_metrics import run_as_background_process @@ -42,8 +43,8 @@ from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import StreamIdGenerator -from synapse.types import get_domain_from_id -from synapse.util.caches.descriptors import Cache, cachedInlineCallbacks +from synapse.types import Collection, get_domain_from_id +from synapse.util.caches.descriptors import Cache, cached from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -137,8 +138,33 @@ class EventsWorkerStore(SQLBaseStore): desc="get_received_ts", ) - @defer.inlineCallbacks - def get_event( + # Inform mypy that if allow_none is False (the default) then get_event + # always returns an EventBase. + @overload + async def get_event( + self, + event_id: str, + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, + allow_none: Literal[False] = False, + check_room_id: Optional[str] = None, + ) -> EventBase: + ... + + @overload + async def get_event( + self, + event_id: str, + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, + allow_none: Literal[True] = False, + check_room_id: Optional[str] = None, + ) -> Optional[EventBase]: + ... + + async def get_event( self, event_id: str, redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, @@ -146,7 +172,7 @@ class EventsWorkerStore(SQLBaseStore): allow_rejected: bool = False, allow_none: bool = False, check_room_id: Optional[str] = None, - ): + ) -> Optional[EventBase]: """Get an event from the database by event_id. Args: @@ -171,12 +197,12 @@ class EventsWorkerStore(SQLBaseStore): If there is a mismatch, behave as per allow_none. Returns: - Deferred[EventBase|None] + The event, or None if the event was not found. """ if not isinstance(event_id, str): raise TypeError("Invalid event event_id %r" % (event_id,)) - events = yield self.get_events_as_list( + events = await self.get_events_as_list( [event_id], redact_behaviour=redact_behaviour, get_prev_content=get_prev_content, @@ -194,14 +220,13 @@ class EventsWorkerStore(SQLBaseStore): return event - @defer.inlineCallbacks - def get_events( + async def get_events( self, - event_ids: List[str], + event_ids: Iterable[str], redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, get_prev_content: bool = False, allow_rejected: bool = False, - ): + ) -> Dict[str, EventBase]: """Get events from the database Args: @@ -220,9 +245,9 @@ class EventsWorkerStore(SQLBaseStore): omits rejeted events from the response. Returns: - Deferred : Dict from event_id to event. + A mapping from event_id to event. """ - events = yield self.get_events_as_list( + events = await self.get_events_as_list( event_ids, redact_behaviour=redact_behaviour, get_prev_content=get_prev_content, @@ -231,14 +256,13 @@ class EventsWorkerStore(SQLBaseStore): return {e.event_id: e for e in events} - @defer.inlineCallbacks - def get_events_as_list( + async def get_events_as_list( self, - event_ids: List[str], + event_ids: Collection[str], redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, get_prev_content: bool = False, allow_rejected: bool = False, - ): + ) -> List[EventBase]: """Get events from the database and return in a list in the same order as given by `event_ids` arg. @@ -259,8 +283,8 @@ class EventsWorkerStore(SQLBaseStore): omits rejected events from the response. Returns: - Deferred[list[EventBase]]: List of events fetched from the database. The - events are in the same order as `event_ids` arg. + List of events fetched from the database. The events are in the same + order as `event_ids` arg. Note that the returned list may be smaller than the list of event IDs if not all events could be fetched. @@ -270,7 +294,7 @@ class EventsWorkerStore(SQLBaseStore): return [] # there may be duplicates so we cast the list to a set - event_entry_map = yield self._get_events_from_cache_or_db( + event_entry_map = await self._get_events_from_cache_or_db( set(event_ids), allow_rejected=allow_rejected ) @@ -305,7 +329,7 @@ class EventsWorkerStore(SQLBaseStore): continue redacted_event_id = entry.event.redacts - event_map = yield self._get_events_from_cache_or_db([redacted_event_id]) + event_map = await self._get_events_from_cache_or_db([redacted_event_id]) original_event_entry = event_map.get(redacted_event_id) if not original_event_entry: # we don't have the redacted event (or it was rejected). @@ -371,7 +395,7 @@ class EventsWorkerStore(SQLBaseStore): if get_prev_content: if "replaces_state" in event.unsigned: - prev = yield self.get_event( + prev = await self.get_event( event.unsigned["replaces_state"], get_prev_content=False, allow_none=True, @@ -383,8 +407,7 @@ class EventsWorkerStore(SQLBaseStore): return events - @defer.inlineCallbacks - def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False): + async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False): """Fetch a bunch of events from the cache or the database. If events are pulled from the database, they will be cached for future lookups. @@ -399,7 +422,7 @@ class EventsWorkerStore(SQLBaseStore): rejected events are omitted from the response. Returns: - Deferred[Dict[str, _EventCacheEntry]]: + Dict[str, _EventCacheEntry]: map from event id to result """ event_entry_map = self._get_events_from_cache( @@ -417,7 +440,7 @@ class EventsWorkerStore(SQLBaseStore): # the events have been redacted, and if so pulling the redaction event out # of the database to check it. # - missing_events = yield self._get_events_from_db( + missing_events = await self._get_events_from_db( missing_events_ids, allow_rejected=allow_rejected ) @@ -525,8 +548,7 @@ class EventsWorkerStore(SQLBaseStore): with PreserveLoggingContext(): self.hs.get_reactor().callFromThread(fire, event_list, e) - @defer.inlineCallbacks - def _get_events_from_db(self, event_ids, allow_rejected=False): + async def _get_events_from_db(self, event_ids, allow_rejected=False): """Fetch a bunch of events from the database. Returned events will be added to the cache for future lookups. @@ -540,7 +562,7 @@ class EventsWorkerStore(SQLBaseStore): rejected events are omitted from the response. Returns: - Deferred[Dict[str, _EventCacheEntry]]: + Dict[str, _EventCacheEntry]: map from event id to result. May return extra events which weren't asked for. """ @@ -548,7 +570,7 @@ class EventsWorkerStore(SQLBaseStore): events_to_fetch = event_ids while events_to_fetch: - row_map = yield self._enqueue_events(events_to_fetch) + row_map = await self._enqueue_events(events_to_fetch) # we need to recursively fetch any redactions of those events redaction_ids = set() @@ -574,8 +596,20 @@ class EventsWorkerStore(SQLBaseStore): if not allow_rejected and rejected_reason: continue - d = db_to_json(row["json"]) - internal_metadata = db_to_json(row["internal_metadata"]) + # If the event or metadata cannot be parsed, log the error and act + # as if the event is unknown. + try: + d = db_to_json(row["json"]) + except ValueError: + logger.error("Unable to parse json from event: %s", event_id) + continue + try: + internal_metadata = db_to_json(row["internal_metadata"]) + except ValueError: + logger.error( + "Unable to parse internal_metadata from event: %s", event_id + ) + continue format_version = row["format_version"] if format_version is None: @@ -650,8 +684,7 @@ class EventsWorkerStore(SQLBaseStore): return result_map - @defer.inlineCallbacks - def _enqueue_events(self, events): + async def _enqueue_events(self, events): """Fetches events from the database using the _event_fetch_list. This allows batch and bulk fetching of events - it allows us to fetch events without having to create a new transaction for each request for events. @@ -660,7 +693,7 @@ class EventsWorkerStore(SQLBaseStore): events (Iterable[str]): events to be fetched. Returns: - Deferred[Dict[str, Dict]]: map from event id to row data from the database. + Dict[str, Dict]: map from event id to row data from the database. May contain events that weren't requested. """ @@ -683,7 +716,7 @@ class EventsWorkerStore(SQLBaseStore): logger.debug("Loading %d events: %s", len(events), events) with PreserveLoggingContext(): - row_map = yield events_d + row_map = await events_d logger.debug("Loaded %d events (%d rows)", len(events), len(row_map)) return row_map @@ -842,33 +875,29 @@ class EventsWorkerStore(SQLBaseStore): # no valid redaction found for this event return None - @defer.inlineCallbacks - def have_events_in_timeline(self, event_ids): + async def have_events_in_timeline(self, event_ids): """Given a list of event ids, check if we have already processed and stored them as non outliers. """ - rows = yield defer.ensureDeferred( - self.db_pool.simple_select_many_batch( - table="events", - retcols=("event_id",), - column="event_id", - iterable=list(event_ids), - keyvalues={"outlier": False}, - desc="have_events_in_timeline", - ) + rows = await self.db_pool.simple_select_many_batch( + table="events", + retcols=("event_id",), + column="event_id", + iterable=list(event_ids), + keyvalues={"outlier": False}, + desc="have_events_in_timeline", ) return {r["event_id"] for r in rows} - @defer.inlineCallbacks - def have_seen_events(self, event_ids): + async def have_seen_events(self, event_ids): """Given a list of event ids, check if we have already processed them. Args: event_ids (iterable[str]): Returns: - Deferred[set[str]]: The events we have already seen. + set[str]: The events we have already seen. """ results = set() @@ -884,7 +913,7 @@ class EventsWorkerStore(SQLBaseStore): # break the input up into chunks of 100 input_iterator = iter(event_ids) for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []): - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "have_seen_events", have_seen_events_txn, chunk ) return results @@ -914,8 +943,7 @@ class EventsWorkerStore(SQLBaseStore): room_id, ) - @defer.inlineCallbacks - def get_room_complexity(self, room_id): + async def get_room_complexity(self, room_id): """ Get a rough approximation of the complexity of the room. This is used by remote servers to decide whether they wish to join the room or not. @@ -926,9 +954,9 @@ class EventsWorkerStore(SQLBaseStore): room_id (str) Returns: - Deferred[dict[str:int]] of complexity version to complexity. + dict[str:int] of complexity version to complexity. """ - state_events = yield self.get_current_state_event_counts(room_id) + state_events = await self.get_current_state_event_counts(room_id) # Call this one "v1", so we can introduce new ones as we want to develop # it. @@ -1165,9 +1193,9 @@ class EventsWorkerStore(SQLBaseStore): to_2, so_2 = await self.get_event_ordering(event_id2) return (to_1, so_1) > (to_2, so_2) - @cachedInlineCallbacks(max_entries=5000) - def get_event_ordering(self, event_id): - res = yield self.db_pool.simple_select_one( + @cached(max_entries=5000) + async def get_event_ordering(self, event_id): + res = await self.db_pool.simple_select_one( table="events", retcols=["topological_ordering", "stream_ordering"], keyvalues={"event_id": event_id}, diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index c2289a9557..a585e54812 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -30,7 +30,7 @@ from synapse.storage.databases.main.pusher import PusherWorkerStore from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException -from synapse.storage.util.id_generators import ChainedIdGenerator +from synapse.storage.util.id_generators import StreamIdGenerator from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -82,9 +82,9 @@ class PushRulesWorkerStore( super(PushRulesWorkerStore, self).__init__(database, db_conn, hs) if hs.config.worker.worker_app is None: - self._push_rules_stream_id_gen = ChainedIdGenerator( - self._stream_id_gen, db_conn, "push_rules_stream", "stream_id" - ) # type: Union[ChainedIdGenerator, SlavedIdTracker] + self._push_rules_stream_id_gen = StreamIdGenerator( + db_conn, "push_rules_stream", "stream_id" + ) # type: Union[StreamIdGenerator, SlavedIdTracker] else: self._push_rules_stream_id_gen = SlavedIdTracker( db_conn, "push_rules_stream", "stream_id" @@ -338,8 +338,9 @@ class PushRuleStore(PushRulesWorkerStore): ) -> None: conditions_json = json_encoder.encode(conditions) actions_json = json_encoder.encode(actions) - with self._push_rules_stream_id_gen.get_next() as ids: - stream_id, event_stream_ordering = ids + with self._push_rules_stream_id_gen.get_next() as stream_id: + event_stream_ordering = self._stream_id_gen.get_current_token() + if before or after: await self.db_pool.runInteraction( "_add_push_rule_relative_txn", @@ -559,8 +560,9 @@ class PushRuleStore(PushRulesWorkerStore): txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE" ) - with self._push_rules_stream_id_gen.get_next() as ids: - stream_id, event_stream_ordering = ids + with self._push_rules_stream_id_gen.get_next() as stream_id: + event_stream_ordering = self._stream_id_gen.get_current_token() + await self.db_pool.runInteraction( "delete_push_rule", delete_push_rule_txn, @@ -569,8 +571,9 @@ class PushRuleStore(PushRulesWorkerStore): ) async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None: - with self._push_rules_stream_id_gen.get_next() as ids: - stream_id, event_stream_ordering = ids + with self._push_rules_stream_id_gen.get_next() as stream_id: + event_stream_ordering = self._stream_id_gen.get_current_token() + await self.db_pool.runInteraction( "_set_push_rule_enabled_txn", self._set_push_rule_enabled_txn, @@ -643,8 +646,9 @@ class PushRuleStore(PushRulesWorkerStore): data={"actions": actions_json}, ) - with self._push_rules_stream_id_gen.get_next() as ids: - stream_id, event_stream_ordering = ids + with self._push_rules_stream_id_gen.get_next() as stream_id: + event_stream_ordering = self._stream_id_gen.get_current_token() + await self.db_pool.runInteraction( "set_push_rule_actions", set_push_rule_actions_txn, @@ -673,11 +677,5 @@ class PushRuleStore(PushRulesWorkerStore): self.push_rules_stream_cache.entity_has_changed, user_id, stream_id ) - def get_push_rules_stream_token(self): - """Get the position of the push rules stream. - Returns a pair of a stream id for the push_rules stream and the - room stream ordering it corresponds to.""" - return self._push_rules_stream_id_gen.get_current_token() - def get_max_push_rules_stream_id(self): - return self.get_push_rules_stream_token()[0] + return self._push_rules_stream_id_gen.get_current_token() diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 4377bddb8c..497f607703 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -379,7 +379,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): limit: int = 0, order: str = "DESC", ) -> Tuple[List[EventBase], str]: - """Get new room events in stream ordering since `from_key`. Args: diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index e2ddd01290..0bf772d4d1 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -16,7 +16,7 @@ import contextlib import threading from collections import deque -from typing import Dict, Set, Tuple +from typing import Dict, Set from typing_extensions import Deque @@ -158,63 +158,13 @@ class StreamIdGenerator(object): return self._current + def get_current_token_for_writer(self, instance_name: str) -> int: + """Returns the position of the given writer. -class ChainedIdGenerator(object): - """Used to generate new stream ids where the stream must be kept in sync - with another stream. It generates pairs of IDs, the first element is an - integer ID for this stream, the second element is the ID for the stream - that this stream needs to be kept in sync with.""" - - def __init__(self, chained_generator, db_conn, table, column): - self.chained_generator = chained_generator - self._table = table - self._lock = threading.Lock() - self._current_max = _load_current_id(db_conn, table, column) - self._unfinished_ids = deque() # type: Deque[Tuple[int, int]] - - def get_next(self): - """ - Usage: - with stream_id_gen.get_next() as (stream_id, chained_id): - # ... persist event ... - """ - with self._lock: - self._current_max += 1 - next_id = self._current_max - chained_id = self.chained_generator.get_current_token() - - self._unfinished_ids.append((next_id, chained_id)) - - @contextlib.contextmanager - def manager(): - try: - yield (next_id, chained_id) - finally: - with self._lock: - self._unfinished_ids.remove((next_id, chained_id)) - - return manager() - - def get_current_token(self): - """Returns the maximum stream id such that all stream ids less than or - equal to it have been successfully persisted. + For streams with single writers this is equivalent to + `get_current_token`. """ - with self._lock: - if self._unfinished_ids: - stream_id, chained_id = self._unfinished_ids[0] - return stream_id - 1, chained_id - - return self._current_max, self.chained_generator.get_current_token() - - def advance(self, token: int): - """Stub implementation for advancing the token when receiving updates - over replication; raises an exception as this instance should be the - only source of updates. - """ - - raise Exception( - "Attempted to advance token on source for table %r", self._table - ) + return self.get_current_token() class MultiWriterIdGenerator: @@ -298,7 +248,7 @@ class MultiWriterIdGenerator: # Assert the fetched ID is actually greater than what we currently # believe the ID to be. If not, then the sequence and table have got # out of sync somehow. - assert self.get_current_token() < next_id + assert self.get_current_token_for_writer(self._instance_name) < next_id with self._lock: self._unfinished_ids.add(next_id) @@ -344,16 +294,18 @@ class MultiWriterIdGenerator: curr = self._current_positions.get(self._instance_name, 0) self._current_positions[self._instance_name] = max(curr, next_id) - def get_current_token(self, instance_name: str = None) -> int: - """Gets the current position of a named writer (defaults to current - instance). - - Returns 0 if we don't have a position for the named writer (likely due - to it being a new writer). + def get_current_token(self) -> int: + """Returns the maximum stream id such that all stream ids less than or + equal to it have been successfully persisted. """ - if instance_name is None: - instance_name = self._instance_name + # Currently we don't support this operation, as it's not obvious how to + # condense the stream positions of multiple writers into a single int. + raise NotImplementedError() + + def get_current_token_for_writer(self, instance_name: str) -> int: + """Returns the position of the given writer. + """ with self._lock: return self._current_positions.get(instance_name, 0) diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 393e34b9fb..7ab46f42bf 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -39,7 +39,7 @@ class EventSources(object): self.store = hs.get_datastore() def get_current_token(self) -> StreamToken: - push_rules_key, _ = self.store.get_push_rules_stream_token() + push_rules_key = self.store.get_max_push_rules_stream_id() to_device_key = self.store.get_to_device_stream_token() device_list_key = self.store.get_device_stream_token() groups_key = self.store.get_group_stream_token() diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index b3f76428b6..b2a22dbd5c 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -25,8 +25,18 @@ from synapse.logging import context logger = logging.getLogger(__name__) -# Create a custom encoder to reduce the whitespace produced by JSON encoding. -json_encoder = json.JSONEncoder(separators=(",", ":")) + +def _reject_invalid_json(val): + """Do not allow Infinity, -Infinity, or NaN values in JSON.""" + raise json.JSONDecodeError("Invalid JSON value: '%s'" % val) + + +# Create a custom encoder to reduce the whitespace produced by JSON encoding and +# ensure that valid JSON is produced. +json_encoder = json.JSONEncoder(allow_nan=False, separators=(",", ":")) + +# Create a custom decoder to reject Python extensions to JSON. +json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json) def unwrapFirstError(failure): diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index c2d72a82cf..49d9fddcf0 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -285,16 +285,9 @@ class Cache(object): class _CacheDescriptorBase(object): - def __init__( - self, orig: _CachedFunction, num_args, inlineCallbacks, cache_context=False - ): + def __init__(self, orig: _CachedFunction, num_args, cache_context=False): self.orig = orig - if inlineCallbacks: - self.function_to_call = defer.inlineCallbacks(orig) - else: - self.function_to_call = orig - arg_spec = inspect.getfullargspec(orig) all_args = arg_spec.args @@ -364,7 +357,7 @@ class CacheDescriptor(_CacheDescriptorBase): invalidated) by adding a special "cache_context" argument to the function and passing that as a kwarg to all caches called. For example:: - @cachedInlineCallbacks(cache_context=True) + @cached(cache_context=True) def foo(self, key, cache_context): r1 = yield self.bar1(key, on_invalidate=cache_context.invalidate) r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate) @@ -382,17 +375,11 @@ class CacheDescriptor(_CacheDescriptorBase): max_entries=1000, num_args=None, tree=False, - inlineCallbacks=False, cache_context=False, iterable=False, ): - super(CacheDescriptor, self).__init__( - orig, - num_args=num_args, - inlineCallbacks=inlineCallbacks, - cache_context=cache_context, - ) + super().__init__(orig, num_args=num_args, cache_context=cache_context) self.max_entries = max_entries self.tree = tree @@ -465,9 +452,7 @@ class CacheDescriptor(_CacheDescriptorBase): observer = defer.succeed(cached_result_d) except KeyError: - ret = defer.maybeDeferred( - preserve_fn(self.function_to_call), obj, *args, **kwargs - ) + ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs) def onErr(f): cache.invalidate(cache_key) @@ -510,9 +495,7 @@ class CacheListDescriptor(_CacheDescriptorBase): of results. """ - def __init__( - self, orig, cached_method_name, list_name, num_args=None, inlineCallbacks=False - ): + def __init__(self, orig, cached_method_name, list_name, num_args=None): """ Args: orig (function) @@ -521,12 +504,8 @@ class CacheListDescriptor(_CacheDescriptorBase): num_args (int): number of positional arguments (excluding ``self``, but including list_name) to use as cache keys. Defaults to all named args of the function. - inlineCallbacks (bool): Whether orig is a generator that should - be wrapped by defer.inlineCallbacks """ - super(CacheListDescriptor, self).__init__( - orig, num_args=num_args, inlineCallbacks=inlineCallbacks - ) + super().__init__(orig, num_args=num_args) self.list_name = list_name @@ -631,7 +610,7 @@ class CacheListDescriptor(_CacheDescriptorBase): cached_defers.append( defer.maybeDeferred( - preserve_fn(self.function_to_call), **args_to_call + preserve_fn(self.orig), **args_to_call ).addCallbacks(complete_all, errback) ) @@ -695,21 +674,7 @@ def cached( ) -def cachedInlineCallbacks( - max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False -): - return lambda orig: CacheDescriptor( - orig, - max_entries=max_entries, - num_args=num_args, - tree=tree, - inlineCallbacks=True, - cache_context=cache_context, - iterable=iterable, - ) - - -def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=False): +def cachedList(cached_method_name, list_name, num_args=None): """Creates a descriptor that wraps a function in a `CacheListDescriptor`. Used to do batch lookups for an already created cache. A single argument @@ -725,8 +690,6 @@ def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=Fal do batch lookups in the cache. num_args (int): Number of arguments to use as the key in the cache (including list_name). Defaults to all named parameters. - inlineCallbacks (bool): Should the function be wrapped in an - `defer.inlineCallbacks`? Example: @@ -744,5 +707,4 @@ def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=Fal cached_method_name=cached_method_name, list_name=list_name, num_args=num_args, - inlineCallbacks=inlineCallbacks, ) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index ac629b8a28..d7f0c19c4c 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -71,7 +71,9 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_my_name(self): - yield self.store.set_profile_displayname(self.frank.localpart, "Frank", 1) + yield defer.ensureDeferred( + self.store.set_profile_displayname(self.frank.localpart, "Frank", 1) + ) displayname = yield defer.ensureDeferred( self.handler.get_displayname(self.frank) @@ -112,10 +114,17 @@ class ProfileTestCase(unittest.TestCase): self.hs.config.enable_set_displayname = False # Setting displayname for the first time is allowed - yield self.store.set_profile_displayname(self.frank.localpart, "Frank", 1) + yield defer.ensureDeferred( + self.store.set_profile_displayname(self.frank.localpart, "Frank", 1) + ) self.assertEquals( - (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank", + ( + yield defer.ensureDeferred( + self.store.get_profile_displayname(self.frank.localpart) + ) + ), + "Frank", ) # Setting displayname a second time is forbidden @@ -158,7 +167,9 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_incoming_fed_query(self): yield defer.ensureDeferred(self.store.create_profile("caroline")) - yield self.store.set_profile_displayname("caroline", "Caroline", 1) + yield defer.ensureDeferred( + self.store.set_profile_displayname("caroline", "Caroline", 1) + ) response = yield defer.ensureDeferred( self.query_handlers["profile"]( @@ -170,8 +181,10 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_my_avatar(self): - yield self.store.set_profile_avatar_url( - self.frank.localpart, "http://my.server/me.png", 1 + yield defer.ensureDeferred( + self.store.set_profile_avatar_url( + self.frank.localpart, "http://my.server/me.png", 1 + ) ) avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank)) @@ -211,8 +224,10 @@ class ProfileTestCase(unittest.TestCase): self.hs.config.enable_set_avatar_url = False # Setting displayname for the first time is allowed - yield self.store.set_profile_avatar_url( - self.frank.localpart, "http://my.server/me.png", 1 + yield defer.ensureDeferred( + self.store.set_profile_avatar_url( + self.frank.localpart, "http://my.server/me.png", 1 + ) ) self.assertEquals( diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index db52725cfe..2668662c9e 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -62,8 +62,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit" + str(i)}, "password": "monkey", } - request_data = json.dumps(params) - request, channel = self.make_request(b"POST", LOGIN_URL, request_data) + request, channel = self.make_request(b"POST", LOGIN_URL, params) self.render(request) if i == 5: @@ -76,14 +75,13 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # than 1min. self.assertTrue(retry_after_ms < 6000) - self.reactor.advance(retry_after_ms / 1000.0) + self.reactor.advance(retry_after_ms / 1000.0 + 1.0) params = { "type": "m.login.password", "identifier": {"type": "m.id.user", "user": "kermit" + str(i)}, "password": "monkey", } - request_data = json.dumps(params) request, channel = self.make_request(b"POST", LOGIN_URL, params) self.render(request) @@ -111,8 +109,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "monkey", } - request_data = json.dumps(params) - request, channel = self.make_request(b"POST", LOGIN_URL, request_data) + request, channel = self.make_request(b"POST", LOGIN_URL, params) self.render(request) if i == 5: @@ -132,7 +129,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "monkey", } - request_data = json.dumps(params) request, channel = self.make_request(b"POST", LOGIN_URL, params) self.render(request) @@ -160,8 +156,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "notamonkey", } - request_data = json.dumps(params) - request, channel = self.make_request(b"POST", LOGIN_URL, request_data) + request, channel = self.make_request(b"POST", LOGIN_URL, params) self.render(request) if i == 5: @@ -174,14 +169,13 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # than 1min. self.assertTrue(retry_after_ms < 6000) - self.reactor.advance(retry_after_ms / 1000.0) + self.reactor.advance(retry_after_ms / 1000.0 + 1.0) params = { "type": "m.login.password", "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "notamonkey", } - request_data = json.dumps(params) request, channel = self.make_request(b"POST", LOGIN_URL, params) self.render(request) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index fce79b38f2..ecf697e5e0 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -156,7 +156,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): else: self.assertEquals(channel.result["code"], b"200", channel.result) - self.reactor.advance(retry_after_ms / 1000.0) + self.reactor.advance(retry_after_ms / 1000.0 + 1.0) request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") self.render(request) @@ -182,7 +182,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): else: self.assertEquals(channel.result["code"], b"200", channel.result) - self.reactor.advance(retry_after_ms / 1000.0) + self.reactor.advance(retry_after_ms / 1000.0 + 1.0) request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") self.render(request) diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 2858d13558..23db821fb7 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -104,7 +104,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) self._rlsn._store.get_events = Mock( - return_value=defer.succeed({"123": mock_event}) + return_value=make_awaitable({"123": mock_event}) ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) # Would be better to check the content, but once == remove blocking event @@ -122,7 +122,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) self._rlsn._store.get_events = Mock( - return_value=defer.succeed({"123": mock_event}) + return_value=make_awaitable({"123": mock_event}) ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -217,7 +217,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) self._rlsn._store.get_events = Mock( - return_value=defer.succeed({"123": mock_event}) + return_value=make_awaitable({"123": mock_event}) ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index a425e66f37..17fbde284a 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -31,6 +31,7 @@ from synapse.storage.databases.main.appservice import ( ) from tests import unittest +from tests.test_utils import make_awaitable from tests.utils import setup_test_homeserver @@ -357,7 +358,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): other_events = [Mock(event_id="e5"), Mock(event_id="e6")] # we aren't testing store._base stuff here, so mock this out - self.store.get_events_as_list = Mock(return_value=defer.succeed(events)) + self.store.get_events_as_list = Mock(return_value=make_awaitable(events)) yield self._insert_txn(self.as_list[1]["id"], 9, other_events) yield self._insert_txn(service.id, 10, events) diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index 8e9a650f9f..43639ca286 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -353,6 +353,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[ "3" ] = 300000 + self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion() # All entries within time frame self.assertEqual( @@ -362,7 +363,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): 3, ) # Oldest room to expire - self.pump(1) + self.pump(1.01) self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion() self.assertEqual( len( diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index e845410dae..7a05194653 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -88,7 +88,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): id_gen = self._create_id_generator() self.assertEqual(id_gen.get_positions(), {"master": 7}) - self.assertEqual(id_gen.get_current_token("master"), 7) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) # Try allocating a new ID gen and check that we only see position # advanced after we leave the context manager. @@ -98,12 +98,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(stream_id, 8) self.assertEqual(id_gen.get_positions(), {"master": 7}) - self.assertEqual(id_gen.get_current_token("master"), 7) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) self.get_success(_get_next_async()) self.assertEqual(id_gen.get_positions(), {"master": 8}) - self.assertEqual(id_gen.get_current_token("master"), 8) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 8) def test_multi_instance(self): """Test that reads and writes from multiple processes are handled @@ -116,8 +116,8 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): second_id_gen = self._create_id_generator("second") self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7}) - self.assertEqual(first_id_gen.get_current_token("first"), 3) - self.assertEqual(first_id_gen.get_current_token("second"), 7) + self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3) + self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7) # Try allocating a new ID gen and check that we only see position # advanced after we leave the context manager. @@ -166,7 +166,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): id_gen = self._create_id_generator() self.assertEqual(id_gen.get_positions(), {"master": 7}) - self.assertEqual(id_gen.get_current_token("master"), 7) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) # Try allocating a new ID gen and check that we only see position # advanced after we leave the context manager. @@ -176,9 +176,9 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(stream_id, 8) self.assertEqual(id_gen.get_positions(), {"master": 7}) - self.assertEqual(id_gen.get_current_token("master"), 7) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) self.get_success(self.db_pool.runInteraction("test", _get_next_txn)) self.assertEqual(id_gen.get_positions(), {"master": 8}) - self.assertEqual(id_gen.get_current_token("master"), 8) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 8) diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py index 619ee0da15..745fa15e26 100644 --- a/tests/storage/test_main.py +++ b/tests/storage/test_main.py @@ -34,10 +34,12 @@ class DataStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_users_paginate(self): - yield self.store.register_user(self.user.to_string(), "pass") + yield defer.ensureDeferred( + self.store.register_user(self.user.to_string(), "pass") + ) yield defer.ensureDeferred(self.store.create_profile(self.user.localpart)) - yield self.store.set_profile_displayname( - self.user.localpart, self.displayname, 1 + yield defer.ensureDeferred( + self.store.set_profile_displayname(self.user.localpart, self.displayname, 1) ) users, total = yield self.store.get_users_paginate( diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 4d2b9e0d64..0363735d4f 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -366,11 +366,11 @@ class CachedListDescriptorTestCase(unittest.TestCase): def fn(self, arg1, arg2): pass - @descriptors.cachedList("fn", "args1", inlineCallbacks=True) - def list_fn(self, args1, arg2): + @descriptors.cachedList("fn", "args1") + async def list_fn(self, args1, arg2): assert current_context().request == "c1" # we want this to behave like an asynchronous function - yield run_on_reactor() + await run_on_reactor() assert current_context().request == "c1" return self.mock(args1, arg2) @@ -416,10 +416,10 @@ class CachedListDescriptorTestCase(unittest.TestCase): def fn(self, arg1, arg2): pass - @descriptors.cachedList("fn", "args1", inlineCallbacks=True) - def list_fn(self, args1, arg2): + @descriptors.cachedList("fn", "args1") + async def list_fn(self, args1, arg2): # we want this to behave like an asynchronous function - yield run_on_reactor() + await run_on_reactor() return self.mock(args1, arg2) obj = Cls() |