summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/api/auth.py5
-rw-r--r--synapse/app/generic_worker.py2
-rw-r--r--synapse/config/experimental.py19
-rw-r--r--synapse/config/homeserver.py2
-rw-r--r--synapse/config/registration.py2
-rw-r--r--synapse/config/saml2.py8
-rw-r--r--synapse/config/tls.py50
-rw-r--r--synapse/config/tracer.py37
-rw-r--r--synapse/crypto/keyring.py98
-rw-r--r--synapse/federation/federation_base.py17
-rw-r--r--synapse/federation/federation_client.py28
-rw-r--r--synapse/federation/transport/client.py92
-rw-r--r--synapse/federation/transport/server.py6
-rw-r--r--synapse/handlers/account_validity.py55
-rw-r--r--synapse/handlers/event_auth.py121
-rw-r--r--synapse/handlers/federation.py46
-rw-r--r--synapse/handlers/presence.py136
-rw-r--r--synapse/handlers/room_member.py20
-rw-r--r--synapse/handlers/send_email.py98
-rw-r--r--synapse/handlers/space_summary.py258
-rw-r--r--synapse/http/client.py7
-rw-r--r--synapse/http/matrixfederationclient.py172
-rw-r--r--synapse/http/site.py4
-rw-r--r--synapse/module_api/__init__.py63
-rw-r--r--synapse/push/mailer.py53
-rw-r--r--synapse/python_dependencies.py1
-rw-r--r--synapse/replication/http/presence.py11
-rw-r--r--synapse/replication/slave/storage/client_ips.py2
-rw-r--r--synapse/replication/slave/storage/transactions.py21
-rw-r--r--synapse/rest/admin/server_notice_servlet.py8
-rw-r--r--synapse/rest/key/v2/local_key_resource.py8
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py3
-rw-r--r--synapse/rest/media/v1/media_repository.py2
-rw-r--r--synapse/rest/media/v1/thumbnailer.py9
-rw-r--r--synapse/server.py5
-rw-r--r--synapse/storage/_base.py2
-rw-r--r--synapse/storage/databases/main/__init__.py4
-rw-r--r--synapse/storage/databases/main/client_ips.py2
-rw-r--r--synapse/storage/databases/main/devices.py4
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py4
-rw-r--r--synapse/storage/databases/main/events_worker.py1
-rw-r--r--synapse/storage/databases/main/keys.py2
-rw-r--r--synapse/storage/databases/main/presence.py76
-rw-r--r--synapse/storage/databases/main/registration.py3
-rw-r--r--synapse/storage/databases/main/transactions.py66
-rw-r--r--synapse/storage/databases/main/user_erasure_store.py13
-rw-r--r--synapse/storage/schema/main/delta/59/13users_to_send_full_presence_to.sql34
-rw-r--r--synapse/storage/state.py2
-rw-r--r--synapse/util/batching_queue.py153
-rw-r--r--synapse/util/caches/deferred_cache.py2
-rw-r--r--synapse/util/caches/descriptors.py15
-rw-r--r--synapse/util/caches/lrucache.py10
-rw-r--r--synapse/util/caches/treecache.py104
-rw-r--r--synapse/util/hash.py10
-rw-r--r--synapse/util/iterutils.py11
-rw-r--r--synapse/util/module_loader.py9
-rw-r--r--synapse/util/msisdn.py10
-rw-r--r--synapse/util/retryutils.py8
-rw-r--r--synapse/util/stringutils.py23
60 files changed, 1452 insertions, 587 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py

