diff --git a/synapse/__init__.py b/synapse/__init__.py
index 7498a6016f..d9843a1708 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -47,7 +47,7 @@ try:
except ImportError:
pass
-__version__ = "1.34.0"
+__version__ = "1.35.0"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index c3c776a9f9..b2e60c6aa7 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -87,6 +87,7 @@ class Auth:
)
self._track_appservice_user_ips = hs.config.track_appservice_user_ips
self._macaroon_secret_key = hs.config.macaroon_secret_key
+ self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users
async def check_from_context(
self, room_version: str, event, context, do_sig_check=True
@@ -208,6 +209,8 @@ class Auth:
opentracing.set_tag("authenticated_entity", user_id)
opentracing.set_tag("user_id", user_id)
opentracing.set_tag("appservice_id", app_service.id)
+ if user_id in self._force_tracing_for_users:
+ opentracing.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
return requester
@@ -260,6 +263,8 @@ class Auth:
opentracing.set_tag("user_id", user_info.user_id)
if device_id:
opentracing.set_tag("device_id", device_id)
+ if user_info.token_owner in self._force_tracing_for_users:
+ opentracing.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
return requester
except KeyError:
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index f730cdbd78..91ad326f19 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -61,7 +61,6 @@ from synapse.replication.slave.storage.pushers import SlavedPusherStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.room import RoomStore
-from synapse.replication.slave.storage.transactions import SlavedTransactionStore
from synapse.rest.admin import register_servlets_for_media_repo
from synapse.rest.client.v1 import events, login, presence, room
from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
@@ -237,7 +236,6 @@ class GenericWorkerSlavedStore(
DirectoryStore,
SlavedApplicationServiceStore,
SlavedRegistrationStore,
- SlavedTransactionStore,
SlavedProfileStore,
SlavedClientIpStore,
SlavedFilteringStore,
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 7137e3d323..ea692f699d 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -35,9 +35,26 @@ class ExperimentalConfig(Config):
self.msc2858_enabled = experimental.get("msc2858_enabled", False) # type: bool
# Spaces (MSC1772, MSC2946, MSC3083, etc)
- self.spaces_enabled = experimental.get("spaces_enabled", False) # type: bool
+ self.spaces_enabled = experimental.get("spaces_enabled", True) # type: bool
if self.spaces_enabled:
KNOWN_ROOM_VERSIONS[RoomVersions.MSC3083.identifier] = RoomVersions.MSC3083
# MSC3026 (busy presence state)
self.msc3026_enabled = experimental.get("msc3026_enabled", False) # type: bool
+
+ def generate_config_section(self, **kwargs):
+ return """\
+ # Enable experimental features in Synapse.
+ #
+ # Experimental features might break or be removed without a deprecation
+ # period.
+ #
+ experimental_features:
+ # Support for Spaces (MSC1772), it enables the following:
+ #
+ # * The Spaces Summary API (MSC2946).
+ # * Restricting room membership based on space membership (MSC3083).
+ #
+ # Uncomment to disable support for Spaces.
+ #spaces_enabled: false
+ """
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index c23b66c88c..5ae0f55bcc 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -57,7 +57,6 @@ class HomeServerConfig(RootConfig):
config_classes = [
ServerConfig,
- ExperimentalConfig,
TlsConfig,
FederationConfig,
CacheConfig,
@@ -94,4 +93,5 @@ class HomeServerConfig(RootConfig):
TracerConfig,
WorkerConfig,
RedisConfig,
+ ExperimentalConfig,
]
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 632ef0d796..eecc0478a7 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -477,4 +477,4 @@ class RegistrationConfig(Config):
def read_arguments(self, args):
if args.enable_registration is not None:
- self.enable_registration = bool(strtobool(str(args.enable_registration)))
+ self.enable_registration = strtobool(str(args.enable_registration))
diff --git a/synapse/config/saml2.py b/synapse/config/saml2.py
index 3d1218c8d1..05e983625d 100644
--- a/synapse/config/saml2.py
+++ b/synapse/config/saml2.py
@@ -164,7 +164,13 @@ class SAML2Config(Config):
config_path = saml2_config.get("config_path", None)
if config_path is not None:
mod = load_python_module(config_path)
- _dict_merge(merge_dict=mod.CONFIG, into_dict=saml2_config_dict)
+ config = getattr(mod, "CONFIG", None)
+ if config is None:
+ raise ConfigError(
+ "Config path specified by saml2_config.config_path does not "
+ "have a CONFIG property."
+ )
+ _dict_merge(merge_dict=config, into_dict=saml2_config_dict)
import saml2.config
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index 7df4e4c3e6..26f1150ca5 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -16,11 +16,8 @@ import logging
import os
import warnings
from datetime import datetime
-from hashlib import sha256
from typing import List, Optional, Pattern
-from unpaddedbase64 import encode_base64
-
from OpenSSL import SSL, crypto
from twisted.internet._sslverify import Certificate, trustRootFromCertificates
@@ -83,13 +80,6 @@ class TlsConfig(Config):
"configured."
)
- self._original_tls_fingerprints = config.get("tls_fingerprints", [])
-
- if self._original_tls_fingerprints is None:
- self._original_tls_fingerprints = []
-
- self.tls_fingerprints = list(self._original_tls_fingerprints)
-
# Whether to verify certificates on outbound federation traffic
self.federation_verify_certificates = config.get(
"federation_verify_certificates", True
@@ -248,19 +238,6 @@ class TlsConfig(Config):
e,
)
- self.tls_fingerprints = list(self._original_tls_fingerprints)
-
- if self.tls_certificate:
- # Check that our own certificate is included in the list of fingerprints
- # and include it if it is not.
- x509_certificate_bytes = crypto.dump_certificate(
- crypto.FILETYPE_ASN1, self.tls_certificate
- )
- sha256_fingerprint = encode_base64(sha256(x509_certificate_bytes).digest())
- sha256_fingerprints = {f["sha256"] for f in self.tls_fingerprints}
- if sha256_fingerprint not in sha256_fingerprints:
- self.tls_fingerprints.append({"sha256": sha256_fingerprint})
-
def generate_config_section(
self,
config_dir_path,
@@ -443,33 +420,6 @@ class TlsConfig(Config):
# If unspecified, we will use CONFDIR/client.key.
#
account_key_file: %(default_acme_account_file)s
-
- # List of allowed TLS fingerprints for this server to publish along
- # with the signing keys for this server. Other matrix servers that
- # make HTTPS requests to this server will check that the TLS
- # certificates returned by this server match one of the fingerprints.
- #
- # Synapse automatically adds the fingerprint of its own certificate
- # to the list. So if federation traffic is handled directly by synapse
- # then no modification to the list is required.
- #
- # If synapse is run behind a load balancer that handles the TLS then it
- # will be necessary to add the fingerprints of the certificates used by
- # the loadbalancers to this list if they are different to the one
- # synapse is using.
- #
- # Homeservers are permitted to cache the list of TLS fingerprints
- # returned in the key responses up to the "valid_until_ts" returned in
- # key. It may be necessary to publish the fingerprints of a new
- # certificate and wait until the "valid_until_ts" of the previous key
- # responses have passed before deploying it.
- #
- # You can calculate a fingerprint from a given TLS listener via:
- # openssl s_client -connect $host:$port < /dev/null 2> /dev/null |
- # openssl x509 -outform DER | openssl sha256 -binary | base64 | tr -d '='
- # or by checking matrix.org/federationtester/api/report?server_name=$host
- #
- #tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}]
"""
# Lowercase the string representation of boolean values
% {
diff --git a/synapse/config/tracer.py b/synapse/config/tracer.py
index db22b5b19f..d0ea17261f 100644
--- a/synapse/config/tracer.py
+++ b/synapse/config/tracer.py
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Set
+
from synapse.python_dependencies import DependencyException, check_requirements
from ._base import Config, ConfigError
@@ -32,6 +34,8 @@ class TracerConfig(Config):
{"sampler": {"type": "const", "param": 1}, "logging": False},
)
+ self.force_tracing_for_users: Set[str] = set()
+
if not self.opentracer_enabled:
return
@@ -48,6 +52,19 @@ class TracerConfig(Config):
if not isinstance(self.opentracer_whitelist, list):
raise ConfigError("Tracer homeserver_whitelist config is malformed")
+ force_tracing_for_users = opentracing_config.get("force_tracing_for_users", [])
+ if not isinstance(force_tracing_for_users, list):
+ raise ConfigError(
+ "Expected a list", ("opentracing", "force_tracing_for_users")
+ )
+ for i, u in enumerate(force_tracing_for_users):
+ if not isinstance(u, str):
+ raise ConfigError(
+ "Expected a string",
+ ("opentracing", "force_tracing_for_users", f"index {i}"),
+ )
+ self.force_tracing_for_users.add(u)
+
def generate_config_section(cls, **kwargs):
return """\
## Opentracing ##
@@ -64,7 +81,8 @@ class TracerConfig(Config):
#enabled: true
# The list of homeservers we wish to send and receive span contexts and span baggage.
- # See docs/opentracing.rst
+ # See docs/opentracing.rst.
+ #
# This is a list of regexes which are matched against the server_name of the
# homeserver.
#
@@ -73,19 +91,26 @@ class TracerConfig(Config):
#homeserver_whitelist:
# - ".*"
+ # A list of the matrix IDs of users whose requests will always be traced,
+ # even if the tracing system would otherwise drop the traces due to
+ # probabilistic sampling.
+ #
+ # By default, the list is empty.
+ #
+ #force_tracing_for_users:
+ # - "@user1:server_name"
+ # - "@user2:server_name"
+
# Jaeger can be configured to sample traces at different rates.
# All configuration options provided by Jaeger can be set here.
- # Jaeger's configuration mostly related to trace sampling which
+ # Jaeger's configuration is mostly related to trace sampling which
# is documented here:
- # https://www.jaegertracing.io/docs/1.13/sampling/.
+ # https://www.jaegertracing.io/docs/latest/sampling/.
#
#jaeger_config:
# sampler:
# type: const
# param: 1
-
- # Logging whether spans were started and reported
- #
# logging:
# false
"""
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 5f18ef7748..6fc0712978 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -17,7 +17,7 @@ import abc
import logging
import urllib
from collections import defaultdict
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Set, Tuple
import attr
from signedjson.key import (
@@ -42,6 +42,8 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.config.key import TrustedKeyServer
+from synapse.events import EventBase
+from synapse.events.utils import prune_event_dict
from synapse.logging.context import (
PreserveLoggingContext,
make_deferred_yieldable,
@@ -69,7 +71,11 @@ class VerifyJsonRequest:
Attributes:
server_name: The name of the server to verify against.
- json_object: The JSON object to verify.
+ get_json_object: A callback to fetch the JSON object to verify.
+ A callback is used to allow deferring the creation of the JSON
+ object to verify until needed, e.g. for events we can defer
+ creating the redacted copy. This reduces the memory usage when
+ there are large numbers of in flight requests.
minimum_valid_until_ts: time at which we require the signing key to
be valid. (0 implies we don't care)
@@ -88,14 +94,50 @@ class VerifyJsonRequest:
"""
server_name = attr.ib(type=str)
- json_object = attr.ib(type=JsonDict)
+ get_json_object = attr.ib(type=Callable[[], JsonDict])
minimum_valid_until_ts = attr.ib(type=int)
request_name = attr.ib(type=str)
- key_ids = attr.ib(init=False, type=List[str])
+ key_ids = attr.ib(type=List[str])
key_ready = attr.ib(default=attr.Factory(defer.Deferred), type=defer.Deferred)
- def __attrs_post_init__(self):
- self.key_ids = signature_ids(self.json_object, self.server_name)
+ @staticmethod
+ def from_json_object(
+ server_name: str,
+ json_object: JsonDict,
+ minimum_valid_until_ms: int,
+ request_name: str,
+ ):
+ """Create a VerifyJsonRequest to verify all signatures on a signed JSON
+ object for the given server.
+ """
+ key_ids = signature_ids(json_object, server_name)
+ return VerifyJsonRequest(
+ server_name,
+ lambda: json_object,
+ minimum_valid_until_ms,
+ request_name=request_name,
+ key_ids=key_ids,
+ )
+
+ @staticmethod
+ def from_event(
+ server_name: str,
+ event: EventBase,
+ minimum_valid_until_ms: int,
+ ):
+ """Create a VerifyJsonRequest to verify all signatures on an event
+ object for the given server.
+ """
+ key_ids = list(event.signatures.get(server_name, []))
+ return VerifyJsonRequest(
+ server_name,
+ # We defer creating the redacted json object, as it uses a lot more
+ # memory than the Event object itself.
+ lambda: prune_event_dict(event.room_version, event.get_pdu_json()),
+ minimum_valid_until_ms,
+ request_name=event.event_id,
+ key_ids=key_ids,
+ )
class KeyLookupError(ValueError):
@@ -147,8 +189,13 @@ class Keyring:
Deferred[None]: completes if the the object was correctly signed, otherwise
errbacks with an error
"""
- req = VerifyJsonRequest(server_name, json_object, validity_time, request_name)
- requests = (req,)
+ request = VerifyJsonRequest.from_json_object(
+ server_name,
+ json_object,
+ validity_time,
+ request_name,
+ )
+ requests = (request,)
return make_deferred_yieldable(self._verify_objects(requests)[0])
def verify_json_objects_for_server(
@@ -175,10 +222,41 @@ class Keyring:
logcontext.
"""
return self._verify_objects(
- VerifyJsonRequest(server_name, json_object, validity_time, request_name)
+ VerifyJsonRequest.from_json_object(
+ server_name, json_object, validity_time, request_name
+ )
for server_name, json_object, validity_time, request_name in server_and_json
)
+ def verify_events_for_server(
+ self, server_and_events: Iterable[Tuple[str, EventBase, int]]
+ ) -> List[defer.Deferred]:
+ """Bulk verification of signatures on events.
+
+ Args:
+ server_and_events:
+ Iterable of `(server_name, event, validity_time)` tuples.
+
+ `server_name` is which server we are verifying the signature for
+ on the event.
+
+ `event` is the event that we'll verify the signatures of for
+ the given `server_name`.
+
+ `validity_time` is a timestamp at which the signing key must be
+ valid.
+
+ Returns:
+ List<Deferred[None]>: for each input triplet, a deferred indicating success
+ or failure to verify each event's signature for the given
+ server_name. The deferreds run their callbacks in the sentinel
+ logcontext.
+ """
+ return self._verify_objects(
+ VerifyJsonRequest.from_event(server_name, event, validity_time)
+ for server_name, event, validity_time in server_and_events
+ )
+
def _verify_objects(
self, verify_requests: Iterable[VerifyJsonRequest]
) -> List[defer.Deferred]:
@@ -892,7 +970,7 @@ async def _handle_key_deferred(verify_request: VerifyJsonRequest) -> None:
with PreserveLoggingContext():
_, key_id, verify_key = await verify_request.key_ready
- json_object = verify_request.json_object
+ json_object = verify_request.get_json_object()
try:
verify_signed_json(json_object, server_name, verify_key)
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 949dcd4614..3fe496dcd3 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -137,11 +137,7 @@ class FederationBase:
return deferreds
-class PduToCheckSig(
- namedtuple(
- "PduToCheckSig", ["pdu", "redacted_pdu_json", "sender_domain", "deferreds"]
- )
-):
+class PduToCheckSig(namedtuple("PduToCheckSig", ["pdu", "sender_domain", "deferreds"])):
pass
@@ -184,7 +180,6 @@ def _check_sigs_on_pdus(
pdus_to_check = [
PduToCheckSig(
pdu=p,
- redacted_pdu_json=prune_event(p).get_pdu_json(),
sender_domain=get_domain_from_id(p.sender),
deferreds=[],
)
@@ -195,13 +190,12 @@ def _check_sigs_on_pdus(
# (except if its a 3pid invite, in which case it may be sent by any server)
pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)]
- more_deferreds = keyring.verify_json_objects_for_server(
+ more_deferreds = keyring.verify_events_for_server(
[
(
p.sender_domain,
- p.redacted_pdu_json,
+ p.pdu,
p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
- p.pdu.event_id,
)
for p in pdus_to_check_sender
]
@@ -230,13 +224,12 @@ def _check_sigs_on_pdus(
if p.sender_domain != get_domain_from_id(p.pdu.event_id)
]
- more_deferreds = keyring.verify_json_objects_for_server(
+ more_deferreds = keyring.verify_events_for_server(
[
(
get_domain_from_id(p.pdu.event_id),
- p.redacted_pdu_json,
+ p.pdu,
p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
- p.pdu.event_id,
)
for p in pdus_to_check_event_id
]
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 5125441df5..3feb60da2a 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -57,6 +57,7 @@ from synapse.api.room_versions import (
)
from synapse.events import EventBase, builder
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
+from synapse.federation.transport.client import SendJoinResponse
from synapse.logging.context import make_deferred_yieldable, preserve_fn
from synapse.logging.utils import log_function
from synapse.types import JsonDict, get_domain_from_id
@@ -667,19 +668,10 @@ class FederationClient(FederationBase):
"""
async def send_request(destination) -> Dict[str, Any]:
- content = await self._do_send_join(destination, pdu)
+ response = await self._do_send_join(room_version, destination, pdu)
- logger.debug("Got content: %s", content)
-
- state = [
- event_from_pdu_json(p, room_version, outlier=True)
- for p in content.get("state", [])
- ]
-
- auth_chain = [
- event_from_pdu_json(p, room_version, outlier=True)
- for p in content.get("auth_chain", [])
- ]
+ state = response.state
+ auth_chain = response.auth_events
pdus = {p.event_id: p for p in itertools.chain(state, auth_chain)}
@@ -754,11 +746,14 @@ class FederationClient(FederationBase):
return await self._try_destination_list("send_join", destinations, send_request)
- async def _do_send_join(self, destination: str, pdu: EventBase) -> JsonDict:
+ async def _do_send_join(
+ self, room_version: RoomVersion, destination: str, pdu: EventBase
+ ) -> SendJoinResponse:
time_now = self._clock.time_msec()
try:
return await self.transport_layer.send_join_v2(
+ room_version=room_version,
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
@@ -773,17 +768,14 @@ class FederationClient(FederationBase):
logger.debug("Couldn't send_join with the v2 API, falling back to the v1 API")
- resp = await self.transport_layer.send_join_v1(
+ return await self.transport_layer.send_join_v1(
+ room_version=room_version,
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
- # We expect the v1 API to respond with [200, content], so we only return the
- # content.
- return resp[1]
-
async def send_invite(
self,
destination: str,
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index a42e055501..bf5b541deb 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -19,18 +19,29 @@ import logging
import urllib
from typing import Any, Dict, List, Optional
+import attr
+import ijson
+
from synapse.api.constants import Membership
from synapse.api.errors import Codes, HttpResponseException, SynapseError
+from synapse.api.room_versions import RoomVersion
from synapse.api.urls import (
FEDERATION_UNSTABLE_PREFIX,
FEDERATION_V1_PREFIX,
FEDERATION_V2_PREFIX,
)
+from synapse.events import EventBase, make_event_from_dict
+from synapse.http.matrixfederationclient import ByteParser
from synapse.logging.utils import log_function
from synapse.types import JsonDict
logger = logging.getLogger(__name__)
+# Send join responses can be huge, so we set a separate limit here. The response
+# is parsed in a streaming manner, which helps alleviate the issue of memory
+# usage a bit.
+MAX_RESPONSE_SIZE_SEND_JOIN = 500 * 1024 * 1024
+
class TransportLayerClient:
"""Sends federation HTTP requests to other servers"""
@@ -253,21 +264,38 @@ class TransportLayerClient:
return content
@log_function
- async def send_join_v1(self, destination, room_id, event_id, content):
+ async def send_join_v1(
+ self,
+ room_version,
+ destination,
+ room_id,
+ event_id,
+ content,
+ ) -> "SendJoinResponse":
path = _create_v1_path("/send_join/%s/%s", room_id, event_id)
response = await self.client.put_json(
- destination=destination, path=path, data=content
+ destination=destination,
+ path=path,
+ data=content,
+ parser=SendJoinParser(room_version, v1_api=True),
+ max_response_size=MAX_RESPONSE_SIZE_SEND_JOIN,
)
return response
@log_function
- async def send_join_v2(self, destination, room_id, event_id, content):
+ async def send_join_v2(
+ self, room_version, destination, room_id, event_id, content
+ ) -> "SendJoinResponse":
path = _create_v2_path("/send_join/%s/%s", room_id, event_id)
response = await self.client.put_json(
- destination=destination, path=path, data=content
+ destination=destination,
+ path=path,
+ data=content,
+ parser=SendJoinParser(room_version, v1_api=False),
+ max_response_size=MAX_RESPONSE_SIZE_SEND_JOIN,
)
return response
@@ -1119,3 +1147,59 @@ def _create_v2_path(path, *args):
str
"""
return _create_path(FEDERATION_V2_PREFIX, path, *args)
+
+
+@attr.s(slots=True, auto_attribs=True)
+class SendJoinResponse:
+ """The parsed response of a `/send_join` request."""
+
+ auth_events: List[EventBase]
+ state: List[EventBase]
+
+
+@ijson.coroutine
+def _event_list_parser(room_version: RoomVersion, events: List[EventBase]):
+ """Helper function for use with `ijson.items_coro` to parse an array of
+ events and add them to the given list.
+ """
+
+ while True:
+ obj = yield
+ event = make_event_from_dict(obj, room_version)
+ events.append(event)
+
+
+class SendJoinParser(ByteParser[SendJoinResponse]):
+ """A parser for the response to `/send_join` requests.
+
+ Args:
+ room_version: The version of the room.
+ v1_api: Whether the response is in the v1 format.
+ """
+
+ CONTENT_TYPE = "application/json"
+
+ def __init__(self, room_version: RoomVersion, v1_api: bool):
+ self._response = SendJoinResponse([], [])
+
+ # The V1 API has the shape of `[200, {...}]`, which we handle by
+ # prefixing with `item.*`.
+ prefix = "item." if v1_api else ""
+
+ self._coro_state = ijson.items_coro(
+ _event_list_parser(room_version, self._response.state),
+ prefix + "state.item",
+ )
+ self._coro_auth = ijson.items_coro(
+ _event_list_parser(room_version, self._response.auth_events),
+ prefix + "auth_chain.item",
+ )
+
+ def write(self, data: bytes) -> int:
+ self._coro_state.send(data)
+ self._coro_auth.send(data)
+
+ return len(data)
+
+ def finish(self) -> SendJoinResponse:
+ return self._response
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index dd5ab5160a..086d999d98 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -162,7 +162,7 @@ class Authenticator:
# If we get a valid signed request from the other side, its probably
# alive
retry_timings = await self.store.get_destination_retry_timings(origin)
- if retry_timings and retry_timings["retry_last_ts"]:
+ if retry_timings and retry_timings.retry_last_ts:
run_in_background(self._reset_retry_timings, origin)
return origin
@@ -1479,7 +1479,7 @@ class FederationSpaceSummaryServlet(BaseFederationServlet):
)
return 200, await self.handler.federation_space_summary(
- room_id, suggested_only, max_rooms_per_space, exclude_rooms
+ origin, room_id, suggested_only, max_rooms_per_space, exclude_rooms
)
# TODO When switching to the stable endpoint, remove the POST handler.
@@ -1509,7 +1509,7 @@ class FederationSpaceSummaryServlet(BaseFederationServlet):
)
return 200, await self.handler.federation_space_summary(
- room_id, suggested_only, max_rooms_per_space, exclude_rooms
+ origin, room_id, suggested_only, max_rooms_per_space, exclude_rooms
)
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index 022789ea5f..640c2e9fd6 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -15,12 +15,9 @@
import email.mime.multipart
import email.utils
import logging
-from email.mime.multipart import MIMEMultipart
-from email.mime.text import MIMEText
from typing import TYPE_CHECKING, List, Optional, Tuple
from synapse.api.errors import StoreError, SynapseError
-from synapse.logging.context import make_deferred_yieldable
from synapse.metrics.background_process_metrics import (
run_as_background_process,
wrap_as_background_process,
@@ -39,9 +36,11 @@ class AccountValidityHandler:
self.hs = hs
self.config = hs.config
self.store = self.hs.get_datastore()
- self.sendmail = self.hs.get_sendmail()
+ self.send_email_handler = self.hs.get_send_email_handler()
self.clock = self.hs.get_clock()
+ self._app_name = self.hs.config.email_app_name
+
self._account_validity_enabled = (
hs.config.account_validity.account_validity_enabled
)
@@ -69,23 +68,10 @@ class AccountValidityHandler:
self._template_text = (
hs.config.account_validity.account_validity_template_text
)
- account_validity_renew_email_subject = (
+ self._renew_email_subject = (
hs.config.account_validity.account_validity_renew_email_subject
)
- try:
- app_name = hs.config.email_app_name
-
- self._subject = account_validity_renew_email_subject % {"app": app_name}
-
- self._from_string = hs.config.email_notif_from % {"app": app_name}
- except Exception:
- # If substitution failed, fall back to the bare strings.
- self._subject = account_validity_renew_email_subject
- self._from_string = hs.config.email_notif_from
-
- self._raw_from = email.utils.parseaddr(self._from_string)[1]
-
# Check the renewal emails to send and send them every 30min.
if hs.config.run_background_tasks:
self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000)
@@ -177,38 +163,17 @@ class AccountValidityHandler:
}
html_text = self._template_html.render(**template_vars)
- html_part = MIMEText(html_text, "html", "utf8")
-
plain_text = self._template_text.render(**template_vars)
- text_part = MIMEText(plain_text, "plain", "utf8")
for address in addresses:
raw_to = email.utils.parseaddr(address)[1]
- multipart_msg = MIMEMultipart("alternative")
- multipart_msg["Subject"] = self._subject
- multipart_msg["From"] = self._from_string
- multipart_msg["To"] = address
- multipart_msg["Date"] = email.utils.formatdate()
- multipart_msg["Message-ID"] = email.utils.make_msgid()
- multipart_msg.attach(text_part)
- multipart_msg.attach(html_part)
-
- logger.info("Sending renewal email to %s", address)
-
- await make_deferred_yieldable(
- self.sendmail(
- self.hs.config.email_smtp_host,
- self._raw_from,
- raw_to,
- multipart_msg.as_string().encode("utf8"),
- reactor=self.hs.get_reactor(),
- port=self.hs.config.email_smtp_port,
- requireAuthentication=self.hs.config.email_smtp_user is not None,
- username=self.hs.config.email_smtp_user,
- password=self.hs.config.email_smtp_pass,
- requireTransportSecurity=self.hs.config.require_transport_security,
- )
+ await self.send_email_handler.send_email(
+ email_address=raw_to,
+ subject=self._renew_email_subject,
+ app_name=self._app_name,
+ html=html_text,
+ text=plain_text,
)
await self.store.set_renewal_mail_status(user_id=user_id, email_sent=True)
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index eff639f407..a0df16a32f 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -11,10 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Collection, Optional
-from synapse.api.constants import EventTypes, JoinRules
+from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.api.errors import AuthError
from synapse.api.room_versions import RoomVersion
+from synapse.events import EventBase
from synapse.types import StateMap
if TYPE_CHECKING:
@@ -29,46 +31,104 @@ class EventAuthHandler:
def __init__(self, hs: "HomeServer"):
self._store = hs.get_datastore()
- async def can_join_without_invite(
- self, state_ids: StateMap[str], room_version: RoomVersion, user_id: str
- ) -> bool:
+ async def check_restricted_join_rules(
+ self,
+ state_ids: StateMap[str],
+ room_version: RoomVersion,
+ user_id: str,
+ prev_member_event: Optional[EventBase],
+ ) -> None:
"""
- Check whether a user can join a room without an invite.
+ Check whether a user can join a room without an invite due to restricted join rules.
When joining a room with restricted joined rules (as defined in MSC3083),
- the membership of spaces must be checked during join.
+ the membership of spaces must be checked during a room join.
Args:
state_ids: The state of the room as it currently is.
room_version: The room version of the room being joined.
user_id: The user joining the room.
+ prev_member_event: The current membership event for this user.
+
+ Raises:
+ AuthError if the user cannot join the room.
+ """
+ # If the member is invited or currently joined, then nothing to do.
+ if prev_member_event and (
+ prev_member_event.membership in (Membership.JOIN, Membership.INVITE)
+ ):
+ return
+
+ # This is not a room with a restricted join rule, so we don't need to do the
+ # restricted room specific checks.
+ #
+ # Note: We'll be applying the standard join rule checks later, which will
+ # catch the cases of e.g. trying to join private rooms without an invite.
+ if not await self.has_restricted_join_rules(state_ids, room_version):
+ return
+
+ # Get the spaces which allow access to this room and check if the user is
+ # in any of them.
+ allowed_spaces = await self.get_spaces_that_allow_join(state_ids)
+ if not await self.is_user_in_rooms(allowed_spaces, user_id):
+ raise AuthError(
+ 403,
+ "You do not belong to any of the required spaces to join this room.",
+ )
+
+ async def has_restricted_join_rules(
+ self, state_ids: StateMap[str], room_version: RoomVersion
+ ) -> bool:
+ """
+ Return if the room has the proper join rules set for access via spaces.
+
+ Args:
+ state_ids: The state of the room as it currently is.
+ room_version: The room version of the room to query.
Returns:
- True if the user can join the room, false otherwise.
+ True if the proper room version and join rules are set for restricted access.
"""
# This only applies to room versions which support the new join rule.
if not room_version.msc3083_join_rules:
- return True
+ return False
# If there's no join rule, then it defaults to invite (so this doesn't apply).
join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""), None)
if not join_rules_event_id:
- return True
+ return False
+
+ # If the join rule is not restricted, this doesn't apply.
+ join_rules_event = await self._store.get_event(join_rules_event_id)
+ return join_rules_event.content.get("join_rule") == JoinRules.MSC3083_RESTRICTED
+
+ async def get_spaces_that_allow_join(
+ self, state_ids: StateMap[str]
+ ) -> Collection[str]:
+ """
+ Generate a list of spaces which allow access to a room.
+
+ Args:
+ state_ids: The state of the room as it currently is.
+
+ Returns:
+ A collection of spaces which provide membership to the room.
+ """
+ # If there's no join rule, then it defaults to invite (so this doesn't apply).
+ join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""), None)
+ if not join_rules_event_id:
+ return ()
# If the join rule is not restricted, this doesn't apply.
join_rules_event = await self._store.get_event(join_rules_event_id)
- if join_rules_event.content.get("join_rule") != JoinRules.MSC3083_RESTRICTED:
- return True
# If allowed is of the wrong form, then only allow invited users.
allowed_spaces = join_rules_event.content.get("allow", [])
if not isinstance(allowed_spaces, list):
- return False
-
- # Get the list of joined rooms and see if there's an overlap.
- joined_rooms = await self._store.get_rooms_for_user(user_id)
+ return ()
# Pull out the other room IDs, invalid data gets filtered.
+ result = []
for space in allowed_spaces:
if not isinstance(space, dict):
continue
@@ -77,10 +137,31 @@ class EventAuthHandler:
if not isinstance(space_id, str):
continue
- # The user was joined to one of the spaces specified, they can join
- # this room!
- if space_id in joined_rooms:
+ result.append(space_id)
+
+ return result
+
+ async def is_user_in_rooms(self, room_ids: Collection[str], user_id: str) -> bool:
+ """
+ Check whether a user is a member of any of the provided rooms.
+
+ Args:
+ room_ids: The rooms to check for membership.
+ user_id: The user to check.
+
+ Returns:
+ True if the user is in any of the rooms, false otherwise.
+ """
+ if not room_ids:
+ return False
+
+ # Get the list of joined rooms and see if there's an overlap.
+ joined_rooms = await self._store.get_rooms_for_user(user_id)
+
+ # Check each room and see if the user is in it.
+ for room_id in room_ids:
+ if room_id in joined_rooms:
return True
- # The user was not in any of the required spaces.
+ # The user was not in any of the rooms.
return False
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 6a5c33f212..36652289a4 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -92,6 +92,7 @@ from synapse.types import (
get_domain_from_id,
)
from synapse.util.async_helpers import Linearizer, concurrently_execute
+from synapse.util.iterutils import batch_iter
from synapse.util.retryutils import NotRetryingDestination
from synapse.util.stringutils import shortstr
from synapse.visibility import filter_events_for_server
@@ -1741,28 +1742,17 @@ class FederationHandler(BaseHandler):
# Check if the user is already in the room or invited to the room.
user_id = event.state_key
prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
- newly_joined = True
- user_is_invited = False
+ prev_member_event = None
if prev_member_event_id:
prev_member_event = await self.store.get_event(prev_member_event_id)
- newly_joined = prev_member_event.membership != Membership.JOIN
- user_is_invited = prev_member_event.membership == Membership.INVITE
-
- # If the member is not already in the room, and not invited, check if
- # they should be allowed access via membership in a space.
- if (
- newly_joined
- and not user_is_invited
- and not await self._event_auth_handler.can_join_without_invite(
- prev_state_ids,
- event.room_version,
- user_id,
- )
- ):
- raise AuthError(
- 403,
- "You do not belong to any of the required spaces to join this room.",
- )
+
+ # Check if the member should be allowed access via membership in a space.
+ await self._event_auth_handler.check_restricted_join_rules(
+ prev_state_ids,
+ event.room_version,
+ user_id,
+ prev_member_event,
+ )
# Persist the event.
await self._auth_and_persist_event(origin, event, context)
@@ -3258,13 +3248,15 @@ class FederationHandler(BaseHandler):
"""
instance = self.config.worker.events_shard_config.get_instance(room_id)
if instance != self._instance_name:
- result = await self._send_events(
- instance_name=instance,
- store=self.store,
- room_id=room_id,
- event_and_contexts=event_and_contexts,
- backfilled=backfilled,
- )
+ # Limit the number of events sent over federation.
+ for batch in batch_iter(event_and_contexts, 1000):
+ result = await self._send_events(
+ instance_name=instance,
+ store=self.store,
+ room_id=room_id,
+ event_and_contexts=batch,
+ backfilled=backfilled,
+ )
return result["max_stream_id"]
else:
assert self.storage.persistence
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 6fd1f34289..f5a049d754 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -222,9 +222,21 @@ class BasePresenceHandler(abc.ABC):
@abc.abstractmethod
async def set_state(
- self, target_user: UserID, state: JsonDict, ignore_status_msg: bool = False
+ self,
+ target_user: UserID,
+ state: JsonDict,
+ ignore_status_msg: bool = False,
+ force_notify: bool = False,
) -> None:
- """Set the presence state of the user. """
+ """Set the presence state of the user.
+
+ Args:
+ target_user: The ID of the user to set the presence state of.
+ state: The presence state as a JSON dictionary.
+ ignore_status_msg: True to ignore the "status_msg" field of the `state` dict.
+ If False, the user's current status will be updated.
+ force_notify: Whether to force notification of the update to clients.
+ """
@abc.abstractmethod
async def bump_presence_active_time(self, user: UserID):
@@ -296,6 +308,51 @@ class BasePresenceHandler(abc.ABC):
for destinations, states in hosts_and_states:
self._federation.send_presence_to_destinations(states, destinations)
+ async def send_full_presence_to_users(self, user_ids: Collection[str]):
+ """
+ Adds to the list of users who should receive a full snapshot of presence
+ upon their next sync. Note that this only works for local users.
+
+ Then, grabs the current presence state for a given set of users and adds it
+ to the top of the presence stream.
+
+ Args:
+ user_ids: The IDs of the local users to send full presence to.
+ """
+ # Retrieve one of the users from the given set
+ if not user_ids:
+ raise Exception(
+ "send_full_presence_to_users must be called with at least one user"
+ )
+ user_id = next(iter(user_ids))
+
+ # Mark all users as receiving full presence on their next sync
+ await self.store.add_users_to_send_full_presence_to(user_ids)
+
+ # Add a new entry to the presence stream. Since we use stream tokens to determine whether a
+ # local user should receive a full snapshot of presence when they sync, we need to bump the
+ # presence stream so that subsequent syncs with no presence activity in between won't result
+ # in the client receiving multiple full snapshots of presence.
+ #
+ # If we bump the stream ID, then the user will get a higher stream token next sync, and thus
+ # correctly won't receive a second snapshot.
+
+ # Get the current presence state for one of the users (defaults to offline if not found)
+ current_presence_state = await self.get_state(UserID.from_string(user_id))
+
+ # Convert the UserPresenceState object into a serializable dict
+ state = {
+ "presence": current_presence_state.state,
+ "status_message": current_presence_state.status_msg,
+ }
+
+ # Copy the presence state to the tip of the presence stream.
+
+ # We set force_notify=True here so that this presence update is guaranteed to
+ # increment the presence stream ID (which resending the current user's presence
+ # otherwise would not do).
+ await self.set_state(UserID.from_string(user_id), state, force_notify=True)
+
class _NullContextManager(ContextManager[None]):
"""A context manager which does nothing."""
@@ -480,8 +537,17 @@ class WorkerPresenceHandler(BasePresenceHandler):
target_user: UserID,
state: JsonDict,
ignore_status_msg: bool = False,
+ force_notify: bool = False,
) -> None:
- """Set the presence state of the user."""
+ """Set the presence state of the user.
+
+ Args:
+ target_user: The ID of the user to set the presence state of.
+ state: The presence state as a JSON dictionary.
+ ignore_status_msg: True to ignore the "status_msg" field of the `state` dict.
+ If False, the user's current status will be updated.
+ force_notify: Whether to force notification of the update to clients.
+ """
presence = state["presence"]
valid_presence = (
@@ -508,6 +574,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
user_id=user_id,
state=state,
ignore_status_msg=ignore_status_msg,
+ force_notify=force_notify,
)
async def bump_presence_active_time(self, user: UserID) -> None:
@@ -677,13 +744,19 @@ class PresenceHandler(BasePresenceHandler):
[self.user_to_current_state[user_id] for user_id in unpersisted]
)
- async def _update_states(self, new_states: Iterable[UserPresenceState]) -> None:
+ async def _update_states(
+ self, new_states: Iterable[UserPresenceState], force_notify: bool = False
+ ) -> None:
"""Updates presence of users. Sets the appropriate timeouts. Pokes
the notifier and federation if and only if the changed presence state
should be sent to clients/servers.
Args:
new_states: The new user presence state updates to process.
+ force_notify: Whether to force notifying clients of this presence state update,
+ even if it doesn't change the state of a user's presence (e.g online -> online).
+ This is currently used to bump the max presence stream ID without changing any
+ user's presence (see PresenceHandler.add_users_to_send_full_presence_to).
"""
now = self.clock.time_msec()
@@ -720,6 +793,9 @@ class PresenceHandler(BasePresenceHandler):
now=now,
)
+ if force_notify:
+ should_notify = True
+
self.user_to_current_state[user_id] = new_state
if should_notify:
@@ -1058,9 +1134,21 @@ class PresenceHandler(BasePresenceHandler):
await self._update_states(updates)
async def set_state(
- self, target_user: UserID, state: JsonDict, ignore_status_msg: bool = False
+ self,
+ target_user: UserID,
+ state: JsonDict,
+ ignore_status_msg: bool = False,
+ force_notify: bool = False,
) -> None:
- """Set the presence state of the user."""
+ """Set the presence state of the user.
+
+ Args:
+ target_user: The ID of the user to set the presence state of.
+ state: The presence state as a JSON dictionary.
+ ignore_status_msg: True to ignore the "status_msg" field of the `state` dict.
+ If False, the user's current status will be updated.
+ force_notify: Whether to force notification of the update to clients.
+ """
status_msg = state.get("status_msg", None)
presence = state["presence"]
@@ -1091,7 +1179,9 @@ class PresenceHandler(BasePresenceHandler):
):
new_fields["last_active_ts"] = self.clock.time_msec()
- await self._update_states([prev_state.copy_and_replace(**new_fields)])
+ await self._update_states(
+ [prev_state.copy_and_replace(**new_fields)], force_notify=force_notify
+ )
async def is_visible(self, observed_user: UserID, observer_user: UserID) -> bool:
"""Returns whether a user can see another user's presence."""
@@ -1389,11 +1479,10 @@ class PresenceEventSource:
#
# Presence -> Notifier -> PresenceEventSource -> Presence
#
- # Same with get_module_api, get_presence_router
+ # Same with get_presence_router:
#
# AuthHandler -> Notifier -> PresenceEventSource -> ModuleApi -> AuthHandler
self.get_presence_handler = hs.get_presence_handler
- self.get_module_api = hs.get_module_api
self.get_presence_router = hs.get_presence_router
self.clock = hs.get_clock()
self.store = hs.get_datastore()
@@ -1424,16 +1513,21 @@ class PresenceEventSource:
stream_change_cache = self.store.presence_stream_cache
with Measure(self.clock, "presence.get_new_events"):
- if user_id in self.get_module_api()._send_full_presence_to_local_users:
- # This user has been specified by a module to receive all current, online
- # user presence. Removing from_key and setting include_offline to false
- # will do effectively this.
- from_key = None
- include_offline = False
-
if from_key is not None:
from_key = int(from_key)
+ # Check if this user should receive all current, online user presence. We only
+ # bother to do this if from_key is set, as otherwise the user will receive all
+ # user presence anyways.
+ if await self.store.should_user_receive_full_presence_with_token(
+ user_id, from_key
+ ):
+ # This user has been specified by a module to receive all current, online
+ # user presence. Removing from_key and setting include_offline to false
+ # will do effectively this.
+ from_key = None
+ include_offline = False
+
max_token = self.store.get_current_presence_token()
if from_key == max_token:
# This is necessary as due to the way stream ID generators work
@@ -1467,12 +1561,6 @@ class PresenceEventSource:
user_id, include_offline, from_key
)
- # Remove the user from the list of users to receive all presence
- if user_id in self.get_module_api()._send_full_presence_to_local_users:
- self.get_module_api()._send_full_presence_to_local_users.remove(
- user_id
- )
-
return presence_updates, max_token
# Make mypy happy. users_interested_in should now be a set
@@ -1522,10 +1610,6 @@ class PresenceEventSource:
)
presence_updates = list(users_to_state.values())
- # Remove the user from the list of users to receive all presence
- if user_id in self.get_module_api()._send_full_presence_to_local_users:
- self.get_module_api()._send_full_presence_to_local_users.remove(user_id)
-
if not include_offline:
# Filter out offline presence states
presence_updates = self._filter_offline_presence_state(presence_updates)
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 835d5862c1..61900d87df 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -307,25 +307,15 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if event.membership == Membership.JOIN:
newly_joined = True
- user_is_invited = False
+ prev_member_event = None
if prev_member_event_id:
prev_member_event = await self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
- user_is_invited = prev_member_event.membership == Membership.INVITE
- # If the member is not already in the room and is not accepting an invite,
- # check if they should be allowed access via membership in a space.
- if (
- newly_joined
- and not user_is_invited
- and not await self.event_auth_handler.can_join_without_invite(
- prev_state_ids, event.room_version, user_id
- )
- ):
- raise AuthError(
- 403,
- "You do not belong to any of the required spaces to join this room.",
- )
+ # Check if the member should be allowed access via membership in a space.
+ await self.event_auth_handler.check_restricted_join_rules(
+ prev_state_ids, event.room_version, user_id, prev_member_event
+ )
# Only rate-limit if the user actually joined the room, otherwise we'll end
# up blocking profile updates.
diff --git a/synapse/handlers/send_email.py b/synapse/handlers/send_email.py
new file mode 100644
index 0000000000..e9f6aef06f
--- /dev/null
+++ b/synapse/handlers/send_email.py
@@ -0,0 +1,98 @@
+# Copyright 2021 The Matrix.org C.I.C. Foundation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import email.utils
+import logging
+from email.mime.multipart import MIMEMultipart
+from email.mime.text import MIMEText
+from typing import TYPE_CHECKING
+
+from synapse.logging.context import make_deferred_yieldable
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class SendEmailHandler:
+ def __init__(self, hs: "HomeServer"):
+ self.hs = hs
+
+ self._sendmail = hs.get_sendmail()
+ self._reactor = hs.get_reactor()
+
+ self._from = hs.config.email.email_notif_from
+ self._smtp_host = hs.config.email.email_smtp_host
+ self._smtp_port = hs.config.email.email_smtp_port
+ self._smtp_user = hs.config.email.email_smtp_user
+ self._smtp_pass = hs.config.email.email_smtp_pass
+ self._require_transport_security = hs.config.email.require_transport_security
+
+ async def send_email(
+ self,
+ email_address: str,
+ subject: str,
+ app_name: str,
+ html: str,
+ text: str,
+ ) -> None:
+ """Send a multipart email with the given information.
+
+ Args:
+ email_address: The address to send the email to.
+ subject: The email's subject.
+ app_name: The app name to include in the From header.
+ html: The HTML content to include in the email.
+ text: The plain text content to include in the email.
+ """
+ try:
+ from_string = self._from % {"app": app_name}
+ except (KeyError, TypeError):
+ from_string = self._from
+
+ raw_from = email.utils.parseaddr(from_string)[1]
+ raw_to = email.utils.parseaddr(email_address)[1]
+
+ if raw_to == "":
+ raise RuntimeError("Invalid 'to' address")
+
+ html_part = MIMEText(html, "html", "utf8")
+ text_part = MIMEText(text, "plain", "utf8")
+
+ multipart_msg = MIMEMultipart("alternative")
+ multipart_msg["Subject"] = subject
+ multipart_msg["From"] = from_string
+ multipart_msg["To"] = email_address
+ multipart_msg["Date"] = email.utils.formatdate()
+ multipart_msg["Message-ID"] = email.utils.make_msgid()
+ multipart_msg.attach(text_part)
+ multipart_msg.attach(html_part)
+
+ logger.info("Sending email to %s" % email_address)
+
+ await make_deferred_yieldable(
+ self._sendmail(
+ self._smtp_host,
+ raw_from,
+ raw_to,
+ multipart_msg.as_string().encode("utf8"),
+ reactor=self._reactor,
+ port=self._smtp_port,
+ requireAuthentication=self._smtp_user is not None,
+ username=self._smtp_user,
+ password=self._smtp_pass,
+ requireTransportSecurity=self._require_transport_security,
+ )
+ )
diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py
index e35d91832b..abd9ddecca 100644
--- a/synapse/handlers/space_summary.py
+++ b/synapse/handlers/space_summary.py
@@ -16,11 +16,16 @@ import itertools
import logging
import re
from collections import deque
-from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Set, Tuple, cast
+from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Set, Tuple
import attr
-from synapse.api.constants import EventContentFields, EventTypes, HistoryVisibility
+from synapse.api.constants import (
+ EventContentFields,
+ EventTypes,
+ HistoryVisibility,
+ Membership,
+)
from synapse.api.errors import AuthError
from synapse.events import EventBase
from synapse.events.utils import format_event_for_client_v2
@@ -32,7 +37,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# number of rooms to return. We'll stop once we hit this limit.
-# TODO: allow clients to reduce this with a request param.
MAX_ROOMS = 50
# max number of events to return per room.
@@ -46,8 +50,7 @@ class SpaceSummaryHandler:
def __init__(self, hs: "HomeServer"):
self._clock = hs.get_clock()
self._auth = hs.get_auth()
- self._room_list_handler = hs.get_room_list_handler()
- self._state_handler = hs.get_state_handler()
+ self._event_auth_handler = hs.get_event_auth_handler()
self._store = hs.get_datastore()
self._event_serializer = hs.get_event_client_serializer()
self._server_name = hs.hostname
@@ -112,28 +115,88 @@ class SpaceSummaryHandler:
max_children = max_rooms_per_space if processed_rooms else None
if is_in_room:
- rooms, events = await self._summarize_local_room(
- requester, room_id, suggested_only, max_children
+ room, events = await self._summarize_local_room(
+ requester, None, room_id, suggested_only, max_children
)
+
+ logger.debug(
+ "Query of local room %s returned events %s",
+ room_id,
+ ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in events],
+ )
+
+ if room:
+ rooms_result.append(room)
else:
- rooms, events = await self._summarize_remote_room(
+ fed_rooms, fed_events = await self._summarize_remote_room(
queue_entry,
suggested_only,
max_children,
exclude_rooms=processed_rooms,
)
- logger.debug(
- "Query of %s returned rooms %s, events %s",
- queue_entry.room_id,
- [room.get("room_id") for room in rooms],
- ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in events],
- )
-
- rooms_result.extend(rooms)
-
- # any rooms returned don't need visiting again
- processed_rooms.update(cast(str, room.get("room_id")) for room in rooms)
+ # The results over federation might include rooms that the we,
+ # as the requesting server, are allowed to see, but the requesting
+ # user is not permitted see.
+ #
+ # Filter the returned results to only what is accessible to the user.
+ room_ids = set()
+ events = []
+ for room in fed_rooms:
+ fed_room_id = room.get("room_id")
+ if not fed_room_id or not isinstance(fed_room_id, str):
+ continue
+
+ # The room should only be included in the summary if:
+ # a. the user is in the room;
+ # b. the room is world readable; or
+ # c. the user is in a space that has been granted access to
+ # the room.
+ #
+ # Note that we know the user is not in the root room (which is
+ # why the remote call was made in the first place), but the user
+ # could be in one of the children rooms and we just didn't know
+ # about the link.
+ include_room = room.get("world_readable") is True
+
+ # Check if the user is a member of any of the allowed spaces
+ # from the response.
+ allowed_spaces = room.get("allowed_spaces")
+ if (
+ not include_room
+ and allowed_spaces
+ and isinstance(allowed_spaces, list)
+ ):
+ include_room = await self._event_auth_handler.is_user_in_rooms(
+ allowed_spaces, requester
+ )
+
+ # Finally, if this isn't the requested room, check ourselves
+ # if we can access the room.
+ if not include_room and fed_room_id != queue_entry.room_id:
+ include_room = await self._is_room_accessible(
+ fed_room_id, requester, None
+ )
+
+ # The user can see the room, include it!
+ if include_room:
+ rooms_result.append(room)
+ room_ids.add(fed_room_id)
+
+ # All rooms returned don't need visiting again (even if the user
+ # didn't have access to them).
+ processed_rooms.add(fed_room_id)
+
+ for event in fed_events:
+ if event.get("room_id") in room_ids:
+ events.append(event)
+
+ logger.debug(
+ "Query of %s returned rooms %s, events %s",
+ room_id,
+ [room.get("room_id") for room in fed_rooms],
+ ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in fed_events],
+ )
# the room we queried may or may not have been returned, but don't process
# it again, anyway.
@@ -159,10 +222,16 @@ class SpaceSummaryHandler:
)
processed_events.add(ev_key)
+ # Before returning to the client, remove the allowed_spaces key for any
+ # rooms.
+ for room in rooms_result:
+ room.pop("allowed_spaces", None)
+
return {"rooms": rooms_result, "events": events_result}
async def federation_space_summary(
self,
+ origin: str,
room_id: str,
suggested_only: bool,
max_rooms_per_space: Optional[int],
@@ -172,6 +241,8 @@ class SpaceSummaryHandler:
Implementation of the space summary Federation API
Args:
+ origin: The server requesting the spaces summary.
+
room_id: room id to start the summary at
suggested_only: whether we should only return children with the "suggested"
@@ -206,14 +277,15 @@ class SpaceSummaryHandler:
logger.debug("Processing room %s", room_id)
- rooms, events = await self._summarize_local_room(
- None, room_id, suggested_only, max_rooms_per_space
+ room, events = await self._summarize_local_room(
+ None, origin, room_id, suggested_only, max_rooms_per_space
)
processed_rooms.add(room_id)
- rooms_result.extend(rooms)
- events_result.extend(events)
+ if room:
+ rooms_result.append(room)
+ events_result.extend(events)
# add any children to the queue
room_queue.extend(edge_event["state_key"] for edge_event in events)
@@ -223,19 +295,27 @@ class SpaceSummaryHandler:
async def _summarize_local_room(
self,
requester: Optional[str],
+ origin: Optional[str],
room_id: str,
suggested_only: bool,
max_children: Optional[int],
- ) -> Tuple[Sequence[JsonDict], Sequence[JsonDict]]:
+ ) -> Tuple[Optional[JsonDict], Sequence[JsonDict]]:
"""
Generate a room entry and a list of event entries for a given room.
Args:
- requester: The requesting user, or None if this is over federation.
+ requester:
+ The user requesting the summary, if it is a local request. None
+ if this is a federation request.
+ origin:
+ The server requesting the summary, if it is a federation request.
+ None if this is a local request.
room_id: The room ID to summarize.
suggested_only: True if only suggested children should be returned.
Otherwise, all children are returned.
- max_children: The maximum number of children to return for this node.
+ max_children:
+ The maximum number of children rooms to include. This is capped
+ to a server-set limit.
Returns:
A tuple of:
@@ -244,8 +324,8 @@ class SpaceSummaryHandler:
An iterable of the sorted children events. This may be limited
to a maximum size or may include all children.
"""
- if not await self._is_room_accessible(room_id, requester):
- return (), ()
+ if not await self._is_room_accessible(room_id, requester, origin):
+ return None, ()
room_entry = await self._build_room_entry(room_id)
@@ -269,7 +349,7 @@ class SpaceSummaryHandler:
event_format=format_event_for_client_v2,
)
)
- return (room_entry,), events_result
+ return room_entry, events_result
async def _summarize_remote_room(
self,
@@ -278,6 +358,26 @@ class SpaceSummaryHandler:
max_children: Optional[int],
exclude_rooms: Iterable[str],
) -> Tuple[Sequence[JsonDict], Sequence[JsonDict]]:
+ """
+ Request room entries and a list of event entries for a given room by querying a remote server.
+
+ Args:
+ room: The room to summarize.
+ suggested_only: True if only suggested children should be returned.
+ Otherwise, all children are returned.
+ max_children:
+ The maximum number of children rooms to include. This is capped
+ to a server-set limit.
+ exclude_rooms:
+ Rooms IDs which do not need to be summarized.
+
+ Returns:
+ A tuple of:
+ An iterable of rooms.
+
+ An iterable of the sorted children events. This may be limited
+ to a maximum size or may include all children.
+ """
room_id = room.room_id
logger.info("Requesting summary for %s via %s", room_id, room.via)
@@ -309,27 +409,93 @@ class SpaceSummaryHandler:
or ev.event_type == EventTypes.SpaceChild
)
- async def _is_room_accessible(self, room_id: str, requester: Optional[str]) -> bool:
- # if we have an authenticated requesting user, first check if they are in the
- # room
+ async def _is_room_accessible(
+ self, room_id: str, requester: Optional[str], origin: Optional[str]
+ ) -> bool:
+ """
+ Calculate whether the room should be shown in the spaces summary.
+
+ It should be included if:
+
+ * The requester is joined or invited to the room.
+ * The requester can join without an invite (per MSC3083).
+ * The origin server has any user that is joined or invited to the room.
+ * The history visibility is set to world readable.
+
+ Args:
+ room_id: The room ID to summarize.
+ requester:
+ The user requesting the summary, if it is a local request. None
+ if this is a federation request.
+ origin:
+ The server requesting the summary, if it is a federation request.
+ None if this is a local request.
+
+ Returns:
+ True if the room should be included in the spaces summary.
+ """
+ state_ids = await self._store.get_current_state_ids(room_id)
+
+ # If there's no state for the room, it isn't known.
+ if not state_ids:
+ logger.info("room %s is unknown, omitting from summary", room_id)
+ return False
+
+ room_version = await self._store.get_room_version(room_id)
+
+ # if we have an authenticated requesting user, first check if they are able to view
+ # stripped state in the room.
if requester:
+ member_event_id = state_ids.get((EventTypes.Member, requester), None)
+
+ # If they're in the room they can see info on it.
+ member_event = None
+ if member_event_id:
+ member_event = await self._store.get_event(member_event_id)
+ if member_event.membership in (Membership.JOIN, Membership.INVITE):
+ return True
+
+ # Otherwise, check if they should be allowed access via membership in a space.
try:
- await self._auth.check_user_in_room(room_id, requester)
- return True
+ await self._event_auth_handler.check_restricted_join_rules(
+ state_ids, room_version, requester, member_event
+ )
except AuthError:
+ # The user doesn't have access due to spaces, but might have access
+ # another way. Keep trying.
pass
+ else:
+ return True
+
+ # If this is a request over federation, check if the host is in the room or
+ # is in one of the spaces specified via the join rules.
+ elif origin:
+ if await self._auth.check_host_in_room(room_id, origin):
+ return True
+
+ # Alternately, if the host has a user in any of the spaces specified
+ # for access, then the host can see this room (and should do filtering
+ # if the requester cannot see it).
+ if await self._event_auth_handler.has_restricted_join_rules(
+ state_ids, room_version
+ ):
+ allowed_spaces = (
+ await self._event_auth_handler.get_spaces_that_allow_join(state_ids)
+ )
+ for space_id in allowed_spaces:
+ if await self._auth.check_host_in_room(space_id, origin):
+ return True
# otherwise, check if the room is peekable
- hist_vis_ev = await self._state_handler.get_current_state(
- room_id, EventTypes.RoomHistoryVisibility, ""
- )
- if hist_vis_ev:
+ hist_vis_event_id = state_ids.get((EventTypes.RoomHistoryVisibility, ""), None)
+ if hist_vis_event_id:
+ hist_vis_ev = await self._store.get_event(hist_vis_event_id)
hist_vis = hist_vis_ev.content.get("history_visibility")
if hist_vis == HistoryVisibility.WORLD_READABLE:
return True
logger.info(
- "room %s is unpeekable and user %s is not a member, omitting from summary",
+ "room %s is unpeekable and user %s is not a member / not allowed to join, omitting from summary",
room_id,
requester,
)
@@ -354,6 +520,15 @@ class SpaceSummaryHandler:
if not room_type:
room_type = create_event.content.get(EventContentFields.MSC1772_ROOM_TYPE)
+ room_version = await self._store.get_room_version(room_id)
+ allowed_spaces = None
+ if await self._event_auth_handler.has_restricted_join_rules(
+ current_state_ids, room_version
+ ):
+ allowed_spaces = await self._event_auth_handler.get_spaces_that_allow_join(
+ current_state_ids
+ )
+
entry = {
"room_id": stats["room_id"],
"name": stats["name"],
@@ -367,6 +542,7 @@ class SpaceSummaryHandler:
"guest_can_join": stats["guest_access"] == "can_join",
"creation_ts": create_event.origin_server_ts,
"room_type": room_type,
+ "allowed_spaces": allowed_spaces,
}
# Filter out Nones – rather omit the field altogether
@@ -430,8 +606,8 @@ def _is_suggested_child_event(edge_event: EventBase) -> bool:
return False
-# Order may only contain characters in the range of \x20 (space) to \x7F (~).
-_INVALID_ORDER_CHARS_RE = re.compile(r"[^\x20-\x7F]")
+# Order may only contain characters in the range of \x20 (space) to \x7E (~) inclusive.
+_INVALID_ORDER_CHARS_RE = re.compile(r"[^\x20-\x7E]")
def _child_events_comparison_key(child: EventBase) -> Tuple[bool, Optional[str], str]:
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 5f40f16e24..1ca6624fd5 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -813,7 +813,12 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
if self.deferred.called:
return
- self.stream.write(data)
+ try:
+ self.stream.write(data)
+ except Exception:
+ self.deferred.errback()
+ return
+
self.length += len(data)
# The first time the maximum size is exceeded, error and cancel the
# connection. dataReceived might be called again if data was received
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index bb837b7b19..1998990a14 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
import cgi
import codecs
import logging
@@ -19,13 +20,24 @@ import sys
import typing
import urllib.parse
from io import BytesIO, StringIO
-from typing import Callable, Dict, List, Optional, Tuple, Union
+from typing import (
+ Callable,
+ Dict,
+ Generic,
+ List,
+ Optional,
+ Tuple,
+ TypeVar,
+ Union,
+ overload,
+)
import attr
import treq
from canonicaljson import encode_canonical_json
from prometheus_client import Counter
from signedjson.sign import sign_json
+from typing_extensions import Literal
from twisted.internet import defer
from twisted.internet.error import DNSLookupError
@@ -48,6 +60,7 @@ from synapse.http.client import (
BlacklistingAgentWrapper,
BlacklistingReactorWrapper,
BodyExceededMaxSize,
+ ByteWriteable,
encode_query_args,
read_body_with_max_size,
)
@@ -88,6 +101,27 @@ _next_id = 1
QueryArgs = Dict[str, Union[str, List[str]]]
+T = TypeVar("T")
+
+
+class ByteParser(ByteWriteable, Generic[T], abc.ABC):
+ """A `ByteWriteable` that has an additional `finish` function that returns
+ the parsed data.
+ """
+
+ CONTENT_TYPE = abc.abstractproperty() # type: str # type: ignore
+ """The expected content type of the response, e.g. `application/json`. If
+ the content type doesn't match we fail the request.
+ """
+
+ @abc.abstractmethod
+ def finish(self) -> T:
+ """Called when response has finished streaming and the parser should
+ return the final result (or error).
+ """
+ pass
+
+
@attr.s(slots=True, frozen=True)
class MatrixFederationRequest:
method = attr.ib(type=str)
@@ -148,15 +182,33 @@ class MatrixFederationRequest:
return self.json
-async def _handle_json_response(
+class JsonParser(ByteParser[Union[JsonDict, list]]):
+ """A parser that buffers the response and tries to parse it as JSON."""
+
+ CONTENT_TYPE = "application/json"
+
+ def __init__(self):
+ self._buffer = StringIO()
+ self._binary_wrapper = BinaryIOWrapper(self._buffer)
+
+ def write(self, data: bytes) -> int:
+ return self._binary_wrapper.write(data)
+
+ def finish(self) -> Union[JsonDict, list]:
+ return json_decoder.decode(self._buffer.getvalue())
+
+
+async def _handle_response(
reactor: IReactorTime,
timeout_sec: float,
request: MatrixFederationRequest,
response: IResponse,
start_ms: int,
-) -> JsonDict:
+ parser: ByteParser[T],
+ max_response_size: Optional[int] = None,
+) -> T:
"""
- Reads the JSON body of a response, with a timeout
+ Reads the body of a response with a timeout and sends it to a parser
Args:
reactor: twisted reactor, for the timeout
@@ -164,23 +216,26 @@ async def _handle_json_response(
request: the request that triggered the response
response: response to the request
start_ms: Timestamp when request was made
+ parser: The parser for the response
+ max_response_size: The maximum size to read from the response, if None
+ uses the default.
Returns:
- The parsed JSON response
+ The parsed response
"""
+
+ if max_response_size is None:
+ max_response_size = MAX_RESPONSE_SIZE
+
try:
- check_content_type_is_json(response.headers)
+ check_content_type_is(response.headers, parser.CONTENT_TYPE)
- buf = StringIO()
- d = read_body_with_max_size(response, BinaryIOWrapper(buf), MAX_RESPONSE_SIZE)
+ d = read_body_with_max_size(response, parser, max_response_size)
d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor)
- def parse(_len: int):
- return json_decoder.decode(buf.getvalue())
+ length = await make_deferred_yieldable(d)
- d.addCallback(parse)
-
- body = await make_deferred_yieldable(d)
+ value = parser.finish()
except BodyExceededMaxSize as e:
# The response was too big.
logger.warning(
@@ -193,9 +248,9 @@ async def _handle_json_response(
)
raise RequestSendFailed(e, can_retry=False) from e
except ValueError as e:
- # The JSON content was invalid.
+ # The content was invalid.
logger.warning(
- "{%s} [%s] Failed to parse JSON response - %s %s",
+ "{%s} [%s] Failed to parse response - %s %s",
request.txn_id,
request.destination,
request.method,
@@ -225,16 +280,17 @@ async def _handle_json_response(
time_taken_secs = reactor.seconds() - start_ms / 1000
logger.info(
- "{%s} [%s] Completed request: %d %s in %.2f secs - %s %s",
+ "{%s} [%s] Completed request: %d %s in %.2f secs, got %d bytes - %s %s",
request.txn_id,
request.destination,
response.code,
response.phrase.decode("ascii", errors="replace"),
time_taken_secs,
+ length,
request.method,
request.uri.decode("ascii"),
)
- return body
+ return value
class BinaryIOWrapper:
@@ -671,6 +727,7 @@ class MatrixFederationHttpClient:
)
return auth_headers
+ @overload
async def put_json(
self,
destination: str,
@@ -683,7 +740,44 @@ class MatrixFederationHttpClient:
ignore_backoff: bool = False,
backoff_on_404: bool = False,
try_trailing_slash_on_400: bool = False,
+ parser: Literal[None] = None,
+ max_response_size: Optional[int] = None,
) -> Union[JsonDict, list]:
+ ...
+
+ @overload
+ async def put_json(
+ self,
+ destination: str,
+ path: str,
+ args: Optional[QueryArgs] = None,
+ data: Optional[JsonDict] = None,
+ json_data_callback: Optional[Callable[[], JsonDict]] = None,
+ long_retries: bool = False,
+ timeout: Optional[int] = None,
+ ignore_backoff: bool = False,
+ backoff_on_404: bool = False,
+ try_trailing_slash_on_400: bool = False,
+ parser: Optional[ByteParser[T]] = None,
+ max_response_size: Optional[int] = None,
+ ) -> T:
+ ...
+
+ async def put_json(
+ self,
+ destination: str,
+ path: str,
+ args: Optional[QueryArgs] = None,
+ data: Optional[JsonDict] = None,
+ json_data_callback: Optional[Callable[[], JsonDict]] = None,
+ long_retries: bool = False,
+ timeout: Optional[int] = None,
+ ignore_backoff: bool = False,
+ backoff_on_404: bool = False,
+ try_trailing_slash_on_400: bool = False,
+ parser: Optional[ByteParser] = None,
+ max_response_size: Optional[int] = None,
+ ):
"""Sends the specified json data using PUT
Args:
@@ -716,6 +810,10 @@ class MatrixFederationHttpClient:
of the request. Workaround for #3622 in Synapse <= v0.99.3. This
will be attempted before backing off if backing off has been
enabled.
+ parser: The parser to use to decode the response. Defaults to
+ parsing as JSON.
+ max_response_size: The maximum size to read from the response, if None
+ uses the default.
Returns:
Succeeds when we get a 2xx HTTP response. The
@@ -756,8 +854,17 @@ class MatrixFederationHttpClient:
else:
_sec_timeout = self.default_timeout
- body = await _handle_json_response(
- self.reactor, _sec_timeout, request, response, start_ms
+ if parser is None:
+ parser = JsonParser()
+
+ body = await _handle_response(
+ self.reactor,
+ _sec_timeout,
+ request,
+ response,
+ start_ms,
+ parser=parser,
+ max_response_size=max_response_size,
)
return body
@@ -830,12 +937,8 @@ class MatrixFederationHttpClient:
else:
_sec_timeout = self.default_timeout
- body = await _handle_json_response(
- self.reactor,
- _sec_timeout,
- request,
- response,
- start_ms,
+ body = await _handle_response(
+ self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser()
)
return body
@@ -907,8 +1010,8 @@ class MatrixFederationHttpClient:
else:
_sec_timeout = self.default_timeout
- body = await _handle_json_response(
- self.reactor, _sec_timeout, request, response, start_ms
+ body = await _handle_response(
+ self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser()
)
return body
@@ -975,8 +1078,8 @@ class MatrixFederationHttpClient:
else:
_sec_timeout = self.default_timeout
- body = await _handle_json_response(
- self.reactor, _sec_timeout, request, response, start_ms
+ body = await _handle_response(
+ self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser()
)
return body
@@ -1068,16 +1171,16 @@ def _flatten_response_never_received(e):
return repr(e)
-def check_content_type_is_json(headers: Headers) -> None:
+def check_content_type_is(headers: Headers, expected_content_type: str) -> None:
"""
Check that a set of HTTP headers have a Content-Type header, and that it
- is application/json.
+ is the expected value..
Args:
headers: headers to check
Raises:
- RequestSendFailed: if the Content-Type header is missing or isn't JSON
+ RequestSendFailed: if the Content-Type header is missing or doesn't match
"""
content_type_headers = headers.getRawHeaders(b"Content-Type")
@@ -1089,11 +1192,10 @@ def check_content_type_is_json(headers: Headers) -> None:
c_type = content_type_headers[0].decode("ascii") # only the first header
val, options = cgi.parse_header(c_type)
- if val != "application/json":
+ if val != expected_content_type:
raise RequestSendFailed(
RuntimeError(
- "Remote server sent Content-Type header of '%s', not 'application/json'"
- % c_type,
+ f"Remote server sent Content-Type header of '{c_type}', not '{expected_content_type}'",
),
can_retry=False,
)
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 671fd3fbcc..40754b7bea 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -105,8 +105,10 @@ class SynapseRequest(Request):
assert self.content, "handleContentChunk() called before gotLength()"
if self.content.tell() + len(data) > self._max_request_body_size:
logger.warning(
- "Aborting connection from %s because the request exceeds maximum size",
+ "Aborting connection from %s because the request exceeds maximum size: %s %s",
self.client,
+ self.get_method(),
+ self.get_redacted_uri(),
)
self.transport.abortConnection()
return
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 12c5ea0815..e562ff693e 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -56,14 +56,6 @@ class ModuleApi:
self._http_client = hs.get_simple_http_client() # type: SimpleHttpClient
self._public_room_list_manager = PublicRoomListManager(hs)
- # The next time these users sync, they will receive the current presence
- # state of all local users. Users are added by send_local_online_presence_to,
- # and removed after a successful sync.
- #
- # We make this a private variable to deter modules from accessing it directly,
- # though other classes in Synapse will still do so.
- self._send_full_presence_to_local_users = set()
-
@property
def http_client(self):
"""Allows making outbound HTTP requests to remote resources.
@@ -405,39 +397,44 @@ class ModuleApi:
Updates to remote users will be sent immediately, whereas local users will receive
them on their next sync attempt.
- Note that this method can only be run on the main or federation_sender worker
- processes.
+ Note that this method can only be run on the process that is configured to write to the
+ presence stream. By default this is the main process.
"""
- if not self._hs.should_send_federation():
+ if self._hs._instance_name not in self._hs.config.worker.writers.presence:
raise Exception(
"send_local_online_presence_to can only be run "
- "on processes that send federation",
+ "on the process that is configured to write to the "
+ "presence stream (by default this is the main process)",
)
+ local_users = set()
+ remote_users = set()
for user in users:
if self._hs.is_mine_id(user):
- # Modify SyncHandler._generate_sync_entry_for_presence to call
- # presence_source.get_new_events with an empty `from_key` if
- # that user's ID were in a list modified by ModuleApi somewhere.
- # That user would then get all presence state on next incremental sync.
-
- # Force a presence initial_sync for this user next time
- self._send_full_presence_to_local_users.add(user)
+ local_users.add(user)
else:
- # Retrieve presence state for currently online users that this user
- # is considered interested in
- presence_events, _ = await self._presence_stream.get_new_events(
- UserID.from_string(user), from_key=None, include_offline=False
- )
-
- # Send to remote destinations.
-
- # We pull out the presence handler here to break a cyclic
- # dependency between the presence router and module API.
- presence_handler = self._hs.get_presence_handler()
- await presence_handler.maybe_send_presence_to_interested_destinations(
- presence_events
- )
+ remote_users.add(user)
+
+ # We pull out the presence handler here to break a cyclic
+ # dependency between the presence router and module API.
+ presence_handler = self._hs.get_presence_handler()
+
+ if local_users:
+ # Force a presence initial_sync for these users next time they sync.
+ await presence_handler.send_full_presence_to_users(local_users)
+
+ for user in remote_users:
+ # Retrieve presence state for currently online users that this user
+ # is considered interested in.
+ presence_events, _ = await self._presence_stream.get_new_events(
+ UserID.from_string(user), from_key=None, include_offline=False
+ )
+
+ # Send to remote destinations.
+ destination = UserID.from_string(user).domain
+ presence_handler.get_federation_queue().send_presence_to_destinations(
+ presence_events, destination
+ )
class PublicRoomListManager:
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index c4b43b0d3f..5f9ea5003a 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -12,12 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import email.mime.multipart
-import email.utils
import logging
import urllib.parse
-from email.mime.multipart import MIMEMultipart
-from email.mime.text import MIMEText
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, TypeVar
import bleach
@@ -27,7 +23,6 @@ from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import StoreError
from synapse.config.emailconfig import EmailSubjectConfig
from synapse.events import EventBase
-from synapse.logging.context import make_deferred_yieldable
from synapse.push.presentable_names import (
calculate_room_name,
descriptor_from_member_events,
@@ -108,7 +103,7 @@ class Mailer:
self.template_html = template_html
self.template_text = template_text
- self.sendmail = self.hs.get_sendmail()
+ self.send_email_handler = hs.get_send_email_handler()
self.store = self.hs.get_datastore()
self.state_store = self.hs.get_storage().state
self.macaroon_gen = self.hs.get_macaroon_generator()
@@ -310,17 +305,6 @@ class Mailer:
self, email_address: str, subject: str, extra_template_vars: Dict[str, Any]
) -> None:
"""Send an email with the given information and template text"""
- try:
- from_string = self.hs.config.email_notif_from % {"app": self.app_name}
- except TypeError:
- from_string = self.hs.config.email_notif_from
-
- raw_from = email.utils.parseaddr(from_string)[1]
- raw_to = email.utils.parseaddr(email_address)[1]
-
- if raw_to == "":
- raise RuntimeError("Invalid 'to' address")
-
template_vars = {
"app_name": self.app_name,
"server_name": self.hs.config.server.server_name,
@@ -329,35 +313,14 @@ class Mailer:
template_vars.update(extra_template_vars)
html_text = self.template_html.render(**template_vars)
- html_part = MIMEText(html_text, "html", "utf8")
-
plain_text = self.template_text.render(**template_vars)
- text_part = MIMEText(plain_text, "plain", "utf8")
-
- multipart_msg = MIMEMultipart("alternative")
- multipart_msg["Subject"] = subject
- multipart_msg["From"] = from_string
- multipart_msg["To"] = email_address
- multipart_msg["Date"] = email.utils.formatdate()
- multipart_msg["Message-ID"] = email.utils.make_msgid()
- multipart_msg.attach(text_part)
- multipart_msg.attach(html_part)
-
- logger.info("Sending email to %s" % email_address)
-
- await make_deferred_yieldable(
- self.sendmail(
- self.hs.config.email_smtp_host,
- raw_from,
- raw_to,
- multipart_msg.as_string().encode("utf8"),
- reactor=self.hs.get_reactor(),
- port=self.hs.config.email_smtp_port,
- requireAuthentication=self.hs.config.email_smtp_user is not None,
- username=self.hs.config.email_smtp_user,
- password=self.hs.config.email_smtp_pass,
- requireTransportSecurity=self.hs.config.require_transport_security,
- )
+
+ await self.send_email_handler.send_email(
+ email_address=email_address,
+ subject=subject,
+ app_name=self.app_name,
+ html=html_text,
+ text=plain_text,
)
async def _get_room_vars(
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 989523c823..546231bec0 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -87,6 +87,7 @@ REQUIREMENTS = [
# We enforce that we have a `cryptography` version that bundles an `openssl`
# with the latest security patches.
"cryptography>=3.4.7",
+ "ijson>=3.0",
]
CONDITIONAL_REQUIREMENTS = {
diff --git a/synapse/replication/http/presence.py b/synapse/replication/http/presence.py
index f25307620d..bb00247953 100644
--- a/synapse/replication/http/presence.py
+++ b/synapse/replication/http/presence.py
@@ -73,6 +73,7 @@ class ReplicationPresenceSetState(ReplicationEndpoint):
{
"state": { ... },
"ignore_status_msg": false,
+ "force_notify": false
}
200 OK
@@ -91,17 +92,23 @@ class ReplicationPresenceSetState(ReplicationEndpoint):
self._presence_handler = hs.get_presence_handler()
@staticmethod
- async def _serialize_payload(user_id, state, ignore_status_msg=False):
+ async def _serialize_payload(
+ user_id, state, ignore_status_msg=False, force_notify=False
+ ):
return {
"state": state,
"ignore_status_msg": ignore_status_msg,
+ "force_notify": force_notify,
}
async def _handle_request(self, request, user_id):
content = parse_json_object_from_request(request)
await self._presence_handler.set_state(
- UserID.from_string(user_id), content["state"], content["ignore_status_msg"]
+ UserID.from_string(user_id),
+ content["state"],
+ content["ignore_status_msg"],
+ content["force_notify"],
)
return (
diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
index 8730966380..13ed87adc4 100644
--- a/synapse/replication/slave/storage/client_ips.py
+++ b/synapse/replication/slave/storage/client_ips.py
@@ -24,7 +24,7 @@ class SlavedClientIpStore(BaseSlavedStore):
super().__init__(database, db_conn, hs)
self.client_ip_last_seen = LruCache(
- cache_name="client_ip_last_seen", keylen=4, max_size=50000
+ cache_name="client_ip_last_seen", max_size=50000
) # type: LruCache[tuple, int]
async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
diff --git a/synapse/replication/slave/storage/transactions.py b/synapse/replication/slave/storage/transactions.py
deleted file mode 100644
index a59e543924..0000000000
--- a/synapse/replication/slave/storage/transactions.py
+++ /dev/null
@@ -1,21 +0,0 @@
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.storage.databases.main.transactions import TransactionStore
-
-from ._base import BaseSlavedStore
-
-
-class SlavedTransactionStore(TransactionStore, BaseSlavedStore):
- pass
diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py
index cc3ab5854b..b5e4c474ef 100644
--- a/synapse/rest/admin/server_notice_servlet.py
+++ b/synapse/rest/admin/server_notice_servlet.py
@@ -54,7 +54,6 @@ class SendServerNoticeServlet(RestServlet):
self.hs = hs
self.auth = hs.get_auth()
self.txns = HttpTransactionCache(hs)
- self.snm = hs.get_server_notices_manager()
def register(self, json_resource: HttpServer):
PATTERN = "/send_server_notice"
@@ -77,7 +76,10 @@ class SendServerNoticeServlet(RestServlet):
event_type = body.get("type", EventTypes.Message)
state_key = body.get("state_key")
- if not self.snm.is_enabled():
+ # We grab the server notices manager here as its initialisation has a check for worker processes,
+ # but worker processes still need to initialise SendServerNoticeServlet (as it is part of the
+ # admin api).
+ if not self.hs.get_server_notices_manager().is_enabled():
raise SynapseError(400, "Server notices are not enabled on this server")
user_id = body["user_id"]
@@ -85,7 +87,7 @@ class SendServerNoticeServlet(RestServlet):
if not self.hs.is_mine_id(user_id):
raise SynapseError(400, "Server notices can only be sent to local users")
- event = await self.snm.send_notice(
+ event = await self.hs.get_server_notices_manager().send_notice(
user_id=body["user_id"],
type=event_type,
state_key=state_key,
diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py
index e8dbe240d8..a5fcd15e3a 100644
--- a/synapse/rest/key/v2/local_key_resource.py
+++ b/synapse/rest/key/v2/local_key_resource.py
@@ -48,11 +48,6 @@ class LocalKey(Resource):
"key": # base64 encoded NACL verification key.
}
},
- "tls_fingerprints": [ # Fingerprints of the TLS certs this server uses.
- {
- "sha256": # base64 encoded sha256 fingerprint of the X509 cert
- },
- ],
"signatures": {
"this.server.example.com": {
"algorithm:version": # NACL signature for this server
@@ -89,14 +84,11 @@ class LocalKey(Resource):
"expired_ts": key.expired_ts,
}
- tls_fingerprints = self.config.tls_fingerprints
-
json_object = {
"valid_until_ts": self.valid_until_ts,
"server_name": self.config.server_name,
"verify_keys": verify_keys,
"old_verify_keys": old_verify_keys,
- "tls_fingerprints": tls_fingerprints,
}
for key in self.config.signing_key:
json_object = sign_json(json_object, self.config.server_name, key)
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index f648678b09..aba1734a55 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -73,9 +73,6 @@ class RemoteKey(DirectServeJsonResource):
"expired_ts": 0, # when the key stop being used.
}
}
- "tls_fingerprints": [
- { "sha256": # fingerprint }
- ]
"signatures": {
"remote.server.example.com": {...}
"this.server.example.com": {...}
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index e8a875b900..21c43c340c 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -76,6 +76,8 @@ class MediaRepository:
self.max_upload_size = hs.config.max_upload_size
self.max_image_pixels = hs.config.max_image_pixels
+ Thumbnailer.set_limits(self.max_image_pixels)
+
self.primary_base_path = hs.config.media_store_path # type: str
self.filepaths = MediaFilePaths(self.primary_base_path) # type: MediaFilePaths
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index 37fe582390..a65e9e1802 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -40,6 +40,10 @@ class Thumbnailer:
FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"}
+ @staticmethod
+ def set_limits(max_image_pixels: int):
+ Image.MAX_IMAGE_PIXELS = max_image_pixels
+
def __init__(self, input_path: str):
try:
self.image = Image.open(input_path)
@@ -47,6 +51,11 @@ class Thumbnailer:
# If an error occurs opening the image, a thumbnail won't be able to
# be generated.
raise ThumbnailError from e
+ except Image.DecompressionBombError as e:
+ # If an image decompression bomb error occurs opening the image,
+ # then the image exceeds the pixel limit and a thumbnail won't
+ # be able to be generated.
+ raise ThumbnailError from e
self.width, self.height = self.image.size
self.transpose_method = None
diff --git a/synapse/server.py b/synapse/server.py
index 2337d2d9b4..fec0024c89 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -104,6 +104,7 @@ from synapse.handlers.room_list import RoomListHandler
from synapse.handlers.room_member import RoomMemberHandler, RoomMemberMasterHandler
from synapse.handlers.room_member_worker import RoomMemberWorkerHandler
from synapse.handlers.search import SearchHandler
+from synapse.handlers.send_email import SendEmailHandler
from synapse.handlers.set_password import SetPasswordHandler
from synapse.handlers.space_summary import SpaceSummaryHandler
from synapse.handlers.sso import SsoHandler
@@ -550,6 +551,10 @@ class HomeServer(metaclass=abc.ABCMeta):
return SearchHandler(self)
@cache_in_self
+ def get_send_email_handler(self) -> SendEmailHandler:
+ return SendEmailHandler(self)
+
+ @cache_in_self
def get_set_password_handler(self) -> SetPasswordHandler:
return SetPasswordHandler(self)
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 3d98d3f5f8..0623da9aa1 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-import random
from abc import ABCMeta
from typing import TYPE_CHECKING, Any, Collection, Iterable, Optional, Union
@@ -44,7 +43,6 @@ class SQLBaseStore(metaclass=ABCMeta):
self._clock = hs.get_clock()
self.database_engine = database.engine
self.db_pool = database
- self.rand = random.SystemRandom()
def process_replication_rows(
self,
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 49c7606d51..9cce62ae6c 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -67,7 +67,7 @@ from .state import StateStore
from .stats import StatsStore
from .stream import StreamStore
from .tags import TagsStore
-from .transactions import TransactionStore
+from .transactions import TransactionWorkerStore
from .ui_auth import UIAuthStore
from .user_directory import UserDirectoryStore
from .user_erasure_store import UserErasureStore
@@ -83,7 +83,7 @@ class DataStore(
StreamStore,
ProfileStore,
PresenceStore,
- TransactionStore,
+ TransactionWorkerStore,
DirectoryStore,
KeyStore,
StateStore,
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index d60010e942..074b077bef 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -436,7 +436,7 @@ class ClientIpStore(ClientIpWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
self.client_ip_last_seen = LruCache(
- cache_name="client_ip_last_seen", keylen=4, max_size=50000
+ cache_name="client_ip_last_seen", max_size=50000
)
super().__init__(database, db_conn, hs)
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index c9346de316..fd87ba71ab 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -665,7 +665,7 @@ class DeviceWorkerStore(SQLBaseStore):
cached_method_name="get_device_list_last_stream_id_for_remote",
list_name="user_ids",
)
- async def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
+ async def get_device_list_last_stream_id_for_remotes(self, user_ids: Iterable[str]):
rows = await self.db_pool.simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",
@@ -1053,7 +1053,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
# Map of (user_id, device_id) -> bool. If there is an entry that implies
# the device exists.
self.device_id_exists_cache = LruCache(
- cache_name="device_id_exists", keylen=2, max_size=10000
+ cache_name="device_id_exists", max_size=10000
)
async def store_device(
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 398d6b6acb..9ba5778a88 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -473,7 +473,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
num_args=1,
)
async def _get_bare_e2e_cross_signing_keys_bulk(
- self, user_ids: List[str]
+ self, user_ids: Iterable[str]
) -> Dict[str, Dict[str, dict]]:
"""Returns the cross-signing keys for a set of users. The output of this
function should be passed to _get_e2e_cross_signing_signatures_txn if
@@ -497,7 +497,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
def _get_bare_e2e_cross_signing_keys_bulk_txn(
self,
txn: Connection,
- user_ids: List[str],
+ user_ids: Iterable[str],
) -> Dict[str, Dict[str, dict]]:
"""Returns the cross-signing keys for a set of users. The output of this
function should be passed to _get_e2e_cross_signing_signatures_txn if
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 2c823e09cf..6963bbf7f4 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -157,7 +157,6 @@ class EventsWorkerStore(SQLBaseStore):
self._get_event_cache = LruCache(
cache_name="*getEvent*",
- keylen=3,
max_size=hs.config.caches.event_cache_size,
)
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index 0e86807834..6990f3ed1d 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -55,7 +55,7 @@ class KeyStore(SQLBaseStore):
"""
keys = {}
- def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str]]) -> None:
+ def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str], ...]) -> None:
"""Processes a batch of keys to fetch, and adds the result to `keys`."""
# batch_iter always returns tuples so it's safe to do len(batch)
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index db22fab23e..6a2baa7841 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, Dict, List, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
from synapse.api.presence import PresenceState, UserPresenceState
from synapse.replication.tcp.streams import PresenceStream
@@ -57,6 +57,7 @@ class PresenceStore(SQLBaseStore):
db_conn, "presence_stream", "stream_id"
)
+ self.hs = hs
self._presence_on_startup = self._get_active_presence(db_conn)
presence_cache_prefill, min_presence_val = self.db_pool.get_cache_dict(
@@ -96,6 +97,15 @@ class PresenceStore(SQLBaseStore):
)
txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,))
+ # Delete old rows to stop database from getting really big
+ sql = "DELETE FROM presence_stream WHERE stream_id < ? AND "
+
+ for states in batch_iter(presence_states, 50):
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "user_id", [s.user_id for s in states]
+ )
+ txn.execute(sql + clause, [stream_id] + list(args))
+
# Actually insert new rows
self.db_pool.simple_insert_many_txn(
txn,
@@ -116,15 +126,6 @@ class PresenceStore(SQLBaseStore):
],
)
- # Delete old rows to stop database from getting really big
- sql = "DELETE FROM presence_stream WHERE stream_id < ? AND "
-
- for states in batch_iter(presence_states, 50):
- clause, args = make_in_list_sql_clause(
- self.database_engine, "user_id", [s.user_id for s in states]
- )
- txn.execute(sql + clause, [stream_id] + list(args))
-
async def get_all_presence_updates(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, list]], int, bool]:
@@ -210,6 +211,61 @@ class PresenceStore(SQLBaseStore):
return {row["user_id"]: UserPresenceState(**row) for row in rows}
+ async def should_user_receive_full_presence_with_token(
+ self,
+ user_id: str,
+ from_token: int,
+ ) -> bool:
+ """Check whether the given user should receive full presence using the stream token
+ they're updating from.
+
+ Args:
+ user_id: The ID of the user to check.
+ from_token: The stream token included in their /sync token.
+
+ Returns:
+ True if the user should have full presence sent to them, False otherwise.
+ """
+
+ def _should_user_receive_full_presence_with_token_txn(txn):
+ sql = """
+ SELECT 1 FROM users_to_send_full_presence_to
+ WHERE user_id = ?
+ AND presence_stream_id >= ?
+ """
+ txn.execute(sql, (user_id, from_token))
+ return bool(txn.fetchone())
+
+ return await self.db_pool.runInteraction(
+ "should_user_receive_full_presence_with_token",
+ _should_user_receive_full_presence_with_token_txn,
+ )
+
+ async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]):
+ """Adds to the list of users who should receive a full snapshot of presence
+ upon their next sync.
+
+ Args:
+ user_ids: An iterable of user IDs.
+ """
+ # Add user entries to the table, updating the presence_stream_id column if the user already
+ # exists in the table.
+ await self.db_pool.simple_upsert_many(
+ table="users_to_send_full_presence_to",
+ key_names=("user_id",),
+ key_values=[(user_id,) for user_id in user_ids],
+ value_names=("presence_stream_id",),
+ # We save the current presence stream ID token along with the user ID entry so
+ # that when a user /sync's, even if they syncing multiple times across separate
+ # devices at different times, each device will receive full presence once - when
+ # the presence stream ID in their sync token is less than the one in the table
+ # for their user ID.
+ value_values=(
+ (self._presence_id_gen.get_current_token(),) for _ in user_ids
+ ),
+ desc="add_users_to_send_full_presence_to",
+ )
+
async def get_presence_for_all_users(
self,
include_offline: bool = True,
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index d36b18a0e9..77e2eb27db 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import random
import re
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
@@ -1077,7 +1078,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
expiration_ts = now_ms + self._account_validity_period
if use_delta:
- expiration_ts = self.rand.randrange(
+ expiration_ts = random.randrange(
expiration_ts - self._account_validity_startup_job_max_delta,
expiration_ts,
)
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 82335e7a9d..d211c423b2 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -16,13 +16,15 @@ import logging
from collections import namedtuple
from typing import Iterable, List, Optional, Tuple
+import attr
from canonicaljson import encode_canonical_json
from synapse.metrics.background_process_metrics import wrap_as_background_process
-from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage._base import db_to_json
from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.types import JsonDict
-from synapse.util.caches.expiringcache import ExpiringCache
+from synapse.util.caches.descriptors import cached
db_binary_type = memoryview
@@ -38,10 +40,23 @@ _UpdateTransactionRow = namedtuple(
"_TransactionRow", ("response_code", "response_json")
)
-SENTINEL = object()
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class DestinationRetryTimings:
+ """The current destination retry timing info for a remote server."""
-class TransactionWorkerStore(SQLBaseStore):
+ # The first time we tried and failed to reach the remote server, in ms.
+ failure_ts: int
+
+ # The last time we tried and failed to reach the remote server, in ms.
+ retry_last_ts: int
+
+ # How long since the last time we tried to reach the remote server before
+ # trying again, in ms.
+ retry_interval: int
+
+
+class TransactionWorkerStore(CacheInvalidationWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
@@ -60,19 +75,6 @@ class TransactionWorkerStore(SQLBaseStore):
"_cleanup_transactions", _cleanup_transactions_txn
)
-
-class TransactionStore(TransactionWorkerStore):
- """A collection of queries for handling PDUs."""
-
- def __init__(self, database: DatabasePool, db_conn, hs):
- super().__init__(database, db_conn, hs)
-
- self._destination_retry_cache = ExpiringCache(
- cache_name="get_destination_retry_timings",
- clock=self._clock,
- expiry_ms=5 * 60 * 1000,
- )
-
async def get_received_txn_response(
self, transaction_id: str, origin: str
) -> Optional[Tuple[int, JsonDict]]:
@@ -145,7 +147,11 @@ class TransactionStore(TransactionWorkerStore):
desc="set_received_txn_response",
)
- async def get_destination_retry_timings(self, destination):
+ @cached(max_entries=10000)
+ async def get_destination_retry_timings(
+ self,
+ destination: str,
+ ) -> Optional[DestinationRetryTimings]:
"""Gets the current retry timings (if any) for a given destination.
Args:
@@ -156,34 +162,29 @@ class TransactionStore(TransactionWorkerStore):
Otherwise a dict for the retry scheme
"""
- result = self._destination_retry_cache.get(destination, SENTINEL)
- if result is not SENTINEL:
- return result
-
result = await self.db_pool.runInteraction(
"get_destination_retry_timings",
self._get_destination_retry_timings,
destination,
)
- # We don't hugely care about race conditions between getting and
- # invalidating the cache, since we time out fairly quickly anyway.
- self._destination_retry_cache[destination] = result
return result
- def _get_destination_retry_timings(self, txn, destination):
+ def _get_destination_retry_timings(
+ self, txn, destination: str
+ ) -> Optional[DestinationRetryTimings]:
result = self.db_pool.simple_select_one_txn(
txn,
table="destinations",
keyvalues={"destination": destination},
- retcols=("destination", "failure_ts", "retry_last_ts", "retry_interval"),
+ retcols=("failure_ts", "retry_last_ts", "retry_interval"),
allow_none=True,
)
# check we have a row and retry_last_ts is not null or zero
# (retry_last_ts can't be negative)
if result and result["retry_last_ts"]:
- return result
+ return DestinationRetryTimings(**result)
else:
return None
@@ -204,7 +205,6 @@ class TransactionStore(TransactionWorkerStore):
retry_interval: how long until next retry in ms
"""
- self._destination_retry_cache.pop(destination, None)
if self.database_engine.can_native_upsert:
return await self.db_pool.runInteraction(
"set_destination_retry_timings",
@@ -252,6 +252,10 @@ class TransactionStore(TransactionWorkerStore):
txn.execute(sql, (destination, failure_ts, retry_last_ts, retry_interval))
+ self._invalidate_cache_and_stream(
+ txn, self.get_destination_retry_timings, (destination,)
+ )
+
def _set_destination_retry_timings_emulated(
self, txn, destination, failure_ts, retry_last_ts, retry_interval
):
@@ -295,6 +299,10 @@ class TransactionStore(TransactionWorkerStore):
},
)
+ self._invalidate_cache_and_stream(
+ txn, self.get_destination_retry_timings, (destination,)
+ )
+
async def store_destination_rooms_entries(
self,
destinations: Iterable[str],
diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py
index acf6b2fb64..1ecdd40c38 100644
--- a/synapse/storage/databases/main/user_erasure_store.py
+++ b/synapse/storage/databases/main/user_erasure_store.py
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Dict, Iterable
+
from synapse.storage._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedList
@@ -37,21 +39,16 @@ class UserErasureWorkerStore(SQLBaseStore):
return bool(result)
@cachedList(cached_method_name="is_user_erased", list_name="user_ids")
- async def are_users_erased(self, user_ids):
+ async def are_users_erased(self, user_ids: Iterable[str]) -> Dict[str, bool]:
"""
Checks which users in a list have requested erasure
Args:
- user_ids (iterable[str]): full user id to check
+ user_ids: full user ids to check
Returns:
- dict[str, bool]:
- for each user, whether the user has requested erasure.
+ for each user, whether the user has requested erasure.
"""
- # this serves the dual purpose of (a) making sure we can do len and
- # iterate it multiple times, and (b) avoiding duplicates.
- user_ids = tuple(set(user_ids))
-
rows = await self.db_pool.simple_select_many_batch(
table="erased_users",
column="user_id",
diff --git a/synapse/storage/schema/main/delta/59/13users_to_send_full_presence_to.sql b/synapse/storage/schema/main/delta/59/13users_to_send_full_presence_to.sql
new file mode 100644
index 0000000000..07b0f53ecf
--- /dev/null
+++ b/synapse/storage/schema/main/delta/59/13users_to_send_full_presence_to.sql
@@ -0,0 +1,34 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Add a table that keeps track of a list of users who should, upon their next
+-- sync request, receive presence for all currently online users that they are
+-- "interested" in.
+
+-- The motivation for a DB table over an in-memory list is so that this list
+-- can be added to and retrieved from by any worker. Specifically, we don't
+-- want to duplicate work across multiple sync workers.
+
+CREATE TABLE IF NOT EXISTS users_to_send_full_presence_to(
+ -- The user ID to send full presence to.
+ user_id TEXT PRIMARY KEY,
+ -- A presence stream ID token - the current presence stream token when the row was last upserted.
+ -- If a user calls /sync and this token is part of the update they're to receive, we also include
+ -- full user presence in the response.
+ -- This allows multiple devices for a user to receive full presence whenever they next call /sync.
+ presence_stream_id BIGINT,
+ FOREIGN KEY (user_id)
+ REFERENCES users (name)
+);
\ No newline at end of file
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index cfafba22c5..c9dce726cb 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -540,7 +540,7 @@ class StateGroupStorage:
state_filter: The state filter used to fetch state from the database.
Returns:
- A dict from (type, state_key) -> state_event
+ A dict from (type, state_key) -> state_event_id
"""
state_map = await self.get_state_ids_for_events(
[event_id], state_filter or StateFilter.all()
diff --git a/synapse/util/batching_queue.py b/synapse/util/batching_queue.py
new file mode 100644
index 0000000000..44bbb7b1a8
--- /dev/null
+++ b/synapse/util/batching_queue.py
@@ -0,0 +1,153 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import (
+ Awaitable,
+ Callable,
+ Dict,
+ Generic,
+ Hashable,
+ List,
+ Set,
+ Tuple,
+ TypeVar,
+)
+
+from twisted.internet import defer
+
+from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
+from synapse.metrics import LaterGauge
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util import Clock
+
+logger = logging.getLogger(__name__)
+
+
+V = TypeVar("V")
+R = TypeVar("R")
+
+
+class BatchingQueue(Generic[V, R]):
+ """A queue that batches up work, calling the provided processing function
+ with all pending work (for a given key).
+
+ The provided processing function will only be called once at a time for each
+ key. It will be called the next reactor tick after `add_to_queue` has been
+ called, and will keep being called until the queue has been drained (for the
+ given key).
+
+ Note that the return value of `add_to_queue` will be the return value of the
+ processing function that processed the given item. This means that the
+ returned value will likely include data for other items that were in the
+ batch.
+ """
+
+ def __init__(
+ self,
+ name: str,
+ clock: Clock,
+ process_batch_callback: Callable[[List[V]], Awaitable[R]],
+ ):
+ self._name = name
+ self._clock = clock
+
+ # The set of keys currently being processed.
+ self._processing_keys = set() # type: Set[Hashable]
+
+ # The currently pending batch of values by key, with a Deferred to call
+ # with the result of the corresponding `_process_batch_callback` call.
+ self._next_values = {} # type: Dict[Hashable, List[Tuple[V, defer.Deferred]]]
+
+ # The function to call with batches of values.
+ self._process_batch_callback = process_batch_callback
+
+ LaterGauge(
+ "synapse_util_batching_queue_number_queued",
+ "The number of items waiting in the queue across all keys",
+ labels=("name",),
+ caller=lambda: sum(len(v) for v in self._next_values.values()),
+ )
+
+ LaterGauge(
+ "synapse_util_batching_queue_number_of_keys",
+ "The number of distinct keys that have items queued",
+ labels=("name",),
+ caller=lambda: len(self._next_values),
+ )
+
+ async def add_to_queue(self, value: V, key: Hashable = ()) -> R:
+ """Adds the value to the queue with the given key, returning the result
+ of the processing function for the batch that included the given value.
+
+ The optional `key` argument allows sharding the queue by some key. The
+ queues will then be processed in parallel, i.e. the process batch
+ function will be called in parallel with batched values from a single
+ key.
+ """
+
+ # First we create a defer and add it and the value to the list of
+ # pending items.
+ d = defer.Deferred()
+ self._next_values.setdefault(key, []).append((value, d))
+
+ # If we're not currently processing the key fire off a background
+ # process to start processing.
+ if key not in self._processing_keys:
+ run_as_background_process(self._name, self._process_queue, key)
+
+ return await make_deferred_yieldable(d)
+
+ async def _process_queue(self, key: Hashable) -> None:
+ """A background task to repeatedly pull things off the queue for the
+ given key and call the `self._process_batch_callback` with the values.
+ """
+
+ try:
+ if key in self._processing_keys:
+ return
+
+ self._processing_keys.add(key)
+
+ while True:
+ # We purposefully wait a reactor tick to allow us to batch
+ # together requests that we're about to receive. A common
+ # pattern is to call `add_to_queue` multiple times at once, and
+ # deferring to the next reactor tick allows us to batch all of
+ # those up.
+ await self._clock.sleep(0)
+
+ next_values = self._next_values.pop(key, [])
+ if not next_values:
+ # We've exhausted the queue.
+ break
+
+ try:
+ values = [value for value, _ in next_values]
+ results = await self._process_batch_callback(values)
+
+ for _, deferred in next_values:
+ with PreserveLoggingContext():
+ deferred.callback(results)
+
+ except Exception as e:
+ for _, deferred in next_values:
+ if deferred.called:
+ continue
+
+ with PreserveLoggingContext():
+ deferred.errback(e)
+
+ finally:
+ self._processing_keys.discard(key)
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 484097a48a..371e7e4dd0 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -70,7 +70,6 @@ class DeferredCache(Generic[KT, VT]):
self,
name: str,
max_entries: int = 1000,
- keylen: int = 1,
tree: bool = False,
iterable: bool = False,
apply_cache_factor_from_config: bool = True,
@@ -101,7 +100,6 @@ class DeferredCache(Generic[KT, VT]):
# a Deferred.
self.cache = LruCache(
max_size=max_entries,
- keylen=keylen,
cache_name=name,
cache_type=cache_type,
size_callback=(lambda d: len(d) or 1) if iterable else None,
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index ac4a078b26..2ac24a2f25 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -270,7 +270,6 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
cache = DeferredCache(
name=self.orig.__name__,
max_entries=self.max_entries,
- keylen=self.num_args,
tree=self.tree,
iterable=self.iterable,
) # type: DeferredCache[CacheKey, Any]
@@ -322,8 +321,8 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
class DeferredCacheListDescriptor(_CacheDescriptorBase):
"""Wraps an existing cache to support bulk fetching of keys.
- Given a list of keys it looks in the cache to find any hits, then passes
- the list of missing keys to the wrapped function.
+ Given an iterable of keys it looks in the cache to find any hits, then passes
+ the tuple of missing keys to the wrapped function.
Once wrapped, the function returns a Deferred which resolves to the list
of results.
@@ -437,7 +436,9 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
return f
args_to_call = dict(arg_dict)
- args_to_call[self.list_name] = list(missing)
+ # copy the missing set before sending it to the callee, to guard against
+ # modification.
+ args_to_call[self.list_name] = tuple(missing)
cached_defers.append(
defer.maybeDeferred(
@@ -522,14 +523,14 @@ def cachedList(
Used to do batch lookups for an already created cache. A single argument
is specified as a list that is iterated through to lookup keys in the
- original cache. A new list consisting of the keys that weren't in the cache
- get passed to the original function, the result of which is stored in the
+ original cache. A new tuple consisting of the (deduplicated) keys that weren't in
+ the cache gets passed to the original function, the result of which is stored in the
cache.
Args:
cached_method_name: The name of the single-item lookup method.
This is only used to find the cache to use.
- list_name: The name of the argument that is the list to use to
+ list_name: The name of the argument that is the iterable to use to
do batch lookups in the cache.
num_args: Number of arguments to use as the key in the cache
(including list_name). Defaults to all named parameters.
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 1be675e014..54df407ff7 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -34,7 +34,7 @@ from typing_extensions import Literal
from synapse.config import cache as cache_config
from synapse.util import caches
from synapse.util.caches import CacheMetric, register_cache
-from synapse.util.caches.treecache import TreeCache
+from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
try:
from pympler.asizeof import Asizer
@@ -160,7 +160,6 @@ class LruCache(Generic[KT, VT]):
self,
max_size: int,
cache_name: Optional[str] = None,
- keylen: int = 1,
cache_type: Type[Union[dict, TreeCache]] = dict,
size_callback: Optional[Callable] = None,
metrics_collection_callback: Optional[Callable[[], None]] = None,
@@ -173,9 +172,6 @@ class LruCache(Generic[KT, VT]):
cache_name: The name of this cache, for the prometheus metrics. If unset,
no metrics will be reported on this cache.
- keylen: The length of the tuple used as the cache key. Ignored unless
- cache_type is `TreeCache`.
-
cache_type (type):
type of underlying cache to be used. Typically one of dict
or TreeCache.
@@ -403,7 +399,9 @@ class LruCache(Generic[KT, VT]):
popped = cache.pop(key)
if popped is None:
return
- for leaf in enumerate_leaves(popped, keylen - len(cast(tuple, key))):
+ # for each deleted node, we now need to remove it from the linked list
+ # and run its callbacks.
+ for leaf in iterate_tree_cache_entry(popped):
delete_node(leaf)
@synchronized
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index eb4d98f683..73502a8b06 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -1,18 +1,43 @@
-from typing import Dict
+# Copyright 2016-2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
SENTINEL = object()
+class TreeCacheNode(dict):
+ """The type of nodes in our tree.
+
+ Has its own type so we can distinguish it from real dicts that are stored at the
+ leaves.
+ """
+
+ pass
+
+
class TreeCache:
"""
Tree-based backing store for LruCache. Allows subtrees of data to be deleted
efficiently.
Keys must be tuples.
+
+ The data structure is a chain of TreeCacheNodes:
+ root = {key_1: {key_2: _value}}
"""
def __init__(self):
self.size = 0
- self.root = {} # type: Dict
+ self.root = TreeCacheNode()
def __setitem__(self, key, value):
return self.set(key, value)
@@ -21,10 +46,23 @@ class TreeCache:
return self.get(key, SENTINEL) is not SENTINEL
def set(self, key, value):
+ if isinstance(value, TreeCacheNode):
+ # this would mean we couldn't tell where our tree ended and the value
+ # started.
+ raise ValueError("Cannot store TreeCacheNodes in a TreeCache")
+
node = self.root
for k in key[:-1]:
- node = node.setdefault(k, {})
- node[key[-1]] = _Entry(value)
+ next_node = node.get(k, SENTINEL)
+ if next_node is SENTINEL:
+ next_node = node[k] = TreeCacheNode()
+ elif not isinstance(next_node, TreeCacheNode):
+ # this suggests that the caller is not being consistent with its key
+ # length.
+ raise ValueError("value conflicts with an existing subtree")
+ node = next_node
+
+ node[key[-1]] = value
self.size += 1
def get(self, key, default=None):
@@ -33,25 +71,41 @@ class TreeCache:
node = node.get(k, None)
if node is None:
return default
- return node.get(key[-1], _Entry(default)).value
+ return node.get(key[-1], default)
def clear(self):
self.size = 0
- self.root = {}
+ self.root = TreeCacheNode()
def pop(self, key, default=None):
+ """Remove the given key, or subkey, from the cache
+
+ Args:
+ key: key or subkey to remove.
+ default: value to return if key is not found
+
+ Returns:
+ If the key is not found, 'default'. If the key is complete, the removed
+ value. If the key is partial, the TreeCacheNode corresponding to the part
+ of the tree that was removed.
+ """
+ # a list of the nodes we have touched on the way down the tree
nodes = []
node = self.root
for k in key[:-1]:
node = node.get(k, None)
- nodes.append(node) # don't add the root node
if node is None:
return default
+ if not isinstance(node, TreeCacheNode):
+ # we've gone off the end of the tree
+ raise ValueError("pop() key too long")
+ nodes.append(node) # don't add the root node
popped = node.pop(key[-1], SENTINEL)
if popped is SENTINEL:
return default
+ # working back up the tree, clear out any nodes that are now empty
node_and_keys = list(zip(nodes, key))
node_and_keys.reverse()
node_and_keys.append((self.root, None))
@@ -61,14 +115,15 @@ class TreeCache:
if n:
break
+ # found an empty node: remove it from its parent, and loop.
node_and_keys[i + 1][0].pop(k)
- popped, cnt = _strip_and_count_entires(popped)
+ cnt = sum(1 for _ in iterate_tree_cache_entry(popped))
self.size -= cnt
return popped
def values(self):
- return list(iterate_tree_cache_entry(self.root))
+ return iterate_tree_cache_entry(self.root)
def __len__(self):
return self.size
@@ -78,36 +133,9 @@ def iterate_tree_cache_entry(d):
"""Helper function to iterate over the leaves of a tree, i.e. a dict of that
can contain dicts.
"""
- if isinstance(d, dict):
+ if isinstance(d, TreeCacheNode):
for value_d in d.values():
for value in iterate_tree_cache_entry(value_d):
yield value
else:
- if isinstance(d, _Entry):
- yield d.value
- else:
- yield d
-
-
-class _Entry:
- __slots__ = ["value"]
-
- def __init__(self, value):
- self.value = value
-
-
-def _strip_and_count_entires(d):
- """Takes an _Entry or dict with leaves of _Entry's, and either returns the
- value or a dictionary with _Entry's replaced by their values.
-
- Also returns the count of _Entry's
- """
- if isinstance(d, dict):
- cnt = 0
- for key, value in d.items():
- v, n = _strip_and_count_entires(value)
- d[key] = v
- cnt += n
- return d, cnt
- else:
- return d.value, 1
+ yield d
diff --git a/synapse/util/hash.py b/synapse/util/hash.py
index ba676e1762..7625ca8c2c 100644
--- a/synapse/util/hash.py
+++ b/synapse/util/hash.py
@@ -17,15 +17,15 @@ import hashlib
import unpaddedbase64
-def sha256_and_url_safe_base64(input_text):
+def sha256_and_url_safe_base64(input_text: str) -> str:
"""SHA256 hash an input string, encode the digest as url-safe base64, and
return
- :param input_text: string to hash
- :type input_text: str
+ Args:
+ input_text: string to hash
- :returns a sha256 hashed and url-safe base64 encoded digest
- :rtype: str
+ returns:
+ A sha256 hashed and url-safe base64 encoded digest
"""
digest = hashlib.sha256(input_text.encode()).digest()
return unpaddedbase64.encode_base64(digest, urlsafe=True)
diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py
index abfdc29832..886afa9d19 100644
--- a/synapse/util/iterutils.py
+++ b/synapse/util/iterutils.py
@@ -30,12 +30,12 @@ from typing import (
T = TypeVar("T")
-def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T]]:
+def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T, ...]]:
"""batch an iterable up into tuples with a maximum size
Args:
- iterable (iterable): the iterable to slice
- size (int): the maximum batch size
+ iterable: the iterable to slice
+ size: the maximum batch size
Returns:
an iterator over the chunks
@@ -46,10 +46,7 @@ def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T]]:
return iter(lambda: tuple(islice(sourceiter, size)), ())
-ISeq = TypeVar("ISeq", bound=Sequence, covariant=True)
-
-
-def chunk_seq(iseq: ISeq, maxlen: int) -> Iterable[ISeq]:
+def chunk_seq(iseq: Sequence[T], maxlen: int) -> Iterable[Sequence[T]]:
"""Split the given sequence into chunks of the given size
The last chunk may be shorter than the given size.
diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py
index 8acbe276e4..cbfbd097f9 100644
--- a/synapse/util/module_loader.py
+++ b/synapse/util/module_loader.py
@@ -15,6 +15,7 @@
import importlib
import importlib.util
import itertools
+from types import ModuleType
from typing import Any, Iterable, Tuple, Type
import jsonschema
@@ -44,8 +45,8 @@ def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]:
# We need to import the module, and then pick the class out of
# that, so we split based on the last dot.
- module, clz = modulename.rsplit(".", 1)
- module = importlib.import_module(module)
+ module_name, clz = modulename.rsplit(".", 1)
+ module = importlib.import_module(module_name)
provider_class = getattr(module, clz)
# Load the module config. If None, pass an empty dictionary instead
@@ -69,11 +70,11 @@ def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]:
return provider_class, provider_config
-def load_python_module(location: str):
+def load_python_module(location: str) -> ModuleType:
"""Load a python module, and return a reference to its global namespace
Args:
- location (str): path to the module
+ location: path to the module
Returns:
python module object
diff --git a/synapse/util/msisdn.py b/synapse/util/msisdn.py
index bbbdebf264..1046224f15 100644
--- a/synapse/util/msisdn.py
+++ b/synapse/util/msisdn.py
@@ -17,19 +17,19 @@ import phonenumbers
from synapse.api.errors import SynapseError
-def phone_number_to_msisdn(country, number):
+def phone_number_to_msisdn(country: str, number: str) -> str:
"""
Takes an ISO-3166-1 2 letter country code and phone number and
returns an msisdn representing the canonical version of that
phone number.
Args:
- country (str): ISO-3166-1 2 letter country code
- number (str): Phone number in a national or international format
+ country: ISO-3166-1 2 letter country code
+ number: Phone number in a national or international format
Returns:
- (str) The canonical form of the phone number, as an msisdn
+ The canonical form of the phone number, as an msisdn
Raises:
- SynapseError if the number could not be parsed.
+ SynapseError if the number could not be parsed.
"""
try:
phoneNumber = phonenumbers.parse(number, country)
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index f9c370a814..129b47cd49 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -82,11 +82,9 @@ async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **k
retry_timings = await store.get_destination_retry_timings(destination)
if retry_timings:
- failure_ts = retry_timings["failure_ts"]
- retry_last_ts, retry_interval = (
- retry_timings["retry_last_ts"],
- retry_timings["retry_interval"],
- )
+ failure_ts = retry_timings.failure_ts
+ retry_last_ts = retry_timings.retry_last_ts
+ retry_interval = retry_timings.retry_interval
now = int(clock.time_msec())
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 4f25cd1d26..f029432191 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -13,8 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
-import random
import re
+import secrets
import string
from collections.abc import Iterable
from typing import Optional, Tuple
@@ -35,26 +35,27 @@ CLIENT_SECRET_REGEX = re.compile(r"^[0-9a-zA-Z\.=_\-]+$")
#
MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$")
-# random_string and random_string_with_symbols are used for a range of things,
-# some cryptographically important, some less so. We use SystemRandom to make sure
-# we get cryptographically-secure randoms.
-rand = random.SystemRandom()
-
def random_string(length: int) -> str:
- return "".join(rand.choice(string.ascii_letters) for _ in range(length))
+ """Generate a cryptographically secure string of random letters.
+
+ Drawn from the characters: `a-z` and `A-Z`
+ """
+ return "".join(secrets.choice(string.ascii_letters) for _ in range(length))
def random_string_with_symbols(length: int) -> str:
- return "".join(rand.choice(_string_with_symbols) for _ in range(length))
+ """Generate a cryptographically secure string of random letters/numbers/symbols.
+
+ Drawn from the characters: `a-z`, `A-Z`, `0-9`, and `.,;:^&*-_+=#~@`
+ """
+ return "".join(secrets.choice(_string_with_symbols) for _ in range(length))
def is_ascii(s: bytes) -> bool:
try:
s.decode("ascii").encode("ascii")
- except UnicodeDecodeError:
- return False
- except UnicodeEncodeError:
+ except UnicodeError:
return False
return True
|