diff options
author | Ben Banfield-Zanin <benbz@matrix.org> | 2020-10-15 14:48:13 +0100 |
---|---|---|
committer | Ben Banfield-Zanin <benbz@matrix.org> | 2020-10-15 14:48:13 +0100 |
commit | 8d9ae573f33110e0420204bceb111fd8df649e7c (patch) | |
tree | c8113c67df9769a14e8bb0a03620026dbe9aa0ba /synapse | |
parent | Merge remote-tracking branch 'origin/anoa/3pid_check_invite_exemption' into b... (diff) | |
parent | Remove racey assertion in MultiWriterIDGenerator (#8530) (diff) | |
download | synapse-8d9ae573f33110e0420204bceb111fd8df649e7c.tar.xz |
Merge remote-tracking branch 'origin/release-v1.21.2' into bbz/info-mainline-1.21.2 github/bbz/info-mainline-1.21.2 bbz/info-mainline-1.21.2
Diffstat (limited to 'synapse')
243 files changed, 4587 insertions, 2194 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py index e40b582bd5..722b53a67d 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -48,7 +48,7 @@ try: except ImportError: pass -__version__ = "1.20.1" +__version__ = "1.21.1" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py index 55cce2db22..da0996edbc 100644 --- a/synapse/_scripts/register_new_matrix_user.py +++ b/synapse/_scripts/register_new_matrix_user.py @@ -14,8 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import print_function - import argparse import getpass import hashlib diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 75388643ee..1071a0576e 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -218,11 +218,7 @@ class Auth: # Deny the request if the user account has expired. if self._account_validity.enabled and not allow_expired: user_id = user.to_string() - expiration_ts = await self.store.get_expiration_ts_for_user(user_id) - if ( - expiration_ts is not None - and self.clock.time_msec() >= expiration_ts - ): + if await self.store.is_account_expired(user_id, self.clock.time_msec()): raise AuthError( 403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT ) diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 94a9e58eae..cd6670d0a2 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -87,7 +87,7 @@ class CodeMessageException(RuntimeError): """ def __init__(self, code: Union[int, HTTPStatus], msg: str): - super(CodeMessageException, self).__init__("%d: %s" % (code, msg)) + super().__init__("%d: %s" % (code, msg)) # Some calls to this method pass instances of http.HTTPStatus for `code`. # While HTTPStatus is a subclass of int, it has magic __str__ methods @@ -138,7 +138,7 @@ class SynapseError(CodeMessageException): msg: The human-readable error message. errcode: The matrix error code e.g 'M_FORBIDDEN' """ - super(SynapseError, self).__init__(code, msg) + super().__init__(code, msg) self.errcode = errcode def error_dict(self): @@ -159,7 +159,7 @@ class ProxiedRequestError(SynapseError): errcode: str = Codes.UNKNOWN, additional_fields: Optional[Dict] = None, ): - super(ProxiedRequestError, self).__init__(code, msg, errcode) + super().__init__(code, msg, errcode) if additional_fields is None: self._additional_fields = {} # type: Dict else: @@ -181,7 +181,7 @@ class ConsentNotGivenError(SynapseError): msg: The human-readable error message consent_url: The URL where the user can give their consent """ - super(ConsentNotGivenError, self).__init__( + super().__init__( code=HTTPStatus.FORBIDDEN, msg=msg, errcode=Codes.CONSENT_NOT_GIVEN ) self._consent_uri = consent_uri @@ -201,7 +201,7 @@ class UserDeactivatedError(SynapseError): Args: msg: The human-readable error message """ - super(UserDeactivatedError, self).__init__( + super().__init__( code=HTTPStatus.FORBIDDEN, msg=msg, errcode=Codes.USER_DEACTIVATED ) @@ -225,7 +225,7 @@ class FederationDeniedError(SynapseError): self.destination = destination - super(FederationDeniedError, self).__init__( + super().__init__( code=403, msg="Federation denied with %s." % (self.destination,), errcode=Codes.FORBIDDEN, @@ -244,9 +244,7 @@ class InteractiveAuthIncompleteError(Exception): """ def __init__(self, session_id: str, result: "JsonDict"): - super(InteractiveAuthIncompleteError, self).__init__( - "Interactive auth not yet complete" - ) + super().__init__("Interactive auth not yet complete") self.session_id = session_id self.result = result @@ -261,14 +259,14 @@ class UnrecognizedRequestError(SynapseError): message = "Unrecognized request" else: message = args[0] - super(UnrecognizedRequestError, self).__init__(400, message, **kwargs) + super().__init__(400, message, **kwargs) class NotFoundError(SynapseError): """An error indicating we can't find the thing you asked for""" def __init__(self, msg: str = "Not found", errcode: str = Codes.NOT_FOUND): - super(NotFoundError, self).__init__(404, msg, errcode=errcode) + super().__init__(404, msg, errcode=errcode) class AuthError(SynapseError): @@ -279,7 +277,7 @@ class AuthError(SynapseError): def __init__(self, *args, **kwargs): if "errcode" not in kwargs: kwargs["errcode"] = Codes.FORBIDDEN - super(AuthError, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) class InvalidClientCredentialsError(SynapseError): @@ -335,7 +333,7 @@ class ResourceLimitError(SynapseError): ): self.admin_contact = admin_contact self.limit_type = limit_type - super(ResourceLimitError, self).__init__(code, msg, errcode=errcode) + super().__init__(code, msg, errcode=errcode) def error_dict(self): return cs_error( @@ -352,7 +350,7 @@ class EventSizeError(SynapseError): def __init__(self, *args, **kwargs): if "errcode" not in kwargs: kwargs["errcode"] = Codes.TOO_LARGE - super(EventSizeError, self).__init__(413, *args, **kwargs) + super().__init__(413, *args, **kwargs) class EventStreamError(SynapseError): @@ -361,7 +359,7 @@ class EventStreamError(SynapseError): def __init__(self, *args, **kwargs): if "errcode" not in kwargs: kwargs["errcode"] = Codes.BAD_PAGINATION - super(EventStreamError, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) class LoginError(SynapseError): @@ -384,7 +382,7 @@ class InvalidCaptchaError(SynapseError): error_url: Optional[str] = None, errcode: str = Codes.CAPTCHA_INVALID, ): - super(InvalidCaptchaError, self).__init__(code, msg, errcode) + super().__init__(code, msg, errcode) self.error_url = error_url def error_dict(self): @@ -402,7 +400,7 @@ class LimitExceededError(SynapseError): retry_after_ms: Optional[int] = None, errcode: str = Codes.LIMIT_EXCEEDED, ): - super(LimitExceededError, self).__init__(code, msg, errcode) + super().__init__(code, msg, errcode) self.retry_after_ms = retry_after_ms def error_dict(self): @@ -418,9 +416,7 @@ class RoomKeysVersionError(SynapseError): Args: current_version: the current version of the store they should have used """ - super(RoomKeysVersionError, self).__init__( - 403, "Wrong room_keys version", Codes.WRONG_ROOM_KEYS_VERSION - ) + super().__init__(403, "Wrong room_keys version", Codes.WRONG_ROOM_KEYS_VERSION) self.current_version = current_version @@ -429,7 +425,7 @@ class UnsupportedRoomVersionError(SynapseError): not support.""" def __init__(self, msg: str = "Homeserver does not support this room version"): - super(UnsupportedRoomVersionError, self).__init__( + super().__init__( code=400, msg=msg, errcode=Codes.UNSUPPORTED_ROOM_VERSION, ) @@ -440,7 +436,7 @@ class ThreepidValidationError(SynapseError): def __init__(self, *args, **kwargs): if "errcode" not in kwargs: kwargs["errcode"] = Codes.FORBIDDEN - super(ThreepidValidationError, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) class IncompatibleRoomVersionError(SynapseError): @@ -451,7 +447,7 @@ class IncompatibleRoomVersionError(SynapseError): """ def __init__(self, room_version: str): - super(IncompatibleRoomVersionError, self).__init__( + super().__init__( code=400, msg="Your homeserver does not support the features required to " "join this room", @@ -473,7 +469,7 @@ class PasswordRefusedError(SynapseError): msg: str = "This password doesn't comply with the server's policy", errcode: str = Codes.WEAK_PASSWORD, ): - super(PasswordRefusedError, self).__init__( + super().__init__( code=400, msg=msg, errcode=errcode, ) @@ -488,7 +484,7 @@ class RequestSendFailed(RuntimeError): """ def __init__(self, inner_exception, can_retry): - super(RequestSendFailed, self).__init__( + super().__init__( "Failed to send request: %s: %s" % (type(inner_exception).__name__, inner_exception) ) @@ -542,7 +538,7 @@ class FederationError(RuntimeError): self.source = source msg = "%s %s: %s" % (level, code, reason) - super(FederationError, self).__init__(msg) + super().__init__(msg) def get_dict(self): return { @@ -570,7 +566,7 @@ class HttpResponseException(CodeMessageException): msg: reason phrase from HTTP response status line response: body of response """ - super(HttpResponseException, self).__init__(code, msg) + super().__init__(code, msg) self.response = response def to_synapse_error(self): diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 2a2c9e6f13..5caf336fd0 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -15,10 +15,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json from typing import List import jsonschema -from canonicaljson import json from jsonschema import FormatChecker from synapse.api.constants import EventContentFields @@ -132,7 +132,7 @@ def matrix_user_id_validator(user_id_str): class Filtering: def __init__(self, hs): - super(Filtering, self).__init__() + super().__init__() self.store = hs.get_datastore() async def get_user_filter(self, user_localpart, filter_id): diff --git a/synapse/api/urls.py b/synapse/api/urls.py index bbfccf955e..6379c86dde 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -21,6 +21,7 @@ from urllib.parse import urlencode from synapse.config import ConfigError +SYNAPSE_CLIENT_API_PREFIX = "/_synapse/client" CLIENT_API_PREFIX = "/_matrix/client" FEDERATION_PREFIX = "/_matrix/federation" FEDERATION_V1_PREFIX = FEDERATION_PREFIX + "/v1" diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index b6c9085670..7d309b1bb0 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -14,13 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse +import json import logging import os import sys import tempfile -from canonicaljson import json - from twisted.internet import defer, task import synapse diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index f985810e88..c38413c893 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -152,7 +152,7 @@ class PresenceStatusStubServlet(RestServlet): PATTERNS = client_patterns("/presence/(?P<user_id>[^/]*)/status") def __init__(self, hs): - super(PresenceStatusStubServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() async def on_GET(self, request, user_id): @@ -176,7 +176,7 @@ class KeyUploadServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(KeyUploadServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.http_client = hs.get_simple_http_client() @@ -646,7 +646,7 @@ class GenericWorkerServer(HomeServer): class GenericWorkerReplicationHandler(ReplicationDataHandler): def __init__(self, hs): - super(GenericWorkerReplicationHandler, self).__init__(hs) + super().__init__(hs) self.store = hs.get_datastore() self.presence_handler = hs.get_presence_handler() # type: GenericWorkerPresence diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 6014adc850..dff739e106 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -15,8 +15,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import print_function - import gc import logging import math @@ -48,6 +46,7 @@ from synapse.api.urls import ( from synapse.app import _base from synapse.app._base import listen_ssl, listen_tcp, quit_with_error from synapse.config._base import ConfigError +from synapse.config.emailconfig import ThreepidBehaviour from synapse.config.homeserver import HomeServerConfig from synapse.config.server import ListenerConfig from synapse.federation.transport.server import TransportLayerServer @@ -209,6 +208,15 @@ class SynapseHomeServer(HomeServer): resources["/_matrix/saml2"] = SAML2Resource(self) + if self.get_config().threepid_behaviour_email == ThreepidBehaviour.LOCAL: + from synapse.rest.synapse.client.password_reset import ( + PasswordResetSubmitTokenResource, + ) + + resources[ + "/_synapse/client/password_reset/email/submit_token" + ] = PasswordResetSubmitTokenResource(self) + if name == "consent": from synapse.rest.consent.consent_resource import ConsentResource diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index bb6fa8299a..c526c28b93 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -88,7 +88,7 @@ class ApplicationServiceApi(SimpleHttpClient): """ def __init__(self, hs): - super(ApplicationServiceApi, self).__init__(hs) + super().__init__(hs) self.clock = hs.get_clock() self.protocol_meta_cache = ResponseCache( @@ -178,7 +178,7 @@ class ApplicationServiceApi(SimpleHttpClient): urllib.parse.quote(protocol), ) try: - info = await self.get_json(uri, {}) + info = await self.get_json(uri) if not _is_valid_3pe_metadata(info): logger.warning( diff --git a/synapse/config/_base.py b/synapse/config/_base.py index f8ab8e38df..85f65da4d9 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -242,12 +242,11 @@ class Config: env = jinja2.Environment(loader=loader, autoescape=autoescape) # Update the environment with our custom filters - env.filters.update( - { - "format_ts": _format_ts_filter, - "mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl), - } - ) + env.filters.update({"format_ts": _format_ts_filter}) + if self.public_baseurl: + env.filters.update( + {"mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl)} + ) for filename in filenames: # Load the template @@ -838,11 +837,26 @@ class ShardedWorkerHandlingConfig: def should_handle(self, instance_name: str, key: str) -> bool: """Whether this instance is responsible for handling the given key. """ - - # If multiple instances are not defined we always return true. + # If multiple instances are not defined we always return true if not self.instances or len(self.instances) == 1: return True + return self.get_instance(key) == instance_name + + def get_instance(self, key: str) -> str: + """Get the instance responsible for handling the given key. + + Note: For things like federation sending the config for which instance + is sending is known only to the sender instance if there is only one. + Therefore `should_handle` should be used where possible. + """ + + if not self.instances: + return "master" + + if len(self.instances) == 1: + return self.instances[0] + # We shard by taking the hash, modulo it by the number of instances and # then checking whether this instance matches the instance at that # index. @@ -852,7 +866,7 @@ class ShardedWorkerHandlingConfig: dest_hash = sha256(key.encode("utf8")).digest() dest_int = int.from_bytes(dest_hash, byteorder="little") remainder = dest_int % (len(self.instances)) - return self.instances[remainder] == instance_name + return self.instances[remainder] __all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"] diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index eb911e8f9f..b8faafa9bd 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -142,3 +142,4 @@ class ShardedWorkerHandlingConfig: instances: List[str] def __init__(self, instances: List[str]) -> None: ... def should_handle(self, instance_name: str, key: str) -> bool: ... + def get_instance(self, key: str) -> str: ... diff --git a/synapse/config/_util.py b/synapse/config/_util.py index cd31b1c3c9..c74969a977 100644 --- a/synapse/config/_util.py +++ b/synapse/config/_util.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List +from typing import Any, Iterable import jsonschema @@ -20,7 +20,9 @@ from synapse.config._base import ConfigError from synapse.types import JsonDict -def validate_config(json_schema: JsonDict, config: Any, config_path: List[str]) -> None: +def validate_config( + json_schema: JsonDict, config: Any, config_path: Iterable[str] +) -> None: """Validates a config setting against a JsonSchema definition This can be used to validate a section of the config file against a schema diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py index 82f04d7966..cb00958165 100644 --- a/synapse/config/captcha.py +++ b/synapse/config/captcha.py @@ -28,6 +28,9 @@ class CaptchaConfig(Config): "recaptcha_siteverify_api", "https://www.recaptcha.net/recaptcha/api/siteverify", ) + self.recaptcha_template = self.read_templates( + ["recaptcha.html"], autoescape=True + )[0] def generate_config_section(self, **kwargs): return """\ diff --git a/synapse/config/consent_config.py b/synapse/config/consent_config.py index aec9c4bbce..6efa59b110 100644 --- a/synapse/config/consent_config.py +++ b/synapse/config/consent_config.py @@ -77,7 +77,7 @@ class ConsentConfig(Config): section = "consent" def __init__(self, *args): - super(ConsentConfig, self).__init__(*args) + super().__init__(*args) self.user_consent_version = None self.user_consent_template_dir = None @@ -89,6 +89,8 @@ class ConsentConfig(Config): def read_config(self, config, **kwargs): consent_config = config.get("user_consent") + self.terms_template = self.read_templates(["terms.html"], autoescape=True)[0] + if consent_config is None: return self.user_consent_version = str(consent_config["version"]) diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index 7a796996c0..cceffbfee2 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -14,7 +14,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import print_function # This file can't be called email.py because if it is, we cannot: import email.utils @@ -228,6 +227,7 @@ class EmailConfig(Config): self.email_registration_template_text, self.email_add_threepid_template_html, self.email_add_threepid_template_text, + self.email_password_reset_template_confirmation_html, self.email_password_reset_template_failure_html, self.email_registration_template_failure_html, self.email_add_threepid_template_failure_html, @@ -242,6 +242,7 @@ class EmailConfig(Config): registration_template_text, add_threepid_template_html, add_threepid_template_text, + "password_reset_confirmation.html", password_reset_template_failure_html, registration_template_failure_html, add_threepid_template_failure_html, @@ -404,9 +405,13 @@ class EmailConfig(Config): # * The contents of password reset emails sent by the homeserver: # 'password_reset.html' and 'password_reset.txt' # - # * HTML pages for success and failure that a user will see when they follow - # the link in the password reset email: 'password_reset_success.html' and - # 'password_reset_failure.html' + # * An HTML page that a user will see when they follow the link in the password + # reset email. The user will be asked to confirm the action before their + # password is reset: 'password_reset_confirmation.html' + # + # * HTML pages for success and failure that a user will see when they confirm + # the password reset flow using the page above: 'password_reset_success.html' + # and 'password_reset_failure.html' # # * The contents of address verification emails sent during registration: # 'registration.html' and 'registration.txt' diff --git a/synapse/config/federation.py b/synapse/config/federation.py index 2c77d8f85b..ffd8fca54e 100644 --- a/synapse/config/federation.py +++ b/synapse/config/federation.py @@ -17,7 +17,8 @@ from typing import Optional from netaddr import IPSet -from ._base import Config, ConfigError +from synapse.config._base import Config, ConfigError +from synapse.config._util import validate_config class FederationConfig(Config): @@ -52,8 +53,18 @@ class FederationConfig(Config): "Invalid range(s) provided in federation_ip_range_blacklist: %s" % e ) + federation_metrics_domains = config.get("federation_metrics_domains") or [] + validate_config( + _METRICS_FOR_DOMAINS_SCHEMA, + federation_metrics_domains, + ("federation_metrics_domains",), + ) + self.federation_metrics_domains = set(federation_metrics_domains) + def generate_config_section(self, config_dir_path, server_name, **kwargs): return """\ + ## Federation ## + # Restrict federation to the following whitelist of domains. # N.B. we recommend also firewalling your federation listener to limit # inbound federation traffic as early as possible, rather than relying @@ -85,4 +96,18 @@ class FederationConfig(Config): - '::1/128' - 'fe80::/64' - 'fc00::/7' + + # Report prometheus metrics on the age of PDUs being sent to and received from + # the following domains. This can be used to give an idea of "delay" on inbound + # and outbound federation, though be aware that any delay can be due to problems + # at either end or with the intermediate network. + # + # By default, no domains are monitored in this way. + # + #federation_metrics_domains: + # - matrix.org + # - example.com """ + + +_METRICS_FOR_DOMAINS_SCHEMA = {"type": "array", "items": {"type": "string"}} diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 556e291495..be65554524 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -92,5 +92,4 @@ class HomeServerConfig(RootConfig): TracerConfig, WorkerConfig, RedisConfig, - FederationConfig, ] diff --git a/synapse/config/logger.py b/synapse/config/logger.py index c96e6ef62a..13d6f6a3ea 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -17,6 +17,7 @@ import logging import logging.config import os import sys +import threading from string import Template import yaml @@ -25,6 +26,7 @@ from twisted.logger import ( ILogObserver, LogBeginner, STDLibLogObserver, + eventAsText, globalLogBeginner, ) @@ -216,8 +218,9 @@ def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner): # system. observer = STDLibLogObserver() - def _log(event): + threadlocal = threading.local() + def _log(event): if "log_text" in event: if event["log_text"].startswith("DNSDatagramProtocol starting on "): return @@ -228,7 +231,25 @@ def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner): if event["log_text"].startswith("Timing out client"): return - return observer(event) + # this is a workaround to make sure we don't get stack overflows when the + # logging system raises an error which is written to stderr which is redirected + # to the logging system, etc. + if getattr(threadlocal, "active", False): + # write the text of the event, if any, to the *real* stderr (which may + # be redirected to /dev/null, but there's not much we can do) + try: + event_text = eventAsText(event) + print("logging during logging: %s" % event_text, file=sys.__stderr__) + except Exception: + # gah. + pass + return + + try: + threadlocal.active = True + return observer(event) + finally: + threadlocal.active = False logBeginner.beginLoggingTo([_log], redirectStandardIO=not config.no_redirect_stdio) if not config.no_redirect_stdio: diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index e0939bce84..f924116819 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -56,6 +56,7 @@ class OIDCConfig(Config): self.oidc_userinfo_endpoint = oidc_config.get("userinfo_endpoint") self.oidc_jwks_uri = oidc_config.get("jwks_uri") self.oidc_skip_verification = oidc_config.get("skip_verification", False) + self.oidc_allow_existing_users = oidc_config.get("allow_existing_users", False) ump_config = oidc_config.get("user_mapping_provider", {}) ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER) @@ -158,6 +159,11 @@ class OIDCConfig(Config): # #skip_verification: true + # Uncomment to allow a user logging in via OIDC to match a pre-existing account instead + # of failing. This could be used if switching from password logins to OIDC. Defaults to false. + # + #allow_existing_users: true + # An external module can be provided here as a custom solution to mapping # attributes returned from a OIDC provider onto a matrix user. # @@ -198,6 +204,14 @@ class OIDCConfig(Config): # If unset, no displayname will be set. # #display_name_template: "{{{{ user.given_name }}}} {{{{ user.last_name }}}}" + + # Jinja2 templates for extra attributes to send back to the client during + # login. + # + # Note that these are non-standard and clients will ignore them without modifications. + # + #extra_attributes: + #birthdate: "{{{{ user.birthdate }}}}" """.format( mapping_provider=DEFAULT_USER_MAPPING_PROVIDER ) diff --git a/synapse/config/registration.py b/synapse/config/registration.py index a670080fe2..aeae5bcaea 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -30,7 +30,7 @@ class AccountValidityConfig(Config): def __init__(self, config, synapse_config): if config is None: return - super(AccountValidityConfig, self).__init__() + super().__init__() self.enabled = config.get("enabled", False) self.renew_by_email_enabled = "renew_at" in config @@ -190,6 +190,11 @@ class RegistrationConfig(Config): session_lifetime = self.parse_duration(session_lifetime) self.session_lifetime = session_lifetime + # The success template used during fallback auth. + self.fallback_success_template = self.read_templates( + ["auth_success.html"], autoescape=True + )[0] + def generate_config_section(self, generate_secrets=False, **kwargs): if generate_secrets: registration_shared_secret = 'registration_shared_secret: "%s"' % ( diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index 755478e2ff..99aa8b3bf1 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -169,12 +169,6 @@ class SAML2Config(Config): saml2_config.get("saml_session_lifetime", "15m") ) - # We enable autoescape here as the message may potentially come from a - # remote resource - self.saml2_error_html_template = self.read_templates( - ["saml_error.html"], saml2_config.get("template_dir"), autoescape=True - )[0] - def _default_saml_config_dict( self, required_attributes: set, optional_attributes: set ): @@ -227,11 +221,14 @@ class SAML2Config(Config): # At least one of `sp_config` or `config_path` must be set in this section to # enable SAML login. # - # (You will probably also want to set the following options to `false` to + # You will probably also want to set the following options to `false` to # disable the regular login/registration flows: # * enable_registration # * password_config.enabled # + # You will also want to investigate the settings under the "sso" configuration + # section below. + # # Once SAML support is enabled, a metadata file will be exposed at # https://<server>:<port>/_matrix/saml2/metadata.xml, which you may be able to # use to configure your SAML IdP with. Alternatively, you can manually configure @@ -353,31 +350,6 @@ class SAML2Config(Config): # value: "staff" # - attribute: department # value: "sales" - - # Directory in which Synapse will try to find the template files below. - # If not set, default templates from within the Synapse package will be used. - # - # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates. - # If you *do* uncomment it, you will need to make sure that all the templates - # below are in the directory. - # - # Synapse will look for the following templates in this directory: - # - # * HTML page to display to users if something goes wrong during the - # authentication process: 'saml_error.html'. - # - # When rendering, this template is given the following variables: - # * code: an HTML error code corresponding to the error that is being - # returned (typically 400 or 500) - # - # * msg: a textual message describing the error. - # - # The variables will automatically be HTML-escaped. - # - # You can see the default templates at: - # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates - # - #template_dir: "res/templates" """ % { "config_dir_path": config_dir_path } diff --git a/synapse/config/server.py b/synapse/config/server.py index e85c6a0840..ef6d70e3f8 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -19,7 +19,7 @@ import logging import os.path import re from textwrap import indent -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Set import attr import yaml @@ -542,6 +542,19 @@ class ServerConfig(Config): users_new_default_push_rules ) # type: set + # Whitelist of domain names that given next_link parameters must have + next_link_domain_whitelist = config.get( + "next_link_domain_whitelist" + ) # type: Optional[List[str]] + + self.next_link_domain_whitelist = None # type: Optional[Set[str]] + if next_link_domain_whitelist is not None: + if not isinstance(next_link_domain_whitelist, list): + raise ConfigError("'next_link_domain_whitelist' must be a list") + + # Turn the list into a set to improve lookup speed. + self.next_link_domain_whitelist = set(next_link_domain_whitelist) + def has_tls_listener(self) -> bool: return any(listener.tls for listener in self.listeners) @@ -628,10 +641,23 @@ class ServerConfig(Config): """\ ## Server ## - # The domain name of the server, with optional explicit port. - # This is used by remote servers to connect to this server, - # e.g. matrix.org, localhost:8080, etc. - # This is also the last part of your UserID. + # The public-facing domain of the server + # + # The server_name name will appear at the end of usernames and room addresses + # created on this server. For example if the server_name was example.com, + # usernames on this server would be in the format @user:example.com + # + # In most cases you should avoid using a matrix specific subdomain such as + # matrix.example.com or synapse.example.com as the server_name for the same + # reasons you wouldn't use user@email.example.com as your email address. + # See https://github.com/matrix-org/synapse/blob/master/docs/delegate.md + # for information on how to host Synapse on a subdomain while preserving + # a clean server_name. + # + # The server_name cannot be changed later so it is important to + # configure this correctly before you start Synapse. It should be all + # lowercase and may contain an explicit port. + # Examples: matrix.org, localhost:8080 # server_name: "%(server_name)s" @@ -1014,6 +1040,24 @@ class ServerConfig(Config): # act as if no error happened and return a fake session ID ('sid') to clients. # #request_token_inhibit_3pid_errors: true + + # A list of domains that the domain portion of 'next_link' parameters + # must match. + # + # This parameter is optionally provided by clients while requesting + # validation of an email or phone number, and maps to a link that + # users will be automatically redirected to after validation + # succeeds. Clients can make use this parameter to aid the validation + # process. + # + # The whitelist is applied whether the homeserver or an + # identity server is handling validation. + # + # The default value is no whitelist functionality; all domains are + # allowed. Setting this value to an empty list will instead disallow + # all domains. + # + #next_link_domain_whitelist: ["matrix.org"] """ % locals() ) diff --git a/synapse/config/server_notices_config.py b/synapse/config/server_notices_config.py index 6c427b6f92..57f69dc8e2 100644 --- a/synapse/config/server_notices_config.py +++ b/synapse/config/server_notices_config.py @@ -62,7 +62,7 @@ class ServerNoticesConfig(Config): section = "servernotices" def __init__(self, *args): - super(ServerNoticesConfig, self).__init__(*args) + super().__init__(*args) self.server_notices_mxid = None self.server_notices_mxid_display_name = None self.server_notices_mxid_avatar_url = None diff --git a/synapse/config/stats.py b/synapse/config/stats.py index 62485189ea..b559bfa411 100644 --- a/synapse/config/stats.py +++ b/synapse/config/stats.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import division - import sys from ._base import Config diff --git a/synapse/config/tls.py b/synapse/config/tls.py index e368ea564d..9ddb8b546b 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -471,7 +471,6 @@ class TlsConfig(Config): # or by checking matrix.org/federationtester/api/report?server_name=$host # #tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}] - """ # Lowercase the string representation of boolean values % { diff --git a/synapse/config/workers.py b/synapse/config/workers.py index c784a71508..f23e42cdf9 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -13,12 +13,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Union + import attr from ._base import Config, ConfigError, ShardedWorkerHandlingConfig from .server import ListenerConfig, parse_listener_def +def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]: + """Helper for allowing parsing a string or list of strings to a config + option expecting a list of strings. + """ + + if isinstance(obj, str): + return [obj] + return obj + + @attr.s class InstanceLocationConfig: """The host and port to talk to an instance via HTTP replication. @@ -33,11 +45,13 @@ class WriterLocations: """Specifies the instances that write various streams. Attributes: - events: The instance that writes to the event and backfill streams. - events: The instance that writes to the typing stream. + events: The instances that write to the event and backfill streams. + typing: The instance that writes to the typing stream. """ - events = attr.ib(default="master", type=str) + events = attr.ib( + default=["master"], type=List[str], converter=_instance_to_list_converter + ) typing = attr.ib(default="master", type=str) @@ -105,15 +119,18 @@ class WorkerConfig(Config): writers = config.get("stream_writers") or {} self.writers = WriterLocations(**writers) - # Check that the configured writer for events and typing also appears in + # Check that the configured writers for events and typing also appears in # `instance_map`. for stream in ("events", "typing"): - instance = getattr(self.writers, stream) - if instance != "master" and instance not in self.instance_map: - raise ConfigError( - "Instance %r is configured to write %s but does not appear in `instance_map` config." - % (instance, stream) - ) + instances = _instance_to_list_converter(getattr(self.writers, stream)) + for instance in instances: + if instance != "master" and instance not in self.instance_map: + raise ConfigError( + "Instance %r is configured to write %s but does not appear in `instance_map` config." + % (instance, stream) + ) + + self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events) def generate_config_section(self, config_dir_path, server_name, **kwargs): return """\ diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py index 2b03f5ac76..79668a402e 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py @@ -45,7 +45,11 @@ _TLS_VERSION_MAP = { class ServerContextFactory(ContextFactory): """Factory for PyOpenSSL SSL contexts that are used to handle incoming - connections.""" + connections. + + TODO: replace this with an implementation of IOpenSSLServerConnectionCreator, + per https://github.com/matrix-org/synapse/issues/1691 + """ def __init__(self, config): # TODO: once pyOpenSSL exposes TLS_METHOD and SSL_CTX_set_min_proto_version, diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 32c31b1cd1..c04ad77cf9 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -42,7 +42,6 @@ from synapse.api.errors import ( ) from synapse.logging.context import ( PreserveLoggingContext, - current_context, make_deferred_yieldable, preserve_fn, run_in_background, @@ -233,8 +232,6 @@ class Keyring: """ try: - ctx = current_context() - # map from server name to a set of outstanding request ids server_to_request_ids = {} @@ -265,12 +262,8 @@ class Keyring: # if there are no more requests for this server, we can drop the lock. if not server_requests: - with PreserveLoggingContext(ctx): - logger.debug("Releasing key lookup lock on %s", server_name) - - # ... but not immediately, as that can cause stack explosions if - # we get a long queue of lookups. - self.clock.call_later(0, drop_server_lock, server_name) + logger.debug("Releasing key lookup lock on %s", server_name) + drop_server_lock(server_name) return res @@ -335,20 +328,32 @@ class Keyring: ) # look for any requests which weren't satisfied - with PreserveLoggingContext(): - for verify_request in remaining_requests: - verify_request.key_ready.errback( - SynapseError( - 401, - "No key for %s with ids in %s (min_validity %i)" - % ( - verify_request.server_name, - verify_request.key_ids, - verify_request.minimum_valid_until_ts, - ), - Codes.UNAUTHORIZED, - ) + while remaining_requests: + verify_request = remaining_requests.pop() + rq_str = ( + "VerifyJsonRequest(server=%s, key_ids=%s, min_valid=%i)" + % ( + verify_request.server_name, + verify_request.key_ids, + verify_request.minimum_valid_until_ts, ) + ) + + # If we run the errback immediately, it may cancel our + # loggingcontext while we are still in it, so instead we + # schedule it for the next time round the reactor. + # + # (this also ensures that we don't get a stack overflow if we + # has a massive queue of lookups waiting for this server). + self.clock.call_later( + 0, + verify_request.key_ready.errback, + SynapseError( + 401, + "Failed to find any key to satisfy %s" % (rq_str,), + Codes.UNAUTHORIZED, + ), + ) except Exception as err: # we don't really expect to get here, because any errors should already # have been caught and logged. But if we do, let's log the error and make @@ -410,10 +415,23 @@ class Keyring: # key was not valid at this point continue - with PreserveLoggingContext(): - verify_request.key_ready.callback( - (server_name, key_id, fetch_key_result.verify_key) - ) + # we have a valid key for this request. If we run the callback + # immediately, it may cancel our loggingcontext while we are still in + # it, so instead we schedule it for the next time round the reactor. + # + # (this also ensures that we don't get a stack overflow if we had + # a massive queue of lookups waiting for this server). + logger.debug( + "Found key %s:%s for %s", + server_name, + key_id, + verify_request.request_name, + ) + self.clock.call_later( + 0, + verify_request.key_ready.callback, + (server_name, key_id, fetch_key_result.verify_key), + ) completed.append(verify_request) break @@ -558,7 +576,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): """KeyFetcher impl which fetches keys from the "perspectives" servers""" def __init__(self, hs): - super(PerspectivesKeyFetcher, self).__init__(hs) + super().__init__(hs) self.clock = hs.get_clock() self.client = hs.get_http_client() self.key_servers = self.config.key_servers @@ -728,7 +746,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher): """KeyFetcher impl which fetches keys from the origin servers""" def __init__(self, hs): - super(ServerKeyFetcher, self).__init__(hs) + super().__init__(hs) self.clock = hs.get_clock() self.client = hs.get_http_client() diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index bf800a3852..dc49df0812 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -23,7 +23,7 @@ from typing import Dict, Optional, Tuple, Type from unpaddedbase64 import encode_base64 from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions -from synapse.types import JsonDict +from synapse.types import JsonDict, RoomStreamToken from synapse.util.caches import intern_dict from synapse.util.frozenutils import freeze @@ -118,8 +118,8 @@ class _EventInternalMetadata: # XXX: These are set by StreamWorkerStore._set_before_and_after. # I'm pretty sure that these are never persisted to the database, so shouldn't # be here - before = DictProperty("before") # type: str - after = DictProperty("after") # type: str + before = DictProperty("before") # type: RoomStreamToken + after = DictProperty("after") # type: RoomStreamToken order = DictProperty("order") # type: Tuple[int, int] def get_dict(self) -> JsonDict: diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index d42930d1b9..302b2f69bc 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -24,10 +24,12 @@ from typing import ( Dict, Iterable, List, + Mapping, Optional, Sequence, Tuple, TypeVar, + Union, ) from prometheus_client import Counter @@ -79,7 +81,7 @@ class InvalidResponseError(RuntimeError): class FederationClient(FederationBase): def __init__(self, hs): - super(FederationClient, self).__init__(hs) + super().__init__(hs) self.pdu_destination_tried = {} self._clock.looping_call(self._clear_tried_cache, 60 * 1000) @@ -501,7 +503,7 @@ class FederationClient(FederationBase): user_id: str, membership: str, content: dict, - params: Dict[str, str], + params: Optional[Mapping[str, Union[str, Iterable[str]]]], ) -> Tuple[str, EventBase, RoomVersion]: """ Creates an m.room.member event, with context, without participating in the room. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index ff00f0b302..24329dd0e3 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -28,7 +28,7 @@ from typing import ( Union, ) -from prometheus_client import Counter, Histogram +from prometheus_client import Counter, Gauge, Histogram from twisted.internet import defer from twisted.internet.abstract import isIPAddress @@ -88,9 +88,16 @@ pdu_process_time = Histogram( ) +last_pdu_age_metric = Gauge( + "synapse_federation_last_received_pdu_age", + "The age (in seconds) of the last PDU successfully received from the given domain", + labelnames=("server_name",), +) + + class FederationServer(FederationBase): def __init__(self, hs): - super(FederationServer, self).__init__(hs) + super().__init__(hs) self.auth = hs.get_auth() self.handler = hs.get_handlers().federation_handler @@ -118,6 +125,10 @@ class FederationServer(FederationBase): hs, "state_ids_resp", timeout_ms=30000 ) + self._federation_metrics_domains = ( + hs.get_config().federation.federation_metrics_domains + ) + async def on_backfill_request( self, origin: str, room_id: str, versions: List[str], limit: int ) -> Tuple[int, Dict[str, Any]]: @@ -262,7 +273,11 @@ class FederationServer(FederationBase): pdus_by_room = {} # type: Dict[str, List[EventBase]] + newest_pdu_ts = 0 + for p in transaction.pdus: # type: ignore + # FIXME (richardv): I don't think this works: + # https://github.com/matrix-org/synapse/issues/8429 if "unsigned" in p: unsigned = p["unsigned"] if "age" in unsigned: @@ -300,6 +315,9 @@ class FederationServer(FederationBase): event = event_from_pdu_json(p, room_version) pdus_by_room.setdefault(room_id, []).append(event) + if event.origin_server_ts > newest_pdu_ts: + newest_pdu_ts = event.origin_server_ts + pdu_results = {} # we can process different rooms in parallel (which is useful if they @@ -340,6 +358,10 @@ class FederationServer(FederationBase): process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT ) + if newest_pdu_ts and origin in self._federation_metrics_domains: + newest_pdu_age = self._clock.time_msec() - newest_pdu_ts + last_pdu_age_metric.labels(server_name=origin).set(newest_pdu_age / 1000) + return pdu_results async def _handle_edus_in_txn(self, origin: str, transaction: Transaction): diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 552519e82c..8bb17b3a05 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -55,6 +55,15 @@ sent_pdus_destination_dist_total = Counter( "Total number of PDUs queued for sending across all destinations", ) +# Time (in s) after Synapse's startup that we will begin to wake up destinations +# that have catch-up outstanding. +CATCH_UP_STARTUP_DELAY_SEC = 15 + +# Time (in s) to wait in between waking up each destination, i.e. one destination +# will be woken up every <x> seconds after Synapse's startup until we have woken +# every destination has outstanding catch-up. +CATCH_UP_STARTUP_INTERVAL_SEC = 5 + class FederationSender: def __init__(self, hs: "synapse.server.HomeServer"): @@ -125,6 +134,14 @@ class FederationSender: 1000.0 / hs.config.federation_rr_transactions_per_room_per_second ) + # wake up destinations that have outstanding PDUs to be caught up + self._catchup_after_startup_timer = self.clock.call_later( + CATCH_UP_STARTUP_DELAY_SEC, + run_as_background_process, + "wake_destinations_needing_catchup", + self._wake_destinations_needing_catchup, + ) + def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue: """Get or create a PerDestinationQueue for the given destination @@ -209,7 +226,7 @@ class FederationSender: logger.debug("Sending %s to %r", event, destinations) if destinations: - self._send_pdu(event, destinations) + await self._send_pdu(event, destinations) now = self.clock.time_msec() ts = await self.store.get_received_ts(event.event_id) @@ -265,7 +282,7 @@ class FederationSender: finally: self._is_processing = False - def _send_pdu(self, pdu: EventBase, destinations: Iterable[str]) -> None: + async def _send_pdu(self, pdu: EventBase, destinations: Iterable[str]) -> None: # We loop through all destinations to see whether we already have # a transaction in progress. If we do, stick it in the pending_pdus # table and we'll get back to it later. @@ -280,6 +297,13 @@ class FederationSender: sent_pdus_destination_dist_total.inc(len(destinations)) sent_pdus_destination_dist_count.inc() + # track the fact that we have a PDU for these destinations, + # to allow us to perform catch-up later on if the remote is unreachable + # for a while. + await self.store.store_destination_rooms_entries( + destinations, pdu.room_id, pdu.internal_metadata.stream_ordering, + ) + for destination in destinations: self._get_per_destination_queue(destination).send_pdu(pdu) @@ -553,3 +577,37 @@ class FederationSender: # Dummy implementation for case where federation sender isn't offloaded # to a worker. return [], 0, False + + async def _wake_destinations_needing_catchup(self): + """ + Wakes up destinations that need catch-up and are not currently being + backed off from. + + In order to reduce load spikes, adds a delay between each destination. + """ + + last_processed = None # type: Optional[str] + + while True: + destinations_to_wake = await self.store.get_catch_up_outstanding_destinations( + last_processed + ) + + if not destinations_to_wake: + # finished waking all destinations! + self._catchup_after_startup_timer = None + break + + destinations_to_wake = [ + d + for d in destinations_to_wake + if self._federation_shard_config.should_handle(self._instance_name, d) + ] + + for last_processed in destinations_to_wake: + logger.info( + "Destination %s has outstanding catch-up, waking up.", + last_processed, + ) + self.wake_destination(last_processed) + await self.clock.sleep(CATCH_UP_STARTUP_INTERVAL_SEC) diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index defc228c23..bc99af3fdd 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -15,7 +15,7 @@ # limitations under the License. import datetime import logging -from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Tuple +from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple, cast from prometheus_client import Counter @@ -92,6 +92,21 @@ class PerDestinationQueue: self._destination = destination self.transmission_loop_running = False + # True whilst we are sending events that the remote homeserver missed + # because it was unreachable. We start in this state so we can perform + # catch-up at startup. + # New events will only be sent once this is finished, at which point + # _catching_up is flipped to False. + self._catching_up = True # type: bool + + # The stream_ordering of the most recent PDU that was discarded due to + # being in catch-up mode. + self._catchup_last_skipped = 0 # type: int + + # Cache of the last successfully-transmitted stream ordering for this + # destination (we are the only updater so this is safe) + self._last_successful_stream_ordering = None # type: Optional[int] + # a list of pending PDUs self._pending_pdus = [] # type: List[EventBase] @@ -138,7 +153,13 @@ class PerDestinationQueue: Args: pdu: pdu to send """ - self._pending_pdus.append(pdu) + if not self._catching_up or self._last_successful_stream_ordering is None: + # only enqueue the PDU if we are not catching up (False) or do not + # yet know if we have anything to catch up (None) + self._pending_pdus.append(pdu) + else: + self._catchup_last_skipped = pdu.internal_metadata.stream_ordering + self.attempt_new_transaction() def send_presence(self, states: Iterable[UserPresenceState]) -> None: @@ -218,6 +239,13 @@ class PerDestinationQueue: # hence why we throw the result away. await get_retry_limiter(self._destination, self._clock, self._store) + if self._catching_up: + # we potentially need to catch-up first + await self._catch_up_transmission_loop() + if self._catching_up: + # not caught up yet + return + pending_pdus = [] while True: # We have to keep 2 free slots for presence and rr_edus @@ -325,6 +353,17 @@ class PerDestinationQueue: self._last_device_stream_id = device_stream_id self._last_device_list_stream_id = dev_list_id + + if pending_pdus: + # we sent some PDUs and it was successful, so update our + # last_successful_stream_ordering in the destinations table. + final_pdu = pending_pdus[-1] + last_successful_stream_ordering = ( + final_pdu.internal_metadata.stream_ordering + ) + await self._store.set_destination_last_successful_stream_ordering( + self._destination, last_successful_stream_ordering + ) else: break except NotRetryingDestination as e: @@ -340,8 +379,9 @@ class PerDestinationQueue: if e.retry_interval > 60 * 60 * 1000: # we won't retry for another hour! # (this suggests a significant outage) - # We drop pending PDUs and EDUs because otherwise they will + # We drop pending EDUs because otherwise they will # rack up indefinitely. + # (Dropping PDUs is already performed by `_start_catching_up`.) # Note that: # - the EDUs that are being dropped here are those that we can # afford to drop (specifically, only typing notifications, @@ -353,11 +393,12 @@ class PerDestinationQueue: # dropping read receipts is a bit sad but should be solved # through another mechanism, because this is all volatile! - self._pending_pdus = [] self._pending_edus = [] self._pending_edus_keyed = {} self._pending_presence = {} self._pending_rrs = {} + + self._start_catching_up() except FederationDeniedError as e: logger.info(e) except HttpResponseException as e: @@ -367,6 +408,8 @@ class PerDestinationQueue: e.code, e, ) + + self._start_catching_up() except RequestSendFailed as e: logger.warning( "TX [%s] Failed to send transaction: %s", self._destination, e @@ -376,16 +419,96 @@ class PerDestinationQueue: logger.info( "Failed to send event %s to %s", p.event_id, self._destination ) + + self._start_catching_up() except Exception: logger.exception("TX [%s] Failed to send transaction", self._destination) for p in pending_pdus: logger.info( "Failed to send event %s to %s", p.event_id, self._destination ) + + self._start_catching_up() finally: # We want to be *very* sure we clear this after we stop processing self.transmission_loop_running = False + async def _catch_up_transmission_loop(self) -> None: + first_catch_up_check = self._last_successful_stream_ordering is None + + if first_catch_up_check: + # first catchup so get last_successful_stream_ordering from database + self._last_successful_stream_ordering = await self._store.get_destination_last_successful_stream_ordering( + self._destination + ) + + if self._last_successful_stream_ordering is None: + # if it's still None, then this means we don't have the information + # in our database  we haven't successfully sent a PDU to this server + # (at least since the introduction of the feature tracking + # last_successful_stream_ordering). + # Sadly, this means we can't do anything here as we don't know what + # needs catching up — so catching up is futile; let's stop. + self._catching_up = False + return + + # get at most 50 catchup room/PDUs + while True: + event_ids = await self._store.get_catch_up_room_event_ids( + self._destination, self._last_successful_stream_ordering, + ) + + if not event_ids: + # No more events to catch up on, but we can't ignore the chance + # of a race condition, so we check that no new events have been + # skipped due to us being in catch-up mode + + if self._catchup_last_skipped > self._last_successful_stream_ordering: + # another event has been skipped because we were in catch-up mode + continue + + # we are done catching up! + self._catching_up = False + break + + if first_catch_up_check: + # as this is our check for needing catch-up, we may have PDUs in + # the queue from before we *knew* we had to do catch-up, so + # clear those out now. + self._start_catching_up() + + # fetch the relevant events from the event store + # - redacted behaviour of REDACT is fine, since we only send metadata + # of redacted events to the destination. + # - don't need to worry about rejected events as we do not actively + # forward received events over federation. + catchup_pdus = await self._store.get_events_as_list(event_ids) + if not catchup_pdus: + raise AssertionError( + "No events retrieved when we asked for %r. " + "This should not happen." % event_ids + ) + + if logger.isEnabledFor(logging.INFO): + rooms = [p.room_id for p in catchup_pdus] + logger.info("Catching up rooms to %s: %r", self._destination, rooms) + + success = await self._transaction_manager.send_new_transaction( + self._destination, catchup_pdus, [] + ) + + if not success: + return + + sent_transactions_counter.inc() + final_pdu = catchup_pdus[-1] + self._last_successful_stream_ordering = cast( + int, final_pdu.internal_metadata.stream_ordering + ) + await self._store.set_destination_last_successful_stream_ordering( + self._destination, self._last_successful_stream_ordering + ) + def _get_rr_edus(self, force_flush: bool) -> Iterable[Edu]: if not self._pending_rrs: return @@ -446,3 +569,12 @@ class PerDestinationQueue: ] return (edus, stream_id) + + def _start_catching_up(self) -> None: + """ + Marks this destination as being in catch-up mode. + + This throws away the PDU queue. + """ + self._catching_up = True + self._pending_pdus = [] diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index c84072ab73..3e07f925e0 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -15,6 +15,8 @@ import logging from typing import TYPE_CHECKING, List +from prometheus_client import Gauge + from synapse.api.errors import HttpResponseException from synapse.events import EventBase from synapse.federation.persistence import TransactionActions @@ -34,6 +36,12 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +last_pdu_age_metric = Gauge( + "synapse_federation_last_sent_pdu_age", + "The age (in seconds) of the last PDU successfully sent to the given domain", + labelnames=("server_name",), +) + class TransactionManager: """Helper class which handles building and sending transactions @@ -48,6 +56,10 @@ class TransactionManager: self._transaction_actions = TransactionActions(self._store) self._transport_layer = hs.get_federation_transport_client() + self._federation_metrics_domains = ( + hs.get_config().federation.federation_metrics_domains + ) + # HACK to get unique tx id self._next_txn_id = int(self.clock.time_msec()) @@ -119,6 +131,9 @@ class TransactionManager: # FIXME (erikj): This is a bit of a hack to make the Pdu age # keys work + # FIXME (richardv): I also believe it no longer works. We (now?) store + # "age_ts" in "unsigned" rather than at the top level. See + # https://github.com/matrix-org/synapse/issues/8429. def json_data_cb(): data = transaction.get_dict() now = int(self.clock.time_msec()) @@ -167,5 +182,12 @@ class TransactionManager: ) success = False + if success and pdus and destination in self._federation_metrics_domains: + last_pdu = pdus[-1] + last_pdu_age = self.clock.time_msec() - last_pdu.origin_server_ts + last_pdu_age_metric.labels(server_name=destination).set( + last_pdu_age / 1000 + ) + set_tag(tags.ERROR, not success) return success diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index cc7e9a973b..3a6b95631e 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -68,7 +68,7 @@ class TransportLayerServer(JsonResource): self.clock = hs.get_clock() self.servlet_groups = servlet_groups - super(TransportLayerServer, self).__init__(hs, canonical_json=False) + super().__init__(hs, canonical_json=False) self.authenticator = Authenticator(hs) self.ratelimiter = hs.get_federation_ratelimiter() @@ -376,9 +376,7 @@ class FederationSendServlet(BaseFederationServlet): RATELIMIT = False def __init__(self, handler, server_name, **kwargs): - super(FederationSendServlet, self).__init__( - handler, server_name=server_name, **kwargs - ) + super().__init__(handler, server_name=server_name, **kwargs) self.server_name = server_name # This is when someone is trying to send us a bunch of data. @@ -773,9 +771,7 @@ class PublicRoomList(BaseFederationServlet): PATH = "/publicRooms" def __init__(self, handler, authenticator, ratelimiter, server_name, allow_access): - super(PublicRoomList, self).__init__( - handler, authenticator, ratelimiter, server_name - ) + super().__init__(handler, authenticator, ratelimiter, server_name) self.allow_access = allow_access async def on_GET(self, origin, content, query): diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index 3fd7bedf99..de4c94cd3a 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -336,7 +336,7 @@ class GroupsServerWorkerHandler: class GroupsServerHandler(GroupsServerWorkerHandler): def __init__(self, hs): - super(GroupsServerHandler, self).__init__(hs) + super().__init__(hs) # Ensure attestations get renewed hs.get_groups_attestation_renewer() diff --git a/synapse/handlers/acme_issuing_service.py b/synapse/handlers/acme_issuing_service.py index 69650ff221..7294649d71 100644 --- a/synapse/handlers/acme_issuing_service.py +++ b/synapse/handlers/acme_issuing_service.py @@ -76,7 +76,7 @@ def create_issuing_service(reactor, acme_url, account_key_file, well_known_resou ) -@attr.s +@attr.s(slots=True) @implementer(ICertificateStore) class ErsatzStore: """ diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 918d0e037c..1ce2091b46 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -28,7 +28,7 @@ logger = logging.getLogger(__name__) class AdminHandler(BaseHandler): def __init__(self, hs): - super(AdminHandler, self).__init__(hs) + super().__init__(hs) self.storage = hs.get_storage() self.state_store = self.storage.state @@ -125,8 +125,8 @@ class AdminHandler(BaseHandler): else: stream_ordering = room.stream_ordering - from_key = str(RoomStreamToken(0, 0)) - to_key = str(RoomStreamToken(None, stream_ordering)) + from_key = RoomStreamToken(0, 0) + to_key = RoomStreamToken(None, stream_ordering) written_events = set() # Events that we've processed in this room diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 90189869cc..00eae92052 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -137,6 +137,15 @@ def login_id_phone_to_thirdparty(identifier: JsonDict) -> Dict[str, str]: } +@attr.s(slots=True) +class SsoLoginExtraAttributes: + """Data we track about SAML2 sessions""" + + # time the session was created, in milliseconds + creation_time = attr.ib(type=int) + extra_attributes = attr.ib(type=JsonDict) + + class AuthHandler(BaseHandler): SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000 @@ -145,7 +154,7 @@ class AuthHandler(BaseHandler): Args: hs (synapse.server.HomeServer): """ - super(AuthHandler, self).__init__(hs) + super().__init__(hs) self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker] for auth_checker_class in INTERACTIVE_AUTH_CHECKERS: @@ -239,6 +248,10 @@ class AuthHandler(BaseHandler): # cast to tuple for use with str.startswith self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist) + # A mapping of user ID to extra attributes to include in the login + # response. + self._extra_attributes = {} # type: Dict[str, SsoLoginExtraAttributes] + async def validate_user_via_ui_auth( self, requester: Requester, @@ -1165,6 +1178,7 @@ class AuthHandler(BaseHandler): registered_user_id: str, request: SynapseRequest, client_redirect_url: str, + extra_attributes: Optional[JsonDict] = None, ): """Having figured out a mxid for this user, complete the HTTP request @@ -1173,6 +1187,8 @@ class AuthHandler(BaseHandler): request: The request to complete. client_redirect_url: The URL to which to redirect the user at the end of the process. + extra_attributes: Extra attributes which will be passed to the client + during successful login. Must be JSON serializable. """ # If the account has been deactivated, do not proceed with the login # flow. @@ -1181,19 +1197,30 @@ class AuthHandler(BaseHandler): respond_with_html(request, 403, self._sso_account_deactivated_template) return - self._complete_sso_login(registered_user_id, request, client_redirect_url) + self._complete_sso_login( + registered_user_id, request, client_redirect_url, extra_attributes + ) def _complete_sso_login( self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str, + extra_attributes: Optional[JsonDict] = None, ): """ The synchronous portion of complete_sso_login. This exists purely for backwards compatibility of synapse.module_api.ModuleApi. """ + # Store any extra attributes which will be passed in the login response. + # Note that this is per-user so it may overwrite a previous value, this + # is considered OK since the newest SSO attributes should be most valid. + if extra_attributes: + self._extra_attributes[registered_user_id] = SsoLoginExtraAttributes( + self._clock.time_msec(), extra_attributes, + ) + # Create a login token login_token = self.macaroon_gen.generate_short_term_login_token( registered_user_id @@ -1226,6 +1253,37 @@ class AuthHandler(BaseHandler): ) respond_with_html(request, 200, html) + async def _sso_login_callback(self, login_result: JsonDict) -> None: + """ + A login callback which might add additional attributes to the login response. + + Args: + login_result: The data to be sent to the client. Includes the user + ID and access token. + """ + # Expire attributes before processing. Note that there shouldn't be any + # valid logins that still have extra attributes. + self._expire_sso_extra_attributes() + + extra_attributes = self._extra_attributes.get(login_result["user_id"]) + if extra_attributes: + login_result.update(extra_attributes.extra_attributes) + + def _expire_sso_extra_attributes(self) -> None: + """ + Iterate through the mapping of user IDs to extra attributes and remove any that are no longer valid. + """ + # TODO This should match the amount of time the macaroon is valid for. + LOGIN_TOKEN_EXPIRATION_TIME = 2 * 60 * 1000 + expire_before = self._clock.time_msec() - LOGIN_TOKEN_EXPIRATION_TIME + to_expire = set() + for user_id, data in self._extra_attributes.items(): + if data.creation_time < expire_before: + to_expire.add(user_id) + for user_id in to_expire: + logger.debug("Expiring extra attributes for user %s", user_id) + del self._extra_attributes[user_id] + @staticmethod def add_query_param_to_url(url: str, param_name: str, param: Any): url_parts = list(urllib.parse.urlparse(url)) @@ -1235,7 +1293,7 @@ class AuthHandler(BaseHandler): return urllib.parse.urlunparse(url_parts) -@attr.s +@attr.s(slots=True) class MacaroonGenerator: hs = attr.ib() diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 25169157c1..0635ad5708 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -29,7 +29,7 @@ class DeactivateAccountHandler(BaseHandler): """Handler which deals with deactivating user accounts.""" def __init__(self, hs): - super(DeactivateAccountHandler, self).__init__(hs) + super().__init__(hs) self.hs = hs self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 643d71a710..b9d9098104 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -20,6 +20,7 @@ from typing import Any, Dict, List, Optional from synapse.api import errors from synapse.api.constants import EventTypes from synapse.api.errors import ( + Codes, FederationDeniedError, HttpResponseException, RequestSendFailed, @@ -28,7 +29,7 @@ from synapse.api.errors import ( from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import ( - RoomStreamToken, + StreamToken, get_domain_from_id, get_verify_key_from_cross_signing_key, ) @@ -47,7 +48,7 @@ MAX_DEVICE_DISPLAY_NAME_LEN = 100 class DeviceWorkerHandler(BaseHandler): def __init__(self, hs): - super(DeviceWorkerHandler, self).__init__(hs) + super().__init__(hs) self.hs = hs self.state = hs.get_state_handler() @@ -104,18 +105,14 @@ class DeviceWorkerHandler(BaseHandler): @trace @measure_func("device.get_user_ids_changed") - async def get_user_ids_changed(self, user_id, from_token): + async def get_user_ids_changed(self, user_id: str, from_token: StreamToken): """Get list of users that have had the devices updated, or have newly joined a room, that `user_id` may be interested in. - - Args: - user_id (str) - from_token (StreamToken) """ set_tag("user_id", user_id) set_tag("from_token", from_token) - now_room_key = await self.store.get_room_events_max_id() + now_room_key = self.store.get_room_max_token() room_ids = await self.store.get_rooms_for_user(user_id) @@ -142,7 +139,7 @@ class DeviceWorkerHandler(BaseHandler): ) rooms_changed.update(event.room_id for event in member_events) - stream_ordering = RoomStreamToken.parse_stream_token(from_token.room_key).stream + stream_ordering = from_token.room_key.stream possibly_changed = set(changed) possibly_left = set() @@ -253,7 +250,7 @@ class DeviceWorkerHandler(BaseHandler): class DeviceHandler(DeviceWorkerHandler): def __init__(self, hs): - super(DeviceHandler, self).__init__(hs) + super().__init__(hs) self.federation_sender = hs.get_federation_sender() @@ -267,6 +264,24 @@ class DeviceHandler(DeviceWorkerHandler): hs.get_distributor().observe("user_left_room", self.user_left_room) + def _check_device_name_length(self, name: str): + """ + Checks whether a device name is longer than the maximum allowed length. + + Args: + name: The name of the device. + + Raises: + SynapseError: if the device name is too long. + """ + if name and len(name) > MAX_DEVICE_DISPLAY_NAME_LEN: + raise SynapseError( + 400, + "Device display name is too long (max %i)" + % (MAX_DEVICE_DISPLAY_NAME_LEN,), + errcode=Codes.TOO_LARGE, + ) + async def check_device_registered( self, user_id, device_id, initial_device_display_name=None ): @@ -284,6 +299,9 @@ class DeviceHandler(DeviceWorkerHandler): Returns: str: device id (generated if none was supplied) """ + + self._check_device_name_length(initial_device_display_name) + if device_id is not None: new_device = await self.store.store_device( user_id=user_id, @@ -399,12 +417,8 @@ class DeviceHandler(DeviceWorkerHandler): # Reject a new displayname which is too long. new_display_name = content.get("display_name") - if new_display_name and len(new_display_name) > MAX_DEVICE_DISPLAY_NAME_LEN: - raise SynapseError( - 400, - "Device display name is too long (max %i)" - % (MAX_DEVICE_DISPLAY_NAME_LEN,), - ) + + self._check_device_name_length(new_display_name) try: await self.store.update_device( diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 46826eb784..62aa9a2da8 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -37,7 +37,7 @@ logger = logging.getLogger(__name__) class DirectoryHandler(BaseHandler): def __init__(self, hs): - super(DirectoryHandler, self).__init__(hs) + super().__init__(hs) self.state = hs.get_state_handler() self.appservice_handler = hs.get_application_service_handler() diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index d629c7c16c..dd40fd1299 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -1201,7 +1201,7 @@ def _one_time_keys_match(old_key_json, new_key): return old_key == new_key_copy -@attr.s +@attr.s(slots=True) class SignatureListItem: """An item in the signature list as used by upload_signatures_for_device_keys. """ diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index b05e32f457..539b4fc32e 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -37,11 +37,7 @@ logger = logging.getLogger(__name__) class EventStreamHandler(BaseHandler): def __init__(self, hs: "HomeServer"): - super(EventStreamHandler, self).__init__(hs) - - self.distributor = hs.get_distributor() - self.distributor.declare("started_user_eventstream") - self.distributor.declare("stopped_user_eventstream") + super().__init__(hs) self.clock = hs.get_clock() @@ -137,8 +133,8 @@ class EventStreamHandler(BaseHandler): chunk = { "chunk": chunks, - "start": tokens[0].to_string(), - "end": tokens[1].to_string(), + "start": await tokens[0].to_string(self.store), + "end": await tokens[1].to_string(self.store), } return chunk @@ -146,7 +142,7 @@ class EventStreamHandler(BaseHandler): class EventHandler(BaseHandler): def __init__(self, hs: "HomeServer"): - super(EventHandler, self).__init__(hs) + super().__init__(hs) self.storage = hs.get_storage() async def get_event( diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 014dab2940..1a8144405a 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -21,7 +21,7 @@ import itertools import logging from collections.abc import Container from http import HTTPStatus -from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union import attr from signedjson.key import decode_verify_key_bytes @@ -69,26 +69,29 @@ from synapse.replication.http.federation import ( ReplicationFederationSendEventsRestServlet, ReplicationStoreRoomOnInviteRestServlet, ) -from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet -from synapse.state import StateResolutionStore, resolve_events_with_store +from synapse.state import StateResolutionStore from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.types import ( JsonDict, MutableStateMap, + PersistedEventPosition, + RoomStreamToken, StateMap, UserID, get_domain_from_id, ) from synapse.util.async_helpers import Linearizer, concurrently_execute -from synapse.util.distributor import user_joined_room from synapse.util.retryutils import NotRetryingDestination from synapse.util.stringutils import shortstr from synapse.visibility import filter_events_for_server +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) -@attr.s +@attr.s(slots=True) class _NewEventInfo: """Holds information about a received event, ready for passing to _handle_new_events @@ -116,8 +119,8 @@ class FederationHandler(BaseHandler): rooms. """ - def __init__(self, hs): - super(FederationHandler, self).__init__(hs) + def __init__(self, hs: "HomeServer"): + super().__init__(hs) self.hs = hs @@ -126,11 +129,11 @@ class FederationHandler(BaseHandler): self.state_store = self.storage.state self.federation_client = hs.get_federation_client() self.state_handler = hs.get_state_handler() + self._state_resolution_handler = hs.get_state_resolution_handler() self.server_name = hs.hostname self.keyring = hs.get_keyring() self.action_generator = hs.get_action_generator() self.is_mine_id = hs.is_mine_id - self.pusher_pool = hs.get_pusherpool() self.spam_checker = hs.get_spam_checker() self.event_creation_handler = hs.get_event_creation_handler() self._message_handler = hs.get_message_handler() @@ -141,9 +144,6 @@ class FederationHandler(BaseHandler): self._replication = hs.get_replication_data_handler() self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs) - self._notify_user_membership_change = ReplicationUserJoinedLeftRoomRestServlet.make_client( - hs - ) self._clean_room_for_join_client = ReplicationCleanRoomRestServlet.make_client( hs ) @@ -159,8 +159,9 @@ class FederationHandler(BaseHandler): self._device_list_updater = hs.get_device_handler().device_list_updater self._maybe_store_room_on_invite = self.store.maybe_store_room_on_invite - # When joining a room we need to queue any events for that room up - self.room_queues = {} + # When joining a room we need to queue any events for that room up. + # For each room, a list of (pdu, origin) tuples. + self.room_queues = {} # type: Dict[str, List[Tuple[EventBase, str]]] self._room_pdu_linearizer = Linearizer("fed_room_pdu") self.third_party_event_rules = hs.get_third_party_event_rules() @@ -285,7 +286,7 @@ class FederationHandler(BaseHandler): raise Exception( "Error fetching missing prev_events for %s: %s" % (event_id, e) - ) + ) from e # Update the set of things we've seen after trying to # fetch the missing stuff @@ -384,8 +385,7 @@ class FederationHandler(BaseHandler): event_map[x.event_id] = x room_version = await self.store.get_room_version_id(room_id) - state_map = await resolve_events_with_store( - self.clock, + state_map = await self._state_resolution_handler.resolve_events_with_store( room_id, room_version, state_maps, @@ -704,31 +704,10 @@ class FederationHandler(BaseHandler): logger.debug("[%s %s] Processing event: %s", room_id, event_id, event) try: - context = await self._handle_new_event(origin, event, state=state) + await self._handle_new_event(origin, event, state=state) except AuthError as e: raise FederationError("ERROR", e.code, e.msg, affected=event.event_id) - if event.type == EventTypes.Member: - if event.membership == Membership.JOIN: - # Only fire user_joined_room if the user has acutally - # joined the room. Don't bother if the user is just - # changing their profile info. - newly_joined = True - - prev_state_ids = await context.get_prev_state_ids() - - prev_state_id = prev_state_ids.get((event.type, event.state_key)) - if prev_state_id: - prev_state = await self.store.get_event( - prev_state_id, allow_none=True - ) - if prev_state and prev_state.membership == Membership.JOIN: - newly_joined = False - - if newly_joined: - user = UserID.from_string(event.state_key) - await self.user_joined_room(user, room_id) - # For encrypted messages we check that we know about the sending device, # if we don't then we mark the device cache for that user as stale. if event.type == EventTypes.Encrypted: @@ -839,6 +818,9 @@ class FederationHandler(BaseHandler): dest, room_id, limit=limit, extremities=extremities ) + if not events: + return [] + # ideally we'd sanity check the events here for excess prev_events etc, # but it's hard to reject events at this point without completely # breaking backfill in the same way that it is currently broken by @@ -923,7 +905,8 @@ class FederationHandler(BaseHandler): ) ) - await self._handle_new_events(dest, ev_infos, backfilled=True) + if ev_infos: + await self._handle_new_events(dest, room_id, ev_infos, backfilled=True) # Step 2: Persist the rest of the events in the chunk one by one events.sort(key=lambda e: e.depth) @@ -1265,7 +1248,7 @@ class FederationHandler(BaseHandler): event_infos.append(_NewEventInfo(event, None, auth)) await self._handle_new_events( - destination, event_infos, + destination, room_id, event_infos, ) def _sanity_check_event(self, ev): @@ -1412,15 +1395,15 @@ class FederationHandler(BaseHandler): ) max_stream_id = await self._persist_auth_tree( - origin, auth_chain, state, event, room_version_obj + origin, room_id, auth_chain, state, event, room_version_obj ) # We wait here until this instance has seen the events come down # replication (if we're using replication) as the below uses caches. - # - # TODO: Currently the events stream is written to from master await self._replication.wait_for_stream_position( - self.config.worker.writers.events, "events", max_stream_id + self.config.worker.events_shard_config.get_instance(room_id), + "events", + max_stream_id, ) # Check whether this room is the result of an upgrade of a room we already know @@ -1599,11 +1582,6 @@ class FederationHandler(BaseHandler): event.signatures, ) - if event.type == EventTypes.Member: - if event.content["membership"] == Membership.JOIN: - user = UserID.from_string(event.state_key) - await self.user_joined_room(user, event.room_id) - prev_state_ids = await context.get_prev_state_ids() state_ids = list(prev_state_ids.values()) @@ -1674,7 +1652,7 @@ class FederationHandler(BaseHandler): ) context = await self.state_handler.compute_event_context(event) - await self.persist_events_and_notify([(event, context)]) + await self.persist_events_and_notify(event.room_id, [(event, context)]) return event @@ -1701,7 +1679,9 @@ class FederationHandler(BaseHandler): await self.federation_client.send_leave(host_list, event) context = await self.state_handler.compute_event_context(event) - stream_id = await self.persist_events_and_notify([(event, context)]) + stream_id = await self.persist_events_and_notify( + event.room_id, [(event, context)] + ) return event, stream_id @@ -1949,7 +1929,7 @@ class FederationHandler(BaseHandler): ) await self.persist_events_and_notify( - [(event, context)], backfilled=backfilled + event.room_id, [(event, context)], backfilled=backfilled ) except Exception: run_in_background( @@ -1962,6 +1942,7 @@ class FederationHandler(BaseHandler): async def _handle_new_events( self, origin: str, + room_id: str, event_infos: Iterable[_NewEventInfo], backfilled: bool = False, ) -> None: @@ -1993,6 +1974,7 @@ class FederationHandler(BaseHandler): ) await self.persist_events_and_notify( + room_id, [ (ev_info.event, context) for ev_info, context in zip(event_infos, contexts) @@ -2003,6 +1985,7 @@ class FederationHandler(BaseHandler): async def _persist_auth_tree( self, origin: str, + room_id: str, auth_events: List[EventBase], state: List[EventBase], event: EventBase, @@ -2017,6 +2000,7 @@ class FederationHandler(BaseHandler): Args: origin: Where the events came from + room_id, auth_events state event @@ -2091,17 +2075,20 @@ class FederationHandler(BaseHandler): events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR await self.persist_events_and_notify( + room_id, [ (e, events_to_context[e.event_id]) for e in itertools.chain(auth_events, state) - ] + ], ) new_event_context = await self.state_handler.compute_event_context( event, old_state=state ) - return await self.persist_events_and_notify([(event, new_event_context)]) + return await self.persist_events_and_notify( + room_id, [(event, new_event_context)] + ) async def _prep_event( self, @@ -2184,10 +2171,10 @@ class FederationHandler(BaseHandler): # given state at the event. This should correctly handle cases # like bans, especially with state res v2. - state_sets = await self.state_store.get_state_groups( + state_sets_d = await self.state_store.get_state_groups( event.room_id, extrem_ids ) - state_sets = list(state_sets.values()) + state_sets = list(state_sets_d.values()) # type: List[Iterable[EventBase]] state_sets.append(state) current_states = await self.state_handler.resolve_events( room_version, state_sets, event @@ -2952,6 +2939,7 @@ class FederationHandler(BaseHandler): async def persist_events_and_notify( self, + room_id: str, event_and_contexts: Sequence[Tuple[EventBase, EventContext]], backfilled: bool = False, ) -> int: @@ -2959,20 +2947,26 @@ class FederationHandler(BaseHandler): necessary. Args: - event_and_contexts: + room_id: The room ID of events being persisted. + event_and_contexts: Sequence of events with their associated + context that should be persisted. All events must belong to + the same room. backfilled: Whether these events are a result of backfilling or not """ - if self.config.worker.writers.events != self._instance_name: + instance = self.config.worker.events_shard_config.get_instance(room_id) + if instance != self._instance_name: result = await self._send_events( - instance_name=self.config.worker.writers.events, + instance_name=instance, store=self.store, + room_id=room_id, event_and_contexts=event_and_contexts, backfilled=backfilled, ) return result["max_stream_id"] else: - max_stream_id = await self.storage.persistence.persist_events( + assert self.storage.persistence + max_stream_token = await self.storage.persistence.persist_events( event_and_contexts, backfilled=backfilled ) @@ -2983,12 +2977,12 @@ class FederationHandler(BaseHandler): if not backfilled: # Never notify for backfilled events for event, _ in event_and_contexts: - await self._notify_persisted_event(event, max_stream_id) + await self._notify_persisted_event(event, max_stream_token) - return max_stream_id + return max_stream_token.stream async def _notify_persisted_event( - self, event: EventBase, max_stream_id: int + self, event: EventBase, max_stream_token: RoomStreamToken ) -> None: """Checks to see if notifier/pushers should be notified about the event or not. @@ -3014,13 +3008,13 @@ class FederationHandler(BaseHandler): elif event.internal_metadata.is_outlier(): return - event_stream_id = event.internal_metadata.stream_ordering + event_pos = PersistedEventPosition( + self._instance_name, event.internal_metadata.stream_ordering + ) self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, extra_users=extra_users + event, event_pos, max_stream_token, extra_users=extra_users ) - await self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id) - async def _clean_room_for_join(self, room_id: str) -> None: """Called to clean up any data in DB for a given room, ready for the server to join the room. @@ -3033,16 +3027,6 @@ class FederationHandler(BaseHandler): else: await self.store.clean_room_for_join(room_id) - async def user_joined_room(self, user: UserID, room_id: str) -> None: - """Called when a new user has joined the room - """ - if self.config.worker_app: - await self._notify_user_membership_change( - room_id=room_id, user_id=user.to_string(), change="joined" - ) - else: - user_joined_room(self.distributor, user, room_id) - async def get_room_complexity( self, remote_room_hosts: List[str], room_id: str ) -> Optional[dict]: diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index ba47e3a242..489a7b885d 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -240,7 +240,7 @@ class GroupsLocalWorkerHandler: class GroupsLocalHandler(GroupsLocalWorkerHandler): def __init__(self, hs): - super(GroupsLocalHandler, self).__init__(hs) + super().__init__(hs) # Ensure attestations get renewed hs.get_groups_attestation_renewer() diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 0ce6ddfbe4..bc3e9607ca 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -21,8 +21,6 @@ import logging import urllib.parse from typing import Awaitable, Callable, Dict, List, Optional, Tuple -from twisted.internet.error import TimeoutError - from synapse.api.errors import ( CodeMessageException, Codes, @@ -30,6 +28,7 @@ from synapse.api.errors import ( SynapseError, ) from synapse.config.emailconfig import ThreepidBehaviour +from synapse.http import RequestTimedOutError from synapse.http.client import SimpleHttpClient from synapse.types import JsonDict, Requester from synapse.util import json_decoder @@ -45,7 +44,7 @@ id_server_scheme = "https://" class IdentityHandler(BaseHandler): def __init__(self, hs): - super(IdentityHandler, self).__init__(hs) + super().__init__(hs) self.http_client = SimpleHttpClient(hs) # We create a blacklisting instance of SimpleHttpClient for contacting identity @@ -93,7 +92,7 @@ class IdentityHandler(BaseHandler): try: data = await self.http_client.get_json(url, query_params) - except TimeoutError: + except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") except HttpResponseException as e: logger.info( @@ -173,7 +172,7 @@ class IdentityHandler(BaseHandler): if e.code != 404 or not use_v2: logger.error("3PID bind failed with Matrix error: %r", e) raise e.to_synapse_error() - except TimeoutError: + except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") except CodeMessageException as e: data = json_decoder.decode(e.msg) # XXX WAT? @@ -273,7 +272,7 @@ class IdentityHandler(BaseHandler): else: logger.error("Failed to unbind threepid on identity server: %s", e) raise SynapseError(500, "Failed to contact identity server") - except TimeoutError: + except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") await self.store.remove_user_bound_threepid( @@ -419,7 +418,7 @@ class IdentityHandler(BaseHandler): except HttpResponseException as e: logger.info("Proxied requestToken failed: %r", e) raise e.to_synapse_error() - except TimeoutError: + except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") async def requestMsisdnToken( @@ -471,7 +470,7 @@ class IdentityHandler(BaseHandler): except HttpResponseException as e: logger.info("Proxied requestToken failed: %r", e) raise e.to_synapse_error() - except TimeoutError: + except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") assert self.hs.config.public_baseurl @@ -553,7 +552,7 @@ class IdentityHandler(BaseHandler): id_server + "/_matrix/identity/api/v1/validate/msisdn/submitToken", body, ) - except TimeoutError: + except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") except HttpResponseException as e: logger.warning("Error contacting msisdn account_threepid_delegate: %s", e) @@ -627,7 +626,7 @@ class IdentityHandler(BaseHandler): # require or validate it. See the following for context: # https://github.com/matrix-org/synapse/issues/5253#issuecomment-666246950 return data["mxid"] - except TimeoutError: + except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") except IOError as e: logger.warning("Error from v1 identity server lookup: %s" % (e,)) @@ -655,7 +654,7 @@ class IdentityHandler(BaseHandler): "%s%s/_matrix/identity/v2/hash_details" % (id_server_scheme, id_server), {"access_token": id_access_token}, ) - except TimeoutError: + except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") if not isinstance(hash_details, dict): @@ -727,7 +726,7 @@ class IdentityHandler(BaseHandler): }, headers=headers, ) - except TimeoutError: + except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") except Exception as e: logger.warning("Error when performing a v2 3pid lookup: %s", e) @@ -823,7 +822,7 @@ class IdentityHandler(BaseHandler): invite_config, {"Authorization": create_id_access_token_header(id_access_token)}, ) - except TimeoutError: + except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") except HttpResponseException as e: if e.code != 404: @@ -841,7 +840,7 @@ class IdentityHandler(BaseHandler): data = await self.blacklisting_http_client.post_json_get_json( url, invite_config ) - except TimeoutError: + except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") except HttpResponseException as e: logger.warning( diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index d5ddc583ad..39a85801c1 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -25,7 +25,7 @@ from synapse.handlers.presence import format_user_presence_state from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.storage.roommember import RoomsForUser from synapse.streams.config import PaginationConfig -from synapse.types import JsonDict, Requester, StreamToken, UserID +from synapse.types import JsonDict, Requester, RoomStreamToken, StreamToken, UserID from synapse.util import unwrapFirstError from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.response_cache import ResponseCache @@ -42,7 +42,7 @@ logger = logging.getLogger(__name__) class InitialSyncHandler(BaseHandler): def __init__(self, hs: "HomeServer"): - super(InitialSyncHandler, self).__init__(hs) + super().__init__(hs) self.hs = hs self.state = hs.get_state_handler() self.clock = hs.get_clock() @@ -116,14 +116,13 @@ class InitialSyncHandler(BaseHandler): now_token = self.hs.get_event_sources().get_current_token() presence_stream = self.hs.get_event_sources().sources["presence"] - pagination_config = PaginationConfig(from_token=now_token) - presence, _ = await presence_stream.get_pagination_rows( - user, pagination_config.get_source_config("presence"), None + presence, _ = await presence_stream.get_new_events( + user, from_key=None, include_offline=False ) - receipt_stream = self.hs.get_event_sources().sources["receipt"] - receipt, _ = await receipt_stream.get_pagination_rows( - user, pagination_config.get_source_config("receipt"), None + joined_rooms = [r.room_id for r in room_list if r.membership == Membership.JOIN] + receipt = await self.store.get_linearized_receipts_for_rooms( + joined_rooms, to_key=int(now_token.receipt_key), ) tags_by_room = await self.store.get_tags_for_user(user_id) @@ -168,7 +167,7 @@ class InitialSyncHandler(BaseHandler): self.state_handler.get_current_state, event.room_id ) elif event.membership == Membership.LEAVE: - room_end_token = "s%d" % (event.stream_ordering,) + room_end_token = RoomStreamToken(None, event.stream_ordering,) deferred_room_state = run_in_background( self.state_store.get_state_for_events, [event.event_id] ) @@ -204,8 +203,8 @@ class InitialSyncHandler(BaseHandler): messages, time_now=time_now, as_client_event=as_client_event ) ), - "start": start_token.to_string(), - "end": end_token.to_string(), + "start": await start_token.to_string(self.store), + "end": await end_token.to_string(self.store), } d["state"] = await self._event_serializer.serialize_events( @@ -250,7 +249,7 @@ class InitialSyncHandler(BaseHandler): ], "account_data": account_data_events, "receipts": receipt, - "end": now_token.to_string(), + "end": await now_token.to_string(self.store), } return ret @@ -326,7 +325,8 @@ class InitialSyncHandler(BaseHandler): if limit is None: limit = 10 - stream_token = await self.store.get_stream_token_for_event(member_event_id) + leave_position = await self.store.get_position_for_event(member_event_id) + stream_token = leave_position.to_room_stream_token() messages, token = await self.store.get_recent_events_for_room( room_id, limit=limit, end_token=stream_token @@ -348,8 +348,8 @@ class InitialSyncHandler(BaseHandler): "chunk": ( await self._event_serializer.serialize_events(messages, time_now) ), - "start": start_token.to_string(), - "end": end_token.to_string(), + "start": await start_token.to_string(self.store), + "end": await end_token.to_string(self.store), }, "state": ( await self._event_serializer.serialize_events( @@ -447,8 +447,8 @@ class InitialSyncHandler(BaseHandler): "chunk": ( await self._event_serializer.serialize_events(messages, time_now) ), - "start": start_token.to_string(), - "end": end_token.to_string(), + "start": await start_token.to_string(self.store), + "end": await end_token.to_string(self.store), }, "state": state, "presence": presence, diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 8a7b4916cd..ee271e85e5 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -376,9 +376,8 @@ class EventCreationHandler: self.notifier = hs.get_notifier() self.config = hs.config self.require_membership_for_aliases = hs.config.require_membership_for_aliases - self._is_event_writer = ( - self.config.worker.writers.events == hs.get_instance_name() - ) + self._events_shard_config = self.config.worker.events_shard_config + self._instance_name = hs.get_instance_name() self.room_invite_state_types = self.hs.config.room_invite_state_types @@ -387,8 +386,6 @@ class EventCreationHandler: # This is only used to get at ratelimit function, and maybe_kick_guest_users self.base_handler = BaseHandler(hs) - self.pusher_pool = hs.get_pusherpool() - # We arbitrarily limit concurrent event creation for a room to 5. # This is to stop us from diverging history *too* much. self.limiter = Linearizer(max_count=5, name="room_event_creation_limit") @@ -904,9 +901,10 @@ class EventCreationHandler: try: # If we're a worker we need to hit out to the master. - if not self._is_event_writer: + writer_instance = self._events_shard_config.get_instance(event.room_id) + if writer_instance != self._instance_name: result = await self.send_event( - instance_name=self.config.worker.writers.events, + instance_name=writer_instance, event_id=event.event_id, store=self.store, requester=requester, @@ -974,7 +972,10 @@ class EventCreationHandler: This should only be run on the instance in charge of persisting events. """ - assert self._is_event_writer + assert self.storage.persistence is not None + assert self._events_shard_config.should_handle( + self._instance_name, event.room_id + ) if ratelimit: # We check if this is a room admin redacting an event so that we @@ -1137,7 +1138,7 @@ class EventCreationHandler: if prev_state_ids: raise AuthError(403, "Changing the room create event is forbidden") - event_stream_id, max_stream_id = await self.storage.persistence.persist_event( + event_pos, max_stream_token = await self.storage.persistence.persist_event( event, context=context ) @@ -1145,12 +1146,10 @@ class EventCreationHandler: # If there's an expiry timestamp on the event, schedule its expiry. self._message_handler.maybe_schedule_expiry(event) - await self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id) - def _notify(): try: self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, extra_users=extra_users + event, event_pos, max_stream_token, extra_users=extra_users ) except Exception: logger.exception("Error notifying about new room event") @@ -1162,7 +1161,7 @@ class EventCreationHandler: # matters as sometimes presence code can take a while. run_in_background(self._bump_active_time, requester.user) - return event_stream_id + return event_pos.stream async def _bump_active_time(self, user: UserID) -> None: try: @@ -1183,54 +1182,7 @@ class EventCreationHandler: ) for room_id in room_ids: - # For each room we need to find a joined member we can use to send - # the dummy event with. - - latest_event_ids = await self.store.get_prev_events_for_room(room_id) - - members = await self.state.get_current_users_in_room( - room_id, latest_event_ids=latest_event_ids - ) - dummy_event_sent = False - for user_id in members: - if not self.hs.is_mine_id(user_id): - continue - requester = create_requester(user_id) - try: - event, context = await self.create_event( - requester, - { - "type": "org.matrix.dummy_event", - "content": {}, - "room_id": room_id, - "sender": user_id, - }, - prev_event_ids=latest_event_ids, - ) - - event.internal_metadata.proactively_send = False - - # Since this is a dummy-event it is OK if it is sent by a - # shadow-banned user. - await self.send_nonmember_event( - requester, - event, - context, - ratelimit=False, - ignore_shadow_ban=True, - ) - dummy_event_sent = True - break - except ConsentNotGivenError: - logger.info( - "Failed to send dummy event into room %s for user %s due to " - "lack of consent. Will try another user" % (room_id, user_id) - ) - except AuthError: - logger.info( - "Failed to send dummy event into room %s for user %s due to " - "lack of power. Will try another user" % (room_id, user_id) - ) + dummy_event_sent = await self._send_dummy_event_for_room(room_id) if not dummy_event_sent: # Did not find a valid user in the room, so remove from future attempts @@ -1243,6 +1195,59 @@ class EventCreationHandler: now = self.clock.time_msec() self._rooms_to_exclude_from_dummy_event_insertion[room_id] = now + async def _send_dummy_event_for_room(self, room_id: str) -> bool: + """Attempt to send a dummy event for the given room. + + Args: + room_id: room to try to send an event from + + Returns: + True if a dummy event was successfully sent. False if no user was able + to send an event. + """ + + # For each room we need to find a joined member we can use to send + # the dummy event with. + latest_event_ids = await self.store.get_prev_events_for_room(room_id) + members = await self.state.get_current_users_in_room( + room_id, latest_event_ids=latest_event_ids + ) + for user_id in members: + if not self.hs.is_mine_id(user_id): + continue + requester = create_requester(user_id) + try: + event, context = await self.create_event( + requester, + { + "type": "org.matrix.dummy_event", + "content": {}, + "room_id": room_id, + "sender": user_id, + }, + prev_event_ids=latest_event_ids, + ) + + event.internal_metadata.proactively_send = False + + # Since this is a dummy-event it is OK if it is sent by a + # shadow-banned user. + await self.send_nonmember_event( + requester, event, context, ratelimit=False, ignore_shadow_ban=True, + ) + return True + except ConsentNotGivenError: + logger.info( + "Failed to send dummy event into room %s for user %s due to " + "lack of consent. Will try another user" % (room_id, user_id) + ) + except AuthError: + logger.info( + "Failed to send dummy event into room %s for user %s due to " + "lack of power. Will try another user" % (room_id, user_id) + ) + return False + def _expire_rooms_to_exclude_from_dummy_event_insertion(self): expire_before = self.clock.time_msec() - _DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY to_expire = set() diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 1b06f3173f..19cd652675 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -37,7 +37,7 @@ from synapse.config import ConfigError from synapse.http.server import respond_with_html from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable -from synapse.types import UserID, map_username_to_mxid_localpart +from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart from synapse.util import json_decoder if TYPE_CHECKING: @@ -114,6 +114,7 @@ class OidcHandler: hs.config.oidc_user_mapping_provider_config ) # type: OidcMappingProvider self._skip_verification = hs.config.oidc_skip_verification # type: bool + self._allow_existing_users = hs.config.oidc_allow_existing_users # type: bool self._http_client = hs.get_proxied_http_client() self._auth_handler = hs.get_auth_handler() @@ -131,10 +132,10 @@ class OidcHandler: def _render_error( self, request, error: str, error_description: Optional[str] = None ) -> None: - """Renders the error template and respond with it. + """Render the error template and respond to the request with it. This is used to show errors to the user. The template of this page can - be found under ``synapse/res/templates/sso_error.html``. + be found under `synapse/res/templates/sso_error.html`. Args: request: The incoming request from the browser. @@ -706,6 +707,15 @@ class OidcHandler: self._render_error(request, "mapping_error", str(e)) return + # Mapping providers might not have get_extra_attributes: only call this + # method if it exists. + extra_attributes = None + get_extra_attributes = getattr( + self._user_mapping_provider, "get_extra_attributes", None + ) + if get_extra_attributes: + extra_attributes = await get_extra_attributes(userinfo, token) + # and finally complete the login if ui_auth_session_id: await self._auth_handler.complete_sso_ui_auth( @@ -713,7 +723,7 @@ class OidcHandler: ) else: await self._auth_handler.complete_sso_login( - user_id, request, client_redirect_url + user_id, request, client_redirect_url, extra_attributes ) def _generate_oidc_session_token( @@ -849,7 +859,8 @@ class OidcHandler: If we don't find the user that way, we should register the user, mapping the localpart and the display name from the UserInfo. - If a user already exists with the mxid we've mapped, raise an exception. + If a user already exists with the mxid we've mapped and allow_existing_users + is disabled, raise an exception. Args: userinfo: an object representing the user @@ -905,21 +916,31 @@ class OidcHandler: localpart = map_username_to_mxid_localpart(attributes["localpart"]) - user_id = UserID(localpart, self._hostname) - if await self._datastore.get_users_by_id_case_insensitive(user_id.to_string()): - # This mxid is taken - raise MappingException( - "mxid '{}' is already taken".format(user_id.to_string()) + user_id = UserID(localpart, self._hostname).to_string() + users = await self._datastore.get_users_by_id_case_insensitive(user_id) + if users: + if self._allow_existing_users: + if len(users) == 1: + registered_user_id = next(iter(users)) + elif user_id in users: + registered_user_id = user_id + else: + raise MappingException( + "Attempted to login as '{}' but it matches more than one user inexactly: {}".format( + user_id, list(users.keys()) + ) + ) + else: + # This mxid is taken + raise MappingException("mxid '{}' is already taken".format(user_id)) + else: + # It's the first time this user is logging in and the mapped mxid was + # not taken, register the user + registered_user_id = await self._registration_handler.register_user( + localpart=localpart, + default_display_name=attributes["display_name"], + user_agent_ips=(user_agent, ip_address), ) - - # It's the first time this user is logging in and the mapped mxid was - # not taken, register the user - registered_user_id = await self._registration_handler.register_user( - localpart=localpart, - default_display_name=attributes["display_name"], - user_agent_ips=(user_agent, ip_address), - ) - await self._datastore.record_user_external_id( self._auth_provider_id, remote_user_id, registered_user_id, ) @@ -972,7 +993,7 @@ class OidcMappingProvider(Generic[C]): async def map_user_attributes( self, userinfo: UserInfo, token: Token ) -> UserAttribute: - """Map a ``UserInfo`` objects into user attributes. + """Map a `UserInfo` object into user attributes. Args: userinfo: An object representing the user given by the OIDC provider @@ -983,6 +1004,18 @@ class OidcMappingProvider(Generic[C]): """ raise NotImplementedError() + async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict: + """Map a `UserInfo` object into additional attributes passed to the client during login. + + Args: + userinfo: An object representing the user given by the OIDC provider + token: A dict with the tokens returned by the provider + + Returns: + A dict containing additional attributes. Must be JSON serializable. + """ + return {} + # Used to clear out "None" values in templates def jinja_finalize(thing): @@ -997,6 +1030,7 @@ class JinjaOidcMappingConfig: subject_claim = attr.ib() # type: str localpart_template = attr.ib() # type: Template display_name_template = attr.ib() # type: Optional[Template] + extra_attributes = attr.ib() # type: Dict[str, Template] class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): @@ -1035,10 +1069,28 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): % (e,) ) + extra_attributes = {} # type Dict[str, Template] + if "extra_attributes" in config: + extra_attributes_config = config.get("extra_attributes") or {} + if not isinstance(extra_attributes_config, dict): + raise ConfigError( + "oidc_config.user_mapping_provider.config.extra_attributes must be a dict" + ) + + for key, value in extra_attributes_config.items(): + try: + extra_attributes[key] = env.from_string(value) + except Exception as e: + raise ConfigError( + "invalid jinja template for oidc_config.user_mapping_provider.config.extra_attributes.%s: %r" + % (key, e) + ) + return JinjaOidcMappingConfig( subject_claim=subject_claim, localpart_template=localpart_template, display_name_template=display_name_template, + extra_attributes=extra_attributes, ) def get_remote_user_id(self, userinfo: UserInfo) -> str: @@ -1059,3 +1111,13 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): display_name = None return UserAttribute(localpart=localpart, display_name=display_name) + + async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict: + extras = {} # type: Dict[str, str] + for key, template in self._config.extra_attributes.items(): + try: + extras[key] = template.render(user=userinfo).strip() + except Exception as e: + # Log an error and skip this value (don't break login for this). + logger.error("Failed to render OIDC extra attribute %s: %s" % (key, e)) + return extras diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 6067585f9b..2c2a633938 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -25,7 +25,7 @@ from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig -from synapse.types import Requester, RoomStreamToken +from synapse.types import Requester from synapse.util.async_helpers import ReadWriteLock from synapse.util.stringutils import random_string from synapse.visibility import filter_events_for_client @@ -335,20 +335,16 @@ class PaginationHandler: user_id = requester.user.to_string() if pagin_config.from_token: - room_token = pagin_config.from_token.room_key + from_token = pagin_config.from_token else: - pagin_config.from_token = ( - self.hs.get_event_sources().get_current_token_for_pagination() - ) - room_token = pagin_config.from_token.room_key - - room_token = RoomStreamToken.parse(room_token) + from_token = self.hs.get_event_sources().get_current_token_for_pagination() - pagin_config.from_token = pagin_config.from_token.copy_and_replace( - "room_key", str(room_token) - ) + if pagin_config.limit is None: + # This shouldn't happen as we've set a default limit before this + # gets called. + raise Exception("limit not set") - source_config = pagin_config.get_source_config("room") + room_token = from_token.room_key with await self.pagination_lock.read(room_id): ( @@ -358,7 +354,7 @@ class PaginationHandler: room_id, user_id, allow_departed_users=True ) - if source_config.direction == "b": + if pagin_config.direction == "b": # if we're going backwards, we might need to backfill. This # requires that we have a topo token. if room_token.topological: @@ -380,23 +376,31 @@ class PaginationHandler: leave_token = await self.store.get_topological_token_for_event( member_event_id ) - if RoomStreamToken.parse(leave_token).topological < curr_topo: - source_config.from_key = str(leave_token) + assert leave_token.topological is not None + + if leave_token.topological < curr_topo: + from_token = from_token.copy_and_replace( + "room_key", leave_token + ) await self.hs.get_handlers().federation_handler.maybe_backfill( - room_id, curr_topo, limit=source_config.limit, + room_id, curr_topo, limit=pagin_config.limit, ) + to_room_key = None + if pagin_config.to_token: + to_room_key = pagin_config.to_token.room_key + events, next_key = await self.store.paginate_room_events( room_id=room_id, - from_key=source_config.from_key, - to_key=source_config.to_key, - direction=source_config.direction, - limit=source_config.limit, + from_key=from_token.room_key, + to_key=to_room_key, + direction=pagin_config.direction, + limit=pagin_config.limit, event_filter=event_filter, ) - next_token = pagin_config.from_token.copy_and_replace("room_key", next_key) + next_token = from_token.copy_and_replace("room_key", next_key) if events: if event_filter: @@ -409,8 +413,8 @@ class PaginationHandler: if not events: return { "chunk": [], - "start": pagin_config.from_token.to_string(), - "end": next_token.to_string(), + "start": await from_token.to_string(self.store), + "end": await next_token.to_string(self.store), } state = None @@ -438,8 +442,8 @@ class PaginationHandler: events, time_now, as_client_event=as_client_event ) ), - "start": pagin_config.from_token.to_string(), - "end": next_token.to_string(), + "start": await from_token.to_string(self.store), + "end": await next_token.to_string(self.store), } if state: diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 91a3aec1cc..1000ac95ff 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -1108,9 +1108,6 @@ class PresenceEventSource: def get_current_key(self): return self.store.get_current_presence_token() - async def get_pagination_rows(self, user, pagination_config, key): - return await self.get_new_events(user, from_key=None, include_offline=False) - @cached(num_args=2, cache_context=True) async def _get_interested_in(self, user, explicit_room_id, cache_context): """Returns the set of users that the given user should see presence diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 0cb8fad89a..5453e6dfc8 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -44,7 +44,7 @@ class BaseProfileHandler(BaseHandler): """ def __init__(self, hs): - super(BaseProfileHandler, self).__init__(hs) + super().__init__(hs) self.federation = hs.get_federation_client() hs.get_federation_registry().register_query_handler( @@ -369,7 +369,7 @@ class MasterProfileHandler(BaseProfileHandler): PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000 def __init__(self, hs): - super(MasterProfileHandler, self).__init__(hs) + super().__init__(hs) assert hs.config.worker_app is None diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py index e3b528d271..c32f314a1c 100644 --- a/synapse/handlers/read_marker.py +++ b/synapse/handlers/read_marker.py @@ -24,7 +24,7 @@ logger = logging.getLogger(__name__) class ReadMarkerHandler(BaseHandler): def __init__(self, hs): - super(ReadMarkerHandler, self).__init__(hs) + super().__init__(hs) self.server_name = hs.config.server_name self.store = hs.get_datastore() self.read_marker_linearizer = Linearizer(name="read_marker") diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 2cc6c2eb68..7225923757 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -23,7 +23,7 @@ logger = logging.getLogger(__name__) class ReceiptsHandler(BaseHandler): def __init__(self, hs): - super(ReceiptsHandler, self).__init__(hs) + super().__init__(hs) self.server_name = hs.config.server_name self.store = hs.get_datastore() @@ -142,18 +142,3 @@ class ReceiptEventSource: def get_current_key(self, direction="f"): return self.store.get_max_receipt_stream_id() - - async def get_pagination_rows(self, user, config, key): - to_key = int(config.from_key) - - if config.to_key: - from_key = int(config.to_key) - else: - from_key = None - - room_ids = await self.store.get_rooms_for_user(user.to_string()) - events = await self.store.get_linearized_receipts_for_rooms( - room_ids, from_key=from_key, to_key=to_key - ) - - return (events, to_key) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index cde2dbca92..538f4b2a61 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -42,7 +42,7 @@ class RegistrationHandler(BaseHandler): Args: hs (synapse.server.HomeServer): """ - super(RegistrationHandler, self).__init__(hs) + super().__init__(hs) self.hs = hs self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index a29305f655..d5f7c78edf 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -70,7 +70,7 @@ FIVE_MINUTES_IN_MS = 5 * 60 * 1000 class RoomCreationHandler(BaseHandler): def __init__(self, hs: "HomeServer"): - super(RoomCreationHandler, self).__init__(hs) + super().__init__(hs) self.spam_checker = hs.get_spam_checker() self.event_creation_handler = hs.get_event_creation_handler() @@ -804,7 +804,9 @@ class RoomCreationHandler(BaseHandler): # Always wait for room creation to progate before returning await self._replication.wait_for_stream_position( - self.hs.config.worker.writers.events, "events", last_stream_id + self.hs.config.worker.events_shard_config.get_instance(room_id), + "events", + last_stream_id, ) return result, last_stream_id @@ -1075,11 +1077,13 @@ class RoomContextHandler: # the token, which we replace. token = StreamToken.START - results["start"] = token.copy_and_replace( + results["start"] = await token.copy_and_replace( "room_key", results["start"] - ).to_string() + ).to_string(self.store) - results["end"] = token.copy_and_replace("room_key", results["end"]).to_string() + results["end"] = await token.copy_and_replace( + "room_key", results["end"] + ).to_string(self.store) return results @@ -1091,20 +1095,19 @@ class RoomEventSource: async def get_new_events( self, user: UserID, - from_key: str, + from_key: RoomStreamToken, limit: int, room_ids: List[str], is_guest: bool, explicit_room_id: Optional[str] = None, - ) -> Tuple[List[EventBase], str]: + ) -> Tuple[List[EventBase], RoomStreamToken]: # We just ignore the key for now. to_key = self.get_current_key() - from_token = RoomStreamToken.parse(from_key) - if from_token.topological: + if from_key.topological: logger.warning("Stream has topological part!!!! %r", from_key) - from_key = "s%s" % (from_token.stream,) + from_key = RoomStreamToken(None, from_key.stream) app_service = self.store.get_app_service_by_user_id(user.to_string()) if app_service: @@ -1139,8 +1142,8 @@ class RoomEventSource: return (events, end_key) - def get_current_key(self) -> str: - return "s%d" % (self.store.get_room_max_stream_ordering(),) + def get_current_key(self) -> RoomStreamToken: + return self.store.get_room_max_token() def get_current_key_for_room(self, room_id: str) -> Awaitable[str]: return self.store.get_room_events_max_id(room_id) @@ -1260,10 +1263,10 @@ class RoomShutdownHandler: # We now wait for the create room to come back in via replication so # that we can assume that all the joins/invites have propogated before # we try and auto join below. - # - # TODO: Currently the events stream is written to from master await self._replication.wait_for_stream_position( - self.hs.config.worker.writers.events, "events", stream_id + self.hs.config.worker.events_shard_config.get_instance(new_room_id), + "events", + stream_id, ) else: new_room_id = None @@ -1293,7 +1296,9 @@ class RoomShutdownHandler: # Wait for leave to come in over replication before trying to forget. await self._replication.wait_for_stream_position( - self.hs.config.worker.writers.events, "events", stream_id + self.hs.config.worker.events_shard_config.get_instance(room_id), + "events", + stream_id, ) await self.room_member_handler.forget(target_requester.user, room_id) diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 5dd7b28391..4a13c8e912 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -38,7 +38,7 @@ EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None) class RoomListHandler(BaseHandler): def __init__(self, hs): - super(RoomListHandler, self).__init__(hs) + super().__init__(hs) self.enable_room_list_search = hs.config.enable_room_list_search self.response_cache = ResponseCache(hs, "room_list") self.remote_response_cache = ResponseCache( diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 32b7e323fa..8feba8c90a 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -40,7 +40,7 @@ from synapse.events.validator import EventValidator from synapse.storage.roommember import RoomsForUser from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID from synapse.util.async_helpers import Linearizer -from synapse.util.distributor import user_joined_room, user_left_room +from synapse.util.distributor import user_left_room from ._base import BaseHandler @@ -51,14 +51,12 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class RoomMemberHandler: +class RoomMemberHandler(metaclass=abc.ABCMeta): # TODO(paul): This handler currently contains a messy conflation of # low-level API that works on UserID objects and so on, and REST-level # API that takes ID strings and returns pagination chunks. These concerns # ought to be separated out a lot better. - __metaclass__ = abc.ABCMeta - def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastore() @@ -82,13 +80,6 @@ class RoomMemberHandler: self._enable_lookup = hs.config.enable_3pid_lookup self.allow_per_room_profiles = self.config.allow_per_room_profiles - self._event_stream_writer_instance = hs.config.worker.writers.events - self._is_on_event_persistence_instance = ( - self._event_stream_writer_instance == hs.get_instance_name() - ) - if self._is_on_event_persistence_instance: - self.persist_event_storage = hs.get_storage().persistence - self._join_rate_limiter_local = Ratelimiter( clock=self.clock, rate_hz=hs.config.ratelimiting.rc_joins_local.per_second, @@ -149,17 +140,6 @@ class RoomMemberHandler: raise NotImplementedError() @abc.abstractmethod - async def _user_joined_room(self, target: UserID, room_id: str) -> None: - """Notifies distributor on master process that the user has joined the - room. - - Args: - target - room_id - """ - raise NotImplementedError() - - @abc.abstractmethod async def _user_left_room(self, target: UserID, room_id: str) -> None: """Notifies distributor on master process that the user has left the room. @@ -221,7 +201,6 @@ class RoomMemberHandler: prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) - newly_joined = False if event.membership == Membership.JOIN: newly_joined = True if prev_member_event_id: @@ -246,12 +225,7 @@ class RoomMemberHandler: requester, event, context, extra_users=[target], ratelimit=ratelimit, ) - if event.membership == Membership.JOIN and newly_joined: - # Only fire user_joined_room if the user has actually joined the - # room. Don't bother if the user is just changing their profile - # info. - await self._user_joined_room(target, room_id) - elif event.membership == Membership.LEAVE: + if event.membership == Membership.LEAVE: if prev_member_event_id: prev_member_event = await self.store.get_event(prev_member_event_id) if prev_member_event.membership == Membership.JOIN: @@ -726,17 +700,7 @@ class RoomMemberHandler: (EventTypes.Member, event.state_key), None ) - if event.membership == Membership.JOIN: - # Only fire user_joined_room if the user has actually joined the - # room. Don't bother if the user is just changing their profile - # info. - newly_joined = True - if prev_member_event_id: - prev_member_event = await self.store.get_event(prev_member_event_id) - newly_joined = prev_member_event.membership != Membership.JOIN - if newly_joined: - await self._user_joined_room(target_user, room_id) - elif event.membership == Membership.LEAVE: + if event.membership == Membership.LEAVE: if prev_member_event_id: prev_member_event = await self.store.get_event(prev_member_event_id) if prev_member_event.membership == Membership.JOIN: @@ -1002,10 +966,9 @@ class RoomMemberHandler: class RoomMemberMasterHandler(RoomMemberHandler): def __init__(self, hs): - super(RoomMemberMasterHandler, self).__init__(hs) + super().__init__(hs) self.distributor = hs.get_distributor() - self.distributor.declare("user_joined_room") self.distributor.declare("user_left_room") async def _is_remote_room_too_complex( @@ -1085,7 +1048,6 @@ class RoomMemberMasterHandler(RoomMemberHandler): event_id, stream_id = await self.federation_handler.do_invite_join( remote_room_hosts, room_id, user.to_string(), content ) - await self._user_joined_room(user, room_id) # Check the room we just joined wasn't too large, if we didn't fetch the # complexity of it before. @@ -1228,11 +1190,6 @@ class RoomMemberMasterHandler(RoomMemberHandler): ) return event.event_id, stream_id - async def _user_joined_room(self, target: UserID, room_id: str) -> None: - """Implements RoomMemberHandler._user_joined_room - """ - user_joined_room(self.distributor, target, room_id) - async def _user_left_room(self, target: UserID, room_id: str) -> None: """Implements RoomMemberHandler._user_left_room """ diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py index 897338fd54..f2e88f6a5b 100644 --- a/synapse/handlers/room_member_worker.py +++ b/synapse/handlers/room_member_worker.py @@ -30,7 +30,7 @@ logger = logging.getLogger(__name__) class RoomMemberWorkerHandler(RoomMemberHandler): def __init__(self, hs): - super(RoomMemberWorkerHandler, self).__init__(hs) + super().__init__(hs) self._remote_join_client = ReplRemoteJoin.make_client(hs) self._remote_reject_client = ReplRejectInvite.make_client(hs) @@ -57,8 +57,6 @@ class RoomMemberWorkerHandler(RoomMemberHandler): content=content, ) - await self._user_joined_room(user, room_id) - return ret["event_id"], ret["stream_id"] async def remote_reject_invite( @@ -81,13 +79,6 @@ class RoomMemberWorkerHandler(RoomMemberHandler): ) return ret["event_id"], ret["stream_id"] - async def _user_joined_room(self, target: UserID, room_id: str) -> None: - """Implements RoomMemberHandler._user_joined_room - """ - await self._notify_change_client( - user_id=target.to_string(), room_id=room_id, change="joined" - ) - async def _user_left_room(self, target: UserID, room_id: str) -> None: """Implements RoomMemberHandler._user_left_room """ diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index 66b063f991..285c481a96 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -21,9 +21,10 @@ import saml2 import saml2.response from saml2.client import Saml2Client -from synapse.api.errors import AuthError, SynapseError +from synapse.api.errors import SynapseError from synapse.config import ConfigError from synapse.config.saml2_config import SamlAttributeRequirement +from synapse.http.server import respond_with_html from synapse.http.servlet import parse_string from synapse.http.site import SynapseRequest from synapse.module_api import ModuleApi @@ -41,7 +42,11 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -@attr.s +class MappingException(Exception): + """Used to catch errors when mapping the SAML2 response to a user.""" + + +@attr.s(slots=True) class Saml2SessionData: """Data we track about SAML2 sessions""" @@ -68,6 +73,7 @@ class SamlHandler: hs.config.saml2_grandfathered_mxid_source_attribute ) self._saml2_attribute_requirements = hs.config.saml2.attribute_requirements + self._error_template = hs.config.sso_error_template # plugin to do custom mapping from saml response to mxid self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class( @@ -84,6 +90,25 @@ class SamlHandler: # a lock on the mappings self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock) + def _render_error( + self, request, error: str, error_description: Optional[str] = None + ) -> None: + """Render the error template and respond to the request with it. + + This is used to show errors to the user. The template of this page can + be found under `synapse/res/templates/sso_error.html`. + + Args: + request: The incoming request from the browser. + We'll respond with an HTML page describing the error. + error: A technical identifier for this error. + error_description: A human-readable description of the error. + """ + html = self._error_template.render( + error=error, error_description=error_description + ) + respond_with_html(request, 400, html) + def handle_redirect_request( self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None ) -> bytes: @@ -134,49 +159,6 @@ class SamlHandler: # the dict. self.expire_sessions() - # Pull out the user-agent and IP from the request. - user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[ - 0 - ].decode("ascii", "surrogateescape") - ip_address = self.hs.get_ip_from_request(request) - - user_id, current_session = await self._map_saml_response_to_user( - resp_bytes, relay_state, user_agent, ip_address - ) - - # Complete the interactive auth session or the login. - if current_session and current_session.ui_auth_session_id: - await self._auth_handler.complete_sso_ui_auth( - user_id, current_session.ui_auth_session_id, request - ) - - else: - await self._auth_handler.complete_sso_login(user_id, request, relay_state) - - async def _map_saml_response_to_user( - self, - resp_bytes: str, - client_redirect_url: str, - user_agent: str, - ip_address: str, - ) -> Tuple[str, Optional[Saml2SessionData]]: - """ - Given a sample response, retrieve the cached session and user for it. - - Args: - resp_bytes: The SAML response. - client_redirect_url: The redirect URL passed in by the client. - user_agent: The user agent of the client making the request. - ip_address: The IP address of the client making the request. - - Returns: - Tuple of the user ID and SAML session associated with this response. - - Raises: - SynapseError if there was a problem with the response. - RedirectException: some mapping providers may raise this if they need - to redirect to an interstitial page. - """ try: saml2_auth = self._saml_client.parse_authn_request_response( resp_bytes, @@ -189,12 +171,23 @@ class SamlHandler: # in the (user-visible) exception message, so let's log the exception here # so we can track down the session IDs later. logger.warning(str(e)) - raise SynapseError(400, "Unexpected SAML2 login.") + self._render_error( + request, "unsolicited_response", "Unexpected SAML2 login." + ) + return except Exception as e: - raise SynapseError(400, "Unable to parse SAML2 response: %s." % (e,)) + self._render_error( + request, + "invalid_response", + "Unable to parse SAML2 response: %s." % (e,), + ) + return if saml2_auth.not_signed: - raise SynapseError(400, "SAML2 response was not signed.") + self._render_error( + request, "unsigned_respond", "SAML2 response was not signed." + ) + return logger.debug("SAML2 response: %s", saml2_auth.origxml) for assertion in saml2_auth.assertions: @@ -213,15 +206,73 @@ class SamlHandler: saml2_auth.in_response_to, None ) + # Ensure that the attributes of the logged in user meet the required + # attributes. for requirement in self._saml2_attribute_requirements: - _check_attribute_requirement(saml2_auth.ava, requirement) + if not _check_attribute_requirement(saml2_auth.ava, requirement): + self._render_error( + request, "unauthorised", "You are not authorised to log in here." + ) + return + + # Pull out the user-agent and IP from the request. + user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[ + 0 + ].decode("ascii", "surrogateescape") + ip_address = self.hs.get_ip_from_request(request) + + # Call the mapper to register/login the user + try: + user_id = await self._map_saml_response_to_user( + saml2_auth, relay_state, user_agent, ip_address + ) + except MappingException as e: + logger.exception("Could not map user") + self._render_error(request, "mapping_error", str(e)) + return + + # Complete the interactive auth session or the login. + if current_session and current_session.ui_auth_session_id: + await self._auth_handler.complete_sso_ui_auth( + user_id, current_session.ui_auth_session_id, request + ) + + else: + await self._auth_handler.complete_sso_login(user_id, request, relay_state) + + async def _map_saml_response_to_user( + self, + saml2_auth: saml2.response.AuthnResponse, + client_redirect_url: str, + user_agent: str, + ip_address: str, + ) -> str: + """ + Given a SAML response, retrieve the user ID for it and possibly register the user. + + Args: + saml2_auth: The parsed SAML2 response. + client_redirect_url: The redirect URL passed in by the client. + user_agent: The user agent of the client making the request. + ip_address: The IP address of the client making the request. + + Returns: + The user ID associated with this response. + + Raises: + MappingException if there was a problem mapping the response to a user. + RedirectException: some mapping providers may raise this if they need + to redirect to an interstitial page. + """ remote_user_id = self._user_mapping_provider.get_remote_user_id( saml2_auth, client_redirect_url ) if not remote_user_id: - raise Exception("Failed to extract remote user id from SAML response") + raise MappingException( + "Failed to extract remote user id from SAML response" + ) with (await self._mapping_lock.queue(self._auth_provider_id)): # first of all, check if we already have a mapping for this user @@ -235,7 +286,7 @@ class SamlHandler: ) if registered_user_id is not None: logger.info("Found existing mapping %s", registered_user_id) - return registered_user_id, current_session + return registered_user_id # backwards-compatibility hack: see if there is an existing user with a # suitable mapping from the uid @@ -260,7 +311,7 @@ class SamlHandler: await self._datastore.record_user_external_id( self._auth_provider_id, remote_user_id, registered_user_id ) - return registered_user_id, current_session + return registered_user_id # Map saml response to user attributes using the configured mapping provider for i in range(1000): @@ -277,7 +328,7 @@ class SamlHandler: localpart = attribute_dict.get("mxid_localpart") if not localpart: - raise Exception( + raise MappingException( "Error parsing SAML2 response: SAML mapping provider plugin " "did not return a mxid_localpart value" ) @@ -294,8 +345,8 @@ class SamlHandler: else: # Unable to generate a username in 1000 iterations # Break and return error to the user - raise SynapseError( - 500, "Unable to generate a Matrix ID from the SAML response" + raise MappingException( + "Unable to generate a Matrix ID from the SAML response" ) logger.info("Mapped SAML user to local part %s", localpart) @@ -310,7 +361,7 @@ class SamlHandler: await self._datastore.record_user_external_id( self._auth_provider_id, remote_user_id, registered_user_id ) - return registered_user_id, current_session + return registered_user_id def expire_sessions(self): expire_before = self._clock.time_msec() - self._saml2_session_lifetime @@ -323,11 +374,11 @@ class SamlHandler: del self._outstanding_requests_dict[reqid] -def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement): +def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement) -> bool: values = ava.get(req.attribute, []) for v in values: if v == req.value: - return + return True logger.info( "SAML2 attribute %s did not match required value '%s' (was '%s')", @@ -335,7 +386,7 @@ def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement): req.value, values, ) - raise AuthError(403, "You are not authorized to log in here.") + return False DOT_REPLACE_PATTERN = re.compile( @@ -390,7 +441,7 @@ class DefaultSamlMappingProvider: return saml_response.ava["uid"][0] except KeyError: logger.warning("SAML2 response lacks a 'uid' attestation") - raise SynapseError(400, "'uid' not in SAML2 response") + raise MappingException("'uid' not in SAML2 response") def saml_response_to_user_attributes( self, diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index d58f9788c5..e9402e6e2e 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -32,7 +32,7 @@ logger = logging.getLogger(__name__) class SearchHandler(BaseHandler): def __init__(self, hs): - super(SearchHandler, self).__init__(hs) + super().__init__(hs) self._event_serializer = hs.get_event_client_serializer() self.storage = hs.get_storage() self.state_store = self.storage.state @@ -362,13 +362,13 @@ class SearchHandler(BaseHandler): self.storage, user.to_string(), res["events_after"] ) - res["start"] = now_token.copy_and_replace( + res["start"] = await now_token.copy_and_replace( "room_key", res["start"] - ).to_string() + ).to_string(self.store) - res["end"] = now_token.copy_and_replace( + res["end"] = await now_token.copy_and_replace( "room_key", res["end"] - ).to_string() + ).to_string(self.store) if include_profile: senders = { diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py index 4d245b618b..a5d67f828f 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py @@ -27,7 +27,7 @@ class SetPasswordHandler(BaseHandler): """Handler which deals with changing user account passwords""" def __init__(self, hs): - super(SetPasswordHandler, self).__init__(hs) + super().__init__(hs) self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() self._password_policy_handler = hs.get_password_policy_handler() diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index e2ddb628ff..bfe2583002 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -89,14 +89,12 @@ class TimelineBatch: events = attr.ib(type=List[EventBase]) limited = attr.ib(bool) - def __nonzero__(self) -> bool: + def __bool__(self) -> bool: """Make the result appear empty if there are no updates. This is used to tell if room needs to be part of the sync result. """ return bool(self.events) - __bool__ = __nonzero__ # python3 - # We can't freeze this class, because we need to update it after it's instantiated to # update its unread count. This is because we calculate the unread count for a room only @@ -114,7 +112,7 @@ class JoinedSyncResult: summary = attr.ib(type=Optional[JsonDict]) unread_count = attr.ib(type=int) - def __nonzero__(self) -> bool: + def __bool__(self) -> bool: """Make the result appear empty if there are no updates. This is used to tell if room needs to be part of the sync result. """ @@ -127,8 +125,6 @@ class JoinedSyncResult: # else in the result, we don't need to send it. ) - __bool__ = __nonzero__ # python3 - @attr.s(slots=True, frozen=True) class ArchivedSyncResult: @@ -137,26 +133,22 @@ class ArchivedSyncResult: state = attr.ib(type=StateMap[EventBase]) account_data = attr.ib(type=List[JsonDict]) - def __nonzero__(self) -> bool: + def __bool__(self) -> bool: """Make the result appear empty if there are no updates. This is used to tell if room needs to be part of the sync result. """ return bool(self.timeline or self.state or self.account_data) - __bool__ = __nonzero__ # python3 - @attr.s(slots=True, frozen=True) class InvitedSyncResult: room_id = attr.ib(type=str) invite = attr.ib(type=EventBase) - def __nonzero__(self) -> bool: + def __bool__(self) -> bool: """Invited rooms should always be reported to the client""" return True - __bool__ = __nonzero__ # python3 - @attr.s(slots=True, frozen=True) class GroupsSyncResult: @@ -164,11 +156,9 @@ class GroupsSyncResult: invite = attr.ib(type=JsonDict) leave = attr.ib(type=JsonDict) - def __nonzero__(self) -> bool: + def __bool__(self) -> bool: return bool(self.join or self.invite or self.leave) - __bool__ = __nonzero__ # python3 - @attr.s(slots=True, frozen=True) class DeviceLists: @@ -181,13 +171,11 @@ class DeviceLists: changed = attr.ib(type=Collection[str]) left = attr.ib(type=Collection[str]) - def __nonzero__(self) -> bool: + def __bool__(self) -> bool: return bool(self.changed or self.left) - __bool__ = __nonzero__ # python3 - -@attr.s +@attr.s(slots=True) class _RoomChanges: """The set of room entries to include in the sync, plus the set of joined and left room IDs since last sync. @@ -227,7 +215,7 @@ class SyncResult: device_one_time_keys_count = attr.ib(type=JsonDict) groups = attr.ib(type=Optional[GroupsSyncResult]) - def __nonzero__(self) -> bool: + def __bool__(self) -> bool: """Make the result appear empty if there are no updates. This is used to tell if the notifier needs to wait for more events when polling for events. @@ -243,8 +231,6 @@ class SyncResult: or self.groups ) - __bool__ = __nonzero__ # python3 - class SyncHandler: def __init__(self, hs: "HomeServer"): @@ -378,7 +364,7 @@ class SyncHandler: sync_config = sync_result_builder.sync_config with Measure(self.clock, "ephemeral_by_room"): - typing_key = since_token.typing_key if since_token else "0" + typing_key = since_token.typing_key if since_token else 0 room_ids = sync_result_builder.joined_room_ids @@ -402,7 +388,7 @@ class SyncHandler: event_copy = {k: v for (k, v) in event.items() if k != "room_id"} ephemeral_by_room.setdefault(room_id, []).append(event_copy) - receipt_key = since_token.receipt_key if since_token else "0" + receipt_key = since_token.receipt_key if since_token else 0 receipt_source = self.event_sources.sources["receipt"] receipts, receipt_key = await receipt_source.get_new_events( @@ -981,7 +967,7 @@ class SyncHandler: raise NotImplementedError() else: joined_room_ids = await self.get_rooms_for_user_at( - user_id, now_token.room_stream_id + user_id, now_token.room_key ) sync_result_builder = SyncResultBuilder( sync_config, @@ -1310,12 +1296,11 @@ class SyncHandler: presence_source = self.event_sources.sources["presence"] since_token = sync_result_builder.since_token + presence_key = None + include_offline = False if since_token and not sync_result_builder.full_state: presence_key = since_token.presence_key include_offline = True - else: - presence_key = None - include_offline = False presence, presence_key = await presence_source.get_new_events( user=user, @@ -1323,6 +1308,7 @@ class SyncHandler: is_guest=sync_config.is_guest, include_offline=include_offline, ) + assert presence_key sync_result_builder.now_token = now_token.copy_and_replace( "presence_key", presence_key ) @@ -1485,7 +1471,7 @@ class SyncHandler: if rooms_changed: return True - stream_id = RoomStreamToken.parse_stream_token(since_token.room_key).stream + stream_id = since_token.room_key.stream for room_id in sync_result_builder.joined_room_ids: if self.store.has_room_changed_since(room_id, stream_id): return True @@ -1609,16 +1595,24 @@ class SyncHandler: if leave_events: leave_event = leave_events[-1] - leave_stream_token = await self.store.get_stream_token_for_event( + leave_position = await self.store.get_position_for_event( leave_event.event_id ) - leave_token = since_token.copy_and_replace( - "room_key", leave_stream_token - ) - if since_token and since_token.is_after(leave_token): + # If the leave event happened before the since token then we + # bail. + if since_token and not leave_position.persisted_after( + since_token.room_key + ): continue + # We can safely convert the position of the leave event into a + # stream token as it'll only be used in the context of this + # room. (c.f. the docstring of `to_room_stream_token`). + leave_token = since_token.copy_and_replace( + "room_key", leave_position.to_room_stream_token() + ) + # If this is an out of band message, like a remote invite # rejection, we include it in the recents batch. Otherwise, we # let _load_filtered_recents handle fetching the correct @@ -1751,7 +1745,7 @@ class SyncHandler: continue leave_token = now_token.copy_and_replace( - "room_key", "s%d" % (event.stream_ordering,) + "room_key", RoomStreamToken(None, event.stream_ordering) ) room_entries.append( RoomSyncResultBuilder( @@ -1930,7 +1924,7 @@ class SyncHandler: raise Exception("Unrecognized rtype: %r", room_builder.rtype) async def get_rooms_for_user_at( - self, user_id: str, stream_ordering: int + self, user_id: str, room_key: RoomStreamToken ) -> FrozenSet[str]: """Get set of joined rooms for a user at the given stream ordering. @@ -1956,15 +1950,15 @@ class SyncHandler: # If the membership's stream ordering is after the given stream # ordering, we need to go and work out if the user was in the room # before. - for room_id, membership_stream_ordering in joined_rooms: - if membership_stream_ordering <= stream_ordering: + for room_id, event_pos in joined_rooms: + if not event_pos.persisted_after(room_key): joined_room_ids.add(room_id) continue logger.info("User joined room after current token: %s", room_id) extrems = await self.store.get_forward_extremeties_for_room( - room_id, stream_ordering + room_id, event_pos.stream ) users_in_room = await self.state.get_current_users_in_room(room_id, extrems) if user_id in users_in_room: @@ -2038,7 +2032,7 @@ def _calculate_state( return {event_id_to_key[e]: e for e in state_ids} -@attr.s +@attr.s(slots=True) class SyncResultBuilder: """Used to help build up a new SyncResult for a user @@ -2074,7 +2068,7 @@ class SyncResultBuilder: to_device = attr.ib(type=List[JsonDict], default=attr.Factory(list)) -@attr.s +@attr.s(slots=True) class RoomSyncResultBuilder: """Stores information needed to create either a `JoinedSyncResult` or `ArchivedSyncResult`. diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index e21f8dbc58..79393c8829 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -37,7 +37,7 @@ class UserDirectoryHandler(StateDeltasHandler): """ def __init__(self, hs): - super(UserDirectoryHandler, self).__init__(hs) + super().__init__(hs) self.store = hs.get_datastore() self.state = hs.get_state_handler() diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py index 3880ce0d94..59b01b812c 100644 --- a/synapse/http/__init__.py +++ b/synapse/http/__init__.py @@ -16,8 +16,6 @@ import re from twisted.internet import task -from twisted.internet.defer import CancelledError -from twisted.python import failure from twisted.web.client import FileBodyProducer from synapse.api.errors import SynapseError @@ -26,19 +24,8 @@ from synapse.api.errors import SynapseError class RequestTimedOutError(SynapseError): """Exception representing timeout of an outbound request""" - def __init__(self): - super(RequestTimedOutError, self).__init__(504, "Timed out") - - -def cancelled_to_request_timed_out_error(value, timeout): - """Turns CancelledErrors into RequestTimedOutErrors. - - For use with async.add_timeout_to_deferred - """ - if isinstance(value, failure.Failure): - value.trap(CancelledError) - raise RequestTimedOutError() - return value + def __init__(self, msg): + super().__init__(504, msg) ACCESS_TOKEN_RE = re.compile(r"(\?.*access(_|%5[Ff])token=)[^&]*(.*)$") diff --git a/synapse/http/client.py b/synapse/http/client.py index 13fcab3378..8324632cb6 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -13,10 +13,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging import urllib from io import BytesIO +from typing import ( + Any, + BinaryIO, + Dict, + Iterable, + List, + Mapping, + Optional, + Sequence, + Tuple, + Union, +) import treq from canonicaljson import encode_canonical_json @@ -26,7 +37,7 @@ from zope.interface import implementer, provider from OpenSSL import SSL from OpenSSL.SSL import VERIFY_NONE -from twisted.internet import defer, protocol, ssl +from twisted.internet import defer, error as twisted_error, protocol, ssl from twisted.internet.interfaces import ( IReactorPluggableNameResolver, IResolutionReceiver, @@ -34,16 +45,18 @@ from twisted.internet.interfaces import ( from twisted.internet.task import Cooperator from twisted.python.failure import Failure from twisted.web._newclient import ResponseDone -from twisted.web.client import Agent, HTTPConnectionPool, readBody +from twisted.web.client import ( + Agent, + HTTPConnectionPool, + ResponseNeverReceived, + readBody, +) from twisted.web.http import PotentialDataLoss from twisted.web.http_headers import Headers +from twisted.web.iweb import IResponse from synapse.api.errors import Codes, HttpResponseException, SynapseError -from synapse.http import ( - QuieterFileBodyProducer, - cancelled_to_request_timed_out_error, - redact_uri, -) +from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri from synapse.http.proxyagent import ProxyAgent from synapse.logging.context import make_deferred_yieldable from synapse.logging.opentracing import set_tag, start_active_span, tags @@ -57,6 +70,19 @@ incoming_responses_counter = Counter( "synapse_http_client_responses", "", ["method", "code"] ) +# the type of the headers list, to be passed to the t.w.h.Headers. +# Actually we can mix str and bytes keys, but Mapping treats 'key' as invariant so +# we simplify. +RawHeaders = Union[Mapping[str, "RawHeaderValue"], Mapping[bytes, "RawHeaderValue"]] + +# the value actually has to be a List, but List is invariant so we can't specify that +# the entries can either be Lists or bytes. +RawHeaderValue = Sequence[Union[str, bytes]] + +# the type of the query params, to be passed into `urlencode` +QueryParamValue = Union[str, bytes, Iterable[Union[str, bytes]]] +QueryParams = Union[Mapping[str, QueryParamValue], Mapping[bytes, QueryParamValue]] + def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist): """ @@ -285,16 +311,27 @@ class SimpleHttpClient: ip_blacklist=self._ip_blacklist, ) - async def request(self, method, uri, data=None, headers=None): + async def request( + self, + method: str, + uri: str, + data: Optional[bytes] = None, + headers: Optional[Headers] = None, + ) -> IResponse: """ Args: - method (str): HTTP method to use. - uri (str): URI to query. - data (bytes): Data to send in the request body, if applicable. - headers (t.w.http_headers.Headers): Request headers. + method: HTTP method to use. + uri: URI to query. + data: Data to send in the request body, if applicable. + headers: Request headers. + + Returns: + Response object, once the headers have been read. + + Raises: + RequestTimedOutError if the request times out before the headers are read + """ - # A small wrapper around self.agent.request() so we can easily attach - # counters to it outgoing_requests_counter.labels(method).inc() # log request but strip `access_token` (AS requests for example include this) @@ -323,13 +360,17 @@ class SimpleHttpClient: data=body_producer, headers=headers, **self._extra_treq_args - ) + ) # type: defer.Deferred + + # we use our own timeout mechanism rather than treq's as a workaround + # for https://twistedmatrix.com/trac/ticket/9534. request_deferred = timeout_deferred( - request_deferred, - 60, - self.hs.get_reactor(), - cancelled_to_request_timed_out_error, + request_deferred, 60, self.hs.get_reactor(), ) + + # turn timeouts into RequestTimedOutErrors + request_deferred.addErrback(_timeout_to_request_timed_out_error) + response = await make_deferred_yieldable(request_deferred) incoming_responses_counter.labels(method, response.code).inc() @@ -353,18 +394,26 @@ class SimpleHttpClient: set_tag("error_reason", e.args[0]) raise - async def post_urlencoded_get_json(self, uri, args={}, headers=None): + async def post_urlencoded_get_json( + self, + uri: str, + args: Mapping[str, Union[str, List[str]]] = {}, + headers: Optional[RawHeaders] = None, + ) -> Any: """ Args: - uri (str): - args (dict[str, str|List[str]]): query params - headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from - header name to a list of values for that header + uri: uri to query + args: parameters to be url-encoded in the body + headers: a map from header name to a list of values for that header Returns: - object: parsed json + parsed json Raises: + RequestTimedOutError: if there is a timeout before the response headers + are received. Note there is currently no timeout on reading the response + body. + HttpResponseException: On a non-2xx HTTP response. ValueError: if the response was not JSON @@ -398,19 +447,24 @@ class SimpleHttpClient: response.code, response.phrase.decode("ascii", errors="replace"), body ) - async def post_json_get_json(self, uri, post_json, headers=None): + async def post_json_get_json( + self, uri: str, post_json: Any, headers: Optional[RawHeaders] = None + ) -> Any: """ Args: - uri (str): - post_json (object): - headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from - header name to a list of values for that header + uri: URI to query. + post_json: request body, to be encoded as json + headers: a map from header name to a list of values for that header Returns: - object: parsed json + parsed json Raises: + RequestTimedOutError: if there is a timeout before the response headers + are received. Note there is currently no timeout on reading the response + body. + HttpResponseException: On a non-2xx HTTP response. ValueError: if the response was not JSON @@ -440,21 +494,22 @@ class SimpleHttpClient: response.code, response.phrase.decode("ascii", errors="replace"), body ) - async def get_json(self, uri, args={}, headers=None): - """ Gets some json from the given URI. + async def get_json( + self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None, + ) -> Any: + """Gets some json from the given URI. Args: - uri (str): The URI to request, not including query parameters - args (dict): A dictionary used to create query strings, defaults to - None. - **Note**: The value of each key is assumed to be an iterable - and *not* a string. - headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from - header name to a list of values for that header + uri: The URI to request, not including query parameters + args: A dictionary used to create query string + headers: a map from header name to a list of values for that header Returns: - Succeeds when we get *any* 2xx HTTP response, with the - HTTP body as JSON. + Succeeds when we get a 2xx HTTP response, with the HTTP body as JSON. Raises: + RequestTimedOutError: if there is a timeout before the response headers + are received. Note there is currently no timeout on reading the response + body. + HttpResponseException On a non-2xx HTTP response. ValueError: if the response was not JSON @@ -466,22 +521,27 @@ class SimpleHttpClient: body = await self.get_raw(uri, args, headers=headers) return json_decoder.decode(body.decode("utf-8")) - async def put_json(self, uri, json_body, args={}, headers=None): - """ Puts some json to the given URI. + async def put_json( + self, + uri: str, + json_body: Any, + args: QueryParams = {}, + headers: RawHeaders = None, + ) -> Any: + """Puts some json to the given URI. Args: - uri (str): The URI to request, not including query parameters - json_body (dict): The JSON to put in the HTTP body, - args (dict): A dictionary used to create query strings, defaults to - None. - **Note**: The value of each key is assumed to be an iterable - and *not* a string. - headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from - header name to a list of values for that header + uri: The URI to request, not including query parameters + json_body: The JSON to put in the HTTP body, + args: A dictionary used to create query strings + headers: a map from header name to a list of values for that header Returns: - Succeeds when we get *any* 2xx HTTP response, with the - HTTP body as JSON. + Succeeds when we get a 2xx HTTP response, with the HTTP body as JSON. Raises: + RequestTimedOutError: if there is a timeout before the response headers + are received. Note there is currently no timeout on reading the response + body. + HttpResponseException On a non-2xx HTTP response. ValueError: if the response was not JSON @@ -513,21 +573,23 @@ class SimpleHttpClient: response.code, response.phrase.decode("ascii", errors="replace"), body ) - async def get_raw(self, uri, args={}, headers=None): - """ Gets raw text from the given URI. + async def get_raw( + self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None + ) -> bytes: + """Gets raw text from the given URI. Args: - uri (str): The URI to request, not including query parameters - args (dict): A dictionary used to create query strings, defaults to - None. - **Note**: The value of each key is assumed to be an iterable - and *not* a string. - headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from - header name to a list of values for that header + uri: The URI to request, not including query parameters + args: A dictionary used to create query strings + headers: a map from header name to a list of values for that header Returns: - Succeeds when we get *any* 2xx HTTP response, with the + Succeeds when we get a 2xx HTTP response, with the HTTP body as bytes. Raises: + RequestTimedOutError: if there is a timeout before the response headers + are received. Note there is currently no timeout on reading the response + body. + HttpResponseException on a non-2xx HTTP response. """ if len(args): @@ -552,16 +614,29 @@ class SimpleHttpClient: # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. # The two should be factored out. - async def get_file(self, url, output_stream, max_size=None, headers=None): + async def get_file( + self, + url: str, + output_stream: BinaryIO, + max_size: Optional[int] = None, + headers: Optional[RawHeaders] = None, + ) -> Tuple[int, Dict[bytes, List[bytes]], str, int]: """GETs a file from a given URL Args: - url (str): The URL to GET - output_stream (file): File to write the response body to. - headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from - header name to a list of values for that header + url: The URL to GET + output_stream: File to write the response body to. + headers: A map from header name to a list of values for that header Returns: - A (int,dict,string,int) tuple of the file length, dict of the response + A tuple of the file length, dict of the response headers, absolute URI of the response and HTTP response code. + + Raises: + RequestTimedOutError: if there is a timeout before the response headers + are received. Note there is currently no timeout on reading the response + body. + + SynapseError: if the response is not a 2xx, the remote file is too large, or + another exception happens during the download. """ actual_headers = {b"User-Agent": [self.user_agent]} @@ -609,6 +684,18 @@ class SimpleHttpClient: ) +def _timeout_to_request_timed_out_error(f: Failure): + if f.check(twisted_error.TimeoutError, twisted_error.ConnectingCancelledError): + # The TCP connection has its own timeout (set by the 'connectTimeout' param + # on the Agent), which raises twisted_error.TimeoutError exception. + raise RequestTimedOutError("Timeout connecting to remote server") + elif f.check(defer.TimeoutError, ResponseNeverReceived): + # this one means that we hit our overall timeout on the request + raise RequestTimedOutError("Timeout waiting for response from remote server") + + return f + + # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. # The two should be factored out. diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py index e6f067ca29..a306faa267 100644 --- a/synapse/http/federation/well_known_resolver.py +++ b/synapse/http/federation/well_known_resolver.py @@ -311,7 +311,7 @@ def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]: return cache_controls -@attr.s() +@attr.s(slots=True) class _FetchWellKnownFailure(Exception): # True if we didn't get a non-5xx HTTP response, i.e. this may or may not be # a temporary failure. diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 5eaf3151ce..c23a4d7c0c 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -76,7 +76,7 @@ MAXINT = sys.maxsize _next_id = 1 -@attr.s(frozen=True) +@attr.s(slots=True, frozen=True) class MatrixFederationRequest: method = attr.ib() """HTTP method @@ -171,7 +171,7 @@ async def _handle_json_response( d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor) body = await make_deferred_yieldable(d) - except TimeoutError as e: + except defer.TimeoutError as e: logger.warning( "{%s} [%s] Timed out reading response - %s %s", request.txn_id, @@ -473,8 +473,6 @@ class MatrixFederationHttpClient: ) response = await request_deferred - except TimeoutError as e: - raise RequestSendFailed(e, can_retry=True) from e except DNSLookupError as e: raise RequestSendFailed(e, can_retry=retry_on_dns_fail) from e except Exception as e: @@ -657,10 +655,14 @@ class MatrixFederationHttpClient: long_retries (bool): whether to use the long retry algorithm. See docs on _send_request for details. - timeout (int|None): number of milliseconds to wait for the response headers - (including connecting to the server), *for each attempt*. + timeout (int|None): number of milliseconds to wait for the response. self._default_timeout (60s) by default. + Note that we may make several attempts to send the request; this + timeout applies to the time spent waiting for response headers for + *each* attempt (including connection time) as well as the time spent + reading the response body after a 200 response. + ignore_backoff (bool): true to ignore the historical backoff data and try the request anyway. backoff_on_404 (bool): True if we should count a 404 response as @@ -706,8 +708,13 @@ class MatrixFederationHttpClient: timeout=timeout, ) + if timeout is not None: + _sec_timeout = timeout / 1000 + else: + _sec_timeout = self.default_timeout + body = await _handle_json_response( - self.reactor, self.default_timeout, request, response, start_ms + self.reactor, _sec_timeout, request, response, start_ms ) return body @@ -736,10 +743,14 @@ class MatrixFederationHttpClient: long_retries (bool): whether to use the long retry algorithm. See docs on _send_request for details. - timeout (int|None): number of milliseconds to wait for the response headers - (including connecting to the server), *for each attempt*. + timeout (int|None): number of milliseconds to wait for the response. self._default_timeout (60s) by default. + Note that we may make several attempts to send the request; this + timeout applies to the time spent waiting for response headers for + *each* attempt (including connection time) as well as the time spent + reading the response body after a 200 response. + ignore_backoff (bool): true to ignore the historical backoff data and try the request anyway. @@ -803,10 +814,14 @@ class MatrixFederationHttpClient: args (dict|None): A dictionary used to create query strings, defaults to None. - timeout (int|None): number of milliseconds to wait for the response headers - (including connecting to the server), *for each attempt*. + timeout (int|None): number of milliseconds to wait for the response. self._default_timeout (60s) by default. + Note that we may make several attempts to send the request; this + timeout applies to the time spent waiting for response headers for + *each* attempt (including connection time) as well as the time spent + reading the response body after a 200 response. + ignore_backoff (bool): true to ignore the historical backoff data and try the request anyway. @@ -842,8 +857,13 @@ class MatrixFederationHttpClient: timeout=timeout, ) + if timeout is not None: + _sec_timeout = timeout / 1000 + else: + _sec_timeout = self.default_timeout + body = await _handle_json_response( - self.reactor, self.default_timeout, request, response, start_ms + self.reactor, _sec_timeout, request, response, start_ms ) return body @@ -867,10 +887,14 @@ class MatrixFederationHttpClient: long_retries (bool): whether to use the long retry algorithm. See docs on _send_request for details. - timeout (int|None): number of milliseconds to wait for the response headers - (including connecting to the server), *for each attempt*. + timeout (int|None): number of milliseconds to wait for the response. self._default_timeout (60s) by default. + Note that we may make several attempts to send the request; this + timeout applies to the time spent waiting for response headers for + *each* attempt (including connection time) as well as the time spent + reading the response body after a 200 response. + ignore_backoff (bool): true to ignore the historical backoff data and try the request anyway. @@ -902,8 +926,13 @@ class MatrixFederationHttpClient: ignore_backoff=ignore_backoff, ) + if timeout is not None: + _sec_timeout = timeout / 1000 + else: + _sec_timeout = self.default_timeout + body = await _handle_json_response( - self.reactor, self.default_timeout, request, response, start_ms + self.reactor, _sec_timeout, request, response, start_ms ) return body diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py index 332da02a8d..e32d3f43e0 100644 --- a/synapse/http/proxyagent.py +++ b/synapse/http/proxyagent.py @@ -44,8 +44,11 @@ class ProxyAgent(_AgentBase): `BrowserLikePolicyForHTTPS`, so unless you have special requirements you can leave this as-is. - connectTimeout (float): The amount of time that this Agent will wait - for the peer to accept a connection. + connectTimeout (Optional[float]): The amount of time that this Agent will wait + for the peer to accept a connection, in seconds. If 'None', + HostnameEndpoint's default (30s) will be used. + + This is used for connections to both proxies and destination servers. bindAddress (bytes): The local address for client sockets to bind to. @@ -108,6 +111,15 @@ class ProxyAgent(_AgentBase): Returns: Deferred[IResponse]: completes when the header of the response has been received (regardless of the response status code). + + Can fail with: + SchemeNotSupported: if the uri is not http or https + + twisted.internet.error.TimeoutError if the server we are connecting + to (proxy or destination) does not accept a connection before + connectTimeout. + + ... other things too. """ uri = uri.strip() if not _VALID_URI.match(uri): diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 22598e02d2..ca0c774cc5 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -65,6 +65,11 @@ except Exception: return None +# a hook which can be set during testing to assert that we aren't abusing logcontexts. +def logcontext_error(msg: str): + logger.warning(msg) + + # get an id for the current thread. # # threading.get_ident doesn't actually return an OS-level tid, and annoyingly, @@ -217,11 +222,9 @@ class _Sentinel: def record_event_fetch(self, event_count): pass - def __nonzero__(self): + def __bool__(self): return False - __bool__ = __nonzero__ # python3 - SENTINEL_CONTEXT = _Sentinel() @@ -332,10 +335,9 @@ class LoggingContext: """Enters this logging context into thread local storage""" old_context = set_current_context(self) if self.previous_context != old_context: - logger.warning( - "Expected previous context %r, found %r", - self.previous_context, - old_context, + logcontext_error( + "Expected previous context %r, found %r" + % (self.previous_context, old_context,) ) return self @@ -348,10 +350,10 @@ class LoggingContext: current = set_current_context(self.previous_context) if current is not self: if current is SENTINEL_CONTEXT: - logger.warning("Expected logging context %s was lost", self) + logcontext_error("Expected logging context %s was lost" % (self,)) else: - logger.warning( - "Expected logging context %s but found %s", self, current + logcontext_error( + "Expected logging context %s but found %s" % (self, current) ) # the fact that we are here suggests that the caller thinks that everything @@ -389,16 +391,16 @@ class LoggingContext: support getrusuage. """ if get_thread_id() != self.main_thread: - logger.warning("Started logcontext %s on different thread", self) + logcontext_error("Started logcontext %s on different thread" % (self,)) return if self.finished: - logger.warning("Re-starting finished log context %s", self) + logcontext_error("Re-starting finished log context %s" % (self,)) # If we haven't already started record the thread resource usage so # far if self.usage_start: - logger.warning("Re-starting already-active log context %s", self) + logcontext_error("Re-starting already-active log context %s" % (self,)) else: self.usage_start = rusage @@ -416,7 +418,7 @@ class LoggingContext: try: if get_thread_id() != self.main_thread: - logger.warning("Stopped logcontext %s on different thread", self) + logcontext_error("Stopped logcontext %s on different thread" % (self,)) return if not rusage: @@ -424,9 +426,9 @@ class LoggingContext: # Record the cpu used since we started if not self.usage_start: - logger.warning( - "Called stop on logcontext %s without recording a start rusage", - self, + logcontext_error( + "Called stop on logcontext %s without recording a start rusage" + % (self,) ) return @@ -586,14 +588,13 @@ class PreserveLoggingContext: if context != self._new_context: if not context: - logger.warning( - "Expected logging context %s was lost", self._new_context + logcontext_error( + "Expected logging context %s was lost" % (self._new_context,) ) else: - logger.warning( - "Expected logging context %s but found %s", - self._new_context, - context, + logcontext_error( + "Expected logging context %s but found %s" + % (self._new_context, context,) ) diff --git a/synapse/logging/formatter.py b/synapse/logging/formatter.py index d736ad5b9b..11f60a77f7 100644 --- a/synapse/logging/formatter.py +++ b/synapse/logging/formatter.py @@ -30,7 +30,7 @@ class LogFormatter(logging.Formatter): """ def __init__(self, *args, **kwargs): - super(LogFormatter, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) def formatException(self, ei): sio = StringIO() diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 7df0aa197d..e58850faff 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -509,7 +509,7 @@ def start_active_span_from_edu( ] # For some reason jaeger decided not to support the visualization of multiple parent - # spans or explicitely show references. I include the span context as a tag here as + # spans or explicitly show references. I include the span context as a tag here as # an aid to people debugging but it's really not an ideal solution. references += _references diff --git a/synapse/logging/scopecontextmanager.py b/synapse/logging/scopecontextmanager.py index 026854b4c7..7b9c657456 100644 --- a/synapse/logging/scopecontextmanager.py +++ b/synapse/logging/scopecontextmanager.py @@ -107,7 +107,7 @@ class _LogContextScope(Scope): finish_on_close (Boolean): if True finish the span when the scope is closed """ - super(_LogContextScope, self).__init__(manager, span) + super().__init__(manager, span) self.logcontext = logcontext self._finish_on_close = finish_on_close self._enter_logcontext = enter_logcontext @@ -120,9 +120,9 @@ class _LogContextScope(Scope): def __exit__(self, type, value, traceback): if type == twisted.internet.defer._DefGen_Return: - super(_LogContextScope, self).__exit__(None, None, None) + super().__exit__(None, None, None) else: - super(_LogContextScope, self).__exit__(type, value, traceback) + super().__exit__(type, value, traceback) if self._enter_logcontext: self.logcontext.__exit__(type, value, traceback) else: # the logcontext existed before the creation of the scope diff --git a/synapse/logging/utils.py b/synapse/logging/utils.py index fea774e2e5..becf66dd86 100644 --- a/synapse/logging/utils.py +++ b/synapse/logging/utils.py @@ -29,11 +29,11 @@ def _log_debug_as_f(f, msg, msg_args): lineno = f.__code__.co_firstlineno pathname = f.__code__.co_filename - record = logging.LogRecord( + record = logger.makeRecord( name=name, level=logging.DEBUG, - pathname=pathname, - lineno=lineno, + fn=pathname, + lno=lineno, msg=msg, args=msg_args, exc_info=None, diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index 2643380d9e..b8d2a8e8a9 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -15,6 +15,7 @@ import functools import gc +import itertools import logging import os import platform @@ -27,8 +28,8 @@ from prometheus_client import Counter, Gauge, Histogram from prometheus_client.core import ( REGISTRY, CounterMetricFamily, + GaugeHistogramMetricFamily, GaugeMetricFamily, - HistogramMetricFamily, ) from twisted.internet import reactor @@ -46,7 +47,7 @@ logger = logging.getLogger(__name__) METRICS_PREFIX = "/_synapse/metrics" running_on_pypy = platform.python_implementation() == "PyPy" -all_gauges = {} # type: Dict[str, Union[LaterGauge, InFlightGauge, BucketCollector]] +all_gauges = {} # type: Dict[str, Union[LaterGauge, InFlightGauge]] HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat") @@ -59,7 +60,7 @@ class RegistryProxy: yield metric -@attr.s(hash=True) +@attr.s(slots=True, hash=True) class LaterGauge: name = attr.ib(type=str) @@ -205,63 +206,83 @@ class InFlightGauge: all_gauges[self.name] = self -@attr.s(hash=True) -class BucketCollector: - """ - Like a Histogram, but allows buckets to be point-in-time instead of - incrementally added to. +class GaugeBucketCollector: + """Like a Histogram, but the buckets are Gauges which are updated atomically. - Args: - name (str): Base name of metric to be exported to Prometheus. - data_collector (callable -> dict): A synchronous callable that - returns a dict mapping bucket to number of items in the - bucket. If these buckets are not the same as the buckets - given to this class, they will be remapped into them. - buckets (list[float]): List of floats/ints of the buckets to - give to Prometheus. +Inf is ignored, if given. + The data is updated by calling `update_data` with an iterable of measurements. + We assume that the data is updated less frequently than it is reported to + Prometheus, and optimise for that case. """ - name = attr.ib() - data_collector = attr.ib() - buckets = attr.ib() + __slots__ = ("_name", "_documentation", "_bucket_bounds", "_metric") - def collect(self): + def __init__( + self, + name: str, + documentation: str, + buckets: Iterable[float], + registry=REGISTRY, + ): + """ + Args: + name: base name of metric to be exported to Prometheus. (a _bucket suffix + will be added.) + documentation: help text for the metric + buckets: The top bounds of the buckets to report + registry: metric registry to register with + """ + self._name = name + self._documentation = documentation - # Fetch the data -- this must be synchronous! - data = self.data_collector() + # the tops of the buckets + self._bucket_bounds = [float(b) for b in buckets] + if self._bucket_bounds != sorted(self._bucket_bounds): + raise ValueError("Buckets not in sorted order") - buckets = {} # type: Dict[float, int] + if self._bucket_bounds[-1] != float("inf"): + self._bucket_bounds.append(float("inf")) - res = [] - for x in data.keys(): - for i, bound in enumerate(self.buckets): - if x <= bound: - buckets[bound] = buckets.get(bound, 0) + data[x] + self._metric = self._values_to_metric([]) + registry.register(self) - for i in self.buckets: - res.append([str(i), buckets.get(i, 0)]) + def collect(self): + yield self._metric - res.append(["+Inf", sum(data.values())]) + def update_data(self, values: Iterable[float]): + """Update the data to be reported by the metric - metric = HistogramMetricFamily( - self.name, "", buckets=res, sum_value=sum(x * y for x, y in data.items()) + The existing data is cleared, and each measurement in the input is assigned + to the relevant bucket. + """ + self._metric = self._values_to_metric(values) + + def _values_to_metric(self, values: Iterable[float]) -> GaugeHistogramMetricFamily: + total = 0.0 + bucket_values = [0 for _ in self._bucket_bounds] + + for v in values: + # assign each value to a bucket + for i, bound in enumerate(self._bucket_bounds): + if v <= bound: + bucket_values[i] += 1 + break + + # ... and increment the sum + total += v + + # now, aggregate the bucket values so that they count the number of entries in + # that bucket or below. + accumulated_values = itertools.accumulate(bucket_values) + + return GaugeHistogramMetricFamily( + self._name, + self._documentation, + buckets=list( + zip((str(b) for b in self._bucket_bounds), accumulated_values) + ), + gsum_value=total, ) - yield metric - - def __attrs_post_init__(self): - self.buckets = [float(x) for x in self.buckets if x != "+Inf"] - if self.buckets != sorted(self.buckets): - raise ValueError("Buckets not sorted") - - self.buckets = tuple(self.buckets) - - if self.name in all_gauges.keys(): - logger.warning("%s already registered, reregistering" % (self.name,)) - REGISTRY.unregister(all_gauges.pop(self.name)) - - REGISTRY.register(self) - all_gauges[self.name] = self # diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_exposition.py index 4304c60d56..734271e765 100644 --- a/synapse/metrics/_exposition.py +++ b/synapse/metrics/_exposition.py @@ -24,9 +24,9 @@ expect, and the newer "best practice" version of the up-to-date official client. import math import threading -from collections import namedtuple from http.server import BaseHTTPRequestHandler, HTTPServer from socketserver import ThreadingMixIn +from typing import Dict, List from urllib.parse import parse_qs, urlparse from prometheus_client import REGISTRY @@ -35,14 +35,6 @@ from twisted.web.resource import Resource from synapse.util import caches -try: - from prometheus_client.samples import Sample -except ImportError: - Sample = namedtuple( # type: ignore[no-redef] # noqa - "Sample", ["name", "labels", "value", "timestamp", "exemplar"] - ) - - CONTENT_TYPE_LATEST = str("text/plain; version=0.0.4; charset=utf-8") @@ -93,17 +85,6 @@ def sample_line(line, name): ) -def nameify_sample(sample): - """ - If we get a prometheus_client<0.4.0 sample as a tuple, transform it into a - namedtuple which has the names we expect. - """ - if not isinstance(sample, Sample): - sample = Sample(*sample, None, None) - - return sample - - def generate_latest(registry, emit_help=False): # Trigger the cache metrics to be rescraped, which updates the common @@ -144,16 +125,33 @@ def generate_latest(registry, emit_help=False): ) ) output.append("# TYPE {0} {1}\n".format(mname, mtype)) - for sample in map(nameify_sample, metric.samples): - # Get rid of the OpenMetrics specific samples + + om_samples = {} # type: Dict[str, List[str]] + for s in metric.samples: for suffix in ["_created", "_gsum", "_gcount"]: - if sample.name.endswith(suffix): + if s.name == metric.name + suffix: + # OpenMetrics specific sample, put in a gauge at the end. + # (these come from gaugehistograms which don't get renamed, + # so no need to faff with mnewname) + om_samples.setdefault(suffix, []).append(sample_line(s, s.name)) break else: - newname = sample.name.replace(mnewname, mname) + newname = s.name.replace(mnewname, mname) if ":" in newname and newname.endswith("_total"): newname = newname[: -len("_total")] - output.append(sample_line(sample, newname)) + output.append(sample_line(s, newname)) + + for suffix, lines in sorted(om_samples.items()): + if emit_help: + output.append( + "# HELP {0}{1} {2}\n".format( + metric.name, + suffix, + metric.documentation.replace("\\", r"\\").replace("\n", r"\n"), + ) + ) + output.append("# TYPE {0}{1} gauge\n".format(metric.name, suffix)) + output.extend(lines) # Get rid of the weird colon things while we're at it if mtype == "counter": @@ -172,16 +170,16 @@ def generate_latest(registry, emit_help=False): ) ) output.append("# TYPE {0} {1}\n".format(mnewname, mtype)) - for sample in map(nameify_sample, metric.samples): - # Get rid of the OpenMetrics specific samples + + for s in metric.samples: + # Get rid of the OpenMetrics specific samples (we should already have + # dealt with them above anyway.) for suffix in ["_created", "_gsum", "_gcount"]: - if sample.name.endswith(suffix): + if s.name == metric.name + suffix: break else: output.append( - sample_line( - sample, sample.name.replace(":total", "").replace(":", "_") - ) + sample_line(s, s.name.replace(":total", "").replace(":", "_")) ) return "".join(output).encode("utf-8") diff --git a/synapse/notifier.py b/synapse/notifier.py index b7f4041306..59415f6f88 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -42,7 +42,13 @@ from synapse.logging.utils import log_function from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.streams.config import PaginationConfig -from synapse.types import Collection, StreamToken, UserID +from synapse.types import ( + Collection, + PersistedEventPosition, + RoomStreamToken, + StreamToken, + UserID, +) from synapse.util.async_helpers import ObservableDeferred, timeout_deferred from synapse.util.metrics import Measure from synapse.visibility import filter_events_for_client @@ -112,7 +118,9 @@ class _NotifierUserStream: with PreserveLoggingContext(): self.notify_deferred = ObservableDeferred(defer.Deferred()) - def notify(self, stream_key: str, stream_id: int, time_now_ms: int): + def notify( + self, stream_key: str, stream_id: Union[int, RoomStreamToken], time_now_ms: int, + ): """Notify any listeners for this user of a new event from an event source. Args: @@ -155,18 +163,16 @@ class _NotifierUserStream: """ # Immediately wake up stream if something has already since happened # since their last token. - if self.last_notified_token.is_after(token): + if self.last_notified_token != token: return _NotificationListener(defer.succeed(self.current_token)) else: return _NotificationListener(self.notify_deferred.observe()) class EventStreamResult(namedtuple("EventStreamResult", ("events", "tokens"))): - def __nonzero__(self): + def __bool__(self): return bool(self.events) - __bool__ = __nonzero__ # python3 - class Notifier: """ This class is responsible for notifying any listeners when there are @@ -187,7 +193,7 @@ class Notifier: self.store = hs.get_datastore() self.pending_new_room_events = ( [] - ) # type: List[Tuple[int, EventBase, Collection[Union[str, UserID]]]] + ) # type: List[Tuple[PersistedEventPosition, EventBase, Collection[UserID]]] # Called when there are new things to stream over replication self.replication_callbacks = [] # type: List[Callable[[], None]] @@ -198,6 +204,7 @@ class Notifier: self.clock = hs.get_clock() self.appservice_handler = hs.get_application_service_handler() + self._pusher_pool = hs.get_pusherpool() self.federation_sender = None if hs.should_send_federation(): @@ -245,9 +252,9 @@ class Notifier: def on_new_room_event( self, event: EventBase, - room_stream_id: int, - max_room_stream_id: int, - extra_users: Collection[Union[str, UserID]] = [], + event_pos: PersistedEventPosition, + max_room_stream_token: RoomStreamToken, + extra_users: Collection[UserID] = [], ): """ Used by handlers to inform the notifier something has happened in the room, room event wise. @@ -260,61 +267,79 @@ class Notifier: until all previous events have been persisted before notifying the client streams. """ - self.pending_new_room_events.append((room_stream_id, event, extra_users)) - self._notify_pending_new_room_events(max_room_stream_id) + self.pending_new_room_events.append((event_pos, event, extra_users)) + self._notify_pending_new_room_events(max_room_stream_token) self.notify_replication() - def _notify_pending_new_room_events(self, max_room_stream_id: int): + def _notify_pending_new_room_events(self, max_room_stream_token: RoomStreamToken): """Notify for the room events that were queued waiting for a previous event to be persisted. Args: - max_room_stream_id: The highest stream_id below which all + max_room_stream_token: The highest stream_id below which all events have been persisted. """ pending = self.pending_new_room_events self.pending_new_room_events = [] - for room_stream_id, event, extra_users in pending: - if room_stream_id > max_room_stream_id: - self.pending_new_room_events.append( - (room_stream_id, event, extra_users) - ) + + users = set() # type: Set[UserID] + rooms = set() # type: Set[str] + + for event_pos, event, extra_users in pending: + if event_pos.persisted_after(max_room_stream_token): + self.pending_new_room_events.append((event_pos, event, extra_users)) else: - self._on_new_room_event(event, room_stream_id, extra_users) + if ( + event.type == EventTypes.Member + and event.membership == Membership.JOIN + ): + self._user_joined_room(event.state_key, event.room_id) + + users.update(extra_users) + rooms.add(event.room_id) + + if users or rooms: + self.on_new_event( + "room_key", max_room_stream_token, users=users, rooms=rooms, + ) + self._on_updated_room_token(max_room_stream_token) + + def _on_updated_room_token(self, max_room_stream_token: RoomStreamToken): + """Poke services that might care that the room position has been + updated. + """ - def _on_new_room_event( - self, - event: EventBase, - room_stream_id: int, - extra_users: Collection[Union[str, UserID]] = [], - ): - """Notify any user streams that are interested in this room event""" # poke any interested application service. run_as_background_process( - "notify_app_services", self._notify_app_services, room_stream_id + "_notify_app_services", self._notify_app_services, max_room_stream_token ) - if self.federation_sender: - self.federation_sender.notify_new_events(room_stream_id) - - if event.type == EventTypes.Member and event.membership == Membership.JOIN: - self._user_joined_room(event.state_key, event.room_id) - - self.on_new_event( - "room_key", room_stream_id, users=extra_users, rooms=[event.room_id] + run_as_background_process( + "_notify_pusher_pool", self._notify_pusher_pool, max_room_stream_token ) - async def _notify_app_services(self, room_stream_id: int): + if self.federation_sender: + self.federation_sender.notify_new_events(max_room_stream_token.stream) + + async def _notify_app_services(self, max_room_stream_token: RoomStreamToken): try: - await self.appservice_handler.notify_interested_services(room_stream_id) + await self.appservice_handler.notify_interested_services( + max_room_stream_token.stream + ) except Exception: logger.exception("Error notifying application services of event") + async def _notify_pusher_pool(self, max_room_stream_token: RoomStreamToken): + try: + await self._pusher_pool.on_new_notifications(max_room_stream_token.stream) + except Exception: + logger.exception("Error pusher pool of event") + def on_new_event( self, stream_key: str, - new_token: int, - users: Collection[Union[str, UserID]] = [], + new_token: Union[int, RoomStreamToken], + users: Collection[UserID] = [], rooms: Collection[str] = [], ): """ Used to inform listeners that something has happened event wise. @@ -432,8 +457,9 @@ class Notifier: If explicit_room_id is set, that room will be polled for events only if it is world readable or the user has joined the room. """ - from_token = pagination_config.from_token - if not from_token: + if pagination_config.from_token: + from_token = pagination_config.from_token + else: from_token = self.event_sources.get_current_token() limit = pagination_config.limit @@ -444,7 +470,7 @@ class Notifier: async def check_for_updates( before_token: StreamToken, after_token: StreamToken ) -> EventStreamResult: - if not after_token.is_after(before_token): + if after_token == before_token: return EventStreamResult([], (from_token, from_token)) events = [] # type: List[EventBase] diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py index edf45dc599..5a437f9810 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py @@ -16,4 +16,4 @@ class PusherConfigException(Exception): def __init__(self, msg): - super(PusherConfigException, self).__init__(msg) + super().__init__(msg) diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index b7ea4438e0..28bd8ab748 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -91,7 +91,7 @@ class EmailPusher: pass self.timed_call = None - def on_new_notifications(self, min_stream_ordering, max_stream_ordering): + def on_new_notifications(self, max_stream_ordering): if self.max_stream_ordering: self.max_stream_ordering = max( max_stream_ordering, self.max_stream_ordering diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index f21fa9b659..26706bf3e1 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -114,7 +114,7 @@ class HttpPusher: if should_check_for_notifs: self._start_processing() - def on_new_notifications(self, min_stream_ordering, max_stream_ordering): + def on_new_notifications(self, max_stream_ordering): self.max_stream_ordering = max( max_stream_ordering, self.max_stream_ordering or 0 ) diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 6c57854018..455a1acb46 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -123,7 +123,7 @@ class Mailer: params = {"token": token, "client_secret": client_secret, "sid": sid} link = ( self.hs.config.public_baseurl - + "_matrix/client/unstable/password_reset/email/submit_token?%s" + + "_synapse/client/password_reset/email/submit_token?%s" % urllib.parse.urlencode(params) ) diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 3c3262a88c..76150e117b 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -60,10 +60,18 @@ class PusherPool: self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() + self._account_validity = hs.config.account_validity + # We shard the handling of push notifications by user ID. self._pusher_shard_config = hs.config.push.pusher_shard_config self._instance_name = hs.get_instance_name() + # Record the last stream ID that we were poked about so we can get + # changes since then. We set this to the current max stream ID on + # startup as every individual pusher will have checked for changes on + # startup. + self._last_room_stream_id_seen = self.store.get_room_max_stream_ordering() + # map from user id to app_id:pushkey to pusher self.pushers = {} # type: Dict[str, Dict[str, Union[HttpPusher, EmailPusher]]] @@ -178,20 +186,35 @@ class PusherPool: ) await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) - async def on_new_notifications(self, min_stream_id, max_stream_id): + async def on_new_notifications(self, max_stream_id: int): if not self.pushers: # nothing to do here. return + if max_stream_id < self._last_room_stream_id_seen: + # Nothing to do + return + + prev_stream_id = self._last_room_stream_id_seen + self._last_room_stream_id_seen = max_stream_id + try: users_affected = await self.store.get_push_action_users_in_range( - min_stream_id, max_stream_id + prev_stream_id, max_stream_id ) for u in users_affected: + # Don't push if the user account has expired + if self._account_validity.enabled: + expired = await self.store.is_account_expired( + u, self.clock.time_msec() + ) + if expired: + continue + if u in self.pushers: for p in self.pushers[u].values(): - p.on_new_notifications(min_stream_id, max_stream_id) + p.on_new_notifications(max_stream_id) except Exception: logger.exception("Exception in pusher on_new_notifications") @@ -209,6 +232,14 @@ class PusherPool: ) for u in users_affected: + # Don't push if the user account has expired + if self._account_validity.enabled: + expired = await self.store.is_account_expired( + u, self.clock.time_msec() + ) + if expired: + continue + if u in self.pushers: for p in self.pushers[u].values(): p.on_new_receipts(min_stream_id, max_stream_id) diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 2d995ec456..0ddead8a0f 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -37,13 +37,16 @@ logger = logging.getLogger(__name__) # installed when that optional dependency requirement is specified. It is passed # to setup() as extras_require in setup.py # +# Note that these both represent runtime dependencies (and the versions +# installed are checked at runtime). +# # [1] https://pip.pypa.io/en/stable/reference/pip_install/#requirement-specifiers. REQUIREMENTS = [ "jsonschema>=2.5.1", "frozendict>=1", "unpaddedbase64>=1.1.0", - "canonicaljson>=1.3.0", + "canonicaljson>=1.4.0", # we use the type definitions added in signedjson 1.1. "signedjson>=1.1.0", "pynacl>=1.2.1", @@ -65,7 +68,11 @@ REQUIREMENTS = [ "pymacaroons>=0.13.0", "msgpack>=0.5.2", "phonenumbers>=8.2.0", - "prometheus_client>=0.0.18,<0.9.0", + # we use GaugeHistogramMetric, which was added in prom-client 0.4.0. + # prom-client has a history of breaking backwards compatibility between + # minor versions (https://github.com/prometheus/client_python/issues/317), + # so we also pin the minor version. + "prometheus_client>=0.4.0,<0.9.0", # we use attr.validators.deep_iterable, which arrived in 19.1.0 (Note: # Fedora 31 only has 19.1, so if we want to upgrade we should wait until 33 # is out in November.) @@ -92,12 +99,6 @@ CONDITIONAL_REQUIREMENTS = { "oidc": ["authlib>=0.14.0"], "systemd": ["systemd-python>=231"], "url_preview": ["lxml>=3.5.0"], - # Dependencies which are exclusively required by unit test code. This is - # NOT a list of all modules that are necessary to run the unit tests. - # Tests assume that all optional dependencies are installed. - # - # parameterized_class decorator was introduced in parameterized 0.7.0 - "test": ["mock>=2.0", "parameterized>=0.7.0"], "sentry": ["sentry-sdk>=0.7.2"], "opentracing": ["jaeger-client>=4.0.0", "opentracing>=2.2.0"], "jwt": ["pyjwt>=1.6.4"], @@ -110,6 +111,7 @@ ALL_OPTIONAL_REQUIREMENTS = set() # type: Set[str] for name, optional_deps in CONDITIONAL_REQUIREMENTS.items(): # Exclude systemd as it's a system-based requirement. + # Exclude lint as it's a dev-based requirement. if name not in ["systemd"]: ALL_OPTIONAL_REQUIREMENTS = set(optional_deps) | ALL_OPTIONAL_REQUIREMENTS diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index ba16f22c91..64edadb624 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -20,20 +20,30 @@ import urllib from inspect import signature from typing import Dict, List, Tuple -from synapse.api.errors import ( - CodeMessageException, - HttpResponseException, - RequestSendFailed, - SynapseError, -) +from prometheus_client import Counter, Gauge + +from synapse.api.errors import HttpResponseException, SynapseError +from synapse.http import RequestTimedOutError from synapse.logging.opentracing import inject_active_span_byte_dict, trace from synapse.util.caches.response_cache import ResponseCache from synapse.util.stringutils import random_string logger = logging.getLogger(__name__) +_pending_outgoing_requests = Gauge( + "synapse_pending_outgoing_replication_requests", + "Number of active outgoing replication requests, by replication method name", + ["name"], +) + +_outgoing_request_counter = Counter( + "synapse_outgoing_replication_requests", + "Number of outgoing replication requests, by replication method name and result", + ["name", "code"], +) + -class ReplicationEndpoint: +class ReplicationEndpoint(metaclass=abc.ABCMeta): """Helper base class for defining new replication HTTP endpoints. This creates an endpoint under `/_synapse/replication/:NAME/:PATH_ARGS..` @@ -72,8 +82,6 @@ class ReplicationEndpoint: is received. """ - __metaclass__ = abc.ABCMeta - NAME = abc.abstractproperty() # type: str # type: ignore PATH_ARGS = abc.abstractproperty() # type: Tuple[str, ...] # type: ignore METHOD = "POST" @@ -140,7 +148,10 @@ class ReplicationEndpoint: instance_map = hs.config.worker.instance_map + outgoing_gauge = _pending_outgoing_requests.labels(cls.NAME) + @trace(opname="outgoing_replication_request") + @outgoing_gauge.track_inprogress() async def send_request(instance_name="master", **kwargs): if instance_name == local_instance_name: raise Exception("Trying to send HTTP request to self") @@ -195,23 +206,26 @@ class ReplicationEndpoint: try: result = await request_func(uri, data, headers=headers) break - except CodeMessageException as e: - if e.code != 504 or not cls.RETRY_ON_TIMEOUT: + except RequestTimedOutError: + if not cls.RETRY_ON_TIMEOUT: raise - logger.warning("%s request timed out", cls.NAME) + logger.warning("%s request timed out; retrying", cls.NAME) # If we timed out we probably don't need to worry about backing # off too much, but lets just wait a little anyway. await clock.sleep(1) except HttpResponseException as e: # We convert to SynapseError as we know that it was a SynapseError - # on the master process that we should send to the client. (And + # on the main process that we should send to the client. (And # importantly, not stack traces everywhere) + _outgoing_request_counter.labels(cls.NAME, e.code).inc() raise e.to_synapse_error() - except RequestSendFailed as e: - raise SynapseError(502, "Failed to talk to master") from e + except Exception as e: + _outgoing_request_counter.labels(cls.NAME, "ERR").inc() + raise SynapseError(502, "Failed to talk to main process") from e + _outgoing_request_counter.labels(cls.NAME, 200).inc() return result return send_request diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py index 20f3ba76c0..807b85d2e1 100644 --- a/synapse/replication/http/devices.py +++ b/synapse/replication/http/devices.py @@ -53,7 +53,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): CACHE = False def __init__(self, hs): - super(ReplicationUserDevicesResyncRestServlet, self).__init__(hs) + super().__init__(hs) self.device_list_updater = hs.get_device_handler().device_list_updater self.store = hs.get_datastore() diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 6b56315148..5393b9a9e7 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -57,7 +57,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): PATH_ARGS = () def __init__(self, hs): - super(ReplicationFederationSendEventsRestServlet, self).__init__(hs) + super().__init__(hs) self.store = hs.get_datastore() self.storage = hs.get_storage() @@ -65,10 +65,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): self.federation_handler = hs.get_handlers().federation_handler @staticmethod - async def _serialize_payload(store, event_and_contexts, backfilled): + async def _serialize_payload(store, room_id, event_and_contexts, backfilled): """ Args: store + room_id (str) event_and_contexts (list[tuple[FrozenEvent, EventContext]]) backfilled (bool): Whether or not the events are the result of backfilling @@ -88,7 +89,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): } ) - payload = {"events": event_payloads, "backfilled": backfilled} + payload = { + "events": event_payloads, + "backfilled": backfilled, + "room_id": room_id, + } return payload @@ -96,6 +101,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): with Measure(self.clock, "repl_fed_send_events_parse"): content = parse_json_object_from_request(request) + room_id = content["room_id"] backfilled = content["backfilled"] event_payloads = content["events"] @@ -120,7 +126,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): logger.info("Got %d events from federation", len(event_and_contexts)) max_stream_id = await self.federation_handler.persist_events_and_notify( - event_and_contexts, backfilled + room_id, event_and_contexts, backfilled ) return 200, {"max_stream_id": max_stream_id} @@ -144,7 +150,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): PATH_ARGS = ("edu_type",) def __init__(self, hs): - super(ReplicationFederationSendEduRestServlet, self).__init__(hs) + super().__init__(hs) self.store = hs.get_datastore() self.clock = hs.get_clock() @@ -187,7 +193,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint): CACHE = False def __init__(self, hs): - super(ReplicationGetQueryRestServlet, self).__init__(hs) + super().__init__(hs) self.store = hs.get_datastore() self.clock = hs.get_clock() @@ -230,7 +236,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint): PATH_ARGS = ("room_id",) def __init__(self, hs): - super(ReplicationCleanRoomRestServlet, self).__init__(hs) + super().__init__(hs) self.store = hs.get_datastore() diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py index fb326bb869..4c81e2d784 100644 --- a/synapse/replication/http/login.py +++ b/synapse/replication/http/login.py @@ -32,7 +32,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): PATH_ARGS = ("user_id",) def __init__(self, hs): - super(RegisterDeviceReplicationServlet, self).__init__(hs) + super().__init__(hs) self.registration_handler = hs.get_registration_handler() @staticmethod diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index 741329ab5f..30680baee8 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Optional from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint from synapse.types import JsonDict, Requester, UserID -from synapse.util.distributor import user_joined_room, user_left_room +from synapse.util.distributor import user_left_room if TYPE_CHECKING: from synapse.server import HomeServer @@ -45,7 +45,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): PATH_ARGS = ("room_id", "user_id") def __init__(self, hs): - super(ReplicationRemoteJoinRestServlet, self).__init__(hs) + super().__init__(hs) self.federation_handler = hs.get_handlers().federation_handler self.store = hs.get_datastore() @@ -107,7 +107,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): PATH_ARGS = ("invite_event_id",) def __init__(self, hs: "HomeServer"): - super(ReplicationRemoteRejectInviteRestServlet, self).__init__(hs) + super().__init__(hs) self.store = hs.get_datastore() self.clock = hs.get_clock() @@ -168,7 +168,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): CACHE = False # No point caching as should return instantly. def __init__(self, hs): - super(ReplicationUserJoinedLeftRoomRestServlet, self).__init__(hs) + super().__init__(hs) self.registeration_handler = hs.get_registration_handler() self.store = hs.get_datastore() @@ -181,9 +181,9 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): Args: room_id (str) user_id (str) - change (str): Either "joined" or "left" + change (str): "left" """ - assert change in ("joined", "left") + assert change == "left" return {} @@ -192,9 +192,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): user = UserID.from_string(user_id) - if change == "joined": - user_joined_room(self.distributor, user, room_id) - elif change == "left": + if change == "left": user_left_room(self.distributor, user, room_id) else: raise Exception("Unrecognized change: %r", change) diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index a02b27474d..7b12ec9060 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -29,7 +29,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): PATH_ARGS = ("user_id",) def __init__(self, hs): - super(ReplicationRegisterServlet, self).__init__(hs) + super().__init__(hs) self.store = hs.get_datastore() self.registration_handler = hs.get_registration_handler() @@ -104,7 +104,7 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint): PATH_ARGS = ("user_id",) def __init__(self, hs): - super(ReplicationPostRegisterActionsServlet, self).__init__(hs) + super().__init__(hs) self.store = hs.get_datastore() self.registration_handler = hs.get_registration_handler() diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index f13d452426..9a3a694d5d 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -52,7 +52,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): PATH_ARGS = ("event_id",) def __init__(self, hs): - super(ReplicationSendEventRestServlet, self).__init__(hs) + super().__init__(hs) self.event_creation_handler = hs.get_event_creation_handler() self.store = hs.get_datastore() diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index 60f2e1245f..d0089fe06c 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -26,16 +26,18 @@ logger = logging.getLogger(__name__) class BaseSlavedStore(CacheInvalidationWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(BaseSlavedStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) if isinstance(self.database_engine, PostgresEngine): self._cache_id_gen = MultiWriterIdGenerator( db_conn, database, + stream_name="caches", instance_name=hs.get_instance_name(), table="cache_invalidation_stream_by_instance", instance_column="instance_name", id_column="stream_id", sequence_name="cache_invalidation_stream_seq", + writers=[], ) # type: Optional[MultiWriterIdGenerator] else: self._cache_id_gen = None diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index bb66ba9b80..4268565fc8 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -34,7 +34,7 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved ], ) - super(SlavedAccountDataStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) def get_max_account_data_stream_id(self): return self._account_data_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index a6fdedde63..1f8dafe7ea 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -22,7 +22,7 @@ from ._base import BaseSlavedStore class SlavedClientIpStore(BaseSlavedStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(SlavedClientIpStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.client_ip_last_seen = Cache( name="client_ip_last_seen", keylen=4, max_entries=50000 diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index 533d927701..5b045bed02 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -24,7 +24,7 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self._device_inbox_id_gen = SlavedIdTracker( db_conn, "device_inbox", "stream_id" ) diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 3b788c9625..e0d86240dd 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -24,7 +24,7 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(SlavedDeviceStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.hs = hs diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index da1cc836cf..fbffe6d85c 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -56,7 +56,7 @@ class SlavedEventStore( BaseSlavedStore, ): def __init__(self, database: DatabasePool, db_conn, hs): - super(SlavedEventStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) events_max = self._stream_id_gen.get_current_token() curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict( diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py index 2562b6fc38..6a23252861 100644 --- a/synapse/replication/slave/storage/filtering.py +++ b/synapse/replication/slave/storage/filtering.py @@ -21,7 +21,7 @@ from ._base import BaseSlavedStore class SlavedFilteringStore(BaseSlavedStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(SlavedFilteringStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) # Filters are immutable so this cache doesn't need to be expired get_user_filter = FilteringStore.__dict__["get_user_filter"] diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py index 567b4a5cc1..30955bcbfe 100644 --- a/synapse/replication/slave/storage/groups.py +++ b/synapse/replication/slave/storage/groups.py @@ -23,7 +23,7 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(SlavedGroupServerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.hs = hs diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index 025f6f6be8..55620c03d8 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -25,7 +25,7 @@ from ._slaved_id_tracker import SlavedIdTracker class SlavedPresenceStore(BaseSlavedStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(SlavedPresenceStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id") self._presence_on_startup = self._get_active_presence(db_conn) # type: ignore diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index 9da218bfe8..c418730ba8 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -24,7 +24,7 @@ from ._slaved_id_tracker import SlavedIdTracker class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(SlavedPusherStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self._pushers_id_gen = SlavedIdTracker( db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] ) diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index 5c2986e050..6195917376 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -30,7 +30,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): db_conn, "receipts_linearized", "stream_id" ) - super(SlavedReceiptsStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) def get_max_receipt_stream_id(self): return self._receipts_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py index 80ae803ad9..109ac6bea1 100644 --- a/synapse/replication/slave/storage/room.py +++ b/synapse/replication/slave/storage/room.py @@ -23,7 +23,7 @@ from ._slaved_id_tracker import SlavedIdTracker class RoomStore(RoomWorkerStore, BaseSlavedStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(RoomStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self._public_room_id_gen = SlavedIdTracker( db_conn, "public_room_list_stream", "stream_id" ) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index d6ecf5b327..e165429cad 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -29,6 +29,7 @@ from synapse.replication.tcp.streams.events import ( EventsStreamEventRow, EventsStreamRow, ) +from synapse.types import PersistedEventPosition, UserID from synapse.util.async_helpers import timeout_deferred from synapse.util.metrics import Measure @@ -98,7 +99,6 @@ class ReplicationDataHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() - self.pusher_pool = hs.get_pusherpool() self.notifier = hs.get_notifier() self._reactor = hs.get_reactor() self._clock = hs.get_clock() @@ -148,13 +148,15 @@ class ReplicationDataHandler: if event.rejected_reason: continue - extra_users = () # type: Tuple[str, ...] + extra_users = () # type: Tuple[UserID, ...] if event.type == EventTypes.Member: - extra_users = (event.state_key,) - max_token = self.store.get_room_max_stream_ordering() - self.notifier.on_new_room_event(event, token, max_token, extra_users) + extra_users = (UserID.from_string(event.state_key),) - await self.pusher_pool.on_new_notifications(token, token) + max_token = self.store.get_room_max_token() + event_pos = PersistedEventPosition(instance_name, token) + self.notifier.on_new_room_event( + event, event_pos, max_token, extra_users + ) # Notify any waiting deferreds. The list is ordered by position so we # just iterate through the list until we reach a position that is diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 1c303f3a46..b323841f73 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -109,7 +109,7 @@ class ReplicationCommandHandler: if isinstance(stream, (EventsStream, BackfillStream)): # Only add EventStream and BackfillStream as a source on the # instance in charge of event persistence. - if hs.config.worker.writers.events == hs.get_instance_name(): + if hs.get_instance_name() in hs.config.worker.writers.events: self._streams_to_replicate.append(stream) continue diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 04d894fb3d..687984e7a8 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -93,7 +93,7 @@ class ReplicationStreamer: """ if not self.command_handler.connected(): # Don't bother if nothing is listening. We still need to advance - # the stream tokens otherwise they'll fall beihind forever + # the stream tokens otherwise they'll fall behind forever for stream in self.streams: stream.discard_updates_and_advance() return diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 682d47f402..54dccd15a6 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -345,7 +345,7 @@ class PushRulesStream(Stream): def __init__(self, hs): self.store = hs.get_datastore() - super(PushRulesStream, self).__init__( + super().__init__( hs.get_instance_name(), self._current_token, self.store.get_all_push_rule_updates, @@ -383,7 +383,7 @@ class CachesStream(Stream): the cache on the workers """ - @attr.s + @attr.s(slots=True) class CachesStreamRow: """Stream to inform workers they should invalidate their cache. @@ -441,7 +441,7 @@ class DeviceListsStream(Stream): told about a device update. """ - @attr.s + @attr.s(slots=True) class DeviceListsStreamRow: entity = attr.ib(type=str) diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index f929fc3954..ccc7ca30d8 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -19,7 +19,7 @@ from typing import List, Tuple, Type import attr -from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance +from ._base import Stream, StreamUpdateResult, Token """Handling of the 'events' replication stream @@ -117,7 +117,7 @@ class EventsStream(Stream): self._store = hs.get_datastore() super().__init__( hs.get_instance_name(), - current_token_without_instance(self._store.get_current_events_token), + self._store._stream_id_gen.get_current_token_for_writer, self._update_function, ) diff --git a/synapse/res/templates/auth_success.html b/synapse/res/templates/auth_success.html new file mode 100644 index 0000000000..baf4633142 --- /dev/null +++ b/synapse/res/templates/auth_success.html @@ -0,0 +1,21 @@ +<html> +<head> +<title>Success!</title> +<meta name='viewport' content='width=device-width, initial-scale=1, + user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'> +<link rel="stylesheet" href="/_matrix/static/client/register/style.css"> +<script> +if (window.onAuthDone) { + window.onAuthDone(); +} else if (window.opener && window.opener.postMessage) { + window.opener.postMessage("authDone", "*"); +} +</script> +</head> +<body> + <div> + <p>Thank you</p> + <p>You may now close this window and return to the application</p> + </div> +</body> +</html> diff --git a/synapse/res/templates/password_reset_confirmation.html b/synapse/res/templates/password_reset_confirmation.html new file mode 100644 index 0000000000..def4b5162b --- /dev/null +++ b/synapse/res/templates/password_reset_confirmation.html @@ -0,0 +1,16 @@ +<html> +<head></head> +<body> +<!--Use a hidden form to resubmit the information necessary to reset the password--> +<form method="post"> + <input type="hidden" name="sid" value="{{ sid }}"> + <input type="hidden" name="token" value="{{ token }}"> + <input type="hidden" name="client_secret" value="{{ client_secret }}"> + + <p>You have requested to <strong>reset your Matrix account password</strong>. Click the link below to confirm this action. <br /><br /> + If you did not mean to do this, please close this page and your password will not be changed.</p> + <p><button type="submit">Confirm changing my password</button></p> +</form> +</body> +</html> + diff --git a/synapse/res/templates/recaptcha.html b/synapse/res/templates/recaptcha.html new file mode 100644 index 0000000000..63944dc608 --- /dev/null +++ b/synapse/res/templates/recaptcha.html @@ -0,0 +1,38 @@ +<html> +<head> +<title>Authentication</title> +<meta name='viewport' content='width=device-width, initial-scale=1, + user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'> +<script src="https://www.recaptcha.net/recaptcha/api.js" + async defer></script> +<script src="//code.jquery.com/jquery-1.11.2.min.js"></script> +<link rel="stylesheet" href="/_matrix/static/client/register/style.css"> +<script> +function captchaDone() { + $('#registrationForm').submit(); +} +</script> +</head> +<body> +<form id="registrationForm" method="post" action="{{ myurl }}"> + <div> + <p> + Hello! We need to prevent computer programs and other automated + things from creating accounts on this server. + </p> + <p> + Please verify that you're not a robot. + </p> + <input type="hidden" name="session" value="{{ session }}" /> + <div class="g-recaptcha" + data-sitekey="{{ sitekey }}" + data-callback="captchaDone"> + </div> + <noscript> + <input type="submit" value="All Done" /> + </noscript> + </div> + </div> +</form> +</body> +</html> diff --git a/synapse/res/templates/saml_error.html b/synapse/res/templates/saml_error.html deleted file mode 100644 index 01cd9bdaf3..0000000000 --- a/synapse/res/templates/saml_error.html +++ /dev/null @@ -1,52 +0,0 @@ -<!DOCTYPE html> -<html lang="en"> -<head> - <meta charset="UTF-8"> - <title>SSO login error</title> -</head> -<body> -{# a 403 means we have actively rejected their login #} -{% if code == 403 %} - <p>You are not allowed to log in here.</p> -{% else %} - <p> - There was an error during authentication: - </p> - <div id="errormsg" style="margin:20px 80px">{{ msg }}</div> - <p> - If you are seeing this page after clicking a link sent to you via email, make - sure you only click the confirmation link once, and that you open the - validation link in the same client you're logging in from. - </p> - <p> - Try logging in again from your Matrix client and if the problem persists - please contact the server's administrator. - </p> - - <script type="text/javascript"> - // Error handling to support Auth0 errors that we might get through a GET request - // to the validation endpoint. If an error is provided, it's either going to be - // located in the query string or in a query string-like URI fragment. - // We try to locate the error from any of these two locations, but if we can't - // we just don't print anything specific. - let searchStr = ""; - if (window.location.search) { - // window.location.searchParams isn't always defined when - // window.location.search is, so it's more reliable to parse the latter. - searchStr = window.location.search; - } else if (window.location.hash) { - // Replace the # with a ? so that URLSearchParams does the right thing and - // doesn't parse the first parameter incorrectly. - searchStr = window.location.hash.replace("#", "?"); - } - - // We might end up with no error in the URL, so we need to check if we have one - // to print one. - let errorDesc = new URLSearchParams(searchStr).get("error_description") - if (errorDesc) { - document.getElementById("errormsg").innerText = errorDesc; - } - </script> -{% endif %} -</body> -</html> diff --git a/synapse/res/templates/sso_error.html b/synapse/res/templates/sso_error.html index 43a211386b..944bc9c9ca 100644 --- a/synapse/res/templates/sso_error.html +++ b/synapse/res/templates/sso_error.html @@ -5,14 +5,49 @@ <title>SSO error</title> </head> <body> - <p>Oops! Something went wrong during authentication.</p> +{# If an error of unauthorised is returned it means we have actively rejected their login #} +{% if error == "unauthorised" %} + <p>You are not allowed to log in here.</p> +{% else %} + <p> + There was an error during authentication: + </p> + <div id="errormsg" style="margin:20px 80px">{{ error_description | e }}</div> + <p> + If you are seeing this page after clicking a link sent to you via email, make + sure you only click the confirmation link once, and that you open the + validation link in the same client you're logging in from. + </p> <p> Try logging in again from your Matrix client and if the problem persists please contact the server's administrator. </p> <p>Error: <code>{{ error }}</code></p> - {% if error_description %} - <pre><code>{{ error_description }}</code></pre> - {% endif %} + + <script type="text/javascript"> + // Error handling to support Auth0 errors that we might get through a GET request + // to the validation endpoint. If an error is provided, it's either going to be + // located in the query string or in a query string-like URI fragment. + // We try to locate the error from any of these two locations, but if we can't + // we just don't print anything specific. + let searchStr = ""; + if (window.location.search) { + // window.location.searchParams isn't always defined when + // window.location.search is, so it's more reliable to parse the latter. + searchStr = window.location.search; + } else if (window.location.hash) { + // Replace the # with a ? so that URLSearchParams does the right thing and + // doesn't parse the first parameter incorrectly. + searchStr = window.location.hash.replace("#", "?"); + } + + // We might end up with no error in the URL, so we need to check if we have one + // to print one. + let errorDesc = new URLSearchParams(searchStr).get("error_description") + if (errorDesc) { + document.getElementById("errormsg").innerText = errorDesc; + } + </script> +{% endif %} </body> </html> diff --git a/synapse/res/templates/terms.html b/synapse/res/templates/terms.html new file mode 100644 index 0000000000..dfef9897ee --- /dev/null +++ b/synapse/res/templates/terms.html @@ -0,0 +1,20 @@ +<html> +<head> +<title>Authentication</title> +<meta name='viewport' content='width=device-width, initial-scale=1, + user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'> +<link rel="stylesheet" href="/_matrix/static/client/register/style.css"> +</head> +<body> +<form id="registrationForm" method="post" action="{{ myurl }}"> + <div> + <p> + Please click the button below if you agree to the + <a href="{{ terms_url }}">privacy policy of this homeserver.</a> + </p> + <input type="hidden" name="session" value="{{ session }}" /> + <input type="submit" value="Agree" /> + </div> +</form> +</body> +</html> diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 87f927890c..40f5c32db2 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -13,8 +13,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import synapse.rest.admin from synapse.http.server import JsonResource +from synapse.rest import admin from synapse.rest.client import versions from synapse.rest.client.v1 import ( directory, @@ -123,9 +123,7 @@ class ClientRestResource(JsonResource): password_policy.register_servlets(hs, client_resource) # moving to /_synapse/admin - synapse.rest.admin.register_servlets_for_client_rest_resource( - hs, client_resource - ) + admin.register_servlets_for_client_rest_resource(hs, client_resource) # unstable shared_rooms.register_servlets(hs, client_resource) diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 1c88c93f38..57cac22252 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -16,13 +16,13 @@ import logging import platform -import re import synapse from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.server import JsonResource from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.rest.admin._base import ( + admin_patterns, assert_requester_is_admin, historical_admin_path_patterns, ) @@ -31,6 +31,7 @@ from synapse.rest.admin.devices import ( DeviceRestServlet, DevicesRestServlet, ) +from synapse.rest.admin.event_reports import EventReportsRestServlet from synapse.rest.admin.groups import DeleteGroupAdminRestServlet from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet @@ -49,6 +50,7 @@ from synapse.rest.admin.users import ( ResetPasswordRestServlet, SearchUsersRestServlet, UserAdminServlet, + UserMembershipRestServlet, UserRegisterServlet, UserRestServletV2, UsersRestServlet, @@ -61,7 +63,7 @@ logger = logging.getLogger(__name__) class VersionServlet(RestServlet): - PATTERNS = (re.compile("^/_synapse/admin/v1/server_version$"),) + PATTERNS = admin_patterns("/server_version$") def __init__(self, hs): self.res = { @@ -107,7 +109,8 @@ class PurgeHistoryRestServlet(RestServlet): if event.room_id != room_id: raise SynapseError(400, "Event is for wrong room.") - token = await self.store.get_topological_token_for_event(event_id) + room_token = await self.store.get_topological_token_for_event(event_id) + token = await room_token.to_string(self.store) logger.info("[purge] purging up to token %s (event_id %s)", token, event_id) elif "purge_up_to_ts" in body: @@ -209,11 +212,13 @@ def register_servlets(hs, http_server): SendServerNoticeServlet(hs).register(http_server) VersionServlet(hs).register(http_server) UserAdminServlet(hs).register(http_server) + UserMembershipRestServlet(hs).register(http_server) UserRestServletV2(hs).register(http_server) UsersRestServletV2(hs).register(http_server) DeviceRestServlet(hs).register(http_server) DevicesRestServlet(hs).register(http_server) DeleteDevicesRestServlet(hs).register(http_server) + EventReportsRestServlet(hs).register(http_server) def register_servlets_for_client_rest_resource(hs, http_server): diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py index d82eaf5e38..db9fea263a 100644 --- a/synapse/rest/admin/_base.py +++ b/synapse/rest/admin/_base.py @@ -44,7 +44,7 @@ def historical_admin_path_patterns(path_regex): ] -def admin_patterns(path_regex: str): +def admin_patterns(path_regex: str, version: str = "v1"): """Returns the list of patterns for an admin endpoint Args: @@ -54,7 +54,7 @@ def admin_patterns(path_regex: str): Returns: A list of regex patterns. """ - admin_prefix = "^/_synapse/admin/v1" + admin_prefix = "^/_synapse/admin/" + version patterns = [re.compile(admin_prefix + path_regex)] return patterns diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py index 8d32677339..a163863322 100644 --- a/synapse/rest/admin/devices.py +++ b/synapse/rest/admin/devices.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import re from synapse.api.errors import NotFoundError, SynapseError from synapse.http.servlet import ( @@ -21,7 +20,7 @@ from synapse.http.servlet import ( assert_params_in_dict, parse_json_object_from_request, ) -from synapse.rest.admin._base import assert_requester_is_admin +from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin from synapse.types import UserID logger = logging.getLogger(__name__) @@ -32,14 +31,12 @@ class DeviceRestServlet(RestServlet): Get, update or delete the given user's device """ - PATTERNS = ( - re.compile( - "^/_synapse/admin/v2/users/(?P<user_id>[^/]*)/devices/(?P<device_id>[^/]*)$" - ), + PATTERNS = admin_patterns( + "/users/(?P<user_id>[^/]*)/devices/(?P<device_id>[^/]*)$", "v2" ) def __init__(self, hs): - super(DeviceRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() @@ -98,7 +95,7 @@ class DevicesRestServlet(RestServlet): Retrieve the given user's devices """ - PATTERNS = (re.compile("^/_synapse/admin/v2/users/(?P<user_id>[^/]*)/devices$"),) + PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/devices$", "v2") def __init__(self, hs): """ @@ -131,9 +128,7 @@ class DeleteDevicesRestServlet(RestServlet): key which lists the device_ids to delete. """ - PATTERNS = ( - re.compile("^/_synapse/admin/v2/users/(?P<user_id>[^/]*)/delete_devices$"), - ) + PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/delete_devices$", "v2") def __init__(self, hs): self.hs = hs diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py new file mode 100644 index 0000000000..5b8d0594cd --- /dev/null +++ b/synapse/rest/admin/event_reports.py @@ -0,0 +1,88 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Dirk Klimpel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.api.errors import Codes, SynapseError +from synapse.http.servlet import RestServlet, parse_integer, parse_string +from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin + +logger = logging.getLogger(__name__) + + +class EventReportsRestServlet(RestServlet): + """ + List all reported events that are known to the homeserver. Results are returned + in a dictionary containing report information. Supports pagination. + The requester must have administrator access in Synapse. + + GET /_synapse/admin/v1/event_reports + returns: + 200 OK with list of reports if success otherwise an error. + + Args: + The parameters `from` and `limit` are required only for pagination. + By default, a `limit` of 100 is used. + The parameter `dir` can be used to define the order of results. + The parameter `user_id` can be used to filter by user id. + The parameter `room_id` can be used to filter by room id. + Returns: + A list of reported events and an integer representing the total number of + reported events that exist given this query + """ + + PATTERNS = admin_patterns("/event_reports$") + + def __init__(self, hs): + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + async def on_GET(self, request): + await assert_requester_is_admin(self.auth, request) + + start = parse_integer(request, "from", default=0) + limit = parse_integer(request, "limit", default=100) + direction = parse_string(request, "dir", default="b") + user_id = parse_string(request, "user_id") + room_id = parse_string(request, "room_id") + + if start < 0: + raise SynapseError( + 400, + "The start parameter must be a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + if limit < 0: + raise SynapseError( + 400, + "The limit parameter must be a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + if direction not in ("f", "b"): + raise SynapseError( + 400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM + ) + + event_reports, total = await self.store.get_event_reports_paginate( + start, limit, direction, user_id, room_id + ) + ret = {"event_reports": event_reports, "total": total} + if (start + limit) < total: + ret["next_token"] = start + len(event_reports) + + return 200, ret diff --git a/synapse/rest/admin/purge_room_servlet.py b/synapse/rest/admin/purge_room_servlet.py index f474066542..8b7bb6d44e 100644 --- a/synapse/rest/admin/purge_room_servlet.py +++ b/synapse/rest/admin/purge_room_servlet.py @@ -12,14 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import re - from synapse.http.servlet import ( RestServlet, assert_params_in_dict, parse_json_object_from_request, ) from synapse.rest.admin import assert_requester_is_admin +from synapse.rest.admin._base import admin_patterns class PurgeRoomServlet(RestServlet): @@ -35,7 +34,7 @@ class PurgeRoomServlet(RestServlet): {} """ - PATTERNS = (re.compile("^/_synapse/admin/v1/purge_room$"),) + PATTERNS = admin_patterns("/purge_room$") def __init__(self, hs): """ diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py index 6e9a874121..375d055445 100644 --- a/synapse/rest/admin/server_notice_servlet.py +++ b/synapse/rest/admin/server_notice_servlet.py @@ -12,8 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import re - from synapse.api.constants import EventTypes from synapse.api.errors import SynapseError from synapse.http.servlet import ( @@ -22,6 +20,7 @@ from synapse.http.servlet import ( parse_json_object_from_request, ) from synapse.rest.admin import assert_requester_is_admin +from synapse.rest.admin._base import admin_patterns from synapse.rest.client.transactions import HttpTransactionCache from synapse.types import UserID @@ -56,13 +55,13 @@ class SendServerNoticeServlet(RestServlet): self.snm = hs.get_server_notices_manager() def register(self, json_resource): - PATTERN = "^/_synapse/admin/v1/send_server_notice" + PATTERN = "/send_server_notice" json_resource.register_paths( - "POST", (re.compile(PATTERN + "$"),), self.on_POST, self.__class__.__name__ + "POST", admin_patterns(PATTERN + "$"), self.on_POST, self.__class__.__name__ ) json_resource.register_paths( "PUT", - (re.compile(PATTERN + "/(?P<txn_id>[^/]*)$"),), + admin_patterns(PATTERN + "/(?P<txn_id>[^/]*)$"), self.on_PUT, self.__class__.__name__, ) diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index f3e77da850..20dc1d0e05 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -15,7 +15,6 @@ import hashlib import hmac import logging -import re from http import HTTPStatus from synapse.api.constants import UserTypes @@ -29,6 +28,7 @@ from synapse.http.servlet import ( parse_string, ) from synapse.rest.admin._base import ( + admin_patterns, assert_requester_is_admin, assert_user_is_admin, historical_admin_path_patterns, @@ -60,7 +60,7 @@ class UsersRestServlet(RestServlet): class UsersRestServletV2(RestServlet): - PATTERNS = (re.compile("^/_synapse/admin/v2/users$"),) + PATTERNS = admin_patterns("/users$", "v2") """Get request to list all local users. This needs user to have administrator access in Synapse. @@ -105,7 +105,7 @@ class UsersRestServletV2(RestServlet): class UserRestServletV2(RestServlet): - PATTERNS = (re.compile("^/_synapse/admin/v2/users/(?P<user_id>[^/]+)$"),) + PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)$", "v2") """Get request to list user details. This needs user to have administrator access in Synapse. @@ -642,7 +642,7 @@ class UserAdminServlet(RestServlet): {} """ - PATTERNS = (re.compile("^/_synapse/admin/v1/users/(?P<user_id>[^/]*)/admin$"),) + PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/admin$") def __init__(self, hs): self.hs = hs @@ -683,3 +683,29 @@ class UserAdminServlet(RestServlet): await self.store.set_server_admin(target_user, set_admin_to) return 200, {} + + +class UserMembershipRestServlet(RestServlet): + """ + Get room list of an user. + """ + + PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/joined_rooms$") + + def __init__(self, hs): + self.is_mine = hs.is_mine + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + async def on_GET(self, request, user_id): + await assert_requester_is_admin(self.auth, request) + + if not self.is_mine(UserID.from_string(user_id)): + raise SynapseError(400, "Can only lookup local users") + + room_ids = await self.store.get_rooms_for_user(user_id) + if not room_ids: + raise NotFoundError("User not found") + + ret = {"joined_rooms": list(room_ids), "total": len(room_ids)} + return 200, ret diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index b210015173..faabeeb91c 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -40,7 +40,7 @@ class ClientDirectoryServer(RestServlet): PATTERNS = client_patterns("/directory/room/(?P<room_alias>[^/]*)$", v1=True) def __init__(self, hs): - super(ClientDirectoryServer, self).__init__() + super().__init__() self.store = hs.get_datastore() self.handlers = hs.get_handlers() self.auth = hs.get_auth() @@ -120,7 +120,7 @@ class ClientDirectoryListServer(RestServlet): PATTERNS = client_patterns("/directory/list/room/(?P<room_id>[^/]*)$", v1=True) def __init__(self, hs): - super(ClientDirectoryListServer, self).__init__() + super().__init__() self.store = hs.get_datastore() self.handlers = hs.get_handlers() self.auth = hs.get_auth() @@ -160,7 +160,7 @@ class ClientAppserviceDirectoryListServer(RestServlet): ) def __init__(self, hs): - super(ClientAppserviceDirectoryListServer, self).__init__() + super().__init__() self.store = hs.get_datastore() self.handlers = hs.get_handlers() self.auth = hs.get_auth() diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index 25effd0261..1ecb77aa26 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -30,9 +30,10 @@ class EventStreamRestServlet(RestServlet): DEFAULT_LONGPOLL_TIME_MS = 30000 def __init__(self, hs): - super(EventStreamRestServlet, self).__init__() + super().__init__() self.event_stream_handler = hs.get_event_stream_handler() self.auth = hs.get_auth() + self.store = hs.get_datastore() async def on_GET(self, request): requester = await self.auth.get_user_by_req(request, allow_guest=True) @@ -44,7 +45,7 @@ class EventStreamRestServlet(RestServlet): if b"room_id" in request.args: room_id = request.args[b"room_id"][0].decode("ascii") - pagin_config = PaginationConfig.from_request(request) + pagin_config = await PaginationConfig.from_request(self.store, request) timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS if b"timeout" in request.args: try: @@ -74,7 +75,7 @@ class EventRestServlet(RestServlet): PATTERNS = client_patterns("/events/(?P<event_id>[^/]*)$", v1=True) def __init__(self, hs): - super(EventRestServlet, self).__init__() + super().__init__() self.clock = hs.get_clock() self.event_handler = hs.get_event_handler() self.auth = hs.get_auth() diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index 910b3b4eeb..91da0ee573 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -24,14 +24,15 @@ class InitialSyncRestServlet(RestServlet): PATTERNS = client_patterns("/initialSync$", v1=True) def __init__(self, hs): - super(InitialSyncRestServlet, self).__init__() + super().__init__() self.initial_sync_handler = hs.get_initial_sync_handler() self.auth = hs.get_auth() + self.store = hs.get_datastore() async def on_GET(self, request): requester = await self.auth.get_user_by_req(request) as_client_event = b"raw" not in request.args - pagination_config = PaginationConfig.from_request(request) + pagination_config = await PaginationConfig.from_request(self.store, request) include_archived = parse_boolean(request, "archived", default=False) content = await self.initial_sync_handler.snapshot_all_rooms( user_id=requester.user.to_string(), diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index a14618ac84..3d1693d7ac 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -18,6 +18,7 @@ from typing import Awaitable, Callable, Dict, Optional from synapse.api.errors import Codes, LoginError, SynapseError from synapse.api.ratelimiting import Ratelimiter +from synapse.appservice import ApplicationService from synapse.handlers.auth import ( convert_client_dict_legacy_fields_to_identifier, login_id_phone_to_thirdparty, @@ -44,9 +45,10 @@ class LoginRestServlet(RestServlet): TOKEN_TYPE = "m.login.token" JWT_TYPE = "org.matrix.login.jwt" JWT_TYPE_DEPRECATED = "m.login.jwt" + APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service" def __init__(self, hs): - super(LoginRestServlet, self).__init__() + super().__init__() self.hs = hs # JWT configuration variables. @@ -61,6 +63,8 @@ class LoginRestServlet(RestServlet): self.cas_enabled = hs.config.cas_enabled self.oidc_enabled = hs.config.oidc_enabled + self.auth = hs.get_auth() + self.auth_handler = self.hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() self.handlers = hs.get_handlers() @@ -116,8 +120,12 @@ class LoginRestServlet(RestServlet): self._address_ratelimiter.ratelimit(request.getClientIP()) login_submission = parse_json_object_from_request(request) + try: - if self.jwt_enabled and ( + if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE: + appservice = self.auth.get_appservice_by_req(request) + result = await self._do_appservice_login(login_submission, appservice) + elif self.jwt_enabled and ( login_submission["type"] == LoginRestServlet.JWT_TYPE or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED ): @@ -134,6 +142,33 @@ class LoginRestServlet(RestServlet): result["well_known"] = well_known_data return 200, result + def _get_qualified_user_id(self, identifier): + if identifier["type"] != "m.id.user": + raise SynapseError(400, "Unknown login identifier type") + if "user" not in identifier: + raise SynapseError(400, "User identifier is missing 'user' key") + + if identifier["user"].startswith("@"): + return identifier["user"] + else: + return UserID(identifier["user"], self.hs.hostname).to_string() + + async def _do_appservice_login( + self, login_submission: JsonDict, appservice: ApplicationService + ): + logger.info( + "Got appservice login request with identifier: %r", + login_submission.get("identifier"), + ) + + identifier = convert_client_dict_legacy_fields_to_identifier(login_submission) + qualified_user_id = self._get_qualified_user_id(identifier) + + if not appservice.is_interested_in_user(qualified_user_id): + raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN) + + return await self._complete_login(qualified_user_id, login_submission) + async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]: """Handle non-token/saml/jwt logins @@ -219,15 +254,7 @@ class LoginRestServlet(RestServlet): # by this point, the identifier should be an m.id.user: if it's anything # else, we haven't understood it. - if identifier["type"] != "m.id.user": - raise SynapseError(400, "Unknown login identifier type") - if "user" not in identifier: - raise SynapseError(400, "User identifier is missing 'user' key") - - if identifier["user"].startswith("@"): - qualified_user_id = identifier["user"] - else: - qualified_user_id = UserID(identifier["user"], self.hs.hostname).to_string() + qualified_user_id = self._get_qualified_user_id(identifier) # Check if we've hit the failed ratelimit (but don't update it) self._failed_attempts_ratelimiter.ratelimit( @@ -255,9 +282,7 @@ class LoginRestServlet(RestServlet): self, user_id: str, login_submission: JsonDict, - callback: Optional[ - Callable[[Dict[str, str]], Awaitable[Dict[str, str]]] - ] = None, + callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None, create_non_existent_users: bool = False, ) -> Dict[str, str]: """Called when we've successfully authed the user and now need to @@ -270,12 +295,12 @@ class LoginRestServlet(RestServlet): Args: user_id: ID of the user to register. login_submission: Dictionary of login information. - callback: Callback function to run after registration. + callback: Callback function to run after login. create_non_existent_users: Whether to create the user if they don't exist. Defaults to False. Returns: - result: Dictionary of account information after successful registration. + result: Dictionary of account information after successful login. """ # Before we actually log them in we check if they've already logged in @@ -310,14 +335,24 @@ class LoginRestServlet(RestServlet): return result async def _do_token_login(self, login_submission: JsonDict) -> Dict[str, str]: + """ + Handle the final stage of SSO login. + + Args: + login_submission: The JSON request body. + + Returns: + The body of the JSON response. + """ token = login_submission["token"] auth_handler = self.auth_handler user_id = await auth_handler.validate_short_term_login_token_and_get_user_id( token ) - result = await self._complete_login(user_id, login_submission) - return result + return await self._complete_login( + user_id, login_submission, self.auth_handler._sso_login_callback + ) async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]: token = login_submission.get("token", None) @@ -400,7 +435,7 @@ class CasTicketServlet(RestServlet): PATTERNS = client_patterns("/login/cas/ticket", v1=True) def __init__(self, hs): - super(CasTicketServlet, self).__init__() + super().__init__() self._cas_handler = hs.get_cas_handler() async def on_GET(self, request: SynapseRequest) -> None: diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py index b0c30b65be..f792b50cdc 100644 --- a/synapse/rest/client/v1/logout.py +++ b/synapse/rest/client/v1/logout.py @@ -25,7 +25,7 @@ class LogoutRestServlet(RestServlet): PATTERNS = client_patterns("/logout$", v1=True) def __init__(self, hs): - super(LogoutRestServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() @@ -53,7 +53,7 @@ class LogoutAllRestServlet(RestServlet): PATTERNS = client_patterns("/logout/all$", v1=True) def __init__(self, hs): - super(LogoutAllRestServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index 970fdd5834..79d8e3057f 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -30,7 +30,7 @@ class PresenceStatusRestServlet(RestServlet): PATTERNS = client_patterns("/presence/(?P<user_id>[^/]*)/status", v1=True) def __init__(self, hs): - super(PresenceStatusRestServlet, self).__init__() + super().__init__() self.hs = hs self.presence_handler = hs.get_presence_handler() self.clock = hs.get_clock() diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index e7fe50ed72..b686cd671f 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -25,7 +25,7 @@ class ProfileDisplaynameRestServlet(RestServlet): PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/displayname", v1=True) def __init__(self, hs): - super(ProfileDisplaynameRestServlet, self).__init__() + super().__init__() self.hs = hs self.profile_handler = hs.get_profile_handler() self.auth = hs.get_auth() @@ -73,7 +73,7 @@ class ProfileAvatarURLRestServlet(RestServlet): PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/avatar_url", v1=True) def __init__(self, hs): - super(ProfileAvatarURLRestServlet, self).__init__() + super().__init__() self.hs = hs self.profile_handler = hs.get_profile_handler() self.auth = hs.get_auth() @@ -124,7 +124,7 @@ class ProfileRestServlet(RestServlet): PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)", v1=True) def __init__(self, hs): - super(ProfileRestServlet, self).__init__() + super().__init__() self.hs = hs self.profile_handler = hs.get_profile_handler() self.auth = hs.get_auth() diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index e781a3bcf4..f9eecb7cf5 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -38,7 +38,7 @@ class PushRuleRestServlet(RestServlet): ) def __init__(self, hs): - super(PushRuleRestServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.notifier = hs.get_notifier() @@ -163,6 +163,18 @@ class PushRuleRestServlet(RestServlet): self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id]) async def set_rule_attr(self, user_id, spec, val): + if spec["attr"] not in ("enabled", "actions"): + # for the sake of potential future expansion, shouldn't report + # 404 in the case of an unknown request so check it corresponds to + # a known attribute first. + raise UnrecognizedRequestError() + + namespaced_rule_id = _namespaced_rule_id_from_spec(spec) + rule_id = spec["rule_id"] + is_default_rule = rule_id.startswith(".") + if is_default_rule: + if namespaced_rule_id not in BASE_RULE_IDS: + raise NotFoundError("Unknown rule %s" % (namespaced_rule_id,)) if spec["attr"] == "enabled": if isinstance(val, dict) and "enabled" in val: val = val["enabled"] @@ -171,9 +183,8 @@ class PushRuleRestServlet(RestServlet): # This should *actually* take a dict, but many clients pass # bools directly, so let's not break them. raise SynapseError(400, "Value for 'enabled' must be boolean") - namespaced_rule_id = _namespaced_rule_id_from_spec(spec) return await self.store.set_push_rule_enabled( - user_id, namespaced_rule_id, val + user_id, namespaced_rule_id, val, is_default_rule ) elif spec["attr"] == "actions": actions = val.get("actions") diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 5f65cb7d83..28dabf1c7a 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -44,7 +44,7 @@ class PushersRestServlet(RestServlet): PATTERNS = client_patterns("/pushers$", v1=True) def __init__(self, hs): - super(PushersRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() @@ -68,7 +68,7 @@ class PushersSetRestServlet(RestServlet): PATTERNS = client_patterns("/pushers/set$", v1=True) def __init__(self, hs): - super(PushersSetRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.notifier = hs.get_notifier() @@ -153,7 +153,7 @@ class PushersRemoveRestServlet(RestServlet): SUCCESS_HTML = b"<html><body>You have been unsubscribed</body><html>" def __init__(self, hs): - super(PushersRemoveRestServlet, self).__init__() + super().__init__() self.hs = hs self.notifier = hs.get_notifier() self.auth = hs.get_auth() diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 84baf3d59b..b63389e5fe 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -57,7 +57,7 @@ logger = logging.getLogger(__name__) class TransactionRestServlet(RestServlet): def __init__(self, hs): - super(TransactionRestServlet, self).__init__() + super().__init__() self.txns = HttpTransactionCache(hs) @@ -65,7 +65,7 @@ class RoomCreateRestServlet(TransactionRestServlet): # No PATTERN; we have custom dispatch rules here def __init__(self, hs): - super(RoomCreateRestServlet, self).__init__(hs) + super().__init__(hs) self._room_creation_handler = hs.get_room_creation_handler() self.auth = hs.get_auth() @@ -111,7 +111,7 @@ class RoomCreateRestServlet(TransactionRestServlet): # TODO: Needs unit testing for generic events class RoomStateEventRestServlet(TransactionRestServlet): def __init__(self, hs): - super(RoomStateEventRestServlet, self).__init__(hs) + super().__init__(hs) self.handlers = hs.get_handlers() self.event_creation_handler = hs.get_event_creation_handler() self.room_member_handler = hs.get_room_member_handler() @@ -229,7 +229,7 @@ class RoomStateEventRestServlet(TransactionRestServlet): # TODO: Needs unit testing for generic events + feedback class RoomSendEventRestServlet(TransactionRestServlet): def __init__(self, hs): - super(RoomSendEventRestServlet, self).__init__(hs) + super().__init__(hs) self.event_creation_handler = hs.get_event_creation_handler() self.auth = hs.get_auth() @@ -280,7 +280,7 @@ class RoomSendEventRestServlet(TransactionRestServlet): # TODO: Needs unit testing for room ID + alias joins class JoinRoomAliasServlet(TransactionRestServlet): def __init__(self, hs): - super(JoinRoomAliasServlet, self).__init__(hs) + super().__init__(hs) self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() @@ -343,7 +343,7 @@ class PublicRoomListRestServlet(TransactionRestServlet): PATTERNS = client_patterns("/publicRooms$", v1=True) def __init__(self, hs): - super(PublicRoomListRestServlet, self).__init__(hs) + super().__init__(hs) self.hs = hs self.auth = hs.get_auth() @@ -448,9 +448,10 @@ class RoomMemberListRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/members$", v1=True) def __init__(self, hs): - super(RoomMemberListRestServlet, self).__init__() + super().__init__() self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() + self.store = hs.get_datastore() async def on_GET(self, request, room_id): # TODO support Pagination stream API (limit/tokens) @@ -465,7 +466,7 @@ class RoomMemberListRestServlet(RestServlet): if at_token_string is None: at_token = None else: - at_token = StreamToken.from_string(at_token_string) + at_token = await StreamToken.from_string(self.store, at_token_string) # let you filter down on particular memberships. # XXX: this may not be the best shape for this API - we could pass in a filter @@ -499,7 +500,7 @@ class JoinedRoomMemberListRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$", v1=True) def __init__(self, hs): - super(JoinedRoomMemberListRestServlet, self).__init__() + super().__init__() self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() @@ -518,13 +519,16 @@ class RoomMessageListRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/messages$", v1=True) def __init__(self, hs): - super(RoomMessageListRestServlet, self).__init__() + super().__init__() self.pagination_handler = hs.get_pagination_handler() self.auth = hs.get_auth() + self.store = hs.get_datastore() async def on_GET(self, request, room_id): requester = await self.auth.get_user_by_req(request, allow_guest=True) - pagination_config = PaginationConfig.from_request(request, default_limit=10) + pagination_config = await PaginationConfig.from_request( + self.store, request, default_limit=10 + ) as_client_event = b"raw" not in request.args filter_str = parse_string(request, b"filter", encoding="utf-8") if filter_str: @@ -557,7 +561,7 @@ class RoomStateRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/state$", v1=True) def __init__(self, hs): - super(RoomStateRestServlet, self).__init__() + super().__init__() self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() @@ -577,13 +581,14 @@ class RoomInitialSyncRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$", v1=True) def __init__(self, hs): - super(RoomInitialSyncRestServlet, self).__init__() + super().__init__() self.initial_sync_handler = hs.get_initial_sync_handler() self.auth = hs.get_auth() + self.store = hs.get_datastore() async def on_GET(self, request, room_id): requester = await self.auth.get_user_by_req(request, allow_guest=True) - pagination_config = PaginationConfig.from_request(request) + pagination_config = await PaginationConfig.from_request(self.store, request) content = await self.initial_sync_handler.room_initial_sync( room_id=room_id, requester=requester, pagin_config=pagination_config ) @@ -596,7 +601,7 @@ class RoomEventServlet(RestServlet): ) def __init__(self, hs): - super(RoomEventServlet, self).__init__() + super().__init__() self.clock = hs.get_clock() self.event_handler = hs.get_event_handler() self._event_serializer = hs.get_event_client_serializer() @@ -628,7 +633,7 @@ class RoomEventContextServlet(RestServlet): ) def __init__(self, hs): - super(RoomEventContextServlet, self).__init__() + super().__init__() self.clock = hs.get_clock() self.room_context_handler = hs.get_room_context_handler() self._event_serializer = hs.get_event_client_serializer() @@ -675,7 +680,7 @@ class RoomEventContextServlet(RestServlet): class RoomForgetRestServlet(TransactionRestServlet): def __init__(self, hs): - super(RoomForgetRestServlet, self).__init__(hs) + super().__init__(hs) self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() @@ -701,7 +706,7 @@ class RoomForgetRestServlet(TransactionRestServlet): # TODO: Needs unit testing class RoomMembershipRestServlet(TransactionRestServlet): def __init__(self, hs): - super(RoomMembershipRestServlet, self).__init__(hs) + super().__init__(hs) self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() @@ -792,7 +797,7 @@ class RoomMembershipRestServlet(TransactionRestServlet): class RoomRedactEventRestServlet(TransactionRestServlet): def __init__(self, hs): - super(RoomRedactEventRestServlet, self).__init__(hs) + super().__init__(hs) self.handlers = hs.get_handlers() self.event_creation_handler = hs.get_event_creation_handler() self.auth = hs.get_auth() @@ -841,7 +846,7 @@ class RoomTypingRestServlet(RestServlet): ) def __init__(self, hs): - super(RoomTypingRestServlet, self).__init__() + super().__init__() self.presence_handler = hs.get_presence_handler() self.typing_handler = hs.get_typing_handler() self.auth = hs.get_auth() @@ -914,7 +919,7 @@ class SearchRestServlet(RestServlet): PATTERNS = client_patterns("/search$", v1=True) def __init__(self, hs): - super(SearchRestServlet, self).__init__() + super().__init__() self.handlers = hs.get_handlers() self.auth = hs.get_auth() @@ -935,7 +940,7 @@ class JoinedRoomsRestServlet(RestServlet): PATTERNS = client_patterns("/joined_rooms$", v1=True) def __init__(self, hs): - super(JoinedRoomsRestServlet, self).__init__() + super().__init__() self.store = hs.get_datastore() self.auth = hs.get_auth() diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index 50277c6cf6..b8d491ca5c 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -25,7 +25,7 @@ class VoipRestServlet(RestServlet): PATTERNS = client_patterns("/voip/turnServer$", v1=True) def __init__(self, hs): - super(VoipRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index a206b75541..86d3d86fad 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -17,6 +17,11 @@ import logging import random from http import HTTPStatus +from typing import TYPE_CHECKING +from urllib.parse import urlparse + +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer from synapse.api.constants import LoginType from synapse.api.errors import ( @@ -47,7 +52,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): PATTERNS = client_patterns("/account/password/email/requestToken$") def __init__(self, hs): - super(EmailPasswordRequestTokenRestServlet, self).__init__() + super().__init__() self.hs = hs self.datastore = hs.get_datastore() self.config = hs.config @@ -91,6 +96,10 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): send_attempt = body["send_attempt"] next_link = body.get("next_link") # Optional param + if next_link: + # Raise if the provided next_link value isn't valid + assert_valid_next_link(self.hs, next_link) + # The email will be sent to the stored address. # This avoids a potential account hijack by requesting a password reset to # an email address which is controlled by the attacker but which, after @@ -137,86 +146,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): return 200, ret -class PasswordResetSubmitTokenServlet(RestServlet): - """Handles 3PID validation token submission""" - - PATTERNS = client_patterns( - "/password_reset/(?P<medium>[^/]*)/submit_token$", releases=(), unstable=True - ) - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super(PasswordResetSubmitTokenServlet, self).__init__() - self.hs = hs - self.auth = hs.get_auth() - self.config = hs.config - self.clock = hs.get_clock() - self.store = hs.get_datastore() - if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - self._failure_email_template = ( - self.config.email_password_reset_template_failure_html - ) - - async def on_GET(self, request, medium): - # We currently only handle threepid token submissions for email - if medium != "email": - raise SynapseError( - 400, "This medium is currently not supported for password resets" - ) - if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: - if self.config.local_threepid_handling_disabled_due_to_email_config: - logger.warning( - "Password reset emails have been disabled due to lack of an email config" - ) - raise SynapseError( - 400, "Email-based password resets are disabled on this server" - ) - - sid = parse_string(request, "sid", required=True) - token = parse_string(request, "token", required=True) - client_secret = parse_string(request, "client_secret", required=True) - assert_valid_client_secret(client_secret) - - # Attempt to validate a 3PID session - try: - # Mark the session as valid - next_link = await self.store.validate_threepid_session( - sid, client_secret, token, self.clock.time_msec() - ) - - # Perform a 302 redirect if next_link is set - if next_link: - if next_link.startswith("file:///"): - logger.warning( - "Not redirecting to next_link as it is a local file: address" - ) - else: - request.setResponseCode(302) - request.setHeader("Location", next_link) - finish_request(request) - return None - - # Otherwise show the success template - html = self.config.email_password_reset_template_success_html_content - status_code = 200 - except ThreepidValidationError as e: - status_code = e.code - - # Show a failure page with a reason - template_vars = {"failure_reason": e.msg} - html = self._failure_email_template.render(**template_vars) - - respond_with_html(request, status_code, html) - - class PasswordRestServlet(RestServlet): PATTERNS = client_patterns("/account/password$") def __init__(self, hs): - super(PasswordRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() @@ -342,7 +276,7 @@ class DeactivateAccountRestServlet(RestServlet): PATTERNS = client_patterns("/account/deactivate$") def __init__(self, hs): - super(DeactivateAccountRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() @@ -361,7 +295,7 @@ class DeactivateAccountRestServlet(RestServlet): requester = await self.auth.get_user_by_req(request) - # allow ASes to dectivate their own users + # allow ASes to deactivate their own users if requester.app_service: await self._deactivate_account_handler.deactivate_account( requester.user.to_string(), erase @@ -390,7 +324,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/email/requestToken$") def __init__(self, hs): - super(EmailThreepidRequestTokenRestServlet, self).__init__() + super().__init__() self.hs = hs self.config = hs.config self.identity_handler = hs.get_handlers().identity_handler @@ -439,6 +373,10 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) + if next_link: + # Raise if the provided next_link value isn't valid + assert_valid_next_link(self.hs, next_link) + existing_user_id = await self.store.get_user_id_by_threepid("email", email) if existing_user_id is not None: @@ -484,7 +422,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): def __init__(self, hs): self.hs = hs - super(MsisdnThreepidRequestTokenRestServlet, self).__init__() + super().__init__() self.store = self.hs.get_datastore() self.identity_handler = hs.get_handlers().identity_handler @@ -510,6 +448,10 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) + if next_link: + # Raise if the provided next_link value isn't valid + assert_valid_next_link(self.hs, next_link) + existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn) if existing_user_id is not None: @@ -596,15 +538,10 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): # Perform a 302 redirect if next_link is set if next_link: - if next_link.startswith("file:///"): - logger.warning( - "Not redirecting to next_link as it is a local file: address" - ) - else: - request.setResponseCode(302) - request.setHeader("Location", next_link) - finish_request(request) - return None + request.setResponseCode(302) + request.setHeader("Location", next_link) + finish_request(request) + return None # Otherwise show the success template html = self.config.email_add_threepid_template_success_html_content @@ -665,7 +602,7 @@ class ThreepidRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid$") def __init__(self, hs): - super(ThreepidRestServlet, self).__init__() + super().__init__() self.hs = hs self.identity_handler = hs.get_handlers().identity_handler self.auth = hs.get_auth() @@ -721,7 +658,7 @@ class ThreepidAddRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/add$") def __init__(self, hs): - super(ThreepidAddRestServlet, self).__init__() + super().__init__() self.hs = hs self.identity_handler = hs.get_handlers().identity_handler self.auth = hs.get_auth() @@ -772,7 +709,7 @@ class ThreepidBindRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/bind$") def __init__(self, hs): - super(ThreepidBindRestServlet, self).__init__() + super().__init__() self.hs = hs self.identity_handler = hs.get_handlers().identity_handler self.auth = hs.get_auth() @@ -801,7 +738,7 @@ class ThreepidUnbindRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/unbind$") def __init__(self, hs): - super(ThreepidUnbindRestServlet, self).__init__() + super().__init__() self.hs = hs self.identity_handler = hs.get_handlers().identity_handler self.auth = hs.get_auth() @@ -832,7 +769,7 @@ class ThreepidDeleteRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/delete$") def __init__(self, hs): - super(ThreepidDeleteRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() @@ -868,11 +805,50 @@ class ThreepidDeleteRestServlet(RestServlet): return 200, {"id_server_unbind_result": id_server_unbind_result} +def assert_valid_next_link(hs: "HomeServer", next_link: str): + """ + Raises a SynapseError if a given next_link value is invalid + + next_link is valid if the scheme is http(s) and the next_link.domain_whitelist config + option is either empty or contains a domain that matches the one in the given next_link + + Args: + hs: The homeserver object + next_link: The next_link value given by the client + + Raises: + SynapseError: If the next_link is invalid + """ + valid = True + + # Parse the contents of the URL + next_link_parsed = urlparse(next_link) + + # Scheme must not point to the local drive + if next_link_parsed.scheme == "file": + valid = False + + # If the domain whitelist is set, the domain must be in it + if ( + valid + and hs.config.next_link_domain_whitelist is not None + and next_link_parsed.hostname not in hs.config.next_link_domain_whitelist + ): + valid = False + + if not valid: + raise SynapseError( + 400, + "'next_link' domain not included in whitelist, or not http(s)", + errcode=Codes.INVALID_PARAM, + ) + + class WhoamiRestServlet(RestServlet): PATTERNS = client_patterns("/account/whoami$") def __init__(self, hs): - super(WhoamiRestServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() async def on_GET(self, request): @@ -883,7 +859,6 @@ class WhoamiRestServlet(RestServlet): def register_servlets(hs, http_server): EmailPasswordRequestTokenRestServlet(hs).register(http_server) - PasswordResetSubmitTokenServlet(hs).register(http_server) PasswordRestServlet(hs).register(http_server) DeactivateAccountRestServlet(hs).register(http_server) EmailThreepidRequestTokenRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py index c1d4cd0caf..87a5b1b86b 100644 --- a/synapse/rest/client/v2_alpha/account_data.py +++ b/synapse/rest/client/v2_alpha/account_data.py @@ -34,7 +34,7 @@ class AccountDataServlet(RestServlet): ) def __init__(self, hs): - super(AccountDataServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.notifier = hs.get_notifier() @@ -86,7 +86,7 @@ class RoomAccountDataServlet(RestServlet): ) def __init__(self, hs): - super(RoomAccountDataServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.notifier = hs.get_notifier() diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py index d06336ceea..bd7f9ae203 100644 --- a/synapse/rest/client/v2_alpha/account_validity.py +++ b/synapse/rest/client/v2_alpha/account_validity.py @@ -32,7 +32,7 @@ class AccountValidityRenewServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(AccountValidityRenewServlet, self).__init__() + super().__init__() self.hs = hs self.account_activity_handler = hs.get_account_validity_handler() @@ -67,7 +67,7 @@ class AccountValiditySendMailServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(AccountValiditySendMailServlet, self).__init__() + super().__init__() self.hs = hs self.account_activity_handler = hs.get_account_validity_handler() diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index 8e585e9153..5fbfae5991 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -25,94 +25,6 @@ from ._base import client_patterns logger = logging.getLogger(__name__) -RECAPTCHA_TEMPLATE = """ -<html> -<head> -<title>Authentication</title> -<meta name='viewport' content='width=device-width, initial-scale=1, - user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'> -<script src="https://www.recaptcha.net/recaptcha/api.js" - async defer></script> -<script src="//code.jquery.com/jquery-1.11.2.min.js"></script> -<link rel="stylesheet" href="/_matrix/static/client/register/style.css"> -<script> -function captchaDone() { - $('#registrationForm').submit(); -} -</script> -</head> -<body> -<form id="registrationForm" method="post" action="%(myurl)s"> - <div> - <p> - Hello! We need to prevent computer programs and other automated - things from creating accounts on this server. - </p> - <p> - Please verify that you're not a robot. - </p> - <input type="hidden" name="session" value="%(session)s" /> - <div class="g-recaptcha" - data-sitekey="%(sitekey)s" - data-callback="captchaDone"> - </div> - <noscript> - <input type="submit" value="All Done" /> - </noscript> - </div> - </div> -</form> -</body> -</html> -""" - -TERMS_TEMPLATE = """ -<html> -<head> -<title>Authentication</title> -<meta name='viewport' content='width=device-width, initial-scale=1, - user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'> -<link rel="stylesheet" href="/_matrix/static/client/register/style.css"> -</head> -<body> -<form id="registrationForm" method="post" action="%(myurl)s"> - <div> - <p> - Please click the button below if you agree to the - <a href="%(terms_url)s">privacy policy of this homeserver.</a> - </p> - <input type="hidden" name="session" value="%(session)s" /> - <input type="submit" value="Agree" /> - </div> -</form> -</body> -</html> -""" - -SUCCESS_TEMPLATE = """ -<html> -<head> -<title>Success!</title> -<meta name='viewport' content='width=device-width, initial-scale=1, - user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'> -<link rel="stylesheet" href="/_matrix/static/client/register/style.css"> -<script> -if (window.onAuthDone) { - window.onAuthDone(); -} else if (window.opener && window.opener.postMessage) { - window.opener.postMessage("authDone", "*"); -} -</script> -</head> -<body> - <div> - <p>Thank you</p> - <p>You may now close this window and return to the application</p> - </div> -</body> -</html> -""" - class AuthRestServlet(RestServlet): """ @@ -124,7 +36,7 @@ class AuthRestServlet(RestServlet): PATTERNS = client_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web") def __init__(self, hs): - super(AuthRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() @@ -145,26 +57,30 @@ class AuthRestServlet(RestServlet): self._cas_server_url = hs.config.cas_server_url self._cas_service_url = hs.config.cas_service_url + self.recaptcha_template = hs.config.recaptcha_template + self.terms_template = hs.config.terms_template + self.success_template = hs.config.fallback_success_template + async def on_GET(self, request, stagetype): session = parse_string(request, "session") if not session: raise SynapseError(400, "No session supplied") if stagetype == LoginType.RECAPTCHA: - html = RECAPTCHA_TEMPLATE % { - "session": session, - "myurl": "%s/r0/auth/%s/fallback/web" + html = self.recaptcha_template.render( + session=session, + myurl="%s/r0/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), - "sitekey": self.hs.config.recaptcha_public_key, - } + sitekey=self.hs.config.recaptcha_public_key, + ) elif stagetype == LoginType.TERMS: - html = TERMS_TEMPLATE % { - "session": session, - "terms_url": "%s_matrix/consent?v=%s" + html = self.terms_template.render( + session=session, + terms_url="%s_matrix/consent?v=%s" % (self.hs.config.public_baseurl, self.hs.config.user_consent_version), - "myurl": "%s/r0/auth/%s/fallback/web" + myurl="%s/r0/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.TERMS), - } + ) elif stagetype == LoginType.SSO: # Display a confirmation page which prompts the user to @@ -222,14 +138,14 @@ class AuthRestServlet(RestServlet): ) if success: - html = SUCCESS_TEMPLATE + html = self.success_template.render() else: - html = RECAPTCHA_TEMPLATE % { - "session": session, - "myurl": "%s/r0/auth/%s/fallback/web" + html = self.recaptcha_template.render( + session=session, + myurl="%s/r0/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), - "sitekey": self.hs.config.recaptcha_public_key, - } + sitekey=self.hs.config.recaptcha_public_key, + ) elif stagetype == LoginType.TERMS: authdict = {"session": session} @@ -238,18 +154,18 @@ class AuthRestServlet(RestServlet): ) if success: - html = SUCCESS_TEMPLATE + html = self.success_template.render() else: - html = TERMS_TEMPLATE % { - "session": session, - "terms_url": "%s_matrix/consent?v=%s" + html = self.terms_template.render( + session=session, + terms_url="%s_matrix/consent?v=%s" % ( self.hs.config.public_baseurl, self.hs.config.user_consent_version, ), - "myurl": "%s/r0/auth/%s/fallback/web" + myurl="%s/r0/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.TERMS), - } + ) elif stagetype == LoginType.SSO: # The SSO fallback workflow should not post here, raise SynapseError(404, "Fallback SSO auth does not support POST requests.") diff --git a/synapse/rest/client/v2_alpha/capabilities.py b/synapse/rest/client/v2_alpha/capabilities.py index fe9d019c44..76879ac559 100644 --- a/synapse/rest/client/v2_alpha/capabilities.py +++ b/synapse/rest/client/v2_alpha/capabilities.py @@ -32,7 +32,7 @@ class CapabilitiesRestServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(CapabilitiesRestServlet, self).__init__() + super().__init__() self.hs = hs self.config = hs.config self.auth = hs.get_auth() diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index c0714fcfb1..7e174de692 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -35,7 +35,7 @@ class DevicesRestServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(DevicesRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() @@ -57,7 +57,7 @@ class DeleteDevicesRestServlet(RestServlet): PATTERNS = client_patterns("/delete_devices") def __init__(self, hs): - super(DeleteDevicesRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() @@ -102,7 +102,7 @@ class DeviceRestServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(DeviceRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py index b28da017cd..7cc692643b 100644 --- a/synapse/rest/client/v2_alpha/filter.py +++ b/synapse/rest/client/v2_alpha/filter.py @@ -28,7 +28,7 @@ class GetFilterRestServlet(RestServlet): PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)") def __init__(self, hs): - super(GetFilterRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.filtering = hs.get_filtering() @@ -64,7 +64,7 @@ class CreateFilterRestServlet(RestServlet): PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter") def __init__(self, hs): - super(CreateFilterRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.filtering = hs.get_filtering() diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py index 075afdd32b..75215a3779 100644 --- a/synapse/rest/client/v2_alpha/groups.py +++ b/synapse/rest/client/v2_alpha/groups.py @@ -32,7 +32,7 @@ class GroupServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/profile$") def __init__(self, hs): - super(GroupServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -66,7 +66,7 @@ class GroupSummaryServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/summary$") def __init__(self, hs): - super(GroupSummaryServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -97,7 +97,7 @@ class GroupSummaryRoomsCatServlet(RestServlet): ) def __init__(self, hs): - super(GroupSummaryRoomsCatServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -137,7 +137,7 @@ class GroupCategoryServlet(RestServlet): ) def __init__(self, hs): - super(GroupCategoryServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -181,7 +181,7 @@ class GroupCategoriesServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/categories/$") def __init__(self, hs): - super(GroupCategoriesServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -204,7 +204,7 @@ class GroupRoleServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$") def __init__(self, hs): - super(GroupRoleServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -248,7 +248,7 @@ class GroupRolesServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/$") def __init__(self, hs): - super(GroupRolesServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -279,7 +279,7 @@ class GroupSummaryUsersRoleServlet(RestServlet): ) def __init__(self, hs): - super(GroupSummaryUsersRoleServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -317,7 +317,7 @@ class GroupRoomServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/rooms$") def __init__(self, hs): - super(GroupRoomServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -343,7 +343,7 @@ class GroupUsersServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/users$") def __init__(self, hs): - super(GroupUsersServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -366,7 +366,7 @@ class GroupInvitedUsersServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/invited_users$") def __init__(self, hs): - super(GroupInvitedUsersServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -389,7 +389,7 @@ class GroupSettingJoinPolicyServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$") def __init__(self, hs): - super(GroupSettingJoinPolicyServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.groups_handler = hs.get_groups_local_handler() @@ -413,7 +413,7 @@ class GroupCreateServlet(RestServlet): PATTERNS = client_patterns("/create_group$") def __init__(self, hs): - super(GroupCreateServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -444,7 +444,7 @@ class GroupAdminRoomsServlet(RestServlet): ) def __init__(self, hs): - super(GroupAdminRoomsServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -481,7 +481,7 @@ class GroupAdminRoomsConfigServlet(RestServlet): ) def __init__(self, hs): - super(GroupAdminRoomsConfigServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -507,7 +507,7 @@ class GroupAdminUsersInviteServlet(RestServlet): ) def __init__(self, hs): - super(GroupAdminUsersInviteServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -536,7 +536,7 @@ class GroupAdminUsersKickServlet(RestServlet): ) def __init__(self, hs): - super(GroupAdminUsersKickServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -585,7 +585,7 @@ class GroupSelfLeaveServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/leave$") def __init__(self, hs): - super(GroupSelfLeaveServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -609,7 +609,7 @@ class GroupSelfJoinServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/join$") def __init__(self, hs): - super(GroupSelfJoinServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -633,7 +633,7 @@ class GroupSelfAcceptInviteServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/accept_invite$") def __init__(self, hs): - super(GroupSelfAcceptInviteServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -657,7 +657,7 @@ class GroupSelfUpdatePublicityServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/update_publicity$") def __init__(self, hs): - super(GroupSelfUpdatePublicityServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.store = hs.get_datastore() @@ -680,7 +680,7 @@ class PublicisedGroupsForUserServlet(RestServlet): PATTERNS = client_patterns("/publicised_groups/(?P<user_id>[^/]*)$") def __init__(self, hs): - super(PublicisedGroupsForUserServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.store = hs.get_datastore() @@ -701,7 +701,7 @@ class PublicisedGroupsForUsersServlet(RestServlet): PATTERNS = client_patterns("/publicised_groups$") def __init__(self, hs): - super(PublicisedGroupsForUsersServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.store = hs.get_datastore() @@ -725,7 +725,7 @@ class GroupsForUserServlet(RestServlet): PATTERNS = client_patterns("/joined_groups$") def __init__(self, hs): - super(GroupsForUserServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 24bb090822..55c4606569 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -64,7 +64,7 @@ class KeyUploadServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(KeyUploadServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() @@ -147,7 +147,7 @@ class KeyQueryServlet(RestServlet): Args: hs (synapse.server.HomeServer): """ - super(KeyQueryServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() @@ -177,9 +177,10 @@ class KeyChangesServlet(RestServlet): Args: hs (synapse.server.HomeServer): """ - super(KeyChangesServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() + self.store = hs.get_datastore() async def on_GET(self, request): requester = await self.auth.get_user_by_req(request, allow_guest=True) @@ -191,7 +192,7 @@ class KeyChangesServlet(RestServlet): # changes after the "to" as well as before. set_tag("to", parse_string(request, "to")) - from_token = StreamToken.from_string(from_token_string) + from_token = await StreamToken.from_string(self.store, from_token_string) user_id = requester.user.to_string() @@ -222,7 +223,7 @@ class OneTimeKeyServlet(RestServlet): PATTERNS = client_patterns("/keys/claim$") def __init__(self, hs): - super(OneTimeKeyServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() @@ -250,7 +251,7 @@ class SigningKeyUploadServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(SigningKeyUploadServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() @@ -308,7 +309,7 @@ class SignaturesUploadServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(SignaturesUploadServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py index aa911d75ee..87063ec8b1 100644 --- a/synapse/rest/client/v2_alpha/notifications.py +++ b/synapse/rest/client/v2_alpha/notifications.py @@ -27,7 +27,7 @@ class NotificationsServlet(RestServlet): PATTERNS = client_patterns("/notifications$") def __init__(self, hs): - super(NotificationsServlet, self).__init__() + super().__init__() self.store = hs.get_datastore() self.auth = hs.get_auth() self.clock = hs.get_clock() diff --git a/synapse/rest/client/v2_alpha/openid.py b/synapse/rest/client/v2_alpha/openid.py index 6ae9a5a8e9..5b996e2d63 100644 --- a/synapse/rest/client/v2_alpha/openid.py +++ b/synapse/rest/client/v2_alpha/openid.py @@ -60,7 +60,7 @@ class IdTokenServlet(RestServlet): EXPIRES_MS = 3600 * 1000 def __init__(self, hs): - super(IdTokenServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.clock = hs.get_clock() diff --git a/synapse/rest/client/v2_alpha/password_policy.py b/synapse/rest/client/v2_alpha/password_policy.py index 968403cca4..68b27ff23a 100644 --- a/synapse/rest/client/v2_alpha/password_policy.py +++ b/synapse/rest/client/v2_alpha/password_policy.py @@ -30,7 +30,7 @@ class PasswordPolicyServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(PasswordPolicyServlet, self).__init__() + super().__init__() self.policy = hs.config.password_policy self.enabled = hs.config.password_policy_enabled diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py index 67cbc37312..55c6688f52 100644 --- a/synapse/rest/client/v2_alpha/read_marker.py +++ b/synapse/rest/client/v2_alpha/read_marker.py @@ -26,7 +26,7 @@ class ReadMarkerRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/read_markers$") def __init__(self, hs): - super(ReadMarkerRestServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.receipts_handler = hs.get_receipts_handler() self.read_marker_handler = hs.get_read_marker_handler() diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index 92555bd4a9..6f7246a394 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -31,7 +31,7 @@ class ReceiptRestServlet(RestServlet): ) def __init__(self, hs): - super(ReceiptRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.receipts_handler = hs.get_receipts_handler() diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index c589dd6c78..ec8ef9bf88 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -76,7 +76,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(EmailRegisterRequestTokenRestServlet, self).__init__() + super().__init__() self.hs = hs self.identity_handler = hs.get_handlers().identity_handler self.config = hs.config @@ -174,7 +174,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(MsisdnRegisterRequestTokenRestServlet, self).__init__() + super().__init__() self.hs = hs self.identity_handler = hs.get_handlers().identity_handler @@ -249,7 +249,7 @@ class RegistrationSubmitTokenServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(RegistrationSubmitTokenServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.config = hs.config @@ -319,7 +319,7 @@ class UsernameAvailabilityRestServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(UsernameAvailabilityRestServlet, self).__init__() + super().__init__() self.hs = hs self.registration_handler = hs.get_registration_handler() self.ratelimiter = FederationRateLimiter( @@ -363,7 +363,7 @@ class RegisterRestServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(RegisterRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() @@ -431,11 +431,14 @@ class RegisterRestServlet(RestServlet): access_token = self.auth.get_access_token_from_request(request) - if isinstance(desired_username, str): - result = await self._do_appservice_registration( - desired_username, access_token, body - ) - return 200, result # we throw for non 200 responses + if not isinstance(desired_username, str): + raise SynapseError(400, "Desired Username is missing or not a string") + + result = await self._do_appservice_registration( + desired_username, access_token, body + ) + + return 200, result # == Normal User Registration == (everyone else) if not self._registration_enabled: diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py index e29f49f7f5..18c75738f8 100644 --- a/synapse/rest/client/v2_alpha/relations.py +++ b/synapse/rest/client/v2_alpha/relations.py @@ -61,7 +61,7 @@ class RelationSendServlet(RestServlet): ) def __init__(self, hs): - super(RelationSendServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.event_creation_handler = hs.get_event_creation_handler() self.txns = HttpTransactionCache(hs) @@ -138,7 +138,7 @@ class RelationPaginationServlet(RestServlet): ) def __init__(self, hs): - super(RelationPaginationServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.clock = hs.get_clock() @@ -233,7 +233,7 @@ class RelationAggregationPaginationServlet(RestServlet): ) def __init__(self, hs): - super(RelationAggregationPaginationServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.event_handler = hs.get_event_handler() @@ -311,7 +311,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): ) def __init__(self, hs): - super(RelationAggregationGroupPaginationServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.clock = hs.get_clock() diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py index e15927c4ea..215d619ca1 100644 --- a/synapse/rest/client/v2_alpha/report_event.py +++ b/synapse/rest/client/v2_alpha/report_event.py @@ -32,7 +32,7 @@ class ReportEventRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/report/(?P<event_id>[^/]*)$") def __init__(self, hs): - super(ReportEventRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.clock = hs.get_clock() diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py index 59529707df..53de97923f 100644 --- a/synapse/rest/client/v2_alpha/room_keys.py +++ b/synapse/rest/client/v2_alpha/room_keys.py @@ -37,7 +37,7 @@ class RoomKeysServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(RoomKeysServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() @@ -248,7 +248,7 @@ class RoomKeysNewVersionServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(RoomKeysNewVersionServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() @@ -301,7 +301,7 @@ class RoomKeysVersionServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(RoomKeysVersionServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py index 39a5518614..bf030e0ff4 100644 --- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py +++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py @@ -53,7 +53,7 @@ class RoomUpgradeRestServlet(RestServlet): ) def __init__(self, hs): - super(RoomUpgradeRestServlet, self).__init__() + super().__init__() self._hs = hs self._room_creation_handler = hs.get_room_creation_handler() self._auth = hs.get_auth() diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py index db829f3098..bc4f43639a 100644 --- a/synapse/rest/client/v2_alpha/sendtodevice.py +++ b/synapse/rest/client/v2_alpha/sendtodevice.py @@ -36,7 +36,7 @@ class SendToDeviceRestServlet(servlet.RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(SendToDeviceRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.txns = HttpTransactionCache(hs) diff --git a/synapse/rest/client/v2_alpha/shared_rooms.py b/synapse/rest/client/v2_alpha/shared_rooms.py index 2492634dac..c866d5151c 100644 --- a/synapse/rest/client/v2_alpha/shared_rooms.py +++ b/synapse/rest/client/v2_alpha/shared_rooms.py @@ -34,7 +34,7 @@ class UserSharedRoomsServlet(RestServlet): ) def __init__(self, hs): - super(UserSharedRoomsServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.user_directory_active = hs.config.update_user_directory diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index a0b00135e1..6779df952f 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -74,9 +74,10 @@ class SyncRestServlet(RestServlet): ALLOWED_PRESENCE = {"online", "offline", "unavailable"} def __init__(self, hs): - super(SyncRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() + self.store = hs.get_datastore() self.sync_handler = hs.get_sync_handler() self.clock = hs.get_clock() self.filtering = hs.get_filtering() @@ -151,10 +152,9 @@ class SyncRestServlet(RestServlet): device_id=device_id, ) + since_token = None if since is not None: - since_token = StreamToken.from_string(since) - else: - since_token = None + since_token = await StreamToken.from_string(self.store, since) # send any outstanding server notices to the user. await self._server_notices_sender.on_user_syncing(user.to_string()) @@ -236,7 +236,7 @@ class SyncRestServlet(RestServlet): "leave": sync_result.groups.leave, }, "device_one_time_keys_count": sync_result.device_one_time_keys_count, - "next_batch": sync_result.next_batch.to_string(), + "next_batch": await sync_result.next_batch.to_string(self.store), } @staticmethod @@ -413,7 +413,7 @@ class SyncRestServlet(RestServlet): result = { "timeline": { "events": serialized_timeline, - "prev_batch": room.timeline.prev_batch.to_string(), + "prev_batch": await room.timeline.prev_batch.to_string(self.store), "limited": room.timeline.limited, }, "state": {"events": serialized_state}, diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py index a3f12e8a77..bf3a79db44 100644 --- a/synapse/rest/client/v2_alpha/tags.py +++ b/synapse/rest/client/v2_alpha/tags.py @@ -31,7 +31,7 @@ class TagListServlet(RestServlet): PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags") def __init__(self, hs): - super(TagListServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -56,7 +56,7 @@ class TagServlet(RestServlet): ) def __init__(self, hs): - super(TagServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.notifier = hs.get_notifier() diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index 23709960ad..0c127a1b5f 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -28,7 +28,7 @@ class ThirdPartyProtocolsServlet(RestServlet): PATTERNS = client_patterns("/thirdparty/protocols") def __init__(self, hs): - super(ThirdPartyProtocolsServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() @@ -44,7 +44,7 @@ class ThirdPartyProtocolServlet(RestServlet): PATTERNS = client_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$") def __init__(self, hs): - super(ThirdPartyProtocolServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() @@ -65,7 +65,7 @@ class ThirdPartyUserServlet(RestServlet): PATTERNS = client_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$") def __init__(self, hs): - super(ThirdPartyUserServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() @@ -87,7 +87,7 @@ class ThirdPartyLocationServlet(RestServlet): PATTERNS = client_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$") def __init__(self, hs): - super(ThirdPartyLocationServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py index 83f3b6b70a..79317c74ba 100644 --- a/synapse/rest/client/v2_alpha/tokenrefresh.py +++ b/synapse/rest/client/v2_alpha/tokenrefresh.py @@ -28,7 +28,7 @@ class TokenRefreshRestServlet(RestServlet): PATTERNS = client_patterns("/tokenrefresh") def __init__(self, hs): - super(TokenRefreshRestServlet, self).__init__() + super().__init__() async def on_POST(self, request): raise AuthError(403, "tokenrefresh is no longer supported.") diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py index bef91a2d3e..ad598cefe0 100644 --- a/synapse/rest/client/v2_alpha/user_directory.py +++ b/synapse/rest/client/v2_alpha/user_directory.py @@ -31,7 +31,7 @@ class UserDirectorySearchRestServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(UserDirectorySearchRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.user_directory_handler = hs.get_user_directory_handler() diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index c560edbc59..d24a199318 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -29,7 +29,7 @@ class VersionsRestServlet(RestServlet): PATTERNS = [re.compile("^/_matrix/client/versions$")] def __init__(self, hs): - super(VersionsRestServlet, self).__init__() + super().__init__() self.config = hs.config # Calculate these once since they shouldn't change after start-up. diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 5db7f81c2d..f843f02454 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -35,7 +35,7 @@ class RemoteKey(DirectServeJsonResource): Supports individual GET APIs and a bulk query POST API. - Requsts: + Requests: GET /_matrix/key/v2/query/remote.server.example.com HTTP/1.1 diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py index d2826374a7..7447eeaebe 100644 --- a/synapse/rest/media/v1/filepath.py +++ b/synapse/rest/media/v1/filepath.py @@ -80,7 +80,7 @@ class MediaFilePaths: self, server_name, file_id, width, height, content_type, method ): top_level_type, sub_type = content_type.split("/") - file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type) + file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) return os.path.join( "remote_thumbnail", server_name, @@ -92,6 +92,23 @@ class MediaFilePaths: remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel) + # Legacy path that was used to store thumbnails previously. + # Should be removed after some time, when most of the thumbnails are stored + # using the new path. + def remote_media_thumbnail_rel_legacy( + self, server_name, file_id, width, height, content_type + ): + top_level_type, sub_type = content_type.split("/") + file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type) + return os.path.join( + "remote_thumbnail", + server_name, + file_id[0:2], + file_id[2:4], + file_id[4:], + file_name, + ) + def remote_media_thumbnail_dir(self, server_name, file_id): return os.path.join( self.base_path, diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 9a1b7779f7..e1192b47cd 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -53,7 +53,7 @@ from .media_storage import MediaStorage from .preview_url_resource import PreviewUrlResource from .storage_provider import StorageProviderWrapper from .thumbnail_resource import ThumbnailResource -from .thumbnailer import Thumbnailer +from .thumbnailer import Thumbnailer, ThumbnailError from .upload_resource import UploadResource logger = logging.getLogger(__name__) @@ -139,7 +139,7 @@ class MediaRepository: async def create_content( self, media_type: str, - upload_name: str, + upload_name: Optional[str], content: IO, content_length: int, auth_user: str, @@ -147,8 +147,8 @@ class MediaRepository: """Store uploaded content for a local user and return the mxc URL Args: - media_type: The content type of the file - upload_name: The name of the file + media_type: The content type of the file. + upload_name: The name of the file, if provided. content: A file like object that is the content to store content_length: The length of the content auth_user: The user_id of the uploader @@ -156,6 +156,7 @@ class MediaRepository: Returns: The mxc url of the stored content """ + media_id = random_string(24) file_info = FileInfo(server_name=None, file_id=media_id) @@ -460,13 +461,30 @@ class MediaRepository: return t_byte_source async def generate_local_exact_thumbnail( - self, media_id, t_width, t_height, t_method, t_type, url_cache - ): + self, + media_id: str, + t_width: int, + t_height: int, + t_method: str, + t_type: str, + url_cache: str, + ) -> Optional[str]: input_path = await self.media_storage.ensure_media_is_in_local_cache( FileInfo(None, media_id, url_cache=url_cache) ) - thumbnailer = Thumbnailer(input_path) + try: + thumbnailer = Thumbnailer(input_path) + except ThumbnailError as e: + logger.warning( + "Unable to generate a thumbnail for local media %s using a method of %s and type of %s: %s", + media_id, + t_method, + t_type, + e, + ) + return None + t_byte_source = await defer_to_thread( self.hs.get_reactor(), self._generate_thumbnail, @@ -506,14 +524,36 @@ class MediaRepository: return output_path + # Could not generate thumbnail. + return None + async def generate_remote_exact_thumbnail( - self, server_name, file_id, media_id, t_width, t_height, t_method, t_type - ): + self, + server_name: str, + file_id: str, + media_id: str, + t_width: int, + t_height: int, + t_method: str, + t_type: str, + ) -> Optional[str]: input_path = await self.media_storage.ensure_media_is_in_local_cache( FileInfo(server_name, file_id, url_cache=False) ) - thumbnailer = Thumbnailer(input_path) + try: + thumbnailer = Thumbnailer(input_path) + except ThumbnailError as e: + logger.warning( + "Unable to generate a thumbnail for remote media %s from %s using a method of %s and type of %s: %s", + media_id, + server_name, + t_method, + t_type, + e, + ) + return None + t_byte_source = await defer_to_thread( self.hs.get_reactor(), self._generate_thumbnail, @@ -559,6 +599,9 @@ class MediaRepository: return output_path + # Could not generate thumbnail. + return None + async def _generate_thumbnails( self, server_name: Optional[str], @@ -590,7 +633,18 @@ class MediaRepository: FileInfo(server_name, file_id, url_cache=url_cache) ) - thumbnailer = Thumbnailer(input_path) + try: + thumbnailer = Thumbnailer(input_path) + except ThumbnailError as e: + logger.warning( + "Unable to generate thumbnails for remote media %s from %s of type %s: %s", + media_id, + server_name, + media_type, + e, + ) + return None + m_width = thumbnailer.width m_height = thumbnailer.height diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index 3a352b5631..a9586fb0b7 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -141,17 +141,34 @@ class MediaStorage: Returns: Returns a Responder if the file was found, otherwise None. """ + paths = [self._file_info_to_path(file_info)] - path = self._file_info_to_path(file_info) - local_path = os.path.join(self.local_media_directory, path) - if os.path.exists(local_path): - return FileResponder(open(local_path, "rb")) + # fallback for remote thumbnails with no method in the filename + if file_info.thumbnail and file_info.server_name: + paths.append( + self.filepaths.remote_media_thumbnail_rel_legacy( + server_name=file_info.server_name, + file_id=file_info.file_id, + width=file_info.thumbnail_width, + height=file_info.thumbnail_height, + content_type=file_info.thumbnail_type, + ) + ) + + for path in paths: + local_path = os.path.join(self.local_media_directory, path) + if os.path.exists(local_path): + logger.debug("responding with local file %s", local_path) + return FileResponder(open(local_path, "rb")) + logger.debug("local file %s did not exist", local_path) for provider in self.storage_providers: - res = await provider.fetch(path, file_info) # type: Any - if res: - logger.debug("Streaming %s from %s", path, provider) - return res + for path in paths: + res = await provider.fetch(path, file_info) # type: Any + if res: + logger.debug("Streaming %s from %s", path, provider) + return res + logger.debug("%s not found on %s", path, provider) return None @@ -170,6 +187,20 @@ class MediaStorage: if os.path.exists(local_path): return local_path + # Fallback for paths without method names + # Should be removed in the future + if file_info.thumbnail and file_info.server_name: + legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy( + server_name=file_info.server_name, + file_id=file_info.file_id, + width=file_info.thumbnail_width, + height=file_info.thumbnail_height, + content_type=file_info.thumbnail_type, + ) + legacy_local_path = os.path.join(self.local_media_directory, legacy_path) + if os.path.exists(legacy_local_path): + return legacy_local_path + dirname = os.path.dirname(local_path) if not os.path.exists(dirname): os.makedirs(dirname) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index cd8c246594..dce6c4d168 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -102,7 +102,7 @@ for endpoint, globs in _oembed_globs.items(): _oembed_patterns[re.compile(pattern)] = endpoint -@attr.s +@attr.s(slots=True) class OEmbedResult: # Either HTML content or URL must be provided. html = attr.ib(type=Optional[str]) @@ -450,7 +450,7 @@ class PreviewUrlResource(DirectServeJsonResource): logger.warning("Error downloading oEmbed metadata from %s: %r", url, e) raise OEmbedError() from e - async def _download_url(self, url, user): + async def _download_url(self, url: str, user): # TODO: we should probably honour robots.txt... except in practice # we're most likely being explicitly triggered by a human rather than a # bot, so are we really a robot? @@ -460,7 +460,7 @@ class PreviewUrlResource(DirectServeJsonResource): file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True) # If this URL can be accessed via oEmbed, use that instead. - url_to_download = url + url_to_download = url # type: Optional[str] oembed_url = self._get_oembed_url(url) if oembed_url: # The result might be a new URL to download, or it might be HTML content. @@ -520,9 +520,15 @@ class PreviewUrlResource(DirectServeJsonResource): # FIXME: we should calculate a proper expiration based on the # Cache-Control and Expire headers. But for now, assume 1 hour. expires = ONE_HOUR - etag = headers["ETag"][0] if "ETag" in headers else None + etag = ( + headers[b"ETag"][0].decode("ascii") if b"ETag" in headers else None + ) else: - html_bytes = oembed_result.html.encode("utf-8") # type: ignore + # we can only get here if we did an oembed request and have an oembed_result.html + assert oembed_result.html is not None + assert oembed_url is not None + + html_bytes = oembed_result.html.encode("utf-8") with self.media_storage.store_into_file(file_info) as (f, fname, finish): f.write(html_bytes) await finish() diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index a83535b97b..30421b663a 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -16,6 +16,7 @@ import logging +from synapse.api.errors import SynapseError from synapse.http.server import DirectServeJsonResource, set_cors_headers from synapse.http.servlet import parse_integer, parse_string @@ -173,7 +174,7 @@ class ThumbnailResource(DirectServeJsonResource): await respond_with_file(request, desired_type, file_path) else: logger.warning("Failed to generate thumbnail") - respond_404(request) + raise SynapseError(400, "Failed to generate thumbnail.") async def _select_or_generate_remote_thumbnail( self, @@ -235,7 +236,7 @@ class ThumbnailResource(DirectServeJsonResource): await respond_with_file(request, desired_type, file_path) else: logger.warning("Failed to generate thumbnail") - respond_404(request) + raise SynapseError(400, "Failed to generate thumbnail.") async def _respond_remote_thumbnail( self, request, server_name, media_id, width, height, method, m_type diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index d681bf7bf0..32a8e4f960 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -15,7 +15,7 @@ import logging from io import BytesIO -from PIL import Image as Image +from PIL import Image logger = logging.getLogger(__name__) @@ -31,12 +31,22 @@ EXIF_TRANSPOSE_MAPPINGS = { } +class ThumbnailError(Exception): + """An error occurred generating a thumbnail.""" + + class Thumbnailer: FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"} def __init__(self, input_path): - self.image = Image.open(input_path) + try: + self.image = Image.open(input_path) + except OSError as e: + # If an error occurs opening the image, a thumbnail won't be able to + # be generated. + raise ThumbnailError from e + self.width, self.height = self.image.size self.transpose_method = None try: @@ -73,7 +83,7 @@ class Thumbnailer: Args: max_width: The largest possible width. - max_height: The larget possible height. + max_height: The largest possible height. """ if max_width * self.height < max_height * self.width: @@ -107,7 +117,7 @@ class Thumbnailer: Args: max_width: The largest possible width. - max_height: The larget possible height. + max_height: The largest possible height. Returns: BytesIO: the bytes of the encoded image ready to be written to disk diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index 3ebf7a68e6..d76f7389e1 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -63,6 +63,10 @@ class UploadResource(DirectServeJsonResource): msg="Invalid UTF-8 filename parameter: %r" % (upload_name), code=400 ) + # If the name is falsey (e.g. an empty byte string) ensure it is None. + else: + upload_name = None + headers = request.requestHeaders if headers.hasHeader(b"Content-Type"): diff --git a/synapse/rest/saml2/response_resource.py b/synapse/rest/saml2/response_resource.py index c10188a5d7..f6668fb5e3 100644 --- a/synapse/rest/saml2/response_resource.py +++ b/synapse/rest/saml2/response_resource.py @@ -13,10 +13,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from twisted.python import failure -from synapse.api.errors import SynapseError -from synapse.http.server import DirectServeHtmlResource, return_html_error +from synapse.http.server import DirectServeHtmlResource class SAML2ResponseResource(DirectServeHtmlResource): @@ -27,21 +25,15 @@ class SAML2ResponseResource(DirectServeHtmlResource): def __init__(self, hs): super().__init__() self._saml_handler = hs.get_saml_handler() - self._error_html_template = hs.config.saml2.saml2_error_html_template async def _async_render_GET(self, request): # We're not expecting any GET request on that resource if everything goes right, # but some IdPs sometimes end up responding with a 302 redirect on this endpoint. # In this case, just tell the user that something went wrong and they should # try to authenticate again. - f = failure.Failure( - SynapseError(400, "Unexpected GET request on /saml2/authn_response") + self._saml_handler._render_error( + request, "unexpected_get", "Unexpected GET request on /saml2/authn_response" ) - return_html_error(f, request, self._error_html_template) async def _async_render_POST(self, request): - try: - await self._saml_handler.handle_saml_response(request) - except Exception: - f = failure.Failure() - return_html_error(f, request, self._error_html_template) + await self._saml_handler.handle_saml_response(request) diff --git a/synapse/rest/synapse/__init__.py b/synapse/rest/synapse/__init__.py new file mode 100644 index 0000000000..c0b733488b --- /dev/null +++ b/synapse/rest/synapse/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/synapse/rest/synapse/client/__init__.py b/synapse/rest/synapse/client/__init__.py new file mode 100644 index 0000000000..c0b733488b --- /dev/null +++ b/synapse/rest/synapse/client/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/synapse/rest/synapse/client/password_reset.py b/synapse/rest/synapse/client/password_reset.py new file mode 100644 index 0000000000..9e4fbc0cbd --- /dev/null +++ b/synapse/rest/synapse/client/password_reset.py @@ -0,0 +1,127 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import TYPE_CHECKING, Tuple + +from twisted.web.http import Request + +from synapse.api.errors import ThreepidValidationError +from synapse.config.emailconfig import ThreepidBehaviour +from synapse.http.server import DirectServeHtmlResource +from synapse.http.servlet import parse_string +from synapse.util.stringutils import assert_valid_client_secret + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class PasswordResetSubmitTokenResource(DirectServeHtmlResource): + """Handles 3PID validation token submission + + This resource gets mounted under /_synapse/client/password_reset/email/submit_token + """ + + isLeaf = 1 + + def __init__(self, hs: "HomeServer"): + """ + Args: + hs: server + """ + super().__init__() + + self.clock = hs.get_clock() + self.store = hs.get_datastore() + + self._local_threepid_handling_disabled_due_to_email_config = ( + hs.config.local_threepid_handling_disabled_due_to_email_config + ) + self._confirmation_email_template = ( + hs.config.email_password_reset_template_confirmation_html + ) + self._email_password_reset_template_success_html = ( + hs.config.email_password_reset_template_success_html_content + ) + self._failure_email_template = ( + hs.config.email_password_reset_template_failure_html + ) + + # This resource should not be mounted if threepid behaviour is not LOCAL + assert hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL + + async def _async_render_GET(self, request: Request) -> Tuple[int, bytes]: + sid = parse_string(request, "sid", required=True) + token = parse_string(request, "token", required=True) + client_secret = parse_string(request, "client_secret", required=True) + assert_valid_client_secret(client_secret) + + # Show a confirmation page, just in case someone accidentally clicked this link when + # they didn't mean to + template_vars = { + "sid": sid, + "token": token, + "client_secret": client_secret, + } + return ( + 200, + self._confirmation_email_template.render(**template_vars).encode("utf-8"), + ) + + async def _async_render_POST(self, request: Request) -> Tuple[int, bytes]: + sid = parse_string(request, "sid", required=True) + token = parse_string(request, "token", required=True) + client_secret = parse_string(request, "client_secret", required=True) + + # Attempt to validate a 3PID session + try: + # Mark the session as valid + next_link = await self.store.validate_threepid_session( + sid, client_secret, token, self.clock.time_msec() + ) + + # Perform a 302 redirect if next_link is set + if next_link: + if next_link.startswith("file:///"): + logger.warning( + "Not redirecting to next_link as it is a local file: address" + ) + else: + next_link_bytes = next_link.encode("utf-8") + request.setHeader("Location", next_link_bytes) + return ( + 302, + ( + b'You are being redirected to <a src="%s">%s</a>.' + % (next_link_bytes, next_link_bytes) + ), + ) + + # Otherwise show the success template + html_bytes = self._email_password_reset_template_success_html.encode( + "utf-8" + ) + status_code = 200 + except ThreepidValidationError as e: + status_code = e.code + + # Show a failure page with a reason + template_vars = {"failure_reason": e.msg} + html_bytes = self._failure_email_template.render(**template_vars).encode( + "utf-8" + ) + + return status_code, html_bytes diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index c7e3015b5d..31082bb16a 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -13,43 +13,46 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import heapq import logging -from collections import namedtuple +from collections import defaultdict, namedtuple from typing import ( + Any, Awaitable, + Callable, + DefaultDict, Dict, Iterable, List, Optional, Sequence, Set, + Tuple, Union, - cast, overload, ) import attr from frozendict import frozendict -from prometheus_client import Histogram +from prometheus_client import Counter, Histogram from typing_extensions import Literal from synapse.api.constants import EventTypes from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions from synapse.events import EventBase from synapse.events.snapshot import EventContext +from synapse.logging.context import ContextResourceUsage from synapse.logging.utils import log_function from synapse.state import v1, v2 from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.roommember import ProfileInfo -from synapse.types import Collection, MutableStateMap, StateMap -from synapse.util import Clock +from synapse.types import Collection, StateMap from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import Measure, measure_func logger = logging.getLogger(__name__) - +metrics_logger = logging.getLogger("synapse.state.metrics") # Metrics for number of state groups involved in a resolution. state_groups_histogram = Histogram( @@ -449,19 +452,44 @@ class StateHandler: state_map = {ev.event_id: ev for st in state_sets for ev in st} - with Measure(self.clock, "state._resolve_events"): - new_state = await resolve_events_with_store( - self.clock, - event.room_id, - room_version, - state_set_ids, - event_map=state_map, - state_res_store=StateResolutionStore(self.store), - ) + new_state = await self._state_resolution_handler.resolve_events_with_store( + event.room_id, + room_version, + state_set_ids, + event_map=state_map, + state_res_store=StateResolutionStore(self.store), + ) return {key: state_map[ev_id] for key, ev_id in new_state.items()} +@attr.s(slots=True) +class _StateResMetrics: + """Keeps track of some usage metrics about state res.""" + + # System and User CPU time, in seconds + cpu_time = attr.ib(type=float, default=0.0) + + # time spent on database transactions (excluding scheduling time). This roughly + # corresponds to the amount of work done on the db server, excluding event fetches. + db_time = attr.ib(type=float, default=0.0) + + # number of events fetched from the db. + db_events = attr.ib(type=int, default=0) + + +_biggest_room_by_cpu_counter = Counter( + "synapse_state_res_cpu_for_biggest_room_seconds", + "CPU time spent performing state resolution for the single most expensive " + "room for state resolution", +) +_biggest_room_by_db_counter = Counter( + "synapse_state_res_db_for_biggest_room_seconds", + "Database time spent performing state resolution for the single most " + "expensive room for state resolution", +) + + class StateResolutionHandler: """Responsible for doing state conflict resolution. @@ -472,10 +500,9 @@ class StateResolutionHandler: def __init__(self, hs): self.clock = hs.get_clock() - # dict of set of event_ids -> _StateCacheEntry. - self._state_cache = None self.resolve_linearizer = Linearizer(name="state_resolve_lock") + # dict of set of event_ids -> _StateCacheEntry. self._state_cache = ExpiringCache( cache_name="state_cache", clock=self.clock, @@ -485,6 +512,17 @@ class StateResolutionHandler: reset_expiry_on_get=True, ) + # + # stuff for tracking time spent on state-res by room + # + + # tracks the amount of work done on state res per room + self._state_res_metrics = defaultdict( + _StateResMetrics + ) # type: DefaultDict[str, _StateResMetrics] + + self.clock.looping_call(self._report_metrics, 120 * 1000) + @log_function async def resolve_state_groups( self, @@ -519,57 +557,26 @@ class StateResolutionHandler: Returns: The resolved state """ - logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys()) - group_names = frozenset(state_groups_ids.keys()) with (await self.resolve_linearizer.queue(group_names)): - if self._state_cache is not None: - cache = self._state_cache.get(group_names, None) - if cache: - return cache + cache = self._state_cache.get(group_names, None) + if cache: + return cache logger.info( - "Resolving state for %s with %d groups", room_id, len(state_groups_ids) + "Resolving state for %s with groups %s", room_id, list(group_names), ) state_groups_histogram.observe(len(state_groups_ids)) - # start by assuming we won't have any conflicted state, and build up the new - # state map by iterating through the state groups. If we discover a conflict, - # we give up and instead use `resolve_events_with_store`. - # - # XXX: is this actually worthwhile, or should we just let - # resolve_events_with_store do it? - new_state = {} # type: MutableStateMap[str] - conflicted_state = False - for st in state_groups_ids.values(): - for key, e_id in st.items(): - if key in new_state: - conflicted_state = True - break - new_state[key] = e_id - if conflicted_state: - break - - if conflicted_state: - logger.info("Resolving conflicted state for %r", room_id) - with Measure(self.clock, "state._resolve_events"): - # resolve_events_with_store returns a StateMap, but we can - # treat it as a MutableStateMap as it is above. It isn't - # actually mutated anymore (and is frozen in - # _make_state_cache_entry below). - new_state = cast( - MutableStateMap, - await resolve_events_with_store( - self.clock, - room_id, - room_version, - list(state_groups_ids.values()), - event_map=event_map, - state_res_store=state_res_store, - ), - ) + new_state = await self.resolve_events_with_store( + room_id, + room_version, + list(state_groups_ids.values()), + event_map=event_map, + state_res_store=state_res_store, + ) # if the new state matches any of the input state groups, we can # use that state group again. Otherwise we will generate a state_id @@ -579,11 +586,118 @@ class StateResolutionHandler: with Measure(self.clock, "state.create_group_ids"): cache = _make_state_cache_entry(new_state, state_groups_ids) - if self._state_cache is not None: - self._state_cache[group_names] = cache + self._state_cache[group_names] = cache return cache + async def resolve_events_with_store( + self, + room_id: str, + room_version: str, + state_sets: Sequence[StateMap[str]], + event_map: Optional[Dict[str, EventBase]], + state_res_store: "StateResolutionStore", + ) -> StateMap[str]: + """ + Args: + room_id: the room we are working in + + room_version: Version of the room + + state_sets: List of dicts of (type, state_key) -> event_id, + which are the different state groups to resolve. + + event_map: + a dict from event_id to event, for any events that we happen to + have in flight (eg, those currently being persisted). This will be + used as a starting point fof finding the state we need; any missing + events will be requested via state_map_factory. + + If None, all events will be fetched via state_res_store. + + state_res_store: a place to fetch events from + + Returns: + a map from (type, state_key) to event_id. + """ + try: + with Measure(self.clock, "state._resolve_events") as m: + v = KNOWN_ROOM_VERSIONS[room_version] + if v.state_res == StateResolutionVersions.V1: + return await v1.resolve_events_with_store( + room_id, state_sets, event_map, state_res_store.get_events + ) + else: + return await v2.resolve_events_with_store( + self.clock, + room_id, + room_version, + state_sets, + event_map, + state_res_store, + ) + finally: + self._record_state_res_metrics(room_id, m.get_resource_usage()) + + def _record_state_res_metrics(self, room_id: str, rusage: ContextResourceUsage): + room_metrics = self._state_res_metrics[room_id] + room_metrics.cpu_time += rusage.ru_utime + rusage.ru_stime + room_metrics.db_time += rusage.db_txn_duration_sec + room_metrics.db_events += rusage.evt_db_fetch_count + + def _report_metrics(self): + if not self._state_res_metrics: + # no state res has happened since the last iteration: don't bother logging. + return + + self._report_biggest( + lambda i: i.cpu_time, "CPU time", _biggest_room_by_cpu_counter, + ) + + self._report_biggest( + lambda i: i.db_time, "DB time", _biggest_room_by_db_counter, + ) + + self._state_res_metrics.clear() + + def _report_biggest( + self, + extract_key: Callable[[_StateResMetrics], Any], + metric_name: str, + prometheus_counter_metric: Counter, + ) -> None: + """Report metrics on the biggest rooms for state res + + Args: + extract_key: a callable which, given a _StateResMetrics, extracts a single + metric to sort by. + metric_name: the name of the metric we have extracted, for the log line + prometheus_counter_metric: a prometheus metric recording the sum of the + the extracted metric + """ + n_to_log = 10 + if not metrics_logger.isEnabledFor(logging.DEBUG): + # only need the most expensive if we don't have debug logging, which + # allows nlargest() to degrade to max() + n_to_log = 1 + + items = self._state_res_metrics.items() + + # log the N biggest rooms + biggest = heapq.nlargest( + n_to_log, items, key=lambda i: extract_key(i[1]) + ) # type: List[Tuple[str, _StateResMetrics]] + metrics_logger.debug( + "%i biggest rooms for state-res by %s: %s", + len(biggest), + metric_name, + ["%s (%gs)" % (r, extract_key(m)) for (r, m) in biggest], + ) + + # report info on the single biggest to prometheus + _, biggest_metrics = biggest[0] + prometheus_counter_metric.inc(extract_key(biggest_metrics)) + def _make_state_cache_entry( new_state: StateMap[str], state_groups_ids: Dict[int, StateMap[str]] @@ -637,48 +751,7 @@ def _make_state_cache_entry( ) -def resolve_events_with_store( - clock: Clock, - room_id: str, - room_version: str, - state_sets: Sequence[StateMap[str]], - event_map: Optional[Dict[str, EventBase]], - state_res_store: "StateResolutionStore", -) -> Awaitable[StateMap[str]]: - """ - Args: - room_id: the room we are working in - - room_version: Version of the room - - state_sets: List of dicts of (type, state_key) -> event_id, - which are the different state groups to resolve. - - event_map: - a dict from event_id to event, for any events that we happen to - have in flight (eg, those currently being persisted). This will be - used as a starting point fof finding the state we need; any missing - events will be requested via state_map_factory. - - If None, all events will be fetched via state_res_store. - - state_res_store: a place to fetch events from - - Returns: - a map from (type, state_key) to event_id. - """ - v = KNOWN_ROOM_VERSIONS[room_version] - if v.state_res == StateResolutionVersions.V1: - return v1.resolve_events_with_store( - room_id, state_sets, event_map, state_res_store.get_events - ) - else: - return v2.resolve_events_with_store( - clock, room_id, room_version, state_sets, event_map, state_res_store - ) - - -@attr.s +@attr.s(slots=True) class StateResolutionStore: """Interface that allows state resolution algorithms to access the database in well defined way. diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 8e5d78f6f7..bbff3c8d5b 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -47,6 +47,9 @@ class Storage: # interfaces. self.main = stores.main - self.persistence = EventsPersistenceStorage(hs, stores) self.purge_events = PurgeEventsStorage(hs, stores) self.state = StateGroupStorage(hs, stores) + + self.persistence = None + if stores.persist_events: + self.persistence = EventsPersistenceStorage(hs, stores) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index ed8a9bffb1..6116191b16 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -403,6 +403,24 @@ class DatabasePool: *args: Any, **kwargs: Any ) -> R: + """Start a new database transaction with the given connection. + + Note: The given func may be called multiple times under certain + failure modes. This is normally fine when in a standard transaction, + but care must be taken if the connection is in `autocommit` mode that + the function will correctly handle being aborted and retried half way + through its execution. + + Args: + conn + desc + after_callbacks + exception_callbacks + func + *args + **kwargs + """ + start = monotonic_time() txn_id = self._TXN_ID @@ -508,7 +526,12 @@ class DatabasePool: sql_txn_timer.labels(desc).observe(duration) async def runInteraction( - self, desc: str, func: "Callable[..., R]", *args: Any, **kwargs: Any + self, + desc: str, + func: "Callable[..., R]", + *args: Any, + db_autocommit: bool = False, + **kwargs: Any ) -> R: """Starts a transaction on the database and runs a given function @@ -518,6 +541,18 @@ class DatabasePool: database transaction (twisted.enterprise.adbapi.Transaction) as its first argument, followed by `args` and `kwargs`. + db_autocommit: Whether to run the function in "autocommit" mode, + i.e. outside of a transaction. This is useful for transactions + that are only a single query. + + Currently, this is only implemented for Postgres. SQLite will still + run the function inside a transaction. + + WARNING: This means that if func fails half way through then + the changes will *not* be rolled back. `func` may also get + called multiple times if the transaction is retried, so must + correctly handle that case. + args: positional args to pass to `func` kwargs: named args to pass to `func` @@ -538,6 +573,7 @@ class DatabasePool: exception_callbacks, func, *args, + db_autocommit=db_autocommit, **kwargs ) @@ -551,7 +587,11 @@ class DatabasePool: return cast(R, result) async def runWithConnection( - self, func: "Callable[..., R]", *args: Any, **kwargs: Any + self, + func: "Callable[..., R]", + *args: Any, + db_autocommit: bool = False, + **kwargs: Any ) -> R: """Wraps the .runWithConnection() method on the underlying db_pool. @@ -560,6 +600,9 @@ class DatabasePool: database connection (twisted.enterprise.adbapi.Connection) as its first argument, followed by `args` and `kwargs`. args: positional args to pass to `func` + db_autocommit: Whether to run the function in "autocommit" mode, + i.e. outside of a transaction. This is useful for transaction + that are only a single query. Currently only affects postgres. kwargs: named args to pass to `func` Returns: @@ -575,6 +618,13 @@ class DatabasePool: start_time = monotonic_time() def inner_func(conn, *args, **kwargs): + # We shouldn't be in a transaction. If we are then something + # somewhere hasn't committed after doing work. (This is likely only + # possible during startup, as `run*` will ensure changes are + # committed/rolled back before putting the connection back in the + # pool). + assert not self.engine.in_transaction(conn) + with LoggingContext("runWithConnection", parent_context) as context: sched_duration_sec = monotonic_time() - start_time sql_scheduling_timer.observe(sched_duration_sec) @@ -584,7 +634,14 @@ class DatabasePool: logger.debug("Reconnecting closed database connection") conn.reconnect() - return func(conn, *args, **kwargs) + try: + if db_autocommit: + self.engine.attempt_to_set_autocommit(conn, True) + + return func(conn, *args, **kwargs) + finally: + if db_autocommit: + self.engine.attempt_to_set_autocommit(conn, False) return await make_deferred_yieldable( self._db_pool.runWithConnection(inner_func, *args, **kwargs) @@ -952,7 +1009,7 @@ class DatabasePool: key_names: Collection[str], key_values: Collection[Iterable[Any]], value_names: Collection[str], - value_values: Iterable[Iterable[str]], + value_values: Iterable[Iterable[Any]], ) -> None: """ Upsert, many times. @@ -981,7 +1038,7 @@ class DatabasePool: key_names: Iterable[str], key_values: Collection[Iterable[Any]], value_names: Collection[str], - value_values: Iterable[Iterable[str]], + value_values: Iterable[Iterable[Any]], ) -> None: """ Upsert, many times, but without native UPSERT support or batching. diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py index 985b12df91..aa5d490624 100644 --- a/synapse/storage/databases/__init__.py +++ b/synapse/storage/databases/__init__.py @@ -75,7 +75,7 @@ class Databases: # If we're on a process that can persist events also # instantiate a `PersistEventsStore` - if hs.config.worker.writers.events == hs.get_instance_name(): + if hs.get_instance_name() in hs.config.worker.writers.events: persist_events = PersistEventsStore(hs, database, main) if "state" in database_config.databases: diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 2ae2fbd5d7..0cb12f4c61 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -160,19 +160,25 @@ class DataStore( ) if isinstance(self.database_engine, PostgresEngine): + # We set the `writers` to an empty list here as we don't care about + # missing updates over restarts, as we'll not have anything in our + # caches to invalidate. (This reduces the amount of writes to the DB + # that happen). self._cache_id_gen = MultiWriterIdGenerator( db_conn, database, - instance_name="master", + stream_name="caches", + instance_name=hs.get_instance_name(), table="cache_invalidation_stream_by_instance", instance_column="instance_name", id_column="stream_id", sequence_name="cache_invalidation_stream_seq", + writers=[], ) else: self._cache_id_gen = None - super(DataStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self._presence_on_startup = self._get_active_presence(db_conn) diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 4436b1a83d..ef81d73573 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -29,22 +29,20 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache logger = logging.getLogger(__name__) -class AccountDataWorkerStore(SQLBaseStore): +# The ABCMeta metaclass ensures that it cannot be instantiated without +# the abstract methods being implemented. +class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): """This is an abstract base class where subclasses must implement `get_max_account_data_stream_id` which can be called in the initializer. """ - # This ABCMeta metaclass ensures that we cannot be instantiated without - # the abstract methods being implemented. - __metaclass__ = abc.ABCMeta - def __init__(self, database: DatabasePool, db_conn, hs): account_max = self.get_max_account_data_stream_id() self._account_data_stream_cache = StreamChangeCache( "AccountDataAndTagsChangeCache", account_max ) - super(AccountDataWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) @abc.abstractmethod def get_max_account_data_stream_id(self): @@ -315,7 +313,7 @@ class AccountDataStore(AccountDataWorkerStore): ], ) - super(AccountDataStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) def get_max_account_data_stream_id(self) -> int: """Get the current max stream id for the private user data stream @@ -341,7 +339,7 @@ class AccountDataStore(AccountDataWorkerStore): """ content_json = json_encoder.encode(content) - with await self._account_data_id_gen.get_next() as next_id: + async with self._account_data_id_gen.get_next() as next_id: # no need to lock here as room_account_data has a unique constraint # on (user_id, room_id, account_data_type) so simple_upsert will # retry if there is a conflict. @@ -389,7 +387,7 @@ class AccountDataStore(AccountDataWorkerStore): """ content_json = json_encoder.encode(content) - with await self._account_data_id_gen.get_next() as next_id: + async with self._account_data_id_gen.get_next() as next_id: # no need to lock here as account_data has a unique constraint on # (user_id, account_data_type) so simple_upsert will retry if # there is a conflict. diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 454c0bc50c..85f6b1e3fd 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -52,7 +52,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore): ) self.exclusive_user_regex = _make_exclusive_regex(self.services_cache) - super(ApplicationServiceWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) def get_app_services(self): return self.services_cache diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index c2fc847fbc..239c7a949c 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -31,7 +31,7 @@ LAST_SEEN_GRANULARITY = 120 * 1000 class ClientIpBackgroundUpdateStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(ClientIpBackgroundUpdateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( "user_ips_device_index", @@ -358,7 +358,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): name="client_ip_last_seen", keylen=4, max_entries=50000 ) - super(ClientIpStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.user_ips_max_age = hs.config.user_ips_max_age diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 0044433110..d42faa3f1f 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -283,7 +283,7 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" def __init__(self, database: DatabasePool, db_conn, hs): - super(DeviceInboxBackgroundUpdateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( "device_inbox_stream_index", @@ -313,7 +313,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" def __init__(self, database: DatabasePool, db_conn, hs): - super(DeviceInboxStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) # Map of (user_id, device_id) to the last stream_id that has been # deleted up to. This is so that we can no op deletions. @@ -362,7 +362,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) rows.append((destination, stream_id, now_ms, edu_json)) txn.executemany(sql, rows) - with await self._device_inbox_id_gen.get_next() as stream_id: + async with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() await self.db_pool.runInteraction( "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id @@ -411,7 +411,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) txn, stream_id, local_messages_by_user_then_device ) - with await self._device_inbox_id_gen.get_next() as stream_id: + async with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() await self.db_pool.runInteraction( "add_messages_from_remote_to_device_inbox", diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index add4e3ea0e..fdf394c612 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -377,7 +377,7 @@ class DeviceWorkerStore(SQLBaseStore): THe new stream ID. """ - with await self._device_list_id_gen.get_next() as stream_id: + async with self._device_list_id_gen.get_next() as stream_id: await self.db_pool.runInteraction( "add_user_sig_change_to_streams", self._add_user_signature_change_txn, @@ -481,7 +481,7 @@ class DeviceWorkerStore(SQLBaseStore): } async def get_users_whose_devices_changed( - self, from_key: str, user_ids: Iterable[str] + self, from_key: int, user_ids: Iterable[str] ) -> Set[str]: """Get set of users whose devices have changed since `from_key` that are in the given list of user_ids. @@ -493,7 +493,6 @@ class DeviceWorkerStore(SQLBaseStore): Returns: The set of user_ids whose devices have changed since `from_key` """ - from_key = int(from_key) # Get set of users who *may* have changed. Users not in the returned # list have definitely not changed. @@ -527,7 +526,7 @@ class DeviceWorkerStore(SQLBaseStore): ) async def get_users_whose_signatures_changed( - self, user_id: str, from_key: str + self, user_id: str, from_key: int ) -> Set[str]: """Get the users who have new cross-signing signatures made by `user_id` since `from_key`. @@ -539,7 +538,7 @@ class DeviceWorkerStore(SQLBaseStore): Returns: A set of user IDs with updated signatures. """ - from_key = int(from_key) + if self._user_signature_stream_cache.has_entity_changed(user_id, from_key): sql = """ SELECT DISTINCT user_ids FROM user_signature_stream @@ -702,7 +701,7 @@ class DeviceWorkerStore(SQLBaseStore): class DeviceBackgroundUpdateStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(DeviceBackgroundUpdateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( "device_lists_stream_idx", @@ -827,7 +826,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(DeviceStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) # Map of (user_id, device_id) -> bool. If there is an entry that implies # the device exists. @@ -1094,7 +1093,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): if not device_ids: return - with await self._device_list_id_gen.get_next_mult( + async with self._device_list_id_gen.get_next_mult( len(device_ids) ) as stream_ids: await self.db_pool.runInteraction( @@ -1109,7 +1108,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): return stream_ids[-1] context = get_active_span_text_map() - with await self._device_list_id_gen.get_next_mult( + async with self._device_list_id_gen.get_next_mult( len(hosts) * len(device_ids) ) as stream_ids: await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index fba3098ea2..22e1ed15d0 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -35,7 +35,7 @@ if TYPE_CHECKING: from synapse.handlers.e2e_keys import SignatureListItem -@attr.s +@attr.s(slots=True) class DeviceKeyLookupResult: """The type returned by get_e2e_device_keys_and_signatures""" @@ -831,7 +831,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): key (dict): the key data """ - with await self._cross_signing_id_gen.get_next() as stream_id: + async with self._cross_signing_id_gen.get_next() as stream_id: return await self.db_pool.runInteraction( "add_e2e_cross_signing_key", self._set_e2e_cross_signing_key_txn, diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 0b69aa6a94..6d3689c09e 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -438,7 +438,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas """ if stream_ordering <= self.stream_ordering_month_ago: - raise StoreError(400, "stream_ordering too old") + raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,)) sql = """ SELECT event_id FROM stream_ordering_to_exterm @@ -600,7 +600,7 @@ class EventFederationStore(EventFederationWorkerStore): EVENT_AUTH_STATE_ONLY = "event_auth_state_only" def __init__(self, database: DatabasePool, db_conn, hs): - super(EventFederationStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_update_handler( self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 5233ed83e2..62f1738732 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -68,7 +68,7 @@ def _deserialize_action(actions, is_highlight): class EventPushActionsWorkerStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(EventPushActionsWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) # These get correctly set by _find_stream_orderings_for_times_txn self.stream_ordering_month_ago = None @@ -661,7 +661,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): EPA_HIGHLIGHT_INDEX = "epa_highlight_index" def __init__(self, database: DatabasePool, db_conn, hs): - super(EventPushActionsStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( self.EPA_HIGHLIGHT_INDEX, @@ -969,7 +969,7 @@ def _action_has_highlight(actions): return False -@attr.s +@attr.s(slots=True) class _EventPushSummary: """Summary of pending event push actions for a given user in a given room. Used in _rotate_notifs_before_txn to manipulate results from event_push_actions. diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index b3d27a2ee7..18def01f50 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -17,7 +17,7 @@ import itertools import logging from collections import OrderedDict, namedtuple -from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple import attr from prometheus_client import Counter @@ -32,7 +32,7 @@ from synapse.logging.utils import log_function from synapse.storage._base import db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.search import SearchEntry -from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.types import StateMap, get_domain_from_id from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.iterutils import batch_iter @@ -97,18 +97,21 @@ class PersistEventsStore: self.store = main_data_store self.database_engine = db.engine self._clock = hs.get_clock() + self._instance_name = hs.get_instance_name() self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages self.is_mine_id = hs.is_mine_id # Ideally we'd move these ID gens here, unfortunately some other ID # generators are chained off them so doing so is a bit of a PITA. - self._backfill_id_gen = self.store._backfill_id_gen # type: StreamIdGenerator - self._stream_id_gen = self.store._stream_id_gen # type: StreamIdGenerator + self._backfill_id_gen = ( + self.store._backfill_id_gen + ) # type: MultiWriterIdGenerator + self._stream_id_gen = self.store._stream_id_gen # type: MultiWriterIdGenerator # This should only exist on instances that are configured to write assert ( - hs.config.worker.writers.events == hs.get_instance_name() + hs.get_instance_name() in hs.config.worker.writers.events ), "Can only instantiate EventsStore on master" async def _persist_events_and_state_updates( @@ -153,15 +156,15 @@ class PersistEventsStore: # Note: Multiple instances of this function cannot be in flight at # the same time for the same room. if backfilled: - stream_ordering_manager = await self._backfill_id_gen.get_next_mult( + stream_ordering_manager = self._backfill_id_gen.get_next_mult( len(events_and_contexts) ) else: - stream_ordering_manager = await self._stream_id_gen.get_next_mult( + stream_ordering_manager = self._stream_id_gen.get_next_mult( len(events_and_contexts) ) - with stream_ordering_manager as stream_orderings: + async with stream_ordering_manager as stream_orderings: for (event, context), stream in zip(events_and_contexts, stream_orderings): event.internal_metadata.stream_ordering = stream @@ -213,7 +216,7 @@ class PersistEventsStore: Returns: Filtered event ids """ - results = [] + results = [] # type: List[str] def _get_events_which_are_prevs_txn(txn, batch): sql = """ @@ -631,7 +634,9 @@ class PersistEventsStore: ) @classmethod - def _filter_events_and_contexts_for_duplicates(cls, events_and_contexts): + def _filter_events_and_contexts_for_duplicates( + cls, events_and_contexts: List[Tuple[EventBase, EventContext]] + ) -> List[Tuple[EventBase, EventContext]]: """Ensure that we don't have the same event twice. Pick the earliest non-outlier if there is one, else the earliest one. @@ -641,7 +646,9 @@ class PersistEventsStore: Returns: list[(EventBase, EventContext)]: filtered list """ - new_events_and_contexts = OrderedDict() + new_events_and_contexts = ( + OrderedDict() + ) # type: OrderedDict[str, Tuple[EventBase, EventContext]] for event, context in events_and_contexts: prev_event_context = new_events_and_contexts.get(event.event_id) if prev_event_context: @@ -655,7 +662,12 @@ class PersistEventsStore: new_events_and_contexts[event.event_id] = (event, context) return list(new_events_and_contexts.values()) - def _update_room_depths_txn(self, txn, events_and_contexts, backfilled): + def _update_room_depths_txn( + self, + txn, + events_and_contexts: List[Tuple[EventBase, EventContext]], + backfilled: bool, + ): """Update min_depth for each room Args: @@ -664,7 +676,7 @@ class PersistEventsStore: we are persisting backfilled (bool): True if the events were backfilled """ - depth_updates = {} + depth_updates = {} # type: Dict[str, int] for event, context in events_and_contexts: # Remove the any existing cache entries for the event_ids txn.call_after(self.store._invalidate_get_event_cache, event.event_id) @@ -800,6 +812,7 @@ class PersistEventsStore: table="events", values=[ { + "instance_name": self._instance_name, "stream_ordering": event.internal_metadata.stream_ordering, "topological_ordering": event.depth, "depth": event.depth, @@ -1095,6 +1108,10 @@ class PersistEventsStore: def _store_room_members_txn(self, txn, events, backfilled): """Store a room member in the database. """ + + def str_or_none(val: Any) -> Optional[str]: + return val if isinstance(val, str) else None + self.db_pool.simple_insert_many_txn( txn, table="room_memberships", @@ -1105,8 +1122,8 @@ class PersistEventsStore: "sender": event.user_id, "room_id": event.room_id, "membership": event.membership, - "display_name": event.content.get("displayname", None), - "avatar_url": event.content.get("avatar_url", None), + "display_name": str_or_none(event.content.get("displayname")), + "avatar_url": str_or_none(event.content.get("avatar_url")), } for event in events ], @@ -1436,7 +1453,7 @@ class PersistEventsStore: Forward extremities are handled when we first start persisting the events. """ - events_by_room = {} + events_by_room = {} # type: Dict[str, List[EventBase]] for ev in events: events_by_room.setdefault(ev.room_id, []).append(ev) diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index e53c6373a8..5e4af2eb51 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -29,7 +29,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities" def __init__(self, database: DatabasePool, db_conn, hs): - super(EventsBackgroundUpdatesStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_update_handler( self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index a7a73cc3d8..f95679ebc4 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import division - import itertools import logging import threading @@ -42,7 +40,8 @@ from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool -from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.storage.engines import PostgresEngine +from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.types import Collection, get_domain_from_id from synapse.util.caches.descriptors import Cache, cached from synapse.util.iterutils import batch_iter @@ -76,29 +75,60 @@ class EventRedactBehaviour(Names): class EventsWorkerStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(EventsWorkerStore, self).__init__(database, db_conn, hs) - - if hs.config.worker.writers.events == hs.get_instance_name(): - # We are the process in charge of generating stream ids for events, - # so instantiate ID generators based on the database - self._stream_id_gen = StreamIdGenerator( - db_conn, "events", "stream_ordering", + super().__init__(database, db_conn, hs) + + if isinstance(database.engine, PostgresEngine): + # If we're using Postgres than we can use `MultiWriterIdGenerator` + # regardless of whether this process writes to the streams or not. + self._stream_id_gen = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + stream_name="events", + instance_name=hs.get_instance_name(), + table="events", + instance_column="instance_name", + id_column="stream_ordering", + sequence_name="events_stream_seq", + writers=hs.config.worker.writers.events, ) - self._backfill_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - step=-1, - extra_tables=[("ex_outlier_stream", "event_stream_ordering")], + self._backfill_id_gen = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + stream_name="backfill", + instance_name=hs.get_instance_name(), + table="events", + instance_column="instance_name", + id_column="stream_ordering", + sequence_name="events_backfill_stream_seq", + positive=False, + writers=hs.config.worker.writers.events, ) else: - # Another process is in charge of persisting events and generating - # stream IDs: rely on the replication streams to let us know which - # IDs we can process. - self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering") - self._backfill_id_gen = SlavedIdTracker( - db_conn, "events", "stream_ordering", step=-1 - ) + # We shouldn't be running in worker mode with SQLite, but its useful + # to support it for unit tests. + # + # If this process is the writer than we need to use + # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets + # updated over replication. (Multiple writers are not supported for + # SQLite). + if hs.get_instance_name() in hs.config.worker.writers.events: + self._stream_id_gen = StreamIdGenerator( + db_conn, "events", "stream_ordering", + ) + self._backfill_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + step=-1, + extra_tables=[("ex_outlier_stream", "event_stream_ordering")], + ) + else: + self._stream_id_gen = SlavedIdTracker( + db_conn, "events", "stream_ordering" + ) + self._backfill_id_gen = SlavedIdTracker( + db_conn, "events", "stream_ordering", step=-1 + ) self._get_event_cache = Cache( "*getEvent*", diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index 1cbf31f52d..f724f45494 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -1275,7 +1275,7 @@ class GroupServerStore(GroupServerWorkerStore): return next_id - with await self._group_updates_id_gen.get_next() as next_id: + async with self._group_updates_id_gen.get_next() as next_id: res = await self.db_pool.runInteraction( "register_user_group_membership", _register_user_group_membership_txn, diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 86557d5512..cc538c5c10 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -17,12 +17,14 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool +BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = ( + "media_repository_drop_index_wo_method" +) + class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(MediaRepositoryBackgroundUpdateStore, self).__init__( - database, db_conn, hs - ) + super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( update_name="local_media_repository_url_idx", @@ -32,12 +34,65 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): where_clause="url_cache IS NOT NULL", ) + # The following the updates add the method to the unique constraint of + # the thumbnail databases. That fixes an issue, where thumbnails of the + # same resolution, but different methods could overwrite one another. + # This can happen with custom thumbnail configs or with dynamic thumbnailing. + self.db_pool.updates.register_background_index_update( + update_name="local_media_repository_thumbnails_method_idx", + index_name="local_media_repository_thumbn_media_id_width_height_method_key", + table="local_media_repository_thumbnails", + columns=[ + "media_id", + "thumbnail_width", + "thumbnail_height", + "thumbnail_type", + "thumbnail_method", + ], + unique=True, + ) + + self.db_pool.updates.register_background_index_update( + update_name="remote_media_repository_thumbnails_method_idx", + index_name="remote_media_repository_thumbn_media_origin_id_width_height_method_key", + table="remote_media_cache_thumbnails", + columns=[ + "media_origin", + "media_id", + "thumbnail_width", + "thumbnail_height", + "thumbnail_type", + "thumbnail_method", + ], + unique=True, + ) + + self.db_pool.updates.register_background_update_handler( + BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD, + self._drop_media_index_without_method, + ) + + async def _drop_media_index_without_method(self, progress, batch_size): + def f(txn): + txn.execute( + "ALTER TABLE local_media_repository_thumbnails DROP CONSTRAINT IF EXISTS local_media_repository_thumbn_media_id_thumbnail_width_thum_key" + ) + txn.execute( + "ALTER TABLE remote_media_cache_thumbnails DROP CONSTRAINT IF EXISTS remote_media_repository_thumbn_media_id_thumbnail_width_thum_key" + ) + + await self.db_pool.runInteraction("drop_media_indices_without_method", f) + await self.db_pool.updates._end_background_update( + BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD + ) + return 1 + class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): """Persistence for attachments and avatars""" def __init__(self, database: DatabasePool, db_conn, hs): - super(MediaRepositoryStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]: """Get the metadata for a local piece of media diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py index 686052bd83..92099f95ce 100644 --- a/synapse/storage/databases/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py @@ -12,10 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import typing -from collections import Counter -from synapse.metrics import BucketCollector +from synapse.metrics import GaugeBucketCollector from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool @@ -23,6 +21,26 @@ from synapse.storage.databases.main.event_push_actions import ( EventPushActionsWorkerStore, ) +# Collect metrics on the number of forward extremities that exist. +_extremities_collecter = GaugeBucketCollector( + "synapse_forward_extremities", + "Number of rooms on the server with the given number of forward extremities" + " or fewer", + buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500], +) + +# we also expose metrics on the "number of excess extremity events", which is +# (E-1)*N, where E is the number of extremities and N is the number of state +# events in the room. This is an approximation to the number of state events +# we could remove from state resolution by reducing the graph to a single +# forward extremity. +_excess_state_events_collecter = GaugeBucketCollector( + "synapse_excess_extremity_events", + "Number of rooms on the server with the given number of excess extremity " + "events, or fewer", + buckets=[0] + [1 << n for n in range(12)], +) + class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): """Functions to pull various metrics from the DB, for e.g. phone home @@ -32,18 +50,6 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) - # Collect metrics on the number of forward extremities that exist. - # Counter of number of extremities to count - self._current_forward_extremities_amount = ( - Counter() - ) # type: typing.Counter[int] - - BucketCollector( - "synapse_forward_extremities", - lambda: self._current_forward_extremities_amount, - buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"], - ) - # Read the extrems every 60 minutes def read_forward_extremities(): # run as a background process to make sure that the database transactions @@ -58,14 +64,25 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): def fetch(txn): txn.execute( """ - select count(*) c from event_forward_extremities - group by room_id + SELECT t1.c, t2.c + FROM ( + SELECT room_id, COUNT(*) c FROM event_forward_extremities + GROUP BY room_id + ) t1 LEFT JOIN ( + SELECT room_id, COUNT(*) c FROM current_state_events + GROUP BY room_id + ) t2 ON t1.room_id = t2.room_id """ ) return txn.fetchall() res = await self.db_pool.runInteraction("read_forward_extremities", fetch) - self._current_forward_extremities_amount = Counter([x[0] for x in res]) + + _extremities_collecter.update_data(x[0] for x in res) + + _excess_state_events_collecter.update_data( + (x[0] - 1) * x[1] for x in res if x[1] + ) async def count_daily_messages(self): """ diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index 1d793d3deb..e93aad33cd 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -28,7 +28,7 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000 class MonthlyActiveUsersWorkerStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(MonthlyActiveUsersWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self._clock = hs.get_clock() self.hs = hs @@ -41,7 +41,14 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): """ def _count_users(txn): - sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users" + # Exclude app service users + sql = """ + SELECT COALESCE(count(*), 0) + FROM monthly_active_users + LEFT JOIN users + ON monthly_active_users.user_id=users.name + WHERE (users.appservice_id IS NULL OR users.appservice_id = ''); + """ txn.execute(sql) (count,) = txn.fetchone() return count @@ -120,7 +127,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(MonthlyActiveUsersStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self._limit_usage_by_mau = hs.config.limit_usage_by_mau self._mau_stats_only = hs.config.mau_stats_only diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index c9f655dfb7..dbbb99cb95 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -23,11 +23,11 @@ from synapse.util.iterutils import batch_iter class PresenceStore(SQLBaseStore): async def update_presence(self, presence_states): - stream_ordering_manager = await self._presence_id_gen.get_next_mult( + stream_ordering_manager = self._presence_id_gen.get_next_mult( len(presence_states) ) - with stream_ordering_manager as stream_orderings: + async with stream_ordering_manager as stream_orderings: await self.db_pool.runInteraction( "update_presence", self._update_presence_txn, diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index ea833829ae..ecfc6717b3 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -42,17 +42,17 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): The set of state groups that are referenced by deleted events. """ + parsed_token = await RoomStreamToken.parse(self, token) + return await self.db_pool.runInteraction( "purge_history", self._purge_history_txn, room_id, - token, + parsed_token, delete_local_events, ) - def _purge_history_txn(self, txn, room_id, token_str, delete_local_events): - token = RoomStreamToken.parse(token_str) - + def _purge_history_txn(self, txn, room_id, token, delete_local_events): # Tables that should be pruned: # event_auth # event_backward_extremities @@ -69,6 +69,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): # room_depth # state_groups # state_groups_state + # destination_rooms # we will build a temporary table listing the events so that we don't # have to keep shovelling the list back and forth across the @@ -336,6 +337,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): # and finally, the tables with an index on room_id (or no useful index) for table in ( "current_state_events", + "destination_rooms", "event_backward_extremities", "event_forward_extremities", "event_json", diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 0de802a86b..711d5aa23d 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -13,11 +13,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import abc import logging from typing import List, Tuple, Union +from synapse.api.errors import NotFoundError, StoreError from synapse.push.baserules import list_with_base_rules from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage._base import SQLBaseStore, db_to_json @@ -27,6 +27,7 @@ from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.pusher import PusherWorkerStore from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore +from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException from synapse.storage.util.id_generators import StreamIdGenerator from synapse.util import json_encoder @@ -60,6 +61,8 @@ def _load_rules(rawrules, enabled_map, use_new_defaults=False): return rules +# The ABCMeta metaclass ensures that it cannot be instantiated without +# the abstract methods being implemented. class PushRulesWorkerStore( ApplicationServiceWorkerStore, ReceiptsWorkerStore, @@ -67,17 +70,14 @@ class PushRulesWorkerStore( RoomMemberWorkerStore, EventsWorkerStore, SQLBaseStore, + metaclass=abc.ABCMeta, ): """This is an abstract base class where subclasses must implement `get_max_push_rules_stream_id` which can be called in the initializer. """ - # This ABCMeta metaclass ensures that we cannot be instantiated without - # the abstract methods being implemented. - __metaclass__ = abc.ABCMeta - def __init__(self, database: DatabasePool, db_conn, hs): - super(PushRulesWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) if hs.config.worker.worker_app is None: self._push_rules_stream_id_gen = StreamIdGenerator( @@ -338,7 +338,7 @@ class PushRuleStore(PushRulesWorkerStore): ) -> None: conditions_json = json_encoder.encode(conditions) actions_json = json_encoder.encode(actions) - with await self._push_rules_stream_id_gen.get_next() as stream_id: + async with self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() if before or after: @@ -540,6 +540,25 @@ class PushRuleStore(PushRulesWorkerStore): }, ) + # ensure we have a push_rules_enable row + # enabledness defaults to true + if isinstance(self.database_engine, PostgresEngine): + sql = """ + INSERT INTO push_rules_enable (id, user_name, rule_id, enabled) + VALUES (?, ?, ?, ?) + ON CONFLICT DO NOTHING + """ + elif isinstance(self.database_engine, Sqlite3Engine): + sql = """ + INSERT OR IGNORE INTO push_rules_enable (id, user_name, rule_id, enabled) + VALUES (?, ?, ?, ?) + """ + else: + raise RuntimeError("Unknown database engine") + + new_enable_id = self._push_rules_enable_id_gen.get_next() + txn.execute(sql, (new_enable_id, user_id, rule_id, 1)) + async def delete_push_rule(self, user_id: str, rule_id: str) -> None: """ Delete a push rule. Args specify the row to be deleted and can be @@ -552,6 +571,12 @@ class PushRuleStore(PushRulesWorkerStore): """ def delete_push_rule_txn(txn, stream_id, event_stream_ordering): + # we don't use simple_delete_one_txn because that would fail if the + # user did not have a push_rule_enable row. + self.db_pool.simple_delete_txn( + txn, "push_rules_enable", {"user_name": user_id, "rule_id": rule_id} + ) + self.db_pool.simple_delete_one_txn( txn, "push_rules", {"user_name": user_id, "rule_id": rule_id} ) @@ -560,7 +585,7 @@ class PushRuleStore(PushRulesWorkerStore): txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE" ) - with await self._push_rules_stream_id_gen.get_next() as stream_id: + async with self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() await self.db_pool.runInteraction( @@ -570,10 +595,29 @@ class PushRuleStore(PushRulesWorkerStore): event_stream_ordering, ) - async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None: - with await self._push_rules_stream_id_gen.get_next() as stream_id: - event_stream_ordering = self._stream_id_gen.get_current_token() + async def set_push_rule_enabled( + self, user_id: str, rule_id: str, enabled: bool, is_default_rule: bool + ) -> None: + """ + Sets the `enabled` state of a push rule. + Args: + user_id: the user ID of the user who wishes to enable/disable the rule + e.g. '@tina:example.org' + rule_id: the full rule ID of the rule to be enabled/disabled + e.g. 'global/override/.m.rule.roomnotif' + or 'global/override/myCustomRule' + enabled: True if the rule is to be enabled, False if it is to be + disabled + is_default_rule: True if and only if this is a server-default rule. + This skips the check for existence (as only user-created rules + are always stored in the database `push_rules` table). + + Raises: + NotFoundError if the rule does not exist. + """ + async with self._push_rules_stream_id_gen.get_next() as stream_id: + event_stream_ordering = self._stream_id_gen.get_current_token() await self.db_pool.runInteraction( "_set_push_rule_enabled_txn", self._set_push_rule_enabled_txn, @@ -582,12 +626,47 @@ class PushRuleStore(PushRulesWorkerStore): user_id, rule_id, enabled, + is_default_rule, ) def _set_push_rule_enabled_txn( - self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled + self, + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + enabled, + is_default_rule, ): new_id = self._push_rules_enable_id_gen.get_next() + + if not is_default_rule: + # first check it exists; we need to lock for key share so that a + # transaction that deletes the push rule will conflict with this one. + # We also need a push_rule_enable row to exist for every push_rules + # row, otherwise it is possible to simultaneously delete a push rule + # (that has no _enable row) and enable it, resulting in a dangling + # _enable row. To solve this: we either need to use SERIALISABLE or + # ensure we always have a push_rule_enable row for every push_rule + # row. We chose the latter. + for_key_share = "FOR KEY SHARE" + if not isinstance(self.database_engine, PostgresEngine): + # For key share is not applicable/available on SQLite + for_key_share = "" + sql = ( + """ + SELECT 1 FROM push_rules + WHERE user_name = ? AND rule_id = ? + %s + """ + % for_key_share + ) + txn.execute(sql, (user_id, rule_id)) + if txn.fetchone() is None: + # needed to set NOT_FOUND code. + raise NotFoundError("Push rule does not exist.") + self.db_pool.simple_upsert_txn( txn, "push_rules_enable", @@ -606,8 +685,30 @@ class PushRuleStore(PushRulesWorkerStore): ) async def set_push_rule_actions( - self, user_id, rule_id, actions, is_default_rule + self, + user_id: str, + rule_id: str, + actions: List[Union[dict, str]], + is_default_rule: bool, ) -> None: + """ + Sets the `actions` state of a push rule. + + Will throw NotFoundError if the rule does not exist; the Code for this + is NOT_FOUND. + + Args: + user_id: the user ID of the user who wishes to enable/disable the rule + e.g. '@tina:example.org' + rule_id: the full rule ID of the rule to be enabled/disabled + e.g. 'global/override/.m.rule.roomnotif' + or 'global/override/myCustomRule' + actions: A list of actions (each action being a dict or string), + e.g. ["notify", {"set_tweak": "highlight", "value": false}] + is_default_rule: True if and only if this is a server-default rule. + This skips the check for existence (as only user-created rules + are always stored in the database `push_rules` table). + """ actions_json = json_encoder.encode(actions) def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering): @@ -629,12 +730,19 @@ class PushRuleStore(PushRulesWorkerStore): update_stream=False, ) else: - self.db_pool.simple_update_one_txn( - txn, - "push_rules", - {"user_name": user_id, "rule_id": rule_id}, - {"actions": actions_json}, - ) + try: + self.db_pool.simple_update_one_txn( + txn, + "push_rules", + {"user_name": user_id, "rule_id": rule_id}, + {"actions": actions_json}, + ) + except StoreError as serr: + if serr.code == 404: + # this sets the NOT_FOUND error Code + raise NotFoundError("Push rule does not exist") + else: + raise self._insert_push_rules_update_txn( txn, @@ -646,7 +754,7 @@ class PushRuleStore(PushRulesWorkerStore): data={"actions": actions_json}, ) - with await self._push_rules_stream_id_gen.get_next() as stream_id: + async with self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index c388468273..df8609b97b 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -281,7 +281,7 @@ class PusherStore(PusherWorkerStore): last_stream_ordering, profile_tag="", ) -> None: - with await self._pushers_id_gen.get_next() as stream_id: + async with self._pushers_id_gen.get_next() as stream_id: # no need to lock because `pushers` has a unique key on # (app_id, pushkey, user_name) so simple_upsert will retry await self.db_pool.simple_upsert( @@ -344,7 +344,7 @@ class PusherStore(PusherWorkerStore): }, ) - with await self._pushers_id_gen.get_next() as stream_id: + async with self._pushers_id_gen.get_next() as stream_id: await self.db_pool.runInteraction( "delete_pusher", delete_pusher_txn, stream_id ) diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 4a0d5a320e..c79ddff680 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -31,17 +31,15 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache logger = logging.getLogger(__name__) -class ReceiptsWorkerStore(SQLBaseStore): +# The ABCMeta metaclass ensures that it cannot be instantiated without +# the abstract methods being implemented. +class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): """This is an abstract base class where subclasses must implement `get_max_receipt_stream_id` which can be called in the initializer. """ - # This ABCMeta metaclass ensures that we cannot be instantiated without - # the abstract methods being implemented. - __metaclass__ = abc.ABCMeta - def __init__(self, database: DatabasePool, db_conn, hs): - super(ReceiptsWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self._receipts_stream_cache = StreamChangeCache( "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id() @@ -388,7 +386,7 @@ class ReceiptsStore(ReceiptsWorkerStore): db_conn, "receipts_linearized", "stream_id" ) - super(ReceiptsStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) def get_max_receipt_stream_id(self): return self._receipts_id_gen.get_current_token() @@ -526,7 +524,7 @@ class ReceiptsStore(ReceiptsWorkerStore): "insert_receipt_conv", graph_to_linear ) - with await self._receipts_id_gen.get_next() as stream_id: + async with self._receipts_id_gen.get_next() as stream_id: event_ts = await self.db_pool.runInteraction( "insert_linearized_receipt", self.insert_linearized_receipt_txn, diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 01f20c03c2..a83df7759d 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -36,11 +36,14 @@ logger = logging.getLogger(__name__) class RegistrationWorkerStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(RegistrationWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.config = hs.config self.clock = hs.get_clock() + # Note: we don't check this sequence for consistency as we'd have to + # call `find_max_generated_user_id_localpart` each time, which is + # expensive if there are many entries. self._user_id_seq = build_sequence_generator( database.engine, find_max_generated_user_id_localpart, "user_id_seq", ) @@ -116,6 +119,20 @@ class RegistrationWorkerStore(SQLBaseStore): desc="get_expiration_ts_for_user", ) + async def is_account_expired(self, user_id: str, current_ts: int) -> bool: + """ + Returns whether an user account is expired. + + Args: + user_id: The user's ID + current_ts: The current timestamp + + Returns: + Whether the user account has expired + """ + expiration_ts = await self.get_expiration_ts_for_user(user_id) + return expiration_ts is not None and current_ts >= expiration_ts + async def set_account_validity_for_user( self, user_id: str, @@ -379,7 +396,7 @@ class RegistrationWorkerStore(SQLBaseStore): async def get_user_by_external_id( self, auth_provider: str, external_id: str - ) -> str: + ) -> Optional[str]: """Look up a user by their external auth id Args: @@ -387,7 +404,7 @@ class RegistrationWorkerStore(SQLBaseStore): external_id: id on that system Returns: - str|None: the mxid of the user, or None if they are not known + the mxid of the user, or None if they are not known """ return await self.db_pool.simple_select_one_onecol( table="user_external_ids", @@ -764,7 +781,7 @@ class RegistrationWorkerStore(SQLBaseStore): class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(RegistrationBackgroundUpdateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.clock = hs.get_clock() self.config = hs.config @@ -892,7 +909,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): class RegistrationStore(RegistrationBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(RegistrationStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self._account_validity = hs.config.account_validity self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 717df97301..3c7630857f 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -69,7 +69,7 @@ class RoomSortOrder(Enum): class RoomWorkerStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(RoomWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.config = hs.config @@ -104,7 +104,8 @@ class RoomWorkerStore(SQLBaseStore): curr.local_users_in_room AS joined_local_members, rooms.room_version AS version, rooms.creator, state.encryption, state.is_federatable AS federatable, rooms.is_public AS public, state.join_rules, state.guest_access, - state.history_visibility, curr.current_state_events AS state_events + state.history_visibility, curr.current_state_events AS state_events, + state.avatar, state.topic FROM rooms LEFT JOIN room_stats_state state USING (room_id) LEFT JOIN room_stats_current curr USING (room_id) @@ -862,7 +863,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): ADD_ROOMS_ROOM_VERSION_COLUMN = "add_rooms_room_version_column" def __init__(self, database: DatabasePool, db_conn, hs): - super(RoomBackgroundUpdateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.config = hs.config @@ -1073,7 +1074,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(RoomStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.config = hs.config @@ -1136,7 +1137,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - with await self._public_room_id_gen.get_next() as next_id: + async with self._public_room_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "store_room_txn", store_room_txn, next_id ) @@ -1203,7 +1204,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - with await self._public_room_id_gen.get_next() as next_id: + async with self._public_room_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "set_room_is_public", set_room_is_public_txn, next_id ) @@ -1283,7 +1284,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - with await self._public_room_id_gen.get_next() as next_id: + async with self._public_room_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "set_room_is_public_appservice", set_room_is_public_appservice_txn, @@ -1327,6 +1328,101 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): desc="add_event_report", ) + async def get_event_reports_paginate( + self, + start: int, + limit: int, + direction: str = "b", + user_id: Optional[str] = None, + room_id: Optional[str] = None, + ) -> Tuple[List[Dict[str, Any]], int]: + """Retrieve a paginated list of event reports + + Args: + start: event offset to begin the query from + limit: number of rows to retrieve + direction: Whether to fetch the most recent first (`"b"`) or the + oldest first (`"f"`) + user_id: search for user_id. Ignored if user_id is None + room_id: search for room_id. Ignored if room_id is None + Returns: + event_reports: json list of event reports + count: total number of event reports matching the filter criteria + """ + + def _get_event_reports_paginate_txn(txn): + filters = [] + args = [] + + if user_id: + filters.append("er.user_id LIKE ?") + args.extend(["%" + user_id + "%"]) + if room_id: + filters.append("er.room_id LIKE ?") + args.extend(["%" + room_id + "%"]) + + if direction == "b": + order = "DESC" + else: + order = "ASC" + + where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else "" + + sql = """ + SELECT COUNT(*) as total_event_reports + FROM event_reports AS er + {} + """.format( + where_clause + ) + txn.execute(sql, args) + count = txn.fetchone()[0] + + sql = """ + SELECT + er.id, + er.received_ts, + er.room_id, + er.event_id, + er.user_id, + er.reason, + er.content, + events.sender, + room_aliases.room_alias, + event_json.json AS event_json + FROM event_reports AS er + LEFT JOIN room_aliases + ON room_aliases.room_id = er.room_id + JOIN events + ON events.event_id = er.event_id + JOIN event_json + ON event_json.event_id = er.event_id + {where_clause} + ORDER BY er.received_ts {order} + LIMIT ? + OFFSET ? + """.format( + where_clause=where_clause, order=order, + ) + + args += [limit, start] + txn.execute(sql, args) + event_reports = self.db_pool.cursor_to_dict(txn) + + if count > 0: + for row in event_reports: + try: + row["content"] = db_to_json(row["content"]) + row["event_json"] = db_to_json(row["event_json"]) + except Exception: + continue + + return event_reports, count + + return await self.db_pool.runInteraction( + "get_event_reports_paginate", _get_event_reports_paginate_txn + ) + def get_current_public_room_stream_id(self): return self._public_room_id_gen.get_current_token() diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 91a8b43da3..86ffe2479e 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set @@ -37,7 +36,7 @@ from synapse.storage.roommember import ( ProfileInfo, RoomsForUser, ) -from synapse.types import Collection, get_domain_from_id +from synapse.types import Collection, PersistedEventPosition, get_domain_from_id from synapse.util.async_helpers import Linearizer from synapse.util.caches import intern_string from synapse.util.caches.descriptors import _CacheContext, cached, cachedList @@ -55,7 +54,7 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership" class RoomMemberWorkerStore(EventsWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(RoomMemberWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) # Is the current_state_events.membership up to date? Or is the # background update still running? @@ -387,7 +386,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # for rooms the server is participating in. if self._current_state_events_membership_up_to_date: sql = """ - SELECT room_id, e.stream_ordering + SELECT room_id, e.instance_name, e.stream_ordering FROM current_state_events AS c INNER JOIN events AS e USING (room_id, event_id) WHERE @@ -397,7 +396,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): """ else: sql = """ - SELECT room_id, e.stream_ordering + SELECT room_id, e.instance_name, e.stream_ordering FROM current_state_events AS c INNER JOIN room_memberships AS m USING (room_id, event_id) INNER JOIN events AS e USING (room_id, event_id) @@ -408,7 +407,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): """ txn.execute(sql, (user_id, Membership.JOIN)) - return frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn) + return frozenset( + GetRoomsForUserWithStreamOrdering( + room_id, PersistedEventPosition(instance, stream_id) + ) + for room_id, instance, stream_id in txn + ) async def get_users_server_still_shares_room_with( self, user_ids: Collection[str] @@ -819,7 +823,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): class RoomMemberBackgroundUpdateStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(RoomMemberBackgroundUpdateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_update_handler( _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile ) @@ -973,7 +977,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(RoomMemberStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) async def forget(self, user_id: str, room_id: str) -> None: """Indicate that user_id wishes to discard history for room_id.""" diff --git a/synapse/storage/databases/main/schema/delta/56/event_labels.sql b/synapse/storage/databases/main/schema/delta/56/event_labels.sql index 5e29c1da19..ccf287971c 100644 --- a/synapse/storage/databases/main/schema/delta/56/event_labels.sql +++ b/synapse/storage/databases/main/schema/delta/56/event_labels.sql @@ -13,7 +13,7 @@ * limitations under the License. */ --- room_id and topoligical_ordering are denormalised from the events table in order to +-- room_id and topological_ordering are denormalised from the events table in order to -- make the index work. CREATE TABLE IF NOT EXISTS event_labels ( event_id TEXT, diff --git a/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres new file mode 100644 index 0000000000..b64926e9c9 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres @@ -0,0 +1,33 @@ +/* Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * This adds the method to the unique key constraint of the thumbnail databases. + * Otherwise you can't have a scaled and a cropped thumbnail with the same + * resolution, which happens quite often with dynamic thumbnailing. + * This is the postgres specific migration modifying the table with a background + * migration. + */ + +-- add new index that includes method to local media +INSERT INTO background_updates (update_name, progress_json) VALUES + ('local_media_repository_thumbnails_method_idx', '{}'); + +-- add new index that includes method to remote media +INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES + ('remote_media_repository_thumbnails_method_idx', '{}', 'local_media_repository_thumbnails_method_idx'); + +-- drop old index +INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES + ('media_repository_drop_index_wo_method', '{}', 'remote_media_repository_thumbnails_method_idx'); + diff --git a/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.sqlite b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.sqlite new file mode 100644 index 0000000000..1d0c04b53a --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.sqlite @@ -0,0 +1,44 @@ +/* Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * This adds the method to the unique key constraint of the thumbnail databases. + * Otherwise you can't have a scaled and a cropped thumbnail with the same + * resolution, which happens quite often with dynamic thumbnailing. + * This is a sqlite specific migration, since sqlite can't modify the unique + * constraint of a table without recreating it. + */ + +CREATE TABLE local_media_repository_thumbnails_new ( media_id TEXT, thumbnail_width INTEGER, thumbnail_height INTEGER, thumbnail_type TEXT, thumbnail_method TEXT, thumbnail_length INTEGER, UNIQUE ( media_id, thumbnail_width, thumbnail_height, thumbnail_type, thumbnail_method ) ); + +INSERT INTO local_media_repository_thumbnails_new + SELECT media_id, thumbnail_width, thumbnail_height, thumbnail_type, thumbnail_method, thumbnail_length + FROM local_media_repository_thumbnails; + +DROP TABLE local_media_repository_thumbnails; + +ALTER TABLE local_media_repository_thumbnails_new RENAME TO local_media_repository_thumbnails; + +CREATE INDEX local_media_repository_thumbnails_media_id ON local_media_repository_thumbnails (media_id); + + + +CREATE TABLE IF NOT EXISTS remote_media_cache_thumbnails_new ( media_origin TEXT, media_id TEXT, thumbnail_width INTEGER, thumbnail_height INTEGER, thumbnail_method TEXT, thumbnail_type TEXT, thumbnail_length INTEGER, filesystem_id TEXT, UNIQUE ( media_origin, media_id, thumbnail_width, thumbnail_height, thumbnail_type, thumbnail_method ) ); + +INSERT INTO remote_media_cache_thumbnails_new + SELECT media_origin, media_id, thumbnail_width, thumbnail_height, thumbnail_method, thumbnail_type, thumbnail_length, filesystem_id + FROM remote_media_cache_thumbnails; + +DROP TABLE remote_media_cache_thumbnails; + +ALTER TABLE remote_media_cache_thumbnails_new RENAME TO remote_media_cache_thumbnails; diff --git a/synapse/storage/databases/main/schema/delta/58/10_pushrules_enabled_delete_obsolete.sql b/synapse/storage/databases/main/schema/delta/58/10_pushrules_enabled_delete_obsolete.sql new file mode 100644 index 0000000000..847aebd85e --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/10_pushrules_enabled_delete_obsolete.sql @@ -0,0 +1,28 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + Delete stuck 'enabled' bits that correspond to deleted or non-existent push rules. + We ignore rules that are server-default rules because they are not defined + in the `push_rules` table. +**/ + +DELETE FROM push_rules_enable WHERE + rule_id NOT LIKE 'global/%/.m.rule.%' + AND NOT EXISTS ( + SELECT 1 FROM push_rules + WHERE push_rules.user_name = push_rules_enable.user_name + AND push_rules.rule_id = push_rules_enable.rule_id + ); diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql new file mode 100644 index 0000000000..98ff76d709 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql @@ -0,0 +1,16 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE events ADD COLUMN instance_name TEXT; diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres new file mode 100644 index 0000000000..c31f9af82a --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres @@ -0,0 +1,28 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE SEQUENCE IF NOT EXISTS events_stream_seq; + +SELECT setval('events_stream_seq', ( + SELECT COALESCE(MAX(stream_ordering), 1) FROM events +)); + +CREATE SEQUENCE IF NOT EXISTS events_backfill_stream_seq; + +-- If the server has never backfilled a room then doing `-MIN(...)` will give +-- a negative result, hence why we do `GREATEST(...)` +SELECT setval('events_backfill_stream_seq', ( + SELECT GREATEST(COALESCE(-MIN(stream_ordering), 1), 1) FROM events +)); diff --git a/synapse/storage/databases/main/schema/delta/58/15_catchup_destination_rooms.sql b/synapse/storage/databases/main/schema/delta/58/15_catchup_destination_rooms.sql new file mode 100644 index 0000000000..ebfbed7925 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/15_catchup_destination_rooms.sql @@ -0,0 +1,42 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +-- This schema delta alters the schema to enable 'catching up' remote homeservers +-- after there has been a connectivity problem for any reason. + +-- This stores, for each (destination, room) pair, the stream_ordering of the +-- latest event for that destination. +CREATE TABLE IF NOT EXISTS destination_rooms ( + -- the destination in question. + destination TEXT NOT NULL REFERENCES destinations (destination), + -- the ID of the room in question + room_id TEXT NOT NULL REFERENCES rooms (room_id), + -- the stream_ordering of the event + stream_ordering BIGINT NOT NULL, + PRIMARY KEY (destination, room_id) + -- We don't declare a foreign key on stream_ordering here because that'd mean + -- we'd need to either maintain an index (expensive) or do a table scan of + -- destination_rooms whenever we delete an event (also potentially expensive). + -- In addition to that, a foreign key on stream_ordering would be redundant + -- as this row doesn't need to refer to a specific event; if the event gets + -- deleted then it doesn't affect the validity of the stream_ordering here. +); + +-- This index is needed to make it so that a deletion of a room (in the rooms +-- table) can be efficient, as otherwise a table scan would need to be performed +-- to check that no destination_rooms rows point to the room to be deleted. +-- Also: it makes it efficient to delete all the entries for a given room ID, +-- such as when purging a room. +CREATE INDEX IF NOT EXISTS destination_rooms_room_id + ON destination_rooms (room_id); diff --git a/synapse/storage/databases/main/schema/delta/58/16populate_stats_process_rooms_fix.sql b/synapse/storage/databases/main/schema/delta/58/16populate_stats_process_rooms_fix.sql new file mode 100644 index 0000000000..55f5d0f732 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/16populate_stats_process_rooms_fix.sql @@ -0,0 +1,22 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +-- This delta file fixes a regression introduced by 58/12room_stats.sql, removing the hacky +-- populate_stats_process_rooms_2 background job and restores the functionality under the +-- original name. +-- See https://github.com/matrix-org/synapse/issues/8238 for details + +DELETE FROM background_updates WHERE update_name = 'populate_stats_process_rooms'; +UPDATE background_updates SET update_name = 'populate_stats_process_rooms' + WHERE update_name = 'populate_stats_process_rooms_2'; diff --git a/synapse/storage/databases/main/schema/delta/58/17_catchup_last_successful.sql b/synapse/storage/databases/main/schema/delta/58/17_catchup_last_successful.sql new file mode 100644 index 0000000000..a67aa5e500 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/17_catchup_last_successful.sql @@ -0,0 +1,21 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- This column tracks the stream_ordering of the event that was most recently +-- successfully transmitted to the destination. +-- A value of NULL means that we have not sent an event successfully yet +-- (at least, not since the introduction of this column). +ALTER TABLE destinations + ADD COLUMN last_successful_stream_ordering BIGINT; diff --git a/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql b/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql new file mode 100644 index 0000000000..985fd949a2 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql @@ -0,0 +1,22 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE stream_positions ( + stream_name TEXT NOT NULL, + instance_name TEXT NOT NULL, + stream_id BIGINT NOT NULL +); + +CREATE UNIQUE INDEX stream_positions_idx ON stream_positions(stream_name, instance_name); diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index f01cf2fd02..e34fce6281 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -89,7 +89,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin" def __init__(self, database: DatabasePool, db_conn, hs): - super(SearchBackgroundUpdateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) if not hs.config.enable_search: return @@ -342,7 +342,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): class SearchStore(SearchBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(SearchStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) async def search_msgs(self, room_ids, search_term, keys): """Performs a full text search over events with given keys. diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 5c6168e301..3c1e33819b 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -56,7 +56,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): """ def __init__(self, database: DatabasePool, db_conn, hs): - super(StateGroupWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) async def get_room_version(self, room_id: str) -> RoomVersion: """Get the room_version of a given room @@ -320,7 +320,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events" def __init__(self, database: DatabasePool, db_conn, hs): - super(MainStateBackgroundUpdateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.server_name = hs.hostname @@ -506,4 +506,4 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore): """ def __init__(self, database: DatabasePool, db_conn, hs): - super(StateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 55a250ef06..5beb302be3 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -61,7 +61,7 @@ TYPE_TO_ORIGIN_TABLE = {"room": ("rooms", "room_id"), "user": ("users", "name")} class StatsStore(StateDeltasStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(StatsStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.server_name = hs.hostname self.clock = self.hs.get_clock() @@ -74,9 +74,6 @@ class StatsStore(StateDeltasStore): "populate_stats_process_rooms", self._populate_stats_process_rooms ) self.db_pool.updates.register_background_update_handler( - "populate_stats_process_rooms_2", self._populate_stats_process_rooms_2 - ) - self.db_pool.updates.register_background_update_handler( "populate_stats_process_users", self._populate_stats_process_users ) # we no longer need to perform clean-up, but we will give ourselves @@ -148,31 +145,10 @@ class StatsStore(StateDeltasStore): return len(users_to_work_on) async def _populate_stats_process_rooms(self, progress, batch_size): - """ - This was a background update which regenerated statistics for rooms. - - It has been replaced by StatsStore._populate_stats_process_rooms_2. This background - job has been scheduled to run as part of Synapse v1.0.0, and again now. To ensure - someone upgrading from <v1.0.0, this background task has been turned into a no-op - so that the potentially expensive task is not run twice. - - Further context: https://github.com/matrix-org/synapse/pull/7977 - """ - await self.db_pool.updates._end_background_update( - "populate_stats_process_rooms" - ) - return 1 - - async def _populate_stats_process_rooms_2(self, progress, batch_size): - """ - This is a background update which regenerates statistics for rooms. - - It replaces StatsStore._populate_stats_process_rooms. See its docstring for the - reasoning. - """ + """This is a background update which regenerates statistics for rooms.""" if not self.stats_enabled: await self.db_pool.updates._end_background_update( - "populate_stats_process_rooms_2" + "populate_stats_process_rooms" ) return 1 @@ -189,13 +165,13 @@ class StatsStore(StateDeltasStore): return [r for r, in txn] rooms_to_work_on = await self.db_pool.runInteraction( - "populate_stats_rooms_2_get_batch", _get_next_batch + "populate_stats_rooms_get_batch", _get_next_batch ) # No more rooms -- complete the transaction. if not rooms_to_work_on: await self.db_pool.updates._end_background_update( - "populate_stats_process_rooms_2" + "populate_stats_process_rooms" ) return 1 @@ -204,9 +180,9 @@ class StatsStore(StateDeltasStore): progress["last_room_id"] = room_id await self.db_pool.runInteraction( - "_populate_stats_process_rooms_2", + "_populate_stats_process_rooms", self.db_pool.updates._background_update_progress_txn, - "populate_stats_process_rooms_2", + "populate_stats_process_rooms", progress, ) @@ -234,6 +210,7 @@ class StatsStore(StateDeltasStore): * topic * avatar * canonical_alias + * guest_access A is_federatable key can also be included with a boolean value. @@ -258,6 +235,7 @@ class StatsStore(StateDeltasStore): "topic", "avatar", "canonical_alias", + "guest_access", ): field = fields.get(col, sentinel) if field is not sentinel and (not isinstance(field, str) or "\0" in field): diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index db20a3db30..37249f1e3f 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -35,7 +35,6 @@ what sort order was used: - topological tokems: "t%d-%d", where the integers map to the topological and stream ordering columns respectively. """ - import abc import logging from collections import namedtuple @@ -54,7 +53,7 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine -from synapse.types import Collection, RoomStreamToken +from synapse.types import Collection, PersistedEventPosition, RoomStreamToken from synapse.util.caches.stream_change_cache import StreamChangeCache if TYPE_CHECKING: @@ -79,8 +78,8 @@ _EventDictReturn = namedtuple( def generate_pagination_where_clause( direction: str, column_names: Tuple[str, str], - from_token: Optional[Tuple[int, int]], - to_token: Optional[Tuple[int, int]], + from_token: Optional[Tuple[Optional[int], int]], + to_token: Optional[Tuple[Optional[int], int]], engine: BaseDatabaseEngine, ) -> str: """Creates an SQL expression to bound the columns by the pagination @@ -259,16 +258,14 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]: return " AND ".join(clauses), args -class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): +class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): """This is an abstract base class where subclasses must implement `get_room_max_stream_ordering` and `get_room_min_stream_ordering` which can be called in the initializer. """ - __metaclass__ = abc.ABCMeta - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): - super(StreamWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self._instance_name = hs.get_instance_name() self._send_federation = hs.should_send_federation() @@ -307,14 +304,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): def get_room_min_stream_ordering(self) -> int: raise NotImplementedError() + def get_room_max_token(self) -> RoomStreamToken: + return RoomStreamToken(None, self.get_room_max_stream_ordering()) + async def get_room_events_stream_for_rooms( self, room_ids: Collection[str], - from_key: str, - to_key: str, + from_key: RoomStreamToken, + to_key: RoomStreamToken, limit: int = 0, order: str = "DESC", - ) -> Dict[str, Tuple[List[EventBase], str]]: + ) -> Dict[str, Tuple[List[EventBase], RoomStreamToken]]: """Get new room events in stream ordering since `from_key`. Args: @@ -333,9 +333,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): - list of recent events in the room - stream ordering key for the start of the chunk of events returned. """ - from_id = RoomStreamToken.parse_stream_token(from_key).stream - - room_ids = self._events_stream_cache.get_entities_changed(room_ids, from_id) + room_ids = self._events_stream_cache.get_entities_changed( + room_ids, from_key.stream + ) if not room_ids: return {} @@ -364,16 +364,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return results def get_rooms_that_changed( - self, room_ids: Collection[str], from_key: str + self, room_ids: Collection[str], from_key: RoomStreamToken ) -> Set[str]: """Given a list of rooms and a token, return rooms where there may have been changes. - - Args: - room_ids - from_key: The room_key portion of a StreamToken """ - from_id = RoomStreamToken.parse_stream_token(from_key).stream + from_id = from_key.stream return { room_id for room_id in room_ids @@ -383,11 +379,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): async def get_room_events_stream_for_room( self, room_id: str, - from_key: str, - to_key: str, + from_key: RoomStreamToken, + to_key: RoomStreamToken, limit: int = 0, order: str = "DESC", - ) -> Tuple[List[EventBase], str]: + ) -> Tuple[List[EventBase], RoomStreamToken]: """Get new room events in stream ordering since `from_key`. Args: @@ -408,8 +404,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): if from_key == to_key: return [], from_key - from_id = RoomStreamToken.parse_stream_token(from_key).stream - to_id = RoomStreamToken.parse_stream_token(to_key).stream + from_id = from_key.stream + to_id = to_key.stream has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id) @@ -441,7 +437,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ret.reverse() if rows: - key = "s%d" % min(r.stream_ordering for r in rows) + key = RoomStreamToken(None, min(r.stream_ordering for r in rows)) else: # Assume we didn't get anything because there was nothing to # get. @@ -450,10 +446,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return ret, key async def get_membership_changes_for_user( - self, user_id: str, from_key: str, to_key: str + self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken ) -> List[EventBase]: - from_id = RoomStreamToken.parse_stream_token(from_key).stream - to_id = RoomStreamToken.parse_stream_token(to_key).stream + from_id = from_key.stream + to_id = to_key.stream if from_key == to_key: return [] @@ -491,8 +487,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return ret async def get_recent_events_for_room( - self, room_id: str, limit: int, end_token: str - ) -> Tuple[List[EventBase], str]: + self, room_id: str, limit: int, end_token: RoomStreamToken + ) -> Tuple[List[EventBase], RoomStreamToken]: """Get the most recent events in the room in topological ordering. Args: @@ -518,8 +514,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return (events, token) async def get_recent_event_ids_for_room( - self, room_id: str, limit: int, end_token: str - ) -> Tuple[List[_EventDictReturn], str]: + self, room_id: str, limit: int, end_token: RoomStreamToken + ) -> Tuple[List[_EventDictReturn], RoomStreamToken]: """Get the most recent events in the room in topological ordering. Args: @@ -535,8 +531,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): if limit == 0: return [], end_token - end_token = RoomStreamToken.parse(end_token) - rows, token = await self.db_pool.runInteraction( "get_recent_event_ids_for_room", self._paginate_room_events_txn, @@ -619,26 +613,28 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): allow_none=allow_none, ) - async def get_stream_token_for_event(self, event_id: str) -> str: - """The stream token for an event - Args: - event_id: The id of the event to look up a stream token for. - Raises: - StoreError if the event wasn't in the database. - Returns: - A "s%d" stream token. + async def get_position_for_event(self, event_id: str) -> PersistedEventPosition: + """Get the persisted position for an event """ - stream_id = await self.get_stream_id_for_event(event_id) - return "s%d" % (stream_id,) + row = await self.db_pool.simple_select_one( + table="events", + keyvalues={"event_id": event_id}, + retcols=("stream_ordering", "instance_name"), + desc="get_position_for_event", + ) - async def get_topological_token_for_event(self, event_id: str) -> str: + return PersistedEventPosition( + row["instance_name"] or "master", row["stream_ordering"] + ) + + async def get_topological_token_for_event(self, event_id: str) -> RoomStreamToken: """The stream token for an event Args: event_id: The id of the event to look up a stream token for. Raises: StoreError if the event wasn't in the database. Returns: - A "t%d-%d" topological token. + A `RoomStreamToken` topological token. """ row = await self.db_pool.simple_select_one( table="events", @@ -646,7 +642,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): retcols=("stream_ordering", "topological_ordering"), desc="get_topological_token_for_event", ) - return "t%d-%d" % (row["topological_ordering"], row["stream_ordering"]) + return RoomStreamToken(row["topological_ordering"], row["stream_ordering"]) async def get_current_topological_token(self, room_id: str, stream_key: int) -> int: """Gets the topological token in a room after or at the given stream @@ -695,8 +691,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): else: topo = None internal = event.internal_metadata - internal.before = str(RoomStreamToken(topo, stream - 1)) - internal.after = str(RoomStreamToken(topo, stream)) + internal.before = RoomStreamToken(topo, stream - 1) + internal.after = RoomStreamToken(topo, stream) internal.order = (int(topo) if topo else 0, int(stream)) async def get_events_around( @@ -951,7 +947,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): direction: str = "b", limit: int = -1, event_filter: Optional[Filter] = None, - ) -> Tuple[List[_EventDictReturn], str]: + ) -> Tuple[List[_EventDictReturn], RoomStreamToken]: """Returns list of events before or after a given token. Args: @@ -986,8 +982,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): bounds = generate_pagination_where_clause( direction=direction, column_names=("topological_ordering", "stream_ordering"), - from_token=from_token, - to_token=to_token, + from_token=from_token.as_tuple(), + to_token=to_token.as_tuple() if to_token else None, engine=self.database_engine, ) @@ -1051,17 +1047,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): # TODO (erikj): We should work out what to do here instead. next_token = to_token if to_token else from_token - return rows, str(next_token) + return rows, next_token async def paginate_room_events( self, room_id: str, - from_key: str, - to_key: Optional[str] = None, + from_key: RoomStreamToken, + to_key: Optional[RoomStreamToken] = None, direction: str = "b", limit: int = -1, event_filter: Optional[Filter] = None, - ) -> Tuple[List[EventBase], str]: + ) -> Tuple[List[EventBase], RoomStreamToken]: """Returns list of events before or after a given token. Args: @@ -1080,10 +1076,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): and `to_key`). """ - from_key = RoomStreamToken.parse(from_key) - if to_key: - to_key = RoomStreamToken.parse(to_key) - rows, token = await self.db_pool.runInteraction( "paginate_room_events", self._paginate_room_events_txn, diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index 96ffe26cc9..9f120d3cb6 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -210,7 +210,7 @@ class TagsStore(TagsWorkerStore): ) self._update_revision_txn(txn, user_id, room_id, next_id) - with await self._account_data_id_gen.get_next() as next_id: + async with self._account_data_id_gen.get_next() as next_id: await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) @@ -232,7 +232,7 @@ class TagsStore(TagsWorkerStore): txn.execute(sql, (user_id, room_id, tag)) self._update_revision_txn(txn, user_id, room_id, next_id) - with await self._account_data_id_gen.get_next() as next_id: + async with self._account_data_id_gen.get_next() as next_id: await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 5b31aab700..97aed1500e 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -15,13 +15,14 @@ import logging from collections import namedtuple -from typing import Optional, Tuple +from typing import Iterable, List, Optional, Tuple from canonicaljson import encode_canonical_json from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.types import JsonDict from synapse.util.caches.expiringcache import ExpiringCache @@ -47,7 +48,7 @@ class TransactionStore(SQLBaseStore): """ def __init__(self, database: DatabasePool, db_conn, hs): - super(TransactionStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000) @@ -164,7 +165,9 @@ class TransactionStore(SQLBaseStore): allow_none=True, ) - if result and result["retry_last_ts"] > 0: + # check we have a row and retry_last_ts is not null or zero + # (retry_last_ts can't be negative) + if result and result["retry_last_ts"]: return result else: return None @@ -215,6 +218,7 @@ class TransactionStore(SQLBaseStore): retry_interval = EXCLUDED.retry_interval WHERE EXCLUDED.retry_interval = 0 + OR destinations.retry_interval IS NULL OR destinations.retry_interval < EXCLUDED.retry_interval """ @@ -246,7 +250,11 @@ class TransactionStore(SQLBaseStore): "retry_interval": retry_interval, }, ) - elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval: + elif ( + retry_interval == 0 + or prev_row["retry_interval"] is None + or prev_row["retry_interval"] < retry_interval + ): self.db_pool.simple_update_one_txn( txn, "destinations", @@ -273,3 +281,196 @@ class TransactionStore(SQLBaseStore): await self.db_pool.runInteraction( "_cleanup_transactions", _cleanup_transactions_txn ) + + async def store_destination_rooms_entries( + self, destinations: Iterable[str], room_id: str, stream_ordering: int, + ) -> None: + """ + Updates or creates `destination_rooms` entries in batch for a single event. + + Args: + destinations: list of destinations + room_id: the room_id of the event + stream_ordering: the stream_ordering of the event + """ + + return await self.db_pool.runInteraction( + "store_destination_rooms_entries", + self._store_destination_rooms_entries_txn, + destinations, + room_id, + stream_ordering, + ) + + def _store_destination_rooms_entries_txn( + self, + txn: LoggingTransaction, + destinations: Iterable[str], + room_id: str, + stream_ordering: int, + ) -> None: + + # ensure we have a `destinations` row for this destination, as there is + # a foreign key constraint. + if isinstance(self.database_engine, PostgresEngine): + q = """ + INSERT INTO destinations (destination) + VALUES (?) + ON CONFLICT DO NOTHING; + """ + elif isinstance(self.database_engine, Sqlite3Engine): + q = """ + INSERT OR IGNORE INTO destinations (destination) + VALUES (?); + """ + else: + raise RuntimeError("Unknown database engine") + + txn.execute_batch(q, ((destination,) for destination in destinations)) + + rows = [(destination, room_id) for destination in destinations] + + self.db_pool.simple_upsert_many_txn( + txn, + "destination_rooms", + ["destination", "room_id"], + rows, + ["stream_ordering"], + [(stream_ordering,)] * len(rows), + ) + + async def get_destination_last_successful_stream_ordering( + self, destination: str + ) -> Optional[int]: + """ + Gets the stream ordering of the PDU most-recently successfully sent + to the specified destination, or None if this information has not been + tracked yet. + + Args: + destination: the destination to query + """ + return await self.db_pool.simple_select_one_onecol( + "destinations", + {"destination": destination}, + "last_successful_stream_ordering", + allow_none=True, + desc="get_last_successful_stream_ordering", + ) + + async def set_destination_last_successful_stream_ordering( + self, destination: str, last_successful_stream_ordering: int + ) -> None: + """ + Marks that we have successfully sent the PDUs up to and including the + one specified. + + Args: + destination: the destination we have successfully sent to + last_successful_stream_ordering: the stream_ordering of the most + recent successfully-sent PDU + """ + return await self.db_pool.simple_upsert( + "destinations", + keyvalues={"destination": destination}, + values={"last_successful_stream_ordering": last_successful_stream_ordering}, + desc="set_last_successful_stream_ordering", + ) + + async def get_catch_up_room_event_ids( + self, destination: str, last_successful_stream_ordering: int, + ) -> List[str]: + """ + Returns at most 50 event IDs and their corresponding stream_orderings + that correspond to the oldest events that have not yet been sent to + the destination. + + Args: + destination: the destination in question + last_successful_stream_ordering: the stream_ordering of the + most-recently successfully-transmitted event to the destination + + Returns: + list of event_ids + """ + return await self.db_pool.runInteraction( + "get_catch_up_room_event_ids", + self._get_catch_up_room_event_ids_txn, + destination, + last_successful_stream_ordering, + ) + + @staticmethod + def _get_catch_up_room_event_ids_txn( + txn: LoggingTransaction, destination: str, last_successful_stream_ordering: int, + ) -> List[str]: + q = """ + SELECT event_id FROM destination_rooms + JOIN events USING (stream_ordering) + WHERE destination = ? + AND stream_ordering > ? + ORDER BY stream_ordering + LIMIT 50 + """ + txn.execute( + q, (destination, last_successful_stream_ordering), + ) + event_ids = [row[0] for row in txn] + return event_ids + + async def get_catch_up_outstanding_destinations( + self, after_destination: Optional[str] + ) -> List[str]: + """ + Gets at most 25 destinations which have outstanding PDUs to be caught up, + and are not being backed off from + Args: + after_destination: + If provided, all destinations must be lexicographically greater + than this one. + + Returns: + list of up to 25 destinations with outstanding catch-up. + These are the lexicographically first destinations which are + lexicographically greater than after_destination (if provided). + """ + time = self.hs.get_clock().time_msec() + + return await self.db_pool.runInteraction( + "get_catch_up_outstanding_destinations", + self._get_catch_up_outstanding_destinations_txn, + time, + after_destination, + ) + + @staticmethod + def _get_catch_up_outstanding_destinations_txn( + txn: LoggingTransaction, now_time_ms: int, after_destination: Optional[str] + ) -> List[str]: + q = """ + SELECT destination FROM destinations + WHERE destination IN ( + SELECT destination FROM destination_rooms + WHERE destination_rooms.stream_ordering > + destinations.last_successful_stream_ordering + ) + AND destination > ? + AND ( + retry_last_ts IS NULL OR + retry_last_ts + retry_interval < ? + ) + ORDER BY destination + LIMIT 25 + """ + txn.execute( + q, + ( + # everything is lexicographically greater than "" so this gives + # us the first batch of up to 25. + after_destination or "", + now_time_ms, + ), + ) + + destinations = [row[0] for row in txn] + return destinations diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py index b89668d561..3b9211a6d2 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py @@ -23,7 +23,7 @@ from synapse.types import JsonDict from synapse.util import json_encoder, stringutils -@attr.s +@attr.s(slots=True) class UIAuthSessionData: session_id = attr.ib(type=str) # The dictionary from the client root level, not the 'auth' key. diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index f2f9a5799a..5a390ff2f6 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -38,7 +38,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): SHARE_PRIVATE_WORKING_SET = 500 def __init__(self, database: DatabasePool, db_conn, hs): - super(UserDirectoryBackgroundUpdateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.server_name = hs.hostname @@ -564,7 +564,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): SHARE_PRIVATE_WORKING_SET = 500 def __init__(self, database: DatabasePool, db_conn, hs): - super(UserDirectoryStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) async def remove_from_user_dir(self, user_id: str) -> None: def _remove_from_user_dir_txn(txn): diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py index 2f7c95fc74..f9575b1f1f 100644 --- a/synapse/storage/databases/main/user_erasure_store.py +++ b/synapse/storage/databases/main/user_erasure_store.py @@ -100,7 +100,7 @@ class UserErasureStore(UserErasureWorkerStore): return # They are there, delete them. - self.simple_delete_one_txn( + self.db_pool.simple_delete_one_txn( txn, "erased_users", keyvalues={"user_id": user_id} ) diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index 139085b672..acb24e33af 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -181,7 +181,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx" def __init__(self, database: DatabasePool, db_conn, hs): - super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_update_handler( self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, self._background_deduplicate_state, diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index e924f1ca3b..0e31cc811a 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -24,7 +24,7 @@ from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStor from synapse.storage.state import StateFilter from synapse.storage.types import Cursor from synapse.storage.util.sequence import build_sequence_generator -from synapse.types import StateMap +from synapse.types import MutableStateMap, StateMap from synapse.util.caches.descriptors import cached from synapse.util.caches.dictionary_cache import DictionaryCache @@ -52,7 +52,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): """ def __init__(self, database: DatabasePool, db_conn, hs): - super(StateGroupDataStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) # Originally the state store used a single DictionaryCache to cache the # event IDs for the state types in a given state group to avoid hammering @@ -99,6 +99,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): self._state_group_seq_gen = build_sequence_generator( self.database_engine, get_max_state_group_txn, "state_group_id_seq" ) + self._state_group_seq_gen.check_consistency( + db_conn, table="state_groups", id_column="id" + ) @cached(max_entries=10000, iterable=True) async def get_state_group_delta(self, state_group): @@ -205,7 +208,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): async def _get_state_for_groups( self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() - ) -> Dict[int, StateMap[str]]: + ) -> Dict[int, MutableStateMap[str]]: """Gets the state at each of a list of state groups, optionally filtering by type/state_key diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py index 908cbc79e3..d6d632dc10 100644 --- a/synapse/storage/engines/_base.py +++ b/synapse/storage/engines/_base.py @@ -97,3 +97,20 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta): """Gets a string giving the server version. For example: '3.22.0' """ ... + + @abc.abstractmethod + def in_transaction(self, conn: Connection) -> bool: + """Whether the connection is currently in a transaction. + """ + ... + + @abc.abstractmethod + def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool): + """Attempt to set the connections autocommit mode. + + When True queries are run outside of transactions. + + Note: This has no effect on SQLite3, so callers still need to + commit/rollback the connections. + """ + ... diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index ff39281f85..7719ac32f7 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -15,7 +15,8 @@ import logging -from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup +from synapse.storage.engines._base import BaseDatabaseEngine, IncorrectDatabaseSetup +from synapse.storage.types import Connection logger = logging.getLogger(__name__) @@ -119,6 +120,7 @@ class PostgresEngine(BaseDatabaseEngine): cursor.execute("SET synchronous_commit TO OFF") cursor.close() + db_conn.commit() @property def can_native_upsert(self): @@ -171,3 +173,9 @@ class PostgresEngine(BaseDatabaseEngine): return "%i.%i" % (numver / 10000, numver % 10000) else: return "%i.%i.%i" % (numver / 10000, (numver % 10000) / 100, numver % 100) + + def in_transaction(self, conn: Connection) -> bool: + return conn.status != self.module.extensions.STATUS_READY # type: ignore + + def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool): + return conn.set_session(autocommit=autocommit) # type: ignore diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index 8a0f8c89d1..5db0f0b520 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -17,6 +17,7 @@ import threading import typing from synapse.storage.engines import BaseDatabaseEngine +from synapse.storage.types import Connection if typing.TYPE_CHECKING: import sqlite3 # noqa: F401 @@ -86,6 +87,7 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]): db_conn.create_function("rank", 1, _rank) db_conn.execute("PRAGMA foreign_keys = ON;") + db_conn.commit() def is_deadlock(self, error): return False @@ -105,6 +107,14 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]): """ return "%i.%i.%i" % self.module.sqlite_version_info + def in_transaction(self, conn: Connection) -> bool: + return conn.in_transaction # type: ignore + + def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool): + # Twisted doesn't let us set attributes on the connections, so we can't + # set the connection to autocommit mode. + pass + # Following functions taken from: https://github.com/coleifer/peewee diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index dbaeef91dd..72939f3984 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -18,7 +18,7 @@ import itertools import logging from collections import deque, namedtuple -from typing import Iterable, List, Optional, Set, Tuple +from typing import Dict, Iterable, List, Optional, Set, Tuple from prometheus_client import Counter, Histogram @@ -31,7 +31,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.databases import Databases from synapse.storage.databases.main.events import DeltaState -from synapse.types import StateMap +from synapse.types import Collection, PersistedEventPosition, RoomStreamToken, StateMap from synapse.util.async_helpers import ObservableDeferred from synapse.util.metrics import Measure @@ -185,18 +185,21 @@ class EventsPersistenceStorage: # store for now. self.main_store = stores.main self.state_store = stores.state + + assert stores.persist_events self.persist_events_store = stores.persist_events self._clock = hs.get_clock() + self._instance_name = hs.get_instance_name() self.is_mine_id = hs.is_mine_id self._event_persist_queue = _EventPeristenceQueue() self._state_resolution_handler = hs.get_state_resolution_handler() async def persist_events( self, - events_and_contexts: List[Tuple[EventBase, EventContext]], + events_and_contexts: Iterable[Tuple[EventBase, EventContext]], backfilled: bool = False, - ) -> int: + ) -> RoomStreamToken: """ Write events to the database Args: @@ -208,7 +211,7 @@ class EventsPersistenceStorage: Returns: the stream ordering of the latest persisted event """ - partitioned = {} + partitioned = {} # type: Dict[str, List[Tuple[EventBase, EventContext]]] for event, ctx in events_and_contexts: partitioned.setdefault(event.room_id, []).append((event, ctx)) @@ -226,11 +229,11 @@ class EventsPersistenceStorage: defer.gatherResults(deferreds, consumeErrors=True) ) - return self.main_store.get_current_events_token() + return self.main_store.get_room_max_token() async def persist_event( self, event: EventBase, context: EventContext, backfilled: bool = False - ) -> Tuple[int, int]: + ) -> Tuple[PersistedEventPosition, RoomStreamToken]: """ Returns: The stream ordering of `event`, and the stream ordering of the @@ -244,8 +247,10 @@ class EventsPersistenceStorage: await make_deferred_yieldable(deferred) - max_persisted_id = self.main_store.get_current_events_token() - return (event.internal_metadata.stream_ordering, max_persisted_id) + event_stream_id = event.internal_metadata.stream_ordering + + pos = PersistedEventPosition(self._instance_name, event_stream_id) + return pos, self.main_store.get_room_max_token() def _maybe_start_persisting(self, room_id: str): async def persisting_queue(item): @@ -305,7 +310,9 @@ class EventsPersistenceStorage: # Work out the new "current state" for each room. # We do this by working out what the new extremities are and then # calculating the state from that. - events_by_room = {} + events_by_room = ( + {} + ) # type: Dict[str, List[Tuple[EventBase, EventContext]]] for event, context in chunk: events_by_room.setdefault(event.room_id, []).append( (event, context) @@ -436,7 +443,7 @@ class EventsPersistenceStorage: self, room_id: str, event_contexts: List[Tuple[EventBase, EventContext]], - latest_event_ids: List[str], + latest_event_ids: Collection[str], ): """Calculates the new forward extremities for a room given events to persist. @@ -470,7 +477,7 @@ class EventsPersistenceStorage: # Remove any events which are prev_events of any existing events. existing_prevs = await self.persist_events_store._get_events_which_are_prevs( result - ) + ) # type: Collection[str] result.difference_update(existing_prevs) # Finally handle the case where the new events have soft-failed prev diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index ee60e2a718..4957e77f4c 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -19,12 +19,15 @@ import logging import os import re from collections import Counter -from typing import TextIO +from typing import Optional, TextIO import attr +from synapse.config.homeserver import HomeServerConfig +from synapse.storage.engines import BaseDatabaseEngine from synapse.storage.engines.postgres import PostgresEngine -from synapse.storage.types import Cursor +from synapse.storage.types import Connection, Cursor +from synapse.types import Collection logger = logging.getLogger(__name__) @@ -63,7 +66,12 @@ UNAPPLIED_DELTA_ON_WORKER_ERROR = ( ) -def prepare_database(db_conn, database_engine, config, databases=["main", "state"]): +def prepare_database( + db_conn: Connection, + database_engine: BaseDatabaseEngine, + config: Optional[HomeServerConfig], + databases: Collection[str] = ["main", "state"], +): """Prepares a physical database for usage. Will either create all necessary tables or upgrade from an older schema version. @@ -73,16 +81,24 @@ def prepare_database(db_conn, database_engine, config, databases=["main", "state Args: db_conn: database_engine: - config (synapse.config.homeserver.HomeServerConfig|None): + config : application config, or None if we are connecting to an existing database which we expect to be configured already - databases (list[str]): The name of the databases that will be used + databases: The name of the databases that will be used with this physical database. Defaults to all databases. """ try: cur = db_conn.cursor() + # sqlite does not automatically start transactions for DDL / SELECT statements, + # so we start one before running anything. This ensures that any upgrades + # are either applied completely, or not at all. + # + # (psycopg2 automatically starts a transaction as soon as we run any statements + # at all, so this is redundant but harmless there.) + cur.execute("BEGIN TRANSACTION") + logger.info("%r: Checking existing schema version", databases) version_info = _get_or_create_schema_state(cur, database_engine) @@ -622,7 +638,7 @@ def _get_or_create_schema_state(txn, database_engine): return None -@attr.s() +@attr.s(slots=True) class _DirectoryListing: """Helper class to store schema file name and the absolute path to it. diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index d30e3f11e7..cec96ad6a7 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -22,7 +22,7 @@ from synapse.api.errors import SynapseError logger = logging.getLogger(__name__) -@attr.s +@attr.s(slots=True) class PaginationChunk: """Returned by relation pagination APIs. diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 8c4a83a840..f152f63321 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -25,7 +25,7 @@ RoomsForUser = namedtuple( ) GetRoomsForUserWithStreamOrdering = namedtuple( - "_GetRoomsForUserWithStreamOrdering", ("room_id", "stream_ordering") + "_GetRoomsForUserWithStreamOrdering", ("room_id", "event_pos") ) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 8f68d968f0..08a69f2f96 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -20,7 +20,7 @@ import attr from synapse.api.constants import EventTypes from synapse.events import EventBase -from synapse.types import StateMap +from synapse.types import MutableStateMap, StateMap logger = logging.getLogger(__name__) @@ -349,7 +349,7 @@ class StateGroupStorage: async def get_state_groups_ids( self, _room_id: str, event_ids: Iterable[str] - ) -> Dict[int, StateMap[str]]: + ) -> Dict[int, MutableStateMap[str]]: """Get the event IDs of all the state for the state groups for the given events Args: @@ -532,7 +532,7 @@ class StateGroupStorage: def _get_state_for_groups( self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() - ) -> Awaitable[Dict[int, StateMap[str]]]: + ) -> Awaitable[Dict[int, MutableStateMap[str]]]: """Gets the state at each of a list of state groups, optionally filtering by type/state_key diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index b7eb4f8ac9..eccd2d5b7b 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -12,17 +12,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import contextlib import heapq import logging import threading from collections import deque -from typing import Dict, List, Set +from contextlib import contextmanager +from typing import Dict, List, Optional, Set, Union +import attr from typing_extensions import Deque +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.types import Cursor from synapse.storage.util.sequence import PostgresSequenceGenerator logger = logging.getLogger(__name__) @@ -86,7 +88,7 @@ class StreamIdGenerator: upwards, -1 to grow downwards. Usage: - with await stream_id_gen.get_next() as stream_id: + async with stream_id_gen.get_next() as stream_id: # ... persist event ... """ @@ -101,10 +103,10 @@ class StreamIdGenerator: ) self._unfinished_ids = deque() # type: Deque[int] - async def get_next(self): + def get_next(self): """ Usage: - with await stream_id_gen.get_next() as stream_id: + async with stream_id_gen.get_next() as stream_id: # ... persist event ... """ with self._lock: @@ -113,7 +115,7 @@ class StreamIdGenerator: self._unfinished_ids.append(next_id) - @contextlib.contextmanager + @contextmanager def manager(): try: yield next_id @@ -121,12 +123,12 @@ class StreamIdGenerator: with self._lock: self._unfinished_ids.remove(next_id) - return manager() + return _AsyncCtxManagerWrapper(manager()) - async def get_next_mult(self, n): + def get_next_mult(self, n): """ Usage: - with await stream_id_gen.get_next(n) as stream_ids: + async with stream_id_gen.get_next(n) as stream_ids: # ... persist events ... """ with self._lock: @@ -140,7 +142,7 @@ class StreamIdGenerator: for next_id in next_ids: self._unfinished_ids.append(next_id) - @contextlib.contextmanager + @contextmanager def manager(): try: yield next_ids @@ -149,7 +151,7 @@ class StreamIdGenerator: for next_id in next_ids: self._unfinished_ids.remove(next_id) - return manager() + return _AsyncCtxManagerWrapper(manager()) def get_current_token(self): """Returns the maximum stream id such that all stream ids less than or @@ -184,12 +186,16 @@ class MultiWriterIdGenerator: Args: db_conn db + stream_name: A name for the stream. instance_name: The name of this instance. table: Database table associated with stream. instance_column: Column that stores the row's writer's instance name id_column: Column that stores the stream ID. sequence_name: The name of the postgres sequence used to generate new IDs. + writers: A list of known writers to use to populate current positions + on startup. Can be empty if nothing uses `get_current_token` or + `get_positions` (e.g. caches stream). positive: Whether the IDs are positive (true) or negative (false). When using negative IDs we go backwards from -1 to -2, -3, etc. """ @@ -198,16 +204,20 @@ class MultiWriterIdGenerator: self, db_conn, db: DatabasePool, + stream_name: str, instance_name: str, table: str, instance_column: str, id_column: str, sequence_name: str, + writers: List[str], positive: bool = True, ): self._db = db + self._stream_name = stream_name self._instance_name = instance_name self._positive = positive + self._writers = writers self._return_factor = 1 if positive else -1 # We lock as some functions may be called from DB threads. @@ -216,14 +226,16 @@ class MultiWriterIdGenerator: # Note: If we are a negative stream then we still store all the IDs as # positive to make life easier for us, and simply negate the IDs when we # return them. - self._current_positions = self._load_current_ids( - db_conn, table, instance_column, id_column - ) + self._current_positions = {} # type: Dict[str, int] # Set of local IDs that we're still processing. The current position # should be less than the minimum of this set (if not empty). self._unfinished_ids = set() # type: Set[int] + # Set of local IDs that we've processed that are larger than the current + # position, due to there being smaller unpersisted IDs. + self._finished_ids = set() # type: Set[int] + # We track the max position where we know everything before has been # persisted. This is done by a) looking at the min across all instances # and b) noting that if we have seen a run of persisted positions @@ -236,37 +248,113 @@ class MultiWriterIdGenerator: # gaps should be relatively rare it's still worth doing the book keeping # that allows us to skip forwards when there are gapless runs of # positions. + # + # We start at 1 here as a) the first generated stream ID will be 2, and + # b) other parts of the code assume that stream IDs are strictly greater + # than 0. self._persisted_upto_position = ( - min(self._current_positions.values()) if self._current_positions else 0 + min(self._current_positions.values()) if self._current_positions else 1 ) self._known_persisted_positions = [] # type: List[int] self._sequence_gen = PostgresSequenceGenerator(sequence_name) + # We check that the table and sequence haven't diverged. + self._sequence_gen.check_consistency( + db_conn, table=table, id_column=id_column, positive=positive + ) + + # This goes and fills out the above state from the database. + self._load_current_ids(db_conn, table, instance_column, id_column) + def _load_current_ids( self, db_conn, table: str, instance_column: str, id_column: str - ) -> Dict[str, int]: - # If positive stream aggregate via MAX. For negative stream use MIN - # *and* negate the result to get a positive number. - sql = """ - SELECT %(instance)s, %(agg)s(%(id)s) FROM %(table)s - GROUP BY %(instance)s - """ % { - "instance": instance_column, - "id": id_column, - "table": table, - "agg": "MAX" if self._positive else "-MIN", - } - + ): cur = db_conn.cursor() - cur.execute(sql) - # `cur` is an iterable over returned rows, which are 2-tuples. - current_positions = dict(cur) + # Load the current positions of all writers for the stream. + if self._writers: + # We delete any stale entries in the positions table. This is + # important if we add back a writer after a long time; we want to + # consider that a "new" writer, rather than using the old stale + # entry here. + sql = """ + DELETE FROM stream_positions + WHERE + stream_name = ? + AND instance_name != ALL(?) + """ + sql = self._db.engine.convert_param_style(sql) + cur.execute(sql, (self._stream_name, self._writers)) + + sql = """ + SELECT instance_name, stream_id FROM stream_positions + WHERE stream_name = ? + """ + sql = self._db.engine.convert_param_style(sql) + + cur.execute(sql, (self._stream_name,)) + + self._current_positions = { + instance: stream_id * self._return_factor + for instance, stream_id in cur + if instance in self._writers + } - cur.close() + # We set the `_persisted_upto_position` to be the minimum of all current + # positions. If empty we use the max stream ID from the DB table. + min_stream_id = min(self._current_positions.values(), default=None) + + if min_stream_id is None: + # We add a GREATEST here to ensure that the result is always + # positive. (This can be a problem for e.g. backfill streams where + # the server has never backfilled). + sql = """ + SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1) + FROM %(table)s + """ % { + "id": id_column, + "table": table, + "agg": "MAX" if self._positive else "-MIN", + } + cur.execute(sql) + (stream_id,) = cur.fetchone() + self._persisted_upto_position = stream_id + else: + # If we have a min_stream_id then we pull out everything greater + # than it from the DB so that we can prefill + # `_known_persisted_positions` and get a more accurate + # `_persisted_upto_position`. + # + # We also check if any of the later rows are from this instance, in + # which case we use that for this instance's current position. This + # is to handle the case where we didn't finish persisting to the + # stream positions table before restart (or the stream position + # table otherwise got out of date). + + sql = """ + SELECT %(instance)s, %(id)s FROM %(table)s + WHERE ? %(cmp)s %(id)s + """ % { + "id": id_column, + "table": table, + "instance": instance_column, + "cmp": "<=" if self._positive else ">=", + } + sql = self._db.engine.convert_param_style(sql) + cur.execute(sql, (min_stream_id * self._return_factor,)) + + self._persisted_upto_position = min_stream_id - return current_positions + with self._lock: + for (instance, stream_id,) in cur: + stream_id = self._return_factor * stream_id + self._add_persisted_position(stream_id) + + if instance == self._instance_name: + self._current_positions[instance] = stream_id + + cur.close() def _load_next_id_txn(self, txn) -> int: return self._sequence_gen.get_next_id_txn(txn) @@ -274,59 +362,23 @@ class MultiWriterIdGenerator: def _load_next_mult_id_txn(self, txn, n: int) -> List[int]: return self._sequence_gen.get_next_mult_txn(txn, n) - async def get_next(self): + def get_next(self): """ Usage: - with await stream_id_gen.get_next() as stream_id: + async with stream_id_gen.get_next() as stream_id: # ... persist event ... """ - next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn) - - # Assert the fetched ID is actually greater than what we currently - # believe the ID to be. If not, then the sequence and table have got - # out of sync somehow. - with self._lock: - assert self._current_positions.get(self._instance_name, 0) < next_id - - self._unfinished_ids.add(next_id) - - @contextlib.contextmanager - def manager(): - try: - # Multiply by the return factor so that the ID has correct sign. - yield self._return_factor * next_id - finally: - self._mark_id_as_finished(next_id) - return manager() + return _MultiWriterCtxManager(self) - async def get_next_mult(self, n: int): + def get_next_mult(self, n: int): """ Usage: - with await stream_id_gen.get_next_mult(5) as stream_ids: + async with stream_id_gen.get_next_mult(5) as stream_ids: # ... persist events ... """ - next_ids = await self._db.runInteraction( - "_load_next_mult_id", self._load_next_mult_id_txn, n - ) - - # Assert the fetched ID is actually greater than any ID we've already - # seen. If not, then the sequence and table have got out of sync - # somehow. - with self._lock: - assert max(self._current_positions.values(), default=0) < min(next_ids) - self._unfinished_ids.update(next_ids) - - @contextlib.contextmanager - def manager(): - try: - yield [self._return_factor * i for i in next_ids] - finally: - for i in next_ids: - self._mark_id_as_finished(i) - - return manager() + return _MultiWriterCtxManager(self, n) def get_next_txn(self, txn: LoggingTransaction): """ @@ -344,21 +396,63 @@ class MultiWriterIdGenerator: txn.call_after(self._mark_id_as_finished, next_id) txn.call_on_exception(self._mark_id_as_finished, next_id) + # Update the `stream_positions` table with newly updated stream + # ID (unless self._writers is not set in which case we don't + # bother, as nothing will read it). + # + # We only do this on the success path so that the persisted current + # position points to a persited row with the correct instance name. + if self._writers: + txn.call_after( + run_as_background_process, + "MultiWriterIdGenerator._update_table", + self._db.runInteraction, + "MultiWriterIdGenerator._update_table", + self._update_stream_positions_table_txn, + ) + return self._return_factor * next_id def _mark_id_as_finished(self, next_id: int): """The ID has finished being processed so we should advance the - current poistion if possible. + current position if possible. """ with self._lock: self._unfinished_ids.discard(next_id) + self._finished_ids.add(next_id) + + new_cur = None + + if self._unfinished_ids: + # If there are unfinished IDs then the new position will be the + # largest finished ID less than the minimum unfinished ID. + + finished = set() + + min_unfinshed = min(self._unfinished_ids) + for s in self._finished_ids: + if s < min_unfinshed: + if new_cur is None or new_cur < s: + new_cur = s + else: + finished.add(s) + + # We clear these out since they're now all less than the new + # position. + self._finished_ids = finished + else: + # There are no unfinished IDs so the new position is simply the + # largest finished one. + new_cur = max(self._finished_ids) + + # We clear these out since they're now all less than the new + # position. + self._finished_ids.clear() - # Figure out if its safe to advance the position by checking there - # aren't any lower allocated IDs that are yet to finish. - if all(c > next_id for c in self._unfinished_ids): + if new_cur: curr = self._current_positions.get(self._instance_name, 0) - self._current_positions[self._instance_name] = max(curr, next_id) + self._current_positions[self._instance_name] = max(curr, new_cur) self._add_persisted_position(next_id) @@ -367,19 +461,28 @@ class MultiWriterIdGenerator: equal to it have been successfully persisted. """ - # Currently we don't support this operation, as it's not obvious how to - # condense the stream positions of multiple writers into a single int. - raise NotImplementedError() + return self.get_persisted_upto_position() def get_current_token_for_writer(self, instance_name: str) -> int: """Returns the position of the given writer. """ + # If we don't have an entry for the given instance name, we assume it's a + # new writer. + # + # For new writers we assume their initial position to be the current + # persisted up to position. This stops Synapse from doing a full table + # scan when a new writer announces itself over replication. with self._lock: - return self._return_factor * self._current_positions.get(instance_name, 0) + return self._return_factor * self._current_positions.get( + instance_name, self._persisted_upto_position + ) def get_positions(self) -> Dict[str, int]: """Get a copy of the current positon map. + + Note that this won't necessarily include all configured writers if some + writers haven't written anything yet. """ with self._lock: @@ -428,7 +531,7 @@ class MultiWriterIdGenerator: # We move the current min position up if the minimum current positions # of all instances is higher (since by definition all positions less # that that have been persisted). - min_curr = min(self._current_positions.values()) + min_curr = min(self._current_positions.values(), default=0) self._persisted_upto_position = max(min_curr, self._persisted_upto_position) # We now iterate through the seen positions, discarding those that are @@ -449,3 +552,97 @@ class MultiWriterIdGenerator: # There was a gap in seen positions, so there is nothing more to # do. break + + def _update_stream_positions_table_txn(self, txn: Cursor): + """Update the `stream_positions` table with newly persisted position. + """ + + if not self._writers: + return + + # We upsert the value, ensuring on conflict that we always increase the + # value (or decrease if stream goes backwards). + sql = """ + INSERT INTO stream_positions (stream_name, instance_name, stream_id) + VALUES (?, ?, ?) + ON CONFLICT (stream_name, instance_name) + DO UPDATE SET + stream_id = %(agg)s(stream_positions.stream_id, EXCLUDED.stream_id) + """ % { + "agg": "GREATEST" if self._positive else "LEAST", + } + + pos = (self.get_current_token_for_writer(self._instance_name),) + txn.execute(sql, (self._stream_name, self._instance_name, pos)) + + +@attr.s(slots=True) +class _AsyncCtxManagerWrapper: + """Helper class to convert a plain context manager to an async one. + + This is mainly useful if you have a plain context manager but the interface + requires an async one. + """ + + inner = attr.ib() + + async def __aenter__(self): + return self.inner.__enter__() + + async def __aexit__(self, exc_type, exc, tb): + return self.inner.__exit__(exc_type, exc, tb) + + +@attr.s(slots=True) +class _MultiWriterCtxManager: + """Async context manager returned by MultiWriterIdGenerator + """ + + id_gen = attr.ib(type=MultiWriterIdGenerator) + multiple_ids = attr.ib(type=Optional[int], default=None) + stream_ids = attr.ib(type=List[int], factory=list) + + async def __aenter__(self) -> Union[int, List[int]]: + # It's safe to run this in autocommit mode as fetching values from a + # sequence ignores transaction semantics anyway. + self.stream_ids = await self.id_gen._db.runInteraction( + "_load_next_mult_id", + self.id_gen._load_next_mult_id_txn, + self.multiple_ids or 1, + db_autocommit=True, + ) + + with self.id_gen._lock: + self.id_gen._unfinished_ids.update(self.stream_ids) + + if self.multiple_ids is None: + return self.stream_ids[0] * self.id_gen._return_factor + else: + return [i * self.id_gen._return_factor for i in self.stream_ids] + + async def __aexit__(self, exc_type, exc, tb): + for i in self.stream_ids: + self.id_gen._mark_id_as_finished(i) + + if exc_type is not None: + return False + + # Update the `stream_positions` table with newly updated stream + # ID (unless self._writers is not set in which case we don't + # bother, as nothing will read it). + # + # We only do this on the success path so that the persisted current + # position points to a persisted row with the correct instance name. + # + # We do this in autocommit mode as a) the upsert works correctly outside + # transactions and b) reduces the amount of time the rows are locked + # for. If we don't do this then we'll often hit serialization errors due + # to the fact we default to REPEATABLE READ isolation levels. + if self.id_gen._writers: + await self.id_gen._db.runInteraction( + "MultiWriterIdGenerator._update_table", + self.id_gen._update_stream_positions_table_txn, + db_autocommit=True, + ) + + return False diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py index ffc1894748..2dd95e2709 100644 --- a/synapse/storage/util/sequence.py +++ b/synapse/storage/util/sequence.py @@ -13,11 +13,34 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc +import logging import threading from typing import Callable, List, Optional -from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine -from synapse.storage.types import Cursor +from synapse.storage.engines import ( + BaseDatabaseEngine, + IncorrectDatabaseSetup, + PostgresEngine, +) +from synapse.storage.types import Connection, Cursor + +logger = logging.getLogger(__name__) + + +_INCONSISTENT_SEQUENCE_ERROR = """ +Postgres sequence '%(seq)s' is inconsistent with associated +table '%(table)s'. This can happen if Synapse has been downgraded and +then upgraded again, or due to a bad migration. + +To fix this error, shut down Synapse (including any and all workers) +and run the following SQL: + + SELECT setval('%(seq)s', ( + %(max_id_sql)s + )); + +See docs/postgres.md for more information. +""" class SequenceGenerator(metaclass=abc.ABCMeta): @@ -28,6 +51,19 @@ class SequenceGenerator(metaclass=abc.ABCMeta): """Gets the next ID in the sequence""" ... + @abc.abstractmethod + def check_consistency( + self, db_conn: Connection, table: str, id_column: str, positive: bool = True + ): + """Should be called during start up to test that the current value of + the sequence is greater than or equal to the maximum ID in the table. + + This is to handle various cases where the sequence value can get out + of sync with the table, e.g. if Synapse gets rolled back to a previous + version and the rolled forwards again. + """ + ... + class PostgresSequenceGenerator(SequenceGenerator): """An implementation of SequenceGenerator which uses a postgres sequence""" @@ -45,6 +81,50 @@ class PostgresSequenceGenerator(SequenceGenerator): ) return [i for (i,) in txn] + def check_consistency( + self, db_conn: Connection, table: str, id_column: str, positive: bool = True + ): + txn = db_conn.cursor() + + # First we get the current max ID from the table. + table_sql = "SELECT GREATEST(%(agg)s(%(id)s), 0) FROM %(table)s" % { + "id": id_column, + "table": table, + "agg": "MAX" if positive else "-MIN", + } + + txn.execute(table_sql) + row = txn.fetchone() + if not row: + # Table is empty, so nothing to do. + txn.close() + return + + # Now we fetch the current value from the sequence and compare with the + # above. + max_stream_id = row[0] + txn.execute( + "SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name} + ) + last_value, is_called = txn.fetchone() + txn.close() + + # If `is_called` is False then `last_value` is actually the value that + # will be generated next, so we decrement to get the true "last value". + if not is_called: + last_value -= 1 + + if max_stream_id > last_value: + logger.warning( + "Postgres sequence %s is behind table %s: %d < %d", + last_value, + max_stream_id, + ) + raise IncorrectDatabaseSetup( + _INCONSISTENT_SEQUENCE_ERROR + % {"seq": self._sequence_name, "table": table, "max_id_sql": table_sql} + ) + GetFirstCallbackType = Callable[[Cursor], int] @@ -81,6 +161,12 @@ class LocalSequenceGenerator(SequenceGenerator): self._current_max_id += 1 return self._current_max_id + def check_consistency( + self, db_conn: Connection, table: str, id_column: str, positive: bool = True + ): + # There is nothing to do for in memory sequences + pass + def build_sequence_generator( database_engine: BaseDatabaseEngine, diff --git a/synapse/streams/config.py b/synapse/streams/config.py index d97dc4d101..fdda21d165 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py @@ -12,11 +12,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging +from typing import Optional + +import attr from synapse.api.errors import SynapseError from synapse.http.servlet import parse_integer, parse_string +from synapse.http.site import SynapseRequest +from synapse.storage.databases.main import DataStore from synapse.types import StreamToken logger = logging.getLogger(__name__) @@ -25,38 +29,23 @@ logger = logging.getLogger(__name__) MAX_LIMIT = 1000 -class SourcePaginationConfig: - - """A configuration object which stores pagination parameters for a - specific event source.""" - - def __init__(self, from_key=None, to_key=None, direction="f", limit=None): - self.from_key = from_key - self.to_key = to_key - self.direction = "f" if direction == "f" else "b" - self.limit = min(int(limit), MAX_LIMIT) if limit is not None else None - - def __repr__(self): - return "StreamConfig(from_key=%r, to_key=%r, direction=%r, limit=%r)" % ( - self.from_key, - self.to_key, - self.direction, - self.limit, - ) - - +@attr.s(slots=True) class PaginationConfig: - """A configuration object which stores pagination parameters.""" - def __init__(self, from_token=None, to_token=None, direction="f", limit=None): - self.from_token = from_token - self.to_token = to_token - self.direction = "f" if direction == "f" else "b" - self.limit = min(int(limit), MAX_LIMIT) if limit is not None else None + from_token = attr.ib(type=Optional[StreamToken]) + to_token = attr.ib(type=Optional[StreamToken]) + direction = attr.ib(type=str) + limit = attr.ib(type=Optional[int]) @classmethod - def from_request(cls, request, raise_invalid_params=True, default_limit=None): + async def from_request( + cls, + store: "DataStore", + request: SynapseRequest, + raise_invalid_params: bool = True, + default_limit: Optional[int] = None, + ) -> "PaginationConfig": direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"]) from_tok = parse_string(request, "from") @@ -66,20 +55,23 @@ class PaginationConfig: if from_tok == "END": from_tok = None # For backwards compat. elif from_tok: - from_tok = StreamToken.from_string(from_tok) + from_tok = await StreamToken.from_string(store, from_tok) except Exception: raise SynapseError(400, "'from' parameter is invalid") try: if to_tok: - to_tok = StreamToken.from_string(to_tok) + to_tok = await StreamToken.from_string(store, to_tok) except Exception: raise SynapseError(400, "'to' parameter is invalid") limit = parse_integer(request, "limit", default=default_limit) - if limit and limit < 0: - raise SynapseError(400, "Limit must be 0 or above") + if limit: + if limit < 0: + raise SynapseError(400, "Limit must be 0 or above") + + limit = min(int(limit), MAX_LIMIT) try: return PaginationConfig(from_tok, to_tok, direction, limit) @@ -87,20 +79,10 @@ class PaginationConfig: logger.exception("Failed to create pagination config") raise SynapseError(400, "Invalid request.") - def __repr__(self): + def __repr__(self) -> str: return ("PaginationConfig(from_tok=%r, to_tok=%r, direction=%r, limit=%r)") % ( self.from_token, self.to_token, self.direction, self.limit, ) - - def get_source_config(self, source_name): - keyname = "%s_key" % source_name - - return SourcePaginationConfig( - from_key=getattr(self.from_token, keyname), - to_key=getattr(self.to_token, keyname) if self.to_token else None, - direction=self.direction, - limit=self.limit, - ) diff --git a/synapse/types.py b/synapse/types.py index f7de48f148..bd271f9f16 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -18,7 +18,17 @@ import re import string import sys from collections import namedtuple -from typing import Any, Dict, Mapping, MutableMapping, Tuple, Type, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Mapping, + MutableMapping, + Optional, + Tuple, + Type, + TypeVar, +) import attr from signedjson.key import decode_verify_key_bytes @@ -26,6 +36,9 @@ from unpaddedbase64 import decode_base64 from synapse.api.errors import Codes, SynapseError +if TYPE_CHECKING: + from synapse.storage.databases.main import DataStore + # define a version of typing.Collection that works on python 3.5 if sys.version_info[:3] >= (3, 6, 0): from typing import Collection @@ -165,7 +178,9 @@ def get_localpart_from_id(string): DS = TypeVar("DS", bound="DomainSpecificString") -class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "domain"))): +class DomainSpecificString( + namedtuple("DomainSpecificString", ("localpart", "domain")), metaclass=abc.ABCMeta +): """Common base class among ID/name strings that have a local part and a domain name, prefixed with a sigil. @@ -175,8 +190,6 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom 'domain' : The domain part of the name """ - __metaclass__ = abc.ABCMeta - SIGIL = abc.abstractproperty() # type: str # type: ignore # Deny iteration because it will bite you if you try to create a singleton @@ -362,86 +375,8 @@ def map_username_to_mxid_localpart(username, case_sensitive=False): return username.decode("ascii") -class StreamToken( - namedtuple( - "Token", - ( - "room_key", - "presence_key", - "typing_key", - "receipt_key", - "account_data_key", - "push_rules_key", - "to_device_key", - "device_list_key", - "groups_key", - ), - ) -): - _SEPARATOR = "_" - START = None # type: StreamToken - - @classmethod - def from_string(cls, string): - try: - keys = string.split(cls._SEPARATOR) - while len(keys) < len(cls._fields): - # i.e. old token from before receipt_key - keys.append("0") - return cls(*keys) - except Exception: - raise SynapseError(400, "Invalid Token") - - def to_string(self): - return self._SEPARATOR.join([str(k) for k in self]) - - @property - def room_stream_id(self): - # TODO(markjh): Awful hack to work around hacks in the presence tests - # which assume that the keys are integers. - if type(self.room_key) is int: - return self.room_key - else: - return int(self.room_key[1:].split("-")[-1]) - - def is_after(self, other): - """Does this token contain events that the other doesn't?""" - return ( - (other.room_stream_id < self.room_stream_id) - or (int(other.presence_key) < int(self.presence_key)) - or (int(other.typing_key) < int(self.typing_key)) - or (int(other.receipt_key) < int(self.receipt_key)) - or (int(other.account_data_key) < int(self.account_data_key)) - or (int(other.push_rules_key) < int(self.push_rules_key)) - or (int(other.to_device_key) < int(self.to_device_key)) - or (int(other.device_list_key) < int(self.device_list_key)) - or (int(other.groups_key) < int(self.groups_key)) - ) - - def copy_and_advance(self, key, new_value): - """Advance the given key in the token to a new value if and only if the - new value is after the old value. - """ - new_token = self.copy_and_replace(key, new_value) - if key == "room_key": - new_id = new_token.room_stream_id - old_id = self.room_stream_id - else: - new_id = int(getattr(new_token, key)) - old_id = int(getattr(self, key)) - if old_id < new_id: - return new_token - else: - return self - - def copy_and_replace(self, key, new_value): - return self._replace(**{key: new_value}) - - -StreamToken.START = StreamToken(*(["s0"] + ["0"] * (len(StreamToken._fields) - 1))) - - -class RoomStreamToken(namedtuple("_StreamToken", "topological stream")): +@attr.s(frozen=True, slots=True) +class RoomStreamToken: """Tokens are positions between events. The token "s1" comes after event 1. s0 s1 @@ -464,10 +399,14 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")): followed by the "stream_ordering" id of the event it comes after. """ - __slots__ = [] # type: list + topological = attr.ib( + type=Optional[int], + validator=attr.validators.optional(attr.validators.instance_of(int)), + ) + stream = attr.ib(type=int, validator=attr.validators.instance_of(int)) @classmethod - def parse(cls, string): + async def parse(cls, store: "DataStore", string: str) -> "RoomStreamToken": try: if string[0] == "s": return cls(topological=None, stream=int(string[1:])) @@ -479,7 +418,7 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")): raise SynapseError(400, "Invalid token %r" % (string,)) @classmethod - def parse_stream_token(cls, string): + def parse_stream_token(cls, string: str) -> "RoomStreamToken": try: if string[0] == "s": return cls(topological=None, stream=int(string[1:])) @@ -487,13 +426,130 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")): pass raise SynapseError(400, "Invalid token %r" % (string,)) - def __str__(self): + def copy_and_advance(self, other: "RoomStreamToken") -> "RoomStreamToken": + """Return a new token such that if an event is after both this token and + the other token, then its after the returned token too. + """ + + if self.topological or other.topological: + raise Exception("Can't advance topological tokens") + + max_stream = max(self.stream, other.stream) + + return RoomStreamToken(None, max_stream) + + def as_tuple(self) -> Tuple[Optional[int], int]: + return (self.topological, self.stream) + + async def to_string(self, store: "DataStore") -> str: if self.topological is not None: return "t%d-%d" % (self.topological, self.stream) else: return "s%d" % (self.stream,) +@attr.s(slots=True, frozen=True) +class StreamToken: + room_key = attr.ib( + type=RoomStreamToken, validator=attr.validators.instance_of(RoomStreamToken) + ) + presence_key = attr.ib(type=int) + typing_key = attr.ib(type=int) + receipt_key = attr.ib(type=int) + account_data_key = attr.ib(type=int) + push_rules_key = attr.ib(type=int) + to_device_key = attr.ib(type=int) + device_list_key = attr.ib(type=int) + groups_key = attr.ib(type=int) + + _SEPARATOR = "_" + START = None # type: StreamToken + + @classmethod + async def from_string(cls, store: "DataStore", string: str) -> "StreamToken": + try: + keys = string.split(cls._SEPARATOR) + while len(keys) < len(attr.fields(cls)): + # i.e. old token from before receipt_key + keys.append("0") + return cls( + await RoomStreamToken.parse(store, keys[0]), *(int(k) for k in keys[1:]) + ) + except Exception: + raise SynapseError(400, "Invalid Token") + + async def to_string(self, store: "DataStore") -> str: + return self._SEPARATOR.join( + [ + await self.room_key.to_string(store), + str(self.presence_key), + str(self.typing_key), + str(self.receipt_key), + str(self.account_data_key), + str(self.push_rules_key), + str(self.to_device_key), + str(self.device_list_key), + str(self.groups_key), + ] + ) + + @property + def room_stream_id(self): + return self.room_key.stream + + def copy_and_advance(self, key, new_value) -> "StreamToken": + """Advance the given key in the token to a new value if and only if the + new value is after the old value. + """ + if key == "room_key": + new_token = self.copy_and_replace( + "room_key", self.room_key.copy_and_advance(new_value) + ) + return new_token + + new_token = self.copy_and_replace(key, new_value) + new_id = int(getattr(new_token, key)) + old_id = int(getattr(self, key)) + + if old_id < new_id: + return new_token + else: + return self + + def copy_and_replace(self, key, new_value) -> "StreamToken": + return attr.evolve(self, **{key: new_value}) + + +StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0) + + +@attr.s(slots=True, frozen=True) +class PersistedEventPosition: + """Position of a newly persisted event with instance that persisted it. + + This can be used to test whether the event is persisted before or after a + RoomStreamToken. + """ + + instance_name = attr.ib(type=str) + stream = attr.ib(type=int) + + def persisted_after(self, token: RoomStreamToken) -> bool: + return token.stream < self.stream + + def to_room_stream_token(self) -> RoomStreamToken: + """Converts the position to a room stream token such that events + persisted in the same room after this position will be after the + returned `RoomStreamToken`. + + Note: no guarentees are made about ordering w.r.t. events in other + rooms. + """ + # Doing the naive thing satisfies the desired properties described in + # the docstring. + return RoomStreamToken(None, self.stream) + + class ThirdPartyInstanceID( namedtuple("ThirdPartyInstanceID", ("appservice_id", "network_id")) ): diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index a13f11f8d8..d55b93d763 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -13,11 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import logging import re import attr -from canonicaljson import json from twisted.internet import defer, task @@ -45,7 +45,7 @@ def unwrapFirstError(failure): return failure.value.subFailure -@attr.s +@attr.s(slots=True) class Clock: """ A Clock wraps a Twisted reactor and provides utilities on top of it. diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index bb57e27beb..382f0cf3f0 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -17,13 +17,25 @@ import collections import logging from contextlib import contextmanager -from typing import Dict, Sequence, Set, Union +from typing import ( + Any, + Callable, + Dict, + Hashable, + Iterable, + List, + Optional, + Set, + TypeVar, + Union, +) import attr from typing_extensions import ContextManager from twisted.internet import defer from twisted.internet.defer import CancelledError +from twisted.internet.interfaces import IReactorTime from twisted.python import failure from synapse.logging.context import ( @@ -54,7 +66,7 @@ class ObservableDeferred: __slots__ = ["_deferred", "_observers", "_result"] - def __init__(self, deferred, consumeErrors=False): + def __init__(self, deferred: defer.Deferred, consumeErrors: bool = False): object.__setattr__(self, "_deferred", deferred) object.__setattr__(self, "_result", None) object.__setattr__(self, "_observers", set()) @@ -111,25 +123,25 @@ class ObservableDeferred: success, res = self._result return defer.succeed(res) if success else defer.fail(res) - def observers(self): + def observers(self) -> List[defer.Deferred]: return self._observers - def has_called(self): + def has_called(self) -> bool: return self._result is not None - def has_succeeded(self): + def has_succeeded(self) -> bool: return self._result is not None and self._result[0] is True - def get_result(self): + def get_result(self) -> Any: return self._result[1] - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: return getattr(self._deferred, name) - def __setattr__(self, name, value): + def __setattr__(self, name: str, value: Any) -> None: setattr(self._deferred, name, value) - def __repr__(self): + def __repr__(self) -> str: return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % ( id(self), self._result, @@ -137,18 +149,20 @@ class ObservableDeferred: ) -def concurrently_execute(func, args, limit): - """Executes the function with each argument conncurrently while limiting +def concurrently_execute( + func: Callable, args: Iterable[Any], limit: int +) -> defer.Deferred: + """Executes the function with each argument concurrently while limiting the number of concurrent executions. Args: - func (func): Function to execute, should return a deferred or coroutine. - args (Iterable): List of arguments to pass to func, each invocation of func + func: Function to execute, should return a deferred or coroutine. + args: List of arguments to pass to func, each invocation of func gets a single argument. - limit (int): Maximum number of conccurent executions. + limit: Maximum number of conccurent executions. Returns: - deferred: Resolved when all function invocations have finished. + Deferred[list]: Resolved when all function invocations have finished. """ it = iter(args) @@ -167,14 +181,17 @@ def concurrently_execute(func, args, limit): ).addErrback(unwrapFirstError) -def yieldable_gather_results(func, iter, *args, **kwargs): +def yieldable_gather_results( + func: Callable, iter: Iterable, *args: Any, **kwargs: Any +) -> defer.Deferred: """Executes the function with each argument concurrently. Args: - func (func): Function to execute that returns a Deferred - iter (iter): An iterable that yields items that get passed as the first + func: Function to execute that returns a Deferred + iter: An iterable that yields items that get passed as the first argument to the function *args: Arguments to be passed to each call to func + **kwargs: Keyword arguments to be passed to each call to func Returns Deferred[list]: Resolved when all functions have been invoked, or errors if @@ -188,24 +205,37 @@ def yieldable_gather_results(func, iter, *args, **kwargs): ).addErrback(unwrapFirstError) +@attr.s(slots=True) +class _LinearizerEntry: + # The number of things executing. + count = attr.ib(type=int) + # Deferreds for the things blocked from executing. + deferreds = attr.ib(type=collections.OrderedDict) + + class Linearizer: """Limits concurrent access to resources based on a key. Useful to ensure only a few things happen at a time on a given resource. Example: - with (yield limiter.queue("test_key")): + with await limiter.queue("test_key"): # do some work. """ - def __init__(self, name=None, max_count=1, clock=None): + def __init__( + self, + name: Optional[str] = None, + max_count: int = 1, + clock: Optional[Clock] = None, + ): """ Args: - max_count(int): The maximum number of concurrent accesses + max_count: The maximum number of concurrent accesses """ if name is None: - self.name = id(self) + self.name = id(self) # type: Union[str, int] else: self.name = name @@ -216,15 +246,10 @@ class Linearizer: self._clock = clock self.max_count = max_count - # key_to_defer is a map from the key to a 2 element list where - # the first element is the number of things executing, and - # the second element is an OrderedDict, where the keys are deferreds for the - # things blocked from executing. - self.key_to_defer = ( - {} - ) # type: Dict[str, Sequence[Union[int, Dict[defer.Deferred, int]]]] + # key_to_defer is a map from the key to a _LinearizerEntry. + self.key_to_defer = {} # type: Dict[Hashable, _LinearizerEntry] - def is_queued(self, key) -> bool: + def is_queued(self, key: Hashable) -> bool: """Checks whether there is a process queued up waiting """ entry = self.key_to_defer.get(key) @@ -234,25 +259,27 @@ class Linearizer: # There are waiting deferreds only in the OrderedDict of deferreds is # non-empty. - return bool(entry[1]) + return bool(entry.deferreds) - def queue(self, key): + def queue(self, key: Hashable) -> defer.Deferred: # we avoid doing defer.inlineCallbacks here, so that cancellation works correctly. # (https://twistedmatrix.com/trac/ticket/4632 meant that cancellations were not # propagated inside inlineCallbacks until Twisted 18.7) - entry = self.key_to_defer.setdefault(key, [0, collections.OrderedDict()]) + entry = self.key_to_defer.setdefault( + key, _LinearizerEntry(0, collections.OrderedDict()) + ) # If the number of things executing is greater than the maximum # then add a deferred to the list of blocked items # When one of the things currently executing finishes it will callback # this item so that it can continue executing. - if entry[0] >= self.max_count: + if entry.count >= self.max_count: res = self._await_lock(key) else: logger.debug( "Acquired uncontended linearizer lock %r for key %r", self.name, key ) - entry[0] += 1 + entry.count += 1 res = defer.succeed(None) # once we successfully get the lock, we need to return a context manager which @@ -267,15 +294,15 @@ class Linearizer: # We've finished executing so check if there are any things # blocked waiting to execute and start one of them - entry[0] -= 1 + entry.count -= 1 - if entry[1]: - (next_def, _) = entry[1].popitem(last=False) + if entry.deferreds: + (next_def, _) = entry.deferreds.popitem(last=False) # we need to run the next thing in the sentinel context. with PreserveLoggingContext(): next_def.callback(None) - elif entry[0] == 0: + elif entry.count == 0: # We were the last thing for this key: remove it from the # map. del self.key_to_defer[key] @@ -283,7 +310,7 @@ class Linearizer: res.addCallback(_ctx_manager) return res - def _await_lock(self, key): + def _await_lock(self, key: Hashable) -> defer.Deferred: """Helper for queue: adds a deferred to the queue Assumes that we've already checked that we've reached the limit of the number @@ -298,11 +325,11 @@ class Linearizer: logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key) new_defer = make_deferred_yieldable(defer.Deferred()) - entry[1][new_defer] = 1 + entry.deferreds[new_defer] = 1 def cb(_r): logger.debug("Acquired linearizer lock %r for key %r", self.name, key) - entry[0] += 1 + entry.count += 1 # if the code holding the lock completes synchronously, then it # will recursively run the next claimant on the list. That can @@ -331,7 +358,7 @@ class Linearizer: ) # we just have to take ourselves back out of the queue. - del entry[1][new_defer] + del entry.deferreds[new_defer] return e new_defer.addCallbacks(cb, eb) @@ -419,14 +446,12 @@ class ReadWriteLock: return _ctx_manager() -def _cancelled_to_timed_out_error(value, timeout): - if isinstance(value, failure.Failure): - value.trap(CancelledError) - raise defer.TimeoutError(timeout, "Deferred") - return value +R = TypeVar("R") -def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None): +def timeout_deferred( + deferred: defer.Deferred, timeout: float, reactor: IReactorTime, +) -> defer.Deferred: """The in built twisted `Deferred.addTimeout` fails to time out deferreds that have a canceller that throws exceptions. This method creates a new deferred that wraps and times out the given deferred, correctly handling @@ -434,27 +459,21 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None): (See https://twistedmatrix.com/trac/ticket/9534) - NOTE: Unlike `Deferred.addTimeout`, this function returns a new deferred + NOTE: Unlike `Deferred.addTimeout`, this function returns a new deferred. - Args: - deferred (Deferred) - timeout (float): Timeout in seconds - reactor (twisted.interfaces.IReactorTime): The twisted reactor to use - on_timeout_cancel (callable): A callable which is called immediately - after the deferred times out, and not if this deferred is - otherwise cancelled before the timeout. + NOTE: the TimeoutError raised by the resultant deferred is + twisted.internet.defer.TimeoutError, which is *different* to the built-in + TimeoutError, as well as various other TimeoutErrors you might have imported. - It takes an arbitrary value, which is the value of the deferred at - that exact point in time (probably a CancelledError Failure), and - the timeout. + Args: + deferred: The Deferred to potentially timeout. + timeout: Timeout in seconds + reactor: The twisted reactor to use - The default callable (if none is provided) will translate a - CancelledError Failure into a defer.TimeoutError. Returns: - Deferred + A new Deferred, which will errback with defer.TimeoutError on timeout. """ - new_d = defer.Deferred() timed_out = [False] @@ -467,18 +486,23 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None): except: # noqa: E722, if we throw any exception it'll break time outs logger.exception("Canceller failed during timeout") + # the cancel() call should have set off a chain of errbacks which + # will have errbacked new_d, but in case it hasn't, errback it now. + if not new_d.called: - new_d.errback(defer.TimeoutError(timeout, "Deferred")) + new_d.errback(defer.TimeoutError("Timed out after %gs" % (timeout,))) delayed_call = reactor.callLater(timeout, time_it_out) - def convert_cancelled(value): - if timed_out[0]: - to_call = on_timeout_cancel or _cancelled_to_timed_out_error - return to_call(value, timeout) + def convert_cancelled(value: failure.Failure): + # if the orgininal deferred was cancelled, and our timeout has fired, then + # the reason it was cancelled was due to our timeout. Turn the CancelledError + # into a TimeoutError. + if timed_out[0] and value.check(CancelledError): + raise defer.TimeoutError("Timed out after %gs" % (timeout,)) return value - deferred.addBoth(convert_cancelled) + deferred.addErrback(convert_cancelled) def cancel_timeout(result): # stop the pending call to cancel the deferred if it's been fired diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index 237f588658..8fc05be278 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -42,7 +42,7 @@ response_cache_evicted = Gauge( response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["name"]) -@attr.s +@attr.s(slots=True) class CacheMetric: _cache = attr.ib() diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index a750261e77..f73e95393c 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -16,8 +16,6 @@ import inspect import logging from twisted.internet import defer -from twisted.internet.defer import Deferred, fail, succeed -from twisted.python import failure from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process @@ -29,11 +27,6 @@ def user_left_room(distributor, user, room_id): distributor.fire("user_left_room", user=user, room_id=room_id) -# XXX: this is no longer used. We should probably kill it. -def user_joined_room(distributor, user, room_id): - distributor.fire("user_joined_room", user=user, room_id=room_id) - - class Distributor: """A central dispatch point for loosely-connected pieces of code to register, observe, and fire signals. @@ -81,28 +74,6 @@ class Distributor: run_as_background_process(name, self.signals[name].fire, *args, **kwargs) -def maybeAwaitableDeferred(f, *args, **kw): - """ - Invoke a function that may or may not return a Deferred or an Awaitable. - - This is a modified version of twisted.internet.defer.maybeDeferred. - """ - try: - result = f(*args, **kw) - except Exception: - return fail(failure.Failure(captureVars=Deferred.debug)) - - if isinstance(result, Deferred): - return result - # Handle the additional case of an awaitable being returned. - elif inspect.isawaitable(result): - return defer.ensureDeferred(result) - elif isinstance(result, failure.Failure): - return fail(result) - else: - return succeed(result) - - class Signal: """A Signal is a dispatch point that stores a list of callables as observers of it. @@ -132,22 +103,17 @@ class Signal: Returns a Deferred that will complete when all the observers have completed.""" - def do(observer): - def eb(failure): + async def do(observer): + try: + result = observer(*args, **kwargs) + if inspect.isawaitable(result): + result = await result + return result + except Exception as e: logger.warning( - "%s signal observer %s failed: %r", - self.name, - observer, - failure, - exc_info=( - failure.type, - failure.value, - failure.getTracebackObject(), - ), + "%s signal observer %s failed: %r", self.name, observer, e, ) - return maybeAwaitableDeferred(observer, *args, **kwargs).addErrback(eb) - deferreds = [run_in_background(do, o) for o in self.observers] return make_deferred_yieldable( diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py index 0e445e01d7..bf094c9386 100644 --- a/synapse/util/frozenutils.py +++ b/synapse/util/frozenutils.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from canonicaljson import json +import json + from frozendict import frozendict @@ -66,5 +67,5 @@ def _handle_frozendict(obj): # A JSONEncoder which is capable of encoding frozendicts without barfing. # Additionally reduce the whitespace produced by JSON encoding. frozendict_json_encoder = json.JSONEncoder( - default=_handle_frozendict, separators=(",", ":"), + allow_nan=False, separators=(",", ":"), default=_handle_frozendict, ) diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py index 631654f297..da24ba0470 100644 --- a/synapse/util/manhole.py +++ b/synapse/util/manhole.py @@ -94,7 +94,7 @@ class SynapseManhole(ColoredManhole): """Overrides connectionMade to create our own ManholeInterpreter""" def connectionMade(self): - super(SynapseManhole, self).connectionMade() + super().connectionMade() # replace the manhole interpreter with our own impl self.interpreter = SynapseManholeInterpreter(self, self.namespace) diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 6e57c1ee72..ffdea0de8d 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -19,7 +19,11 @@ from typing import Any, Callable, Optional, TypeVar, cast from prometheus_client import Counter -from synapse.logging.context import LoggingContext, current_context +from synapse.logging.context import ( + ContextResourceUsage, + LoggingContext, + current_context, +) from synapse.metrics import InFlightGauge logger = logging.getLogger(__name__) @@ -104,27 +108,27 @@ class Measure: def __init__(self, clock, name): self.clock = clock self.name = name - self._logging_context = None + parent_context = current_context() + self._logging_context = LoggingContext( + "Measure[%s]" % (self.name,), parent_context + ) self.start = None - def __enter__(self): - if self._logging_context: + def __enter__(self) -> "Measure": + if self.start is not None: raise RuntimeError("Measure() objects cannot be re-used") self.start = self.clock.time() - parent_context = current_context() - self._logging_context = LoggingContext( - "Measure[%s]" % (self.name,), parent_context - ) self._logging_context.__enter__() in_flight.register((self.name,), self._update_in_flight) + return self def __exit__(self, exc_type, exc_val, exc_tb): - if not self._logging_context: + if self.start is None: raise RuntimeError("Measure() block exited without being entered") duration = self.clock.time() - self.start - usage = self._logging_context.get_resource_usage() + usage = self.get_resource_usage() in_flight.unregister((self.name,), self._update_in_flight) self._logging_context.__exit__(exc_type, exc_val, exc_tb) @@ -140,6 +144,13 @@ class Measure: except ValueError: logger.warning("Failed to save metrics! Usage: %s", usage) + def get_resource_usage(self) -> ContextResourceUsage: + """Get the resources used within this Measure block + + If the Measure block is still active, returns the resource usage so far. + """ + return self._logging_context.get_resource_usage() + def _update_in_flight(self, metrics): """Gets called when processing in flight metrics """ diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py index 54c046b6e1..72574d3af2 100644 --- a/synapse/util/patch_inline_callbacks.py +++ b/synapse/util/patch_inline_callbacks.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import print_function - import functools import sys from typing import Any, Callable, List diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index 79869aaa44..a5cc9d0551 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -45,7 +45,7 @@ class NotRetryingDestination(Exception): """ msg = "Not retrying server %s." % (destination,) - super(NotRetryingDestination, self).__init__(msg) + super().__init__(msg) self.retry_last_ts = retry_last_ts self.retry_interval = retry_interval |