index 7498a6016f..d9843a1708 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py
@@ -47,7 +47,7 @@ try: except ImportError: pass -__version__ = "1.34.0" +__version__ = "1.35.0" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index c3c776a9f9..b2e60c6aa7 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py
@@ -87,6 +87,7 @@ class Auth: ) self._track_appservice_user_ips = hs.config.track_appservice_user_ips self._macaroon_secret_key = hs.config.macaroon_secret_key + self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users async def check_from_context( self, room_version: str, event, context, do_sig_check=True @@ -208,6 +209,8 @@ class Auth: opentracing.set_tag("authenticated_entity", user_id) opentracing.set_tag("user_id", user_id) opentracing.set_tag("appservice_id", app_service.id) + if user_id in self._force_tracing_for_users: + opentracing.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1) return requester @@ -260,6 +263,8 @@ class Auth: opentracing.set_tag("user_id", user_info.user_id) if device_id: opentracing.set_tag("device_id", device_id) + if user_info.token_owner in self._force_tracing_for_users: + opentracing.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1) return requester except KeyError: diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index f730cdbd78..91ad326f19 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py
@@ -61,7 +61,6 @@ from synapse.replication.slave.storage.pushers import SlavedPusherStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.room import RoomStore -from synapse.replication.slave.storage.transactions import SlavedTransactionStore from synapse.rest.admin import register_servlets_for_media_repo from synapse.rest.client.v1 import events, login, presence, room from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet @@ -237,7 +236,6 @@ class GenericWorkerSlavedStore( DirectoryStore, SlavedApplicationServiceStore, SlavedRegistrationStore, - SlavedTransactionStore, SlavedProfileStore, SlavedClientIpStore, SlavedFilteringStore, diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 7137e3d323..ea692f699d 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py
@@ -35,9 +35,26 @@ class ExperimentalConfig(Config): self.msc2858_enabled = experimental.get("msc2858_enabled", False) # type: bool # Spaces (MSC1772, MSC2946, MSC3083, etc) - self.spaces_enabled = experimental.get("spaces_enabled", False) # type: bool + self.spaces_enabled = experimental.get("spaces_enabled", True) # type: bool if self.spaces_enabled: KNOWN_ROOM_VERSIONS[RoomVersions.MSC3083.identifier] = RoomVersions.MSC3083 # MSC3026 (busy presence state) self.msc3026_enabled = experimental.get("msc3026_enabled", False) # type: bool + + def generate_config_section(self, **kwargs): + return """\ + # Enable experimental features in Synapse. + # + # Experimental features might break or be removed without a deprecation + # period. + # + experimental_features: + # Support for Spaces (MSC1772), it enables the following: + # + # * The Spaces Summary API (MSC2946). + # * Restricting room membership based on space membership (MSC3083). + # + # Uncomment to disable support for Spaces. + #spaces_enabled: false + """ diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index c23b66c88c..5ae0f55bcc 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py
@@ -57,7 +57,6 @@ class HomeServerConfig(RootConfig): config_classes = [ ServerConfig, - ExperimentalConfig, TlsConfig, FederationConfig, CacheConfig, @@ -94,4 +93,5 @@ class HomeServerConfig(RootConfig): TracerConfig, WorkerConfig, RedisConfig, + ExperimentalConfig, ] diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 632ef0d796..eecc0478a7 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py
@@ -477,4 +477,4 @@ class RegistrationConfig(Config): def read_arguments(self, args): if args.enable_registration is not None: - self.enable_registration = bool(strtobool(str(args.enable_registration))) + self.enable_registration = strtobool(str(args.enable_registration)) diff --git a/synapse/config/saml2.py b/synapse/config/saml2.py
index 3d1218c8d1..05e983625d 100644 --- a/synapse/config/saml2.py +++ b/synapse/config/saml2.py
@@ -164,7 +164,13 @@ class SAML2Config(Config): config_path = saml2_config.get("config_path", None) if config_path is not None: mod = load_python_module(config_path) - _dict_merge(merge_dict=mod.CONFIG, into_dict=saml2_config_dict) + config = getattr(mod, "CONFIG", None) + if config is None: + raise ConfigError( + "Config path specified by saml2_config.config_path does not " + "have a CONFIG property." + ) + _dict_merge(merge_dict=config, into_dict=saml2_config_dict) import saml2.config diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index 7df4e4c3e6..26f1150ca5 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py
@@ -16,11 +16,8 @@ import logging import os import warnings from datetime import datetime -from hashlib import sha256 from typing import List, Optional, Pattern -from unpaddedbase64 import encode_base64 - from OpenSSL import SSL, crypto from twisted.internet._sslverify import Certificate, trustRootFromCertificates @@ -83,13 +80,6 @@ class TlsConfig(Config): "configured." ) - self._original_tls_fingerprints = config.get("tls_fingerprints", []) - - if self._original_tls_fingerprints is None: - self._original_tls_fingerprints = [] - - self.tls_fingerprints = list(self._original_tls_fingerprints) - # Whether to verify certificates on outbound federation traffic self.federation_verify_certificates = config.get( "federation_verify_certificates", True @@ -248,19 +238,6 @@ class TlsConfig(Config): e, ) - self.tls_fingerprints = list(self._original_tls_fingerprints) - - if self.tls_certificate: - # Check that our own certificate is included in the list of fingerprints - # and include it if it is not. - x509_certificate_bytes = crypto.dump_certificate( - crypto.FILETYPE_ASN1, self.tls_certificate - ) - sha256_fingerprint = encode_base64(sha256(x509_certificate_bytes).digest()) - sha256_fingerprints = {f["sha256"] for f in self.tls_fingerprints} - if sha256_fingerprint not in sha256_fingerprints: - self.tls_fingerprints.append({"sha256": sha256_fingerprint}) - def generate_config_section( self, config_dir_path, @@ -443,33 +420,6 @@ class TlsConfig(Config): # If unspecified, we will use CONFDIR/client.key. # account_key_file: %(default_acme_account_file)s - - # List of allowed TLS fingerprints for this server to publish along - # with the signing keys for this server. Other matrix servers that - # make HTTPS requests to this server will check that the TLS - # certificates returned by this server match one of the fingerprints. - # - # Synapse automatically adds the fingerprint of its own certificate - # to the list. So if federation traffic is handled directly by synapse - # then no modification to the list is required. - # - # If synapse is run behind a load balancer that handles the TLS then it - # will be necessary to add the fingerprints of the certificates used by - # the loadbalancers to this list if they are different to the one - # synapse is using. - # - # Homeservers are permitted to cache the list of TLS fingerprints - # returned in the key responses up to the "valid_until_ts" returned in - # key. It may be necessary to publish the fingerprints of a new - # certificate and wait until the "valid_until_ts" of the previous key - # responses have passed before deploying it. - # - # You can calculate a fingerprint from a given TLS listener via: - # openssl s_client -connect $host:$port < /dev/null 2> /dev/null | - # openssl x509 -outform DER | openssl sha256 -binary | base64 | tr -d '=' - # or by checking matrix.org/federationtester/api/report?server_name=$host - # - #tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}] """ # Lowercase the string representation of boolean values % { diff --git a/synapse/config/tracer.py b/synapse/config/tracer.py
index db22b5b19f..d0ea17261f 100644 --- a/synapse/config/tracer.py +++ b/synapse/config/tracer.py
@@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Set + from synapse.python_dependencies import DependencyException, check_requirements from ._base import Config, ConfigError @@ -32,6 +34,8 @@ class TracerConfig(Config): {"sampler": {"type": "const", "param": 1}, "logging": False}, ) + self.force_tracing_for_users: Set[str] = set() + if not self.opentracer_enabled: return @@ -48,6 +52,19 @@ class TracerConfig(Config): if not isinstance(self.opentracer_whitelist, list): raise ConfigError("Tracer homeserver_whitelist config is malformed") + force_tracing_for_users = opentracing_config.get("force_tracing_for_users", []) + if not isinstance(force_tracing_for_users, list): + raise ConfigError( + "Expected a list", ("opentracing", "force_tracing_for_users") + ) + for i, u in enumerate(force_tracing_for_users): + if not isinstance(u, str): + raise ConfigError( + "Expected a string", + ("opentracing", "force_tracing_for_users", f"index {i}"), + ) + self.force_tracing_for_users.add(u) + def generate_config_section(cls, **kwargs): return """\ ## Opentracing ## @@ -64,7 +81,8 @@ class TracerConfig(Config): #enabled: true # The list of homeservers we wish to send and receive span contexts and span baggage. - # See docs/opentracing.rst + # See docs/opentracing.rst. + # # This is a list of regexes which are matched against the server_name of the # homeserver. # @@ -73,19 +91,26 @@ class TracerConfig(Config): #homeserver_whitelist: # - ".*" + # A list of the matrix IDs of users whose requests will always be traced, + # even if the tracing system would otherwise drop the traces due to + # probabilistic sampling. + # + # By default, the list is empty. + # + #force_tracing_for_users: + # - "@user1:server_name" + # - "@user2:server_name" + # Jaeger can be configured to sample traces at different rates. # All configuration options provided by Jaeger can be set here. - # Jaeger's configuration mostly related to trace sampling which + # Jaeger's configuration is mostly related to trace sampling which # is documented here: - # https://www.jaegertracing.io/docs/1.13/sampling/. + # https://www.jaegertracing.io/docs/latest/sampling/. # #jaeger_config: # sampler: # type: const # param: 1 - - # Logging whether spans were started and reported - # # logging: # false """ diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 5f18ef7748..6fc0712978 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py
@@ -17,7 +17,7 @@ import abc import logging import urllib from collections import defaultdict -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Set, Tuple import attr from signedjson.key import ( @@ -42,6 +42,8 @@ from synapse.api.errors import ( SynapseError, ) from synapse.config.key import TrustedKeyServer +from synapse.events import EventBase +from synapse.events.utils import prune_event_dict from synapse.logging.context import ( PreserveLoggingContext, make_deferred_yieldable, @@ -69,7 +71,11 @@ class VerifyJsonRequest: Attributes: server_name: The name of the server to verify against. - json_object: The JSON object to verify. + get_json_object: A callback to fetch the JSON object to verify. + A callback is used to allow deferring the creation of the JSON + object to verify until needed, e.g. for events we can defer + creating the redacted copy. This reduces the memory usage when + there are large numbers of in flight requests. minimum_valid_until_ts: time at which we require the signing key to be valid. (0 implies we don't care) @@ -88,14 +94,50 @@ class VerifyJsonRequest: """ server_name = attr.ib(type=str) - json_object = attr.ib(type=JsonDict) + get_json_object = attr.ib(type=Callable[[], JsonDict]) minimum_valid_until_ts = attr.ib(type=int) request_name = attr.ib(type=str) - key_ids = attr.ib(init=False, type=List[str]) + key_ids = attr.ib(type=List[str]) key_ready = attr.ib(default=attr.Factory(defer.Deferred), type=defer.Deferred) - def __attrs_post_init__(self): - self.key_ids = signature_ids(self.json_object, self.server_name) + @staticmethod + def from_json_object( + server_name: str, + json_object: JsonDict, + minimum_valid_until_ms: int, + request_name: str, + ): + """Create a VerifyJsonRequest to verify all signatures on a signed JSON + object for the given server. + """ + key_ids = signature_ids(json_object, server_name) + return VerifyJsonRequest( + server_name, + lambda: json_object, + minimum_valid_until_ms, + request_name=request_name, + key_ids=key_ids, + ) + + @staticmethod + def from_event( + server_name: str, + event: EventBase, + minimum_valid_until_ms: int, + ): + """Create a VerifyJsonRequest to verify all signatures on an event + object for the given server. + """ + key_ids = list(event.signatures.get(server_name, [])) + return VerifyJsonRequest( + server_name, + # We defer creating the redacted json object, as it uses a lot more + # memory than the Event object itself. + lambda: prune_event_dict(event.room_version, event.get_pdu_json()), + minimum_valid_until_ms, + request_name=event.event_id, + key_ids=key_ids, + ) class KeyLookupError(ValueError): @@ -147,8 +189,13 @@ class Keyring: Deferred[None]: completes if the the object was correctly signed, otherwise errbacks with an error """ - req = VerifyJsonRequest(server_name, json_object, validity_time, request_name) - requests = (req,) + request = VerifyJsonRequest.from_json_object( + server_name, + json_object, + validity_time, + request_name, + ) + requests = (request,) return make_deferred_yieldable(self._verify_objects(requests)[0]) def verify_json_objects_for_server( @@ -175,10 +222,41 @@ class Keyring: logcontext. """ return self._verify_objects( - VerifyJsonRequest(server_name, json_object, validity_time, request_name) + VerifyJsonRequest.from_json_object( + server_name, json_object, validity_time, request_name + ) for server_name, json_object, validity_time, request_name in server_and_json ) + def verify_events_for_server( + self, server_and_events: Iterable[Tuple[str, EventBase, int]] + ) -> List[defer.Deferred]: + """Bulk verification of signatures on events. + + Args: + server_and_events: + Iterable of `(server_name, event, validity_time)` tuples. + + `server_name` is which server we are verifying the signature for + on the event. + + `event` is the event that we'll verify the signatures of for + the given `server_name`. + + `validity_time` is a timestamp at which the signing key must be + valid. + + Returns: + List<Deferred[None]>: for each input triplet, a deferred indicating success + or failure to verify each event's signature for the given + server_name. The deferreds run their callbacks in the sentinel + logcontext. + """ + return self._verify_objects( + VerifyJsonRequest.from_event(server_name, event, validity_time) + for server_name, event, validity_time in server_and_events + ) + def _verify_objects( self, verify_requests: Iterable[VerifyJsonRequest] ) -> List[defer.Deferred]: @@ -892,7 +970,7 @@ async def _handle_key_deferred(verify_request: VerifyJsonRequest) -> None: with PreserveLoggingContext(): _, key_id, verify_key = await verify_request.key_ready - json_object = verify_request.json_object + json_object = verify_request.get_json_object() try: verify_signed_json(json_object, server_name, verify_key) diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 949dcd4614..3fe496dcd3 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py
@@ -137,11 +137,7 @@ class FederationBase: return deferreds -class PduToCheckSig( - namedtuple( - "PduToCheckSig", ["pdu", "redacted_pdu_json", "sender_domain", "deferreds"] - ) -): +class PduToCheckSig(namedtuple("PduToCheckSig", ["pdu", "sender_domain", "deferreds"])): pass @@ -184,7 +180,6 @@ def _check_sigs_on_pdus( pdus_to_check = [ PduToCheckSig( pdu=p, - redacted_pdu_json=prune_event(p).get_pdu_json(), sender_domain=get_domain_from_id(p.sender), deferreds=[], ) @@ -195,13 +190,12 @@ def _check_sigs_on_pdus( # (except if its a 3pid invite, in which case it may be sent by any server) pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)] - more_deferreds = keyring.verify_json_objects_for_server( + more_deferreds = keyring.verify_events_for_server( [ ( p.sender_domain, - p.redacted_pdu_json, + p.pdu, p.pdu.origin_server_ts if room_version.enforce_key_validity else 0, - p.pdu.event_id, ) for p in pdus_to_check_sender ] @@ -230,13 +224,12 @@ def _check_sigs_on_pdus( if p.sender_domain != get_domain_from_id(p.pdu.event_id) ] - more_deferreds = keyring.verify_json_objects_for_server( + more_deferreds = keyring.verify_events_for_server( [ ( get_domain_from_id(p.pdu.event_id), - p.redacted_pdu_json, + p.pdu, p.pdu.origin_server_ts if room_version.enforce_key_validity else 0, - p.pdu.event_id, ) for p in pdus_to_check_event_id ] diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 5125441df5..3feb60da2a 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py
@@ -57,6 +57,7 @@ from synapse.api.room_versions import ( ) from synapse.events import EventBase, builder from synapse.federation.federation_base import FederationBase, event_from_pdu_json +from synapse.federation.transport.client import SendJoinResponse from synapse.logging.context import make_deferred_yieldable, preserve_fn from synapse.logging.utils import log_function from synapse.types import JsonDict, get_domain_from_id @@ -667,19 +668,10 @@ class FederationClient(FederationBase): """ async def send_request(destination) -> Dict[str, Any]: - content = await self._do_send_join(destination, pdu) + response = await self._do_send_join(room_version, destination, pdu) - logger.debug("Got content: %s", content) - - state = [ - event_from_pdu_json(p, room_version, outlier=True) - for p in content.get("state", []) - ] - - auth_chain = [ - event_from_pdu_json(p, room_version, outlier=True) - for p in content.get("auth_chain", []) - ] + state = response.state + auth_chain = response.auth_events pdus = {p.event_id: p for p in itertools.chain(state, auth_chain)} @@ -754,11 +746,14 @@ class FederationClient(FederationBase): return await self._try_destination_list("send_join", destinations, send_request) - async def _do_send_join(self, destination: str, pdu: EventBase) -> JsonDict: + async def _do_send_join( + self, room_version: RoomVersion, destination: str, pdu: EventBase + ) -> SendJoinResponse: time_now = self._clock.time_msec() try: return await self.transport_layer.send_join_v2( + room_version=room_version, destination=destination, room_id=pdu.room_id, event_id=pdu.event_id, @@ -773,17 +768,14 @@ class FederationClient(FederationBase): logger.debug("Couldn't send_join with the v2 API, falling back to the v1 API") - resp = await self.transport_layer.send_join_v1( + return await self.transport_layer.send_join_v1( + room_version=room_version, destination=destination, room_id=pdu.room_id, event_id=pdu.event_id, content=pdu.get_pdu_json(time_now), ) - # We expect the v1 API to respond with [200, content], so we only return the - # content. - return resp[1] - async def send_invite( self, destination: str, diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index a42e055501..bf5b541deb 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py
@@ -19,18 +19,29 @@ import logging import urllib from typing import Any, Dict, List, Optional +import attr +import ijson + from synapse.api.constants import Membership from synapse.api.errors import Codes, HttpResponseException, SynapseError +from synapse.api.room_versions import RoomVersion from synapse.api.urls import ( FEDERATION_UNSTABLE_PREFIX, FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX, ) +from synapse.events import EventBase, make_event_from_dict +from synapse.http.matrixfederationclient import ByteParser from synapse.logging.utils import log_function from synapse.types import JsonDict logger = logging.getLogger(__name__) +# Send join responses can be huge, so we set a separate limit here. The response +# is parsed in a streaming manner, which helps alleviate the issue of memory +# usage a bit. +MAX_RESPONSE_SIZE_SEND_JOIN = 500 * 1024 * 1024 + class TransportLayerClient: """Sends federation HTTP requests to other servers""" @@ -253,21 +264,38 @@ class TransportLayerClient: return content @log_function - async def send_join_v1(self, destination, room_id, event_id, content): + async def send_join_v1( + self, + room_version, + destination, + room_id, + event_id, + content, + ) -> "SendJoinResponse": path = _create_v1_path("/send_join/%s/%s", room_id, event_id) response = await self.client.put_json( - destination=destination, path=path, data=content + destination=destination, + path=path, + data=content, + parser=SendJoinParser(room_version, v1_api=True), + max_response_size=MAX_RESPONSE_SIZE_SEND_JOIN, ) return response @log_function - async def send_join_v2(self, destination, room_id, event_id, content): + async def send_join_v2( + self, room_version, destination, room_id, event_id, content + ) -> "SendJoinResponse": path = _create_v2_path("/send_join/%s/%s", room_id, event_id) response = await self.client.put_json( - destination=destination, path=path, data=content + destination=destination, + path=path, + data=content, + parser=SendJoinParser(room_version, v1_api=False), + max_response_size=MAX_RESPONSE_SIZE_SEND_JOIN, ) return response @@ -1119,3 +1147,59 @@ def _create_v2_path(path, *args): str """ return _create_path(FEDERATION_V2_PREFIX, path, *args) + + +@attr.s(slots=True, auto_attribs=True) +class SendJoinResponse: + """The parsed response of a `/send_join` request.""" + + auth_events: List[EventBase] + state: List[EventBase] + + +@ijson.coroutine +def _event_list_parser(room_version: RoomVersion, events: List[EventBase]): + """Helper function for use with `ijson.items_coro` to parse an array of + events and add them to the given list. + """ + + while True: + obj = yield + event = make_event_from_dict(obj, room_version) + events.append(event) + + +class SendJoinParser(ByteParser[SendJoinResponse]): + """A parser for the response to `/send_join` requests. + + Args: + room_version: The version of the room. + v1_api: Whether the response is in the v1 format. + """ + + CONTENT_TYPE = "application/json" + + def __init__(self, room_version: RoomVersion, v1_api: bool): + self._response = SendJoinResponse([], []) + + # The V1 API has the shape of `[200, {...}]`, which we handle by + # prefixing with `item.*`. + prefix = "item." if v1_api else "" + + self._coro_state = ijson.items_coro( + _event_list_parser(room_version, self._response.state), + prefix + "state.item", + ) + self._coro_auth = ijson.items_coro( + _event_list_parser(room_version, self._response.auth_events), + prefix + "auth_chain.item", + ) + + def write(self, data: bytes) -> int: + self._coro_state.send(data) + self._coro_auth.send(data) + + return len(data) + + def finish(self) -> SendJoinResponse: + return self._response diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index dd5ab5160a..086d999d98 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py
@@ -162,7 +162,7 @@ class Authenticator: # If we get a valid signed request from the other side, its probably # alive retry_timings = await self.store.get_destination_retry_timings(origin) - if retry_timings and retry_timings["retry_last_ts"]: + if retry_timings and retry_timings.retry_last_ts: run_in_background(self._reset_retry_timings, origin) return origin @@ -1479,7 +1479,7 @@ class FederationSpaceSummaryServlet(BaseFederationServlet): ) return 200, await self.handler.federation_space_summary( - room_id, suggested_only, max_rooms_per_space, exclude_rooms + origin, room_id, suggested_only, max_rooms_per_space, exclude_rooms ) # TODO When switching to the stable endpoint, remove the POST handler. @@ -1509,7 +1509,7 @@ class FederationSpaceSummaryServlet(BaseFederationServlet): ) return 200, await self.handler.federation_space_summary( - room_id, suggested_only, max_rooms_per_space, exclude_rooms + origin, room_id, suggested_only, max_rooms_per_space, exclude_rooms ) diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index 022789ea5f..640c2e9fd6 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py
@@ -15,12 +15,9 @@ import email.mime.multipart import email.utils import logging -from email.mime.multipart import MIMEMultipart -from email.mime.text import MIMEText from typing import TYPE_CHECKING, List, Optional, Tuple from synapse.api.errors import StoreError, SynapseError -from synapse.logging.context import make_deferred_yieldable from synapse.metrics.background_process_metrics import ( run_as_background_process, wrap_as_background_process, @@ -39,9 +36,11 @@ class AccountValidityHandler: self.hs = hs self.config = hs.config self.store = self.hs.get_datastore() - self.sendmail = self.hs.get_sendmail() + self.send_email_handler = self.hs.get_send_email_handler() self.clock = self.hs.get_clock() + self._app_name = self.hs.config.email_app_name + self._account_validity_enabled = ( hs.config.account_validity.account_validity_enabled ) @@ -69,23 +68,10 @@ class AccountValidityHandler: self._template_text = ( hs.config.account_validity.account_validity_template_text ) - account_validity_renew_email_subject = ( + self._renew_email_subject = ( hs.config.account_validity.account_validity_renew_email_subject ) - try: - app_name = hs.config.email_app_name - - self._subject = account_validity_renew_email_subject % {"app": app_name} - - self._from_string = hs.config.email_notif_from % {"app": app_name} - except Exception: - # If substitution failed, fall back to the bare strings. - self._subject = account_validity_renew_email_subject - self._from_string = hs.config.email_notif_from - - self._raw_from = email.utils.parseaddr(self._from_string)[1] - # Check the renewal emails to send and send them every 30min. if hs.config.run_background_tasks: self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000) @@ -177,38 +163,17 @@ class AccountValidityHandler: } html_text = self._template_html.render(**template_vars) - html_part = MIMEText(html_text, "html", "utf8") - plain_text = self._template_text.render(**template_vars) - text_part = MIMEText(plain_text, "plain", "utf8") for address in addresses: raw_to = email.utils.parseaddr(address)[1] - multipart_msg = MIMEMultipart("alternative") - multipart_msg["Subject"] = self._subject - multipart_msg["From"] = self._from_string - multipart_msg["To"] = address - multipart_msg["Date"] = email.utils.formatdate() - multipart_msg["Message-ID"] = email.utils.make_msgid() - multipart_msg.attach(text_part) - multipart_msg.attach(html_part) - - logger.info("Sending renewal email to %s", address) - - await make_deferred_yieldable( - self.sendmail( - self.hs.config.email_smtp_host, - self._raw_from, - raw_to, - multipart_msg.as_string().encode("utf8"), - reactor=self.hs.get_reactor(), - port=self.hs.config.email_smtp_port, - requireAuthentication=self.hs.config.email_smtp_user is not None, - username=self.hs.config.email_smtp_user, - password=self.hs.config.email_smtp_pass, - requireTransportSecurity=self.hs.config.require_transport_security, - ) + await self.send_email_handler.send_email( + email_address=raw_to, + subject=self._renew_email_subject, + app_name=self._app_name, + html=html_text, + text=plain_text, ) await self.store.set_renewal_mail_status(user_id=user_id, email_sent=True) diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index eff639f407..a0df16a32f 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py
@@ -11,10 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Collection, Optional -from synapse.api.constants import EventTypes, JoinRules +from synapse.api.constants import EventTypes, JoinRules, Membership +from synapse.api.errors import AuthError from synapse.api.room_versions import RoomVersion +from synapse.events import EventBase from synapse.types import StateMap if TYPE_CHECKING: @@ -29,46 +31,104 @@ class EventAuthHandler: def __init__(self, hs: "HomeServer"): self._store = hs.get_datastore() - async def can_join_without_invite( - self, state_ids: StateMap[str], room_version: RoomVersion, user_id: str - ) -> bool: + async def check_restricted_join_rules( + self, + state_ids: StateMap[str], + room_version: RoomVersion, + user_id: str, + prev_member_event: Optional[EventBase], + ) -> None: """ - Check whether a user can join a room without an invite. + Check whether a user can join a room without an invite due to restricted join rules. When joining a room with restricted joined rules (as defined in MSC3083), - the membership of spaces must be checked during join. + the membership of spaces must be checked during a room join. Args: state_ids: The state of the room as it currently is. room_version: The room version of the room being joined. user_id: The user joining the room. + prev_member_event: The current membership event for this user. + + Raises: + AuthError if the user cannot join the room. + """ + # If the member is invited or currently joined, then nothing to do. + if prev_member_event and ( + prev_member_event.membership in (Membership.JOIN, Membership.INVITE) + ): + return + + # This is not a room with a restricted join rule, so we don't need to do the + # restricted room specific checks. + # + # Note: We'll be applying the standard join rule checks later, which will + # catch the cases of e.g. trying to join private rooms without an invite. + if not await self.has_restricted_join_rules(state_ids, room_version): + return + + # Get the spaces which allow access to this room and check if the user is + # in any of them. + allowed_spaces = await self.get_spaces_that_allow_join(state_ids) + if not await self.is_user_in_rooms(allowed_spaces, user_id): + raise AuthError( + 403, + "You do not belong to any of the required spaces to join this room.", + ) + + async def has_restricted_join_rules( + self, state_ids: StateMap[str], room_version: RoomVersion + ) -> bool: + """ + Return if the room has the proper join rules set for access via spaces. + + Args: + state_ids: The state of the room as it currently is. + room_version: The room version of the room to query. Returns: - True if the user can join the room, false otherwise. + True if the proper room version and join rules are set for restricted access. """ # This only applies to room versions which support the new join rule. if not room_version.msc3083_join_rules: - return True + return False # If there's no join rule, then it defaults to invite (so this doesn't apply). join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""), None) if not join_rules_event_id: - return True + return False + + # If the join rule is not restricted, this doesn't apply. + join_rules_event = await self._store.get_event(join_rules_event_id) + return join_rules_event.content.get("join_rule") == JoinRules.MSC3083_RESTRICTED + + async def get_spaces_that_allow_join( + self, state_ids: StateMap[str] + ) -> Collection[str]: + """ + Generate a list of spaces which allow access to a room. + + Args: + state_ids: The state of the room as it currently is. + + Returns: + A collection of spaces which provide membership to the room. + """ + # If there's no join rule, then it defaults to invite (so this doesn't apply). + join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""), None) + if not join_rules_event_id: + return () # If the join rule is not restricted, this doesn't apply. join_rules_event = await self._store.get_event(join_rules_event_id) - if join_rules_event.content.get("join_rule") != JoinRules.MSC3083_RESTRICTED: - return True # If allowed is of the wrong form, then only allow invited users. allowed_spaces = join_rules_event.content.get("allow", []) if not isinstance(allowed_spaces, list): - return False - - # Get the list of joined rooms and see if there's an overlap. - joined_rooms = await self._store.get_rooms_for_user(user_id) + return () # Pull out the other room IDs, invalid data gets filtered. + result = [] for space in allowed_spaces: if not isinstance(space, dict): continue @@ -77,10 +137,31 @@ class EventAuthHandler: if not isinstance(space_id, str): continue - # The user was joined to one of the spaces specified, they can join - # this room! - if space_id in joined_rooms: + result.append(space_id) + + return result + + async def is_user_in_rooms(self, room_ids: Collection[str], user_id: str) -> bool: + """ + Check whether a user is a member of any of the provided rooms. + + Args: + room_ids: The rooms to check for membership. + user_id: The user to check. + + Returns: + True if the user is in any of the rooms, false otherwise. + """ + if not room_ids: + return False + + # Get the list of joined rooms and see if there's an overlap. + joined_rooms = await self._store.get_rooms_for_user(user_id) + + # Check each room and see if the user is in it. + for room_id in room_ids: + if room_id in joined_rooms: return True - # The user was not in any of the required spaces. + # The user was not in any of the rooms. return False diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 6a5c33f212..36652289a4 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py
@@ -92,6 +92,7 @@ from synapse.types import ( get_domain_from_id, ) from synapse.util.async_helpers import Linearizer, concurrently_execute +from synapse.util.iterutils import batch_iter from synapse.util.retryutils import NotRetryingDestination from synapse.util.stringutils import shortstr from synapse.visibility import filter_events_for_server @@ -1741,28 +1742,17 @@ class FederationHandler(BaseHandler): # Check if the user is already in the room or invited to the room. user_id = event.state_key prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) - newly_joined = True - user_is_invited = False + prev_member_event = None if prev_member_event_id: prev_member_event = await self.store.get_event(prev_member_event_id) - newly_joined = prev_member_event.membership != Membership.JOIN - user_is_invited = prev_member_event.membership == Membership.INVITE - - # If the member is not already in the room, and not invited, check if - # they should be allowed access via membership in a space. - if ( - newly_joined - and not user_is_invited - and not await self._event_auth_handler.can_join_without_invite( - prev_state_ids, - event.room_version, - user_id, - ) - ): - raise AuthError( - 403, - "You do not belong to any of the required spaces to join this room.", - ) + + # Check if the member should be allowed access via membership in a space. + await self._event_auth_handler.check_restricted_join_rules( + prev_state_ids, + event.room_version, + user_id, + prev_member_event, + ) # Persist the event. await self._auth_and_persist_event(origin, event, context) @@ -3258,13 +3248,15 @@ class FederationHandler(BaseHandler): """ instance = self.config.worker.events_shard_config.get_instance(room_id) if instance != self._instance_name: - result = await self._send_events( - instance_name=instance, - store=self.store, - room_id=room_id, - event_and_contexts=event_and_contexts, - backfilled=backfilled, - ) + # Limit the number of events sent over federation. + for batch in batch_iter(event_and_contexts, 1000): + result = await self._send_events( + instance_name=instance, + store=self.store, + room_id=room_id, + event_and_contexts=batch, + backfilled=backfilled, + ) return result["max_stream_id"] else: assert self.storage.persistence diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 6fd1f34289..f5a049d754 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py
@@ -222,9 +222,21 @@ class BasePresenceHandler(abc.ABC): @abc.abstractmethod async def set_state( - self, target_user: UserID, state: JsonDict, ignore_status_msg: bool = False + self, + target_user: UserID, + state: JsonDict, + ignore_status_msg: bool = False, + force_notify: bool = False, ) -> None: - """Set the presence state of the user. """ + """Set the presence state of the user. + + Args: + target_user: The ID of the user to set the presence state of. + state: The presence state as a JSON dictionary. + ignore_status_msg: True to ignore the "status_msg" field of the `state` dict. + If False, the user's current status will be updated. + force_notify: Whether to force notification of the update to clients. + """ @abc.abstractmethod async def bump_presence_active_time(self, user: UserID): @@ -296,6 +308,51 @@ class BasePresenceHandler(abc.ABC): for destinations, states in hosts_and_states: self._federation.send_presence_to_destinations(states, destinations) + async def send_full_presence_to_users(self, user_ids: Collection[str]): + """ + Adds to the list of users who should receive a full snapshot of presence + upon their next sync. Note that this only works for local users. + + Then, grabs the current presence state for a given set of users and adds it + to the top of the presence stream. + + Args: + user_ids: The IDs of the local users to send full presence to. + """ + # Retrieve one of the users from the given set + if not user_ids: + raise Exception( + "send_full_presence_to_users must be called with at least one user" + ) + user_id = next(iter(user_ids)) + + # Mark all users as receiving full presence on their next sync + await self.store.add_users_to_send_full_presence_to(user_ids) + + # Add a new entry to the presence stream. Since we use stream tokens to determine whether a + # local user should receive a full snapshot of presence when they sync, we need to bump the + # presence stream so that subsequent syncs with no presence activity in between won't result + # in the client receiving multiple full snapshots of presence. + # + # If we bump the stream ID, then the user will get a higher stream token next sync, and thus + # correctly won't receive a second snapshot. + + # Get the current presence state for one of the users (defaults to offline if not found) + current_presence_state = await self.get_state(UserID.from_string(user_id)) + + # Convert the UserPresenceState object into a serializable dict + state = { + "presence": current_presence_state.state, + "status_message": current_presence_state.status_msg, + } + + # Copy the presence state to the tip of the presence stream. + + # We set force_notify=True here so that this presence update is guaranteed to + # increment the presence stream ID (which resending the current user's presence + # otherwise would not do). + await self.set_state(UserID.from_string(user_id), state, force_notify=True) + class _NullContextManager(ContextManager[None]): """A context manager which does nothing.""" @@ -480,8 +537,17 @@ class WorkerPresenceHandler(BasePresenceHandler): target_user: UserID, state: JsonDict, ignore_status_msg: bool = False, + force_notify: bool = False, ) -> None: - """Set the presence state of the user.""" + """Set the presence state of the user. + + Args: + target_user: The ID of the user to set the presence state of. + state: The presence state as a JSON dictionary. + ignore_status_msg: True to ignore the "status_msg" field of the `state` dict. + If False, the user's current status will be updated. + force_notify: Whether to force notification of the update to clients. + """ presence = state["presence"] valid_presence = ( @@ -508,6 +574,7 @@ class WorkerPresenceHandler(BasePresenceHandler): user_id=user_id, state=state, ignore_status_msg=ignore_status_msg, + force_notify=force_notify, ) async def bump_presence_active_time(self, user: UserID) -> None: @@ -677,13 +744,19 @@ class PresenceHandler(BasePresenceHandler): [self.user_to_current_state[user_id] for user_id in unpersisted] ) - async def _update_states(self, new_states: Iterable[UserPresenceState]) -> None: + async def _update_states( + self, new_states: Iterable[UserPresenceState], force_notify: bool = False + ) -> None: """Updates presence of users. Sets the appropriate timeouts. Pokes the notifier and federation if and only if the changed presence state should be sent to clients/servers. Args: new_states: The new user presence state updates to process. + force_notify: Whether to force notifying clients of this presence state update, + even if it doesn't change the state of a user's presence (e.g online -> online). + This is currently used to bump the max presence stream ID without changing any + user's presence (see PresenceHandler.add_users_to_send_full_presence_to). """ now = self.clock.time_msec() @@ -720,6 +793,9 @@ class PresenceHandler(BasePresenceHandler): now=now, ) + if force_notify: + should_notify = True + self.user_to_current_state[user_id] = new_state if should_notify: @@ -1058,9 +1134,21 @@ class PresenceHandler(BasePresenceHandler): await self._update_states(updates) async def set_state( - self, target_user: UserID, state: JsonDict, ignore_status_msg: bool = False + self, + target_user: UserID, + state: JsonDict, + ignore_status_msg: bool = False, + force_notify: bool = False, ) -> None: - """Set the presence state of the user.""" + """Set the presence state of the user. + + Args: + target_user: The ID of the user to set the presence state of. + state: The presence state as a JSON dictionary. + ignore_status_msg: True to ignore the "status_msg" field of the `state` dict. + If False, the user's current status will be updated. + force_notify: Whether to force notification of the update to clients. + """ status_msg = state.get("status_msg", None) presence = state["presence"] @@ -1091,7 +1179,9 @@ class PresenceHandler(BasePresenceHandler): ): new_fields["last_active_ts"] = self.clock.time_msec() - await self._update_states([prev_state.copy_and_replace(**new_fields)]) + await self._update_states( + [prev_state.copy_and_replace(**new_fields)], force_notify=force_notify + ) async def is_visible(self, observed_user: UserID, observer_user: UserID) -> bool: """Returns whether a user can see another user's presence.""" @@ -1389,11 +1479,10 @@ class PresenceEventSource: # # Presence -> Notifier -> PresenceEventSource -> Presence # - # Same with get_module_api, get_presence_router + # Same with get_presence_router: # # AuthHandler -> Notifier -> PresenceEventSource -> ModuleApi -> AuthHandler self.get_presence_handler = hs.get_presence_handler - self.get_module_api = hs.get_module_api self.get_presence_router = hs.get_presence_router self.clock = hs.get_clock() self.store = hs.get_datastore() @@ -1424,16 +1513,21 @@ class PresenceEventSource: stream_change_cache = self.store.presence_stream_cache with Measure(self.clock, "presence.get_new_events"): - if user_id in self.get_module_api()._send_full_presence_to_local_users: - # This user has been specified by a module to receive all current, online - # user presence. Removing from_key and setting include_offline to false - # will do effectively this. - from_key = None - include_offline = False - if from_key is not None: from_key = int(from_key) + # Check if this user should receive all current, online user presence. We only + # bother to do this if from_key is set, as otherwise the user will receive all + # user presence anyways. + if await self.store.should_user_receive_full_presence_with_token( + user_id, from_key + ): + # This user has been specified by a module to receive all current, online + # user presence. Removing from_key and setting include_offline to false + # will do effectively this. + from_key = None + include_offline = False + max_token = self.store.get_current_presence_token() if from_key == max_token: # This is necessary as due to the way stream ID generators work @@ -1467,12 +1561,6 @@ class PresenceEventSource: user_id, include_offline, from_key ) - # Remove the user from the list of users to receive all presence - if user_id in self.get_module_api()._send_full_presence_to_local_users: - self.get_module_api()._send_full_presence_to_local_users.remove( - user_id - ) - return presence_updates, max_token # Make mypy happy. users_interested_in should now be a set @@ -1522,10 +1610,6 @@ class PresenceEventSource: ) presence_updates = list(users_to_state.values()) - # Remove the user from the list of users to receive all presence - if user_id in self.get_module_api()._send_full_presence_to_local_users: - self.get_module_api()._send_full_presence_to_local_users.remove(user_id) - if not include_offline: # Filter out offline presence states presence_updates = self._filter_offline_presence_state(presence_updates) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 835d5862c1..61900d87df 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py
@@ -307,25 +307,15 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): if event.membership == Membership.JOIN: newly_joined = True - user_is_invited = False + prev_member_event = None if prev_member_event_id: prev_member_event = await self.store.get_event(prev_member_event_id) newly_joined = prev_member_event.membership != Membership.JOIN - user_is_invited = prev_member_event.membership == Membership.INVITE - # If the member is not already in the room and is not accepting an invite, - # check if they should be allowed access via membership in a space. - if ( - newly_joined - and not user_is_invited - and not await self.event_auth_handler.can_join_without_invite( - prev_state_ids, event.room_version, user_id - ) - ): - raise AuthError( - 403, - "You do not belong to any of the required spaces to join this room.", - ) + # Check if the member should be allowed access via membership in a space. + await self.event_auth_handler.check_restricted_join_rules( + prev_state_ids, event.room_version, user_id, prev_member_event + ) # Only rate-limit if the user actually joined the room, otherwise we'll end # up blocking profile updates. diff --git a/synapse/handlers/send_email.py b/synapse/handlers/send_email.py new file mode 100644
index 0000000000..e9f6aef06f --- /dev/null +++ b/synapse/handlers/send_email.py
@@ -0,0 +1,98 @@ +# Copyright 2021 The Matrix.org C.I.C. Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import email.utils +import logging +from email.mime.multipart import MIMEMultipart +from email.mime.text import MIMEText +from typing import TYPE_CHECKING + +from synapse.logging.context import make_deferred_yieldable + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class SendEmailHandler: + def __init__(self, hs: "HomeServer"): + self.hs = hs + + self._sendmail = hs.get_sendmail() + self._reactor = hs.get_reactor() + + self._from = hs.config.email.email_notif_from + self._smtp_host = hs.config.email.email_smtp_host + self._smtp_port = hs.config.email.email_smtp_port + self._smtp_user = hs.config.email.email_smtp_user + self._smtp_pass = hs.config.email.email_smtp_pass + self._require_transport_security = hs.config.email.require_transport_security + + async def send_email( + self, + email_address: str, + subject: str, + app_name: str, + html: str, + text: str, + ) -> None: + """Send a multipart email with the given information. + + Args: + email_address: The address to send the email to. + subject: The email's subject. + app_name: The app name to include in the From header. + html: The HTML content to include in the email. + text: The plain text content to include in the email. + """ + try: + from_string = self._from % {"app": app_name} + except (KeyError, TypeError): + from_string = self._from + + raw_from = email.utils.parseaddr(from_string)[1] + raw_to = email.utils.parseaddr(email_address)[1] + + if raw_to == "": + raise RuntimeError("Invalid 'to' address") + + html_part = MIMEText(html, "html", "utf8") + text_part = MIMEText(text, "plain", "utf8") + + multipart_msg = MIMEMultipart("alternative") + multipart_msg["Subject"] = subject + multipart_msg["From"] = from_string + multipart_msg["To"] = email_address + multipart_msg["Date"] = email.utils.formatdate() + multipart_msg["Message-ID"] = email.utils.make_msgid() + multipart_msg.attach(text_part) + multipart_msg.attach(html_part) + + logger.info("Sending email to %s" % email_address) + + await make_deferred_yieldable( + self._sendmail( + self._smtp_host, + raw_from, + raw_to, + multipart_msg.as_string().encode("utf8"), + reactor=self._reactor, + port=self._smtp_port, + requireAuthentication=self._smtp_user is not None, + username=self._smtp_user, + password=self._smtp_pass, + requireTransportSecurity=self._require_transport_security, + ) + ) diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py
index e35d91832b..abd9ddecca 100644 --- a/synapse/handlers/space_summary.py +++ b/synapse/handlers/space_summary.py
@@ -16,11 +16,16 @@ import itertools import logging import re from collections import deque -from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Set, Tuple, cast +from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Set, Tuple import attr -from synapse.api.constants import EventContentFields, EventTypes, HistoryVisibility +from synapse.api.constants import ( + EventContentFields, + EventTypes, + HistoryVisibility, + Membership, +) from synapse.api.errors import AuthError from synapse.events import EventBase from synapse.events.utils import format_event_for_client_v2 @@ -32,7 +37,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) # number of rooms to return. We'll stop once we hit this limit. -# TODO: allow clients to reduce this with a request param. MAX_ROOMS = 50 # max number of events to return per room. @@ -46,8 +50,7 @@ class SpaceSummaryHandler: def __init__(self, hs: "HomeServer"): self._clock = hs.get_clock() self._auth = hs.get_auth() - self._room_list_handler = hs.get_room_list_handler() - self._state_handler = hs.get_state_handler() + self._event_auth_handler = hs.get_event_auth_handler() self._store = hs.get_datastore() self._event_serializer = hs.get_event_client_serializer() self._server_name = hs.hostname @@ -112,28 +115,88 @@ class SpaceSummaryHandler: max_children = max_rooms_per_space if processed_rooms else None if is_in_room: - rooms, events = await self._summarize_local_room( - requester, room_id, suggested_only, max_children + room, events = await self._summarize_local_room( + requester, None, room_id, suggested_only, max_children ) + + logger.debug( + "Query of local room %s returned events %s", + room_id, + ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in events], + ) + + if room: + rooms_result.append(room) else: - rooms, events = await self._summarize_remote_room( + fed_rooms, fed_events = await self._summarize_remote_room( queue_entry, suggested_only, max_children, exclude_rooms=processed_rooms, ) - logger.debug( - "Query of %s returned rooms %s, events %s", - queue_entry.room_id, - [room.get("room_id") for room in rooms], - ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in events], - ) - - rooms_result.extend(rooms) - - # any rooms returned don't need visiting again - processed_rooms.update(cast(str, room.get("room_id")) for room in rooms) + # The results over federation might include rooms that the we, + # as the requesting server, are allowed to see, but the requesting + # user is not permitted see. + # + # Filter the returned results to only what is accessible to the user. + room_ids = set() + events = [] + for room in fed_rooms: + fed_room_id = room.get("room_id") + if not fed_room_id or not isinstance(fed_room_id, str): + continue + + # The room should only be included in the summary if: + # a. the user is in the room; + # b. the room is world readable; or + # c. the user is in a space that has been granted access to + # the room. + # + # Note that we know the user is not in the root room (which is + # why the remote call was made in the first place), but the user + # could be in one of the children rooms and we just didn't know + # about the link. + include_room = room.get("world_readable") is True + + # Check if the user is a member of any of the allowed spaces + # from the response. + allowed_spaces = room.get("allowed_spaces") + if ( + not include_room + and allowed_spaces + and isinstance(allowed_spaces, list) + ): + include_room = await self._event_auth_handler.is_user_in_rooms( + allowed_spaces, requester + ) + + # Finally, if this isn't the requested room, check ourselves + # if we can access the room. + if not include_room and fed_room_id != queue_entry.room_id: + include_room = await self._is_room_accessible( + fed_room_id, requester, None + ) + + # The user can see the room, include it! + if include_room: + rooms_result.append(room) + room_ids.add(fed_room_id) + + # All rooms returned don't need visiting again (even if the user + # didn't have access to them). + processed_rooms.add(fed_room_id) + + for event in fed_events: + if event.get("room_id") in room_ids: + events.append(event) + + logger.debug( + "Query of %s returned rooms %s, events %s", + room_id, + [room.get("room_id") for room in fed_rooms], + ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in fed_events], + ) # the room we queried may or may not have been returned, but don't process # it again, anyway. @@ -159,10 +222,16 @@ class SpaceSummaryHandler: ) processed_events.add(ev_key) + # Before returning to the client, remove the allowed_spaces key for any + # rooms. + for room in rooms_result: + room.pop("allowed_spaces", None) + return {"rooms": rooms_result, "events": events_result} async def federation_space_summary( self, + origin: str, room_id: str, suggested_only: bool, max_rooms_per_space: Optional[int], @@ -172,6 +241,8 @@ class SpaceSummaryHandler: Implementation of the space summary Federation API Args: + origin: The server requesting the spaces summary. + room_id: room id to start the summary at suggested_only: whether we should only return children with the "suggested" @@ -206,14 +277,15 @@ class SpaceSummaryHandler: logger.debug("Processing room %s", room_id) - rooms, events = await self._summarize_local_room( - None, room_id, suggested_only, max_rooms_per_space + room, events = await self._summarize_local_room( + None, origin, room_id, suggested_only, max_rooms_per_space ) processed_rooms.add(room_id) - rooms_result.extend(rooms) - events_result.extend(events) + if room: + rooms_result.append(room) + events_result.extend(events) # add any children to the queue room_queue.extend(edge_event["state_key"] for edge_event in events) @@ -223,19 +295,27 @@ class SpaceSummaryHandler: async def _summarize_local_room( self, requester: Optional[str], + origin: Optional[str], room_id: str, suggested_only: bool, max_children: Optional[int], - ) -> Tuple[Sequence[JsonDict], Sequence[JsonDict]]: + ) -> Tuple[Optional[JsonDict], Sequence[JsonDict]]: """ Generate a room entry and a list of event entries for a given room. Args: - requester: The requesting user, or None if this is over federation. + requester: + The user requesting the summary, if it is a local request. None + if this is a federation request. + origin: + The server requesting the summary, if it is a federation request. + None if this is a local request. room_id: The room ID to summarize. suggested_only: True if only suggested children should be returned. Otherwise, all children are returned. - max_children: The maximum number of children to return for this node. + max_children: + The maximum number of children rooms to include. This is capped + to a server-set limit. Returns: A tuple of: @@ -244,8 +324,8 @@ class SpaceSummaryHandler: An iterable of the sorted children events. This may be limited to a maximum size or may include all children. """ - if not await self._is_room_accessible(room_id, requester): - return (), () + if not await self._is_room_accessible(room_id, requester, origin): + return None, () room_entry = await self._build_room_entry(room_id) @@ -269,7 +349,7 @@ class SpaceSummaryHandler: event_format=format_event_for_client_v2, ) ) - return (room_entry,), events_result + return room_entry, events_result async def _summarize_remote_room( self, @@ -278,6 +358,26 @@ class SpaceSummaryHandler: max_children: Optional[int], exclude_rooms: Iterable[str], ) -> Tuple[Sequence[JsonDict], Sequence[JsonDict]]: + """ + Request room entries and a list of event entries for a given room by querying a remote server. + + Args: + room: The room to summarize. + suggested_only: True if only suggested children should be returned. + Otherwise, all children are returned. + max_children: + The maximum number of children rooms to include. This is capped + to a server-set limit. + exclude_rooms: + Rooms IDs which do not need to be summarized. + + Returns: + A tuple of: + An iterable of rooms. + + An iterable of the sorted children events. This may be limited + to a maximum size or may include all children. + """ room_id = room.room_id logger.info("Requesting summary for %s via %s", room_id, room.via) @@ -309,27 +409,93 @@ class SpaceSummaryHandler: or ev.event_type == EventTypes.SpaceChild ) - async def _is_room_accessible(self, room_id: str, requester: Optional[str]) -> bool: - # if we have an authenticated requesting user, first check if they are in the - # room + async def _is_room_accessible( + self, room_id: str, requester: Optional[str], origin: Optional[str] + ) -> bool: + """ + Calculate whether the room should be shown in the spaces summary. + + It should be included if: + + * The requester is joined or invited to the room. + * The requester can join without an invite (per MSC3083). + * The origin server has any user that is joined or invited to the room. + * The history visibility is set to world readable. + + Args: + room_id: The room ID to summarize. + requester: + The user requesting the summary, if it is a local request. None + if this is a federation request. + origin: + The server requesting the summary, if it is a federation request. + None if this is a local request. + + Returns: + True if the room should be included in the spaces summary. + """ + state_ids = await self._store.get_current_state_ids(room_id) + + # If there's no state for the room, it isn't known. + if not state_ids: + logger.info("room %s is unknown, omitting from summary", room_id) + return False + + room_version = await self._store.get_room_version(room_id) + + # if we have an authenticated requesting user, first check if they are able to view + # stripped state in the room. if requester: + member_event_id = state_ids.get((EventTypes.Member, requester), None) + + # If they're in the room they can see info on it. + member_event = None + if member_event_id: + member_event = await self._store.get_event(member_event_id) + if member_event.membership in (Membership.JOIN, Membership.INVITE): + return True + + # Otherwise, check if they should be allowed access via membership in a space. try: - await self._auth.check_user_in_room(room_id, requester) - return True + await self._event_auth_handler.check_restricted_join_rules( + state_ids, room_version, requester, member_event + ) except AuthError: + # The user doesn't have access due to spaces, but might have access + # another way. Keep trying. pass + else: + return True + + # If this is a request over federation, check if the host is in the room or + # is in one of the spaces specified via the join rules. + elif origin: + if await self._auth.check_host_in_room(room_id, origin): + return True + + # Alternately, if the host has a user in any of the spaces specified + # for access, then the host can see this room (and should do filtering + # if the requester cannot see it). + if await self._event_auth_handler.has_restricted_join_rules( + state_ids, room_version + ): + allowed_spaces = ( + await self._event_auth_handler.get_spaces_that_allow_join(state_ids) + ) + for space_id in allowed_spaces: + if await self._auth.check_host_in_room(space_id, origin): + return True # otherwise, check if the room is peekable - hist_vis_ev = await self._state_handler.get_current_state( - room_id, EventTypes.RoomHistoryVisibility, "" - ) - if hist_vis_ev: + hist_vis_event_id = state_ids.get((EventTypes.RoomHistoryVisibility, ""), None) + if hist_vis_event_id: + hist_vis_ev = await self._store.get_event(hist_vis_event_id) hist_vis = hist_vis_ev.content.get("history_visibility") if hist_vis == HistoryVisibility.WORLD_READABLE: return True logger.info( - "room %s is unpeekable and user %s is not a member, omitting from summary", + "room %s is unpeekable and user %s is not a member / not allowed to join, omitting from summary", room_id, requester, ) @@ -354,6 +520,15 @@ class SpaceSummaryHandler: if not room_type: room_type = create_event.content.get(EventContentFields.MSC1772_ROOM_TYPE) + room_version = await self._store.get_room_version(room_id) + allowed_spaces = None + if await self._event_auth_handler.has_restricted_join_rules( + current_state_ids, room_version + ): + allowed_spaces = await self._event_auth_handler.get_spaces_that_allow_join( + current_state_ids + ) + entry = { "room_id": stats["room_id"], "name": stats["name"], @@ -367,6 +542,7 @@ class SpaceSummaryHandler: "guest_can_join": stats["guest_access"] == "can_join", "creation_ts": create_event.origin_server_ts, "room_type": room_type, + "allowed_spaces": allowed_spaces, } # Filter out Nones – rather omit the field altogether @@ -430,8 +606,8 @@ def _is_suggested_child_event(edge_event: EventBase) -> bool: return False -# Order may only contain characters in the range of \x20 (space) to \x7F (~). -_INVALID_ORDER_CHARS_RE = re.compile(r"[^\x20-\x7F]") +# Order may only contain characters in the range of \x20 (space) to \x7E (~) inclusive. +_INVALID_ORDER_CHARS_RE = re.compile(r"[^\x20-\x7E]") def _child_events_comparison_key(child: EventBase) -> Tuple[bool, Optional[str], str]: diff --git a/synapse/http/client.py b/synapse/http/client.py
index 5f40f16e24..1ca6624fd5 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py
@@ -813,7 +813,12 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol): if self.deferred.called: return - self.stream.write(data) + try: + self.stream.write(data) + except Exception: + self.deferred.errback() + return + self.length += len(data) # The first time the maximum size is exceeded, error and cancel the # connection. dataReceived might be called again if data was received diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index bb837b7b19..1998990a14 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py
@@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import abc import cgi import codecs import logging @@ -19,13 +20,24 @@ import sys import typing import urllib.parse from io import BytesIO, StringIO -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import ( + Callable, + Dict, + Generic, + List, + Optional, + Tuple, + TypeVar, + Union, + overload, +) import attr import treq from canonicaljson import encode_canonical_json from prometheus_client import Counter from signedjson.sign import sign_json +from typing_extensions import Literal from twisted.internet import defer from twisted.internet.error import DNSLookupError @@ -48,6 +60,7 @@ from synapse.http.client import ( BlacklistingAgentWrapper, BlacklistingReactorWrapper, BodyExceededMaxSize, + ByteWriteable, encode_query_args, read_body_with_max_size, ) @@ -88,6 +101,27 @@ _next_id = 1 QueryArgs = Dict[str, Union[str, List[str]]] +T = TypeVar("T") + + +class ByteParser(ByteWriteable, Generic[T], abc.ABC): + """A `ByteWriteable` that has an additional `finish` function that returns + the parsed data. + """ + + CONTENT_TYPE = abc.abstractproperty() # type: str # type: ignore + """The expected content type of the response, e.g. `application/json`. If + the content type doesn't match we fail the request. + """ + + @abc.abstractmethod + def finish(self) -> T: + """Called when response has finished streaming and the parser should + return the final result (or error). + """ + pass + + @attr.s(slots=True, frozen=True) class MatrixFederationRequest: method = attr.ib(type=str) @@ -148,15 +182,33 @@ class MatrixFederationRequest: return self.json -async def _handle_json_response( +class JsonParser(ByteParser[Union[JsonDict, list]]): + """A parser that buffers the response and tries to parse it as JSON.""" + + CONTENT_TYPE = "application/json" + + def __init__(self): + self._buffer = StringIO() + self._binary_wrapper = BinaryIOWrapper(self._buffer) + + def write(self, data: bytes) -> int: + return self._binary_wrapper.write(data) + + def finish(self) -> Union[JsonDict, list]: + return json_decoder.decode(self._buffer.getvalue()) + + +async def _handle_response( reactor: IReactorTime, timeout_sec: float, request: MatrixFederationRequest, response: IResponse, start_ms: int, -) -> JsonDict: + parser: ByteParser[T], + max_response_size: Optional[int] = None, +) -> T: """ - Reads the JSON body of a response, with a timeout + Reads the body of a response with a timeout and sends it to a parser Args: reactor: twisted reactor, for the timeout @@ -164,23 +216,26 @@ async def _handle_json_response( request: the request that triggered the response response: response to the request start_ms: Timestamp when request was made + parser: The parser for the response + max_response_size: The maximum size to read from the response, if None + uses the default. Returns: - The parsed JSON response + The parsed response """ + + if max_response_size is None: + max_response_size = MAX_RESPONSE_SIZE + try: - check_content_type_is_json(response.headers) + check_content_type_is(response.headers, parser.CONTENT_TYPE) - buf = StringIO() - d = read_body_with_max_size(response, BinaryIOWrapper(buf), MAX_RESPONSE_SIZE) + d = read_body_with_max_size(response, parser, max_response_size) d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor) - def parse(_len: int): - return json_decoder.decode(buf.getvalue()) + length = await make_deferred_yieldable(d) - d.addCallback(parse) - - body = await make_deferred_yieldable(d) + value = parser.finish() except BodyExceededMaxSize as e: # The response was too big. logger.warning( @@ -193,9 +248,9 @@ async def _handle_json_response( ) raise RequestSendFailed(e, can_retry=False) from e except ValueError as e: - # The JSON content was invalid. + # The content was invalid. logger.warning( - "{%s} [%s] Failed to parse JSON response - %s %s", + "{%s} [%s] Failed to parse response - %s %s", request.txn_id, request.destination, request.method, @@ -225,16 +280,17 @@ async def _handle_json_response( time_taken_secs = reactor.seconds() - start_ms / 1000 logger.info( - "{%s} [%s] Completed request: %d %s in %.2f secs - %s %s", + "{%s} [%s] Completed request: %d %s in %.2f secs, got %d bytes - %s %s", request.txn_id, request.destination, response.code, response.phrase.decode("ascii", errors="replace"), time_taken_secs, + length, request.method, request.uri.decode("ascii"), ) - return body + return value class BinaryIOWrapper: @@ -671,6 +727,7 @@ class MatrixFederationHttpClient: ) return auth_headers + @overload async def put_json( self, destination: str, @@ -683,7 +740,44 @@ class MatrixFederationHttpClient: ignore_backoff: bool = False, backoff_on_404: bool = False, try_trailing_slash_on_400: bool = False, + parser: Literal[None] = None, + max_response_size: Optional[int] = None, ) -> Union[JsonDict, list]: + ... + + @overload + async def put_json( + self, + destination: str, + path: str, + args: Optional[QueryArgs] = None, + data: Optional[JsonDict] = None, + json_data_callback: Optional[Callable[[], JsonDict]] = None, + long_retries: bool = False, + timeout: Optional[int] = None, + ignore_backoff: bool = False, + backoff_on_404: bool = False, + try_trailing_slash_on_400: bool = False, + parser: Optional[ByteParser[T]] = None, + max_response_size: Optional[int] = None, + ) -> T: + ... + + async def put_json( + self, + destination: str, + path: str, + args: Optional[QueryArgs] = None, + data: Optional[JsonDict] = None, + json_data_callback: Optional[Callable[[], JsonDict]] = None, + long_retries: bool = False, + timeout: Optional[int] = None, + ignore_backoff: bool = False, + backoff_on_404: bool = False, + try_trailing_slash_on_400: bool = False, + parser: Optional[ByteParser] = None, + max_response_size: Optional[int] = None, + ): """Sends the specified json data using PUT Args: @@ -716,6 +810,10 @@ class MatrixFederationHttpClient: of the request. Workaround for #3622 in Synapse <= v0.99.3. This will be attempted before backing off if backing off has been enabled. + parser: The parser to use to decode the response. Defaults to + parsing as JSON. + max_response_size: The maximum size to read from the response, if None + uses the default. Returns: Succeeds when we get a 2xx HTTP response. The @@ -756,8 +854,17 @@ class MatrixFederationHttpClient: else: _sec_timeout = self.default_timeout - body = await _handle_json_response( - self.reactor, _sec_timeout, request, response, start_ms + if parser is None: + parser = JsonParser() + + body = await _handle_response( + self.reactor, + _sec_timeout, + request, + response, + start_ms, + parser=parser, + max_response_size=max_response_size, ) return body @@ -830,12 +937,8 @@ class MatrixFederationHttpClient: else: _sec_timeout = self.default_timeout - body = await _handle_json_response( - self.reactor, - _sec_timeout, - request, - response, - start_ms, + body = await _handle_response( + self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser() ) return body @@ -907,8 +1010,8 @@ class MatrixFederationHttpClient: else: _sec_timeout = self.default_timeout - body = await _handle_json_response( - self.reactor, _sec_timeout, request, response, start_ms + body = await _handle_response( + self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser() ) return body @@ -975,8 +1078,8 @@ class MatrixFederationHttpClient: else: _sec_timeout = self.default_timeout - body = await _handle_json_response( - self.reactor, _sec_timeout, request, response, start_ms + body = await _handle_response( + self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser() ) return body @@ -1068,16 +1171,16 @@ def _flatten_response_never_received(e): return repr(e) -def check_content_type_is_json(headers: Headers) -> None: +def check_content_type_is(headers: Headers, expected_content_type: str) -> None: """ Check that a set of HTTP headers have a Content-Type header, and that it - is application/json. + is the expected value.. Args: headers: headers to check Raises: - RequestSendFailed: if the Content-Type header is missing or isn't JSON + RequestSendFailed: if the Content-Type header is missing or doesn't match """ content_type_headers = headers.getRawHeaders(b"Content-Type") @@ -1089,11 +1192,10 @@ def check_content_type_is_json(headers: Headers) -> None: c_type = content_type_headers[0].decode("ascii") # only the first header val, options = cgi.parse_header(c_type) - if val != "application/json": + if val != expected_content_type: raise RequestSendFailed( RuntimeError( - "Remote server sent Content-Type header of '%s', not 'application/json'" - % c_type, + f"Remote server sent Content-Type header of '{c_type}', not '{expected_content_type}'", ), can_retry=False, ) diff --git a/synapse/http/site.py b/synapse/http/site.py
index 671fd3fbcc..40754b7bea 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py
@@ -105,8 +105,10 @@ class SynapseRequest(Request): assert self.content, "handleContentChunk() called before gotLength()" if self.content.tell() + len(data) > self._max_request_body_size: logger.warning( - "Aborting connection from %s because the request exceeds maximum size", + "Aborting connection from %s because the request exceeds maximum size: %s %s", self.client, + self.get_method(), + self.get_redacted_uri(), ) self.transport.abortConnection() return diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 12c5ea0815..e562ff693e 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py
@@ -56,14 +56,6 @@ class ModuleApi: self._http_client = hs.get_simple_http_client() # type: SimpleHttpClient self._public_room_list_manager = PublicRoomListManager(hs) - # The next time these users sync, they will receive the current presence - # state of all local users. Users are added by send_local_online_presence_to, - # and removed after a successful sync. - # - # We make this a private variable to deter modules from accessing it directly, - # though other classes in Synapse will still do so. - self._send_full_presence_to_local_users = set() - @property def http_client(self): """Allows making outbound HTTP requests to remote resources. @@ -405,39 +397,44 @@ class ModuleApi: Updates to remote users will be sent immediately, whereas local users will receive them on their next sync attempt. - Note that this method can only be run on the main or federation_sender worker - processes. + Note that this method can only be run on the process that is configured to write to the + presence stream. By default this is the main process. """ - if not self._hs.should_send_federation(): + if self._hs._instance_name not in self._hs.config.worker.writers.presence: raise Exception( "send_local_online_presence_to can only be run " - "on processes that send federation", + "on the process that is configured to write to the " + "presence stream (by default this is the main process)", ) + local_users = set() + remote_users = set() for user in users: if self._hs.is_mine_id(user): - # Modify SyncHandler._generate_sync_entry_for_presence to call - # presence_source.get_new_events with an empty `from_key` if - # that user's ID were in a list modified by ModuleApi somewhere. - # That user would then get all presence state on next incremental sync. - - # Force a presence initial_sync for this user next time - self._send_full_presence_to_local_users.add(user) + local_users.add(user) else: - # Retrieve presence state for currently online users that this user - # is considered interested in - presence_events, _ = await self._presence_stream.get_new_events( - UserID.from_string(user), from_key=None, include_offline=False - ) - - # Send to remote destinations. - - # We pull out the presence handler here to break a cyclic - # dependency between the presence router and module API. - presence_handler = self._hs.get_presence_handler() - await presence_handler.maybe_send_presence_to_interested_destinations( - presence_events - ) + remote_users.add(user) + + # We pull out the presence handler here to break a cyclic + # dependency between the presence router and module API. + presence_handler = self._hs.get_presence_handler() + + if local_users: + # Force a presence initial_sync for these users next time they sync. + await presence_handler.send_full_presence_to_users(local_users) + + for user in remote_users: + # Retrieve presence state for currently online users that this user + # is considered interested in. + presence_events, _ = await self._presence_stream.get_new_events( + UserID.from_string(user), from_key=None, include_offline=False + ) + + # Send to remote destinations. + destination = UserID.from_string(user).domain + presence_handler.get_federation_queue().send_presence_to_destinations( + presence_events, destination + ) class PublicRoomListManager: diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index c4b43b0d3f..5f9ea5003a 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py
@@ -12,12 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import email.mime.multipart -import email.utils import logging import urllib.parse -from email.mime.multipart import MIMEMultipart -from email.mime.text import MIMEText from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, TypeVar import bleach @@ -27,7 +23,6 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.errors import StoreError from synapse.config.emailconfig import EmailSubjectConfig from synapse.events import EventBase -from synapse.logging.context import make_deferred_yieldable from synapse.push.presentable_names import ( calculate_room_name, descriptor_from_member_events, @@ -108,7 +103,7 @@ class Mailer: self.template_html = template_html self.template_text = template_text - self.sendmail = self.hs.get_sendmail() + self.send_email_handler = hs.get_send_email_handler() self.store = self.hs.get_datastore() self.state_store = self.hs.get_storage().state self.macaroon_gen = self.hs.get_macaroon_generator() @@ -310,17 +305,6 @@ class Mailer: self, email_address: str, subject: str, extra_template_vars: Dict[str, Any] ) -> None: """Send an email with the given information and template text""" - try: - from_string = self.hs.config.email_notif_from % {"app": self.app_name} - except TypeError: - from_string = self.hs.config.email_notif_from - - raw_from = email.utils.parseaddr(from_string)[1] - raw_to = email.utils.parseaddr(email_address)[1] - - if raw_to == "": - raise RuntimeError("Invalid 'to' address") - template_vars = { "app_name": self.app_name, "server_name": self.hs.config.server.server_name, @@ -329,35 +313,14 @@ class Mailer: template_vars.update(extra_template_vars) html_text = self.template_html.render(**template_vars) - html_part = MIMEText(html_text, "html", "utf8") - plain_text = self.template_text.render(**template_vars) - text_part = MIMEText(plain_text, "plain", "utf8") - - multipart_msg = MIMEMultipart("alternative") - multipart_msg["Subject"] = subject - multipart_msg["From"] = from_string - multipart_msg["To"] = email_address - multipart_msg["Date"] = email.utils.formatdate() - multipart_msg["Message-ID"] = email.utils.make_msgid() - multipart_msg.attach(text_part) - multipart_msg.attach(html_part) - - logger.info("Sending email to %s" % email_address) - - await make_deferred_yieldable( - self.sendmail( - self.hs.config.email_smtp_host, - raw_from, - raw_to, - multipart_msg.as_string().encode("utf8"), - reactor=self.hs.get_reactor(), - port=self.hs.config.email_smtp_port, - requireAuthentication=self.hs.config.email_smtp_user is not None, - username=self.hs.config.email_smtp_user, - password=self.hs.config.email_smtp_pass, - requireTransportSecurity=self.hs.config.require_transport_security, - ) + + await self.send_email_handler.send_email( + email_address=email_address, + subject=subject, + app_name=self.app_name, + html=html_text, + text=plain_text, ) async def _get_room_vars( diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 989523c823..546231bec0 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py
@@ -87,6 +87,7 @@ REQUIREMENTS = [ # We enforce that we have a `cryptography` version that bundles an `openssl` # with the latest security patches. "cryptography>=3.4.7", + "ijson>=3.0", ] CONDITIONAL_REQUIREMENTS = { diff --git a/synapse/replication/http/presence.py b/synapse/replication/http/presence.py
index f25307620d..bb00247953 100644 --- a/synapse/replication/http/presence.py +++ b/synapse/replication/http/presence.py
@@ -73,6 +73,7 @@ class ReplicationPresenceSetState(ReplicationEndpoint): { "state": { ... }, "ignore_status_msg": false, + "force_notify": false } 200 OK @@ -91,17 +92,23 @@ class ReplicationPresenceSetState(ReplicationEndpoint): self._presence_handler = hs.get_presence_handler() @staticmethod - async def _serialize_payload(user_id, state, ignore_status_msg=False): + async def _serialize_payload( + user_id, state, ignore_status_msg=False, force_notify=False + ): return { "state": state, "ignore_status_msg": ignore_status_msg, + "force_notify": force_notify, } async def _handle_request(self, request, user_id): content = parse_json_object_from_request(request) await self._presence_handler.set_state( - UserID.from_string(user_id), content["state"], content["ignore_status_msg"] + UserID.from_string(user_id), + content["state"], + content["ignore_status_msg"], + content["force_notify"], ) return ( diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
index 8730966380..13ed87adc4 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py
@@ -24,7 +24,7 @@ class SlavedClientIpStore(BaseSlavedStore): super().__init__(database, db_conn, hs) self.client_ip_last_seen = LruCache( - cache_name="client_ip_last_seen", keylen=4, max_size=50000 + cache_name="client_ip_last_seen", max_size=50000 ) # type: LruCache[tuple, int] async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): diff --git a/synapse/replication/slave/storage/transactions.py b/synapse/replication/slave/storage/transactions.py deleted file mode 100644
index a59e543924..0000000000 --- a/synapse/replication/slave/storage/transactions.py +++ /dev/null
@@ -1,21 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.storage.databases.main.transactions import TransactionStore - -from ._base import BaseSlavedStore - - -class SlavedTransactionStore(TransactionStore, BaseSlavedStore): - pass diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py
index cc3ab5854b..b5e4c474ef 100644 --- a/synapse/rest/admin/server_notice_servlet.py +++ b/synapse/rest/admin/server_notice_servlet.py
@@ -54,7 +54,6 @@ class SendServerNoticeServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() self.txns = HttpTransactionCache(hs) - self.snm = hs.get_server_notices_manager() def register(self, json_resource: HttpServer): PATTERN = "/send_server_notice" @@ -77,7 +76,10 @@ class SendServerNoticeServlet(RestServlet): event_type = body.get("type", EventTypes.Message) state_key = body.get("state_key") - if not self.snm.is_enabled(): + # We grab the server notices manager here as its initialisation has a check for worker processes, + # but worker processes still need to initialise SendServerNoticeServlet (as it is part of the + # admin api). + if not self.hs.get_server_notices_manager().is_enabled(): raise SynapseError(400, "Server notices are not enabled on this server") user_id = body["user_id"] @@ -85,7 +87,7 @@ class SendServerNoticeServlet(RestServlet): if not self.hs.is_mine_id(user_id): raise SynapseError(400, "Server notices can only be sent to local users") - event = await self.snm.send_notice( + event = await self.hs.get_server_notices_manager().send_notice( user_id=body["user_id"], type=event_type, state_key=state_key, diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py
index e8dbe240d8..a5fcd15e3a 100644 --- a/synapse/rest/key/v2/local_key_resource.py +++ b/synapse/rest/key/v2/local_key_resource.py
@@ -48,11 +48,6 @@ class LocalKey(Resource): "key": # base64 encoded NACL verification key. } }, - "tls_fingerprints": [ # Fingerprints of the TLS certs this server uses. - { - "sha256": # base64 encoded sha256 fingerprint of the X509 cert - }, - ], "signatures": { "this.server.example.com": { "algorithm:version": # NACL signature for this server @@ -89,14 +84,11 @@ class LocalKey(Resource): "expired_ts": key.expired_ts, } - tls_fingerprints = self.config.tls_fingerprints - json_object = { "valid_until_ts": self.valid_until_ts, "server_name": self.config.server_name, "verify_keys": verify_keys, "old_verify_keys": old_verify_keys, - "tls_fingerprints": tls_fingerprints, } for key in self.config.signing_key: json_object = sign_json(json_object, self.config.server_name, key) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index f648678b09..aba1734a55 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -73,9 +73,6 @@ class RemoteKey(DirectServeJsonResource): "expired_ts": 0, # when the key stop being used. } } - "tls_fingerprints": [ - { "sha256": # fingerprint } - ] "signatures": { "remote.server.example.com": {...} "this.server.example.com": {...} diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index e8a875b900..21c43c340c 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py
@@ -76,6 +76,8 @@ class MediaRepository: self.max_upload_size = hs.config.max_upload_size self.max_image_pixels = hs.config.max_image_pixels + Thumbnailer.set_limits(self.max_image_pixels) + self.primary_base_path = hs.config.media_store_path # type: str self.filepaths = MediaFilePaths(self.primary_base_path) # type: MediaFilePaths diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index 37fe582390..a65e9e1802 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py
@@ -40,6 +40,10 @@ class Thumbnailer: FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"} + @staticmethod + def set_limits(max_image_pixels: int): + Image.MAX_IMAGE_PIXELS = max_image_pixels + def __init__(self, input_path: str): try: self.image = Image.open(input_path) @@ -47,6 +51,11 @@ class Thumbnailer: # If an error occurs opening the image, a thumbnail won't be able to # be generated. raise ThumbnailError from e + except Image.DecompressionBombError as e: + # If an image decompression bomb error occurs opening the image, + # then the image exceeds the pixel limit and a thumbnail won't + # be able to be generated. + raise ThumbnailError from e self.width, self.height = self.image.size self.transpose_method = None diff --git a/synapse/server.py b/synapse/server.py
index 2337d2d9b4..fec0024c89 100644 --- a/synapse/server.py +++ b/synapse/server.py
@@ -104,6 +104,7 @@ from synapse.handlers.room_list import RoomListHandler from synapse.handlers.room_member import RoomMemberHandler, RoomMemberMasterHandler from synapse.handlers.room_member_worker import RoomMemberWorkerHandler from synapse.handlers.search import SearchHandler +from synapse.handlers.send_email import SendEmailHandler from synapse.handlers.set_password import SetPasswordHandler from synapse.handlers.space_summary import SpaceSummaryHandler from synapse.handlers.sso import SsoHandler @@ -550,6 +551,10 @@ class HomeServer(metaclass=abc.ABCMeta): return SearchHandler(self) @cache_in_self + def get_send_email_handler(self) -> SendEmailHandler: + return SendEmailHandler(self) + + @cache_in_self def get_set_password_handler(self) -> SetPasswordHandler: return SetPasswordHandler(self) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 3d98d3f5f8..0623da9aa1 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py
@@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import random from abc import ABCMeta from typing import TYPE_CHECKING, Any, Collection, Iterable, Optional, Union @@ -44,7 +43,6 @@ class SQLBaseStore(metaclass=ABCMeta): self._clock = hs.get_clock() self.database_engine = database.engine self.db_pool = database - self.rand = random.SystemRandom() def process_replication_rows( self, diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 49c7606d51..9cce62ae6c 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py
@@ -67,7 +67,7 @@ from .state import StateStore from .stats import StatsStore from .stream import StreamStore from .tags import TagsStore -from .transactions import TransactionStore +from .transactions import TransactionWorkerStore from .ui_auth import UIAuthStore from .user_directory import UserDirectoryStore from .user_erasure_store import UserErasureStore @@ -83,7 +83,7 @@ class DataStore( StreamStore, ProfileStore, PresenceStore, - TransactionStore, + TransactionWorkerStore, DirectoryStore, KeyStore, StateStore, diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index d60010e942..074b077bef 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py
@@ -436,7 +436,7 @@ class ClientIpStore(ClientIpWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): self.client_ip_last_seen = LruCache( - cache_name="client_ip_last_seen", keylen=4, max_size=50000 + cache_name="client_ip_last_seen", max_size=50000 ) super().__init__(database, db_conn, hs) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index c9346de316..fd87ba71ab 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py
@@ -665,7 +665,7 @@ class DeviceWorkerStore(SQLBaseStore): cached_method_name="get_device_list_last_stream_id_for_remote", list_name="user_ids", ) - async def get_device_list_last_stream_id_for_remotes(self, user_ids: str): + async def get_device_list_last_stream_id_for_remotes(self, user_ids: Iterable[str]): rows = await self.db_pool.simple_select_many_batch( table="device_lists_remote_extremeties", column="user_id", @@ -1053,7 +1053,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): # Map of (user_id, device_id) -> bool. If there is an entry that implies # the device exists. self.device_id_exists_cache = LruCache( - cache_name="device_id_exists", keylen=2, max_size=10000 + cache_name="device_id_exists", max_size=10000 ) async def store_device( diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 398d6b6acb..9ba5778a88 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -473,7 +473,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): num_args=1, ) async def _get_bare_e2e_cross_signing_keys_bulk( - self, user_ids: List[str] + self, user_ids: Iterable[str] ) -> Dict[str, Dict[str, dict]]: """Returns the cross-signing keys for a set of users. The output of this function should be passed to _get_e2e_cross_signing_signatures_txn if @@ -497,7 +497,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): def _get_bare_e2e_cross_signing_keys_bulk_txn( self, txn: Connection, - user_ids: List[str], + user_ids: Iterable[str], ) -> Dict[str, Dict[str, dict]]: """Returns the cross-signing keys for a set of users. The output of this function should be passed to _get_e2e_cross_signing_signatures_txn if diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 2c823e09cf..6963bbf7f4 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py
@@ -157,7 +157,6 @@ class EventsWorkerStore(SQLBaseStore): self._get_event_cache = LruCache( cache_name="*getEvent*", - keylen=3, max_size=hs.config.caches.event_cache_size, ) diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index 0e86807834..6990f3ed1d 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py
@@ -55,7 +55,7 @@ class KeyStore(SQLBaseStore): """ keys = {} - def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str]]) -> None: + def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str], ...]) -> None: """Processes a batch of keys to fetch, and adds the result to `keys`.""" # batch_iter always returns tuples so it's safe to do len(batch) diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index db22fab23e..6a2baa7841 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py
@@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Dict, List, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple from synapse.api.presence import PresenceState, UserPresenceState from synapse.replication.tcp.streams import PresenceStream @@ -57,6 +57,7 @@ class PresenceStore(SQLBaseStore): db_conn, "presence_stream", "stream_id" ) + self.hs = hs self._presence_on_startup = self._get_active_presence(db_conn) presence_cache_prefill, min_presence_val = self.db_pool.get_cache_dict( @@ -96,6 +97,15 @@ class PresenceStore(SQLBaseStore): ) txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,)) + # Delete old rows to stop database from getting really big + sql = "DELETE FROM presence_stream WHERE stream_id < ? AND " + + for states in batch_iter(presence_states, 50): + clause, args = make_in_list_sql_clause( + self.database_engine, "user_id", [s.user_id for s in states] + ) + txn.execute(sql + clause, [stream_id] + list(args)) + # Actually insert new rows self.db_pool.simple_insert_many_txn( txn, @@ -116,15 +126,6 @@ class PresenceStore(SQLBaseStore): ], ) - # Delete old rows to stop database from getting really big - sql = "DELETE FROM presence_stream WHERE stream_id < ? AND " - - for states in batch_iter(presence_states, 50): - clause, args = make_in_list_sql_clause( - self.database_engine, "user_id", [s.user_id for s in states] - ) - txn.execute(sql + clause, [stream_id] + list(args)) - async def get_all_presence_updates( self, instance_name: str, last_id: int, current_id: int, limit: int ) -> Tuple[List[Tuple[int, list]], int, bool]: @@ -210,6 +211,61 @@ class PresenceStore(SQLBaseStore): return {row["user_id"]: UserPresenceState(**row) for row in rows} + async def should_user_receive_full_presence_with_token( + self, + user_id: str, + from_token: int, + ) -> bool: + """Check whether the given user should receive full presence using the stream token + they're updating from. + + Args: + user_id: The ID of the user to check. + from_token: The stream token included in their /sync token. + + Returns: + True if the user should have full presence sent to them, False otherwise. + """ + + def _should_user_receive_full_presence_with_token_txn(txn): + sql = """ + SELECT 1 FROM users_to_send_full_presence_to + WHERE user_id = ? + AND presence_stream_id >= ? + """ + txn.execute(sql, (user_id, from_token)) + return bool(txn.fetchone()) + + return await self.db_pool.runInteraction( + "should_user_receive_full_presence_with_token", + _should_user_receive_full_presence_with_token_txn, + ) + + async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]): + """Adds to the list of users who should receive a full snapshot of presence + upon their next sync. + + Args: + user_ids: An iterable of user IDs. + """ + # Add user entries to the table, updating the presence_stream_id column if the user already + # exists in the table. + await self.db_pool.simple_upsert_many( + table="users_to_send_full_presence_to", + key_names=("user_id",), + key_values=[(user_id,) for user_id in user_ids], + value_names=("presence_stream_id",), + # We save the current presence stream ID token along with the user ID entry so + # that when a user /sync's, even if they syncing multiple times across separate + # devices at different times, each device will receive full presence once - when + # the presence stream ID in their sync token is less than the one in the table + # for their user ID. + value_values=( + (self._presence_id_gen.get_current_token(),) for _ in user_ids + ), + desc="add_users_to_send_full_presence_to", + ) + async def get_presence_for_all_users( self, include_offline: bool = True, diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index d36b18a0e9..77e2eb27db 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py
@@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import random import re from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union @@ -1077,7 +1078,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): expiration_ts = now_ms + self._account_validity_period if use_delta: - expiration_ts = self.rand.randrange( + expiration_ts = random.randrange( expiration_ts - self._account_validity_startup_job_max_delta, expiration_ts, ) diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 82335e7a9d..d211c423b2 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py
@@ -16,13 +16,15 @@ import logging from collections import namedtuple from typing import Iterable, List, Optional, Tuple +import attr from canonicaljson import encode_canonical_json from synapse.metrics.background_process_metrics import wrap_as_background_process -from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage._base import db_to_json from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.types import JsonDict -from synapse.util.caches.expiringcache import ExpiringCache +from synapse.util.caches.descriptors import cached db_binary_type = memoryview @@ -38,10 +40,23 @@ _UpdateTransactionRow = namedtuple( "_TransactionRow", ("response_code", "response_json") ) -SENTINEL = object() +@attr.s(slots=True, frozen=True, auto_attribs=True) +class DestinationRetryTimings: + """The current destination retry timing info for a remote server.""" -class TransactionWorkerStore(SQLBaseStore): + # The first time we tried and failed to reach the remote server, in ms. + failure_ts: int + + # The last time we tried and failed to reach the remote server, in ms. + retry_last_ts: int + + # How long since the last time we tried to reach the remote server before + # trying again, in ms. + retry_interval: int + + +class TransactionWorkerStore(CacheInvalidationWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) @@ -60,19 +75,6 @@ class TransactionWorkerStore(SQLBaseStore): "_cleanup_transactions", _cleanup_transactions_txn ) - -class TransactionStore(TransactionWorkerStore): - """A collection of queries for handling PDUs.""" - - def __init__(self, database: DatabasePool, db_conn, hs): - super().__init__(database, db_conn, hs) - - self._destination_retry_cache = ExpiringCache( - cache_name="get_destination_retry_timings", - clock=self._clock, - expiry_ms=5 * 60 * 1000, - ) - async def get_received_txn_response( self, transaction_id: str, origin: str ) -> Optional[Tuple[int, JsonDict]]: @@ -145,7 +147,11 @@ class TransactionStore(TransactionWorkerStore): desc="set_received_txn_response", ) - async def get_destination_retry_timings(self, destination): + @cached(max_entries=10000) + async def get_destination_retry_timings( + self, + destination: str, + ) -> Optional[DestinationRetryTimings]: """Gets the current retry timings (if any) for a given destination. Args: @@ -156,34 +162,29 @@ class TransactionStore(TransactionWorkerStore): Otherwise a dict for the retry scheme """ - result = self._destination_retry_cache.get(destination, SENTINEL) - if result is not SENTINEL: - return result - result = await self.db_pool.runInteraction( "get_destination_retry_timings", self._get_destination_retry_timings, destination, ) - # We don't hugely care about race conditions between getting and - # invalidating the cache, since we time out fairly quickly anyway. - self._destination_retry_cache[destination] = result return result - def _get_destination_retry_timings(self, txn, destination): + def _get_destination_retry_timings( + self, txn, destination: str + ) -> Optional[DestinationRetryTimings]: result = self.db_pool.simple_select_one_txn( txn, table="destinations", keyvalues={"destination": destination}, - retcols=("destination", "failure_ts", "retry_last_ts", "retry_interval"), + retcols=("failure_ts", "retry_last_ts", "retry_interval"), allow_none=True, ) # check we have a row and retry_last_ts is not null or zero # (retry_last_ts can't be negative) if result and result["retry_last_ts"]: - return result + return DestinationRetryTimings(**result) else: return None @@ -204,7 +205,6 @@ class TransactionStore(TransactionWorkerStore): retry_interval: how long until next retry in ms """ - self._destination_retry_cache.pop(destination, None) if self.database_engine.can_native_upsert: return await self.db_pool.runInteraction( "set_destination_retry_timings", @@ -252,6 +252,10 @@ class TransactionStore(TransactionWorkerStore): txn.execute(sql, (destination, failure_ts, retry_last_ts, retry_interval)) + self._invalidate_cache_and_stream( + txn, self.get_destination_retry_timings, (destination,) + ) + def _set_destination_retry_timings_emulated( self, txn, destination, failure_ts, retry_last_ts, retry_interval ): @@ -295,6 +299,10 @@ class TransactionStore(TransactionWorkerStore): }, ) + self._invalidate_cache_and_stream( + txn, self.get_destination_retry_timings, (destination,) + ) + async def store_destination_rooms_entries( self, destinations: Iterable[str], diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py
index acf6b2fb64..1ecdd40c38 100644 --- a/synapse/storage/databases/main/user_erasure_store.py +++ b/synapse/storage/databases/main/user_erasure_store.py
@@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict, Iterable + from synapse.storage._base import SQLBaseStore from synapse.util.caches.descriptors import cached, cachedList @@ -37,21 +39,16 @@ class UserErasureWorkerStore(SQLBaseStore): return bool(result) @cachedList(cached_method_name="is_user_erased", list_name="user_ids") - async def are_users_erased(self, user_ids): + async def are_users_erased(self, user_ids: Iterable[str]) -> Dict[str, bool]: """ Checks which users in a list have requested erasure Args: - user_ids (iterable[str]): full user id to check + user_ids: full user ids to check Returns: - dict[str, bool]: - for each user, whether the user has requested erasure. + for each user, whether the user has requested erasure. """ - # this serves the dual purpose of (a) making sure we can do len and - # iterate it multiple times, and (b) avoiding duplicates. - user_ids = tuple(set(user_ids)) - rows = await self.db_pool.simple_select_many_batch( table="erased_users", column="user_id", diff --git a/synapse/storage/schema/main/delta/59/13users_to_send_full_presence_to.sql b/synapse/storage/schema/main/delta/59/13users_to_send_full_presence_to.sql new file mode 100644
index 0000000000..07b0f53ecf --- /dev/null +++ b/synapse/storage/schema/main/delta/59/13users_to_send_full_presence_to.sql
@@ -0,0 +1,34 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Add a table that keeps track of a list of users who should, upon their next +-- sync request, receive presence for all currently online users that they are +-- "interested" in. + +-- The motivation for a DB table over an in-memory list is so that this list +-- can be added to and retrieved from by any worker. Specifically, we don't +-- want to duplicate work across multiple sync workers. + +CREATE TABLE IF NOT EXISTS users_to_send_full_presence_to( + -- The user ID to send full presence to. + user_id TEXT PRIMARY KEY, + -- A presence stream ID token - the current presence stream token when the row was last upserted. + -- If a user calls /sync and this token is part of the update they're to receive, we also include + -- full user presence in the response. + -- This allows multiple devices for a user to receive full presence whenever they next call /sync. + presence_stream_id BIGINT, + FOREIGN KEY (user_id) + REFERENCES users (name) +); \ No newline at end of file diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index cfafba22c5..c9dce726cb 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py
@@ -540,7 +540,7 @@ class StateGroupStorage: state_filter: The state filter used to fetch state from the database. Returns: - A dict from (type, state_key) -> state_event + A dict from (type, state_key) -> state_event_id """ state_map = await self.get_state_ids_for_events( [event_id], state_filter or StateFilter.all() diff --git a/synapse/util/batching_queue.py b/synapse/util/batching_queue.py new file mode 100644
index 0000000000..44bbb7b1a8 --- /dev/null +++ b/synapse/util/batching_queue.py
@@ -0,0 +1,153 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import ( + Awaitable, + Callable, + Dict, + Generic, + Hashable, + List, + Set, + Tuple, + TypeVar, +) + +from twisted.internet import defer + +from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable +from synapse.metrics import LaterGauge +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.util import Clock + +logger = logging.getLogger(__name__) + + +V = TypeVar("V") +R = TypeVar("R") + + +class BatchingQueue(Generic[V, R]): + """A queue that batches up work, calling the provided processing function + with all pending work (for a given key). + + The provided processing function will only be called once at a time for each + key. It will be called the next reactor tick after `add_to_queue` has been + called, and will keep being called until the queue has been drained (for the + given key). + + Note that the return value of `add_to_queue` will be the return value of the + processing function that processed the given item. This means that the + returned value will likely include data for other items that were in the + batch. + """ + + def __init__( + self, + name: str, + clock: Clock, + process_batch_callback: Callable[[List[V]], Awaitable[R]], + ): + self._name = name + self._clock = clock + + # The set of keys currently being processed. + self._processing_keys = set() # type: Set[Hashable] + + # The currently pending batch of values by key, with a Deferred to call + # with the result of the corresponding `_process_batch_callback` call. + self._next_values = {} # type: Dict[Hashable, List[Tuple[V, defer.Deferred]]] + + # The function to call with batches of values. + self._process_batch_callback = process_batch_callback + + LaterGauge( + "synapse_util_batching_queue_number_queued", + "The number of items waiting in the queue across all keys", + labels=("name",), + caller=lambda: sum(len(v) for v in self._next_values.values()), + ) + + LaterGauge( + "synapse_util_batching_queue_number_of_keys", + "The number of distinct keys that have items queued", + labels=("name",), + caller=lambda: len(self._next_values), + ) + + async def add_to_queue(self, value: V, key: Hashable = ()) -> R: + """Adds the value to the queue with the given key, returning the result + of the processing function for the batch that included the given value. + + The optional `key` argument allows sharding the queue by some key. The + queues will then be processed in parallel, i.e. the process batch + function will be called in parallel with batched values from a single + key. + """ + + # First we create a defer and add it and the value to the list of + # pending items. + d = defer.Deferred() + self._next_values.setdefault(key, []).append((value, d)) + + # If we're not currently processing the key fire off a background + # process to start processing. + if key not in self._processing_keys: + run_as_background_process(self._name, self._process_queue, key) + + return await make_deferred_yieldable(d) + + async def _process_queue(self, key: Hashable) -> None: + """A background task to repeatedly pull things off the queue for the + given key and call the `self._process_batch_callback` with the values. + """ + + try: + if key in self._processing_keys: + return + + self._processing_keys.add(key) + + while True: + # We purposefully wait a reactor tick to allow us to batch + # together requests that we're about to receive. A common + # pattern is to call `add_to_queue` multiple times at once, and + # deferring to the next reactor tick allows us to batch all of + # those up. + await self._clock.sleep(0) + + next_values = self._next_values.pop(key, []) + if not next_values: + # We've exhausted the queue. + break + + try: + values = [value for value, _ in next_values] + results = await self._process_batch_callback(values) + + for _, deferred in next_values: + with PreserveLoggingContext(): + deferred.callback(results) + + except Exception as e: + for _, deferred in next_values: + if deferred.called: + continue + + with PreserveLoggingContext(): + deferred.errback(e) + + finally: + self._processing_keys.discard(key) diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 484097a48a..371e7e4dd0 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py
@@ -70,7 +70,6 @@ class DeferredCache(Generic[KT, VT]): self, name: str, max_entries: int = 1000, - keylen: int = 1, tree: bool = False, iterable: bool = False, apply_cache_factor_from_config: bool = True, @@ -101,7 +100,6 @@ class DeferredCache(Generic[KT, VT]): # a Deferred. self.cache = LruCache( max_size=max_entries, - keylen=keylen, cache_name=name, cache_type=cache_type, size_callback=(lambda d: len(d) or 1) if iterable else None, diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index ac4a078b26..2ac24a2f25 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py
@@ -270,7 +270,6 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): cache = DeferredCache( name=self.orig.__name__, max_entries=self.max_entries, - keylen=self.num_args, tree=self.tree, iterable=self.iterable, ) # type: DeferredCache[CacheKey, Any] @@ -322,8 +321,8 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): class DeferredCacheListDescriptor(_CacheDescriptorBase): """Wraps an existing cache to support bulk fetching of keys. - Given a list of keys it looks in the cache to find any hits, then passes - the list of missing keys to the wrapped function. + Given an iterable of keys it looks in the cache to find any hits, then passes + the tuple of missing keys to the wrapped function. Once wrapped, the function returns a Deferred which resolves to the list of results. @@ -437,7 +436,9 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): return f args_to_call = dict(arg_dict) - args_to_call[self.list_name] = list(missing) + # copy the missing set before sending it to the callee, to guard against + # modification. + args_to_call[self.list_name] = tuple(missing) cached_defers.append( defer.maybeDeferred( @@ -522,14 +523,14 @@ def cachedList( Used to do batch lookups for an already created cache. A single argument is specified as a list that is iterated through to lookup keys in the - original cache. A new list consisting of the keys that weren't in the cache - get passed to the original function, the result of which is stored in the + original cache. A new tuple consisting of the (deduplicated) keys that weren't in + the cache gets passed to the original function, the result of which is stored in the cache. Args: cached_method_name: The name of the single-item lookup method. This is only used to find the cache to use. - list_name: The name of the argument that is the list to use to + list_name: The name of the argument that is the iterable to use to do batch lookups in the cache. num_args: Number of arguments to use as the key in the cache (including list_name). Defaults to all named parameters. diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 1be675e014..54df407ff7 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py
@@ -34,7 +34,7 @@ from typing_extensions import Literal from synapse.config import cache as cache_config from synapse.util import caches from synapse.util.caches import CacheMetric, register_cache -from synapse.util.caches.treecache import TreeCache +from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry try: from pympler.asizeof import Asizer @@ -160,7 +160,6 @@ class LruCache(Generic[KT, VT]): self, max_size: int, cache_name: Optional[str] = None, - keylen: int = 1, cache_type: Type[Union[dict, TreeCache]] = dict, size_callback: Optional[Callable] = None, metrics_collection_callback: Optional[Callable[[], None]] = None, @@ -173,9 +172,6 @@ class LruCache(Generic[KT, VT]): cache_name: The name of this cache, for the prometheus metrics. If unset, no metrics will be reported on this cache. - keylen: The length of the tuple used as the cache key. Ignored unless - cache_type is `TreeCache`. - cache_type (type): type of underlying cache to be used. Typically one of dict or TreeCache. @@ -403,7 +399,9 @@ class LruCache(Generic[KT, VT]): popped = cache.pop(key) if popped is None: return - for leaf in enumerate_leaves(popped, keylen - len(cast(tuple, key))): + # for each deleted node, we now need to remove it from the linked list + # and run its callbacks. + for leaf in iterate_tree_cache_entry(popped): delete_node(leaf) @synchronized diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index eb4d98f683..73502a8b06 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py
@@ -1,18 +1,43 @@ -from typing import Dict +# Copyright 2016-2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. SENTINEL = object() +class TreeCacheNode(dict): + """The type of nodes in our tree. + + Has its own type so we can distinguish it from real dicts that are stored at the + leaves. + """ + + pass + + class TreeCache: """ Tree-based backing store for LruCache. Allows subtrees of data to be deleted efficiently. Keys must be tuples. + + The data structure is a chain of TreeCacheNodes: + root = {key_1: {key_2: _value}} """ def __init__(self): self.size = 0 - self.root = {} # type: Dict + self.root = TreeCacheNode() def __setitem__(self, key, value): return self.set(key, value) @@ -21,10 +46,23 @@ class TreeCache: return self.get(key, SENTINEL) is not SENTINEL def set(self, key, value): + if isinstance(value, TreeCacheNode): + # this would mean we couldn't tell where our tree ended and the value + # started. + raise ValueError("Cannot store TreeCacheNodes in a TreeCache") + node = self.root for k in key[:-1]: - node = node.setdefault(k, {}) - node[key[-1]] = _Entry(value) + next_node = node.get(k, SENTINEL) + if next_node is SENTINEL: + next_node = node[k] = TreeCacheNode() + elif not isinstance(next_node, TreeCacheNode): + # this suggests that the caller is not being consistent with its key + # length. + raise ValueError("value conflicts with an existing subtree") + node = next_node + + node[key[-1]] = value self.size += 1 def get(self, key, default=None): @@ -33,25 +71,41 @@ class TreeCache: node = node.get(k, None) if node is None: return default - return node.get(key[-1], _Entry(default)).value + return node.get(key[-1], default) def clear(self): self.size = 0 - self.root = {} + self.root = TreeCacheNode() def pop(self, key, default=None): + """Remove the given key, or subkey, from the cache + + Args: + key: key or subkey to remove. + default: value to return if key is not found + + Returns: + If the key is not found, 'default'. If the key is complete, the removed + value. If the key is partial, the TreeCacheNode corresponding to the part + of the tree that was removed. + """ + # a list of the nodes we have touched on the way down the tree nodes = [] node = self.root for k in key[:-1]: node = node.get(k, None) - nodes.append(node) # don't add the root node if node is None: return default + if not isinstance(node, TreeCacheNode): + # we've gone off the end of the tree + raise ValueError("pop() key too long") + nodes.append(node) # don't add the root node popped = node.pop(key[-1], SENTINEL) if popped is SENTINEL: return default + # working back up the tree, clear out any nodes that are now empty node_and_keys = list(zip(nodes, key)) node_and_keys.reverse() node_and_keys.append((self.root, None)) @@ -61,14 +115,15 @@ class TreeCache: if n: break + # found an empty node: remove it from its parent, and loop. node_and_keys[i + 1][0].pop(k) - popped, cnt = _strip_and_count_entires(popped) + cnt = sum(1 for _ in iterate_tree_cache_entry(popped)) self.size -= cnt return popped def values(self): - return list(iterate_tree_cache_entry(self.root)) + return iterate_tree_cache_entry(self.root) def __len__(self): return self.size @@ -78,36 +133,9 @@ def iterate_tree_cache_entry(d): """Helper function to iterate over the leaves of a tree, i.e. a dict of that can contain dicts. """ - if isinstance(d, dict): + if isinstance(d, TreeCacheNode): for value_d in d.values(): for value in iterate_tree_cache_entry(value_d): yield value else: - if isinstance(d, _Entry): - yield d.value - else: - yield d - - -class _Entry: - __slots__ = ["value"] - - def __init__(self, value): - self.value = value - - -def _strip_and_count_entires(d): - """Takes an _Entry or dict with leaves of _Entry's, and either returns the - value or a dictionary with _Entry's replaced by their values. - - Also returns the count of _Entry's - """ - if isinstance(d, dict): - cnt = 0 - for key, value in d.items(): - v, n = _strip_and_count_entires(value) - d[key] = v - cnt += n - return d, cnt - else: - return d.value, 1 + yield d diff --git a/synapse/util/hash.py b/synapse/util/hash.py
index ba676e1762..7625ca8c2c 100644 --- a/synapse/util/hash.py +++ b/synapse/util/hash.py
@@ -17,15 +17,15 @@ import hashlib import unpaddedbase64 -def sha256_and_url_safe_base64(input_text): +def sha256_and_url_safe_base64(input_text: str) -> str: """SHA256 hash an input string, encode the digest as url-safe base64, and return - :param input_text: string to hash - :type input_text: str + Args: + input_text: string to hash - :returns a sha256 hashed and url-safe base64 encoded digest - :rtype: str + returns: + A sha256 hashed and url-safe base64 encoded digest """ digest = hashlib.sha256(input_text.encode()).digest() return unpaddedbase64.encode_base64(digest, urlsafe=True) diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py
index abfdc29832..886afa9d19 100644 --- a/synapse/util/iterutils.py +++ b/synapse/util/iterutils.py
@@ -30,12 +30,12 @@ from typing import ( T = TypeVar("T") -def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T]]: +def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T, ...]]: """batch an iterable up into tuples with a maximum size Args: - iterable (iterable): the iterable to slice - size (int): the maximum batch size + iterable: the iterable to slice + size: the maximum batch size Returns: an iterator over the chunks @@ -46,10 +46,7 @@ def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T]]: return iter(lambda: tuple(islice(sourceiter, size)), ()) -ISeq = TypeVar("ISeq", bound=Sequence, covariant=True) - - -def chunk_seq(iseq: ISeq, maxlen: int) -> Iterable[ISeq]: +def chunk_seq(iseq: Sequence[T], maxlen: int) -> Iterable[Sequence[T]]: """Split the given sequence into chunks of the given size The last chunk may be shorter than the given size. diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py
index 8acbe276e4..cbfbd097f9 100644 --- a/synapse/util/module_loader.py +++ b/synapse/util/module_loader.py
@@ -15,6 +15,7 @@ import importlib import importlib.util import itertools +from types import ModuleType from typing import Any, Iterable, Tuple, Type import jsonschema @@ -44,8 +45,8 @@ def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]: # We need to import the module, and then pick the class out of # that, so we split based on the last dot. - module, clz = modulename.rsplit(".", 1) - module = importlib.import_module(module) + module_name, clz = modulename.rsplit(".", 1) + module = importlib.import_module(module_name) provider_class = getattr(module, clz) # Load the module config. If None, pass an empty dictionary instead @@ -69,11 +70,11 @@ def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]: return provider_class, provider_config -def load_python_module(location: str): +def load_python_module(location: str) -> ModuleType: """Load a python module, and return a reference to its global namespace Args: - location (str): path to the module + location: path to the module Returns: python module object diff --git a/synapse/util/msisdn.py b/synapse/util/msisdn.py
index bbbdebf264..1046224f15 100644 --- a/synapse/util/msisdn.py +++ b/synapse/util/msisdn.py
@@ -17,19 +17,19 @@ import phonenumbers from synapse.api.errors import SynapseError -def phone_number_to_msisdn(country, number): +def phone_number_to_msisdn(country: str, number: str) -> str: """ Takes an ISO-3166-1 2 letter country code and phone number and returns an msisdn representing the canonical version of that phone number. Args: - country (str): ISO-3166-1 2 letter country code - number (str): Phone number in a national or international format + country: ISO-3166-1 2 letter country code + number: Phone number in a national or international format Returns: - (str) The canonical form of the phone number, as an msisdn + The canonical form of the phone number, as an msisdn Raises: - SynapseError if the number could not be parsed. + SynapseError if the number could not be parsed. """ try: phoneNumber = phonenumbers.parse(number, country) diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index f9c370a814..129b47cd49 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py
@@ -82,11 +82,9 @@ async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **k retry_timings = await store.get_destination_retry_timings(destination) if retry_timings: - failure_ts = retry_timings["failure_ts"] - retry_last_ts, retry_interval = ( - retry_timings["retry_last_ts"], - retry_timings["retry_interval"], - ) + failure_ts = retry_timings.failure_ts + retry_last_ts = retry_timings.retry_last_ts + retry_interval = retry_timings.retry_interval now = int(clock.time_msec()) diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 4f25cd1d26..f029432191 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py
@@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import itertools -import random import re +import secrets import string from collections.abc import Iterable from typing import Optional, Tuple @@ -35,26 +35,27 @@ CLIENT_SECRET_REGEX = re.compile(r"^[0-9a-zA-Z\.=_\-]+$") # MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$") -# random_string and random_string_with_symbols are used for a range of things, -# some cryptographically important, some less so. We use SystemRandom to make sure -# we get cryptographically-secure randoms. -rand = random.SystemRandom() - def random_string(length: int) -> str: - return "".join(rand.choice(string.ascii_letters) for _ in range(length)) + """Generate a cryptographically secure string of random letters. + + Drawn from the characters: `a-z` and `A-Z` + """ + return "".join(secrets.choice(string.ascii_letters) for _ in range(length)) def random_string_with_symbols(length: int) -> str: - return "".join(rand.choice(_string_with_symbols) for _ in range(length)) + """Generate a cryptographically secure string of random letters/numbers/symbols. + + Drawn from the characters: `a-z`, `A-Z`, `0-9`, and `.,;:^&*-_+=#~@` + """ + return "".join(secrets.choice(_string_with_symbols) for _ in range(length)) def is_ascii(s: bytes) -> bool: try: s.decode("ascii").encode("ascii") - except UnicodeDecodeError: - return False - except UnicodeEncodeError: + except UnicodeError: return False return True