diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index c6f8733e60..61da585ad0 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -26,13 +26,14 @@ from synapse.api.errors import (
Codes,
InvalidClientTokenError,
MissingClientTokenError,
+ UnstableSpecAuthError,
)
from synapse.appservice import ApplicationService
from synapse.http import get_request_user_agent
from synapse.http.site import SynapseRequest
-from synapse.logging.opentracing import (
- active_span,
+from synapse.logging.tracing import (
force_tracing,
+ get_active_span,
start_active_span,
trace,
)
@@ -111,8 +112,11 @@ class Auth:
forgot = await self.store.did_forget(user_id, room_id)
if not forgot:
return membership, member_event_id
-
- raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
+ raise UnstableSpecAuthError(
+ 403,
+ "User %s not in room %s" % (user_id, room_id),
+ errcode=Codes.NOT_JOINED,
+ )
async def get_user_by_req(
self,
@@ -137,7 +141,7 @@ class Auth:
is invalid.
AuthError if access is denied for the user in the access token
"""
- parent_span = active_span()
+ parent_span = get_active_span()
with start_active_span("get_user_by_req"):
requester = await self._wrapped_get_user_by_req(
request, allow_guest, allow_expired
@@ -147,19 +151,18 @@ class Auth:
if requester.authenticated_entity in self._force_tracing_for_users:
# request tracing is enabled for this user, so we need to force it
# tracing on for the parent span (which will be the servlet span).
- #
+ force_tracing(parent_span)
# It's too late for the get_user_by_req span to inherit the setting,
# so we also force it on for that.
force_tracing()
- force_tracing(parent_span)
- parent_span.set_tag(
+ parent_span.set_attribute(
"authenticated_entity", requester.authenticated_entity
)
- parent_span.set_tag("user_id", requester.user.to_string())
+ parent_span.set_attribute("user_id", requester.user.to_string())
if requester.device_id is not None:
- parent_span.set_tag("device_id", requester.device_id)
+ parent_span.set_attribute("device_id", requester.device_id)
if requester.app_service is not None:
- parent_span.set_tag("appservice_id", requester.app_service.id)
+ parent_span.set_attribute("appservice_id", requester.app_service.id)
return requester
async def _wrapped_get_user_by_req(
@@ -170,7 +173,7 @@ class Auth:
) -> Requester:
"""Helper for get_user_by_req
- Once get_user_by_req has set up the opentracing span, this does the actual work.
+ Once get_user_by_req has set up the tracing span, this does the actual work.
"""
try:
ip_addr = request.getClientAddress().host
@@ -606,8 +609,9 @@ class Auth:
== HistoryVisibility.WORLD_READABLE
):
return Membership.JOIN, None
- raise AuthError(
+ raise UnstableSpecAuthError(
403,
"User %s not in room %s, and room previews are disabled"
% (user_id, room_id),
+ errcode=Codes.NOT_JOINED,
)
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 2653764119..fc04e4d4bd 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -193,6 +193,9 @@ class LimitBlockingTypes:
class EventContentFields:
"""Fields found in events' content, regardless of type."""
+ # Synapse internal content field for tracing
+ TRACING_CONTEXT: Final = "org.matrix.tracing_context"
+
# Labels for the event, cf https://github.com/matrix-org/matrix-doc/pull/2326
LABELS: Final = "org.matrix.labels"
@@ -268,4 +271,4 @@ class PublicRoomsFilterFields:
"""
GENERIC_SEARCH_TERM: Final = "generic_search_term"
- ROOM_TYPES: Final = "org.matrix.msc3827.room_types"
+ ROOM_TYPES: Final = "room_types"
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 1c74e131f2..e6dea89c6d 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -26,6 +26,7 @@ from twisted.web import http
from synapse.util import json_decoder
if typing.TYPE_CHECKING:
+ from synapse.config.homeserver import HomeServerConfig
from synapse.types import JsonDict
logger = logging.getLogger(__name__)
@@ -80,6 +81,12 @@ class Codes(str, Enum):
INVALID_SIGNATURE = "M_INVALID_SIGNATURE"
USER_DEACTIVATED = "M_USER_DEACTIVATED"
+ # Part of MSC3848
+ # https://github.com/matrix-org/matrix-spec-proposals/pull/3848
+ ALREADY_JOINED = "ORG.MATRIX.MSC3848.ALREADY_JOINED"
+ NOT_JOINED = "ORG.MATRIX.MSC3848.NOT_JOINED"
+ INSUFFICIENT_POWER = "ORG.MATRIX.MSC3848.INSUFFICIENT_POWER"
+
# The account has been suspended on the server.
# By opposition to `USER_DEACTIVATED`, this is a reversible measure
# that can possibly be appealed and reverted.
@@ -167,7 +174,7 @@ class SynapseError(CodeMessageException):
else:
self._additional_fields = dict(additional_fields)
- def error_dict(self) -> "JsonDict":
+ def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
return cs_error(self.msg, self.errcode, **self._additional_fields)
@@ -213,7 +220,7 @@ class ConsentNotGivenError(SynapseError):
)
self._consent_uri = consent_uri
- def error_dict(self) -> "JsonDict":
+ def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
return cs_error(self.msg, self.errcode, consent_uri=self._consent_uri)
@@ -307,6 +314,37 @@ class AuthError(SynapseError):
super().__init__(code, msg, errcode, additional_fields)
+class UnstableSpecAuthError(AuthError):
+ """An error raised when a new error code is being proposed to replace a previous one.
+ This error will return a "org.matrix.unstable.errcode" property with the new error code,
+ with the previous error code still being defined in the "errcode" property.
+
+ This error will include `org.matrix.msc3848.unstable.errcode` in the C-S error body.
+ """
+
+ def __init__(
+ self,
+ code: int,
+ msg: str,
+ errcode: str,
+ previous_errcode: str = Codes.FORBIDDEN,
+ additional_fields: Optional[dict] = None,
+ ):
+ self.previous_errcode = previous_errcode
+ super().__init__(code, msg, errcode, additional_fields)
+
+ def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
+ fields = {}
+ if config is not None and config.experimental.msc3848_enabled:
+ fields["org.matrix.msc3848.unstable.errcode"] = self.errcode
+ return cs_error(
+ self.msg,
+ self.previous_errcode,
+ **fields,
+ **self._additional_fields,
+ )
+
+
class InvalidClientCredentialsError(SynapseError):
"""An error raised when there was a problem with the authorisation credentials
in a client request.
@@ -338,8 +376,8 @@ class InvalidClientTokenError(InvalidClientCredentialsError):
super().__init__(msg=msg, errcode="M_UNKNOWN_TOKEN")
self._soft_logout = soft_logout
- def error_dict(self) -> "JsonDict":
- d = super().error_dict()
+ def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
+ d = super().error_dict(config)
d["soft_logout"] = self._soft_logout
return d
@@ -362,7 +400,7 @@ class ResourceLimitError(SynapseError):
self.limit_type = limit_type
super().__init__(code, msg, errcode=errcode)
- def error_dict(self) -> "JsonDict":
+ def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
return cs_error(
self.msg,
self.errcode,
@@ -397,7 +435,7 @@ class InvalidCaptchaError(SynapseError):
super().__init__(code, msg, errcode)
self.error_url = error_url
- def error_dict(self) -> "JsonDict":
+ def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
return cs_error(self.msg, self.errcode, error_url=self.error_url)
@@ -414,7 +452,7 @@ class LimitExceededError(SynapseError):
super().__init__(code, msg, errcode)
self.retry_after_ms = retry_after_ms
- def error_dict(self) -> "JsonDict":
+ def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
return cs_error(self.msg, self.errcode, retry_after_ms=self.retry_after_ms)
@@ -429,7 +467,7 @@ class RoomKeysVersionError(SynapseError):
super().__init__(403, "Wrong room_keys version", Codes.WRONG_ROOM_KEYS_VERSION)
self.current_version = current_version
- def error_dict(self) -> "JsonDict":
+ def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
return cs_error(self.msg, self.errcode, current_version=self.current_version)
@@ -469,7 +507,7 @@ class IncompatibleRoomVersionError(SynapseError):
self._room_version = room_version
- def error_dict(self) -> "JsonDict":
+ def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
return cs_error(self.msg, self.errcode, room_version=self._room_version)
@@ -515,7 +553,7 @@ class UnredactedContentDeletedError(SynapseError):
)
self.content_keep_ms = content_keep_ms
- def error_dict(self) -> "JsonDict":
+ def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
extra = {}
if self.content_keep_ms is not None:
extra = {"fi.mau.msc2815.content_keep_ms": self.content_keep_ms}
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 923891ae0d..5763352b29 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -62,7 +62,7 @@ from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
from synapse.handlers.auth import load_legacy_password_auth_providers
from synapse.logging.context import PreserveLoggingContext
-from synapse.logging.opentracing import init_tracer
+from synapse.logging.tracing import init_tracer
from synapse.metrics import install_gc_manager, register_threadpool
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.metrics.jemalloc import setup_jemalloc_stats
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 6bafa7d3f3..745e704141 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -44,6 +44,7 @@ from synapse.app._base import (
register_start,
)
from synapse.config._base import ConfigError, format_config_error
+from synapse.config.emailconfig import ThreepidBehaviour
from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import ListenerConfig
from synapse.federation.transport.server import TransportLayerServer
@@ -201,7 +202,7 @@ class SynapseHomeServer(HomeServer):
}
)
- if self.config.email.can_verify_email:
+ if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
from synapse.rest.synapse.client.password_reset import (
PasswordResetSubmitTokenResource,
)
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index 3ead80d985..7765c5b454 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -18,6 +18,7 @@
import email.utils
import logging
import os
+from enum import Enum
from typing import Any
import attr
@@ -85,14 +86,19 @@ class EmailConfig(Config):
if email_config is None:
email_config = {}
+ self.force_tls = email_config.get("force_tls", False)
self.email_smtp_host = email_config.get("smtp_host", "localhost")
- self.email_smtp_port = email_config.get("smtp_port", 25)
+ self.email_smtp_port = email_config.get(
+ "smtp_port", 465 if self.force_tls else 25
+ )
self.email_smtp_user = email_config.get("smtp_user", None)
self.email_smtp_pass = email_config.get("smtp_pass", None)
self.require_transport_security = email_config.get(
"require_transport_security", False
)
self.enable_smtp_tls = email_config.get("enable_tls", True)
+ if self.force_tls and not self.enable_smtp_tls:
+ raise ConfigError("email.force_tls requires email.enable_tls to be true")
if self.require_transport_security and not self.enable_smtp_tls:
raise ConfigError(
"email.require_transport_security requires email.enable_tls to be true"
@@ -130,22 +136,40 @@ class EmailConfig(Config):
self.email_enable_notifs = email_config.get("enable_notifs", False)
+ self.threepid_behaviour_email = (
+ # Have Synapse handle the email sending if account_threepid_delegates.email
+ # is not defined
+ # msisdn is currently always remote while Synapse does not support any method of
+ # sending SMS messages
+ ThreepidBehaviour.REMOTE
+ if self.root.registration.account_threepid_delegate_email
+ else ThreepidBehaviour.LOCAL
+ )
+
if config.get("trust_identity_server_for_password_resets"):
raise ConfigError(
- 'The config option "trust_identity_server_for_password_resets" '
- "is no longer supported. Please remove it from the config file."
+ 'The config option "trust_identity_server_for_password_resets" has been removed.'
+ "Please consult the configuration manual at docs/usage/configuration/config_documentation.md for "
+ "details and update your config file."
)
- # If we have email config settings, assume that we can verify ownership of
- # email addresses.
- self.can_verify_email = email_config != {}
+ self.local_threepid_handling_disabled_due_to_email_config = False
+ if (
+ self.threepid_behaviour_email == ThreepidBehaviour.LOCAL
+ and email_config == {}
+ ):
+ # We cannot warn the user this has happened here
+ # Instead do so when a user attempts to reset their password
+ self.local_threepid_handling_disabled_due_to_email_config = True
+
+ self.threepid_behaviour_email = ThreepidBehaviour.OFF
# Get lifetime of a validation token in milliseconds
self.email_validation_token_lifetime = self.parse_duration(
email_config.get("validation_token_lifetime", "1h")
)
- if self.can_verify_email:
+ if self.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
missing = []
if not self.email_notif_from:
missing.append("email.notif_from")
@@ -336,3 +360,18 @@ class EmailConfig(Config):
"Config option email.invite_client_location must be a http or https URL",
path=("email", "invite_client_location"),
)
+
+
+class ThreepidBehaviour(Enum):
+ """
+ Enum to define the behaviour of Synapse with regards to when it contacts an identity
+ server for 3pid registration and password resets
+
+ REMOTE = use an external server to send tokens
+ LOCAL = send tokens ourselves
+ OFF = disable registration via 3pid and password resets
+ """
+
+ REMOTE = "remote"
+ LOCAL = "local"
+ OFF = "off"
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index ee443cea00..c2ecd977cd 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -88,5 +88,5 @@ class ExperimentalConfig(Config):
# MSC3715: dir param on /relations.
self.msc3715_enabled: bool = experimental.get("msc3715_enabled", False)
- # MSC3827: Filtering of /publicRooms by room type
- self.msc3827_enabled: bool = experimental.get("msc3827_enabled", False)
+ # MSC3848: Introduce errcodes for specific event sending failures
+ self.msc3848_enabled: bool = experimental.get("msc3848_enabled", False)
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 685a0423c5..01fb0331bc 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
+import logging
from typing import Any, Optional
from synapse.api.constants import RoomCreationPreset
@@ -20,11 +21,15 @@ from synapse.config._base import Config, ConfigError
from synapse.types import JsonDict, RoomAlias, UserID
from synapse.util.stringutils import random_string_with_symbols, strtobool
-NO_EMAIL_DELEGATE_ERROR = """\
-Delegation of email verification to an identity server is no longer supported. To
+logger = logging.getLogger(__name__)
+
+LEGACY_EMAIL_DELEGATE_WARNING = """\
+Delegation of email verification to an identity server is now deprecated. To
continue to allow users to add email addresses to their accounts, and use them for
password resets, configure Synapse with an SMTP server via the `email` setting, and
remove `account_threepid_delegates.email`.
+
+This will be an error in a future version.
"""
@@ -59,8 +64,9 @@ class RegistrationConfig(Config):
account_threepid_delegates = config.get("account_threepid_delegates") or {}
if "email" in account_threepid_delegates:
- raise ConfigError(NO_EMAIL_DELEGATE_ERROR)
- # self.account_threepid_delegate_email = account_threepid_delegates.get("email")
+ logger.warning(LEGACY_EMAIL_DELEGATE_WARNING)
+
+ self.account_threepid_delegate_email = account_threepid_delegates.get("email")
self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn")
self.default_identity_server = config.get("default_identity_server")
self.allow_guest_access = config.get("allow_guest_access", False)
diff --git a/synapse/config/tracer.py b/synapse/config/tracer.py
index c19270c6c5..d67498f50d 100644
--- a/synapse/config/tracer.py
+++ b/synapse/config/tracer.py
@@ -24,41 +24,50 @@ class TracerConfig(Config):
section = "tracing"
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
- opentracing_config = config.get("opentracing")
- if opentracing_config is None:
- opentracing_config = {}
+ tracing_config = config.get("tracing")
+ if tracing_config is None:
+ tracing_config = {}
- self.opentracer_enabled = opentracing_config.get("enabled", False)
+ self.tracing_enabled = tracing_config.get("enabled", False)
- self.jaeger_config = opentracing_config.get(
- "jaeger_config",
- {"sampler": {"type": "const", "param": 1}, "logging": False},
+ self.jaeger_exporter_config = tracing_config.get(
+ "jaeger_exporter_config",
+ {},
)
self.force_tracing_for_users: Set[str] = set()
- if not self.opentracer_enabled:
+ if not self.tracing_enabled:
return
- check_requirements("opentracing")
+ check_requirements("opentelemetry")
# The tracer is enabled so sanitize the config
- self.opentracer_whitelist: List[str] = opentracing_config.get(
+ # Default to always sample. Range: [0.0 - 1.0]
+ self.sample_rate: float = float(tracing_config.get("sample_rate", 1))
+ if self.sample_rate < 0.0 or self.sample_rate > 1.0:
+ raise ConfigError(
+ "Tracing sample_rate must be in range [0.0, 1.0].",
+ ("tracing", "sample_rate"),
+ )
+
+ self.homeserver_whitelist: List[str] = tracing_config.get(
"homeserver_whitelist", []
)
- if not isinstance(self.opentracer_whitelist, list):
- raise ConfigError("Tracer homeserver_whitelist config is malformed")
-
- force_tracing_for_users = opentracing_config.get("force_tracing_for_users", [])
- if not isinstance(force_tracing_for_users, list):
+ if not isinstance(self.homeserver_whitelist, list):
raise ConfigError(
- "Expected a list", ("opentracing", "force_tracing_for_users")
+ "Tracing homeserver_whitelist config is malformed",
+ ("tracing", "homeserver_whitelist"),
)
+
+ force_tracing_for_users = tracing_config.get("force_tracing_for_users", [])
+ if not isinstance(force_tracing_for_users, list):
+ raise ConfigError("Expected a list", ("tracing", "force_tracing_for_users"))
for i, u in enumerate(force_tracing_for_users):
if not isinstance(u, str):
raise ConfigError(
"Expected a string",
- ("opentracing", "force_tracing_for_users", f"index {i}"),
+ ("tracing", "force_tracing_for_users", f"index {i}"),
)
self.force_tracing_for_users.add(u)
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 965cb265da..389b0c5d53 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -30,7 +30,13 @@ from synapse.api.constants import (
JoinRules,
Membership,
)
-from synapse.api.errors import AuthError, EventSizeError, SynapseError
+from synapse.api.errors import (
+ AuthError,
+ Codes,
+ EventSizeError,
+ SynapseError,
+ UnstableSpecAuthError,
+)
from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS,
EventFormatVersions,
@@ -291,7 +297,11 @@ def check_state_dependent_auth_rules(
invite_level = get_named_level(auth_dict, "invite", 0)
if user_level < invite_level:
- raise AuthError(403, "You don't have permission to invite users")
+ raise UnstableSpecAuthError(
+ 403,
+ "You don't have permission to invite users",
+ errcode=Codes.INSUFFICIENT_POWER,
+ )
else:
logger.debug("Allowing! %s", event)
return
@@ -474,7 +484,11 @@ def _is_membership_change_allowed(
return
if not caller_in_room: # caller isn't joined
- raise AuthError(403, "%s not in room %s." % (event.user_id, event.room_id))
+ raise UnstableSpecAuthError(
+ 403,
+ "%s not in room %s." % (event.user_id, event.room_id),
+ errcode=Codes.NOT_JOINED,
+ )
if Membership.INVITE == membership:
# TODO (erikj): We should probably handle this more intelligently
@@ -484,10 +498,18 @@ def _is_membership_change_allowed(
if target_banned:
raise AuthError(403, "%s is banned from the room" % (target_user_id,))
elif target_in_room: # the target is already in the room.
- raise AuthError(403, "%s is already in the room." % target_user_id)
+ raise UnstableSpecAuthError(
+ 403,
+ "%s is already in the room." % target_user_id,
+ errcode=Codes.ALREADY_JOINED,
+ )
else:
if user_level < invite_level:
- raise AuthError(403, "You don't have permission to invite users")
+ raise UnstableSpecAuthError(
+ 403,
+ "You don't have permission to invite users",
+ errcode=Codes.INSUFFICIENT_POWER,
+ )
elif Membership.JOIN == membership:
# Joins are valid iff caller == target and:
# * They are not banned.
@@ -549,15 +571,27 @@ def _is_membership_change_allowed(
elif Membership.LEAVE == membership:
# TODO (erikj): Implement kicks.
if target_banned and user_level < ban_level:
- raise AuthError(403, "You cannot unban user %s." % (target_user_id,))
+ raise UnstableSpecAuthError(
+ 403,
+ "You cannot unban user %s." % (target_user_id,),
+ errcode=Codes.INSUFFICIENT_POWER,
+ )
elif target_user_id != event.user_id:
kick_level = get_named_level(auth_events, "kick", 50)
if user_level < kick_level or user_level <= target_level:
- raise AuthError(403, "You cannot kick user %s." % target_user_id)
+ raise UnstableSpecAuthError(
+ 403,
+ "You cannot kick user %s." % target_user_id,
+ errcode=Codes.INSUFFICIENT_POWER,
+ )
elif Membership.BAN == membership:
if user_level < ban_level or user_level <= target_level:
- raise AuthError(403, "You don't have permission to ban")
+ raise UnstableSpecAuthError(
+ 403,
+ "You don't have permission to ban",
+ errcode=Codes.INSUFFICIENT_POWER,
+ )
elif room_version.msc2403_knocking and Membership.KNOCK == membership:
if join_rule != JoinRules.KNOCK and (
not room_version.msc3787_knock_restricted_join_rule
@@ -567,7 +601,11 @@ def _is_membership_change_allowed(
elif target_user_id != event.user_id:
raise AuthError(403, "You cannot knock for other users")
elif target_in_room:
- raise AuthError(403, "You cannot knock on a room you are already in")
+ raise UnstableSpecAuthError(
+ 403,
+ "You cannot knock on a room you are already in",
+ errcode=Codes.ALREADY_JOINED,
+ )
elif caller_invited:
raise AuthError(403, "You are already invited to this room")
elif target_banned:
@@ -638,10 +676,11 @@ def _can_send_event(event: "EventBase", auth_events: StateMap["EventBase"]) -> b
user_level = get_user_power_level(event.user_id, auth_events)
if user_level < send_level:
- raise AuthError(
+ raise UnstableSpecAuthError(
403,
"You don't have permission to post that to the room. "
+ "user_level (%d) < send_level (%d)" % (user_level, send_level),
+ errcode=Codes.INSUFFICIENT_POWER,
)
# Check state_key
@@ -716,9 +755,10 @@ def check_historical(
historical_level = get_named_level(auth_events, "historical", 100)
if user_level < historical_level:
- raise AuthError(
+ raise UnstableSpecAuthError(
403,
'You don\'t have permission to send send historical related events ("insertion", "batch", and "marker")',
+ errcode=Codes.INSUFFICIENT_POWER,
)
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index c1f96328b4..54ffbd8170 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -405,9 +405,9 @@ class FederationClient(FederationBase):
# Prime the cache
self._get_pdu_cache[event.event_id] = event
- # FIXME: We should add a `break` here to avoid calling every
- # destination after we already found a PDU (will follow-up
- # in a separate PR)
+ # Now that we have an event, we can break out of this
+ # loop and stop asking other destinations.
+ break
except SynapseError as e:
logger.info(
@@ -727,6 +727,12 @@ class FederationClient(FederationBase):
if failover_errcodes is None:
failover_errcodes = ()
+ if not destinations:
+ # Give a bit of a clearer message if no servers were specified at all.
+ raise SynapseError(
+ 502, f"Failed to {description} via any server: No servers specified."
+ )
+
for destination in destinations:
if destination == self.server_name:
continue
@@ -776,7 +782,7 @@ class FederationClient(FederationBase):
"Failed to %s via %s", description, destination, exc_info=True
)
- raise SynapseError(502, "Failed to %s via any server" % (description,))
+ raise SynapseError(502, f"Failed to {description} via any server")
async def make_membership_event(
self,
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index ae550d3f4d..c9cd02ebeb 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -61,7 +61,7 @@ from synapse.logging.context import (
nested_logging_context,
run_in_background,
)
-from synapse.logging.opentracing import log_kv, start_active_span_from_edu, trace
+from synapse.logging.tracing import log_kv, start_active_span_from_edu, trace
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.replication.http.federation import (
ReplicationFederationSendEduRestServlet,
@@ -469,7 +469,7 @@ class FederationServer(FederationBase):
)
for pdu in pdus_by_room[room_id]:
event_id = pdu.event_id
- pdu_results[event_id] = e.error_dict()
+ pdu_results[event_id] = e.error_dict(self.hs.config)
return
for pdu in pdus_by_room[room_id]:
@@ -1399,7 +1399,7 @@ class FederationHandlerRegistry:
# Check if we have a handler on this instance
handler = self.edu_handlers.get(edu_type)
if handler:
- with start_active_span_from_edu(content, "handle_edu"):
+ with start_active_span_from_edu("handle_edu", edu_content=content):
try:
await handler(origin, content)
except SynapseError as e:
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index 41d8b937af..72bc935452 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -32,7 +32,7 @@ from synapse.events import EventBase
from synapse.federation.units import Edu
from synapse.handlers.presence import format_user_presence_state
from synapse.logging import issue9533_logger
-from synapse.logging.opentracing import SynapseTags, set_tag
+from synapse.logging.tracing import SynapseTags, set_attribute
from synapse.metrics import sent_transactions_counter
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import ReadReceipt
@@ -596,7 +596,7 @@ class PerDestinationQueue:
if not message_id:
continue
- set_tag(SynapseTags.TO_DEVICE_MESSAGE_ID, message_id)
+ set_attribute(SynapseTags.TO_DEVICE_MESSAGE_ID, message_id)
edus = [
Edu(
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index 75081810fd..3f2c8bcfa1 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -21,11 +21,13 @@ from synapse.api.errors import HttpResponseException
from synapse.events import EventBase
from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction
-from synapse.logging.opentracing import (
+from synapse.logging.tracing import (
+ Link,
+ StatusCode,
extract_text_map,
- set_tag,
- start_active_span_follows_from,
- tags,
+ get_span_context_from_context,
+ set_status,
+ start_active_span,
whitelisted_homeserver,
)
from synapse.types import JsonDict
@@ -79,7 +81,7 @@ class TransactionManager:
edus: List of EDUs to send
"""
- # Make a transaction-sending opentracing span. This span follows on from
+ # Make a transaction-sending tracing span. This span follows on from
# all the edus in that transaction. This needs to be done since there is
# no active span here, so if the edus were not received by the remote the
# span would have no causality and it would be forgotten.
@@ -88,13 +90,20 @@ class TransactionManager:
keep_destination = whitelisted_homeserver(destination)
for edu in edus:
- context = edu.get_context()
- if context:
- span_contexts.append(extract_text_map(json_decoder.decode(context)))
+ tracing_context_json = edu.get_tracing_context_json()
+ if tracing_context_json:
+ context = extract_text_map(json_decoder.decode(tracing_context_json))
+ if context:
+ span_context = get_span_context_from_context(context)
+ if span_context:
+ span_contexts.append(span_context)
if keep_destination:
- edu.strip_context()
+ edu.strip_tracing_context()
- with start_active_span_follows_from("send_transaction", span_contexts):
+ with start_active_span(
+ "send_transaction",
+ links=[Link(span_context) for span_context in span_contexts],
+ ):
logger.debug("TX [%s] _attempt_new_transaction", destination)
txn_id = str(self._next_txn_id)
@@ -166,7 +175,7 @@ class TransactionManager:
except HttpResponseException as e:
code = e.code
- set_tag(tags.ERROR, True)
+ set_status(StatusCode.ERROR, e)
logger.info("TX [%s] {%s} got %d response", destination, txn_id, code)
raise
diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py
index bb0f8d6b7b..9425106a3f 100644
--- a/synapse/federation/transport/server/_base.py
+++ b/synapse/federation/transport/server/_base.py
@@ -15,7 +15,6 @@
import functools
import logging
import re
-import time
from http import HTTPStatus
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Tuple, cast
@@ -25,12 +24,15 @@ from synapse.http.server import HttpServer, ServletCallback, is_method_cancellab
from synapse.http.servlet import parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.logging.context import run_in_background
-from synapse.logging.opentracing import (
- active_span,
- set_tag,
- span_context_from_request,
+from synapse.logging.tracing import (
+ Link,
+ context_from_request,
+ create_non_recording_span,
+ get_active_span,
+ set_attribute,
start_active_span,
- start_active_span_follows_from,
+ start_span,
+ use_span,
whitelisted_homeserver,
)
from synapse.types import JsonDict
@@ -308,60 +310,70 @@ class BaseFederationServlet:
logger.warning("authenticate_request failed: %s", e)
raise
- # update the active opentracing span with the authenticated entity
- set_tag("authenticated_entity", str(origin))
+ # update the active tracing span with the authenticated entity
+ set_attribute("authenticated_entity", str(origin))
# if the origin is authenticated and whitelisted, use its span context
# as the parent.
- context = None
+ origin_context = None
if origin and whitelisted_homeserver(origin):
- context = span_context_from_request(request)
-
- if context:
- servlet_span = active_span()
- # a scope which uses the origin's context as a parent
- processing_start_time = time.time()
- scope = start_active_span_follows_from(
+ origin_context = context_from_request(request)
+
+ if origin_context:
+ local_servlet_span = get_active_span()
+ # Create a span which uses the `origin_context` as a parent
+ # so we can see how the incoming payload was processed while
+ # we're looking at the outgoing trace. Since the parent is set
+ # to a remote span (from the origin), it won't show up in the
+ # local trace which is why we create another span below for the
+ # local trace. A span can only have one parent so we have to
+ # create two separate ones.
+ remote_parent_span = start_span(
"incoming-federation-request",
- child_of=context,
- contexts=(servlet_span,),
- start_time=processing_start_time,
+ context=origin_context,
+ # Cross-link back to the local trace so we can jump
+ # to the incoming side from the remote origin trace.
+ links=[Link(local_servlet_span.get_span_context())]
+ if local_servlet_span
+ else None,
)
+ # Create a local span to appear in the local trace
+ local_parent_span_cm = start_active_span(
+ "process-federation-request",
+ # Cross-link back to the remote outgoing trace so we can
+ # jump over there.
+ links=[Link(remote_parent_span.get_span_context())],
+ )
else:
- # just use our context as a parent
- scope = start_active_span(
- "incoming-federation-request",
+ # Otherwise just use our local active servlet context as a parent
+ local_parent_span_cm = start_active_span(
+ "process-federation-request",
)
- try:
- with scope:
- if origin and self.RATELIMIT:
- with ratelimiter.ratelimit(origin) as d:
- await d
- if request._disconnected:
- logger.warning(
- "client disconnected before we started processing "
- "request"
- )
- return None
- response = await func(
- origin, content, request.args, *args, **kwargs
+ # Don't need to record anything for the remote because no remote
+ # trace context given.
+ remote_parent_span = create_non_recording_span()
+
+ remote_parent_span_cm = use_span(remote_parent_span, end_on_exit=True)
+
+ with remote_parent_span_cm, local_parent_span_cm:
+ if origin and self.RATELIMIT:
+ with ratelimiter.ratelimit(origin) as d:
+ await d
+ if request._disconnected:
+ logger.warning(
+ "client disconnected before we started processing "
+ "request"
)
- else:
+ return None
response = await func(
origin, content, request.args, *args, **kwargs
)
- finally:
- # if we used the origin's context as the parent, add a new span using
- # the servlet span as a parent, so that we have a link
- if context:
- scope2 = start_active_span_follows_from(
- "process-federation_request",
- contexts=(scope.span,),
- start_time=processing_start_time,
+ else:
+ response = await func(
+ origin, content, request.args, *args, **kwargs
)
- scope2.close()
return response
diff --git a/synapse/federation/units.py b/synapse/federation/units.py
index b9b12fbea5..a6b590269e 100644
--- a/synapse/federation/units.py
+++ b/synapse/federation/units.py
@@ -21,6 +21,7 @@ from typing import List, Optional
import attr
+from synapse.api.constants import EventContentFields
from synapse.types import JsonDict
logger = logging.getLogger(__name__)
@@ -54,11 +55,13 @@ class Edu:
"destination": self.destination,
}
- def get_context(self) -> str:
- return getattr(self, "content", {}).get("org.matrix.opentracing_context", "{}")
+ def get_tracing_context_json(self) -> str:
+ return getattr(self, "content", {}).get(
+ EventContentFields.TRACING_CONTEXT, "{}"
+ )
- def strip_context(self) -> None:
- getattr(self, "content", {})["org.matrix.opentracing_context"] = "{}"
+ def strip_tracing_context(self) -> None:
+ getattr(self, "content", {})[EventContentFields.TRACING_CONTEXT] = "{}"
def _none_to_list(edus: Optional[List[JsonDict]]) -> List[JsonDict]:
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 3d83236b0c..bfa5535044 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -565,7 +565,7 @@ class AuthHandler:
except LoginError as e:
# this step failed. Merge the error dict into the response
# so that the client can have another go.
- errordict = e.error_dict()
+ errordict = e.error_dict(self.hs.config)
creds = await self.store.get_completed_ui_auth_stages(session.session_id)
for f in flows:
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 1a8379854c..659ee0ef5e 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -36,7 +36,7 @@ from synapse.api.errors import (
RequestSendFailed,
SynapseError,
)
-from synapse.logging.opentracing import log_kv, set_tag, trace
+from synapse.logging.tracing import log_kv, set_attribute, trace
from synapse.metrics.background_process_metrics import (
run_as_background_process,
wrap_as_background_process,
@@ -86,7 +86,7 @@ class DeviceWorkerHandler:
info on each device
"""
- set_tag("user_id", user_id)
+ set_attribute("user_id", user_id)
device_map = await self.store.get_devices_by_user(user_id)
ips = await self.store.get_last_client_ip_by_device(user_id, device_id=None)
@@ -118,8 +118,8 @@ class DeviceWorkerHandler:
ips = await self.store.get_last_client_ip_by_device(user_id, device_id)
_update_device_from_client_ips(device, ips)
- set_tag("device", str(device))
- set_tag("ips", str(ips))
+ set_attribute("device", str(device))
+ set_attribute("ips", str(ips))
return device
@@ -169,8 +169,8 @@ class DeviceWorkerHandler:
joined a room, that `user_id` may be interested in.
"""
- set_tag("user_id", user_id)
- set_tag("from_token", str(from_token))
+ set_attribute("user_id", user_id)
+ set_attribute("from_token", str(from_token))
now_room_key = self.store.get_room_max_token()
room_ids = await self.store.get_rooms_for_user(user_id)
@@ -461,8 +461,8 @@ class DeviceHandler(DeviceWorkerHandler):
except errors.StoreError as e:
if e.code == 404:
# no match
- set_tag("error", True)
- set_tag("reason", "User doesn't have that device id.")
+ set_attribute("error", True)
+ set_attribute("reason", "User doesn't have that device id.")
else:
raise
@@ -688,7 +688,7 @@ class DeviceHandler(DeviceWorkerHandler):
else:
return
- for user_id, device_id, room_id, stream_id, opentracing_context in rows:
+ for user_id, device_id, room_id, stream_id, tracing_context in rows:
hosts = set()
# Ignore any users that aren't ours
@@ -707,7 +707,7 @@ class DeviceHandler(DeviceWorkerHandler):
room_id=room_id,
stream_id=stream_id,
hosts=hosts,
- context=opentracing_context,
+ context=tracing_context,
)
# Notify replication that we've updated the device list stream.
@@ -794,8 +794,8 @@ class DeviceListUpdater:
for parsing the EDU and adding to pending updates list.
"""
- set_tag("origin", origin)
- set_tag("edu_content", str(edu_content))
+ set_attribute("origin", origin)
+ set_attribute("edu_content", str(edu_content))
user_id = edu_content.pop("user_id")
device_id = edu_content.pop("device_id")
stream_id = str(edu_content.pop("stream_id")) # They may come as ints
@@ -815,7 +815,7 @@ class DeviceListUpdater:
origin,
)
- set_tag("error", True)
+ set_attribute("error", True)
log_kv(
{
"message": "Got a device list update edu from a user and "
@@ -830,7 +830,7 @@ class DeviceListUpdater:
if not room_ids:
# We don't share any rooms with this user. Ignore update, as we
# probably won't get any further updates.
- set_tag("error", True)
+ set_attribute("error", True)
log_kv(
{
"message": "Got an update from a user for which "
@@ -1027,12 +1027,12 @@ class DeviceListUpdater:
# eventually become consistent.
return None
except FederationDeniedError as e:
- set_tag("error", True)
+ set_attribute("error", True)
log_kv({"reason": "FederationDeniedError"})
logger.info(e)
return None
except Exception as e:
- set_tag("error", True)
+ set_attribute("error", True)
log_kv(
{"message": "Exception raised by federation request", "exception": e}
)
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 444c08bc2e..9c9da0cb63 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -15,15 +15,15 @@
import logging
from typing import TYPE_CHECKING, Any, Dict
-from synapse.api.constants import EduTypes, ToDeviceEventTypes
+from synapse.api.constants import EduTypes, EventContentFields, ToDeviceEventTypes
from synapse.api.errors import SynapseError
from synapse.api.ratelimiting import Ratelimiter
from synapse.logging.context import run_in_background
-from synapse.logging.opentracing import (
+from synapse.logging.tracing import (
SynapseTags,
get_active_span_text_map,
log_kv,
- set_tag,
+ set_attribute,
)
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
from synapse.types import JsonDict, Requester, StreamKeyType, UserID, get_domain_from_id
@@ -217,10 +217,10 @@ class DeviceMessageHandler:
sender_user_id = requester.user.to_string()
message_id = random_string(16)
- set_tag(SynapseTags.TO_DEVICE_MESSAGE_ID, message_id)
+ set_attribute(SynapseTags.TO_DEVICE_MESSAGE_ID, message_id)
log_kv({"number_of_to_device_messages": len(messages)})
- set_tag("sender", sender_user_id)
+ set_attribute("sender", sender_user_id)
local_messages = {}
remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for user_id, by_device in messages.items():
@@ -273,7 +273,7 @@ class DeviceMessageHandler:
"sender": sender_user_id,
"type": message_type,
"message_id": message_id,
- "org.matrix.opentracing_context": json_encoder.encode(context),
+ EventContentFields.TRACING_CONTEXT: json_encoder.encode(context),
}
# Add messages to the database.
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index c938339ddd..a3692f00d9 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -28,7 +28,7 @@ from twisted.internet import defer
from synapse.api.constants import EduTypes
from synapse.api.errors import CodeMessageException, Codes, NotFoundError, SynapseError
from synapse.logging.context import make_deferred_yieldable, run_in_background
-from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
+from synapse.logging.tracing import log_kv, set_attribute, tag_args, trace
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
from synapse.types import (
JsonDict,
@@ -138,8 +138,8 @@ class E2eKeysHandler:
else:
remote_queries[user_id] = device_ids
- set_tag("local_key_query", str(local_query))
- set_tag("remote_key_query", str(remote_queries))
+ set_attribute("local_key_query", str(local_query))
+ set_attribute("remote_key_query", str(remote_queries))
# First get local devices.
# A map of destination -> failure response.
@@ -342,8 +342,8 @@ class E2eKeysHandler:
except Exception as e:
failure = _exception_to_failure(e)
failures[destination] = failure
- set_tag("error", True)
- set_tag("reason", str(failure))
+ set_attribute("error", True)
+ set_attribute("reason", str(failure))
return
@@ -405,7 +405,7 @@ class E2eKeysHandler:
Returns:
A map from user_id -> device_id -> device details
"""
- set_tag("local_query", str(query))
+ set_attribute("local_query", str(query))
local_query: List[Tuple[str, Optional[str]]] = []
result_dict: Dict[str, Dict[str, dict]] = {}
@@ -420,7 +420,7 @@ class E2eKeysHandler:
"user_id": user_id,
}
)
- set_tag("error", True)
+ set_attribute("error", True)
raise SynapseError(400, "Not a user here")
if not device_ids:
@@ -477,8 +477,8 @@ class E2eKeysHandler:
domain = get_domain_from_id(user_id)
remote_queries.setdefault(domain, {})[user_id] = one_time_keys
- set_tag("local_key_query", str(local_query))
- set_tag("remote_key_query", str(remote_queries))
+ set_attribute("local_key_query", str(local_query))
+ set_attribute("remote_key_query", str(remote_queries))
results = await self.store.claim_e2e_one_time_keys(local_query)
@@ -494,7 +494,7 @@ class E2eKeysHandler:
@trace
async def claim_client_keys(destination: str) -> None:
- set_tag("destination", destination)
+ set_attribute("destination", destination)
device_keys = remote_queries[destination]
try:
remote_result = await self.federation.claim_client_keys(
@@ -507,8 +507,8 @@ class E2eKeysHandler:
except Exception as e:
failure = _exception_to_failure(e)
failures[destination] = failure
- set_tag("error", True)
- set_tag("reason", str(failure))
+ set_attribute("error", True)
+ set_attribute("reason", str(failure))
await make_deferred_yieldable(
defer.gatherResults(
@@ -611,7 +611,7 @@ class E2eKeysHandler:
result = await self.store.count_e2e_one_time_keys(user_id, device_id)
- set_tag("one_time_key_counts", str(result))
+ set_attribute("one_time_key_counts", str(result))
return {"one_time_key_counts": result}
async def _upload_one_time_keys_for_user(
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index 28dc08c22a..8786534e54 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -25,7 +25,7 @@ from synapse.api.errors import (
StoreError,
SynapseError,
)
-from synapse.logging.opentracing import log_kv, trace
+from synapse.logging.tracing import log_kv, trace
from synapse.storage.databases.main.e2e_room_keys import RoomKey
from synapse.types import JsonDict
from synapse.util.async_helpers import Linearizer
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 59580ef93e..30f1585a85 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -548,9 +548,9 @@ class FederationHandler:
)
if ret.partial_state:
- # TODO(faster_joins): roll this back if we don't manage to start the
- # background resync (eg process_remote_join fails)
- # https://github.com/matrix-org/synapse/issues/12998
+ # Mark the room as having partial state.
+ # The background process is responsible for unmarking this flag,
+ # even if the join fails.
await self.store.store_partial_state_room(room_id, ret.servers_in_room)
try:
@@ -576,17 +576,21 @@ class FederationHandler:
room_id,
)
raise LimitExceededError(msg=e.msg, errcode=e.errcode, retry_after_ms=0)
-
- if ret.partial_state:
- # Kick off the process of asynchronously fetching the state for this
- # room.
- run_as_background_process(
- desc="sync_partial_state_room",
- func=self._sync_partial_state_room,
- initial_destination=origin,
- other_destinations=ret.servers_in_room,
- room_id=room_id,
- )
+ finally:
+ # Always kick off the background process that asynchronously fetches
+ # state for the room.
+ # If the join failed, the background process is responsible for
+ # cleaning up — including unmarking the room as a partial state room.
+ if ret.partial_state:
+ # Kick off the process of asynchronously fetching the state for this
+ # room.
+ run_as_background_process(
+ desc="sync_partial_state_room",
+ func=self._sync_partial_state_room,
+ initial_destination=origin,
+ other_destinations=ret.servers_in_room,
+ room_id=room_id,
+ )
# We wait here until this instance has seen the events come down
# replication (if we're using replication) as the below uses caches.
@@ -1541,15 +1545,16 @@ class FederationHandler:
# Make an infinite iterator of destinations to try. Once we find a working
# destination, we'll stick with it until it flakes.
+ destinations: Collection[str]
if initial_destination is not None:
# Move `initial_destination` to the front of the list.
destinations = list(other_destinations)
if initial_destination in destinations:
destinations.remove(initial_destination)
destinations = [initial_destination] + destinations
- destination_iter = itertools.cycle(destinations)
else:
- destination_iter = itertools.cycle(other_destinations)
+ destinations = other_destinations
+ destination_iter = itertools.cycle(destinations)
# `destination` is the current remote homeserver we're pulling from.
destination = next(destination_iter)
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 4429319265..8968b705d4 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -279,7 +279,8 @@ class FederationEventHandler:
)
try:
- await self._process_received_pdu(origin, pdu, state_ids=None)
+ context = await self._state_handler.compute_event_context(pdu)
+ await self._process_received_pdu(origin, pdu, context)
except PartialStateConflictError:
# The room was un-partial stated while we were processing the PDU.
# Try once more, with full state this time.
@@ -287,7 +288,8 @@ class FederationEventHandler:
"Room %s was un-partial stated while processing the PDU, trying again.",
room_id,
)
- await self._process_received_pdu(origin, pdu, state_ids=None)
+ context = await self._state_handler.compute_event_context(pdu)
+ await self._process_received_pdu(origin, pdu, context)
async def on_send_membership_event(
self, origin: str, event: EventBase
@@ -317,6 +319,7 @@ class FederationEventHandler:
The event and context of the event after inserting it into the room graph.
Raises:
+ RuntimeError if any prev_events are missing
SynapseError if the event is not accepted into the room
PartialStateConflictError if the room was un-partial stated in between
computing the state at the event and persisting it. The caller should
@@ -377,7 +380,7 @@ class FederationEventHandler:
# need to.
await self._event_creation_handler.cache_joined_hosts_for_event(event, context)
- await self._check_for_soft_fail(event, None, origin=origin)
+ await self._check_for_soft_fail(event, context=context, origin=origin)
await self._run_push_actions_and_persist_event(event, context)
return event, context
@@ -535,27 +538,30 @@ class FederationEventHandler:
#
# This is the same operation as we do when we receive a regular event
# over federation.
- state_ids = await self._resolve_state_at_missing_prevs(destination, event)
-
- # build a new state group for it if need be
- context = await self._state_handler.compute_event_context(
- event,
- state_ids_before_event=state_ids,
+ context = await self._compute_event_context_with_maybe_missing_prevs(
+ destination, event
)
if context.partial_state:
# this can happen if some or all of the event's prev_events still have
- # partial state - ie, an event has an earlier stream_ordering than one
- # or more of its prev_events, so we de-partial-state it before its
- # prev_events.
+ # partial state. We were careful to only pick events from the db without
+ # partial-state prev events, so that implies that a prev event has
+ # been persisted (with partial state) since we did the query.
#
- # TODO(faster_joins): we probably need to be more intelligent, and
- # exclude partial-state prev_events from consideration
- # https://github.com/matrix-org/synapse/issues/13001
+ # So, let's just ignore `event` for now; when we re-run the db query
+ # we should instead get its partial-state prev event, which we will
+ # de-partial-state, and then come back to event.
logger.warning(
- "%s still has partial state: can't de-partial-state it yet",
+ "%s still has prev_events with partial state: can't de-partial-state it yet",
event.event_id,
)
return
+
+ # since the state at this event has changed, we should now re-evaluate
+ # whether it should have been rejected. We must already have all of the
+ # auth events (from last time we went round this path), so there is no
+ # need to pass the origin.
+ await self._check_event_auth(None, event, context)
+
await self._store.update_state_for_partial_state_event(event, context)
self._state_storage_controller.notify_event_un_partial_stated(
event.event_id
@@ -811,29 +817,55 @@ class FederationEventHandler:
return
try:
- state_ids = await self._resolve_state_at_missing_prevs(origin, event)
- # TODO(faster_joins): make sure that _resolve_state_at_missing_prevs does
- # not return partial state
- # https://github.com/matrix-org/synapse/issues/13002
+ try:
+ context = await self._compute_event_context_with_maybe_missing_prevs(
+ origin, event
+ )
+ await self._process_received_pdu(
+ origin,
+ event,
+ context,
+ backfilled=backfilled,
+ )
+ except PartialStateConflictError:
+ # The room was un-partial stated while we were processing the event.
+ # Try once more, with full state this time.
+ context = await self._compute_event_context_with_maybe_missing_prevs(
+ origin, event
+ )
- await self._process_received_pdu(
- origin, event, state_ids=state_ids, backfilled=backfilled
- )
+ # We ought to have full state now, barring some unlikely race where we left and
+ # rejoned the room in the background.
+ if context.partial_state:
+ raise AssertionError(
+ f"Event {event.event_id} still has a partial resolved state "
+ f"after room {event.room_id} was un-partial stated"
+ )
+
+ await self._process_received_pdu(
+ origin,
+ event,
+ context,
+ backfilled=backfilled,
+ )
except FederationError as e:
if e.code == 403:
logger.warning("Pulled event %s failed history check.", event_id)
else:
raise
- async def _resolve_state_at_missing_prevs(
+ async def _compute_event_context_with_maybe_missing_prevs(
self, dest: str, event: EventBase
- ) -> Optional[StateMap[str]]:
- """Calculate the state at an event with missing prev_events.
+ ) -> EventContext:
+ """Build an EventContext structure for a non-outlier event whose prev_events may
+ be missing.
- This is used when we have pulled a batch of events from a remote server, and
- still don't have all the prev_events.
+ This is used when we have pulled a batch of events from a remote server, and may
+ not have all the prev_events.
- If we already have all the prev_events for `event`, this method does nothing.
+ To build an EventContext, we need to calculate the state before the event. If we
+ already have all the prev_events for `event`, we can simply use the state after
+ the prev_events to calculate the state before `event`.
Otherwise, the missing prevs become new backwards extremities, and we fall back
to asking the remote server for the state after each missing `prev_event`,
@@ -854,8 +886,7 @@ class FederationEventHandler:
event: an event to check for missing prevs.
Returns:
- if we already had all the prev events, `None`. Otherwise, returns
- the event ids of the state at `event`.
+ The event context.
Raises:
FederationError if we fail to get the state from the remote server after any
@@ -869,7 +900,7 @@ class FederationEventHandler:
missing_prevs = prevs - seen
if not missing_prevs:
- return None
+ return await self._state_handler.compute_event_context(event)
logger.info(
"Event %s is missing prev_events %s: calculating state for a "
@@ -881,9 +912,15 @@ class FederationEventHandler:
# resolve them to find the correct state at the current event.
try:
+ # Determine whether we may be about to retrieve partial state
+ # Events may be un-partial stated right after we compute the partial state
+ # flag, but that's okay, as long as the flag errs on the conservative side.
+ partial_state_flags = await self._store.get_partial_state_events(seen)
+ partial_state = any(partial_state_flags.values())
+
# Get the state of the events we know about
ours = await self._state_storage_controller.get_state_groups_ids(
- room_id, seen
+ room_id, seen, await_full_state=False
)
# state_maps is a list of mappings from (type, state_key) to event_id
@@ -929,7 +966,9 @@ class FederationEventHandler:
"We can't get valid state history.",
affected=event_id,
)
- return state_map
+ return await self._state_handler.compute_event_context(
+ event, state_ids_before_event=state_map, partial_state=partial_state
+ )
async def _get_state_ids_after_missing_prev_event(
self,
@@ -1098,7 +1137,7 @@ class FederationEventHandler:
self,
origin: str,
event: EventBase,
- state_ids: Optional[StateMap[str]],
+ context: EventContext,
backfilled: bool = False,
) -> None:
"""Called when we have a new non-outlier event.
@@ -1120,24 +1159,18 @@ class FederationEventHandler:
event: event to be persisted
- state_ids: Normally None, but if we are handling a gap in the graph
- (ie, we are missing one or more prev_events), the resolved state at the
- event. Must not be partial state.
+ context: The `EventContext` to persist the event with.
backfilled: True if this is part of a historical batch of events (inhibits
notification to clients, and validation of device keys.)
PartialStateConflictError: if the room was un-partial stated in between
- computing the state at the event and persisting it. The caller should retry
- exactly once in this case. Will never be raised if `state_ids` is provided.
+ computing the state at the event and persisting it. The caller should
+ recompute `context` and retry exactly once when this happens.
"""
logger.debug("Processing event: %s", event)
assert not event.internal_metadata.outlier
- context = await self._state_handler.compute_event_context(
- event,
- state_ids_before_event=state_ids,
- )
try:
await self._check_event_auth(origin, event, context)
except AuthError as e:
@@ -1149,7 +1182,7 @@ class FederationEventHandler:
# For new (non-backfilled and non-outlier) events we check if the event
# passes auth based on the current state. If it doesn't then we
# "soft-fail" the event.
- await self._check_for_soft_fail(event, state_ids, origin=origin)
+ await self._check_for_soft_fail(event, context=context, origin=origin)
await self._run_push_actions_and_persist_event(event, context, backfilled)
@@ -1561,13 +1594,15 @@ class FederationEventHandler:
)
async def _check_event_auth(
- self, origin: str, event: EventBase, context: EventContext
+ self, origin: Optional[str], event: EventBase, context: EventContext
) -> None:
"""
Checks whether an event should be rejected (for failing auth checks).
Args:
- origin: The host the event originates from.
+ origin: The host the event originates from. This is used to fetch
+ any missing auth events. It can be set to None, but only if we are
+ sure that we already have all the auth events.
event: The event itself.
context:
The event context.
@@ -1710,7 +1745,7 @@ class FederationEventHandler:
async def _check_for_soft_fail(
self,
event: EventBase,
- state_ids: Optional[StateMap[str]],
+ context: EventContext,
origin: str,
) -> None:
"""Checks if we should soft fail the event; if so, marks the event as
@@ -1721,7 +1756,7 @@ class FederationEventHandler:
Args:
event
- state_ids: The state at the event if we don't have all the event's prev events
+ context: The `EventContext` which we are about to persist the event with.
origin: The host the event originates from.
"""
if await self._store.is_partial_state_room(event.room_id):
@@ -1747,11 +1782,15 @@ class FederationEventHandler:
auth_types = auth_types_for_event(room_version_obj, event)
# Calculate the "current state".
- if state_ids is not None:
- # If we're explicitly given the state then we won't have all the
- # prev events, and so we have a gap in the graph. In this case
- # we want to be a little careful as we might have been down for
- # a while and have an incorrect view of the current state,
+ seen_event_ids = await self._store.have_events_in_timeline(prev_event_ids)
+ has_missing_prevs = bool(prev_event_ids - seen_event_ids)
+ if has_missing_prevs:
+ # We don't have all the prev_events of this event, which means we have a
+ # gap in the graph, and the new event is going to become a new backwards
+ # extremity.
+ #
+ # In this case we want to be a little careful as we might have been
+ # down for a while and have an incorrect view of the current state,
# however we still want to do checks as gaps are easy to
# maliciously manufacture.
#
@@ -1764,6 +1803,7 @@ class FederationEventHandler:
event.room_id, extrem_ids
)
state_sets: List[StateMap[str]] = list(state_sets_d.values())
+ state_ids = await context.get_prev_state_ids()
state_sets.append(state_ids)
current_state_ids = (
await self._state_resolution_handler.resolve_events_with_store(
@@ -1813,7 +1853,7 @@ class FederationEventHandler:
event.internal_metadata.soft_failed = True
async def _load_or_fetch_auth_events_for_event(
- self, destination: str, event: EventBase
+ self, destination: Optional[str], event: EventBase
) -> Collection[EventBase]:
"""Fetch this event's auth_events, from database or remote
@@ -1829,12 +1869,19 @@ class FederationEventHandler:
Args:
destination: where to send the /event_auth request. Typically the server
that sent us `event` in the first place.
+
+ If this is None, no attempt is made to load any missing auth events:
+ rather, an AssertionError is raised if there are any missing events.
+
event: the event whose auth_events we want
Returns:
all of the events listed in `event.auth_events_ids`, after deduplication
Raises:
+ AssertionError if some auth events were missing and no `destination` was
+ supplied.
+
AuthError if we were unable to fetch the auth_events for any reason.
"""
event_auth_event_ids = set(event.auth_event_ids())
@@ -1846,6 +1893,13 @@ class FederationEventHandler:
)
if not missing_auth_event_ids:
return event_auth_events.values()
+ if destination is None:
+ # this shouldn't happen: destination must be set unless we know we have already
+ # persisted the auth events.
+ raise AssertionError(
+ "_load_or_fetch_auth_events_for_event() called with no destination for "
+ "an event with missing auth_events"
+ )
logger.info(
"Event %s refers to unknown auth events %s: fetching auth chain",
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 9571d461c8..e5afe84df9 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -26,6 +26,7 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.api.ratelimiting import Ratelimiter
+from synapse.config.emailconfig import ThreepidBehaviour
from synapse.http import RequestTimedOutError
from synapse.http.client import SimpleHttpClient
from synapse.http.site import SynapseRequest
@@ -415,6 +416,48 @@ class IdentityHandler:
return session_id
+ async def request_email_token(
+ self,
+ id_server: str,
+ email: str,
+ client_secret: str,
+ send_attempt: int,
+ next_link: Optional[str] = None,
+ ) -> JsonDict:
+ """
+ Request an external server send an email on our behalf for the purposes of threepid
+ validation.
+
+ Args:
+ id_server: The identity server to proxy to
+ email: The email to send the message to
+ client_secret: The unique client_secret sends by the user
+ send_attempt: Which attempt this is
+ next_link: A link to redirect the user to once they submit the token
+
+ Returns:
+ The json response body from the server
+ """
+ params = {
+ "email": email,
+ "client_secret": client_secret,
+ "send_attempt": send_attempt,
+ }
+ if next_link:
+ params["next_link"] = next_link
+
+ try:
+ data = await self.http_client.post_json_get_json(
+ id_server + "/_matrix/identity/api/v1/validate/email/requestToken",
+ params,
+ )
+ return data
+ except HttpResponseException as e:
+ logger.info("Proxied requestToken failed: %r", e)
+ raise e.to_synapse_error()
+ except RequestTimedOutError:
+ raise SynapseError(500, "Timed out contacting identity server")
+
async def requestMsisdnToken(
self,
id_server: str,
@@ -488,7 +531,18 @@ class IdentityHandler:
validation_session = None
# Try to validate as email
- if self.hs.config.email.can_verify_email:
+ if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
+ # Remote emails will only be used if a valid identity server is provided.
+ assert (
+ self.hs.config.registration.account_threepid_delegate_email is not None
+ )
+
+ # Ask our delegated email identity server
+ validation_session = await self.threepid_from_creds(
+ self.hs.config.registration.account_threepid_delegate_email,
+ threepid_creds,
+ )
+ elif self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
# Get a validated session matching these details
validation_session = await self.store.get_threepid_validation_session(
"email", client_secret, sid=sid, validated=True
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index bd7baef051..e85b540451 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -41,6 +41,7 @@ from synapse.api.errors import (
NotFoundError,
ShadowBanError,
SynapseError,
+ UnstableSpecAuthError,
UnsupportedRoomVersionError,
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
@@ -149,7 +150,11 @@ class MessageHandler:
"Attempted to retrieve data from a room for a user that has never been in it. "
"This should not have happened."
)
- raise SynapseError(403, "User not in room", errcode=Codes.FORBIDDEN)
+ raise UnstableSpecAuthError(
+ 403,
+ "User not in room",
+ errcode=Codes.NOT_JOINED,
+ )
return data
@@ -334,7 +339,11 @@ class MessageHandler:
break
else:
# Loop fell through, AS has no interested users in room
- raise AuthError(403, "Appservice not in room")
+ raise UnstableSpecAuthError(
+ 403,
+ "Appservice not in room",
+ errcode=Codes.NOT_JOINED,
+ )
return {
user_id: {
@@ -1135,6 +1144,10 @@ class EventCreationHandler:
context = await self.state.compute_event_context(
event,
state_ids_before_event=state_map_for_event,
+ # TODO(faster_joins): check how MSC2716 works and whether we can have
+ # partial state here
+ # https://github.com/matrix-org/synapse/issues/13003
+ partial_state=False,
)
else:
context = await self.state.compute_event_context(event)
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 895ea63ed3..741504ba9f 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -34,7 +34,6 @@ from typing import (
Callable,
Collection,
Dict,
- FrozenSet,
Generator,
Iterable,
List,
@@ -42,7 +41,6 @@ from typing import (
Set,
Tuple,
Type,
- Union,
)
from prometheus_client import Counter
@@ -68,7 +66,6 @@ from synapse.storage.databases.main import DataStore
from synapse.streams import EventSource
from synapse.types import JsonDict, StreamKeyType, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer
-from synapse.util.caches.descriptors import _CacheContext, cached
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
@@ -1656,15 +1653,18 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
# doesn't return. C.f. #5503.
return [], max_token
- # Figure out which other users this user should receive updates for
- users_interested_in = await self._get_interested_in(user, explicit_room_id)
+ # Figure out which other users this user should explicitly receive
+ # updates for
+ additional_users_interested_in = (
+ await self.get_presence_router().get_interested_users(user.to_string())
+ )
# We have a set of users that we're interested in the presence of. We want to
# cross-reference that with the users that have actually changed their presence.
# Check whether this user should see all user updates
- if users_interested_in == PresenceRouter.ALL_USERS:
+ if additional_users_interested_in == PresenceRouter.ALL_USERS:
# Provide presence state for all users
presence_updates = await self._filter_all_presence_updates_for_user(
user_id, include_offline, from_key
@@ -1673,34 +1673,47 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
return presence_updates, max_token
# Make mypy happy. users_interested_in should now be a set
- assert not isinstance(users_interested_in, str)
+ assert not isinstance(additional_users_interested_in, str)
+
+ # We always care about our own presence.
+ additional_users_interested_in.add(user_id)
+
+ if explicit_room_id:
+ user_ids = await self.store.get_users_in_room(explicit_room_id)
+ additional_users_interested_in.update(user_ids)
# The set of users that we're interested in and that have had a presence update.
# We'll actually pull the presence updates for these users at the end.
- interested_and_updated_users: Union[Set[str], FrozenSet[str]] = set()
+ interested_and_updated_users: Collection[str]
if from_key is not None:
# First get all users that have had a presence update
updated_users = stream_change_cache.get_all_entities_changed(from_key)
# Cross-reference users we're interested in with those that have had updates.
- # Use a slightly-optimised method for processing smaller sets of updates.
- if updated_users is not None and len(updated_users) < 500:
- # For small deltas, it's quicker to get all changes and then
- # cross-reference with the users we're interested in
+ if updated_users is not None:
+ # If we have the full list of changes for presence we can
+ # simply check which ones share a room with the user.
get_updates_counter.labels("stream").inc()
- for other_user_id in updated_users:
- if other_user_id in users_interested_in:
- # mypy thinks this variable could be a FrozenSet as it's possibly set
- # to one in the `get_entities_changed` call below, and `add()` is not
- # method on a FrozenSet. That doesn't affect us here though, as
- # `interested_and_updated_users` is clearly a set() above.
- interested_and_updated_users.add(other_user_id) # type: ignore
+
+ sharing_users = await self.store.do_users_share_a_room(
+ user_id, updated_users
+ )
+
+ interested_and_updated_users = (
+ sharing_users.union(additional_users_interested_in)
+ ).intersection(updated_users)
+
else:
# Too many possible updates. Find all users we can see and check
# if any of them have changed.
get_updates_counter.labels("full").inc()
+ users_interested_in = (
+ await self.store.get_users_who_share_room_with_user(user_id)
+ )
+ users_interested_in.update(additional_users_interested_in)
+
interested_and_updated_users = (
stream_change_cache.get_entities_changed(
users_interested_in, from_key
@@ -1709,7 +1722,10 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
else:
# No from_key has been specified. Return the presence for all users
# this user is interested in
- interested_and_updated_users = users_interested_in
+ interested_and_updated_users = (
+ await self.store.get_users_who_share_room_with_user(user_id)
+ )
+ interested_and_updated_users.update(additional_users_interested_in)
# Retrieve the current presence state for each user
users_to_state = await self.get_presence_handler().current_state_for_users(
@@ -1804,62 +1820,6 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
def get_current_key(self) -> int:
return self.store.get_current_presence_token()
- @cached(num_args=2, cache_context=True)
- async def _get_interested_in(
- self,
- user: UserID,
- explicit_room_id: Optional[str] = None,
- cache_context: Optional[_CacheContext] = None,
- ) -> Union[Set[str], str]:
- """Returns the set of users that the given user should see presence
- updates for.
-
- Args:
- user: The user to retrieve presence updates for.
- explicit_room_id: The users that are in the room will be returned.
-
- Returns:
- A set of user IDs to return presence updates for, or "ALL" to return all
- known updates.
- """
- user_id = user.to_string()
- users_interested_in = set()
- users_interested_in.add(user_id) # So that we receive our own presence
-
- # cache_context isn't likely to ever be None due to the @cached decorator,
- # but we can't have a non-optional argument after the optional argument
- # explicit_room_id either. Assert cache_context is not None so we can use it
- # without mypy complaining.
- assert cache_context
-
- # Check with the presence router whether we should poll additional users for
- # their presence information
- additional_users = await self.get_presence_router().get_interested_users(
- user.to_string()
- )
- if additional_users == PresenceRouter.ALL_USERS:
- # If the module requested that this user see the presence updates of *all*
- # users, then simply return that instead of calculating what rooms this
- # user shares
- return PresenceRouter.ALL_USERS
-
- # Add the additional users from the router
- users_interested_in.update(additional_users)
-
- # Find the users who share a room with this user
- users_who_share_room = await self.store.get_users_who_share_room_with_user(
- user_id, on_invalidate=cache_context.invalidate
- )
- users_interested_in.update(users_who_share_room)
-
- if explicit_room_id:
- user_ids = await self.store.get_users_in_room(
- explicit_room_id, on_invalidate=cache_context.invalidate
- )
- users_interested_in.update(user_ids)
-
- return users_interested_in
-
def handle_timeouts(
user_states: List[UserPresenceState],
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index a02fc45e71..72d25df8c8 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -74,7 +74,6 @@ class RelationsHandler:
room_id: str,
relation_type: Optional[str] = None,
event_type: Optional[str] = None,
- aggregation_key: Optional[str] = None,
limit: int = 5,
direction: str = "b",
from_token: Optional[StreamToken] = None,
@@ -90,7 +89,6 @@ class RelationsHandler:
room_id: The room the event belongs to.
relation_type: Only fetch events with this relation type, if given.
event_type: Only fetch events with this event type, if given.
- aggregation_key: Only fetch events with this aggregation key, if given.
limit: Only fetch the most recent `limit` events.
direction: Whether to fetch the most recent first (`"b"`) or the
oldest first (`"f"`).
@@ -123,7 +121,6 @@ class RelationsHandler:
room_id=room_id,
relation_type=relation_type,
event_type=event_type,
- aggregation_key=aggregation_key,
limit=limit,
direction=direction,
from_token=from_token,
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 29868eb743..bb0bdb8e6f 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -182,7 +182,7 @@ class RoomListHandler:
== HistoryVisibility.WORLD_READABLE,
"guest_can_join": room["guest_access"] == "can_join",
"join_rule": room["join_rules"],
- "org.matrix.msc3827.room_type": room["room_type"],
+ "room_type": room["room_type"],
}
# Filter out Nones – rather omit the field altogether
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 30b4cb23df..520c52e013 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -1679,7 +1679,11 @@ class RoomMemberMasterHandler(RoomMemberHandler):
]
if len(remote_room_hosts) == 0:
- raise SynapseError(404, "No known servers")
+ raise SynapseError(
+ 404,
+ "Can't join remote room because no servers "
+ "that are in the room have been provided.",
+ )
check_complexity = self.hs.config.server.limit_remote_rooms.enabled
if (
diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index 13098f56ed..ebd445adca 100644
--- a/synapse/handlers/room_summary.py
+++ b/synapse/handlers/room_summary.py
@@ -28,11 +28,11 @@ from synapse.api.constants import (
RoomTypes,
)
from synapse.api.errors import (
- AuthError,
Codes,
NotFoundError,
StoreError,
SynapseError,
+ UnstableSpecAuthError,
UnsupportedRoomVersionError,
)
from synapse.api.ratelimiting import Ratelimiter
@@ -175,10 +175,11 @@ class RoomSummaryHandler:
# First of all, check that the room is accessible.
if not await self._is_local_room_accessible(requested_room_id, requester):
- raise AuthError(
+ raise UnstableSpecAuthError(
403,
"User %s not in room %s, and room previews are disabled"
% (requester, requested_room_id),
+ errcode=Codes.NOT_JOINED,
)
# If this is continuing a previous session, pull the persisted data.
@@ -452,7 +453,6 @@ class RoomSummaryHandler:
"type": e.type,
"state_key": e.state_key,
"content": e.content,
- "room_id": e.room_id,
"sender": e.sender,
"origin_server_ts": e.origin_server_ts,
}
diff --git a/synapse/handlers/send_email.py b/synapse/handlers/send_email.py
index a305a66860..e2844799e8 100644
--- a/synapse/handlers/send_email.py
+++ b/synapse/handlers/send_email.py
@@ -23,10 +23,12 @@ from pkg_resources import parse_version
import twisted
from twisted.internet.defer import Deferred
-from twisted.internet.interfaces import IOpenSSLContextFactory, IReactorTCP
+from twisted.internet.interfaces import IOpenSSLContextFactory
+from twisted.internet.ssl import optionsForClientTLS
from twisted.mail.smtp import ESMTPSender, ESMTPSenderFactory
from synapse.logging.context import make_deferred_yieldable
+from synapse.types import ISynapseReactor
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -48,7 +50,7 @@ class _NoTLSESMTPSender(ESMTPSender):
async def _sendmail(
- reactor: IReactorTCP,
+ reactor: ISynapseReactor,
smtphost: str,
smtpport: int,
from_addr: str,
@@ -59,6 +61,7 @@ async def _sendmail(
require_auth: bool = False,
require_tls: bool = False,
enable_tls: bool = True,
+ force_tls: bool = False,
) -> None:
"""A simple wrapper around ESMTPSenderFactory, to allow substitution in tests
@@ -73,8 +76,9 @@ async def _sendmail(
password: password to give when authenticating
require_auth: if auth is not offered, fail the request
require_tls: if TLS is not offered, fail the reqest
- enable_tls: True to enable TLS. If this is False and require_tls is True,
+ enable_tls: True to enable STARTTLS. If this is False and require_tls is True,
the request will fail.
+ force_tls: True to enable Implicit TLS.
"""
msg = BytesIO(msg_bytes)
d: "Deferred[object]" = Deferred()
@@ -105,13 +109,23 @@ async def _sendmail(
# set to enable TLS.
factory = build_sender_factory(hostname=smtphost if enable_tls else None)
- reactor.connectTCP(
- smtphost,
- smtpport,
- factory,
- timeout=30,
- bindAddress=None,
- )
+ if force_tls:
+ reactor.connectSSL(
+ smtphost,
+ smtpport,
+ factory,
+ optionsForClientTLS(smtphost),
+ timeout=30,
+ bindAddress=None,
+ )
+ else:
+ reactor.connectTCP(
+ smtphost,
+ smtpport,
+ factory,
+ timeout=30,
+ bindAddress=None,
+ )
await make_deferred_yieldable(d)
@@ -132,6 +146,7 @@ class SendEmailHandler:
self._smtp_pass = passwd.encode("utf-8") if passwd is not None else None
self._require_transport_security = hs.config.email.require_transport_security
self._enable_tls = hs.config.email.enable_smtp_tls
+ self._force_tls = hs.config.email.force_tls
self._sendmail = _sendmail
@@ -189,4 +204,5 @@ class SendEmailHandler:
require_auth=self._smtp_user is not None,
require_tls=self._require_transport_security,
enable_tls=self._enable_tls,
+ force_tls=self._force_tls,
)
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index d42a414c90..8dc05d648d 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -25,7 +25,12 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
from synapse.handlers.relations import BundledAggregations
from synapse.logging.context import current_context
-from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
+from synapse.logging.tracing import (
+ SynapseTags,
+ log_kv,
+ set_attribute,
+ start_active_span,
+)
from synapse.push.clientformat import format_push_rules_for_user
from synapse.storage.databases.main.event_push_actions import NotifCounts
from synapse.storage.roommember import MemberSummary
@@ -391,12 +396,12 @@ class SyncHandler:
indoctrination.
"""
with start_active_span("sync.current_sync_for_user"):
- log_kv({"since_token": since_token})
+ log_kv({"since_token": str(since_token)})
sync_result = await self.generate_sync_result(
sync_config, since_token, full_state
)
- set_tag(SynapseTags.SYNC_RESULT, bool(sync_result))
+ set_attribute(SynapseTags.SYNC_RESULT, bool(sync_result))
return sync_result
async def push_rules_for_user(self, user: UserID) -> Dict[str, Dict[str, list]]:
@@ -1084,7 +1089,7 @@ class SyncHandler:
# to query up to a given point.
# Always use the `now_token` in `SyncResultBuilder`
now_token = self.event_sources.get_current_token()
- log_kv({"now_token": now_token})
+ log_kv({"now_token": str(now_token)})
logger.debug(
"Calculating sync response for %r between %s and %s",
@@ -1337,7 +1342,7 @@ class SyncHandler:
# `/sync`
message_id = message.pop("message_id", None)
if message_id:
- set_tag(SynapseTags.TO_DEVICE_MESSAGE_ID, message_id)
+ set_attribute(SynapseTags.TO_DEVICE_MESSAGE_ID, message_id)
logger.debug(
"Returning %d to-device messages between %d and %d (current token: %d)",
@@ -1997,13 +2002,13 @@ class SyncHandler:
upto_token = room_builder.upto_token
with start_active_span("sync.generate_room_entry"):
- set_tag("room_id", room_id)
+ set_attribute("room_id", room_id)
log_kv({"events": len(events or ())})
log_kv(
{
- "since_token": since_token,
- "upto_token": upto_token,
+ "since_token": str(since_token),
+ "upto_token": str(upto_token),
}
)
@@ -2018,7 +2023,7 @@ class SyncHandler:
log_kv(
{
"batch_events": len(batch.events),
- "prev_batch": batch.prev_batch,
+ "prev_batch": str(batch.prev_batch),
"batch_limited": batch.limited,
}
)
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index d104ea07fe..27aa0d3126 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -489,8 +489,15 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]):
handler = self.get_typing_handler()
events = []
- for room_id in handler._room_serials.keys():
- if handler._room_serials[room_id] <= from_key:
+
+ # Work on a copy of things here as these may change in the handler while
+ # waiting for the AS `is_interested_in_room` call to complete.
+ # Shallow copy is safe as no nested data is present.
+ latest_room_serial = handler._latest_room_serial
+ room_serials = handler._room_serials.copy()
+
+ for room_id, serial in room_serials.items():
+ if serial <= from_key:
continue
if not await service.is_interested_in_room(room_id, self._main_store):
@@ -498,7 +505,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]):
events.append(self._make_event_for(room_id))
- return events, handler._latest_room_serial
+ return events, latest_room_serial
async def get_new_events(
self,
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index a744d68c64..05cebb5d4d 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -19,6 +19,7 @@ from twisted.web.client import PartialDownloadError
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, LoginError, SynapseError
+from synapse.config.emailconfig import ThreepidBehaviour
from synapse.util import json_decoder
if TYPE_CHECKING:
@@ -152,7 +153,7 @@ class _BaseThreepidAuthChecker:
logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,))
- # msisdns are currently always verified via the IS
+ # msisdns are currently always ThreepidBehaviour.REMOTE
if medium == "msisdn":
if not self.hs.config.registration.account_threepid_delegate_msisdn:
raise SynapseError(
@@ -163,7 +164,18 @@ class _BaseThreepidAuthChecker:
threepid_creds,
)
elif medium == "email":
- if self.hs.config.email.can_verify_email:
+ if (
+ self.hs.config.email.threepid_behaviour_email
+ == ThreepidBehaviour.REMOTE
+ ):
+ assert self.hs.config.registration.account_threepid_delegate_email
+ threepid = await identity_handler.threepid_from_creds(
+ self.hs.config.registration.account_threepid_delegate_email,
+ threepid_creds,
+ )
+ elif (
+ self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL
+ ):
threepid = None
row = await self.store.get_threepid_validation_session(
medium,
@@ -215,7 +227,10 @@ class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChec
_BaseThreepidAuthChecker.__init__(self, hs)
def is_enabled(self) -> bool:
- return self.hs.config.email.can_verify_email
+ return self.hs.config.email.threepid_behaviour_email in (
+ ThreepidBehaviour.REMOTE,
+ ThreepidBehaviour.LOCAL,
+ )
async def check_auth(self, authdict: dict, clientip: str) -> Any:
return await self._check_threepid("email", authdict)
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 084d0a5b84..89bd403312 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -75,7 +75,13 @@ from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_u
from synapse.http.proxyagent import ProxyAgent
from synapse.http.types import QueryParams
from synapse.logging.context import make_deferred_yieldable
-from synapse.logging.opentracing import set_tag, start_active_span, tags
+from synapse.logging.tracing import (
+ SpanAttributes,
+ SpanKind,
+ StatusCode,
+ set_status,
+ start_active_span,
+)
from synapse.types import ISynapseReactor
from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred
@@ -402,12 +408,11 @@ class SimpleHttpClient:
with start_active_span(
"outgoing-client-request",
- tags={
- tags.SPAN_KIND: tags.SPAN_KIND_RPC_CLIENT,
- tags.HTTP_METHOD: method,
- tags.HTTP_URL: uri,
+ kind=SpanKind.CLIENT,
+ attributes={
+ SpanAttributes.HTTP_METHOD: method,
+ SpanAttributes.HTTP_URL: uri,
},
- finish_on_close=True,
):
try:
body_producer = None
@@ -459,8 +464,7 @@ class SimpleHttpClient:
type(e).__name__,
e.args[0],
)
- set_tag(tags.ERROR, True)
- set_tag("error_reason", e.args[0])
+ set_status(StatusCode.ERROR, e)
raise
async def post_urlencoded_get_json(
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 3c35b1d2c7..00704a6a7c 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -72,9 +72,14 @@ from synapse.http.client import (
)
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.http.types import QueryParams
-from synapse.logging import opentracing
+from synapse.logging import tracing
from synapse.logging.context import make_deferred_yieldable, run_in_background
-from synapse.logging.opentracing import set_tag, start_active_span, tags
+from synapse.logging.tracing import (
+ SpanAttributes,
+ SpanKind,
+ set_attribute,
+ start_active_span,
+)
from synapse.types import JsonDict
from synapse.util import json_decoder
from synapse.util.async_helpers import AwakenableSleeper, timeout_deferred
@@ -517,18 +522,19 @@ class MatrixFederationHttpClient:
scope = start_active_span(
"outgoing-federation-request",
- tags={
- tags.SPAN_KIND: tags.SPAN_KIND_RPC_CLIENT,
- tags.PEER_ADDRESS: request.destination,
- tags.HTTP_METHOD: request.method,
- tags.HTTP_URL: request.path,
+ kind=SpanKind.CLIENT,
+ attributes={
+ SpanAttributes.HTTP_HOST: request.destination,
+ SpanAttributes.HTTP_METHOD: request.method,
+ SpanAttributes.HTTP_URL: request.path,
},
- finish_on_close=True,
)
# Inject the span into the headers
headers_dict: Dict[bytes, List[bytes]] = {}
- opentracing.inject_header_dict(headers_dict, request.destination)
+ tracing.inject_active_tracing_context_into_header_dict(
+ headers_dict, request.destination
+ )
headers_dict[b"User-Agent"] = [self.version_string_bytes]
@@ -614,7 +620,7 @@ class MatrixFederationHttpClient:
request.method, response.code
).inc()
- set_tag(tags.HTTP_STATUS_CODE, response.code)
+ set_attribute(SpanAttributes.HTTP_STATUS_CODE, response.code)
response_phrase = response.phrase.decode("ascii", errors="replace")
if 200 <= response.code < 300:
diff --git a/synapse/http/server.py b/synapse/http/server.py
index cf2d6f904b..6420c0837b 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -58,15 +58,16 @@ from synapse.api.errors import (
SynapseError,
UnrecognizedRequestError,
)
+from synapse.config.homeserver import HomeServerConfig
from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread, preserve_fn, run_in_background
-from synapse.logging.opentracing import active_span, start_active_span, trace_servlet
+from synapse.logging.tracing import get_active_span, start_active_span, trace_servlet
from synapse.util import json_encoder
from synapse.util.caches import intern_dict
from synapse.util.iterutils import chunk_seq
if TYPE_CHECKING:
- import opentracing
+ import opentelemetry
from synapse.server import HomeServer
@@ -155,15 +156,16 @@ def is_method_cancellable(method: Callable[..., Any]) -> bool:
return getattr(method, "cancellable", False)
-def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
+def return_json_error(
+ f: failure.Failure, request: SynapseRequest, config: Optional[HomeServerConfig]
+) -> None:
"""Sends a JSON error response to clients."""
if f.check(SynapseError):
# mypy doesn't understand that f.check asserts the type.
exc: SynapseError = f.value # type: ignore
error_code = exc.code
- error_dict = exc.error_dict()
-
+ error_dict = exc.error_dict(config)
logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg)
elif f.check(CancelledError):
error_code = HTTP_STATUS_REQUEST_CANCELLED
@@ -327,7 +329,7 @@ class HttpServer(Protocol):
subsequent arguments will be any matched groups from the regex.
This should return either tuple of (code, response), or None.
servlet_classname (str): The name of the handler to be used in prometheus
- and opentracing logs.
+ and tracing logs.
"""
@@ -338,7 +340,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
requests by method, or override `_async_render` to handle all requests.
Args:
- extract_context: Whether to attempt to extract the opentracing
+ extract_context: Whether to attempt to extract the tracing
context from the request the servlet is handling.
"""
@@ -450,7 +452,7 @@ class DirectServeJsonResource(_AsyncResource):
request: SynapseRequest,
) -> None:
"""Implements _AsyncResource._send_error_response"""
- return_json_error(f, request)
+ return_json_error(f, request, None)
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -508,7 +510,7 @@ class JsonResource(DirectServeJsonResource):
callback: The handler for the request. Usually a Servlet
servlet_classname: The name of the handler to be used in prometheus
- and opentracing logs.
+ and tracing logs.
"""
method_bytes = method.encode("utf-8")
@@ -575,6 +577,14 @@ class JsonResource(DirectServeJsonResource):
return callback_return
+ def _send_error_response(
+ self,
+ f: failure.Failure,
+ request: SynapseRequest,
+ ) -> None:
+ """Implements _AsyncResource._send_error_response"""
+ return_json_error(f, request, self.hs.config)
+
class DirectServeHtmlResource(_AsyncResource):
"""A resource that will call `self._async_on_<METHOD>` on new requests,
@@ -868,19 +878,19 @@ async def _async_write_json_to_request_in_thread(
expensive.
"""
- def encode(opentracing_span: "Optional[opentracing.Span]") -> bytes:
+ def encode(tracing_span: Optional["opentelemetry.trace.Span"]) -> bytes:
# it might take a while for the threadpool to schedule us, so we write
- # opentracing logs once we actually get scheduled, so that we can see how
+ # tracing logs once we actually get scheduled, so that we can see how
# much that contributed.
- if opentracing_span:
- opentracing_span.log_kv({"event": "scheduled"})
+ if tracing_span:
+ tracing_span.add_event("scheduled", attributes={"event": "scheduled"})
res = json_encoder(json_object)
- if opentracing_span:
- opentracing_span.log_kv({"event": "encoded"})
+ if tracing_span:
+ tracing_span.add_event("scheduled", attributes={"event": "encoded"})
return res
with start_active_span("encode_json_response"):
- span = active_span()
+ span = get_active_span()
json_str = await defer_to_thread(request.reactor, encode, span)
_write_bytes_to_request(request, json_str)
diff --git a/synapse/http/site.py b/synapse/http/site.py
index eeec74b78a..d82c046dd7 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -37,7 +37,7 @@ from synapse.logging.context import (
from synapse.types import Requester
if TYPE_CHECKING:
- import opentracing
+ import opentelemetry
logger = logging.getLogger(__name__)
@@ -85,9 +85,9 @@ class SynapseRequest(Request):
# server name, for client requests this is the Requester object.
self._requester: Optional[Union[Requester, str]] = None
- # An opentracing span for this request. Will be closed when the request is
+ # An tracing span for this request. Will be closed when the request is
# completely processed.
- self._opentracing_span: "Optional[opentracing.Span]" = None
+ self._tracing_span: Optional["opentelemetry.trace.Span"] = None
# we can't yet create the logcontext, as we don't know the method.
self.logcontext: Optional[LoggingContext] = None
@@ -164,12 +164,12 @@ class SynapseRequest(Request):
# If there's no authenticated entity, it was the requester.
self.logcontext.request.authenticated_entity = authenticated_entity or requester
- def set_opentracing_span(self, span: "opentracing.Span") -> None:
- """attach an opentracing span to this request
+ def set_tracing_span(self, span: "opentelemetry.trace.Span") -> None:
+ """attach an tracing span to this request
Doing so will cause the span to be closed when we finish processing the request
"""
- self._opentracing_span = span
+ self._tracing_span = span
def get_request_id(self) -> str:
return "%s-%i" % (self.get_method(), self.request_seq)
@@ -309,8 +309,10 @@ class SynapseRequest(Request):
self._processing_finished_time = time.time()
self._is_processing = False
- if self._opentracing_span:
- self._opentracing_span.log_kv({"event": "finished processing"})
+ if self._tracing_span:
+ self._tracing_span.add_event(
+ "finished processing", attributes={"event": "finished processing"}
+ )
# if we've already sent the response, log it now; otherwise, we wait for the
# response to be sent.
@@ -325,8 +327,10 @@ class SynapseRequest(Request):
"""
self.finish_time = time.time()
Request.finish(self)
- if self._opentracing_span:
- self._opentracing_span.log_kv({"event": "response sent"})
+ if self._tracing_span:
+ self._tracing_span.add_event(
+ "response sent", attributes={"event": "response sent"}
+ )
if not self._is_processing:
assert self.logcontext is not None
with PreserveLoggingContext(self.logcontext):
@@ -361,9 +365,13 @@ class SynapseRequest(Request):
with PreserveLoggingContext(self.logcontext):
logger.info("Connection from client lost before response was sent")
- if self._opentracing_span:
- self._opentracing_span.log_kv(
- {"event": "client connection lost", "reason": str(reason.value)}
+ if self._tracing_span:
+ self._tracing_span.add_event(
+ "client connection lost",
+ attributes={
+ "event": "client connection lost",
+ "reason": str(reason.value),
+ },
)
if self._is_processing:
@@ -471,9 +479,9 @@ class SynapseRequest(Request):
usage.evt_db_fetch_count,
)
- # complete the opentracing span, if any.
- if self._opentracing_span:
- self._opentracing_span.finish()
+ # complete the tracing span, if any.
+ if self._tracing_span:
+ self._tracing_span.end()
try:
self.request_metrics.stop(self.finish_time, self.code, self.sentLength)
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index fd9cb97920..dde9c151f5 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -46,7 +46,6 @@ from twisted.internet import defer, threads
from twisted.python.threadpool import ThreadPool
if TYPE_CHECKING:
- from synapse.logging.scopecontextmanager import _LogContextScope
from synapse.types import ISynapseReactor
logger = logging.getLogger(__name__)
@@ -221,14 +220,13 @@ LoggingContextOrSentinel = Union["LoggingContext", "_Sentinel"]
class _Sentinel:
"""Sentinel to represent the root context"""
- __slots__ = ["previous_context", "finished", "request", "scope", "tag"]
+ __slots__ = ["previous_context", "finished", "request", "tag"]
def __init__(self) -> None:
# Minimal set for compatibility with LoggingContext
self.previous_context = None
self.finished = False
self.request = None
- self.scope = None
self.tag = None
def __str__(self) -> str:
@@ -281,7 +279,6 @@ class LoggingContext:
"finished",
"request",
"tag",
- "scope",
]
def __init__(
@@ -302,7 +299,6 @@ class LoggingContext:
self.main_thread = get_thread_id()
self.request = None
self.tag = ""
- self.scope: Optional["_LogContextScope"] = None
# keep track of whether we have hit the __exit__ block for this context
# (suggesting that the the thing that created the context thinks it should
@@ -315,9 +311,6 @@ class LoggingContext:
# we track the current request_id
self.request = self.parent_context.request
- # we also track the current scope:
- self.scope = self.parent_context.scope
-
if request is not None:
# the request param overrides the request from the parent context
self.request = request
@@ -337,10 +330,8 @@ class LoggingContext:
@classmethod
def current_context(cls) -> LoggingContextOrSentinel:
"""Get the current logging context from thread local storage
-
This exists for backwards compatibility. ``current_context()`` should be
called directly.
-
Returns:
LoggingContext: the current logging context
"""
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
deleted file mode 100644
index ad5cbf46a4..0000000000
--- a/synapse/logging/opentracing.py
+++ /dev/null
@@ -1,972 +0,0 @@
-# Copyright 2019 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-# NOTE
-# This is a small wrapper around opentracing because opentracing is not currently
-# packaged downstream (specifically debian). Since opentracing instrumentation is
-# fairly invasive it was awkward to make it optional. As a result we opted to encapsulate
-# all opentracing state in these methods which effectively noop if opentracing is
-# not present. We should strongly consider encouraging the downstream distributers
-# to package opentracing and making opentracing a full dependency. In order to facilitate
-# this move the methods have work very similarly to opentracing's and it should only
-# be a matter of few regexes to move over to opentracing's access patterns proper.
-
-"""
-============================
-Using OpenTracing in Synapse
-============================
-
-Python-specific tracing concepts are at https://opentracing.io/guides/python/.
-Note that Synapse wraps OpenTracing in a small module (this one) in order to make the
-OpenTracing dependency optional. That means that the access patterns are
-different to those demonstrated in the OpenTracing guides. However, it is
-still useful to know, especially if OpenTracing is included as a full dependency
-in the future or if you are modifying this module.
-
-
-OpenTracing is encapsulated so that
-no span objects from OpenTracing are exposed in Synapse's code. This allows
-OpenTracing to be easily disabled in Synapse and thereby have OpenTracing as
-an optional dependency. This does however limit the number of modifiable spans
-at any point in the code to one. From here out references to `opentracing`
-in the code snippets refer to the Synapses module.
-Most methods provided in the module have a direct correlation to those provided
-by opentracing. Refer to docs there for a more in-depth documentation on some of
-the args and methods.
-
-Tracing
--------
-
-In Synapse it is not possible to start a non-active span. Spans can be started
-using the ``start_active_span`` method. This returns a scope (see
-OpenTracing docs) which is a context manager that needs to be entered and
-exited. This is usually done by using ``with``.
-
-.. code-block:: python
-
- from synapse.logging.opentracing import start_active_span
-
- with start_active_span("operation name"):
- # Do something we want to tracer
-
-Forgetting to enter or exit a scope will result in some mysterious and grievous log
-context errors.
-
-At anytime where there is an active span ``opentracing.set_tag`` can be used to
-set a tag on the current active span.
-
-Tracing functions
------------------
-
-Functions can be easily traced using decorators. The name of
-the function becomes the operation name for the span.
-
-.. code-block:: python
-
- from synapse.logging.opentracing import trace
-
- # Start a span using 'interesting_function' as the operation name
- @trace
- def interesting_function(*args, **kwargs):
- # Does all kinds of cool and expected things
- return something_usual_and_useful
-
-
-Operation names can be explicitly set for a function by using ``trace_with_opname``:
-
-.. code-block:: python
-
- from synapse.logging.opentracing import trace_with_opname
-
- @trace_with_opname("a_better_operation_name")
- def interesting_badly_named_function(*args, **kwargs):
- # Does all kinds of cool and expected things
- return something_usual_and_useful
-
-Setting Tags
-------------
-
-To set a tag on the active span do
-
-.. code-block:: python
-
- from synapse.logging.opentracing import set_tag
-
- set_tag(tag_name, tag_value)
-
-There's a convenient decorator to tag all the args of the method. It uses
-inspection in order to use the formal parameter names prefixed with 'ARG_' as
-tag names. It uses kwarg names as tag names without the prefix.
-
-.. code-block:: python
-
- from synapse.logging.opentracing import tag_args
-
- @tag_args
- def set_fates(clotho, lachesis, atropos, father="Zues", mother="Themis"):
- pass
-
- set_fates("the story", "the end", "the act")
- # This will have the following tags
- # - ARG_clotho: "the story"
- # - ARG_lachesis: "the end"
- # - ARG_atropos: "the act"
- # - father: "Zues"
- # - mother: "Themis"
-
-Contexts and carriers
----------------------
-
-There are a selection of wrappers for injecting and extracting contexts from
-carriers provided. Unfortunately OpenTracing's three context injection
-techniques are not adequate for our inject of OpenTracing span-contexts into
-Twisted's http headers, EDU contents and our database tables. Also note that
-the binary encoding format mandated by OpenTracing is not actually implemented
-by jaeger_client v4.0.0 - it will silently noop.
-Please refer to the end of ``logging/opentracing.py`` for the available
-injection and extraction methods.
-
-Homeserver whitelisting
------------------------
-
-Most of the whitelist checks are encapsulated in the modules's injection
-and extraction method but be aware that using custom carriers or crossing
-unchartered waters will require the enforcement of the whitelist.
-``logging/opentracing.py`` has a ``whitelisted_homeserver`` method which takes
-in a destination and compares it to the whitelist.
-
-Most injection methods take a 'destination' arg. The context will only be injected
-if the destination matches the whitelist or the destination is None.
-
-=======
-Gotchas
-=======
-
-- Checking whitelists on span propagation
-- Inserting pii
-- Forgetting to enter or exit a scope
-- Span source: make sure that the span you expect to be active across a
- function call really will be that one. Does the current function have more
- than one caller? Will all of those calling functions have be in a context
- with an active span?
-"""
-import contextlib
-import enum
-import inspect
-import logging
-import re
-from functools import wraps
-from typing import (
- TYPE_CHECKING,
- Any,
- Callable,
- Collection,
- Dict,
- Generator,
- Iterable,
- List,
- Optional,
- Pattern,
- Type,
- TypeVar,
- Union,
- cast,
- overload,
-)
-
-import attr
-from typing_extensions import ParamSpec
-
-from twisted.internet import defer
-from twisted.web.http import Request
-from twisted.web.http_headers import Headers
-
-from synapse.config import ConfigError
-from synapse.util import json_decoder, json_encoder
-
-if TYPE_CHECKING:
- from synapse.http.site import SynapseRequest
- from synapse.server import HomeServer
-
-# Helper class
-
-
-class _DummyTagNames:
- """wrapper of opentracings tags. We need to have them if we
- want to reference them without opentracing around. Clearly they
- should never actually show up in a trace. `set_tags` overwrites
- these with the correct ones."""
-
- INVALID_TAG = "invalid-tag"
- COMPONENT = INVALID_TAG
- DATABASE_INSTANCE = INVALID_TAG
- DATABASE_STATEMENT = INVALID_TAG
- DATABASE_TYPE = INVALID_TAG
- DATABASE_USER = INVALID_TAG
- ERROR = INVALID_TAG
- HTTP_METHOD = INVALID_TAG
- HTTP_STATUS_CODE = INVALID_TAG
- HTTP_URL = INVALID_TAG
- MESSAGE_BUS_DESTINATION = INVALID_TAG
- PEER_ADDRESS = INVALID_TAG
- PEER_HOSTNAME = INVALID_TAG
- PEER_HOST_IPV4 = INVALID_TAG
- PEER_HOST_IPV6 = INVALID_TAG
- PEER_PORT = INVALID_TAG
- PEER_SERVICE = INVALID_TAG
- SAMPLING_PRIORITY = INVALID_TAG
- SERVICE = INVALID_TAG
- SPAN_KIND = INVALID_TAG
- SPAN_KIND_CONSUMER = INVALID_TAG
- SPAN_KIND_PRODUCER = INVALID_TAG
- SPAN_KIND_RPC_CLIENT = INVALID_TAG
- SPAN_KIND_RPC_SERVER = INVALID_TAG
-
-
-try:
- import opentracing
- import opentracing.tags
-
- tags = opentracing.tags
-except ImportError:
- opentracing = None # type: ignore[assignment]
- tags = _DummyTagNames # type: ignore[assignment]
-try:
- from jaeger_client import Config as JaegerConfig
-
- from synapse.logging.scopecontextmanager import LogContextScopeManager
-except ImportError:
- JaegerConfig = None # type: ignore
- LogContextScopeManager = None # type: ignore
-
-
-try:
- from rust_python_jaeger_reporter import Reporter
-
- # jaeger-client 4.7.0 requires that reporters inherit from BaseReporter, which
- # didn't exist before that version.
- try:
- from jaeger_client.reporter import BaseReporter
- except ImportError:
-
- class BaseReporter: # type: ignore[no-redef]
- pass
-
- @attr.s(slots=True, frozen=True, auto_attribs=True)
- class _WrappedRustReporter(BaseReporter):
- """Wrap the reporter to ensure `report_span` never throws."""
-
- _reporter: Reporter = attr.Factory(Reporter)
-
- def set_process(self, *args: Any, **kwargs: Any) -> None:
- return self._reporter.set_process(*args, **kwargs)
-
- def report_span(self, span: "opentracing.Span") -> None:
- try:
- return self._reporter.report_span(span)
- except Exception:
- logger.exception("Failed to report span")
-
- RustReporter: Optional[Type[_WrappedRustReporter]] = _WrappedRustReporter
-except ImportError:
- RustReporter = None
-
-
-logger = logging.getLogger(__name__)
-
-
-class SynapseTags:
- # The message ID of any to_device message processed
- TO_DEVICE_MESSAGE_ID = "to_device.message_id"
-
- # Whether the sync response has new data to be returned to the client.
- SYNC_RESULT = "sync.new_data"
-
- # incoming HTTP request ID (as written in the logs)
- REQUEST_ID = "request_id"
-
- # HTTP request tag (used to distinguish full vs incremental syncs, etc)
- REQUEST_TAG = "request_tag"
-
- # Text description of a database transaction
- DB_TXN_DESC = "db.txn_desc"
-
- # Uniqueish ID of a database transaction
- DB_TXN_ID = "db.txn_id"
-
- # The name of the external cache
- CACHE_NAME = "cache.name"
-
-
-class SynapseBaggage:
- FORCE_TRACING = "synapse-force-tracing"
-
-
-# Block everything by default
-# A regex which matches the server_names to expose traces for.
-# None means 'block everything'.
-_homeserver_whitelist: Optional[Pattern[str]] = None
-
-# Util methods
-
-
-class _Sentinel(enum.Enum):
- # defining a sentinel in this way allows mypy to correctly handle the
- # type of a dictionary lookup.
- sentinel = object()
-
-
-P = ParamSpec("P")
-R = TypeVar("R")
-T = TypeVar("T")
-
-
-def only_if_tracing(func: Callable[P, R]) -> Callable[P, Optional[R]]:
- """Executes the function only if we're tracing. Otherwise returns None."""
-
- @wraps(func)
- def _only_if_tracing_inner(*args: P.args, **kwargs: P.kwargs) -> Optional[R]:
- if opentracing:
- return func(*args, **kwargs)
- else:
- return None
-
- return _only_if_tracing_inner
-
-
-@overload
-def ensure_active_span(
- message: str,
-) -> Callable[[Callable[P, R]], Callable[P, Optional[R]]]:
- ...
-
-
-@overload
-def ensure_active_span(
- message: str, ret: T
-) -> Callable[[Callable[P, R]], Callable[P, Union[T, R]]]:
- ...
-
-
-def ensure_active_span(
- message: str, ret: Optional[T] = None
-) -> Callable[[Callable[P, R]], Callable[P, Union[Optional[T], R]]]:
- """Executes the operation only if opentracing is enabled and there is an active span.
- If there is no active span it logs message at the error level.
-
- Args:
- message: Message which fills in "There was no active span when trying to %s"
- in the error log if there is no active span and opentracing is enabled.
- ret: return value if opentracing is None or there is no active span.
-
- Returns:
- The result of the func, falling back to ret if opentracing is disabled or there
- was no active span.
- """
-
- def ensure_active_span_inner_1(
- func: Callable[P, R]
- ) -> Callable[P, Union[Optional[T], R]]:
- @wraps(func)
- def ensure_active_span_inner_2(
- *args: P.args, **kwargs: P.kwargs
- ) -> Union[Optional[T], R]:
- if not opentracing:
- return ret
-
- if not opentracing.tracer.active_span:
- logger.error(
- "There was no active span when trying to %s."
- " Did you forget to start one or did a context slip?",
- message,
- stack_info=True,
- )
-
- return ret
-
- return func(*args, **kwargs)
-
- return ensure_active_span_inner_2
-
- return ensure_active_span_inner_1
-
-
-# Setup
-
-
-def init_tracer(hs: "HomeServer") -> None:
- """Set the whitelists and initialise the JaegerClient tracer"""
- global opentracing
- if not hs.config.tracing.opentracer_enabled:
- # We don't have a tracer
- opentracing = None # type: ignore[assignment]
- return
-
- if not opentracing or not JaegerConfig:
- raise ConfigError(
- "The server has been configured to use opentracing but opentracing is not "
- "installed."
- )
-
- # Pull out the jaeger config if it was given. Otherwise set it to something sensible.
- # See https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/config.py
-
- set_homeserver_whitelist(hs.config.tracing.opentracer_whitelist)
-
- from jaeger_client.metrics.prometheus import PrometheusMetricsFactory
-
- config = JaegerConfig(
- config=hs.config.tracing.jaeger_config,
- service_name=f"{hs.config.server.server_name} {hs.get_instance_name()}",
- scope_manager=LogContextScopeManager(),
- metrics_factory=PrometheusMetricsFactory(),
- )
-
- # If we have the rust jaeger reporter available let's use that.
- if RustReporter:
- logger.info("Using rust_python_jaeger_reporter library")
- assert config.sampler is not None
- tracer = config.create_tracer(RustReporter(), config.sampler)
- opentracing.set_global_tracer(tracer)
- else:
- config.initialize_tracer()
-
-
-# Whitelisting
-
-
-@only_if_tracing
-def set_homeserver_whitelist(homeserver_whitelist: Iterable[str]) -> None:
- """Sets the homeserver whitelist
-
- Args:
- homeserver_whitelist: regexes specifying whitelisted homeservers
- """
- global _homeserver_whitelist
- if homeserver_whitelist:
- # Makes a single regex which accepts all passed in regexes in the list
- _homeserver_whitelist = re.compile(
- "({})".format(")|(".join(homeserver_whitelist))
- )
-
-
-@only_if_tracing
-def whitelisted_homeserver(destination: str) -> bool:
- """Checks if a destination matches the whitelist
-
- Args:
- destination
- """
-
- if _homeserver_whitelist:
- return _homeserver_whitelist.match(destination) is not None
- return False
-
-
-# Start spans and scopes
-
-# Could use kwargs but I want these to be explicit
-def start_active_span(
- operation_name: str,
- child_of: Optional[Union["opentracing.Span", "opentracing.SpanContext"]] = None,
- references: Optional[List["opentracing.Reference"]] = None,
- tags: Optional[Dict[str, str]] = None,
- start_time: Optional[float] = None,
- ignore_active_span: bool = False,
- finish_on_close: bool = True,
- *,
- tracer: Optional["opentracing.Tracer"] = None,
-) -> "opentracing.Scope":
- """Starts an active opentracing span.
-
- Records the start time for the span, and sets it as the "active span" in the
- scope manager.
-
- Args:
- See opentracing.tracer
- Returns:
- scope (Scope) or contextlib.nullcontext
- """
-
- if opentracing is None:
- return contextlib.nullcontext() # type: ignore[unreachable]
-
- if tracer is None:
- # use the global tracer by default
- tracer = opentracing.tracer
-
- return tracer.start_active_span(
- operation_name,
- child_of=child_of,
- references=references,
- tags=tags,
- start_time=start_time,
- ignore_active_span=ignore_active_span,
- finish_on_close=finish_on_close,
- )
-
-
-def start_active_span_follows_from(
- operation_name: str,
- contexts: Collection,
- child_of: Optional[Union["opentracing.Span", "opentracing.SpanContext"]] = None,
- start_time: Optional[float] = None,
- *,
- inherit_force_tracing: bool = False,
- tracer: Optional["opentracing.Tracer"] = None,
-) -> "opentracing.Scope":
- """Starts an active opentracing span, with additional references to previous spans
-
- Args:
- operation_name: name of the operation represented by the new span
- contexts: the previous spans to inherit from
-
- child_of: optionally override the parent span. If unset, the currently active
- span will be the parent. (If there is no currently active span, the first
- span in `contexts` will be the parent.)
-
- start_time: optional override for the start time of the created span. Seconds
- since the epoch.
-
- inherit_force_tracing: if set, and any of the previous contexts have had tracing
- forced, the new span will also have tracing forced.
- tracer: override the opentracing tracer. By default the global tracer is used.
- """
- if opentracing is None:
- return contextlib.nullcontext() # type: ignore[unreachable]
-
- references = [opentracing.follows_from(context) for context in contexts]
- scope = start_active_span(
- operation_name,
- child_of=child_of,
- references=references,
- start_time=start_time,
- tracer=tracer,
- )
-
- if inherit_force_tracing and any(
- is_context_forced_tracing(ctx) for ctx in contexts
- ):
- force_tracing(scope.span)
-
- return scope
-
-
-def start_active_span_from_edu(
- edu_content: Dict[str, Any],
- operation_name: str,
- references: Optional[List["opentracing.Reference"]] = None,
- tags: Optional[Dict[str, str]] = None,
- start_time: Optional[float] = None,
- ignore_active_span: bool = False,
- finish_on_close: bool = True,
-) -> "opentracing.Scope":
- """
- Extracts a span context from an edu and uses it to start a new active span
-
- Args:
- edu_content: an edu_content with a `context` field whose value is
- canonical json for a dict which contains opentracing information.
-
- For the other args see opentracing.tracer
- """
- references = references or []
-
- if opentracing is None:
- return contextlib.nullcontext() # type: ignore[unreachable]
-
- carrier = json_decoder.decode(edu_content.get("context", "{}")).get(
- "opentracing", {}
- )
- context = opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier)
- _references = [
- opentracing.child_of(span_context_from_string(x))
- for x in carrier.get("references", [])
- ]
-
- # For some reason jaeger decided not to support the visualization of multiple parent
- # spans or explicitly show references. I include the span context as a tag here as
- # an aid to people debugging but it's really not an ideal solution.
-
- references += _references
-
- scope = opentracing.tracer.start_active_span(
- operation_name,
- child_of=context,
- references=references,
- tags=tags,
- start_time=start_time,
- ignore_active_span=ignore_active_span,
- finish_on_close=finish_on_close,
- )
-
- scope.span.set_tag("references", carrier.get("references", []))
- return scope
-
-
-# Opentracing setters for tags, logs, etc
-@only_if_tracing
-def active_span() -> Optional["opentracing.Span"]:
- """Get the currently active span, if any"""
- return opentracing.tracer.active_span
-
-
-@ensure_active_span("set a tag")
-def set_tag(key: str, value: Union[str, bool, int, float]) -> None:
- """Sets a tag on the active span"""
- assert opentracing.tracer.active_span is not None
- opentracing.tracer.active_span.set_tag(key, value)
-
-
-@ensure_active_span("log")
-def log_kv(key_values: Dict[str, Any], timestamp: Optional[float] = None) -> None:
- """Log to the active span"""
- assert opentracing.tracer.active_span is not None
- opentracing.tracer.active_span.log_kv(key_values, timestamp)
-
-
-@ensure_active_span("set the traces operation name")
-def set_operation_name(operation_name: str) -> None:
- """Sets the operation name of the active span"""
- assert opentracing.tracer.active_span is not None
- opentracing.tracer.active_span.set_operation_name(operation_name)
-
-
-@only_if_tracing
-def force_tracing(
- span: Union["opentracing.Span", _Sentinel] = _Sentinel.sentinel
-) -> None:
- """Force sampling for the active/given span and its children.
-
- Args:
- span: span to force tracing for. By default, the active span.
- """
- if isinstance(span, _Sentinel):
- span_to_trace = opentracing.tracer.active_span
- else:
- span_to_trace = span
- if span_to_trace is None:
- logger.error("No active span in force_tracing")
- return
-
- span_to_trace.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
-
- # also set a bit of baggage, so that we have a way of figuring out if
- # it is enabled later
- span_to_trace.set_baggage_item(SynapseBaggage.FORCE_TRACING, "1")
-
-
-def is_context_forced_tracing(
- span_context: Optional["opentracing.SpanContext"],
-) -> bool:
- """Check if sampling has been force for the given span context."""
- if span_context is None:
- return False
- return span_context.baggage.get(SynapseBaggage.FORCE_TRACING) is not None
-
-
-# Injection and extraction
-
-
-@ensure_active_span("inject the span into a header dict")
-def inject_header_dict(
- headers: Dict[bytes, List[bytes]],
- destination: Optional[str] = None,
- check_destination: bool = True,
-) -> None:
- """
- Injects a span context into a dict of HTTP headers
-
- Args:
- headers: the dict to inject headers into
- destination: address of entity receiving the span context. Must be given unless
- check_destination is False. The context will only be injected if the
- destination matches the opentracing whitelist
- check_destination (bool): If false, destination will be ignored and the context
- will always be injected.
-
- Note:
- The headers set by the tracer are custom to the tracer implementation which
- should be unique enough that they don't interfere with any headers set by
- synapse or twisted. If we're still using jaeger these headers would be those
- here:
- https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/constants.py
- """
- if check_destination:
- if destination is None:
- raise ValueError(
- "destination must be given unless check_destination is False"
- )
- if not whitelisted_homeserver(destination):
- return
-
- span = opentracing.tracer.active_span
-
- carrier: Dict[str, str] = {}
- assert span is not None
- opentracing.tracer.inject(span.context, opentracing.Format.HTTP_HEADERS, carrier)
-
- for key, value in carrier.items():
- headers[key.encode()] = [value.encode()]
-
-
-def inject_response_headers(response_headers: Headers) -> None:
- """Inject the current trace id into the HTTP response headers"""
- if not opentracing:
- return
- span = opentracing.tracer.active_span
- if not span:
- return
-
- # This is a bit implementation-specific.
- #
- # Jaeger's Spans have a trace_id property; other implementations (including the
- # dummy opentracing.span.Span which we use if init_tracer is not called) do not
- # expose it
- trace_id = getattr(span, "trace_id", None)
-
- if trace_id is not None:
- response_headers.addRawHeader("Synapse-Trace-Id", f"{trace_id:x}")
-
-
-@ensure_active_span(
- "get the active span context as a dict", ret=cast(Dict[str, str], {})
-)
-def get_active_span_text_map(destination: Optional[str] = None) -> Dict[str, str]:
- """
- Gets a span context as a dict. This can be used instead of manually
- injecting a span into an empty carrier.
-
- Args:
- destination: the name of the remote server.
-
- Returns:
- dict: the active span's context if opentracing is enabled, otherwise empty.
- """
-
- if destination and not whitelisted_homeserver(destination):
- return {}
-
- carrier: Dict[str, str] = {}
- assert opentracing.tracer.active_span is not None
- opentracing.tracer.inject(
- opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier
- )
-
- return carrier
-
-
-@ensure_active_span("get the span context as a string.", ret={})
-def active_span_context_as_string() -> str:
- """
- Returns:
- The active span context encoded as a string.
- """
- carrier: Dict[str, str] = {}
- if opentracing:
- assert opentracing.tracer.active_span is not None
- opentracing.tracer.inject(
- opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier
- )
- return json_encoder.encode(carrier)
-
-
-def span_context_from_request(request: Request) -> "Optional[opentracing.SpanContext]":
- """Extract an opentracing context from the headers on an HTTP request
-
- This is useful when we have received an HTTP request from another part of our
- system, and want to link our spans to those of the remote system.
- """
- if not opentracing:
- return None
- header_dict = {
- k.decode(): v[0].decode() for k, v in request.requestHeaders.getAllRawHeaders()
- }
- return opentracing.tracer.extract(opentracing.Format.HTTP_HEADERS, header_dict)
-
-
-@only_if_tracing
-def span_context_from_string(carrier: str) -> Optional["opentracing.SpanContext"]:
- """
- Returns:
- The active span context decoded from a string.
- """
- payload: Dict[str, str] = json_decoder.decode(carrier)
- return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, payload)
-
-
-@only_if_tracing
-def extract_text_map(carrier: Dict[str, str]) -> Optional["opentracing.SpanContext"]:
- """
- Wrapper method for opentracing's tracer.extract for TEXT_MAP.
- Args:
- carrier: a dict possibly containing a span context.
-
- Returns:
- The active span context extracted from carrier.
- """
- return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier)
-
-
-# Tracing decorators
-
-
-def trace_with_opname(opname: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
- """
- Decorator to trace a function with a custom opname.
-
- See the module's doc string for usage examples.
-
- """
-
- def decorator(func: Callable[P, R]) -> Callable[P, R]:
- if opentracing is None:
- return func # type: ignore[unreachable]
-
- if inspect.iscoroutinefunction(func):
-
- @wraps(func)
- async def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R:
- with start_active_span(opname):
- return await func(*args, **kwargs) # type: ignore[misc]
-
- else:
- # The other case here handles both sync functions and those
- # decorated with inlineDeferred.
- @wraps(func)
- def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R:
- scope = start_active_span(opname)
- scope.__enter__()
-
- try:
- result = func(*args, **kwargs)
- if isinstance(result, defer.Deferred):
-
- def call_back(result: R) -> R:
- scope.__exit__(None, None, None)
- return result
-
- def err_back(result: R) -> R:
- scope.__exit__(None, None, None)
- return result
-
- result.addCallbacks(call_back, err_back)
-
- else:
- if inspect.isawaitable(result):
- logger.error(
- "@trace may not have wrapped %s correctly! "
- "The function is not async but returned a %s.",
- func.__qualname__,
- type(result).__name__,
- )
-
- scope.__exit__(None, None, None)
-
- return result
-
- except Exception as e:
- scope.__exit__(type(e), None, e.__traceback__)
- raise
-
- return _trace_inner # type: ignore[return-value]
-
- return decorator
-
-
-def trace(func: Callable[P, R]) -> Callable[P, R]:
- """
- Decorator to trace a function.
-
- Sets the operation name to that of the function's name.
-
- See the module's doc string for usage examples.
- """
-
- return trace_with_opname(func.__name__)(func)
-
-
-def tag_args(func: Callable[P, R]) -> Callable[P, R]:
- """
- Tags all of the args to the active span.
- """
-
- if not opentracing:
- return func
-
- @wraps(func)
- def _tag_args_inner(*args: P.args, **kwargs: P.kwargs) -> R:
- argspec = inspect.getfullargspec(func)
- for i, arg in enumerate(argspec.args[1:]):
- set_tag("ARG_" + arg, args[i]) # type: ignore[index]
- set_tag("args", args[len(argspec.args) :]) # type: ignore[index]
- set_tag("kwargs", str(kwargs))
- return func(*args, **kwargs)
-
- return _tag_args_inner
-
-
-@contextlib.contextmanager
-def trace_servlet(
- request: "SynapseRequest", extract_context: bool = False
-) -> Generator[None, None, None]:
- """Returns a context manager which traces a request. It starts a span
- with some servlet specific tags such as the request metrics name and
- request information.
-
- Args:
- request
- extract_context: Whether to attempt to extract the opentracing
- context from the request the servlet is handling.
- """
-
- if opentracing is None:
- yield # type: ignore[unreachable]
- return
-
- request_tags = {
- SynapseTags.REQUEST_ID: request.get_request_id(),
- tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
- tags.HTTP_METHOD: request.get_method(),
- tags.HTTP_URL: request.get_redacted_uri(),
- tags.PEER_HOST_IPV6: request.getClientAddress().host,
- }
-
- request_name = request.request_metrics.name
- context = span_context_from_request(request) if extract_context else None
-
- # we configure the scope not to finish the span immediately on exit, and instead
- # pass the span into the SynapseRequest, which will finish it once we've finished
- # sending the response to the client.
- scope = start_active_span(request_name, child_of=context, finish_on_close=False)
- request.set_opentracing_span(scope.span)
-
- with scope:
- inject_response_headers(request.responseHeaders)
- try:
- yield
- finally:
- # We set the operation name again in case its changed (which happens
- # with JsonResource).
- scope.span.set_operation_name(request.request_metrics.name)
-
- # set the tags *after* the servlet completes, in case it decided to
- # prioritise the span (tags will get dropped on unprioritised spans)
- request_tags[
- SynapseTags.REQUEST_TAG
- ] = request.request_metrics.start_context.tag
-
- for k, v in request_tags.items():
- scope.span.set_tag(k, v)
diff --git a/synapse/logging/scopecontextmanager.py b/synapse/logging/scopecontextmanager.py
deleted file mode 100644
index 10877bdfc5..0000000000
--- a/synapse/logging/scopecontextmanager.py
+++ /dev/null
@@ -1,171 +0,0 @@
-# Copyright 2019 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.import logging
-
-import logging
-from types import TracebackType
-from typing import Optional, Type
-
-from opentracing import Scope, ScopeManager, Span
-
-import twisted
-
-from synapse.logging.context import (
- LoggingContext,
- current_context,
- nested_logging_context,
-)
-
-logger = logging.getLogger(__name__)
-
-
-class LogContextScopeManager(ScopeManager):
- """
- The LogContextScopeManager tracks the active scope in opentracing
- by using the log contexts which are native to synapse. This is so
- that the basic opentracing api can be used across twisted defereds.
-
- It would be nice just to use opentracing's ContextVarsScopeManager,
- but currently that doesn't work due to https://twistedmatrix.com/trac/ticket/10301.
- """
-
- def __init__(self) -> None:
- pass
-
- @property
- def active(self) -> Optional[Scope]:
- """
- Returns the currently active Scope which can be used to access the
- currently active Scope.span.
- If there is a non-null Scope, its wrapped Span
- becomes an implicit parent of any newly-created Span at
- Tracer.start_active_span() time.
-
- Return:
- The Scope that is active, or None if not available.
- """
- ctx = current_context()
- return ctx.scope
-
- def activate(self, span: Span, finish_on_close: bool) -> Scope:
- """
- Makes a Span active.
- Args
- span: the span that should become active.
- finish_on_close: whether Span should be automatically finished when
- Scope.close() is called.
-
- Returns:
- Scope to control the end of the active period for
- *span*. It is a programming error to neglect to call
- Scope.close() on the returned instance.
- """
-
- ctx = current_context()
-
- if not ctx:
- logger.error("Tried to activate scope outside of loggingcontext")
- return Scope(None, span) # type: ignore[arg-type]
-
- if ctx.scope is not None:
- # start a new logging context as a child of the existing one.
- # Doing so -- rather than updating the existing logcontext -- means that
- # creating several concurrent spans under the same logcontext works
- # correctly.
- ctx = nested_logging_context("")
- enter_logcontext = True
- else:
- # if there is no span currently associated with the current logcontext, we
- # just store the scope in it.
- #
- # This feels a bit dubious, but it does hack around a problem where a
- # span outlasts its parent logcontext (which would otherwise lead to
- # "Re-starting finished log context" errors).
- enter_logcontext = False
-
- scope = _LogContextScope(self, span, ctx, enter_logcontext, finish_on_close)
- ctx.scope = scope
- if enter_logcontext:
- ctx.__enter__()
-
- return scope
-
-
-class _LogContextScope(Scope):
- """
- A custom opentracing scope, associated with a LogContext
-
- * filters out _DefGen_Return exceptions which arise from calling
- `defer.returnValue` in Twisted code
-
- * When the scope is closed, the logcontext's active scope is reset to None.
- and - if enter_logcontext was set - the logcontext is finished too.
- """
-
- def __init__(
- self,
- manager: LogContextScopeManager,
- span: Span,
- logcontext: LoggingContext,
- enter_logcontext: bool,
- finish_on_close: bool,
- ):
- """
- Args:
- manager:
- the manager that is responsible for this scope.
- span:
- the opentracing span which this scope represents the local
- lifetime for.
- logcontext:
- the log context to which this scope is attached.
- enter_logcontext:
- if True the log context will be exited when the scope is finished
- finish_on_close:
- if True finish the span when the scope is closed
- """
- super().__init__(manager, span)
- self.logcontext = logcontext
- self._finish_on_close = finish_on_close
- self._enter_logcontext = enter_logcontext
-
- def __exit__(
- self,
- exc_type: Optional[Type[BaseException]],
- value: Optional[BaseException],
- traceback: Optional[TracebackType],
- ) -> None:
- if exc_type == twisted.internet.defer._DefGen_Return:
- # filter out defer.returnValue() calls
- exc_type = value = traceback = None
- super().__exit__(exc_type, value, traceback)
-
- def __str__(self) -> str:
- return f"Scope<{self.span}>"
-
- def close(self) -> None:
- active_scope = self.manager.active
- if active_scope is not self:
- logger.error(
- "Closing scope %s which is not the currently-active one %s",
- self,
- active_scope,
- )
-
- if self._finish_on_close:
- self.span.finish()
-
- self.logcontext.scope = None
-
- if self._enter_logcontext:
- self.logcontext.__exit__(None, None, None)
diff --git a/synapse/logging/tracing.py b/synapse/logging/tracing.py
new file mode 100644
index 0000000000..e3a1a010a2
--- /dev/null
+++ b/synapse/logging/tracing.py
@@ -0,0 +1,942 @@
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# NOTE This is a small wrapper around opentelemetry because tracing is optional
+# and not always packaged downstream. Since opentelemetry instrumentation is
+# fairly invasive it was awkward to make it optional. As a result we opted to
+# encapsulate all opentelemetry state in these methods which effectively noop if
+# opentelemetry is not present. We should strongly consider encouraging the
+# downstream distributers to package opentelemetry and making opentelemetry a
+# full dependency. In order to facilitate this move the methods have work very
+# similarly to opentelemetry's and it should only be a matter of few regexes to
+# move over to opentelemetry's access patterns proper.
+
+"""
+============================
+Using OpenTelemetry in Synapse
+============================
+
+Python-specific tracing concepts are at
+https://opentelemetry.io/docs/instrumentation/python/. Note that Synapse wraps
+OpenTelemetry in a small module (this one) in order to make the OpenTelemetry
+dependency optional. That means that some access patterns are different to those
+demonstrated in the OpenTelemetry guides. However, it is still useful to know,
+especially if OpenTelemetry is included as a full dependency in the future or if
+you are modifying this module.
+
+
+OpenTelemetry is encapsulated so that no span objects from OpenTelemetry are
+exposed in Synapse's code. This allows OpenTelemetry to be easily disabled in
+Synapse and thereby have OpenTelemetry as an optional dependency. This does
+however limit the number of modifiable spans at any point in the code to one.
+From here out references to `tracing` in the code snippets refer to the Synapses
+module. Most methods provided in the module have a direct correlation to those
+provided by OpenTelemetry. Refer to docs there for a more in-depth documentation
+on some of the args and methods.
+
+Tracing
+-------
+
+In Synapse, it is not possible to start a non-active span. Spans can be started
+using the ``start_active_span`` method. This returns a context manager that
+needs to be entered and exited to expose the ``span``. This is usually done by
+using a ``with`` statement.
+
+.. code-block:: python
+
+ from synapse.logging.tracing import start_active_span
+
+ with start_active_span("operation name"):
+ # Do something we want to trace
+
+Forgetting to enter or exit a scope will result in unstarted and unfinished
+spans that will not be reported (exported).
+
+At anytime where there is an active span ``set_attribute`` can be
+used to set a tag on the current active span.
+
+Tracing functions
+-----------------
+
+Functions can be easily traced using decorators. The name of the function
+becomes the operation name for the span.
+
+.. code-block:: python
+
+ from synapse.logging.tracing import trace
+
+ # Start a span using 'interesting_function' as the operation name
+ @trace
+ def interesting_function(*args, **kwargs):
+ # Does all kinds of cool and expected things return
+ something_usual_and_useful
+
+
+Operation names can be explicitly set for a function by using
+``trace_with_opname``:
+
+.. code-block:: python
+
+ from synapse.logging.tracing import trace_with_opname
+
+ @trace_with_opname("a_better_operation_name")
+ def interesting_badly_named_function(*args, **kwargs):
+ # Does all kinds of cool and expected things return
+ something_usual_and_useful
+
+Setting Tags
+------------
+
+To set a tag on the active span do
+
+.. code-block:: python
+
+ from synapse.logging.tracing import set_attribute
+
+ set_attribute(tag_name, tag_value)
+
+There's a convenient decorator to tag all the args of the method. It uses
+inspection in order to use the formal parameter names prefixed with 'ARG_' as
+tag names. It uses kwarg names as tag names without the prefix.
+
+.. code-block:: python
+ from synapse.logging.tracing import tag_args
+ @tag_args
+ def set_fates(clotho, lachesis, atropos, father="Zues", mother="Themis"):
+ pass
+ set_fates("the story", "the end", "the act")
+ # This will have the following tags
+ # - ARG_clotho: "the story"
+ # - ARG_lachesis: "the end"
+ # - ARG_atropos: "the act"
+ # - father: "Zues"
+ # - mother: "Themis"
+
+Contexts and carriers
+---------------------
+
+There are a selection of wrappers for injecting and extracting contexts from
+carriers provided. We use these to inject of OpenTelemetry Contexts into
+Twisted's http headers, EDU contents and our database tables. Please refer to
+the end of ``logging/tracing.py`` for the available injection and extraction
+methods.
+
+Homeserver whitelisting
+-----------------------
+
+Most of the whitelist checks are encapsulated in the modules's injection and
+extraction method but be aware that using custom carriers or crossing
+unchartered waters will require the enforcement of the whitelist.
+``logging/tracing.py`` has a ``whitelisted_homeserver`` method which takes
+in a destination and compares it to the whitelist.
+
+Most injection methods take a 'destination' arg. The context will only be
+injected if the destination matches the whitelist or the destination is None.
+
+=======
+Gotchas
+=======
+
+- Checking whitelists on span propagation
+- Inserting pii
+- Forgetting to enter or exit a scope
+- Span source: make sure that the span you expect to be active across a function
+ call really will be that one. Does the current function have more than one
+ caller? Will all of those calling functions have be in a context with an
+ active span?
+"""
+import contextlib
+import inspect
+import logging
+import re
+from abc import ABC
+from functools import wraps
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ ContextManager,
+ Dict,
+ Generator,
+ Iterable,
+ List,
+ Optional,
+ Pattern,
+ Sequence,
+ TypeVar,
+ Union,
+ cast,
+ overload,
+)
+
+from typing_extensions import ParamSpec
+
+from twisted.internet import defer
+from twisted.web.http import Request
+from twisted.web.http_headers import Headers
+
+from synapse.api.constants import EventContentFields
+from synapse.config import ConfigError
+from synapse.util import json_decoder
+
+if TYPE_CHECKING:
+ from synapse.http.site import SynapseRequest
+ from synapse.server import HomeServer
+
+# Helper class
+
+T = TypeVar("T")
+
+
+class _DummyLookup(object):
+ """This will always returns the fixed value given for any accessed property"""
+
+ def __init__(self, value: T) -> None:
+ self.value = value
+
+ def __getattribute__(self, name: str) -> T:
+ return object.__getattribute__(self, "value")
+
+
+class DummyLink(ABC):
+ """Dummy placeholder for `opentelemetry.trace.Link`"""
+
+ def __init__(self) -> None:
+ self.not_implemented_message = (
+ "opentelemetry isn't installed so this is just a dummy link placeholder"
+ )
+
+ @property
+ def context(self) -> None:
+ raise NotImplementedError(self.not_implemented_message)
+
+ @property
+ def attributes(self) -> None:
+ raise NotImplementedError(self.not_implemented_message)
+
+
+# These dependencies are optional so they can fail to import
+# and we
+try:
+ import opentelemetry
+ import opentelemetry.exporter.jaeger.thrift
+ import opentelemetry.propagate
+ import opentelemetry.sdk.resources
+ import opentelemetry.sdk.trace
+ import opentelemetry.sdk.trace.export
+ import opentelemetry.semconv.trace
+ import opentelemetry.trace
+ import opentelemetry.trace.propagation
+ import opentelemetry.trace.status
+
+ SpanKind = opentelemetry.trace.SpanKind
+ SpanAttributes = opentelemetry.semconv.trace.SpanAttributes
+ StatusCode = opentelemetry.trace.status.StatusCode
+ Link = opentelemetry.trace.Link
+except ImportError:
+ opentelemetry = None # type: ignore[assignment]
+ SpanKind = _DummyLookup(0) # type: ignore
+ SpanAttributes = _DummyLookup("fake-attribute") # type: ignore
+ StatusCode = _DummyLookup(0) # type: ignore
+ Link = DummyLink # type: ignore
+
+
+logger = logging.getLogger(__name__)
+
+
+class SynapseTags:
+ """FIXME: Rename to `SynapseAttributes` so it matches OpenTelemetry `SpanAttributes`"""
+
+ # The message ID of any to_device message processed
+ TO_DEVICE_MESSAGE_ID = "to_device.message_id"
+
+ # Whether the sync response has new data to be returned to the client.
+ SYNC_RESULT = "sync.new_data"
+
+ # incoming HTTP request ID (as written in the logs)
+ REQUEST_ID = "request_id"
+
+ # HTTP request tag (used to distinguish full vs incremental syncs, etc)
+ REQUEST_TAG = "request_tag"
+
+ # Text description of a database transaction
+ DB_TXN_DESC = "db.txn_desc"
+
+ # Uniqueish ID of a database transaction
+ DB_TXN_ID = "db.txn_id"
+
+ # The name of the external cache
+ CACHE_NAME = "cache.name"
+
+
+class SynapseBaggage:
+ FORCE_TRACING = "synapse-force-tracing"
+
+
+# Block everything by default
+# A regex which matches the server_names to expose traces for.
+# None means 'block everything'.
+_homeserver_whitelist: Optional[Pattern[str]] = None
+
+# Util methods
+
+
+P = ParamSpec("P")
+R = TypeVar("R")
+
+
+def only_if_tracing(func: Callable[P, R]) -> Callable[P, Optional[R]]:
+ """Decorator function that executes the function only if we're tracing. Otherwise returns None."""
+
+ @wraps(func)
+ def _only_if_tracing_inner(*args: P.args, **kwargs: P.kwargs) -> Optional[R]:
+ if opentelemetry:
+ return func(*args, **kwargs)
+ else:
+ return None
+
+ return _only_if_tracing_inner
+
+
+@overload
+def ensure_active_span(
+ message: str,
+) -> Callable[[Callable[P, R]], Callable[P, Optional[R]]]:
+ ...
+
+
+@overload
+def ensure_active_span(
+ message: str, ret: T
+) -> Callable[[Callable[P, R]], Callable[P, Union[T, R]]]:
+ ...
+
+
+def ensure_active_span(
+ message: str, ret: Optional[T] = None
+) -> Callable[[Callable[P, R]], Callable[P, Union[Optional[T], R]]]:
+ """Executes the operation only if opentelemetry is enabled and there is an active span.
+ If there is no active span it logs message at the error level.
+
+ Args:
+ message: Message which fills in "There was no active span when trying to %s"
+ in the error log if there is no active span and opentelemetry is enabled.
+ ret: return value if opentelemetry is None or there is no active span.
+
+ Returns:
+ The result of the func, falling back to ret if opentelemetry is disabled or there
+ was no active span.
+ """
+
+ def ensure_active_span_inner_1(
+ func: Callable[P, R]
+ ) -> Callable[P, Union[Optional[T], R]]:
+ @wraps(func)
+ def ensure_active_span_inner_2(
+ *args: P.args, **kwargs: P.kwargs
+ ) -> Union[Optional[T], R]:
+ if not opentelemetry:
+ return ret
+
+ if not opentelemetry.trace.get_current_span():
+ logger.error(
+ "There was no active span when trying to %s."
+ " Did you forget to start one or did a context slip?",
+ message,
+ stack_info=True,
+ )
+
+ return ret
+
+ return func(*args, **kwargs)
+
+ return ensure_active_span_inner_2
+
+ return ensure_active_span_inner_1
+
+
+# Setup
+
+
+def init_tracer(hs: "HomeServer") -> None:
+ """Set the whitelists and initialise the OpenTelemetry tracer"""
+ global opentelemetry
+ if not hs.config.tracing.tracing_enabled:
+ # We don't have a tracer
+ opentelemetry = None # type: ignore[assignment]
+ return
+
+ if not opentelemetry:
+ raise ConfigError(
+ "The server has been configured to use OpenTelemetry but OpenTelemetry is not "
+ "installed."
+ )
+
+ # Pull out of the config if it was given. Otherwise set it to something sensible.
+ set_homeserver_whitelist(hs.config.tracing.homeserver_whitelist)
+
+ resource = opentelemetry.sdk.resources.Resource(
+ attributes={
+ opentelemetry.sdk.resources.SERVICE_NAME: f"{hs.config.server.server_name} {hs.get_instance_name()}"
+ }
+ )
+
+ # TODO: `force_tracing_for_users` is not compatible with OTEL samplers
+ # because you can only determine `opentelemetry.trace.TraceFlags.SAMPLED`
+ # and whether it uses a recording span when the span is created and we don't
+ # have enough information at that time (we can determine in
+ # `synapse/api/auth.py`). There isn't a way to change the trace flags after
+ # the fact so there is no way to programmatically force
+ # recording/tracing/sampling like there was in opentracing.
+ sampler = opentelemetry.sdk.trace.sampling.ParentBasedTraceIdRatio(
+ hs.config.tracing.sample_rate
+ )
+
+ tracer_provider = opentelemetry.sdk.trace.TracerProvider(
+ resource=resource, sampler=sampler
+ )
+
+ # consoleProcessor = opentelemetry.sdk.trace.export.BatchSpanProcessor(
+ # opentelemetry.sdk.trace.export.ConsoleSpanExporter()
+ # )
+ # tracer_provider.add_span_processor(consoleProcessor)
+
+ jaeger_exporter = opentelemetry.exporter.jaeger.thrift.JaegerExporter(
+ **hs.config.tracing.jaeger_exporter_config
+ )
+ jaeger_processor = opentelemetry.sdk.trace.export.BatchSpanProcessor(
+ jaeger_exporter
+ )
+ tracer_provider.add_span_processor(jaeger_processor)
+
+ # Sets the global default tracer provider
+ opentelemetry.trace.set_tracer_provider(tracer_provider)
+
+
+# Whitelisting
+
+
+@only_if_tracing
+def set_homeserver_whitelist(homeserver_whitelist: Iterable[str]) -> None:
+ """Sets the homeserver whitelist
+
+ Args:
+ homeserver_whitelist: regexes specifying whitelisted homeservers
+ """
+ global _homeserver_whitelist
+ if homeserver_whitelist:
+ # Makes a single regex which accepts all passed in regexes in the list
+ _homeserver_whitelist = re.compile(
+ "({})".format(")|(".join(homeserver_whitelist))
+ )
+
+
+@only_if_tracing
+def whitelisted_homeserver(destination: str) -> bool:
+ """Checks if a destination matches the whitelist
+
+ Args:
+ destination
+ """
+
+ if _homeserver_whitelist:
+ return _homeserver_whitelist.match(destination) is not None
+ return False
+
+
+# Start spans and scopes
+
+
+def use_span(
+ span: "opentelemetry.trace.Span",
+ end_on_exit: bool = True,
+) -> ContextManager["opentelemetry.trace.Span"]:
+ if opentelemetry is None:
+ return contextlib.nullcontext() # type: ignore[unreachable]
+
+ return opentelemetry.trace.use_span(span=span, end_on_exit=end_on_exit)
+
+
+def create_non_recording_span() -> "opentelemetry.trace.Span":
+ """Create a no-op span that does not record or become part of a recorded trace"""
+
+ return opentelemetry.trace.NonRecordingSpan(
+ opentelemetry.trace.INVALID_SPAN_CONTEXT
+ )
+
+
+def start_span(
+ name: str,
+ *,
+ context: Optional["opentelemetry.context.context.Context"] = None,
+ kind: Optional["opentelemetry.trace.SpanKind"] = SpanKind.INTERNAL,
+ attributes: "opentelemetry.util.types.Attributes" = None,
+ links: Optional[Sequence["opentelemetry.trace.Link"]] = None,
+ start_time: Optional[int] = None,
+ record_exception: bool = True,
+ set_status_on_exception: bool = True,
+ end_on_exit: bool = True,
+ # For testing only
+ tracer: Optional["opentelemetry.trace.Tracer"] = None,
+) -> "opentelemetry.trace.Span":
+ if opentelemetry is None:
+ raise Exception("Not able to create span without opentelemetry installed.")
+
+ if tracer is None:
+ tracer = opentelemetry.trace.get_tracer(__name__)
+
+ # TODO: Why is this necessary to satisfy this error? It has a default?
+ # ` error: Argument "kind" to "start_span" of "Tracer" has incompatible type "Optional[SpanKind]"; expected "SpanKind" [arg-type]`
+ if kind is None:
+ kind = SpanKind.INTERNAL
+
+ return tracer.start_span(
+ name=name,
+ context=context,
+ kind=kind,
+ attributes=attributes,
+ links=links,
+ start_time=start_time,
+ record_exception=record_exception,
+ set_status_on_exception=set_status_on_exception,
+ )
+
+
+def start_active_span(
+ name: str,
+ *,
+ context: Optional["opentelemetry.context.context.Context"] = None,
+ kind: Optional["opentelemetry.trace.SpanKind"] = SpanKind.INTERNAL,
+ attributes: "opentelemetry.util.types.Attributes" = None,
+ links: Optional[Sequence["opentelemetry.trace.Link"]] = None,
+ start_time: Optional[int] = None,
+ record_exception: bool = True,
+ set_status_on_exception: bool = True,
+ end_on_exit: bool = True,
+ # For testing only
+ tracer: Optional["opentelemetry.trace.Tracer"] = None,
+) -> ContextManager["opentelemetry.trace.Span"]:
+ if opentelemetry is None:
+ return contextlib.nullcontext() # type: ignore[unreachable]
+
+ # TODO: Why is this necessary to satisfy this error? It has a default?
+ # ` error: Argument "kind" to "start_span" of "Tracer" has incompatible type "Optional[SpanKind]"; expected "SpanKind" [arg-type]`
+ if kind is None:
+ kind = SpanKind.INTERNAL
+
+ span = start_span(
+ name=name,
+ context=context,
+ kind=kind,
+ attributes=attributes,
+ links=links,
+ start_time=start_time,
+ record_exception=record_exception,
+ set_status_on_exception=set_status_on_exception,
+ tracer=tracer,
+ )
+
+ # Equivalent to `tracer.start_as_current_span`
+ return opentelemetry.trace.use_span(
+ span,
+ end_on_exit=end_on_exit,
+ record_exception=record_exception,
+ set_status_on_exception=set_status_on_exception,
+ )
+
+
+def start_active_span_from_edu(
+ operation_name: str,
+ *,
+ edu_content: Dict[str, Any],
+) -> ContextManager["opentelemetry.trace.Span"]:
+ """
+ Extracts a span context from an edu and uses it to start a new active span
+
+ Args:
+ operation_name: The label for the chunk of time used to process the given edu.
+ edu_content: an edu_content with a `context` field whose value is
+ canonical json for a dict which contains tracing information.
+ """
+ if opentelemetry is None:
+ return contextlib.nullcontext() # type: ignore[unreachable]
+
+ carrier = json_decoder.decode(
+ edu_content.get(EventContentFields.TRACING_CONTEXT, "{}")
+ )
+
+ context = extract_text_map(carrier)
+
+ return start_active_span(name=operation_name, context=context)
+
+
+# OpenTelemetry setters for attributes, logs, etc
+@only_if_tracing
+def get_active_span() -> Optional["opentelemetry.trace.Span"]:
+ """Get the currently active span, if any"""
+ return opentelemetry.trace.get_current_span()
+
+
+def get_span_context_from_context(
+ context: "opentelemetry.context.context.Context",
+) -> Optional["opentelemetry.trace.SpanContext"]:
+ """Utility function to convert a `Context` to a `SpanContext`
+
+ Based on https://github.com/open-telemetry/opentelemetry-python/blob/43288ca9a36144668797c11ca2654836ec8b5e99/opentelemetry-api/src/opentelemetry/trace/propagation/tracecontext.py#L99-L102
+ """
+ span = opentelemetry.trace.get_current_span(context=context)
+ span_context = span.get_span_context()
+ if span_context == opentelemetry.trace.INVALID_SPAN_CONTEXT:
+ return None
+ return span_context
+
+
+def get_context_from_span(
+ span: "opentelemetry.trace.Span",
+) -> "opentelemetry.context.context.Context":
+ # This doesn't affect the current context at all, it just converts a span
+ # into `Context` object basically (bad name).
+ ctx = opentelemetry.trace.propagation.set_span_in_context(span)
+ return ctx
+
+
+@ensure_active_span("set a tag")
+def set_attribute(key: str, value: Union[str, bool, int, float]) -> None:
+ """Sets a tag on the active span"""
+ active_span = get_active_span()
+ assert active_span is not None
+ active_span.set_attribute(key, value)
+
+
+@ensure_active_span("set the status")
+def set_status(
+ status_code: "opentelemetry.trace.status.StatusCode", exc: Optional[Exception]
+) -> None:
+ """Sets a tag on the active span"""
+ active_span = get_active_span()
+ assert active_span is not None
+ active_span.set_status(opentelemetry.trace.status.Status(status_code=status_code))
+ if exc:
+ active_span.record_exception(exc)
+
+
+DEFAULT_LOG_NAME = "log"
+
+
+@ensure_active_span("log")
+def log_kv(key_values: Dict[str, Any], timestamp: Optional[int] = None) -> None:
+ """Log to the active span"""
+ active_span = get_active_span()
+ assert active_span is not None
+ event_name = key_values.get("event", DEFAULT_LOG_NAME)
+ active_span.add_event(event_name, attributes=key_values, timestamp=timestamp)
+
+
+@only_if_tracing
+def force_tracing(span: Optional["opentelemetry.trace.Span"] = None) -> None:
+ """Force sampling for the active/given span and its children.
+
+ Args:
+ span: span to force tracing for. By default, the active span.
+ """
+ # TODO
+ pass
+
+
+def is_context_forced_tracing(
+ context: "opentelemetry.context.context.Context",
+) -> bool:
+ """Check if sampling has been force for the given span context."""
+ # TODO
+ return False
+
+
+# Injection and extraction
+
+
+@ensure_active_span("inject the active tracing context into a header dict")
+def inject_active_tracing_context_into_header_dict(
+ headers: Dict[bytes, List[bytes]],
+ destination: Optional[str] = None,
+ check_destination: bool = True,
+) -> None:
+ """
+ Injects the active tracing context into a dict of HTTP headers
+
+ Args:
+ headers: the dict to inject headers into
+ destination: address of entity receiving the span context. Must be given unless
+ `check_destination` is False.
+ check_destination (bool): If False, destination will be ignored and the context
+ will always be injected. If True, the context will only be injected if the
+ destination matches the tracing allowlist
+
+ Note:
+ The headers set by the tracer are custom to the tracer implementation which
+ should be unique enough that they don't interfere with any headers set by
+ synapse or twisted.
+ """
+ if check_destination:
+ if destination is None:
+ raise ValueError(
+ "destination must be given unless check_destination is False"
+ )
+ if not whitelisted_homeserver(destination):
+ return
+
+ active_span = get_active_span()
+ assert active_span is not None
+ ctx = get_context_from_span(active_span)
+
+ propagator = opentelemetry.propagate.get_global_textmap()
+ # Put all of SpanContext properties into the headers dict
+ propagator.inject(headers, context=ctx)
+
+
+def inject_trace_id_into_response_headers(response_headers: Headers) -> None:
+ """Inject the current trace id into the HTTP response headers"""
+ if not opentelemetry:
+ return
+ active_span = get_active_span()
+ if not active_span:
+ return
+
+ trace_id = active_span.get_span_context().trace_id
+
+ if trace_id is not None:
+ response_headers.addRawHeader("Synapse-Trace-Id", f"{trace_id:x}")
+
+
+@ensure_active_span(
+ "get the active span context as a dict", ret=cast(Dict[str, str], {})
+)
+def get_active_span_text_map(destination: Optional[str] = None) -> Dict[str, str]:
+ """
+ Gets the active tracing Context serialized as a dict. This can be used
+ instead of manually injecting a span into an empty carrier.
+
+ Args:
+ destination: the name of the remote server.
+
+ Returns:
+ dict: the serialized active span's context if opentelemetry is enabled, otherwise
+ empty.
+ """
+ if destination and not whitelisted_homeserver(destination):
+ return {}
+
+ active_span = get_active_span()
+ assert active_span is not None
+ ctx = get_context_from_span(active_span)
+
+ carrier_text_map: Dict[str, str] = {}
+ propagator = opentelemetry.propagate.get_global_textmap()
+ # Put all of Context properties onto the carrier text map that we can return
+ propagator.inject(carrier_text_map, context=ctx)
+
+ return carrier_text_map
+
+
+def context_from_request(
+ request: Request,
+) -> Optional["opentelemetry.context.context.Context"]:
+ """Extract an opentelemetry context from the headers on an HTTP request
+
+ This is useful when we have received an HTTP request from another part of our
+ system, and want to link our spans to those of the remote system.
+ """
+ if not opentelemetry:
+ return None
+ header_dict = {
+ k.decode(): v[0].decode() for k, v in request.requestHeaders.getAllRawHeaders()
+ }
+
+ # Extract all of the relevant values from the headers to construct a
+ # SpanContext to return.
+ return extract_text_map(header_dict)
+
+
+@only_if_tracing
+def extract_text_map(
+ carrier: Dict[str, str]
+) -> Optional["opentelemetry.context.context.Context"]:
+ """
+ Wrapper method for opentelemetry's propagator.extract for TEXT_MAP.
+ Args:
+ carrier: a dict possibly containing a context.
+
+ Returns:
+ The active context extracted from carrier.
+ """
+ propagator = opentelemetry.propagate.get_global_textmap()
+ # Extract all of the relevant values from the `carrier` to construct a
+ # Context to return.
+ return propagator.extract(carrier)
+
+
+# Tracing decorators
+
+
+def trace_with_opname(opname: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
+ """
+ Decorator to trace a function with a custom opname.
+
+ See the module's doc string for usage examples.
+
+ """
+
+ def decorator(func: Callable[P, R]) -> Callable[P, R]:
+ if opentelemetry is None:
+ return func # type: ignore[unreachable]
+
+ if inspect.iscoroutinefunction(func):
+
+ @wraps(func)
+ async def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R:
+ with start_active_span(opname):
+ return await func(*args, **kwargs) # type: ignore[misc]
+
+ else:
+ # The other case here handles both sync functions and those
+ # decorated with inlineDeferred.
+ @wraps(func)
+ def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R:
+ scope = start_active_span(opname)
+ scope.__enter__()
+
+ try:
+ result = func(*args, **kwargs)
+ if isinstance(result, defer.Deferred):
+
+ def call_back(result: R) -> R:
+ scope.__exit__(None, None, None)
+ return result
+
+ def err_back(result: R) -> R:
+ scope.__exit__(None, None, None)
+ return result
+
+ result.addCallbacks(call_back, err_back)
+
+ else:
+ if inspect.isawaitable(result):
+ logger.error(
+ "@trace may not have wrapped %s correctly! "
+ "The function is not async but returned a %s.",
+ func.__qualname__,
+ type(result).__name__,
+ )
+
+ scope.__exit__(None, None, None)
+
+ return result
+
+ except Exception as e:
+ scope.__exit__(type(e), None, e.__traceback__)
+ raise
+
+ return _trace_inner # type: ignore[return-value]
+
+ return decorator
+
+
+def trace(func: Callable[P, R]) -> Callable[P, R]:
+ """
+ Decorator to trace a function.
+
+ Sets the operation name to that of the function's name.
+
+ See the module's doc string for usage examples.
+ """
+
+ return trace_with_opname(func.__name__)(func)
+
+
+def tag_args(func: Callable[P, R]) -> Callable[P, R]:
+ """
+ Tags all of the args to the active span.
+ """
+
+ if not opentelemetry:
+ return func
+
+ @wraps(func)
+ def _tag_args_inner(*args: P.args, **kwargs: P.kwargs) -> R:
+ argspec = inspect.getfullargspec(func)
+ for i, arg in enumerate(argspec.args[1:]):
+ set_attribute("ARG_" + arg, str(args[i])) # type: ignore[index]
+ set_attribute("args", str(args[len(argspec.args) :])) # type: ignore[index]
+ set_attribute("kwargs", str(kwargs))
+ return func(*args, **kwargs)
+
+ return _tag_args_inner
+
+
+@contextlib.contextmanager
+def trace_servlet(
+ request: "SynapseRequest", extract_context: bool = False
+) -> Generator[None, None, None]:
+ """Returns a context manager which traces a request. It starts a span
+ with some servlet specific tags such as the request metrics name and
+ request information.
+
+ Args:
+ request
+ extract_context: Whether to attempt to extract the tracing
+ context from the request the servlet is handling.
+ """
+
+ if opentelemetry is None:
+ yield # type: ignore[unreachable]
+ return
+
+ attrs = {
+ SynapseTags.REQUEST_ID: request.get_request_id(),
+ SpanAttributes.HTTP_METHOD: request.get_method(),
+ SpanAttributes.HTTP_URL: request.get_redacted_uri(),
+ SpanAttributes.HTTP_HOST: request.getClientAddress().host,
+ }
+
+ request_name = request.request_metrics.name
+ tracing_context = context_from_request(request) if extract_context else None
+
+ # This is will end up being the root span for all of servlet traces and we
+ # aren't able to determine whether to force tracing yet. We can determine
+ # whether to force trace later in `synapse/api/auth.py`.
+ with start_active_span(
+ request_name,
+ kind=SpanKind.SERVER,
+ context=tracing_context,
+ attributes=attrs,
+ # we configure the span not to finish immediately on exiting the scope,
+ # and instead pass the span into the SynapseRequest (via
+ # `request.set_tracing_span(span)`), which will finish it once we've
+ # finished sending the response to the client.
+ end_on_exit=False,
+ ) as span:
+ request.set_tracing_span(span)
+
+ inject_trace_id_into_response_headers(request.responseHeaders)
+ try:
+ yield
+ finally:
+ # We set the operation name again in case its changed (which happens
+ # with JsonResource).
+ span.update_name(request.request_metrics.name)
+
+ if request.request_metrics.start_context.tag is not None:
+ span.set_attribute(
+ SynapseTags.REQUEST_TAG, request.request_metrics.start_context.tag
+ )
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index 7a1516d3a8..59d956dd9d 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -42,7 +42,7 @@ from synapse.logging.context import (
LoggingContext,
PreserveLoggingContext,
)
-from synapse.logging.opentracing import SynapseTags, start_active_span
+from synapse.logging.tracing import SynapseTags, start_active_span
from synapse.metrics._types import Collector
if TYPE_CHECKING:
@@ -208,7 +208,7 @@ def run_as_background_process(
Args:
desc: a description for this background process type
func: a function, which may return a Deferred or a coroutine
- bg_start_span: Whether to start an opentracing span. Defaults to True.
+ bg_start_span: Whether to start an tracing span. Defaults to True.
Should only be disabled for processes that will not log to or tag
a span.
args: positional args for func
@@ -232,7 +232,8 @@ def run_as_background_process(
try:
if bg_start_span:
ctx = start_active_span(
- f"bgproc.{desc}", tags={SynapseTags.REQUEST_ID: str(context)}
+ f"bgproc.{desc}",
+ attributes={SynapseTags.REQUEST_ID: str(context)},
)
else:
ctx = nullcontext() # type: ignore[assignment]
diff --git a/synapse/notifier.py b/synapse/notifier.py
index c42bb8266a..8fd8cb8100 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -39,7 +39,7 @@ from synapse.events import EventBase
from synapse.handlers.presence import format_user_presence_state
from synapse.logging import issue9533_logger
from synapse.logging.context import PreserveLoggingContext
-from synapse.logging.opentracing import log_kv, start_active_span
+from synapse.logging.tracing import log_kv, start_active_span
from synapse.metrics import LaterGauge
from synapse.streams.config import PaginationConfig
from synapse.types import (
@@ -536,7 +536,7 @@ class Notifier:
log_kv(
{
"wait_for_events": "sleep",
- "token": prev_token,
+ "token": str(prev_token),
}
)
@@ -546,7 +546,7 @@ class Notifier:
log_kv(
{
"wait_for_events": "woken",
- "token": user_stream.current_token,
+ "token": str(user_stream.current_token),
}
)
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index e96fb45e9f..11299367d2 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -23,7 +23,7 @@ from twisted.internet.interfaces import IDelayedCall
from synapse.api.constants import EventTypes
from synapse.events import EventBase
-from synapse.logging import opentracing
+from synapse.logging import tracing
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import Pusher, PusherConfig, PusherConfigException
from synapse.storage.databases.main.event_push_actions import HttpPushAction
@@ -198,9 +198,9 @@ class HttpPusher(Pusher):
)
for push_action in unprocessed:
- with opentracing.start_active_span(
+ with tracing.start_active_span(
"http-push",
- tags={
+ attributes={
"authenticated_entity": self.user_id,
"event_id": push_action.event_id,
"app_id": self.app_id,
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 561ad5bf04..af160e31aa 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -28,8 +28,8 @@ from synapse.api.errors import HttpResponseException, SynapseError
from synapse.http import RequestTimedOutError
from synapse.http.server import HttpServer, is_method_cancellable
from synapse.http.site import SynapseRequest
-from synapse.logging import opentracing
-from synapse.logging.opentracing import trace_with_opname
+from synapse.logging import tracing
+from synapse.logging.tracing import trace_with_opname
from synapse.types import JsonDict
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import random_string
@@ -248,7 +248,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
# Add an authorization header, if configured.
if replication_secret:
headers[b"Authorization"] = [b"Bearer " + replication_secret]
- opentracing.inject_header_dict(headers, check_destination=False)
+ tracing.inject_active_tracing_context_into_header_dict(
+ headers, check_destination=False
+ )
try:
# Keep track of attempts made so we can bail if we don't manage to
diff --git a/synapse/replication/tcp/external_cache.py b/synapse/replication/tcp/external_cache.py
index a448dd7eb1..e928fded36 100644
--- a/synapse/replication/tcp/external_cache.py
+++ b/synapse/replication/tcp/external_cache.py
@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Any, Optional
from prometheus_client import Counter, Histogram
-from synapse.logging import opentracing
+from synapse.logging import tracing
from synapse.logging.context import make_deferred_yieldable
from synapse.util import json_decoder, json_encoder
@@ -94,9 +94,9 @@ class ExternalCache:
logger.debug("Caching %s %s: %r", cache_name, key, encoded_value)
- with opentracing.start_active_span(
+ with tracing.start_active_span(
"ExternalCache.set",
- tags={opentracing.SynapseTags.CACHE_NAME: cache_name},
+ attributes={tracing.SynapseTags.CACHE_NAME: cache_name},
):
with response_timer.labels("set").time():
return await make_deferred_yieldable(
@@ -113,9 +113,9 @@ class ExternalCache:
if self._redis_connection is None:
return None
- with opentracing.start_active_span(
+ with tracing.start_active_span(
"ExternalCache.get",
- tags={opentracing.SynapseTags.CACHE_NAME: cache_name},
+ attributes={tracing.SynapseTags.CACHE_NAME: cache_name},
):
with response_timer.labels("get").time():
result = await make_deferred_yieldable(
diff --git a/synapse/res/templates/sso_auth_account_details.html b/synapse/res/templates/sso_auth_account_details.html
index 1ba850369a..cf72df0a2a 100644
--- a/synapse/res/templates/sso_auth_account_details.html
+++ b/synapse/res/templates/sso_auth_account_details.html
@@ -138,7 +138,7 @@
<div class="username_input" id="username_input">
<label for="field-username">Username (required)</label>
<div class="prefix">@</div>
- <input type="text" name="username" id="field-username" value="{{ user_attributes.localpart }}" autofocus>
+ <input type="text" name="username" id="field-username" value="{{ user_attributes.localpart }}" autofocus autocorrect="off" autocapitalize="none">
<div class="postfix">:{{ server_name }}</div>
</div>
<output for="username_input" id="field-username-output"></output>
diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py
index 0cc87a4001..50edc6b7d3 100644
--- a/synapse/rest/client/account.py
+++ b/synapse/rest/client/account.py
@@ -28,6 +28,7 @@ from synapse.api.errors import (
SynapseError,
ThreepidValidationError,
)
+from synapse.config.emailconfig import ThreepidBehaviour
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http.server import HttpServer, finish_request, respond_with_html
from synapse.http.servlet import (
@@ -63,7 +64,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
self.config = hs.config
self.identity_handler = hs.get_identity_handler()
- if self.config.email.can_verify_email:
+ if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
self.mailer = Mailer(
hs=self.hs,
app_name=self.config.email.email_app_name,
@@ -72,10 +73,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
)
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- if not self.config.email.can_verify_email:
- logger.warning(
- "User password resets have been disabled due to lack of email config"
- )
+ if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF:
+ if self.config.email.local_threepid_handling_disabled_due_to_email_config:
+ logger.warning(
+ "User password resets have been disabled due to lack of email config"
+ )
raise SynapseError(
400, "Email-based password resets have been disabled on this server"
)
@@ -127,21 +129,35 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
- # Send password reset emails from Synapse
- sid = await self.identity_handler.send_threepid_validation(
- email,
- client_secret,
- send_attempt,
- self.mailer.send_password_reset_mail,
- next_link,
- )
+ if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
+ assert self.hs.config.registration.account_threepid_delegate_email
+
+ # Have the configured identity server handle the request
+ ret = await self.identity_handler.request_email_token(
+ self.hs.config.registration.account_threepid_delegate_email,
+ email,
+ client_secret,
+ send_attempt,
+ next_link,
+ )
+ else:
+ # Send password reset emails from Synapse
+ sid = await self.identity_handler.send_threepid_validation(
+ email,
+ client_secret,
+ send_attempt,
+ self.mailer.send_password_reset_mail,
+ next_link,
+ )
+
+ # Wrap the session id in a JSON object
+ ret = {"sid": sid}
threepid_send_requests.labels(type="email", reason="password_reset").observe(
send_attempt
)
- # Wrap the session id in a JSON object
- return 200, {"sid": sid}
+ return 200, ret
class PasswordRestServlet(RestServlet):
@@ -333,7 +349,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
self.identity_handler = hs.get_identity_handler()
self.store = self.hs.get_datastores().main
- if self.config.email.can_verify_email:
+ if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
self.mailer = Mailer(
hs=self.hs,
app_name=self.config.email.email_app_name,
@@ -342,10 +358,11 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
)
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- if not self.config.email.can_verify_email:
- logger.warning(
- "Adding emails have been disabled due to lack of an email config"
- )
+ if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF:
+ if self.config.email.local_threepid_handling_disabled_due_to_email_config:
+ logger.warning(
+ "Adding emails have been disabled due to lack of an email config"
+ )
raise SynapseError(
400, "Adding an email to your account is disabled on this server"
)
@@ -396,20 +413,35 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
- sid = await self.identity_handler.send_threepid_validation(
- email,
- client_secret,
- send_attempt,
- self.mailer.send_add_threepid_mail,
- next_link,
- )
+ if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
+ assert self.hs.config.registration.account_threepid_delegate_email
+
+ # Have the configured identity server handle the request
+ ret = await self.identity_handler.request_email_token(
+ self.hs.config.registration.account_threepid_delegate_email,
+ email,
+ client_secret,
+ send_attempt,
+ next_link,
+ )
+ else:
+ # Send threepid validation emails from Synapse
+ sid = await self.identity_handler.send_threepid_validation(
+ email,
+ client_secret,
+ send_attempt,
+ self.mailer.send_add_threepid_mail,
+ next_link,
+ )
+
+ # Wrap the session id in a JSON object
+ ret = {"sid": sid}
threepid_send_requests.labels(type="email", reason="add_threepid").observe(
send_attempt
)
- # Wrap the session id in a JSON object
- return 200, {"sid": sid}
+ return 200, ret
class MsisdnThreepidRequestTokenRestServlet(RestServlet):
@@ -502,19 +534,25 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
self.config = hs.config
self.clock = hs.get_clock()
self.store = hs.get_datastores().main
- if self.config.email.can_verify_email:
+ if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
self._failure_email_template = (
self.config.email.email_add_threepid_template_failure_html
)
async def on_GET(self, request: Request) -> None:
- if not self.config.email.can_verify_email:
- logger.warning(
- "Adding emails have been disabled due to lack of an email config"
- )
+ if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF:
+ if self.config.email.local_threepid_handling_disabled_due_to_email_config:
+ logger.warning(
+ "Adding emails have been disabled due to lack of an email config"
+ )
raise SynapseError(
400, "Adding an email to your account is disabled on this server"
)
+ elif self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
+ raise SynapseError(
+ 400,
+ "This homeserver is not validating threepids.",
+ )
sid = parse_string(request, "sid", required=True)
token = parse_string(request, "token", required=True)
diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py
index e3f454896a..a592fd2cfb 100644
--- a/synapse/rest/client/keys.py
+++ b/synapse/rest/client/keys.py
@@ -26,7 +26,7 @@ from synapse.http.servlet import (
parse_string,
)
from synapse.http.site import SynapseRequest
-from synapse.logging.opentracing import log_kv, set_tag, trace_with_opname
+from synapse.logging.tracing import log_kv, set_attribute, trace_with_opname
from synapse.types import JsonDict, StreamToken
from ._base import client_patterns, interactive_auth_handler
@@ -88,7 +88,7 @@ class KeyUploadServlet(RestServlet):
user_id
)
if dehydrated_device is not None and device_id != dehydrated_device[0]:
- set_tag("error", True)
+ set_attribute("error", True)
log_kv(
{
"message": "Client uploading keys for a different device",
@@ -204,13 +204,13 @@ class KeyChangesServlet(RestServlet):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
from_token_string = parse_string(request, "from", required=True)
- set_tag("from", from_token_string)
+ set_attribute("from", from_token_string)
# We want to enforce they do pass us one, but we ignore it and return
# changes after the "to" as well as before.
#
# XXX This does not enforce that "to" is passed.
- set_tag("to", str(parse_string(request, "to")))
+ set_attribute("to", str(parse_string(request, "to")))
from_token = await StreamToken.from_string(self.store, from_token_string)
diff --git a/synapse/rest/client/knock.py b/synapse/rest/client/knock.py
index ad025c8a45..8201f2bd86 100644
--- a/synapse/rest/client/knock.py
+++ b/synapse/rest/client/knock.py
@@ -24,7 +24,7 @@ from synapse.http.servlet import (
parse_strings_from_args,
)
from synapse.http.site import SynapseRequest
-from synapse.logging.opentracing import set_tag
+from synapse.logging.tracing import set_attribute
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.types import JsonDict, RoomAlias, RoomID
@@ -97,7 +97,7 @@ class KnockRoomAliasServlet(RestServlet):
def on_PUT(
self, request: SynapseRequest, room_identifier: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
- set_tag("txn_id", txn_id)
+ set_attribute("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_identifier, txn_id
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index a8402cdb3a..b7ab090bbd 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -31,6 +31,7 @@ from synapse.api.errors import (
)
from synapse.api.ratelimiting import Ratelimiter
from synapse.config import ConfigError
+from synapse.config.emailconfig import ThreepidBehaviour
from synapse.config.homeserver import HomeServerConfig
from synapse.config.ratelimiting import FederationRateLimitConfig
from synapse.config.server import is_threepid_reserved
@@ -73,7 +74,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
self.identity_handler = hs.get_identity_handler()
self.config = hs.config
- if self.hs.config.email.can_verify_email:
+ if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
self.mailer = Mailer(
hs=self.hs,
app_name=self.config.email.email_app_name,
@@ -82,10 +83,13 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
)
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- if not self.hs.config.email.can_verify_email:
- logger.warning(
- "Email registration has been disabled due to lack of email config"
- )
+ if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF:
+ if (
+ self.hs.config.email.local_threepid_handling_disabled_due_to_email_config
+ ):
+ logger.warning(
+ "Email registration has been disabled due to lack of email config"
+ )
raise SynapseError(
400, "Email-based registration has been disabled on this server"
)
@@ -134,21 +138,35 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
- # Send registration emails from Synapse
- sid = await self.identity_handler.send_threepid_validation(
- email,
- client_secret,
- send_attempt,
- self.mailer.send_registration_mail,
- next_link,
- )
+ if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
+ assert self.hs.config.registration.account_threepid_delegate_email
+
+ # Have the configured identity server handle the request
+ ret = await self.identity_handler.request_email_token(
+ self.hs.config.registration.account_threepid_delegate_email,
+ email,
+ client_secret,
+ send_attempt,
+ next_link,
+ )
+ else:
+ # Send registration emails from Synapse,
+ # wrapping the session id in a JSON object.
+ ret = {
+ "sid": await self.identity_handler.send_threepid_validation(
+ email,
+ client_secret,
+ send_attempt,
+ self.mailer.send_registration_mail,
+ next_link,
+ )
+ }
threepid_send_requests.labels(type="email", reason="register").observe(
send_attempt
)
- # Wrap the session id in a JSON object
- return 200, {"sid": sid}
+ return 200, ret
class MsisdnRegisterRequestTokenRestServlet(RestServlet):
@@ -242,7 +260,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
self.clock = hs.get_clock()
self.store = hs.get_datastores().main
- if self.config.email.can_verify_email:
+ if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
self._failure_email_template = (
self.config.email.email_registration_template_failure_html
)
@@ -252,10 +270,11 @@ class RegistrationSubmitTokenServlet(RestServlet):
raise SynapseError(
400, "This medium is currently not supported for registration"
)
- if not self.config.email.can_verify_email:
- logger.warning(
- "User registration via email has been disabled due to lack of email config"
- )
+ if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF:
+ if self.config.email.local_threepid_handling_disabled_due_to_email_config:
+ logger.warning(
+ "User registration via email has been disabled due to lack of email config"
+ )
raise SynapseError(
400, "Email-based registration is disabled on this server"
)
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index 2f513164cb..3880846e9a 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -46,7 +46,7 @@ from synapse.http.servlet import (
parse_strings_from_args,
)
from synapse.http.site import SynapseRequest
-from synapse.logging.opentracing import set_tag
+from synapse.logging.tracing import set_attribute
from synapse.rest.client._base import client_patterns
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.storage.state import StateFilter
@@ -82,7 +82,7 @@ class RoomCreateRestServlet(TransactionRestServlet):
def on_PUT(
self, request: SynapseRequest, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
- set_tag("txn_id", txn_id)
+ set_attribute("txn_id", txn_id)
return self.txns.fetch_or_execute_request(request, self.on_POST, request)
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
@@ -194,7 +194,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
if txn_id:
- set_tag("txn_id", txn_id)
+ set_attribute("txn_id", txn_id)
content = parse_json_object_from_request(request)
@@ -229,7 +229,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
except ShadowBanError:
event_id = "$" + random_string(43)
- set_tag("event_id", event_id)
+ set_attribute("event_id", event_id)
ret = {"event_id": event_id}
return 200, ret
@@ -279,7 +279,7 @@ class RoomSendEventRestServlet(TransactionRestServlet):
except ShadowBanError:
event_id = "$" + random_string(43)
- set_tag("event_id", event_id)
+ set_attribute("event_id", event_id)
return 200, {"event_id": event_id}
def on_GET(
@@ -290,7 +290,7 @@ class RoomSendEventRestServlet(TransactionRestServlet):
def on_PUT(
self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
- set_tag("txn_id", txn_id)
+ set_attribute("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, event_type, txn_id
@@ -348,7 +348,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
def on_PUT(
self, request: SynapseRequest, room_identifier: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
- set_tag("txn_id", txn_id)
+ set_attribute("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_identifier, txn_id
@@ -816,7 +816,7 @@ class RoomForgetRestServlet(TransactionRestServlet):
def on_PUT(
self, request: SynapseRequest, room_id: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
- set_tag("txn_id", txn_id)
+ set_attribute("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, txn_id
@@ -916,7 +916,7 @@ class RoomMembershipRestServlet(TransactionRestServlet):
def on_PUT(
self, request: SynapseRequest, room_id: str, membership_action: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
- set_tag("txn_id", txn_id)
+ set_attribute("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, membership_action, txn_id
@@ -962,13 +962,13 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
except ShadowBanError:
event_id = "$" + random_string(43)
- set_tag("event_id", event_id)
+ set_attribute("event_id", event_id)
return 200, {"event_id": event_id}
def on_PUT(
self, request: SynapseRequest, room_id: str, event_id: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
- set_tag("txn_id", txn_id)
+ set_attribute("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, event_id, txn_id
diff --git a/synapse/rest/client/sendtodevice.py b/synapse/rest/client/sendtodevice.py
index 1a8e9a96d4..a3f8fdb317 100644
--- a/synapse/rest/client/sendtodevice.py
+++ b/synapse/rest/client/sendtodevice.py
@@ -19,7 +19,7 @@ from synapse.http import servlet
from synapse.http.server import HttpServer
from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request
from synapse.http.site import SynapseRequest
-from synapse.logging.opentracing import set_tag, trace_with_opname
+from synapse.logging.tracing import set_attribute, trace_with_opname
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.types import JsonDict
@@ -47,8 +47,8 @@ class SendToDeviceRestServlet(servlet.RestServlet):
def on_PUT(
self, request: SynapseRequest, message_type: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
- set_tag("message_type", message_type)
- set_tag("txn_id", txn_id)
+ set_attribute("message_type", message_type)
+ set_attribute("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
request, self._put, request, message_type, txn_id
)
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index c2989765ce..5ddb08eb2f 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -37,7 +37,7 @@ from synapse.handlers.sync import (
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.http.site import SynapseRequest
-from synapse.logging.opentracing import trace_with_opname
+from synapse.logging.tracing import trace_with_opname
from synapse.types import JsonDict, StreamToken
from synapse.util import json_decoder
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index f4f06563dd..0366986755 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -95,8 +95,8 @@ class VersionsRestServlet(RestServlet):
"org.matrix.msc3026.busy_presence": self.config.experimental.msc3026_enabled,
# Supports receiving private read receipts as per MSC2285
"org.matrix.msc2285": self.config.experimental.msc2285_enabled,
- # Supports filtering of /publicRooms by room type MSC3827
- "org.matrix.msc3827": self.config.experimental.msc3827_enabled,
+ # Supports filtering of /publicRooms by room type as per MSC3827
+ "org.matrix.msc3827.stable": True,
# Adds support for importing historical messages as per MSC2716
"org.matrix.msc2716": self.config.experimental.msc2716_enabled,
# Adds support for jump to date endpoints (/timestamp_to_event) as per MSC3030
diff --git a/synapse/rest/synapse/client/password_reset.py b/synapse/rest/synapse/client/password_reset.py
index b9402cfb75..6ac9dbc7c9 100644
--- a/synapse/rest/synapse/client/password_reset.py
+++ b/synapse/rest/synapse/client/password_reset.py
@@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Tuple
from twisted.web.server import Request
from synapse.api.errors import ThreepidValidationError
+from synapse.config.emailconfig import ThreepidBehaviour
from synapse.http.server import DirectServeHtmlResource
from synapse.http.servlet import parse_string
from synapse.util.stringutils import assert_valid_client_secret
@@ -45,6 +46,9 @@ class PasswordResetSubmitTokenResource(DirectServeHtmlResource):
self.clock = hs.get_clock()
self.store = hs.get_datastores().main
+ self._local_threepid_handling_disabled_due_to_email_config = (
+ hs.config.email.local_threepid_handling_disabled_due_to_email_config
+ )
self._confirmation_email_template = (
hs.config.email.email_password_reset_template_confirmation_html
)
@@ -55,8 +59,8 @@ class PasswordResetSubmitTokenResource(DirectServeHtmlResource):
hs.config.email.email_password_reset_template_failure_html
)
- # This resource should only be mounted if email validation is enabled
- assert hs.config.email.can_verify_email
+ # This resource should not be mounted if threepid behaviour is not LOCAL
+ assert hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL
async def _async_render_GET(self, request: Request) -> Tuple[int, bytes]:
sid = parse_string(request, "sid", required=True)
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 87ccd52f0a..c355e4f98a 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -255,7 +255,7 @@ class StateHandler:
self,
event: EventBase,
state_ids_before_event: Optional[StateMap[str]] = None,
- partial_state: bool = False,
+ partial_state: Optional[bool] = None,
) -> EventContext:
"""Build an EventContext structure for a non-outlier event.
@@ -270,10 +270,18 @@ class StateHandler:
it can't be calculated from existing events. This is normally
only specified when receiving an event from federation where we
don't have the prev events, e.g. when backfilling.
- partial_state: True if `state_ids_before_event` is partial and omits
- non-critical membership events
+ partial_state:
+ `True` if `state_ids_before_event` is partial and omits non-critical
+ membership events.
+ `False` if `state_ids_before_event` is the full state.
+ `None` when `state_ids_before_event` is not provided. In this case, the
+ flag will be calculated based on `event`'s prev events.
Returns:
The event context.
+
+ Raises:
+ RuntimeError if `state_ids_before_event` is not provided and one or more
+ prev events are missing or outliers.
"""
assert not event.internal_metadata.is_outlier()
@@ -298,12 +306,14 @@ class StateHandler:
)
)
+ # the partial_state flag must be provided
+ assert partial_state is not None
else:
# otherwise, we'll need to resolve the state across the prev_events.
# partial_state should not be set explicitly in this case:
# we work it out dynamically
- assert not partial_state
+ assert partial_state is None
# if any of the prev-events have partial state, so do we.
# (This is slightly racy - the prev-events might get fixed up before we use
@@ -313,13 +323,13 @@ class StateHandler:
incomplete_prev_events = await self.store.get_partial_state_events(
prev_event_ids
)
- if any(incomplete_prev_events.values()):
+ partial_state = any(incomplete_prev_events.values())
+ if partial_state:
logger.debug(
"New/incoming event %s refers to prev_events %s with partial state",
event.event_id,
[k for (k, v) in incomplete_prev_events.items() if v],
)
- partial_state = True
logger.debug("calling resolve_state_groups from compute_event_context")
# we've already taken into account partial state, so no need to wait for
@@ -426,6 +436,10 @@ class StateHandler:
Returns:
The resolved state
+
+ Raises:
+ RuntimeError if we don't have a state group for one or more of the events
+ (ie. they are outliers or unknown)
"""
logger.debug("resolve_state_groups event_ids %s", event_ids)
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index a2f8310388..e30f9c76d4 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -80,6 +80,10 @@ class SQLBaseStore(metaclass=ABCMeta):
)
self._attempt_to_invalidate_cache("get_local_users_in_room", (room_id,))
+ # There's no easy way of invalidating this cache for just the users
+ # that have changed, so we just clear the entire thing.
+ self._attempt_to_invalidate_cache("does_pair_of_users_share_a_room", None)
+
for user_id in members_changed:
self._attempt_to_invalidate_cache(
"get_user_in_room_with_profile", (room_id, user_id)
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index b4b904ff1d..f87f5098a5 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -45,8 +45,8 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
-from synapse.logging import opentracing
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
+from synapse.logging.tracing import Link, get_active_span, start_active_span, trace
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.controllers.state import StateStorageController
from synapse.storage.databases import Databases
@@ -118,7 +118,7 @@ times_pruned_extremities = Counter(
class _PersistEventsTask:
"""A batch of events to persist."""
- name: ClassVar[str] = "persist_event_batch" # used for opentracing
+ name: ClassVar[str] = "persist_event_batch" # used for tracing
events_and_contexts: List[Tuple[EventBase, EventContext]]
backfilled: bool
@@ -139,7 +139,7 @@ class _PersistEventsTask:
class _UpdateCurrentStateTask:
"""A room whose current state needs recalculating."""
- name: ClassVar[str] = "update_current_state" # used for opentracing
+ name: ClassVar[str] = "update_current_state" # used for tracing
def try_merge(self, task: "_EventPersistQueueTask") -> bool:
"""Deduplicates consecutive recalculations of current state."""
@@ -154,11 +154,11 @@ class _EventPersistQueueItem:
task: _EventPersistQueueTask
deferred: ObservableDeferred
- parent_opentracing_span_contexts: List = attr.ib(factory=list)
- """A list of opentracing spans waiting for this batch"""
+ parent_tracing_span_contexts: List = attr.ib(factory=list)
+ """A list of tracing spans waiting for this batch"""
- opentracing_span_context: Any = None
- """The opentracing span under which the persistence actually happened"""
+ tracing_span_context: Any = None
+ """The tracing span under which the persistence actually happened"""
_PersistResult = TypeVar("_PersistResult")
@@ -222,10 +222,10 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
)
queue.append(end_item)
- # also add our active opentracing span to the item so that we get a link back
- span = opentracing.active_span()
+ # also add our active tracing span to the item so that we get a link back
+ span = get_active_span()
if span:
- end_item.parent_opentracing_span_contexts.append(span.context)
+ end_item.parent_tracing_span_contexts.append(span.get_span_context())
# start a processor for the queue, if there isn't one already
self._handle_queue(room_id)
@@ -233,9 +233,10 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
# wait for the queue item to complete
res = await make_deferred_yieldable(end_item.deferred.observe())
- # add another opentracing span which links to the persist trace.
- with opentracing.start_active_span_follows_from(
- f"{task.name}_complete", (end_item.opentracing_span_context,)
+ # add another tracing span which links to the persist trace.
+ with start_active_span(
+ f"{task.name}_complete",
+ links=[Link(end_item.tracing_span_context)],
):
pass
@@ -266,13 +267,15 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
queue = self._get_drainining_queue(room_id)
for item in queue:
try:
- with opentracing.start_active_span_follows_from(
+ with start_active_span(
item.task.name,
- item.parent_opentracing_span_contexts,
- inherit_force_tracing=True,
- ) as scope:
- if scope:
- item.opentracing_span_context = scope.span.context
+ links=[
+ Link(span_context)
+ for span_context in item.parent_tracing_span_contexts
+ ],
+ ) as span:
+ if span:
+ item.tracing_span_context = span.get_span_context()
ret = await self._per_item_callback(room_id, item.task)
except Exception:
@@ -355,7 +358,7 @@ class EventsPersistenceStorageController:
f"Found an unexpected task type in event persistence queue: {task}"
)
- @opentracing.trace
+ @trace
async def persist_events(
self,
events_and_contexts: Iterable[Tuple[EventBase, EventContext]],
@@ -418,7 +421,7 @@ class EventsPersistenceStorageController:
self.main_store.get_room_max_token(),
)
- @opentracing.trace
+ @trace
async def persist_event(
self, event: EventBase, context: EventContext, backfilled: bool = False
) -> Tuple[EventBase, PersistedEventPosition, RoomStreamToken]:
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index f584e6c92e..0d480f1014 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -83,13 +83,15 @@ class StateStorageController:
return state_group_delta.prev_group, state_group_delta.delta_ids
async def get_state_groups_ids(
- self, _room_id: str, event_ids: Collection[str]
+ self, _room_id: str, event_ids: Collection[str], await_full_state: bool = True
) -> Dict[int, MutableStateMap[str]]:
"""Get the event IDs of all the state for the state groups for the given events
Args:
_room_id: id of the room for these events
event_ids: ids of the events
+ await_full_state: if `True`, will block if we do not yet have complete
+ state at these events.
Returns:
dict of state_group_id -> (dict of (type, state_key) -> event id)
@@ -101,7 +103,9 @@ class StateStorageController:
if not event_ids:
return {}
- event_to_groups = await self.get_state_group_for_events(event_ids)
+ event_to_groups = await self.get_state_group_for_events(
+ event_ids, await_full_state=await_full_state
+ )
groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(groups)
@@ -339,6 +343,10 @@ class StateStorageController:
event_ids: events to get state groups for
await_full_state: if true, will block if we do not yet have complete
state at these events.
+
+ Raises:
+ RuntimeError if we don't have a state group for one or more of the events
+ (ie. they are outliers or unknown)
"""
if await_full_state:
await self._partial_state_events_tracker.await_full_state(event_ids)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index b394a6658b..ca0f606797 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -47,7 +47,7 @@ from twisted.internet.interfaces import IReactorCore
from synapse.api.errors import StoreError
from synapse.config.database import DatabaseConnectionConfig
-from synapse.logging import opentracing
+from synapse.logging import tracing
from synapse.logging.context import (
LoggingContext,
current_context,
@@ -422,11 +422,11 @@ class LoggingTransaction:
start = time.time()
try:
- with opentracing.start_active_span(
+ with tracing.start_active_span(
"db.query",
- tags={
- opentracing.tags.DATABASE_TYPE: "sql",
- opentracing.tags.DATABASE_STATEMENT: one_line_sql,
+ attributes={
+ tracing.SpanAttributes.DB_SYSTEM: "sql",
+ tracing.SpanAttributes.DB_STATEMENT: one_line_sql,
},
):
return func(sql, *args, **kwargs)
@@ -701,15 +701,15 @@ class DatabasePool:
exception_callbacks=exception_callbacks,
)
try:
- with opentracing.start_active_span(
+ with tracing.start_active_span(
"db.txn",
- tags={
- opentracing.SynapseTags.DB_TXN_DESC: desc,
- opentracing.SynapseTags.DB_TXN_ID: name,
+ attributes={
+ tracing.SynapseTags.DB_TXN_DESC: desc,
+ tracing.SynapseTags.DB_TXN_ID: name,
},
):
r = func(cursor, *args, **kwargs)
- opentracing.log_kv({"message": "commit"})
+ tracing.log_kv({"message": "commit"})
conn.commit()
return r
except self.engine.module.OperationalError as e:
@@ -725,7 +725,7 @@ class DatabasePool:
if i < N:
i += 1
try:
- with opentracing.start_active_span("db.rollback"):
+ with tracing.start_active_span("db.rollback"):
conn.rollback()
except self.engine.module.Error as e1:
transaction_logger.warning("[TXN EROLL] {%s} %s", name, e1)
@@ -739,7 +739,7 @@ class DatabasePool:
if i < N:
i += 1
try:
- with opentracing.start_active_span("db.rollback"):
+ with tracing.start_active_span("db.rollback"):
conn.rollback()
except self.engine.module.Error as e1:
transaction_logger.warning(
@@ -845,7 +845,7 @@ class DatabasePool:
logger.warning("Starting db txn '%s' from sentinel context", desc)
try:
- with opentracing.start_active_span(f"db.{desc}"):
+ with tracing.start_active_span(f"db.{desc}"):
result = await self.runWithConnection(
self.new_transaction,
desc,
@@ -928,9 +928,7 @@ class DatabasePool:
with LoggingContext(
str(curr_context), parent_context=parent_context
) as context:
- with opentracing.start_active_span(
- operation_name="db.connection",
- ):
+ with tracing.start_active_span("db.connection"):
sched_duration_sec = monotonic_time() - start_time
sql_scheduling_timer.observe(sched_duration_sec)
context.add_database_scheduled(sched_duration_sec)
@@ -944,15 +942,13 @@ class DatabasePool:
"Reconnecting database connection over transaction limit"
)
conn.reconnect()
- opentracing.log_kv(
- {"message": "reconnected due to txn limit"}
- )
+ tracing.log_kv({"message": "reconnected due to txn limit"})
self._txn_counters[tid] = 1
if self.engine.is_connection_closed(conn):
logger.debug("Reconnecting closed database connection")
conn.reconnect()
- opentracing.log_kv({"message": "reconnected"})
+ tracing.log_kv({"message": "reconnected"})
if self._txn_limit > 0:
self._txn_counters[tid] = 1
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 73c95ffb6f..1503d74b1f 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -27,7 +27,7 @@ from typing import (
)
from synapse.logging import issue9533_logger
-from synapse.logging.opentracing import log_kv, set_tag, trace
+from synapse.logging.tracing import log_kv, set_attribute, trace
from synapse.replication.tcp.streams import ToDeviceStream
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import (
@@ -436,7 +436,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
(user_id, device_id), None
)
- set_tag("last_deleted_stream_id", str(last_deleted_stream_id))
+ set_attribute("last_deleted_stream_id", str(last_deleted_stream_id))
if last_deleted_stream_id:
has_changed = self._device_inbox_stream_cache.has_entity_changed(
@@ -485,10 +485,10 @@ class DeviceInboxWorkerStore(SQLBaseStore):
A list of messages for the device and where in the stream the messages got to.
"""
- set_tag("destination", destination)
- set_tag("last_stream_id", last_stream_id)
- set_tag("current_stream_id", current_stream_id)
- set_tag("limit", limit)
+ set_attribute("destination", destination)
+ set_attribute("last_stream_id", last_stream_id)
+ set_attribute("current_stream_id", current_stream_id)
+ set_attribute("limit", limit)
has_changed = self._device_federation_outbox_stream_cache.has_entity_changed(
destination, last_stream_id
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index ca0fe8c4be..7ceb7a202b 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -30,11 +30,11 @@ from typing import (
from typing_extensions import Literal
-from synapse.api.constants import EduTypes
+from synapse.api.constants import EduTypes, EventContentFields
from synapse.api.errors import Codes, StoreError
-from synapse.logging.opentracing import (
+from synapse.logging.tracing import (
get_active_span_text_map,
- set_tag,
+ set_attribute,
trace,
whitelisted_homeserver,
)
@@ -333,12 +333,12 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
# (user_id, device_id) entries into a map, with the value being
# the max stream_id across each set of duplicate entries
#
- # maps (user_id, device_id) -> (stream_id, opentracing_context)
+ # maps (user_id, device_id) -> (stream_id,tracing_context)
#
- # opentracing_context contains the opentracing metadata for the request
+ # tracing_context contains the opentelemetry metadata for the request
# that created the poke
#
- # The most recent request's opentracing_context is used as the
+ # The most recent request's tracing_context is used as the
# context which created the Edu.
# This is the stream ID that we will return for the consumer to resume
@@ -401,8 +401,8 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
if update_stream_id > previous_update_stream_id:
# FIXME If this overwrites an older update, this discards the
- # previous OpenTracing context.
- # It might make it harder to track down issues using OpenTracing.
+ # previous tracing context.
+ # It might make it harder to track down issues using tracing.
# If there's a good reason why it doesn't matter, a comment here
# about that would not hurt.
query_map[key] = (update_stream_id, update_context)
@@ -468,11 +468,11 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
- user_id
- device_id
- stream_id
- - opentracing_context
+ - tracing_context
"""
# get the list of device updates that need to be sent
sql = """
- SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes
+ SELECT user_id, device_id, stream_id, tracing_context FROM device_lists_outbound_pokes
WHERE destination = ? AND ? < stream_id AND stream_id <= ?
ORDER BY stream_id
LIMIT ?
@@ -493,7 +493,7 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
destination: The host the device updates are intended for
from_stream_id: The minimum stream_id to filter updates by, exclusive
query_map: Dictionary mapping (user_id, device_id) to
- (update stream_id, the relevant json-encoded opentracing context)
+ (update stream_id, the relevant json-encoded tracing context)
Returns:
List of objects representing a device update EDU.
@@ -531,13 +531,13 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
for device_id in device_ids:
device = user_devices[device_id]
- stream_id, opentracing_context = query_map[(user_id, device_id)]
+ stream_id, tracing_context = query_map[(user_id, device_id)]
result = {
"user_id": user_id,
"device_id": device_id,
"prev_id": [prev_id] if prev_id else [],
"stream_id": stream_id,
- "org.matrix.opentracing_context": opentracing_context,
+ EventContentFields.TRACING_CONTEXT: tracing_context,
}
prev_id = stream_id
@@ -706,8 +706,8 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
else:
results[user_id] = await self.get_cached_devices_for_user(user_id)
- set_tag("in_cache", str(results))
- set_tag("not_in_cache", str(user_ids_not_in_cache))
+ set_attribute("in_cache", str(results))
+ set_attribute("not_in_cache", str(user_ids_not_in_cache))
return user_ids_not_in_cache, results
@@ -1801,7 +1801,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
"device_id",
"sent",
"ts",
- "opentracing_context",
+ "tracing_context",
),
values=[
(
@@ -1846,7 +1846,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
"room_id",
"stream_id",
"converted_to_destinations",
- "opentracing_context",
+ "tracing_context",
),
values=[
(
@@ -1870,11 +1870,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
written to `device_lists_outbound_pokes`.
Returns:
- A list of user ID, device ID, room ID, stream ID and optional opentracing context.
+ A list of user ID, device ID, room ID, stream ID and optional opentelemetry context.
"""
sql = """
- SELECT user_id, device_id, room_id, stream_id, opentracing_context
+ SELECT user_id, device_id, room_id, stream_id, tracing_context
FROM device_lists_changes_in_room
WHERE NOT converted_to_destinations
ORDER BY stream_id
@@ -1892,9 +1892,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
device_id,
room_id,
stream_id,
- db_to_json(opentracing_context),
+ db_to_json(tracing_context),
)
- for user_id, device_id, room_id, stream_id, opentracing_context in txn
+ for user_id, device_id, room_id, stream_id, tracing_context in txn
]
return await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index af59be6b48..6d565102ac 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -18,7 +18,7 @@ from typing import Dict, Iterable, Mapping, Optional, Tuple, cast
from typing_extensions import Literal, TypedDict
from synapse.api.errors import StoreError
-from synapse.logging.opentracing import log_kv, trace
+from synapse.logging.tracing import log_kv, trace
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import LoggingTransaction
from synapse.types import JsonDict, JsonSerializable, StreamKeyType
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 46c0d06157..2df8101390 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -36,7 +36,7 @@ from synapse.appservice import (
TransactionOneTimeKeyCounts,
TransactionUnusedFallbackKeys,
)
-from synapse.logging.opentracing import log_kv, set_tag, trace
+from synapse.logging.tracing import log_kv, set_attribute, trace
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import (
DatabasePool,
@@ -146,7 +146,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
key data. The key data will be a dict in the same format as the
DeviceKeys type returned by POST /_matrix/client/r0/keys/query.
"""
- set_tag("query_list", str(query_list))
+ set_attribute("query_list", str(query_list))
if not query_list:
return {}
@@ -228,8 +228,8 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
Dict mapping from user-id to dict mapping from device_id to
key data.
"""
- set_tag("include_all_devices", include_all_devices)
- set_tag("include_deleted_devices", include_deleted_devices)
+ set_attribute("include_all_devices", include_all_devices)
+ set_attribute("include_deleted_devices", include_deleted_devices)
result = await self.db_pool.runInteraction(
"get_e2e_device_keys",
@@ -416,9 +416,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
"""
def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None:
- set_tag("user_id", user_id)
- set_tag("device_id", device_id)
- set_tag("new_keys", str(new_keys))
+ set_attribute("user_id", user_id)
+ set_attribute("device_id", device_id)
+ set_attribute("new_keys", str(new_keys))
# We are protected from race between lookup and insertion due to
# a unique constraint. If there is a race of two calls to
# `add_e2e_one_time_keys` then they'll conflict and we will only
@@ -1158,10 +1158,10 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
"""
def _set_e2e_device_keys_txn(txn: LoggingTransaction) -> bool:
- set_tag("user_id", user_id)
- set_tag("device_id", device_id)
- set_tag("time_now", time_now)
- set_tag("device_keys", str(device_keys))
+ set_attribute("user_id", user_id)
+ set_attribute("device_id", device_id)
+ set_attribute("time_now", time_now)
+ set_attribute("device_keys", str(device_keys))
old_key_json = self.db_pool.simple_select_one_onecol_txn(
txn,
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 5914a35420..29c99c6357 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -2110,11 +2110,29 @@ class EventsWorkerStore(SQLBaseStore):
def _get_partial_state_events_batch_txn(
txn: LoggingTransaction, room_id: str
) -> List[str]:
+ # we want to work through the events from oldest to newest, so
+ # we only want events whose prev_events do *not* have partial state - hence
+ # the 'NOT EXISTS' clause in the below.
+ #
+ # This is necessary because ordering by stream ordering isn't quite enough
+ # to ensure that we work from oldest to newest event (in particular,
+ # if an event is initially persisted as an outlier and later de-outliered,
+ # it can end up with a lower stream_ordering than its prev_events).
+ #
+ # Typically this means we'll only return one event per batch, but that's
+ # hard to do much about.
+ #
+ # See also: https://github.com/matrix-org/synapse/issues/13001
txn.execute(
"""
SELECT event_id FROM partial_state_events AS pse
JOIN events USING (event_id)
- WHERE pse.room_id = ?
+ WHERE pse.room_id = ? AND
+ NOT EXISTS(
+ SELECT 1 FROM event_edges AS ee
+ JOIN partial_state_events AS prev_pse ON (prev_pse.event_id=ee.prev_event_id)
+ WHERE ee.event_id=pse.event_id
+ )
ORDER BY events.stream_ordering
LIMIT 100
""",
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index b457bc189e..7bd27790eb 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -62,7 +62,6 @@ class RelationsWorkerStore(SQLBaseStore):
room_id: str,
relation_type: Optional[str] = None,
event_type: Optional[str] = None,
- aggregation_key: Optional[str] = None,
limit: int = 5,
direction: str = "b",
from_token: Optional[StreamToken] = None,
@@ -76,7 +75,6 @@ class RelationsWorkerStore(SQLBaseStore):
room_id: The room the event belongs to.
relation_type: Only fetch events with this relation type, if given.
event_type: Only fetch events with this event type, if given.
- aggregation_key: Only fetch events with this aggregation key, if given.
limit: Only fetch the most recent `limit` events.
direction: Whether to fetch the most recent first (`"b"`) or the
oldest first (`"f"`).
@@ -105,10 +103,6 @@ class RelationsWorkerStore(SQLBaseStore):
where_clause.append("type = ?")
where_args.append(event_type)
- if aggregation_key:
- where_clause.append("aggregation_key = ?")
- where_args.append(aggregation_key)
-
pagination_clause = generate_pagination_where_clause(
direction=direction,
column_names=("topological_ordering", "stream_ordering"),
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index d6d485507b..0f1f0d11ea 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -207,7 +207,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
def _construct_room_type_where_clause(
self, room_types: Union[List[Union[str, None]], None]
) -> Tuple[Union[str, None], List[str]]:
- if not room_types or not self.config.experimental.msc3827_enabled:
+ if not room_types:
return None, []
else:
# We use None when we want get rooms without a type
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index df6b82660e..e2cccc688c 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -21,6 +21,7 @@ from typing import (
FrozenSet,
Iterable,
List,
+ Mapping,
Optional,
Set,
Tuple,
@@ -55,6 +56,7 @@ from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
+from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
if TYPE_CHECKING:
@@ -183,7 +185,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
self._check_safe_current_state_events_membership_updated_txn,
)
- @cached(max_entries=100000, iterable=True, prune_unread_entries=False)
+ @cached(max_entries=100000, iterable=True)
async def get_users_in_room(self, room_id: str) -> List[str]:
return await self.db_pool.runInteraction(
"get_users_in_room", self.get_users_in_room_txn, room_id
@@ -561,7 +563,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return results_dict.get("membership"), results_dict.get("event_id")
- @cached(max_entries=500000, iterable=True, prune_unread_entries=False)
+ @cached(max_entries=500000, iterable=True)
async def get_rooms_for_user_with_stream_ordering(
self, user_id: str
) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
@@ -732,25 +734,76 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
return frozenset(r.room_id for r in rooms)
- @cached(
- max_entries=500000,
- cache_context=True,
- iterable=True,
- prune_unread_entries=False,
+ @cached(max_entries=10000)
+ async def does_pair_of_users_share_a_room(
+ self, user_id: str, other_user_id: str
+ ) -> bool:
+ raise NotImplementedError()
+
+ @cachedList(
+ cached_method_name="does_pair_of_users_share_a_room", list_name="other_user_ids"
)
- async def get_users_who_share_room_with_user(
- self, user_id: str, cache_context: _CacheContext
+ async def _do_users_share_a_room(
+ self, user_id: str, other_user_ids: Collection[str]
+ ) -> Mapping[str, Optional[bool]]:
+ """Return mapping from user ID to whether they share a room with the
+ given user.
+
+ Note: `None` and `False` are equivalent and mean they don't share a
+ room.
+ """
+
+ def do_users_share_a_room_txn(
+ txn: LoggingTransaction, user_ids: Collection[str]
+ ) -> Dict[str, bool]:
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "state_key", user_ids
+ )
+
+ # This query works by fetching both the list of rooms for the target
+ # user and the set of other users, and then checking if there is any
+ # overlap.
+ sql = f"""
+ SELECT b.state_key
+ FROM (
+ SELECT room_id FROM current_state_events
+ WHERE type = 'm.room.member' AND membership = 'join' AND state_key = ?
+ ) AS a
+ INNER JOIN (
+ SELECT room_id, state_key FROM current_state_events
+ WHERE type = 'm.room.member' AND membership = 'join' AND {clause}
+ ) AS b using (room_id)
+ LIMIT 1
+ """
+
+ txn.execute(sql, (user_id, *args))
+ return {u: True for u, in txn}
+
+ to_return = {}
+ for batch_user_ids in batch_iter(other_user_ids, 1000):
+ res = await self.db_pool.runInteraction(
+ "do_users_share_a_room", do_users_share_a_room_txn, batch_user_ids
+ )
+ to_return.update(res)
+
+ return to_return
+
+ async def do_users_share_a_room(
+ self, user_id: str, other_user_ids: Collection[str]
) -> Set[str]:
+ """Return the set of users who share a room with the first users"""
+
+ user_dict = await self._do_users_share_a_room(user_id, other_user_ids)
+
+ return {u for u, share_room in user_dict.items() if share_room}
+
+ async def get_users_who_share_room_with_user(self, user_id: str) -> Set[str]:
"""Returns the set of users who share a room with `user_id`"""
- room_ids = await self.get_rooms_for_user(
- user_id, on_invalidate=cache_context.invalidate
- )
+ room_ids = await self.get_rooms_for_user(user_id)
user_who_share_room = set()
for room_id in room_ids:
- user_ids = await self.get_users_in_room(
- room_id, on_invalidate=cache_context.invalidate
- )
+ user_ids = await self.get_users_in_room(room_id)
user_who_share_room.update(user_ids)
return user_who_share_room
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 9674c4a757..f70705a0af 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -419,13 +419,15 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# anything that was rejected should have the same state as its
# predecessor.
if context.rejected:
- assert context.state_group == context.state_group_before_event
+ state_group = context.state_group_before_event
+ else:
+ state_group = context.state_group
self.db_pool.simple_update_txn(
txn,
table="event_to_state_groups",
keyvalues={"event_id": event.event_id},
- updatevalues={"state_group": context.state_group},
+ updatevalues={"state_group": state_group},
)
self.db_pool.simple_delete_one_txn(
@@ -440,7 +442,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
txn.call_after(
self._get_state_group_for_event.prefill,
(event.event_id,),
- context.state_group,
+ state_group,
)
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index a9a88c8bfd..dd187f7422 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -75,6 +75,8 @@ Changes in SCHEMA_VERSION = 71:
Changes in SCHEMA_VERSION = 72:
- event_edges.(room_id, is_state) are no longer written to.
- Tables related to groups are dropped.
+ - Rename column in `device_lists_outbound_pokes` and `device_lists_changes_in_room`
+ from `opentracing_context` to generalized `tracing_context`.
"""
diff --git a/synapse/storage/schema/main/delta/72/04rename_opentelemtetry_tracing_context.sql b/synapse/storage/schema/main/delta/72/04rename_opentelemtetry_tracing_context.sql
new file mode 100644
index 0000000000..ae904863f8
--- /dev/null
+++ b/synapse/storage/schema/main/delta/72/04rename_opentelemtetry_tracing_context.sql
@@ -0,0 +1,18 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Rename to generalized `tracing_context` since we're moving from opentracing to opentelemetry
+ALTER TABLE device_lists_outbound_pokes RENAME COLUMN opentracing_context TO tracing_context;
+ALTER TABLE device_lists_changes_in_room RENAME COLUMN opentracing_context TO tracing_context;
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index a3eb5f741b..1dd2d3e62e 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -29,11 +29,7 @@ import attr
from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, run_in_background
-from synapse.logging.opentracing import (
- active_span,
- start_active_span,
- start_active_span_follows_from,
-)
+from synapse.logging.tracing import Link, get_active_span, start_active_span
from synapse.util import Clock
from synapse.util.async_helpers import AbstractObservableDeferred, ObservableDeferred
from synapse.util.caches import register_cache
@@ -41,7 +37,7 @@ from synapse.util.caches import register_cache
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
- import opentracing
+ import opentelemetry
# the type of the key in the cache
KV = TypeVar("KV")
@@ -82,8 +78,8 @@ class ResponseCacheEntry:
easier to cache Failure results.
"""
- opentracing_span_context: "Optional[opentracing.SpanContext]"
- """The opentracing span which generated/is generating the result"""
+ tracing_span_context: Optional["opentelemetry.trace.SpanContext"]
+ """The tracing span which generated/is generating the result"""
class ResponseCache(Generic[KV]):
@@ -141,7 +137,7 @@ class ResponseCache(Generic[KV]):
self,
context: ResponseCacheContext[KV],
deferred: "defer.Deferred[RV]",
- opentracing_span_context: "Optional[opentracing.SpanContext]",
+ tracing_span_context: Optional["opentelemetry.trace.SpanContext"],
) -> ResponseCacheEntry:
"""Set the entry for the given key to the given deferred.
@@ -152,14 +148,14 @@ class ResponseCache(Generic[KV]):
Args:
context: Information about the cache miss
deferred: The deferred which resolves to the result.
- opentracing_span_context: An opentracing span wrapping the calculation
+ tracing_span_context: An tracing span wrapping the calculation
Returns:
The cache entry object.
"""
result = ObservableDeferred(deferred, consumeErrors=True)
key = context.cache_key
- entry = ResponseCacheEntry(result, opentracing_span_context)
+ entry = ResponseCacheEntry(result, tracing_span_context)
self._result_cache[key] = entry
def on_complete(r: RV) -> RV:
@@ -234,15 +230,15 @@ class ResponseCache(Generic[KV]):
if cache_context:
kwargs["cache_context"] = context
- span_context: Optional[opentracing.SpanContext] = None
+ span_context: Optional["opentelemetry.trace.SpanContext"] = None
async def cb() -> RV:
# NB it is important that we do not `await` before setting span_context!
nonlocal span_context
with start_active_span(f"ResponseCache[{self._name}].calculate"):
- span = active_span()
+ span = get_active_span()
if span:
- span_context = span.context
+ span_context = span.get_span_context()
return await callback(*args, **kwargs)
d = run_in_background(cb)
@@ -257,9 +253,9 @@ class ResponseCache(Generic[KV]):
"[%s]: using incomplete cached result for [%s]", self._name, key
)
- span_context = entry.opentracing_span_context
- with start_active_span_follows_from(
+ span_context = entry.tracing_span_context
+ with start_active_span(
f"ResponseCache[{self._name}].wait",
- contexts=(span_context,) if span_context else (),
+ links=[Link(span_context)] if span_context else None,
):
return await make_deferred_yieldable(result)
|