diff options
Diffstat (limited to 'synapse')
62 files changed, 2496 insertions, 805 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 86f145649c..7ce6540bdd 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -25,7 +25,13 @@ from twisted.internet import defer import synapse.types from synapse import event_auth from synapse.api.constants import EventTypes, JoinRules, Membership -from synapse.api.errors import AuthError, Codes, ResourceLimitError +from synapse.api.errors import ( + AuthError, + Codes, + InvalidClientTokenError, + MissingClientTokenError, + ResourceLimitError, +) from synapse.config.server import is_threepid_reserved from synapse.types import UserID from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache @@ -63,7 +69,6 @@ class Auth(object): self.clock = hs.get_clock() self.store = hs.get_datastore() self.state = hs.get_state_handler() - self.TOKEN_NOT_FOUND_HTTP_STATUS = 401 self.token_cache = LruCache(CACHE_SIZE_FACTOR * 10000) register_cache("cache", "token_cache", self.token_cache) @@ -189,18 +194,17 @@ class Auth(object): Returns: defer.Deferred: resolves to a ``synapse.types.Requester`` object Raises: - AuthError if no user by that token exists or the token is invalid. + InvalidClientCredentialsError if no user by that token exists or the token + is invalid. + AuthError if access is denied for the user in the access token """ - # Can optionally look elsewhere in the request (e.g. headers) try: ip_addr = self.hs.get_ip_from_request(request) user_agent = request.requestHeaders.getRawHeaders( b"User-Agent", default=[b""] )[0].decode("ascii", "surrogateescape") - access_token = self.get_access_token_from_request( - request, self.TOKEN_NOT_FOUND_HTTP_STATUS - ) + access_token = self.get_access_token_from_request(request) user_id, app_service = yield self._get_appservice_user_id(request) if user_id: @@ -264,18 +268,12 @@ class Auth(object): ) ) except KeyError: - raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, - "Missing access token.", - errcode=Codes.MISSING_TOKEN, - ) + raise MissingClientTokenError() @defer.inlineCallbacks def _get_appservice_user_id(self, request): app_service = self.store.get_app_service_by_token( - self.get_access_token_from_request( - request, self.TOKEN_NOT_FOUND_HTTP_STATUS - ) + self.get_access_token_from_request(request) ) if app_service is None: defer.returnValue((None, None)) @@ -313,13 +311,25 @@ class Auth(object): `token_id` (int|None): access token id. May be None if guest `device_id` (str|None): device corresponding to access token Raises: - AuthError if no user by that token exists or the token is invalid. + InvalidClientCredentialsError if no user by that token exists or the token + is invalid. """ if rights == "access": # first look in the database r = yield self._look_up_user_by_access_token(token) if r: + valid_until_ms = r["valid_until_ms"] + if ( + valid_until_ms is not None + and valid_until_ms < self.clock.time_msec() + ): + # there was a valid access token, but it has expired. + # soft-logout the user. + raise InvalidClientTokenError( + msg="Access token has expired", soft_logout=True + ) + defer.returnValue(r) # otherwise it needs to be a valid macaroon @@ -331,11 +341,7 @@ class Auth(object): if not guest: # non-guest access tokens must be in the database logger.warning("Unrecognised access token - not in store.") - raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, - "Unrecognised access token.", - errcode=Codes.UNKNOWN_TOKEN, - ) + raise InvalidClientTokenError() # Guest access tokens are not stored in the database (there can # only be one access token per guest, anyway). @@ -350,16 +356,10 @@ class Auth(object): # guest tokens. stored_user = yield self.store.get_user_by_id(user_id) if not stored_user: - raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, - "Unknown user_id %s" % user_id, - errcode=Codes.UNKNOWN_TOKEN, - ) + raise InvalidClientTokenError("Unknown user_id %s" % user_id) if not stored_user["is_guest"]: - raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, - "Guest access token used for regular user", - errcode=Codes.UNKNOWN_TOKEN, + raise InvalidClientTokenError( + "Guest access token used for regular user" ) ret = { "user": user, @@ -386,11 +386,7 @@ class Auth(object): ValueError, ) as e: logger.warning("Invalid macaroon in auth: %s %s", type(e), e) - raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, - "Invalid macaroon passed.", - errcode=Codes.UNKNOWN_TOKEN, - ) + raise InvalidClientTokenError("Invalid macaroon passed.") def _parse_and_validate_macaroon(self, token, rights="access"): """Takes a macaroon and tries to parse and validate it. This is cached @@ -430,11 +426,7 @@ class Auth(object): macaroon, rights, self.hs.config.expire_access_token, user_id=user_id ) except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError): - raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, - "Invalid macaroon passed.", - errcode=Codes.UNKNOWN_TOKEN, - ) + raise InvalidClientTokenError("Invalid macaroon passed.") if not has_expiry and rights == "access": self.token_cache[token] = (user_id, guest) @@ -453,17 +445,14 @@ class Auth(object): (str) user id Raises: - AuthError if there is no user_id caveat in the macaroon + InvalidClientCredentialsError if there is no user_id caveat in the + macaroon """ user_prefix = "user_id = " for caveat in macaroon.caveats: if caveat.caveat_id.startswith(user_prefix): return caveat.caveat_id[len(user_prefix) :] - raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, - "No user caveat in macaroon", - errcode=Codes.UNKNOWN_TOKEN, - ) + raise InvalidClientTokenError("No user caveat in macaroon") def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id): """ @@ -527,26 +516,18 @@ class Auth(object): "token_id": ret.get("token_id", None), "is_guest": False, "device_id": ret.get("device_id"), + "valid_until_ms": ret.get("valid_until_ms"), } defer.returnValue(user_info) def get_appservice_by_req(self, request): - try: - token = self.get_access_token_from_request( - request, self.TOKEN_NOT_FOUND_HTTP_STATUS - ) - service = self.store.get_app_service_by_token(token) - if not service: - logger.warn("Unrecognised appservice access token.") - raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, - "Unrecognised access token.", - errcode=Codes.UNKNOWN_TOKEN, - ) - request.authenticated_entity = service.sender - return defer.succeed(service) - except KeyError: - raise AuthError(self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.") + token = self.get_access_token_from_request(request) + service = self.store.get_app_service_by_token(token) + if not service: + logger.warn("Unrecognised appservice access token.") + raise InvalidClientTokenError() + request.authenticated_entity = service.sender + return defer.succeed(service) def is_server_admin(self, user): """ Check if the given user is a local server admin. @@ -625,21 +606,6 @@ class Auth(object): defer.returnValue(auth_ids) - def check_redaction(self, room_version, event, auth_events): - """Check whether the event sender is allowed to redact the target event. - - Returns: - True if the the sender is allowed to redact the target event if the - target event was created by them. - False if the sender is allowed to redact the target event with no - further checks. - - Raises: - AuthError if the event sender is definitely not allowed to redact - the target event. - """ - return event_auth.check_redaction(room_version, event, auth_events) - @defer.inlineCallbacks def check_can_change_room_list(self, room_id, user): """Check if the user is allowed to edit the room's entry in the @@ -692,20 +658,16 @@ class Auth(object): return bool(query_params) or bool(auth_headers) @staticmethod - def get_access_token_from_request(request, token_not_found_http_status=401): + def get_access_token_from_request(request): """Extracts the access_token from the request. Args: request: The http request. - token_not_found_http_status(int): The HTTP status code to set in the - AuthError if the token isn't found. This is used in some of the - legacy APIs to change the status code to 403 from the default of - 401 since some of the old clients depended on auth errors returning - 403. Returns: unicode: The access_token Raises: - AuthError: If there isn't an access_token in the request. + MissingClientTokenError: If there isn't a single access_token in the + request """ auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") @@ -714,34 +676,20 @@ class Auth(object): # Try the get the access_token from a "Authorization: Bearer" # header if query_params is not None: - raise AuthError( - token_not_found_http_status, - "Mixing Authorization headers and access_token query parameters.", - errcode=Codes.MISSING_TOKEN, + raise MissingClientTokenError( + "Mixing Authorization headers and access_token query parameters." ) if len(auth_headers) > 1: - raise AuthError( - token_not_found_http_status, - "Too many Authorization headers.", - errcode=Codes.MISSING_TOKEN, - ) + raise MissingClientTokenError("Too many Authorization headers.") parts = auth_headers[0].split(b" ") if parts[0] == b"Bearer" and len(parts) == 2: return parts[1].decode("ascii") else: - raise AuthError( - token_not_found_http_status, - "Invalid Authorization header.", - errcode=Codes.MISSING_TOKEN, - ) + raise MissingClientTokenError("Invalid Authorization header.") else: # Try to get the access_token from the query params. if not query_params: - raise AuthError( - token_not_found_http_status, - "Missing access token.", - errcode=Codes.MISSING_TOKEN, - ) + raise MissingClientTokenError() return query_params[0].decode("ascii") diff --git a/synapse/api/errors.py b/synapse/api/errors.py index f01f4285db..be15921bc6 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -140,6 +140,22 @@ class ConsentNotGivenError(SynapseError): return cs_error(self.msg, self.errcode, consent_uri=self._consent_uri) +class UserDeactivatedError(SynapseError): + """The error returned to the client when the user attempted to access an + authenticated endpoint, but the account has been deactivated. + """ + + def __init__(self, msg): + """Constructs a UserDeactivatedError + + Args: + msg (str): The human-readable error message + """ + super(UserDeactivatedError, self).__init__( + code=http_client.FORBIDDEN, msg=msg, errcode=Codes.UNKNOWN + ) + + class RegistrationError(SynapseError): """An error raised when a registration event fails.""" @@ -211,7 +227,9 @@ class NotFoundError(SynapseError): class AuthError(SynapseError): - """An error raised when there was a problem authorising an event.""" + """An error raised when there was a problem authorising an event, and at various + other poorly-defined times. + """ def __init__(self, *args, **kwargs): if "errcode" not in kwargs: @@ -219,6 +237,41 @@ class AuthError(SynapseError): super(AuthError, self).__init__(*args, **kwargs) +class InvalidClientCredentialsError(SynapseError): + """An error raised when there was a problem with the authorisation credentials + in a client request. + + https://matrix.org/docs/spec/client_server/r0.5.0#using-access-tokens: + + When credentials are required but missing or invalid, the HTTP call will + return with a status of 401 and the error code, M_MISSING_TOKEN or + M_UNKNOWN_TOKEN respectively. + """ + + def __init__(self, msg, errcode): + super().__init__(code=401, msg=msg, errcode=errcode) + + +class MissingClientTokenError(InvalidClientCredentialsError): + """Raised when we couldn't find the access token in a request""" + + def __init__(self, msg="Missing access token"): + super().__init__(msg=msg, errcode="M_MISSING_TOKEN") + + +class InvalidClientTokenError(InvalidClientCredentialsError): + """Raised when we didn't understand the access token in a request""" + + def __init__(self, msg="Unrecognised access token", soft_logout=False): + super().__init__(msg=msg, errcode="M_UNKNOWN_TOKEN") + self._soft_logout = soft_logout + + def error_dict(self): + d = super().error_dict() + d["soft_logout"] = self._soft_logout + return d + + class ResourceLimitError(SynapseError): """ Any error raised when there is a problem with resource usage. diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 1ebb7ae539..540dbd9236 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -48,7 +48,7 @@ def register_sighup(func): _sighup_callbacks.append(func) -def start_worker_reactor(appname, config): +def start_worker_reactor(appname, config, run_command=reactor.run): """ Run the reactor in the main process Daemonizes if necessary, and then configures some resources, before starting @@ -57,6 +57,7 @@ def start_worker_reactor(appname, config): Args: appname (str): application name which will be sent to syslog config (synapse.config.Config): config object + run_command (Callable[]): callable that actually runs the reactor """ logger = logging.getLogger(config.worker_app) @@ -69,11 +70,19 @@ def start_worker_reactor(appname, config): daemonize=config.worker_daemonize, print_pidfile=config.print_pidfile, logger=logger, + run_command=run_command, ) def start_reactor( - appname, soft_file_limit, gc_thresholds, pid_file, daemonize, print_pidfile, logger + appname, + soft_file_limit, + gc_thresholds, + pid_file, + daemonize, + print_pidfile, + logger, + run_command=reactor.run, ): """ Run the reactor in the main process @@ -88,6 +97,7 @@ def start_reactor( daemonize (bool): true to run the reactor in a background process print_pidfile (bool): whether to print the pid file, if daemonize is True logger (logging.Logger): logger instance to pass to Daemonize + run_command (Callable[]): callable that actually runs the reactor """ install_dns_limiter(reactor) @@ -97,7 +107,7 @@ def start_reactor( change_resource_limit(soft_file_limit) if gc_thresholds: gc.set_threshold(*gc_thresholds) - reactor.run() + run_command() # make sure that we run the reactor with the sentinel log context, # otherwise other PreserveLoggingContext instances will get confused @@ -139,8 +149,7 @@ def listen_metrics(bind_addresses, port): """ Start Prometheus metrics server. """ - from synapse.metrics import RegistryProxy - from prometheus_client import start_http_server + from synapse.metrics import RegistryProxy, start_http_server for host in bind_addresses: logger.info("Starting metrics listener on %s:%d", host, port) @@ -243,6 +252,9 @@ def start(hs, listeners=None): # Load the certificate from disk. refresh_certificate(hs) + # Start the tracer + synapse.logging.opentracing.init_tracer(hs.config) + # It is now safe to start your Synapse. hs.start_listening(listeners) hs.get_datastore().start_profiling() diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py new file mode 100644 index 0000000000..1fd52a5526 --- /dev/null +++ b/synapse/app/admin_cmd.py @@ -0,0 +1,264 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2019 Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import logging +import os +import sys +import tempfile + +from canonicaljson import json + +from twisted.internet import defer, task + +import synapse +from synapse.app import _base +from synapse.config._base import ConfigError +from synapse.config.homeserver import HomeServerConfig +from synapse.config.logger import setup_logging +from synapse.handlers.admin import ExfiltrationWriter +from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.replication.slave.storage.account_data import SlavedAccountDataStore +from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore +from synapse.replication.slave.storage.client_ips import SlavedClientIpStore +from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore +from synapse.replication.slave.storage.devices import SlavedDeviceStore +from synapse.replication.slave.storage.events import SlavedEventStore +from synapse.replication.slave.storage.filtering import SlavedFilteringStore +from synapse.replication.slave.storage.groups import SlavedGroupServerStore +from synapse.replication.slave.storage.presence import SlavedPresenceStore +from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore +from synapse.replication.slave.storage.receipts import SlavedReceiptsStore +from synapse.replication.slave.storage.registration import SlavedRegistrationStore +from synapse.replication.slave.storage.room import RoomStore +from synapse.replication.tcp.client import ReplicationClientHandler +from synapse.server import HomeServer +from synapse.storage.engines import create_engine +from synapse.util.logcontext import LoggingContext +from synapse.util.versionstring import get_version_string + +logger = logging.getLogger("synapse.app.admin_cmd") + + +class AdminCmdSlavedStore( + SlavedReceiptsStore, + SlavedAccountDataStore, + SlavedApplicationServiceStore, + SlavedRegistrationStore, + SlavedFilteringStore, + SlavedPresenceStore, + SlavedGroupServerStore, + SlavedDeviceInboxStore, + SlavedDeviceStore, + SlavedPushRuleStore, + SlavedEventStore, + SlavedClientIpStore, + RoomStore, + BaseSlavedStore, +): + pass + + +class AdminCmdServer(HomeServer): + DATASTORE_CLASS = AdminCmdSlavedStore + + def _listen_http(self, listener_config): + pass + + def start_listening(self, listeners): + pass + + def build_tcp_replication(self): + return AdminCmdReplicationHandler(self) + + +class AdminCmdReplicationHandler(ReplicationClientHandler): + @defer.inlineCallbacks + def on_rdata(self, stream_name, token, rows): + pass + + def get_streams_to_replicate(self): + return {} + + +@defer.inlineCallbacks +def export_data_command(hs, args): + """Export data for a user. + + Args: + hs (HomeServer) + args (argparse.Namespace) + """ + + user_id = args.user_id + directory = args.output_directory + + res = yield hs.get_handlers().admin_handler.export_user_data( + user_id, FileExfiltrationWriter(user_id, directory=directory) + ) + print(res) + + +class FileExfiltrationWriter(ExfiltrationWriter): + """An ExfiltrationWriter that writes the users data to a directory. + Returns the directory location on completion. + + Note: This writes to disk on the main reactor thread. + + Args: + user_id (str): The user whose data is being exfiltrated. + directory (str|None): The directory to write the data to, if None then + will write to a temporary directory. + """ + + def __init__(self, user_id, directory=None): + self.user_id = user_id + + if directory: + self.base_directory = directory + else: + self.base_directory = tempfile.mkdtemp( + prefix="synapse-exfiltrate__%s__" % (user_id,) + ) + + os.makedirs(self.base_directory, exist_ok=True) + if list(os.listdir(self.base_directory)): + raise Exception("Directory must be empty") + + def write_events(self, room_id, events): + room_directory = os.path.join(self.base_directory, "rooms", room_id) + os.makedirs(room_directory, exist_ok=True) + events_file = os.path.join(room_directory, "events") + + with open(events_file, "a") as f: + for event in events: + print(json.dumps(event.get_pdu_json()), file=f) + + def write_state(self, room_id, event_id, state): + room_directory = os.path.join(self.base_directory, "rooms", room_id) + state_directory = os.path.join(room_directory, "state") + os.makedirs(state_directory, exist_ok=True) + + event_file = os.path.join(state_directory, event_id) + + with open(event_file, "a") as f: + for event in state.values(): + print(json.dumps(event.get_pdu_json()), file=f) + + def write_invite(self, room_id, event, state): + self.write_events(room_id, [event]) + + # We write the invite state somewhere else as they aren't full events + # and are only a subset of the state at the event. + room_directory = os.path.join(self.base_directory, "rooms", room_id) + os.makedirs(room_directory, exist_ok=True) + + invite_state = os.path.join(room_directory, "invite_state") + + with open(invite_state, "a") as f: + for event in state.values(): + print(json.dumps(event), file=f) + + def finished(self): + return self.base_directory + + +def start(config_options): + parser = argparse.ArgumentParser(description="Synapse Admin Command") + HomeServerConfig.add_arguments_to_parser(parser) + + subparser = parser.add_subparsers( + title="Admin Commands", + required=True, + dest="command", + metavar="<admin_command>", + help="The admin command to perform.", + ) + export_data_parser = subparser.add_parser( + "export-data", help="Export all data for a user" + ) + export_data_parser.add_argument("user_id", help="User to extra data from") + export_data_parser.add_argument( + "--output-directory", + action="store", + metavar="DIRECTORY", + required=False, + help="The directory to store the exported data in. Must be empty. Defaults" + " to creating a temp directory.", + ) + export_data_parser.set_defaults(func=export_data_command) + + try: + config, args = HomeServerConfig.load_config_with_parser(parser, config_options) + except ConfigError as e: + sys.stderr.write("\n" + str(e) + "\n") + sys.exit(1) + + if config.worker_app is not None: + assert config.worker_app == "synapse.app.admin_cmd" + + # Update the config with some basic overrides so that don't have to specify + # a full worker config. + config.worker_app = "synapse.app.admin_cmd" + + if ( + not config.worker_daemonize + and not config.worker_log_file + and not config.worker_log_config + ): + # Since we're meant to be run as a "command" let's not redirect stdio + # unless we've actually set log config. + config.no_redirect_stdio = True + + # Explicitly disable background processes + config.update_user_directory = False + config.start_pushers = False + config.send_federation = False + + setup_logging(config, use_worker_options=True) + + synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts + + database_engine = create_engine(config.database_config) + + ss = AdminCmdServer( + config.server_name, + db_config=config.database_config, + config=config, + version_string="Synapse/" + get_version_string(synapse), + database_engine=database_engine, + ) + + ss.setup() + + # We use task.react as the basic run command as it correctly handles tearing + # down the reactor when the deferreds resolve and setting the return value. + # We also make sure that `_base.start` gets run before we actually run the + # command. + + @defer.inlineCallbacks + def run(_reactor): + with LoggingContext("command"): + yield _base.start(ss, []) + yield args.func(ss, args) + + _base.start_worker_reactor( + "synapse-admin-cmd", config, run_command=lambda: task.react(run) + ) + + +if __name__ == "__main__": + with LoggingContext("main"): + start(sys.argv[1:]) diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py index be44249ed6..e01f3e5f3b 100644 --- a/synapse/app/appservice.py +++ b/synapse/app/appservice.py @@ -27,8 +27,7 @@ from synapse.config.homeserver import HomeServerConfig from synapse.config.logger import setup_logging from synapse.http.site import SynapseSite from synapse.logging.context import LoggingContext, run_in_background -from synapse.metrics import RegistryProxy -from synapse.metrics.resource import METRICS_PREFIX, MetricsResource +from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.directory import DirectoryStore from synapse.replication.slave.storage.events import SlavedEventStore diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py index ff11beca82..29bddc4823 100644 --- a/synapse/app/client_reader.py +++ b/synapse/app/client_reader.py @@ -28,8 +28,7 @@ from synapse.config.logger import setup_logging from synapse.http.server import JsonResource from synapse.http.site import SynapseSite from synapse.logging.context import LoggingContext -from synapse.metrics import RegistryProxy -from synapse.metrics.resource import METRICS_PREFIX, MetricsResource +from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage.account_data import SlavedAccountDataStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py index cacad25eac..042cfd04af 100644 --- a/synapse/app/event_creator.py +++ b/synapse/app/event_creator.py @@ -28,8 +28,7 @@ from synapse.config.logger import setup_logging from synapse.http.server import JsonResource from synapse.http.site import SynapseSite from synapse.logging.context import LoggingContext -from synapse.metrics import RegistryProxy -from synapse.metrics.resource import METRICS_PREFIX, MetricsResource +from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage.account_data import SlavedAccountDataStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py index 11e80dbae0..76a97f8f32 100644 --- a/synapse/app/federation_reader.py +++ b/synapse/app/federation_reader.py @@ -29,8 +29,7 @@ from synapse.config.logger import setup_logging from synapse.federation.transport.server import TransportLayerServer from synapse.http.site import SynapseSite from synapse.logging.context import LoggingContext -from synapse.metrics import RegistryProxy -from synapse.metrics.resource import METRICS_PREFIX, MetricsResource +from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage.account_data import SlavedAccountDataStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index 97da7bdcbf..fec49d5092 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -28,9 +28,8 @@ from synapse.config.logger import setup_logging from synapse.federation import send_queue from synapse.http.site import SynapseSite from synapse.logging.context import LoggingContext, run_in_background -from synapse.metrics import RegistryProxy +from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.metrics.resource import METRICS_PREFIX, MetricsResource from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore from synapse.replication.slave.storage.devices import SlavedDeviceStore from synapse.replication.slave.storage.events import SlavedEventStore diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py index 417a10bbd2..1f1f1df78e 100644 --- a/synapse/app/frontend_proxy.py +++ b/synapse/app/frontend_proxy.py @@ -30,8 +30,7 @@ from synapse.http.server import JsonResource from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.site import SynapseSite from synapse.logging.context import LoggingContext -from synapse.metrics import RegistryProxy -from synapse.metrics.resource import METRICS_PREFIX, MetricsResource +from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.client_ips import SlavedClientIpStore diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 639b1429c0..0c075cb3f1 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -55,9 +55,8 @@ from synapse.http.additional_resource import AdditionalResource from synapse.http.server import RootRedirect from synapse.http.site import SynapseSite from synapse.logging.context import LoggingContext -from synapse.metrics import RegistryProxy +from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.metrics.resource import METRICS_PREFIX, MetricsResource from synapse.module_api import ModuleApi from synapse.python_dependencies import check_requirements from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index f23b9b6eda..d70780e9d5 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -28,8 +28,7 @@ from synapse.config.homeserver import HomeServerConfig from synapse.config.logger import setup_logging from synapse.http.site import SynapseSite from synapse.logging.context import LoggingContext -from synapse.metrics import RegistryProxy -from synapse.metrics.resource import METRICS_PREFIX, MetricsResource +from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.client_ips import SlavedClientIpStore diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index 4f929edf86..070de7d0b0 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -27,8 +27,7 @@ from synapse.config.homeserver import HomeServerConfig from synapse.config.logger import setup_logging from synapse.http.site import SynapseSite from synapse.logging.context import LoggingContext, run_in_background -from synapse.metrics import RegistryProxy -from synapse.metrics.resource import METRICS_PREFIX, MetricsResource +from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.replication.slave.storage._base import __func__ from synapse.replication.slave.storage.account_data import SlavedAccountDataStore from synapse.replication.slave.storage.events import SlavedEventStore diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index de4797fddc..315c030694 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -32,8 +32,7 @@ from synapse.handlers.presence import PresenceHandler, get_interested_parties from synapse.http.server import JsonResource from synapse.http.site import SynapseSite from synapse.logging.context import LoggingContext, run_in_background -from synapse.metrics import RegistryProxy -from synapse.metrics.resource import METRICS_PREFIX, MetricsResource +from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.replication.slave.storage._base import BaseSlavedStore, __func__ from synapse.replication.slave.storage.account_data import SlavedAccountDataStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index 1177ddd72e..03ef21bd01 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -29,8 +29,7 @@ from synapse.config.logger import setup_logging from synapse.http.server import JsonResource from synapse.http.site import SynapseSite from synapse.logging.context import LoggingContext, run_in_background -from synapse.metrics import RegistryProxy -from synapse.metrics.resource import METRICS_PREFIX, MetricsResource +from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.client_ips import SlavedClientIpStore diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 965478d8d5..6ce5cd07fb 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -137,12 +137,42 @@ class Config(object): return file_stream.read() def invoke_all(self, name, *args, **kargs): + """Invoke all instance methods with the given name and arguments in the + class's MRO. + + Args: + name (str): Name of function to invoke + *args + **kwargs + + Returns: + list: The list of the return values from each method called + """ results = [] for cls in type(self).mro(): if name in cls.__dict__: results.append(getattr(cls, name)(self, *args, **kargs)) return results + @classmethod + def invoke_all_static(cls, name, *args, **kargs): + """Invoke all static methods with the given name and arguments in the + class's MRO. + + Args: + name (str): Name of function to invoke + *args + **kwargs + + Returns: + list: The list of the return values from each method called + """ + results = [] + for c in cls.mro(): + if name in c.__dict__: + results.append(getattr(c, name)(*args, **kargs)) + return results + def generate_config( self, config_dir_path, @@ -202,6 +232,23 @@ class Config(object): Returns: Config object. """ config_parser = argparse.ArgumentParser(description=description) + cls.add_arguments_to_parser(config_parser) + obj, _ = cls.load_config_with_parser(config_parser, argv) + + return obj + + @classmethod + def add_arguments_to_parser(cls, config_parser): + """Adds all the config flags to an ArgumentParser. + + Doesn't support config-file-generation: used by the worker apps. + + Used for workers where we want to add extra flags/subcommands. + + Args: + config_parser (ArgumentParser): App description + """ + config_parser.add_argument( "-c", "--config-path", @@ -219,16 +266,34 @@ class Config(object): " Defaults to the directory containing the last config file", ) - obj = cls() + cls.invoke_all_static("add_arguments", config_parser) - obj.invoke_all("add_arguments", config_parser) + @classmethod + def load_config_with_parser(cls, parser, argv): + """Parse the commandline and config files with the given parser + + Doesn't support config-file-generation: used by the worker apps. - config_args = config_parser.parse_args(argv) + Used for workers where we want to add extra flags/subcommands. + + Args: + parser (ArgumentParser) + argv (list[str]) + + Returns: + tuple[HomeServerConfig, argparse.Namespace]: Returns the parsed + config object and the parsed argparse.Namespace object from + `parser.parse_args(..)` + """ + + obj = cls() + + config_args = parser.parse_args(argv) config_files = find_config_files(search_paths=config_args.config_path) if not config_files: - config_parser.error("Must supply a config file.") + parser.error("Must supply a config file.") if config_args.keys_directory: config_dir_path = config_args.keys_directory @@ -244,7 +309,7 @@ class Config(object): obj.invoke_all("read_arguments", config_args) - return obj + return obj, config_args @classmethod def load_or_generate_config(cls, description, argv): @@ -401,7 +466,7 @@ class Config(object): formatter_class=argparse.RawDescriptionHelpFormatter, ) - obj.invoke_all("add_arguments", parser) + obj.invoke_all_static("add_arguments", parser) args = parser.parse_args(remaining_args) config_dict = read_config_files(config_files) diff --git a/synapse/config/database.py b/synapse/config/database.py index bcb2089dd7..746a6cd1f4 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -69,7 +69,8 @@ class DatabaseConfig(Config): if database_path is not None: self.database_config["args"]["database"] = database_path - def add_arguments(self, parser): + @staticmethod + def add_arguments(parser): db_group = parser.add_argument_group("database") db_group.add_argument( "-d", diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index acadef4fd3..72acad4f18 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -40,6 +40,7 @@ from .spam_checker import SpamCheckerConfig from .stats import StatsConfig from .third_party_event_rules import ThirdPartyRulesConfig from .tls import TlsConfig +from .tracer import TracerConfig from .user_directory import UserDirectoryConfig from .voip import VoipConfig from .workers import WorkerConfig @@ -75,5 +76,6 @@ class HomeServerConfig( ServerNoticesConfig, RoomDirectoryConfig, ThirdPartyRulesConfig, + TracerConfig, ): pass diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 0f5554211c..d321d00b80 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import logging import logging.config import os @@ -75,10 +76,8 @@ root: class LoggingConfig(Config): def read_config(self, config, **kwargs): - self.verbosity = config.get("verbose", 0) - self.no_redirect_stdio = config.get("no_redirect_stdio", False) self.log_config = self.abspath(config.get("log_config")) - self.log_file = self.abspath(config.get("log_file")) + self.no_redirect_stdio = config.get("no_redirect_stdio", False) def generate_config_section(self, config_dir_path, server_name, **kwargs): log_config = os.path.join(config_dir_path, server_name + ".log.config") @@ -94,38 +93,13 @@ class LoggingConfig(Config): ) def read_arguments(self, args): - if args.verbose is not None: - self.verbosity = args.verbose if args.no_redirect_stdio is not None: self.no_redirect_stdio = args.no_redirect_stdio - if args.log_config is not None: - self.log_config = args.log_config - if args.log_file is not None: - self.log_file = args.log_file - def add_arguments(cls, parser): + @staticmethod + def add_arguments(parser): logging_group = parser.add_argument_group("logging") logging_group.add_argument( - "-v", - "--verbose", - dest="verbose", - action="count", - help="The verbosity level. Specify multiple times to increase " - "verbosity. (Ignored if --log-config is specified.)", - ) - logging_group.add_argument( - "-f", - "--log-file", - dest="log_file", - help="File to log to. (Ignored if --log-config is specified.)", - ) - logging_group.add_argument( - "--log-config", - dest="log_config", - default=None, - help="Python logging config file", - ) - logging_group.add_argument( "-n", "--no-redirect-stdio", action="store_true", @@ -152,58 +126,29 @@ def setup_logging(config, use_worker_options=False): config (LoggingConfig | synapse.config.workers.WorkerConfig): configuration data - use_worker_options (bool): True to use 'worker_log_config' and - 'worker_log_file' options instead of 'log_config' and 'log_file'. + use_worker_options (bool): True to use the 'worker_log_config' option + instead of 'log_config'. register_sighup (func | None): Function to call to register a sighup handler. """ log_config = config.worker_log_config if use_worker_options else config.log_config - log_file = config.worker_log_file if use_worker_options else config.log_file - - log_format = ( - "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s" - " - %(message)s" - ) if log_config is None: - # We don't have a logfile, so fall back to the 'verbosity' param from - # the config or cmdline. (Note that we generate a log config for new - # installs, so this will be an unusual case) - level = logging.INFO - level_for_storage = logging.INFO - if config.verbosity: - level = logging.DEBUG - if config.verbosity > 1: - level_for_storage = logging.DEBUG + log_format = ( + "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s" + " - %(message)s" + ) logger = logging.getLogger("") - logger.setLevel(level) - - logging.getLogger("synapse.storage.SQL").setLevel(level_for_storage) + logger.setLevel(logging.INFO) + logging.getLogger("synapse.storage.SQL").setLevel(logging.INFO) formatter = logging.Formatter(log_format) - if log_file: - # TODO: Customisable file size / backup count - handler = logging.handlers.RotatingFileHandler( - log_file, maxBytes=(1000 * 1000 * 100), backupCount=3, encoding="utf8" - ) - - def sighup(signum, stack): - logger.info("Closing log file due to SIGHUP") - handler.doRollover() - logger.info("Opened new log file due to SIGHUP") - - else: - handler = logging.StreamHandler() - - def sighup(*args): - pass + handler = logging.StreamHandler() handler.setFormatter(formatter) - handler.addFilter(LoggingContextFilter(request="")) - logger.addHandler(handler) else: @@ -217,8 +162,7 @@ def setup_logging(config, use_worker_options=False): logging.info("Reloaded log config from %s due to SIGHUP", log_config) load_log_config() - - appbase.register_sighup(sighup) + appbase.register_sighup(sighup) # make sure that the first thing we log is a thing we can grep backwards # for diff --git a/synapse/config/registration.py b/synapse/config/registration.py index b895c4e9f4..c3de7a4e32 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -84,6 +84,11 @@ class RegistrationConfig(Config): "disable_msisdn_registration", False ) + session_lifetime = config.get("session_lifetime") + if session_lifetime is not None: + session_lifetime = self.parse_duration(session_lifetime) + self.session_lifetime = session_lifetime + def generate_config_section(self, generate_secrets=False, **kwargs): if generate_secrets: registration_shared_secret = 'registration_shared_secret: "%s"' % ( @@ -141,6 +146,17 @@ class RegistrationConfig(Config): # renew_at: 1w # renew_email_subject: "Renew your %%(app)s account" + # Time that a user's session remains valid for, after they log in. + # + # Note that this is not currently compatible with guest logins. + # + # Note also that this is calculated at login time: changes are not applied + # retrospectively to users who have already logged in. + # + # By default, this is infinite. + # + #session_lifetime: 24h + # The user must provide all of the below types of 3PID when registering. # #registrations_require_3pid: @@ -221,7 +237,8 @@ class RegistrationConfig(Config): % locals() ) - def add_arguments(self, parser): + @staticmethod + def add_arguments(parser): reg_group = parser.add_argument_group("registration") reg_group.add_argument( "--enable-registration", diff --git a/synapse/config/server.py b/synapse/config/server.py index 2a74dea2ea..00170f1393 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -136,7 +136,7 @@ class ServerConfig(Config): # Whether to enable experimental MSC1849 (aka relations) support self.experimental_msc1849_support_enabled = config.get( - "experimental_msc1849_support_enabled", False + "experimental_msc1849_support_enabled", True ) # Options to control access by tracking MAU @@ -639,7 +639,8 @@ class ServerConfig(Config): if args.print_pidfile is not None: self.print_pidfile = args.print_pidfile - def add_arguments(self, parser): + @staticmethod + def add_arguments(parser): server_group = parser.add_argument_group("server") server_group.add_argument( "-D", diff --git a/synapse/config/tracer.py b/synapse/config/tracer.py new file mode 100644 index 0000000000..a2ce9ab3f6 --- /dev/null +++ b/synapse/config/tracer.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C.d +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import Config, ConfigError + + +class TracerConfig(Config): + def read_config(self, config, **kwargs): + opentracing_config = config.get("opentracing") + if opentracing_config is None: + opentracing_config = {} + + self.opentracer_enabled = opentracing_config.get("enabled", False) + if not self.opentracer_enabled: + return + + # The tracer is enabled so sanitize the config + + self.opentracer_whitelist = opentracing_config.get("homeserver_whitelist", []) + if not isinstance(self.opentracer_whitelist, list): + raise ConfigError("Tracer homeserver_whitelist config is malformed") + + def generate_config_section(cls, **kwargs): + return """\ + ## Opentracing ## + + # These settings enable opentracing, which implements distributed tracing. + # This allows you to observe the causal chains of events across servers + # including requests, key lookups etc., across any server running + # synapse or any other other services which supports opentracing + # (specifically those implemented with Jaeger). + # + opentracing: + # tracing is disabled by default. Uncomment the following line to enable it. + # + #enabled: true + + # The list of homeservers we wish to send and receive span contexts and span baggage. + # + # Though it's mostly safe to send and receive span contexts to and from + # untrusted users since span contexts are usually opaque ids it can lead to + # two problems, namely: + # - If the span context is marked as sampled by the sending homeserver the receiver will + # sample it. Therefore two homeservers with wildly disparaging sampling policies + # could incur higher sampling counts than intended. + # - Span baggage can be arbitrary data. For safety this has been disabled in synapse + # but that doesn't prevent another server sending you baggage which will be logged + # to opentracing logs. + # + # This a list of regexes which are matched against the server_name of the + # homeserver. + # + # By defult, it is empty, so no servers are matched. + # + #homeserver_whitelist: + # - ".*" + """ diff --git a/synapse/config/workers.py b/synapse/config/workers.py index 3b75471d85..246d72cd61 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -31,8 +31,6 @@ class WorkerConfig(Config): self.worker_listeners = config.get("worker_listeners", []) self.worker_daemonize = config.get("worker_daemonize") self.worker_pid_file = config.get("worker_pid_file") - self.worker_log_file = config.get("worker_log_file") - self.worker_log_config = config.get("worker_log_config") # The host used to connect to the main synapse self.worker_replication_host = config.get("worker_replication_host", None) @@ -78,9 +76,5 @@ class WorkerConfig(Config): if args.daemonize is not None: self.worker_daemonize = args.daemonize - if args.log_config is not None: - self.worker_log_config = args.log_config - if args.log_file is not None: - self.worker_log_file = args.log_file if args.manhole is not None: self.worker_manhole = args.worker_manhole diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index d3de70e671..88ed6d764f 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -104,6 +104,17 @@ class _EventInternalMetadata(object): """ return getattr(self, "proactively_send", True) + def is_redacted(self): + """Whether the event has been redacted. + + This is used for efficiently checking whether an event has been + marked as redacted without needing to make another database call. + + Returns: + bool + """ + return getattr(self, "redacted", False) + def _event_dict_property(key): # We want to be able to use hasattr with the event dict properties. diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 987de5cab7..9487a886f5 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -52,10 +52,15 @@ def prune_event(event): from . import event_type_from_format_version - return event_type_from_format_version(event.format_version)( + pruned_event = event_type_from_format_version(event.format_version)( pruned_event_dict, event.internal_metadata.get_dict() ) + # Mark the event as redacted + pruned_event.internal_metadata.redacted = True + + return pruned_event + def prune_event_dict(event_dict): """Redacts the event_dict in the same way as `prune_event`, except it @@ -360,9 +365,12 @@ class EventClientSerializer(object): event_id = event.event_id serialized_event = serialize_event(event, time_now, **kwargs) - # If MSC1849 is enabled then we need to look if thre are any relations - # we need to bundle in with the event - if self.experimental_msc1849_support_enabled and bundle_aggregations: + # If MSC1849 is enabled then we need to look if there are any relations + # we need to bundle in with the event. + # Do not bundle relations if the event has been redacted + if not event.internal_metadata.is_redacted() and ( + self.experimental_msc1849_support_enabled and bundle_aggregations + ): annotations = yield self.store.get_aggregation_groups_for_event(event_id) references = yield self.store.get_relations_for_event( event_id, RelationTypes.REFERENCE, direction="f" diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 2efdcff7ef..663264dec4 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # Copyright 2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,9 +19,8 @@ import functools import logging import re -from twisted.internet import defer - import synapse +import synapse.logging.opentracing as opentracing from synapse.api.errors import Codes, FederationDeniedError, SynapseError from synapse.api.room_versions import RoomVersions from synapse.api.urls import ( @@ -102,8 +102,7 @@ class Authenticator(object): self.federation_domain_whitelist = hs.config.federation_domain_whitelist # A method just so we can pass 'self' as the authenticator to the Servlets - @defer.inlineCallbacks - def authenticate_request(self, request, content): + async def authenticate_request(self, request, content): now = self._clock.time_msec() json_request = { "method": request.method.decode("ascii"), @@ -141,7 +140,7 @@ class Authenticator(object): 401, "Missing Authorization headers", Codes.UNAUTHORIZED ) - yield self.keyring.verify_json_for_server( + await self.keyring.verify_json_for_server( origin, json_request, now, "Incoming request" ) @@ -150,17 +149,16 @@ class Authenticator(object): # If we get a valid signed request from the other side, its probably # alive - retry_timings = yield self.store.get_destination_retry_timings(origin) + retry_timings = await self.store.get_destination_retry_timings(origin) if retry_timings and retry_timings["retry_last_ts"]: run_in_background(self._reset_retry_timings, origin) - defer.returnValue(origin) + return origin - @defer.inlineCallbacks - def _reset_retry_timings(self, origin): + async def _reset_retry_timings(self, origin): try: logger.info("Marking origin %r as up", origin) - yield self.store.set_destination_retry_timings(origin, 0, 0) + await self.store.set_destination_retry_timings(origin, 0, 0) except Exception: logger.exception("Error resetting retry timings on %s", origin) @@ -214,7 +212,8 @@ class BaseFederationServlet(object): match against the request path (excluding the /federation/v1 prefix). The servlet should also implement one or more of on_GET, on_POST, on_PUT, to match - the appropriate HTTP method. These methods have the signature: + the appropriate HTTP method. These methods must be *asynchronous* and have the + signature: on_<METHOD>(self, origin, content, query, **kwargs) @@ -234,7 +233,7 @@ class BaseFederationServlet(object): components as specified in the path match regexp. Returns: - Deferred[(int, object)|None]: either (response code, response object) to + Optional[Tuple[int, object]]: either (response code, response object) to return a JSON response, or None if the request has already been handled. Raises: @@ -257,10 +256,9 @@ class BaseFederationServlet(object): authenticator = self.authenticator ratelimiter = self.ratelimiter - @defer.inlineCallbacks @functools.wraps(func) - def new_func(request, *args, **kwargs): - """ A callback which can be passed to HttpServer.RegisterPaths + async def new_func(request, *args, **kwargs): + """A callback which can be passed to HttpServer.RegisterPaths Args: request (twisted.web.http.Request): @@ -269,8 +267,8 @@ class BaseFederationServlet(object): components as specified in the path match regexp. Returns: - Deferred[(int, object)|None]: (response code, response object) as returned - by the callback method. None if the request has already been handled. + Tuple[int, object]|None: (response code, response object) as returned by + the callback method. None if the request has already been handled. """ content = None if request.method in [b"PUT", b"POST"]: @@ -278,7 +276,7 @@ class BaseFederationServlet(object): content = parse_json_object_from_request(request) try: - origin = yield authenticator.authenticate_request(request, content) + origin = await authenticator.authenticate_request(request, content) except NoAuthenticationError: origin = None if self.REQUIRE_AUTH: @@ -288,16 +286,31 @@ class BaseFederationServlet(object): logger.warn("authenticate_request failed: %s", e) raise - if origin: - with ratelimiter.ratelimit(origin) as d: - yield d - response = yield func( + # Start an opentracing span + with opentracing.start_active_span_from_context( + request.requestHeaders, + "incoming-federation-request", + tags={ + "request_id": request.get_request_id(), + opentracing.tags.SPAN_KIND: opentracing.tags.SPAN_KIND_RPC_SERVER, + opentracing.tags.HTTP_METHOD: request.get_method(), + opentracing.tags.HTTP_URL: request.get_redacted_uri(), + opentracing.tags.PEER_HOST_IPV6: request.getClientIP(), + "authenticated_entity": origin, + }, + ): + if origin: + with ratelimiter.ratelimit(origin) as d: + await d + response = await func( + origin, content, request.args, *args, **kwargs + ) + else: + response = await func( origin, content, request.args, *args, **kwargs ) - else: - response = yield func(origin, content, request.args, *args, **kwargs) - defer.returnValue(response) + return response # Extra logic that functools.wraps() doesn't finish new_func.__self__ = func.__self__ @@ -325,8 +338,7 @@ class FederationSendServlet(BaseFederationServlet): self.server_name = server_name # This is when someone is trying to send us a bunch of data. - @defer.inlineCallbacks - def on_PUT(self, origin, content, query, transaction_id): + async def on_PUT(self, origin, content, query, transaction_id): """ Called on PUT /send/<transaction_id>/ Args: @@ -335,7 +347,7 @@ class FederationSendServlet(BaseFederationServlet): request. This is *not* None. Returns: - Deferred: Results in a tuple of `(code, response)`, where + Tuple of `(code, response)`, where `response` is a python dict to be converted into JSON that is used as the response body. """ @@ -364,34 +376,33 @@ class FederationSendServlet(BaseFederationServlet): except Exception as e: logger.exception(e) - defer.returnValue((400, {"error": "Invalid transaction"})) - return + return 400, {"error": "Invalid transaction"} try: - code, response = yield self.handler.on_incoming_transaction( + code, response = await self.handler.on_incoming_transaction( origin, transaction_data ) except Exception: logger.exception("on_incoming_transaction failed") raise - defer.returnValue((code, response)) + return code, response class FederationEventServlet(BaseFederationServlet): PATH = "/event/(?P<event_id>[^/]*)/?" # This is when someone asks for a data item for a given server data_id pair. - def on_GET(self, origin, content, query, event_id): - return self.handler.on_pdu_request(origin, event_id) + async def on_GET(self, origin, content, query, event_id): + return await self.handler.on_pdu_request(origin, event_id) class FederationStateServlet(BaseFederationServlet): PATH = "/state/(?P<context>[^/]*)/?" # This is when someone asks for all data for a given context. - def on_GET(self, origin, content, query, context): - return self.handler.on_context_state_request( + async def on_GET(self, origin, content, query, context): + return await self.handler.on_context_state_request( origin, context, parse_string_from_args(query, "event_id", None, required=True), @@ -401,8 +412,8 @@ class FederationStateServlet(BaseFederationServlet): class FederationStateIdsServlet(BaseFederationServlet): PATH = "/state_ids/(?P<room_id>[^/]*)/?" - def on_GET(self, origin, content, query, room_id): - return self.handler.on_state_ids_request( + async def on_GET(self, origin, content, query, room_id): + return await self.handler.on_state_ids_request( origin, room_id, parse_string_from_args(query, "event_id", None, required=True), @@ -412,22 +423,22 @@ class FederationStateIdsServlet(BaseFederationServlet): class FederationBackfillServlet(BaseFederationServlet): PATH = "/backfill/(?P<context>[^/]*)/?" - def on_GET(self, origin, content, query, context): + async def on_GET(self, origin, content, query, context): versions = [x.decode("ascii") for x in query[b"v"]] limit = parse_integer_from_args(query, "limit", None) if not limit: - return defer.succeed((400, {"error": "Did not include limit param"})) + return 400, {"error": "Did not include limit param"} - return self.handler.on_backfill_request(origin, context, versions, limit) + return await self.handler.on_backfill_request(origin, context, versions, limit) class FederationQueryServlet(BaseFederationServlet): PATH = "/query/(?P<query_type>[^/]*)" # This is when we receive a server-server Query - def on_GET(self, origin, content, query, query_type): - return self.handler.on_query_request( + async def on_GET(self, origin, content, query, query_type): + return await self.handler.on_query_request( query_type, {k.decode("utf8"): v[0].decode("utf-8") for k, v in query.items()}, ) @@ -436,8 +447,7 @@ class FederationQueryServlet(BaseFederationServlet): class FederationMakeJoinServlet(BaseFederationServlet): PATH = "/make_join/(?P<context>[^/]*)/(?P<user_id>[^/]*)" - @defer.inlineCallbacks - def on_GET(self, origin, _content, query, context, user_id): + async def on_GET(self, origin, _content, query, context, user_id): """ Args: origin (unicode): The authenticated server_name of the calling server @@ -450,8 +460,7 @@ class FederationMakeJoinServlet(BaseFederationServlet): components as specified in the path match regexp. Returns: - Deferred[(int, object)|None]: either (response code, response object) to - return a JSON response, or None if the request has already been handled. + Tuple[int, object]: (response code, response object) """ versions = query.get(b"ver") if versions is not None: @@ -459,64 +468,60 @@ class FederationMakeJoinServlet(BaseFederationServlet): else: supported_versions = ["1"] - content = yield self.handler.on_make_join_request( + content = await self.handler.on_make_join_request( origin, context, user_id, supported_versions=supported_versions ) - defer.returnValue((200, content)) + return 200, content class FederationMakeLeaveServlet(BaseFederationServlet): PATH = "/make_leave/(?P<context>[^/]*)/(?P<user_id>[^/]*)" - @defer.inlineCallbacks - def on_GET(self, origin, content, query, context, user_id): - content = yield self.handler.on_make_leave_request(origin, context, user_id) - defer.returnValue((200, content)) + async def on_GET(self, origin, content, query, context, user_id): + content = await self.handler.on_make_leave_request(origin, context, user_id) + return 200, content class FederationSendLeaveServlet(BaseFederationServlet): PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" - @defer.inlineCallbacks - def on_PUT(self, origin, content, query, room_id, event_id): - content = yield self.handler.on_send_leave_request(origin, content, room_id) - defer.returnValue((200, content)) + async def on_PUT(self, origin, content, query, room_id, event_id): + content = await self.handler.on_send_leave_request(origin, content, room_id) + return 200, content class FederationEventAuthServlet(BaseFederationServlet): PATH = "/event_auth/(?P<context>[^/]*)/(?P<event_id>[^/]*)" - def on_GET(self, origin, content, query, context, event_id): - return self.handler.on_event_auth(origin, context, event_id) + async def on_GET(self, origin, content, query, context, event_id): + return await self.handler.on_event_auth(origin, context, event_id) class FederationSendJoinServlet(BaseFederationServlet): PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)" - @defer.inlineCallbacks - def on_PUT(self, origin, content, query, context, event_id): + async def on_PUT(self, origin, content, query, context, event_id): # TODO(paul): assert that context/event_id parsed from path actually # match those given in content - content = yield self.handler.on_send_join_request(origin, content, context) - defer.returnValue((200, content)) + content = await self.handler.on_send_join_request(origin, content, context) + return 200, content class FederationV1InviteServlet(BaseFederationServlet): PATH = "/invite/(?P<context>[^/]*)/(?P<event_id>[^/]*)" - @defer.inlineCallbacks - def on_PUT(self, origin, content, query, context, event_id): + async def on_PUT(self, origin, content, query, context, event_id): # We don't get a room version, so we have to assume its EITHER v1 or # v2. This is "fine" as the only difference between V1 and V2 is the # state resolution algorithm, and we don't use that for processing # invites - content = yield self.handler.on_invite_request( + content = await self.handler.on_invite_request( origin, content, room_version=RoomVersions.V1.identifier ) # V1 federation API is defined to return a content of `[200, {...}]` # due to a historical bug. - defer.returnValue((200, (200, content))) + return 200, (200, content) class FederationV2InviteServlet(BaseFederationServlet): @@ -524,8 +529,7 @@ class FederationV2InviteServlet(BaseFederationServlet): PREFIX = FEDERATION_V2_PREFIX - @defer.inlineCallbacks - def on_PUT(self, origin, content, query, context, event_id): + async def on_PUT(self, origin, content, query, context, event_id): # TODO(paul): assert that context/event_id parsed from path actually # match those given in content @@ -538,69 +542,65 @@ class FederationV2InviteServlet(BaseFederationServlet): event.setdefault("unsigned", {})["invite_room_state"] = invite_room_state - content = yield self.handler.on_invite_request( + content = await self.handler.on_invite_request( origin, event, room_version=room_version ) - defer.returnValue((200, content)) + return 200, content class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet): PATH = "/exchange_third_party_invite/(?P<room_id>[^/]*)" - @defer.inlineCallbacks - def on_PUT(self, origin, content, query, room_id): - content = yield self.handler.on_exchange_third_party_invite_request( + async def on_PUT(self, origin, content, query, room_id): + content = await self.handler.on_exchange_third_party_invite_request( origin, room_id, content ) - defer.returnValue((200, content)) + return 200, content class FederationClientKeysQueryServlet(BaseFederationServlet): PATH = "/user/keys/query" - def on_POST(self, origin, content, query): - return self.handler.on_query_client_keys(origin, content) + async def on_POST(self, origin, content, query): + return await self.handler.on_query_client_keys(origin, content) class FederationUserDevicesQueryServlet(BaseFederationServlet): PATH = "/user/devices/(?P<user_id>[^/]*)" - def on_GET(self, origin, content, query, user_id): - return self.handler.on_query_user_devices(origin, user_id) + async def on_GET(self, origin, content, query, user_id): + return await self.handler.on_query_user_devices(origin, user_id) class FederationClientKeysClaimServlet(BaseFederationServlet): PATH = "/user/keys/claim" - @defer.inlineCallbacks - def on_POST(self, origin, content, query): - response = yield self.handler.on_claim_client_keys(origin, content) - defer.returnValue((200, response)) + async def on_POST(self, origin, content, query): + response = await self.handler.on_claim_client_keys(origin, content) + return 200, response class FederationQueryAuthServlet(BaseFederationServlet): PATH = "/query_auth/(?P<context>[^/]*)/(?P<event_id>[^/]*)" - @defer.inlineCallbacks - def on_POST(self, origin, content, query, context, event_id): - new_content = yield self.handler.on_query_auth_request( + async def on_POST(self, origin, content, query, context, event_id): + new_content = await self.handler.on_query_auth_request( origin, content, context, event_id ) - defer.returnValue((200, new_content)) + return 200, new_content class FederationGetMissingEventsServlet(BaseFederationServlet): # TODO(paul): Why does this path alone end with "/?" optional? PATH = "/get_missing_events/(?P<room_id>[^/]*)/?" - @defer.inlineCallbacks - def on_POST(self, origin, content, query, room_id): + async def on_POST(self, origin, content, query, room_id): limit = int(content.get("limit", 10)) earliest_events = content.get("earliest_events", []) latest_events = content.get("latest_events", []) - content = yield self.handler.on_get_missing_events( + content = await self.handler.on_get_missing_events( origin, room_id=room_id, earliest_events=earliest_events, @@ -608,7 +608,7 @@ class FederationGetMissingEventsServlet(BaseFederationServlet): limit=limit, ) - defer.returnValue((200, content)) + return 200, content class On3pidBindServlet(BaseFederationServlet): @@ -616,8 +616,7 @@ class On3pidBindServlet(BaseFederationServlet): REQUIRE_AUTH = False - @defer.inlineCallbacks - def on_POST(self, origin, content, query): + async def on_POST(self, origin, content, query): if "invites" in content: last_exception = None for invite in content["invites"]: @@ -629,7 +628,7 @@ class On3pidBindServlet(BaseFederationServlet): ) logger.info(message) raise SynapseError(400, message) - yield self.handler.exchange_third_party_invite( + await self.handler.exchange_third_party_invite( invite["sender"], invite["mxid"], invite["room_id"], @@ -639,7 +638,7 @@ class On3pidBindServlet(BaseFederationServlet): last_exception = e if last_exception: raise last_exception - defer.returnValue((200, {})) + return 200, {} class OpenIdUserInfo(BaseFederationServlet): @@ -663,29 +662,26 @@ class OpenIdUserInfo(BaseFederationServlet): REQUIRE_AUTH = False - @defer.inlineCallbacks - def on_GET(self, origin, content, query): + async def on_GET(self, origin, content, query): token = query.get(b"access_token", [None])[0] if token is None: - defer.returnValue( - (401, {"errcode": "M_MISSING_TOKEN", "error": "Access Token required"}) + return ( + 401, + {"errcode": "M_MISSING_TOKEN", "error": "Access Token required"}, ) - return - user_id = yield self.handler.on_openid_userinfo(token.decode("ascii")) + user_id = await self.handler.on_openid_userinfo(token.decode("ascii")) if user_id is None: - defer.returnValue( - ( - 401, - { - "errcode": "M_UNKNOWN_TOKEN", - "error": "Access Token unknown or expired", - }, - ) + return ( + 401, + { + "errcode": "M_UNKNOWN_TOKEN", + "error": "Access Token unknown or expired", + }, ) - defer.returnValue((200, {"sub": user_id})) + return 200, {"sub": user_id} class PublicRoomList(BaseFederationServlet): @@ -727,8 +723,7 @@ class PublicRoomList(BaseFederationServlet): ) self.allow_access = allow_access - @defer.inlineCallbacks - def on_GET(self, origin, content, query): + async def on_GET(self, origin, content, query): if not self.allow_access: raise FederationDeniedError(origin) @@ -748,10 +743,10 @@ class PublicRoomList(BaseFederationServlet): else: network_tuple = ThirdPartyInstanceID(None, None) - data = yield self.handler.get_local_public_room_list( + data = await self.handler.get_local_public_room_list( limit, since_token, network_tuple=network_tuple, from_federation=True ) - defer.returnValue((200, data)) + return 200, data class FederationVersionServlet(BaseFederationServlet): @@ -759,12 +754,10 @@ class FederationVersionServlet(BaseFederationServlet): REQUIRE_AUTH = False - def on_GET(self, origin, content, query): - return defer.succeed( - ( - 200, - {"server": {"name": "Synapse", "version": get_version_string(synapse)}}, - ) + async def on_GET(self, origin, content, query): + return ( + 200, + {"server": {"name": "Synapse", "version": get_version_string(synapse)}}, ) @@ -774,41 +767,38 @@ class FederationGroupsProfileServlet(BaseFederationServlet): PATH = "/groups/(?P<group_id>[^/]*)/profile" - @defer.inlineCallbacks - def on_GET(self, origin, content, query, group_id): + async def on_GET(self, origin, content, query, group_id): requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - new_content = yield self.handler.get_group_profile(group_id, requester_user_id) + new_content = await self.handler.get_group_profile(group_id, requester_user_id) - defer.returnValue((200, new_content)) + return 200, new_content - @defer.inlineCallbacks - def on_POST(self, origin, content, query, group_id): + async def on_POST(self, origin, content, query, group_id): requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - new_content = yield self.handler.update_group_profile( + new_content = await self.handler.update_group_profile( group_id, requester_user_id, content ) - defer.returnValue((200, new_content)) + return 200, new_content class FederationGroupsSummaryServlet(BaseFederationServlet): PATH = "/groups/(?P<group_id>[^/]*)/summary" - @defer.inlineCallbacks - def on_GET(self, origin, content, query, group_id): + async def on_GET(self, origin, content, query, group_id): requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - new_content = yield self.handler.get_group_summary(group_id, requester_user_id) + new_content = await self.handler.get_group_summary(group_id, requester_user_id) - defer.returnValue((200, new_content)) + return 200, new_content class FederationGroupsRoomsServlet(BaseFederationServlet): @@ -817,15 +807,14 @@ class FederationGroupsRoomsServlet(BaseFederationServlet): PATH = "/groups/(?P<group_id>[^/]*)/rooms" - @defer.inlineCallbacks - def on_GET(self, origin, content, query, group_id): + async def on_GET(self, origin, content, query, group_id): requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - new_content = yield self.handler.get_rooms_in_group(group_id, requester_user_id) + new_content = await self.handler.get_rooms_in_group(group_id, requester_user_id) - defer.returnValue((200, new_content)) + return 200, new_content class FederationGroupsAddRoomsServlet(BaseFederationServlet): @@ -834,29 +823,27 @@ class FederationGroupsAddRoomsServlet(BaseFederationServlet): PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)" - @defer.inlineCallbacks - def on_POST(self, origin, content, query, group_id, room_id): + async def on_POST(self, origin, content, query, group_id, room_id): requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - new_content = yield self.handler.add_room_to_group( + new_content = await self.handler.add_room_to_group( group_id, requester_user_id, room_id, content ) - defer.returnValue((200, new_content)) + return 200, new_content - @defer.inlineCallbacks - def on_DELETE(self, origin, content, query, group_id, room_id): + async def on_DELETE(self, origin, content, query, group_id, room_id): requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - new_content = yield self.handler.remove_room_from_group( + new_content = await self.handler.remove_room_from_group( group_id, requester_user_id, room_id ) - defer.returnValue((200, new_content)) + return 200, new_content class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet): @@ -868,17 +855,16 @@ class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet): "/config/(?P<config_key>[^/]*)" ) - @defer.inlineCallbacks - def on_POST(self, origin, content, query, group_id, room_id, config_key): + async def on_POST(self, origin, content, query, group_id, room_id, config_key): requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - result = yield self.groups_handler.update_room_in_group( + result = await self.groups_handler.update_room_in_group( group_id, requester_user_id, room_id, config_key, content ) - defer.returnValue((200, result)) + return 200, result class FederationGroupsUsersServlet(BaseFederationServlet): @@ -887,15 +873,14 @@ class FederationGroupsUsersServlet(BaseFederationServlet): PATH = "/groups/(?P<group_id>[^/]*)/users" - @defer.inlineCallbacks - def on_GET(self, origin, content, query, group_id): + async def on_GET(self, origin, content, query, group_id): requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - new_content = yield self.handler.get_users_in_group(group_id, requester_user_id) + new_content = await self.handler.get_users_in_group(group_id, requester_user_id) - defer.returnValue((200, new_content)) + return 200, new_content class FederationGroupsInvitedUsersServlet(BaseFederationServlet): @@ -904,17 +889,16 @@ class FederationGroupsInvitedUsersServlet(BaseFederationServlet): PATH = "/groups/(?P<group_id>[^/]*)/invited_users" - @defer.inlineCallbacks - def on_GET(self, origin, content, query, group_id): + async def on_GET(self, origin, content, query, group_id): requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - new_content = yield self.handler.get_invited_users_in_group( + new_content = await self.handler.get_invited_users_in_group( group_id, requester_user_id ) - defer.returnValue((200, new_content)) + return 200, new_content class FederationGroupsInviteServlet(BaseFederationServlet): @@ -923,17 +907,16 @@ class FederationGroupsInviteServlet(BaseFederationServlet): PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite" - @defer.inlineCallbacks - def on_POST(self, origin, content, query, group_id, user_id): + async def on_POST(self, origin, content, query, group_id, user_id): requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - new_content = yield self.handler.invite_to_group( + new_content = await self.handler.invite_to_group( group_id, user_id, requester_user_id, content ) - defer.returnValue((200, new_content)) + return 200, new_content class FederationGroupsAcceptInviteServlet(BaseFederationServlet): @@ -942,14 +925,13 @@ class FederationGroupsAcceptInviteServlet(BaseFederationServlet): PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite" - @defer.inlineCallbacks - def on_POST(self, origin, content, query, group_id, user_id): + async def on_POST(self, origin, content, query, group_id, user_id): if get_domain_from_id(user_id) != origin: raise SynapseError(403, "user_id doesn't match origin") - new_content = yield self.handler.accept_invite(group_id, user_id, content) + new_content = await self.handler.accept_invite(group_id, user_id, content) - defer.returnValue((200, new_content)) + return 200, new_content class FederationGroupsJoinServlet(BaseFederationServlet): @@ -958,14 +940,13 @@ class FederationGroupsJoinServlet(BaseFederationServlet): PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join" - @defer.inlineCallbacks - def on_POST(self, origin, content, query, group_id, user_id): + async def on_POST(self, origin, content, query, group_id, user_id): if get_domain_from_id(user_id) != origin: raise SynapseError(403, "user_id doesn't match origin") - new_content = yield self.handler.join_group(group_id, user_id, content) + new_content = await self.handler.join_group(group_id, user_id, content) - defer.returnValue((200, new_content)) + return 200, new_content class FederationGroupsRemoveUserServlet(BaseFederationServlet): @@ -974,17 +955,16 @@ class FederationGroupsRemoveUserServlet(BaseFederationServlet): PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove" - @defer.inlineCallbacks - def on_POST(self, origin, content, query, group_id, user_id): + async def on_POST(self, origin, content, query, group_id, user_id): requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - new_content = yield self.handler.remove_user_from_group( + new_content = await self.handler.remove_user_from_group( group_id, user_id, requester_user_id, content ) - defer.returnValue((200, new_content)) + return 200, new_content class FederationGroupsLocalInviteServlet(BaseFederationServlet): @@ -993,14 +973,13 @@ class FederationGroupsLocalInviteServlet(BaseFederationServlet): PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite" - @defer.inlineCallbacks - def on_POST(self, origin, content, query, group_id, user_id): + async def on_POST(self, origin, content, query, group_id, user_id): if get_domain_from_id(group_id) != origin: raise SynapseError(403, "group_id doesn't match origin") - new_content = yield self.handler.on_invite(group_id, user_id, content) + new_content = await self.handler.on_invite(group_id, user_id, content) - defer.returnValue((200, new_content)) + return 200, new_content class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet): @@ -1009,16 +988,15 @@ class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet): PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove" - @defer.inlineCallbacks - def on_POST(self, origin, content, query, group_id, user_id): + async def on_POST(self, origin, content, query, group_id, user_id): if get_domain_from_id(group_id) != origin: raise SynapseError(403, "user_id doesn't match origin") - new_content = yield self.handler.user_removed_from_group( + new_content = await self.handler.user_removed_from_group( group_id, user_id, content ) - defer.returnValue((200, new_content)) + return 200, new_content class FederationGroupsRenewAttestaionServlet(BaseFederationServlet): @@ -1027,15 +1005,14 @@ class FederationGroupsRenewAttestaionServlet(BaseFederationServlet): PATH = "/groups/(?P<group_id>[^/]*)/renew_attestation/(?P<user_id>[^/]*)" - @defer.inlineCallbacks - def on_POST(self, origin, content, query, group_id, user_id): + async def on_POST(self, origin, content, query, group_id, user_id): # We don't need to check auth here as we check the attestation signatures - new_content = yield self.handler.on_renew_attestation( + new_content = await self.handler.on_renew_attestation( group_id, user_id, content ) - defer.returnValue((200, new_content)) + return 200, new_content class FederationGroupsSummaryRoomsServlet(BaseFederationServlet): @@ -1052,8 +1029,7 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet): "/rooms/(?P<room_id>[^/]*)" ) - @defer.inlineCallbacks - def on_POST(self, origin, content, query, group_id, category_id, room_id): + async def on_POST(self, origin, content, query, group_id, category_id, room_id): requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1061,7 +1037,7 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet): if category_id == "": raise SynapseError(400, "category_id cannot be empty string") - resp = yield self.handler.update_group_summary_room( + resp = await self.handler.update_group_summary_room( group_id, requester_user_id, room_id=room_id, @@ -1069,10 +1045,9 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet): content=content, ) - defer.returnValue((200, resp)) + return 200, resp - @defer.inlineCallbacks - def on_DELETE(self, origin, content, query, group_id, category_id, room_id): + async def on_DELETE(self, origin, content, query, group_id, category_id, room_id): requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1080,11 +1055,11 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet): if category_id == "": raise SynapseError(400, "category_id cannot be empty string") - resp = yield self.handler.delete_group_summary_room( + resp = await self.handler.delete_group_summary_room( group_id, requester_user_id, room_id=room_id, category_id=category_id ) - defer.returnValue((200, resp)) + return 200, resp class FederationGroupsCategoriesServlet(BaseFederationServlet): @@ -1093,15 +1068,14 @@ class FederationGroupsCategoriesServlet(BaseFederationServlet): PATH = "/groups/(?P<group_id>[^/]*)/categories/?" - @defer.inlineCallbacks - def on_GET(self, origin, content, query, group_id): + async def on_GET(self, origin, content, query, group_id): requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - resp = yield self.handler.get_group_categories(group_id, requester_user_id) + resp = await self.handler.get_group_categories(group_id, requester_user_id) - defer.returnValue((200, resp)) + return 200, resp class FederationGroupsCategoryServlet(BaseFederationServlet): @@ -1110,20 +1084,18 @@ class FederationGroupsCategoryServlet(BaseFederationServlet): PATH = "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)" - @defer.inlineCallbacks - def on_GET(self, origin, content, query, group_id, category_id): + async def on_GET(self, origin, content, query, group_id, category_id): requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - resp = yield self.handler.get_group_category( + resp = await self.handler.get_group_category( group_id, requester_user_id, category_id ) - defer.returnValue((200, resp)) + return 200, resp - @defer.inlineCallbacks - def on_POST(self, origin, content, query, group_id, category_id): + async def on_POST(self, origin, content, query, group_id, category_id): requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1131,14 +1103,13 @@ class FederationGroupsCategoryServlet(BaseFederationServlet): if category_id == "": raise SynapseError(400, "category_id cannot be empty string") - resp = yield self.handler.upsert_group_category( + resp = await self.handler.upsert_group_category( group_id, requester_user_id, category_id, content ) - defer.returnValue((200, resp)) + return 200, resp - @defer.inlineCallbacks - def on_DELETE(self, origin, content, query, group_id, category_id): + async def on_DELETE(self, origin, content, query, group_id, category_id): requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1146,11 +1117,11 @@ class FederationGroupsCategoryServlet(BaseFederationServlet): if category_id == "": raise SynapseError(400, "category_id cannot be empty string") - resp = yield self.handler.delete_group_category( + resp = await self.handler.delete_group_category( group_id, requester_user_id, category_id ) - defer.returnValue((200, resp)) + return 200, resp class FederationGroupsRolesServlet(BaseFederationServlet): @@ -1159,15 +1130,14 @@ class FederationGroupsRolesServlet(BaseFederationServlet): PATH = "/groups/(?P<group_id>[^/]*)/roles/?" - @defer.inlineCallbacks - def on_GET(self, origin, content, query, group_id): + async def on_GET(self, origin, content, query, group_id): requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - resp = yield self.handler.get_group_roles(group_id, requester_user_id) + resp = await self.handler.get_group_roles(group_id, requester_user_id) - defer.returnValue((200, resp)) + return 200, resp class FederationGroupsRoleServlet(BaseFederationServlet): @@ -1176,18 +1146,16 @@ class FederationGroupsRoleServlet(BaseFederationServlet): PATH = "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)" - @defer.inlineCallbacks - def on_GET(self, origin, content, query, group_id, role_id): + async def on_GET(self, origin, content, query, group_id, role_id): requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - resp = yield self.handler.get_group_role(group_id, requester_user_id, role_id) + resp = await self.handler.get_group_role(group_id, requester_user_id, role_id) - defer.returnValue((200, resp)) + return 200, resp - @defer.inlineCallbacks - def on_POST(self, origin, content, query, group_id, role_id): + async def on_POST(self, origin, content, query, group_id, role_id): requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1195,14 +1163,13 @@ class FederationGroupsRoleServlet(BaseFederationServlet): if role_id == "": raise SynapseError(400, "role_id cannot be empty string") - resp = yield self.handler.update_group_role( + resp = await self.handler.update_group_role( group_id, requester_user_id, role_id, content ) - defer.returnValue((200, resp)) + return 200, resp - @defer.inlineCallbacks - def on_DELETE(self, origin, content, query, group_id, role_id): + async def on_DELETE(self, origin, content, query, group_id, role_id): requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1210,11 +1177,11 @@ class FederationGroupsRoleServlet(BaseFederationServlet): if role_id == "": raise SynapseError(400, "role_id cannot be empty string") - resp = yield self.handler.delete_group_role( + resp = await self.handler.delete_group_role( group_id, requester_user_id, role_id ) - defer.returnValue((200, resp)) + return 200, resp class FederationGroupsSummaryUsersServlet(BaseFederationServlet): @@ -1231,8 +1198,7 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet): "/users/(?P<user_id>[^/]*)" ) - @defer.inlineCallbacks - def on_POST(self, origin, content, query, group_id, role_id, user_id): + async def on_POST(self, origin, content, query, group_id, role_id, user_id): requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1240,7 +1206,7 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet): if role_id == "": raise SynapseError(400, "role_id cannot be empty string") - resp = yield self.handler.update_group_summary_user( + resp = await self.handler.update_group_summary_user( group_id, requester_user_id, user_id=user_id, @@ -1248,10 +1214,9 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet): content=content, ) - defer.returnValue((200, resp)) + return 200, resp - @defer.inlineCallbacks - def on_DELETE(self, origin, content, query, group_id, role_id, user_id): + async def on_DELETE(self, origin, content, query, group_id, role_id, user_id): requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1259,11 +1224,11 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet): if role_id == "": raise SynapseError(400, "role_id cannot be empty string") - resp = yield self.handler.delete_group_summary_user( + resp = await self.handler.delete_group_summary_user( group_id, requester_user_id, user_id=user_id, role_id=role_id ) - defer.returnValue((200, resp)) + return 200, resp class FederationGroupsBulkPublicisedServlet(BaseFederationServlet): @@ -1272,13 +1237,12 @@ class FederationGroupsBulkPublicisedServlet(BaseFederationServlet): PATH = "/get_groups_publicised" - @defer.inlineCallbacks - def on_POST(self, origin, content, query): - resp = yield self.handler.bulk_get_publicised_groups( + async def on_POST(self, origin, content, query): + resp = await self.handler.bulk_get_publicised_groups( content["user_ids"], proxy=False ) - defer.returnValue((200, resp)) + return 200, resp class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet): @@ -1287,17 +1251,16 @@ class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet): PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy" - @defer.inlineCallbacks - def on_PUT(self, origin, content, query, group_id): + async def on_PUT(self, origin, content, query, group_id): requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - new_content = yield self.handler.set_group_join_policy( + new_content = await self.handler.set_group_join_policy( group_id, requester_user_id, content ) - defer.returnValue((200, new_content)) + return 200, new_content class RoomComplexityServlet(BaseFederationServlet): @@ -1309,18 +1272,17 @@ class RoomComplexityServlet(BaseFederationServlet): PATH = "/rooms/(?P<room_id>[^/]*)/complexity" PREFIX = FEDERATION_UNSTABLE_PREFIX - @defer.inlineCallbacks - def on_GET(self, origin, content, query, room_id): + async def on_GET(self, origin, content, query, room_id): store = self.handler.hs.get_datastore() - is_public = yield store.is_room_world_readable_or_publicly_joinable(room_id) + is_public = await store.is_room_world_readable_or_publicly_joinable(room_id) if not is_public: raise SynapseError(404, "Room not found", errcode=Codes.INVALID_PARAM) - complexity = yield store.get_room_complexity(room_id) - defer.returnValue((200, complexity)) + complexity = await store.get_room_complexity(room_id) + return 200, complexity FEDERATION_SERVLET_CLASSES = ( diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 941ebfa107..e8a651e231 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -17,6 +17,10 @@ import logging from twisted.internet import defer +from synapse.api.constants import Membership +from synapse.types import RoomStreamToken +from synapse.visibility import filter_events_for_client + from ._base import BaseHandler logger = logging.getLogger(__name__) @@ -89,3 +93,182 @@ class AdminHandler(BaseHandler): ret = yield self.store.search_users(term) defer.returnValue(ret) + + @defer.inlineCallbacks + def export_user_data(self, user_id, writer): + """Write all data we have on the user to the given writer. + + Args: + user_id (str) + writer (ExfiltrationWriter) + + Returns: + defer.Deferred: Resolves when all data for a user has been written. + The returned value is that returned by `writer.finished()`. + """ + # Get all rooms the user is in or has been in + rooms = yield self.store.get_rooms_for_user_where_membership_is( + user_id, + membership_list=( + Membership.JOIN, + Membership.LEAVE, + Membership.BAN, + Membership.INVITE, + ), + ) + + # We only try and fetch events for rooms the user has been in. If + # they've been e.g. invited to a room without joining then we handle + # those seperately. + rooms_user_has_been_in = yield self.store.get_rooms_user_has_been_in(user_id) + + for index, room in enumerate(rooms): + room_id = room.room_id + + logger.info( + "[%s] Handling room %s, %d/%d", user_id, room_id, index + 1, len(rooms) + ) + + forgotten = yield self.store.did_forget(user_id, room_id) + if forgotten: + logger.info("[%s] User forgot room %d, ignoring", user_id, room_id) + continue + + if room_id not in rooms_user_has_been_in: + # If we haven't been in the rooms then the filtering code below + # won't return anything, so we need to handle these cases + # explicitly. + + if room.membership == Membership.INVITE: + event_id = room.event_id + invite = yield self.store.get_event(event_id, allow_none=True) + if invite: + invited_state = invite.unsigned["invite_room_state"] + writer.write_invite(room_id, invite, invited_state) + + continue + + # We only want to bother fetching events up to the last time they + # were joined. We estimate that point by looking at the + # stream_ordering of the last membership if it wasn't a join. + if room.membership == Membership.JOIN: + stream_ordering = yield self.store.get_room_max_stream_ordering() + else: + stream_ordering = room.stream_ordering + + from_key = str(RoomStreamToken(0, 0)) + to_key = str(RoomStreamToken(None, stream_ordering)) + + written_events = set() # Events that we've processed in this room + + # We need to track gaps in the events stream so that we can then + # write out the state at those events. We do this by keeping track + # of events whose prev events we haven't seen. + + # Map from event ID to prev events that haven't been processed, + # dict[str, set[str]]. + event_to_unseen_prevs = {} + + # The reverse mapping to above, i.e. map from unseen event to events + # that have the unseen event in their prev_events, i.e. the unseen + # events "children". dict[str, set[str]] + unseen_to_child_events = {} + + # We fetch events in the room the user could see by fetching *all* + # events that we have and then filtering, this isn't the most + # efficient method perhaps but it does guarantee we get everything. + while True: + events, _ = yield self.store.paginate_room_events( + room_id, from_key, to_key, limit=100, direction="f" + ) + if not events: + break + + from_key = events[-1].internal_metadata.after + + events = yield filter_events_for_client(self.store, user_id, events) + + writer.write_events(room_id, events) + + # Update the extremity tracking dicts + for event in events: + # Check if we have any prev events that haven't been + # processed yet, and add those to the appropriate dicts. + unseen_events = set(event.prev_event_ids()) - written_events + if unseen_events: + event_to_unseen_prevs[event.event_id] = unseen_events + for unseen in unseen_events: + unseen_to_child_events.setdefault(unseen, set()).add( + event.event_id + ) + + # Now check if this event is an unseen prev event, if so + # then we remove this event from the appropriate dicts. + for child_id in unseen_to_child_events.pop(event.event_id, []): + event_to_unseen_prevs[child_id].discard(event.event_id) + + written_events.add(event.event_id) + + logger.info( + "Written %d events in room %s", len(written_events), room_id + ) + + # Extremities are the events who have at least one unseen prev event. + extremities = ( + event_id + for event_id, unseen_prevs in event_to_unseen_prevs.items() + if unseen_prevs + ) + for event_id in extremities: + if not event_to_unseen_prevs[event_id]: + continue + state = yield self.store.get_state_for_event(event_id) + writer.write_state(room_id, event_id, state) + + defer.returnValue(writer.finished()) + + +class ExfiltrationWriter(object): + """Interface used to specify how to write exported data. + """ + + def write_events(self, room_id, events): + """Write a batch of events for a room. + + Args: + room_id (str) + events (list[FrozenEvent]) + """ + pass + + def write_state(self, room_id, event_id, state): + """Write the state at the given event in the room. + + This only gets called for backward extremities rather than for each + event. + + Args: + room_id (str) + event_id (str) + state (dict[tuple[str, str], FrozenEvent]) + """ + pass + + def write_invite(self, room_id, event, state): + """Write an invite for the room, with associated invite state. + + Args: + room_id (str) + event (FrozenEvent) + state (dict[tuple[str, str], dict]): A subset of the state at the + invite, with a subset of the event keys (type, state_key + content and sender) + """ + + def finished(self): + """Called when all data has succesfully been exported and written. + + This functions return value is passed to the caller of + `export_user_data`. + """ + pass diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index ef5585aa99..d4d6574975 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -15,6 +15,7 @@ # limitations under the License. import logging +import time import unicodedata import attr @@ -34,6 +35,7 @@ from synapse.api.errors import ( LoginError, StoreError, SynapseError, + UserDeactivatedError, ) from synapse.api.ratelimiting import Ratelimiter from synapse.logging.context import defer_to_thread @@ -558,7 +560,7 @@ class AuthHandler(BaseHandler): return self.sessions[session_id] @defer.inlineCallbacks - def get_access_token_for_user_id(self, user_id, device_id=None): + def get_access_token_for_user_id(self, user_id, device_id, valid_until_ms): """ Creates a new access token for the user with the given user ID. @@ -572,15 +574,27 @@ class AuthHandler(BaseHandler): device_id (str|None): the device ID to associate with the tokens. None to leave the tokens unassociated with a device (deprecated: we should always have a device ID) + valid_until_ms (int|None): when the token is valid until. None for + no expiry. Returns: The access token for the user's session. Raises: StoreError if there was a problem storing the token. """ - logger.info("Logging in user %s on device %s", user_id, device_id) - access_token = yield self.issue_access_token(user_id, device_id) + fmt_expiry = "" + if valid_until_ms is not None: + fmt_expiry = time.strftime( + " until %Y-%m-%d %H:%M:%S", time.localtime(valid_until_ms / 1000.0) + ) + logger.info("Logging in user %s on device %s%s", user_id, device_id, fmt_expiry) + yield self.auth.check_auth_blocking(user_id) + access_token = self.macaroon_gen.generate_access_token(user_id) + yield self.store.add_access_token_to_user( + user_id, access_token, device_id, valid_until_ms + ) + # the device *should* have been registered before we got here; however, # it's possible we raced against a DELETE operation. The thing we # really don't want is active access_tokens without a record of the @@ -610,6 +624,7 @@ class AuthHandler(BaseHandler): Raises: LimitExceededError if the ratelimiter's login requests count for this user is too high too proceed. + UserDeactivatedError if a user is found but is deactivated. """ self.ratelimit_login_per_account(user_id) res = yield self._find_user_id_and_pwd_hash(user_id) @@ -825,6 +840,13 @@ class AuthHandler(BaseHandler): if not lookupres: defer.returnValue(None) (user_id, password_hash) = lookupres + + # If the password hash is None, the account has likely been deactivated + if not password_hash: + deactivated = yield self.store.get_user_deactivated_status(user_id) + if deactivated: + raise UserDeactivatedError("This account has been deactivated") + result = yield self.validate_hash(password, password_hash) if not result: logger.warn("Failed password login for user %s", user_id) @@ -832,12 +854,6 @@ class AuthHandler(BaseHandler): defer.returnValue(user_id) @defer.inlineCallbacks - def issue_access_token(self, user_id, device_id=None): - access_token = self.macaroon_gen.generate_access_token(user_id) - yield self.store.add_access_token_to_user(user_id, access_token, device_id) - defer.returnValue(access_token) - - @defer.inlineCallbacks def validate_short_term_login_token_and_get_user_id(self, login_token): auth_api = self.hs.get_auth() user_id = None diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 1997df556e..454f99524b 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -24,12 +24,7 @@ from signedjson.sign import SignatureVerifyException, verify_signed_json from twisted.internet import defer -from synapse.api.errors import ( - CodeMessageException, - Codes, - FederationDeniedError, - SynapseError, -) +from synapse.api.errors import CodeMessageException, Codes, SynapseError from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.types import ( UserID, @@ -554,9 +549,6 @@ def _exception_to_failure(e): if isinstance(e, NotRetryingDestination): return {"status": 503, "message": "Not ready for retry"} - if isinstance(e, FederationDeniedError): - return {"status": 403, "message": "Federation Denied"} - # include ConnectionRefused and other errors # # Note that some Exceptions (notably twisted's ResponseFailed etc) don't diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index c82b1933f2..546d6169e9 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -118,7 +118,7 @@ class IdentityHandler(BaseHandler): raise SynapseError(400, "No client_secret in creds") try: - data = yield self.http_client.post_urlencoded_get_json( + data = yield self.http_client.post_json_get_json( "https://%s%s" % (id_server, "/_matrix/identity/api/v1/3pid/bind"), {"sid": creds["sid"], "client_secret": client_secret, "mxid": mxid}, ) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index eaeda7a5cb..6d7a987f13 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -23,6 +23,7 @@ from canonicaljson import encode_canonical_json, json from twisted.internet import defer from twisted.internet.defer import succeed +from synapse import event_auth from synapse.api.constants import EventTypes, Membership, RelationTypes from synapse.api.errors import ( AuthError, @@ -784,6 +785,20 @@ class EventCreationHandler(object): event.signatures.update(returned_invite.signatures) if event.type == EventTypes.Redaction: + original_event = yield self.store.get_event( + event.redacts, + check_redacted=False, + get_prev_content=False, + allow_rejected=False, + allow_none=True, + check_room_id=event.room_id, + ) + + # we can make some additional checks now if we have the original event. + if original_event: + if original_event.type == EventTypes.Create: + raise AuthError(403, "Redacting create events is not permitted") + prev_state_ids = yield context.get_prev_state_ids(self.store) auth_events_ids = yield self.auth.compute_auth_events( event, prev_state_ids, for_verification=True @@ -791,18 +806,18 @@ class EventCreationHandler(object): auth_events = yield self.store.get_events(auth_events_ids) auth_events = {(e.type, e.state_key): e for e in auth_events.values()} room_version = yield self.store.get_room_version(event.room_id) - if self.auth.check_redaction(room_version, event, auth_events=auth_events): - original_event = yield self.store.get_event( - event.redacts, - check_redacted=False, - get_prev_content=False, - allow_rejected=False, - allow_none=False, - ) + + if event_auth.check_redaction(room_version, event, auth_events=auth_events): + # this user doesn't have 'redact' rights, so we need to do some more + # checks on the original event. Let's start by checking the original + # event exists. + if not original_event: + raise NotFoundError("Could not find event %s" % (event.redacts,)) + if event.user_id != original_event.user_id: raise AuthError(403, "You don't have permission to redact events") - # We've already checked. + # all the checks are done. event.internal_metadata.recheck_redaction = False if event.type == EventTypes.Create: diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index d8462b75ec..a2388a7091 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -303,6 +303,10 @@ class BaseProfileHandler(BaseHandler): if not self.hs.config.require_auth_for_profile_requests or not requester: return + # Always allow the user to query their own profile. + if target_user.to_string() == requester.to_string(): + return + try: requester_rooms = yield self.store.get_rooms_for_user(requester.to_string()) target_user_rooms = yield self.store.get_rooms_for_user( diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index a3e553d5f5..bb7cfd71b9 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -84,6 +84,8 @@ class RegistrationHandler(BaseHandler): self.device_handler = hs.get_device_handler() self.pusher_pool = hs.get_pusherpool() + self.session_lifetime = hs.config.session_lifetime + @defer.inlineCallbacks def check_username(self, localpart, guest_access_token=None, assigned_user_id=None): if types.contains_invalid_mxid_characters(localpart): @@ -584,7 +586,7 @@ class RegistrationHandler(BaseHandler): address=address, ) else: - return self.store.register( + return self.store.register_user( user_id=user_id, password_hash=password_hash, was_guest=was_guest, @@ -599,6 +601,8 @@ class RegistrationHandler(BaseHandler): def register_device(self, user_id, device_id, initial_display_name, is_guest=False): """Register a device for a user and generate an access token. + The access token will be limited by the homeserver's session_lifetime config. + Args: user_id (str): full canonical @user:id device_id (str|None): The device ID to check, or None to generate @@ -619,20 +623,29 @@ class RegistrationHandler(BaseHandler): is_guest=is_guest, ) defer.returnValue((r["device_id"], r["access_token"])) - else: - device_id = yield self.device_handler.check_device_registered( - user_id, device_id, initial_display_name - ) + + valid_until_ms = None + if self.session_lifetime is not None: if is_guest: - access_token = self.macaroon_gen.generate_access_token( - user_id, ["guest = true"] - ) - else: - access_token = yield self._auth_handler.get_access_token_for_user_id( - user_id, device_id=device_id + raise Exception( + "session_lifetime is not currently implemented for guest access" ) + valid_until_ms = self.clock.time_msec() + self.session_lifetime + + device_id = yield self.device_handler.check_device_registered( + user_id, device_id, initial_display_name + ) + if is_guest: + assert valid_until_ms is None + access_token = self.macaroon_gen.generate_access_token( + user_id, ["guest = true"] + ) + else: + access_token = yield self._auth_handler.get_access_token_for_user_id( + user_id, device_id=device_id, valid_until_ms=valid_until_ms + ) - defer.returnValue((device_id, access_token)) + defer.returnValue((device_id, access_token)) @defer.inlineCallbacks def post_registration_actions( diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 679daaa074..e0196ef83e 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -29,7 +29,7 @@ from twisted.internet import defer import synapse.server import synapse.types from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import AuthError, Codes, SynapseError +from synapse.api.errors import AuthError, Codes, HttpResponseException, SynapseError from synapse.types import RoomID, UserID from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_joined_room, user_left_room @@ -872,9 +872,23 @@ class RoomMemberHandler(object): "sender_avatar_url": inviter_avatar_url, } - data = yield self.simple_http_client.post_urlencoded_get_json( - is_url, invite_config - ) + try: + data = yield self.simple_http_client.post_json_get_json( + is_url, invite_config + ) + except HttpResponseException as e: + # Some identity servers may only support application/x-www-form-urlencoded + # types. This is especially true with old instances of Sydent, see + # https://github.com/matrix-org/sydent/pull/170 + logger.info( + "Failed to POST %s with JSON, falling back to urlencoded form: %s", + is_url, + e, + ) + data = yield self.simple_http_client.post_urlencoded_get_json( + is_url, invite_config + ) + # TODO: Check for success token = data["token"] public_keys = data.get("public_keys", []) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index dee3710f68..e60334547e 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -36,6 +36,7 @@ from twisted.internet.task import _EPSILON, Cooperator from twisted.web._newclient import ResponseDone from twisted.web.http_headers import Headers +import synapse.logging.opentracing as opentracing import synapse.metrics import synapse.util.retryutils from synapse.api.errors import ( @@ -339,9 +340,25 @@ class MatrixFederationHttpClient(object): else: query_bytes = b"" - headers_dict = {b"User-Agent": [self.version_string_bytes]} + # Retreive current span + scope = opentracing.start_active_span( + "outgoing-federation-request", + tags={ + opentracing.tags.SPAN_KIND: opentracing.tags.SPAN_KIND_RPC_CLIENT, + opentracing.tags.PEER_ADDRESS: request.destination, + opentracing.tags.HTTP_METHOD: request.method, + opentracing.tags.HTTP_URL: request.path, + }, + finish_on_close=True, + ) + + # Inject the span into the headers + headers_dict = {} + opentracing.inject_active_span_byte_dict(headers_dict, request.destination) - with limiter: + headers_dict[b"User-Agent"] = [self.version_string_bytes] + + with limiter, scope: # XXX: Would be much nicer to retry only at the transaction-layer # (once we have reliable transactions in place) if long_retries: @@ -419,6 +436,10 @@ class MatrixFederationHttpClient(object): response.phrase.decode("ascii", errors="replace"), ) + opentracing.set_tag( + opentracing.tags.HTTP_STATUS_CODE, response.code + ) + if 200 <= response.code < 300: pass else: @@ -499,8 +520,7 @@ class MatrixFederationHttpClient(object): _flatten_response_never_received(e), ) raise - - defer.returnValue(response) + defer.returnValue(response) def build_auth_headers( self, destination, method, url_bytes, content=None, destination_is=None diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index cd8415acd5..889038ff25 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -20,6 +20,7 @@ import logging from canonicaljson import json from synapse.api.errors import Codes, SynapseError +from synapse.logging.opentracing import trace_servlet logger = logging.getLogger(__name__) @@ -290,7 +291,11 @@ class RestServlet(object): for method in ("GET", "PUT", "POST", "OPTIONS", "DELETE"): if hasattr(self, "on_%s" % (method,)): method_handler = getattr(self, "on_%s" % (method,)) - http_server.register_paths(method, patterns, method_handler) + http_server.register_paths( + method, + patterns, + trace_servlet(self.__class__.__name__, method_handler), + ) else: raise NotImplementedError("RestServlet must register something.") diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 30dfa1d6b2..b456c31f70 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -186,6 +186,7 @@ class LoggingContext(object): "alive", "request", "tag", + "scope", ] thread_local = threading.local() @@ -238,6 +239,7 @@ class LoggingContext(object): self.request = None self.tag = "" self.alive = True + self.scope = None self.parent_context = parent_context @@ -322,10 +324,12 @@ class LoggingContext(object): another LoggingContext """ - # 'request' is the only field we currently use in the logger, so that's - # all we need to copy + # we track the current request record.request = self.request + # we also track the current scope: + record.scope = self.scope + def start(self): if get_thread_id() != self.main_thread: logger.warning("Started logcontext %s on different thread", self) diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py new file mode 100644 index 0000000000..56d900080b --- /dev/null +++ b/synapse/logging/opentracing.py @@ -0,0 +1,357 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.import opentracing + + +# NOTE +# This is a small wrapper around opentracing because opentracing is not currently +# packaged downstream (specifically debian). Since opentracing instrumentation is +# fairly invasive it was awkward to make it optional. As a result we opted to encapsulate +# all opentracing state in these methods which effectively noop if opentracing is +# not present. We should strongly consider encouraging the downstream distributers +# to package opentracing and making opentracing a full dependency. In order to facilitate +# this move the methods have work very similarly to opentracing's and it should only +# be a matter of few regexes to move over to opentracing's access patterns proper. + +import contextlib +import logging +import re +from functools import wraps + +from twisted.internet import defer + +from synapse.config import ConfigError + +try: + import opentracing +except ImportError: + opentracing = None +try: + from jaeger_client import Config as JaegerConfig + from synapse.logging.scopecontextmanager import LogContextScopeManager +except ImportError: + JaegerConfig = None + LogContextScopeManager = None + + +logger = logging.getLogger(__name__) + + +class _DumTagNames(object): + """wrapper of opentracings tags. We need to have them if we + want to reference them without opentracing around. Clearly they + should never actually show up in a trace. `set_tags` overwrites + these with the correct ones.""" + + INVALID_TAG = "invalid-tag" + COMPONENT = INVALID_TAG + DATABASE_INSTANCE = INVALID_TAG + DATABASE_STATEMENT = INVALID_TAG + DATABASE_TYPE = INVALID_TAG + DATABASE_USER = INVALID_TAG + ERROR = INVALID_TAG + HTTP_METHOD = INVALID_TAG + HTTP_STATUS_CODE = INVALID_TAG + HTTP_URL = INVALID_TAG + MESSAGE_BUS_DESTINATION = INVALID_TAG + PEER_ADDRESS = INVALID_TAG + PEER_HOSTNAME = INVALID_TAG + PEER_HOST_IPV4 = INVALID_TAG + PEER_HOST_IPV6 = INVALID_TAG + PEER_PORT = INVALID_TAG + PEER_SERVICE = INVALID_TAG + SAMPLING_PRIORITY = INVALID_TAG + SERVICE = INVALID_TAG + SPAN_KIND = INVALID_TAG + SPAN_KIND_CONSUMER = INVALID_TAG + SPAN_KIND_PRODUCER = INVALID_TAG + SPAN_KIND_RPC_CLIENT = INVALID_TAG + SPAN_KIND_RPC_SERVER = INVALID_TAG + + +def only_if_tracing(func): + """Executes the function only if we're tracing. Otherwise return. + Assumes the function wrapped may return None""" + + @wraps(func) + def _only_if_tracing_inner(*args, **kwargs): + if opentracing: + return func(*args, **kwargs) + else: + return + + return _only_if_tracing_inner + + +# A regex which matches the server_names to expose traces for. +# None means 'block everything'. +_homeserver_whitelist = None + +tags = _DumTagNames + + +def init_tracer(config): + """Set the whitelists and initialise the JaegerClient tracer + + Args: + config (HomeserverConfig): The config used by the homeserver + """ + global opentracing + if not config.opentracer_enabled: + # We don't have a tracer + opentracing = None + return + + if not opentracing or not JaegerConfig: + raise ConfigError( + "The server has been configured to use opentracing but opentracing is not " + "installed." + ) + + # Include the worker name + name = config.worker_name if config.worker_name else "master" + + set_homeserver_whitelist(config.opentracer_whitelist) + jaeger_config = JaegerConfig( + config={"sampler": {"type": "const", "param": 1}, "logging": True}, + service_name="{} {}".format(config.server_name, name), + scope_manager=LogContextScopeManager(config), + ) + jaeger_config.initialize_tracer() + + # Set up tags to be opentracing's tags + global tags + tags = opentracing.tags + + +@contextlib.contextmanager +def _noop_context_manager(*args, **kwargs): + """Does absolutely nothing really well. Can be entered and exited arbitrarily. + Good substitute for an opentracing scope.""" + yield + + +# Could use kwargs but I want these to be explicit +def start_active_span( + operation_name, + child_of=None, + references=None, + tags=None, + start_time=None, + ignore_active_span=False, + finish_on_close=True, +): + """Starts an active opentracing span. Note, the scope doesn't become active + until it has been entered, however, the span starts from the time this + message is called. + Args: + See opentracing.tracer + Returns: + scope (Scope) or noop_context_manager + """ + if opentracing is None: + return _noop_context_manager() + else: + # We need to enter the scope here for the logcontext to become active + return opentracing.tracer.start_active_span( + operation_name, + child_of=child_of, + references=references, + tags=tags, + start_time=start_time, + ignore_active_span=ignore_active_span, + finish_on_close=finish_on_close, + ) + + +@only_if_tracing +def close_active_span(): + """Closes the active span. This will close it's logcontext if the context + was made for the span""" + opentracing.tracer.scope_manager.active.__exit__(None, None, None) + + +@only_if_tracing +def set_tag(key, value): + """Set's a tag on the active span""" + opentracing.tracer.active_span.set_tag(key, value) + + +@only_if_tracing +def log_kv(key_values, timestamp=None): + """Log to the active span""" + opentracing.tracer.active_span.log_kv(key_values, timestamp) + + +# Note: we don't have a get baggage items because we're trying to hide all +# scope and span state from synapse. I think this method may also be useless +# as a result +@only_if_tracing +def set_baggage_item(key, value): + """Attach baggage to the active span""" + opentracing.tracer.active_span.set_baggage_item(key, value) + + +@only_if_tracing +def set_operation_name(operation_name): + """Sets the operation name of the active span""" + opentracing.tracer.active_span.set_operation_name(operation_name) + + +@only_if_tracing +def set_homeserver_whitelist(homeserver_whitelist): + """Sets the whitelist + + Args: + homeserver_whitelist (iterable of strings): regex of whitelisted homeservers + """ + global _homeserver_whitelist + if homeserver_whitelist: + # Makes a single regex which accepts all passed in regexes in the list + _homeserver_whitelist = re.compile( + "({})".format(")|(".join(homeserver_whitelist)) + ) + + +@only_if_tracing +def whitelisted_homeserver(destination): + """Checks if a destination matches the whitelist + Args: + destination (String)""" + if _homeserver_whitelist: + return _homeserver_whitelist.match(destination) + return False + + +def start_active_span_from_context( + headers, + operation_name, + references=None, + tags=None, + start_time=None, + ignore_active_span=False, + finish_on_close=True, +): + """ + Extracts a span context from Twisted Headers. + args: + headers (twisted.web.http_headers.Headers) + returns: + span_context (opentracing.span.SpanContext) + """ + # Twisted encodes the values as lists whereas opentracing doesn't. + # So, we take the first item in the list. + # Also, twisted uses byte arrays while opentracing expects strings. + if opentracing is None: + return _noop_context_manager() + + header_dict = {k.decode(): v[0].decode() for k, v in headers.getAllRawHeaders()} + context = opentracing.tracer.extract(opentracing.Format.HTTP_HEADERS, header_dict) + + return opentracing.tracer.start_active_span( + operation_name, + child_of=context, + references=references, + tags=tags, + start_time=start_time, + ignore_active_span=ignore_active_span, + finish_on_close=finish_on_close, + ) + + +@only_if_tracing +def inject_active_span_twisted_headers(headers, destination): + """ + Injects a span context into twisted headers inplace + + Args: + headers (twisted.web.http_headers.Headers) + span (opentracing.Span) + + Returns: + Inplace modification of headers + + Note: + The headers set by the tracer are custom to the tracer implementation which + should be unique enough that they don't interfere with any headers set by + synapse or twisted. If we're still using jaeger these headers would be those + here: + https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/constants.py + """ + + if not whitelisted_homeserver(destination): + return + + span = opentracing.tracer.active_span + carrier = {} + opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier) + + for key, value in carrier.items(): + headers.addRawHeaders(key, value) + + +@only_if_tracing +def inject_active_span_byte_dict(headers, destination): + """ + Injects a span context into a dict where the headers are encoded as byte + strings + + Args: + headers (dict) + span (opentracing.Span) + + Returns: + Inplace modification of headers + + Note: + The headers set by the tracer are custom to the tracer implementation which + should be unique enough that they don't interfere with any headers set by + synapse or twisted. If we're still using jaeger these headers would be those + here: + https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/constants.py + """ + if not whitelisted_homeserver(destination): + return + + span = opentracing.tracer.active_span + + carrier = {} + opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier) + + for key, value in carrier.items(): + headers[key.encode()] = [value.encode()] + + +def trace_servlet(servlet_name, func): + """Decorator which traces a serlet. It starts a span with some servlet specific + tags such as the servlet_name and request information""" + + @wraps(func) + @defer.inlineCallbacks + def _trace_servlet_inner(request, *args, **kwargs): + with start_active_span( + "incoming-client-request", + tags={ + "request_id": request.get_request_id(), + tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER, + tags.HTTP_METHOD: request.get_method(), + tags.HTTP_URL: request.get_redacted_uri(), + tags.PEER_HOST_IPV6: request.getClientIP(), + "servlet_name": servlet_name, + }, + ): + result = yield defer.maybeDeferred(func, request, *args, **kwargs) + defer.returnValue(result) + + return _trace_servlet_inner diff --git a/synapse/logging/scopecontextmanager.py b/synapse/logging/scopecontextmanager.py new file mode 100644 index 0000000000..8c661302c9 --- /dev/null +++ b/synapse/logging/scopecontextmanager.py @@ -0,0 +1,138 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.import logging + +import logging + +from opentracing import Scope, ScopeManager + +import twisted + +from synapse.logging.context import LoggingContext, nested_logging_context + +logger = logging.getLogger(__name__) + + +class LogContextScopeManager(ScopeManager): + """ + The LogContextScopeManager tracks the active scope in opentracing + by using the log contexts which are native to synapse. This is so + that the basic opentracing api can be used across twisted defereds. + (I would love to break logcontexts and this into an OS package. but + let's wait for twisted's contexts to be released.) + """ + + def __init__(self, config): + pass + + @property + def active(self): + """ + Returns the currently active Scope which can be used to access the + currently active Scope.span. + If there is a non-null Scope, its wrapped Span + becomes an implicit parent of any newly-created Span at + Tracer.start_active_span() time. + + Return: + (Scope) : the Scope that is active, or None if not + available. + """ + ctx = LoggingContext.current_context() + if ctx is LoggingContext.sentinel: + return None + else: + return ctx.scope + + def activate(self, span, finish_on_close): + """ + Makes a Span active. + Args + span (Span): the span that should become active. + finish_on_close (Boolean): whether Span should be automatically + finished when Scope.close() is called. + + Returns: + Scope to control the end of the active period for + *span*. It is a programming error to neglect to call + Scope.close() on the returned instance. + """ + + enter_logcontext = False + ctx = LoggingContext.current_context() + + if ctx is LoggingContext.sentinel: + # We don't want this scope to affect. + logger.error("Tried to activate scope outside of loggingcontext") + return Scope(None, span) + elif ctx.scope is not None: + # We want the logging scope to look exactly the same so we give it + # a blank suffix + ctx = nested_logging_context("") + enter_logcontext = True + + scope = _LogContextScope(self, span, ctx, enter_logcontext, finish_on_close) + ctx.scope = scope + return scope + + +class _LogContextScope(Scope): + """ + A custom opentracing scope. The only significant difference is that it will + close the log context it's related to if the logcontext was created specifically + for this scope. + """ + + def __init__(self, manager, span, logcontext, enter_logcontext, finish_on_close): + """ + Args: + manager (LogContextScopeManager): + the manager that is responsible for this scope. + span (Span): + the opentracing span which this scope represents the local + lifetime for. + logcontext (LogContext): + the logcontext to which this scope is attached. + enter_logcontext (Boolean): + if True the logcontext will be entered and exited when the scope + is entered and exited respectively + finish_on_close (Boolean): + if True finish the span when the scope is closed + """ + super(_LogContextScope, self).__init__(manager, span) + self.logcontext = logcontext + self._finish_on_close = finish_on_close + self._enter_logcontext = enter_logcontext + + def __enter__(self): + if self._enter_logcontext: + self.logcontext.__enter__() + + def __exit__(self, type, value, traceback): + if type == twisted.internet.defer._DefGen_Return: + super(_LogContextScope, self).__exit__(None, None, None) + else: + super(_LogContextScope, self).__exit__(type, value, traceback) + if self._enter_logcontext: + self.logcontext.__exit__(type, value, traceback) + else: # the logcontext existed before the creation of the scope + self.logcontext.scope = None + + def close(self): + if self.manager.active is not self: + logger.error("Tried to close a none active scope!") + return + + if self._finish_on_close: + self.span.finish() diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index eaf0aaa86e..488280b4a6 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -29,8 +29,16 @@ from prometheus_client.core import REGISTRY, GaugeMetricFamily, HistogramMetricF from twisted.internet import reactor +from synapse.metrics._exposition import ( + MetricsResource, + generate_latest, + start_http_server, +) + logger = logging.getLogger(__name__) +METRICS_PREFIX = "/_synapse/metrics" + running_on_pypy = platform.python_implementation() == "PyPy" all_metrics = [] all_collectors = [] @@ -470,3 +478,12 @@ try: gc.disable() except AttributeError: pass + +__all__ = [ + "MetricsResource", + "generate_latest", + "start_http_server", + "LaterGauge", + "InFlightGauge", + "BucketCollector", +] diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_exposition.py new file mode 100644 index 0000000000..1933ecd3e3 --- /dev/null +++ b/synapse/metrics/_exposition.py @@ -0,0 +1,258 @@ +# -*- coding: utf-8 -*- +# Copyright 2015-2019 Prometheus Python Client Developers +# Copyright 2019 Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This code is based off `prometheus_client/exposition.py` from version 0.7.1. + +Due to the renaming of metrics in prometheus_client 0.4.0, this customised +vendoring of the code will emit both the old versions that Synapse dashboards +expect, and the newer "best practice" version of the up-to-date official client. +""" + +import math +import threading +from collections import namedtuple +from http.server import BaseHTTPRequestHandler, HTTPServer +from socketserver import ThreadingMixIn +from urllib.parse import parse_qs, urlparse + +from prometheus_client import REGISTRY + +from twisted.web.resource import Resource + +try: + from prometheus_client.samples import Sample +except ImportError: + Sample = namedtuple("Sample", ["name", "labels", "value", "timestamp", "exemplar"]) + + +CONTENT_TYPE_LATEST = str("text/plain; version=0.0.4; charset=utf-8") + + +INF = float("inf") +MINUS_INF = float("-inf") + + +def floatToGoString(d): + d = float(d) + if d == INF: + return "+Inf" + elif d == MINUS_INF: + return "-Inf" + elif math.isnan(d): + return "NaN" + else: + s = repr(d) + dot = s.find(".") + # Go switches to exponents sooner than Python. + # We only need to care about positive values for le/quantile. + if d > 0 and dot > 6: + mantissa = "{0}.{1}{2}".format(s[0], s[1:dot], s[dot + 1 :]).rstrip("0.") + return "{0}e+0{1}".format(mantissa, dot - 1) + return s + + +def sample_line(line, name): + if line.labels: + labelstr = "{{{0}}}".format( + ",".join( + [ + '{0}="{1}"'.format( + k, + v.replace("\\", r"\\").replace("\n", r"\n").replace('"', r"\""), + ) + for k, v in sorted(line.labels.items()) + ] + ) + ) + else: + labelstr = "" + timestamp = "" + if line.timestamp is not None: + # Convert to milliseconds. + timestamp = " {0:d}".format(int(float(line.timestamp) * 1000)) + return "{0}{1} {2}{3}\n".format( + name, labelstr, floatToGoString(line.value), timestamp + ) + + +def nameify_sample(sample): + """ + If we get a prometheus_client<0.4.0 sample as a tuple, transform it into a + namedtuple which has the names we expect. + """ + if not isinstance(sample, Sample): + sample = Sample(*sample, None, None) + + return sample + + +def generate_latest(registry, emit_help=False): + output = [] + + for metric in registry.collect(): + + if metric.name.startswith("__unused"): + continue + + if not metric.samples: + # No samples, don't bother. + continue + + mname = metric.name + mnewname = metric.name + mtype = metric.type + + # OpenMetrics -> Prometheus + if mtype == "counter": + mnewname = mnewname + "_total" + elif mtype == "info": + mtype = "gauge" + mnewname = mnewname + "_info" + elif mtype == "stateset": + mtype = "gauge" + elif mtype == "gaugehistogram": + mtype = "histogram" + elif mtype == "unknown": + mtype = "untyped" + + # Output in the old format for compatibility. + if emit_help: + output.append( + "# HELP {0} {1}\n".format( + mname, + metric.documentation.replace("\\", r"\\").replace("\n", r"\n"), + ) + ) + output.append("# TYPE {0} {1}\n".format(mname, mtype)) + for sample in map(nameify_sample, metric.samples): + # Get rid of the OpenMetrics specific samples + for suffix in ["_created", "_gsum", "_gcount"]: + if sample.name.endswith(suffix): + break + else: + newname = sample.name.replace(mnewname, mname) + if ":" in newname and newname.endswith("_total"): + newname = newname[: -len("_total")] + output.append(sample_line(sample, newname)) + + # Get rid of the weird colon things while we're at it + if mtype == "counter": + mnewname = mnewname.replace(":total", "") + mnewname = mnewname.replace(":", "_") + + if mname == mnewname: + continue + + # Also output in the new format, if it's different. + if emit_help: + output.append( + "# HELP {0} {1}\n".format( + mnewname, + metric.documentation.replace("\\", r"\\").replace("\n", r"\n"), + ) + ) + output.append("# TYPE {0} {1}\n".format(mnewname, mtype)) + for sample in map(nameify_sample, metric.samples): + # Get rid of the OpenMetrics specific samples + for suffix in ["_created", "_gsum", "_gcount"]: + if sample.name.endswith(suffix): + break + else: + output.append( + sample_line( + sample, sample.name.replace(":total", "").replace(":", "_") + ) + ) + + return "".join(output).encode("utf-8") + + +class MetricsHandler(BaseHTTPRequestHandler): + """HTTP handler that gives metrics from ``REGISTRY``.""" + + registry = REGISTRY + + def do_GET(self): + registry = self.registry + params = parse_qs(urlparse(self.path).query) + + if "help" in params: + emit_help = True + else: + emit_help = False + + try: + output = generate_latest(registry, emit_help=emit_help) + except Exception: + self.send_error(500, "error generating metric output") + raise + self.send_response(200) + self.send_header("Content-Type", CONTENT_TYPE_LATEST) + self.end_headers() + self.wfile.write(output) + + def log_message(self, format, *args): + """Log nothing.""" + + @classmethod + def factory(cls, registry): + """Returns a dynamic MetricsHandler class tied + to the passed registry. + """ + # This implementation relies on MetricsHandler.registry + # (defined above and defaulted to REGISTRY). + + # As we have unicode_literals, we need to create a str() + # object for type(). + cls_name = str(cls.__name__) + MyMetricsHandler = type(cls_name, (cls, object), {"registry": registry}) + return MyMetricsHandler + + +class _ThreadingSimpleServer(ThreadingMixIn, HTTPServer): + """Thread per request HTTP server.""" + + # Make worker threads "fire and forget". Beginning with Python 3.7 this + # prevents a memory leak because ``ThreadingMixIn`` starts to gather all + # non-daemon threads in a list in order to join on them at server close. + # Enabling daemon threads virtually makes ``_ThreadingSimpleServer`` the + # same as Python 3.7's ``ThreadingHTTPServer``. + daemon_threads = True + + +def start_http_server(port, addr="", registry=REGISTRY): + """Starts an HTTP server for prometheus metrics as a daemon thread""" + CustomMetricsHandler = MetricsHandler.factory(registry) + httpd = _ThreadingSimpleServer((addr, port), CustomMetricsHandler) + t = threading.Thread(target=httpd.serve_forever) + t.daemon = True + t.start() + + +class MetricsResource(Resource): + """ + Twisted ``Resource`` that serves prometheus metrics. + """ + + isLeaf = True + + def __init__(self, registry=REGISTRY): + self.registry = registry + + def render_GET(self, request): + request.setHeader(b"Content-Type", CONTENT_TYPE_LATEST.encode("ascii")) + return generate_latest(self.registry) diff --git a/synapse/metrics/resource.py b/synapse/metrics/resource.py deleted file mode 100644 index 9789359077..0000000000 --- a/synapse/metrics/resource.py +++ /dev/null @@ -1,20 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from prometheus_client.twisted import MetricsResource - -METRICS_PREFIX = "/_synapse/metrics" - -__all__ = ["MetricsResource", "METRICS_PREFIX"] diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 6324c00ef1..c6465c0386 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -65,9 +65,7 @@ REQUIREMENTS = [ "msgpack>=0.5.2", "phonenumbers>=8.2.0", "six>=1.10", - # prometheus_client 0.4.0 changed the format of counter metrics - # (cf https://github.com/matrix-org/synapse/issues/4001) - "prometheus_client>=0.0.18,<0.4.0", + "prometheus_client>=0.0.18,<0.8.0", # we use attr.s(slots), which arrived in 16.0.0 # Twisted 18.7.0 requires attrs>=17.4.0 "attrs>=17.4.0", @@ -95,6 +93,7 @@ CONDITIONAL_REQUIREMENTS = { "url_preview": ["lxml>=3.5.0"], "test": ["mock>=2.0", "parameterized"], "sentry": ["sentry-sdk>=0.7.2"], + "opentracing": ["jaeger-client>=4.0.0", "opentracing>=2.2.0"], "jwt": ["pyjwt>=1.6.4"], } diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index dd0d38ea5c..57542c2b4b 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -18,7 +18,13 @@ import logging from twisted.internet import defer -from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError +from synapse.api.errors import ( + AuthError, + Codes, + InvalidClientCredentialsError, + NotFoundError, + SynapseError, +) from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.rest.client.v2_alpha._base import client_patterns from synapse.types import RoomAlias @@ -97,7 +103,7 @@ class ClientDirectoryServer(RestServlet): room_alias.to_string(), ) defer.returnValue((200, {})) - except AuthError: + except InvalidClientCredentialsError: # fallback to default user behaviour if they aren't an AS pass diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index cca7e45ddb..7709c2d705 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -24,7 +24,12 @@ from canonicaljson import json from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import AuthError, Codes, SynapseError +from synapse.api.errors import ( + AuthError, + Codes, + InvalidClientCredentialsError, + SynapseError, +) from synapse.api.filtering import Filter from synapse.events.utils import format_event_for_client_v2 from synapse.http.servlet import ( @@ -307,7 +312,7 @@ class PublicRoomListRestServlet(TransactionRestServlet): try: yield self.auth.get_user_by_req(request, allow_guest=True) - except AuthError as e: + except InvalidClientCredentialsError as e: # Option to allow servers to require auth when accessing # /publicRooms via CS API. This is especially helpful in private # federations. diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py index 8e362782cc..6e52f6d284 100644 --- a/synapse/rest/client/v2_alpha/relations.py +++ b/synapse/rest/client/v2_alpha/relations.py @@ -34,6 +34,7 @@ from synapse.http.servlet import ( from synapse.rest.client.transactions import HttpTransactionCache from synapse.storage.relations import ( AggregationPaginationToken, + PaginationChunk, RelationPaginationToken, ) @@ -145,38 +146,55 @@ class RelationPaginationServlet(RestServlet): room_id, requester.user.to_string() ) - # This checks that a) the event exists and b) the user is allowed to - # view it. - yield self.event_handler.get_event(requester.user, room_id, parent_id) + # This gets the original event and checks that a) the event exists and + # b) the user is allowed to view it. + event = yield self.event_handler.get_event(requester.user, room_id, parent_id) limit = parse_integer(request, "limit", default=5) from_token = parse_string(request, "from") to_token = parse_string(request, "to") - if from_token: - from_token = RelationPaginationToken.from_string(from_token) - - if to_token: - to_token = RelationPaginationToken.from_string(to_token) - - result = yield self.store.get_relations_for_event( - event_id=parent_id, - relation_type=relation_type, - event_type=event_type, - limit=limit, - from_token=from_token, - to_token=to_token, - ) + if event.internal_metadata.is_redacted(): + # If the event is redacted, return an empty list of relations + pagination_chunk = PaginationChunk(chunk=[]) + else: + # Return the relations + if from_token: + from_token = RelationPaginationToken.from_string(from_token) + + if to_token: + to_token = RelationPaginationToken.from_string(to_token) + + pagination_chunk = yield self.store.get_relations_for_event( + event_id=parent_id, + relation_type=relation_type, + event_type=event_type, + limit=limit, + from_token=from_token, + to_token=to_token, + ) events = yield self.store.get_events_as_list( - [c["event_id"] for c in result.chunk] + [c["event_id"] for c in pagination_chunk.chunk] ) now = self.clock.time_msec() - events = yield self._event_serializer.serialize_events(events, now) + # We set bundle_aggregations to False when retrieving the original + # event because we want the content before relations were applied to + # it. + original_event = yield self._event_serializer.serialize_event( + event, now, bundle_aggregations=False + ) + # Similarly, we don't allow relations to be applied to relations, so we + # return the original relations without any aggregations on top of them + # here. + events = yield self._event_serializer.serialize_events( + events, now, bundle_aggregations=False + ) - return_value = result.to_dict() + return_value = pagination_chunk.to_dict() return_value["chunk"] = events + return_value["original_event"] = original_event defer.returnValue((200, return_value)) @@ -222,7 +240,7 @@ class RelationAggregationPaginationServlet(RestServlet): # This checks that a) the event exists and b) the user is allowed to # view it. - yield self.event_handler.get_event(requester.user, room_id, parent_id) + event = yield self.event_handler.get_event(requester.user, room_id, parent_id) if relation_type not in (RelationTypes.ANNOTATION, None): raise SynapseError(400, "Relation type must be 'annotation'") @@ -231,21 +249,26 @@ class RelationAggregationPaginationServlet(RestServlet): from_token = parse_string(request, "from") to_token = parse_string(request, "to") - if from_token: - from_token = AggregationPaginationToken.from_string(from_token) - - if to_token: - to_token = AggregationPaginationToken.from_string(to_token) - - res = yield self.store.get_aggregation_groups_for_event( - event_id=parent_id, - event_type=event_type, - limit=limit, - from_token=from_token, - to_token=to_token, - ) - - defer.returnValue((200, res.to_dict())) + if event.internal_metadata.is_redacted(): + # If the event is redacted, return an empty list of relations + pagination_chunk = PaginationChunk(chunk=[]) + else: + # Return the relations + if from_token: + from_token = AggregationPaginationToken.from_string(from_token) + + if to_token: + to_token = AggregationPaginationToken.from_string(to_token) + + pagination_chunk = yield self.store.get_aggregation_groups_for_event( + event_id=parent_id, + event_type=event_type, + limit=limit, + from_token=from_token, + to_token=to_token, + ) + + defer.returnValue((200, pagination_chunk.to_dict())) class RelationAggregationGroupPaginationServlet(RestServlet): diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py index e8f559acc1..37687ea7f4 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py @@ -67,7 +67,7 @@ class StorageProviderWrapper(StorageProvider): backend (StorageProvider) store_local (bool): Whether to store new local files or not. store_synchronous (bool): Whether to wait for file to be successfully - uploaded, or todo the upload in the backgroud. + uploaded, or todo the upload in the background. store_remote (bool): Whether remote media should be uploaded """ diff --git a/synapse/static/index.html b/synapse/static/index.html index d3f1c7dce0..bf46df9097 100644 --- a/synapse/static/index.html +++ b/synapse/static/index.html @@ -48,13 +48,13 @@ </div> <h1>It works! Synapse is running</h1> <p>Your Synapse server is listening on this port and is ready for messages.</p> - <p>To use this server you'll need <a href="https://matrix.org/docs/projects/try-matrix-now.html#clients" target="_blank">a Matrix client</a>. + <p>To use this server you'll need <a href="https://matrix.org/docs/projects/try-matrix-now.html#clients" target="_blank" rel="noopener noreferrer">a Matrix client</a>. </p> <p>Welcome to the Matrix universe :)</p> <hr> <p> <small> - <a href="https://matrix.org" target="_blank"> + <a href="https://matrix.org" target="_blank" rel="noopener noreferrer"> matrix.org </a> </small> diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 2f940dbae6..a7c93efa46 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -86,7 +86,21 @@ _CURRENT_STATE_CACHE_NAME = "cs_cache_fake" class LoggingTransaction(object): """An object that almost-transparently proxies for the 'txn' object passed to the constructor. Adds logging and metrics to the .execute() - method.""" + method. + + Args: + txn: The database transcation object to wrap. + name (str): The name of this transactions for logging. + database_engine (Sqlite3Engine|PostgresEngine) + after_callbacks(list|None): A list that callbacks will be appended to + that have been added by `call_after` which should be run on + successful completion of the transaction. None indicates that no + callbacks should be allowed to be scheduled to run. + exception_callbacks(list|None): A list that callbacks will be appended + to that have been added by `call_on_exception` which should be run + if transaction ends with an error. None indicates that no callbacks + should be allowed to be scheduled to run. + """ __slots__ = [ "txn", @@ -97,7 +111,7 @@ class LoggingTransaction(object): ] def __init__( - self, txn, name, database_engine, after_callbacks, exception_callbacks + self, txn, name, database_engine, after_callbacks=None, exception_callbacks=None ): object.__setattr__(self, "txn", txn) object.__setattr__(self, "name", name) diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index eca77069fd..dcfb67e029 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -79,8 +79,6 @@ class EventPushActionsWorkerStore(SQLBaseStore): db_conn.cursor(), name="_find_stream_orderings_for_times_txn", database_engine=self.database_engine, - after_callbacks=[], - exception_callbacks=[], ) self._find_stream_orderings_for_times_txn(cur) cur.close() diff --git a/synapse/storage/events.py b/synapse/storage/events.py index b486ca50eb..b70457bfc6 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -918,8 +918,6 @@ class EventsStore( min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering - self._update_current_state_txn(txn, state_delta_for_room, min_stream_order) - self._update_forward_extremities_txn( txn, new_forward_extremities=new_forward_extremeties, @@ -993,6 +991,10 @@ class EventsStore( backfilled=backfilled, ) + # We call this last as it assumes we've inserted the events into + # room_memberships, where applicable. + self._update_current_state_txn(txn, state_delta_for_room, min_stream_order) + def _update_current_state_txn(self, txn, state_delta_by_room, stream_id): for room_id, current_state_tuple in iteritems(state_delta_by_room): to_delete, to_insert = current_state_tuple @@ -1062,16 +1064,16 @@ class EventsStore( ), ) - self._simple_insert_many_txn( - txn, - table="current_state_events", - values=[ - { - "event_id": ev_id, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - } + # We include the membership in the current state table, hence we do + # a lookup when we insert. This assumes that all events have already + # been inserted into room_memberships. + txn.executemany( + """INSERT INTO current_state_events + (room_id, type, state_key, event_id, membership) + VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?)) + """, + [ + (room_id, key[0], key[1], ev_id, ev_id) for key, ev_id in iteritems(to_insert) ], ) diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index 09db872511..858fc755a1 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -37,6 +37,7 @@ from synapse.logging.context import ( ) from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import get_domain_from_id +from synapse.util import batch_iter from synapse.util.metrics import Measure from ._base import SQLBaseStore @@ -218,9 +219,108 @@ class EventsWorkerStore(SQLBaseStore): if not event_ids: defer.returnValue([]) - event_id_list = event_ids - event_ids = set(event_ids) + # there may be duplicates so we cast the list to a set + event_entry_map = yield self._get_events_from_cache_or_db( + set(event_ids), allow_rejected=allow_rejected + ) + + events = [] + for event_id in event_ids: + entry = event_entry_map.get(event_id, None) + if not entry: + continue + + if not allow_rejected: + assert not entry.event.rejected_reason, ( + "rejected event returned from _get_events_from_cache_or_db despite " + "allow_rejected=False" + ) + + # We may not have had the original event when we received a redaction, so + # we have to recheck auth now. + + if not allow_rejected and entry.event.type == EventTypes.Redaction: + redacted_event_id = entry.event.redacts + event_map = yield self._get_events_from_cache_or_db([redacted_event_id]) + original_event_entry = event_map.get(redacted_event_id) + if not original_event_entry: + # we don't have the redacted event (or it was rejected). + # + # We assume that the redaction isn't authorized for now; if the + # redacted event later turns up, the redaction will be re-checked, + # and if it is found valid, the original will get redacted before it + # is served to the client. + logger.debug( + "Withholding redaction event %s since we don't (yet) have the " + "original %s", + event_id, + redacted_event_id, + ) + continue + + original_event = original_event_entry.event + if original_event.type == EventTypes.Create: + # we never serve redactions of Creates to clients. + logger.info( + "Withholding redaction %s of create event %s", + event_id, + redacted_event_id, + ) + continue + + if entry.event.internal_metadata.need_to_check_redaction(): + original_domain = get_domain_from_id(original_event.sender) + redaction_domain = get_domain_from_id(entry.event.sender) + if original_domain != redaction_domain: + # the senders don't match, so this is forbidden + logger.info( + "Withholding redaction %s whose sender domain %s doesn't " + "match that of redacted event %s %s", + event_id, + redaction_domain, + redacted_event_id, + original_domain, + ) + continue + + # Update the cache to save doing the checks again. + entry.event.internal_metadata.recheck_redaction = False + + if check_redacted and entry.redacted_event: + event = entry.redacted_event + else: + event = entry.event + + events.append(event) + + if get_prev_content: + if "replaces_state" in event.unsigned: + prev = yield self.get_event( + event.unsigned["replaces_state"], + get_prev_content=False, + allow_none=True, + ) + if prev: + event.unsigned = dict(event.unsigned) + event.unsigned["prev_content"] = prev.content + event.unsigned["prev_sender"] = prev.sender + + defer.returnValue(events) + + @defer.inlineCallbacks + def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False): + """Fetch a bunch of events from the cache or the database. + + If events are pulled from the database, they will be cached for future lookups. + Args: + event_ids (Iterable[str]): The event_ids of the events to fetch + allow_rejected (bool): Whether to include rejected events + + Returns: + Deferred[Dict[str, _EventCacheEntry]]: + map from event id to result + """ event_entry_map = self._get_events_from_cache( event_ids, allow_rejected=allow_rejected ) @@ -243,81 +343,7 @@ class EventsWorkerStore(SQLBaseStore): event_entry_map.update(missing_events) - events = [] - for event_id in event_id_list: - entry = event_entry_map.get(event_id, None) - if not entry: - continue - - # Starting in room version v3, some redactions need to be rechecked if we - # didn't have the redacted event at the time, so we recheck on read - # instead. - if not allow_rejected and entry.event.type == EventTypes.Redaction: - if entry.event.internal_metadata.need_to_check_redaction(): - # XXX: we need to avoid calling get_event here. - # - # The problem is that we end up at this point when an event - # which has been redacted is pulled out of the database by - # _enqueue_events, because _enqueue_events needs to check - # the redaction before it can cache the redacted event. So - # obviously, calling get_event to get the redacted event out - # of the database gives us an infinite loop. - # - # For now (quick hack to fix during 0.99 release cycle), we - # just go and fetch the relevant row from the db, but it - # would be nice to think about how we can cache this rather - # than hit the db every time we access a redaction event. - # - # One thought on how to do this: - # 1. split get_events_as_list up so that it is divided into - # (a) get the rawish event from the db/cache, (b) do the - # redaction/rejection filtering - # 2. have _get_event_from_row just call the first half of - # that - - orig_sender = yield self._simple_select_one_onecol( - table="events", - keyvalues={"event_id": entry.event.redacts}, - retcol="sender", - allow_none=True, - ) - - expected_domain = get_domain_from_id(entry.event.sender) - if ( - orig_sender - and get_domain_from_id(orig_sender) == expected_domain - ): - # This redaction event is allowed. Mark as not needing a - # recheck. - entry.event.internal_metadata.recheck_redaction = False - else: - # We don't have the event that is being redacted, so we - # assume that the event isn't authorized for now. (If we - # later receive the event, then we will always redact - # it anyway, since we have this redaction) - continue - - if allow_rejected or not entry.event.rejected_reason: - if check_redacted and entry.redacted_event: - event = entry.redacted_event - else: - event = entry.event - - events.append(event) - - if get_prev_content: - if "replaces_state" in event.unsigned: - prev = yield self.get_event( - event.unsigned["replaces_state"], - get_prev_content=False, - allow_none=True, - ) - if prev: - event.unsigned = dict(event.unsigned) - event.unsigned["prev_content"] = prev.content - event.unsigned["prev_sender"] = prev.sender - - defer.returnValue(events) + return event_entry_map def _invalidate_get_event_cache(self, event_id): self._get_event_cache.invalidate((event_id,)) @@ -326,8 +352,8 @@ class EventsWorkerStore(SQLBaseStore): """Fetch events from the caches Args: - events (list(str)): list of event_ids to fetch - allow_rejected (bool): Whether to teturn events that were rejected + events (Iterable[str]): list of event_ids to fetch + allow_rejected (bool): Whether to return events that were rejected update_metrics (bool): Whether to update the cache hit ratio metrics Returns: @@ -384,19 +410,16 @@ class EventsWorkerStore(SQLBaseStore): The fetch requests. Each entry consists of a list of event ids to be fetched, and a deferred to be completed once the events have been fetched. - """ with Measure(self._clock, "_fetch_event_list"): try: event_id_lists = list(zip(*event_list))[0] event_ids = [item for sublist in event_id_lists for item in sublist] - rows = self._new_transaction( + row_dict = self._new_transaction( conn, "do_fetch", [], [], self._fetch_event_rows, event_ids ) - row_dict = {r["event_id"]: r for r in rows} - # We only want to resolve deferreds from the main thread def fire(lst, res): for ids, d in lst: @@ -454,7 +477,7 @@ class EventsWorkerStore(SQLBaseStore): logger.debug("Loaded %d events (%d rows)", len(events), len(rows)) if not allow_rejected: - rows[:] = [r for r in rows if not r["rejects"]] + rows[:] = [r for r in rows if r["rejected_reason"] is None] res = yield make_deferred_yieldable( defer.gatherResults( @@ -463,8 +486,8 @@ class EventsWorkerStore(SQLBaseStore): self._get_event_from_row, row["internal_metadata"], row["json"], - row["redacts"], - rejected_reason=row["rejects"], + row["redactions"], + rejected_reason=row["rejected_reason"], format_version=row["format_version"], ) for row in rows @@ -475,49 +498,98 @@ class EventsWorkerStore(SQLBaseStore): defer.returnValue({e.event.event_id: e for e in res if e}) - def _fetch_event_rows(self, txn, events): - rows = [] - N = 200 - for i in range(1 + len(events) // N): - evs = events[i * N : (i + 1) * N] - if not evs: - break + def _fetch_event_rows(self, txn, event_ids): + """Fetch event rows from the database + + Events which are not found are omitted from the result. + + The returned per-event dicts contain the following keys: + + * event_id (str) + + * json (str): json-encoded event structure + + * internal_metadata (str): json-encoded internal metadata dict + + * format_version (int|None): The format of the event. Hopefully one + of EventFormatVersions. 'None' means the event predates + EventFormatVersions (so the event is format V1). + + * rejected_reason (str|None): if the event was rejected, the reason + why. + * redactions (List[str]): a list of event-ids which (claim to) redact + this event. + + Args: + txn (twisted.enterprise.adbapi.Connection): + event_ids (Iterable[str]): event IDs to fetch + + Returns: + Dict[str, Dict]: a map from event id to event info. + """ + event_dict = {} + for evs in batch_iter(event_ids, 200): sql = ( "SELECT " - " e.event_id as event_id, " + " e.event_id, " " e.internal_metadata," " e.json," " e.format_version, " - " r.redacts as redacts," - " rej.event_id as rejects " + " rej.reason " " FROM event_json as e" " LEFT JOIN rejections as rej USING (event_id)" - " LEFT JOIN redactions as r ON e.event_id = r.redacts" " WHERE e.event_id IN (%s)" ) % (",".join(["?"] * len(evs)),) txn.execute(sql, evs) - rows.extend(self.cursor_to_dict(txn)) - return rows + for row in txn: + event_id = row[0] + event_dict[event_id] = { + "event_id": event_id, + "internal_metadata": row[1], + "json": row[2], + "format_version": row[3], + "rejected_reason": row[4], + "redactions": [], + } + + # check for redactions + redactions_sql = ( + "SELECT event_id, redacts FROM redactions WHERE redacts IN (%s)" + ) % (",".join(["?"] * len(evs)),) + + txn.execute(redactions_sql, evs) + + for (redacter, redacted) in txn: + d = event_dict.get(redacted) + if d: + d["redactions"].append(redacter) + + return event_dict @defer.inlineCallbacks def _get_event_from_row( - self, internal_metadata, js, redacted, format_version, rejected_reason=None + self, internal_metadata, js, redactions, format_version, rejected_reason=None ): + """Parse an event row which has been read from the database + + Args: + internal_metadata (str): json-encoded internal_metadata column + js (str): json-encoded event body from event_json + redactions (list[str]): a list of the events which claim to have redacted + this event, from the redactions table + format_version: (str): the 'format_version' column + rejected_reason (str|None): the reason this event was rejected, if any + + Returns: + _EventCacheEntry + """ with Measure(self._clock, "_get_event_from_row"): d = json.loads(js) internal_metadata = json.loads(internal_metadata) - if rejected_reason: - rejected_reason = yield self._simple_select_one_onecol( - table="rejections", - keyvalues={"event_id": rejected_reason}, - retcol="reason", - desc="_get_event_from_row_rejected_reason", - ) - if format_version is None: # This means that we stored the event before we had the concept # of a event format version, so it must be a V1 event. @@ -529,41 +601,7 @@ class EventsWorkerStore(SQLBaseStore): rejected_reason=rejected_reason, ) - redacted_event = None - if redacted: - redacted_event = prune_event(original_ev) - - redaction_id = yield self._simple_select_one_onecol( - table="redactions", - keyvalues={"redacts": redacted_event.event_id}, - retcol="event_id", - desc="_get_event_from_row_redactions", - ) - - redacted_event.unsigned["redacted_by"] = redaction_id - # Get the redaction event. - - because = yield self.get_event( - redaction_id, check_redacted=False, allow_none=True - ) - - if because: - # It's fine to do add the event directly, since get_pdu_json - # will serialise this field correctly - redacted_event.unsigned["redacted_because"] = because - - # Starting in room version v3, some redactions need to be - # rechecked if we didn't have the redacted event at the - # time, so we recheck on read instead. - if because.internal_metadata.need_to_check_redaction(): - expected_domain = get_domain_from_id(original_ev.sender) - if get_domain_from_id(because.sender) == expected_domain: - # This redaction event is allowed. Mark as not needing a - # recheck. - because.internal_metadata.recheck_redaction = False - else: - # Senders don't match, so the event isn't actually redacted - redacted_event = None + redacted_event = yield self._maybe_redact_event_row(original_ev, redactions) cache_entry = _EventCacheEntry( event=original_ev, redacted_event=redacted_event @@ -574,6 +612,60 @@ class EventsWorkerStore(SQLBaseStore): defer.returnValue(cache_entry) @defer.inlineCallbacks + def _maybe_redact_event_row(self, original_ev, redactions): + """Given an event object and a list of possible redacting event ids, + determine whether to honour any of those redactions and if so return a redacted + event. + + Args: + original_ev (EventBase): + redactions (iterable[str]): list of event ids of potential redaction events + + Returns: + Deferred[EventBase|None]: if the event should be redacted, a pruned + event object. Otherwise, None. + """ + if original_ev.type == "m.room.create": + # we choose to ignore redactions of m.room.create events. + return None + + redaction_map = yield self._get_events_from_cache_or_db(redactions) + + for redaction_id in redactions: + redaction_entry = redaction_map.get(redaction_id) + if not redaction_entry: + # we don't have the redaction event, or the redaction event was not + # authorized. + continue + + redaction_event = redaction_entry.event + + # Starting in room version v3, some redactions need to be + # rechecked if we didn't have the redacted event at the + # time, so we recheck on read instead. + if redaction_event.internal_metadata.need_to_check_redaction(): + expected_domain = get_domain_from_id(original_ev.sender) + if get_domain_from_id(redaction_event.sender) == expected_domain: + # This redaction event is allowed. Mark as not needing a recheck. + redaction_event.internal_metadata.recheck_redaction = False + else: + # Senders don't match, so the event isn't actually redacted + continue + + # we found a good redaction event. Redact! + redacted_event = prune_event(original_ev) + redacted_event.unsigned["redacted_by"] = redaction_id + + # It's fine to add the event directly, since get_pdu_json + # will serialise this field correctly + redacted_event.unsigned["redacted_because"] = redaction_event + + return redacted_event + + # no valid redaction found for this event + return None + + @defer.inlineCallbacks def have_events_in_timeline(self, event_ids): """Given a list of event ids, check if we have already processed and stored them as non outliers. diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 7c4e1dc7ec..d20eacda59 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -27,7 +27,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 55 +SCHEMA_VERSION = 56 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index aea5b3276b..8b2c2a97ab 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -90,7 +90,8 @@ class RegistrationWorkerStore(SQLBaseStore): token (str): The access token of a user. Returns: defer.Deferred: None, if the token did not match, otherwise dict - including the keys `name`, `is_guest`, `device_id`, `token_id`. + including the keys `name`, `is_guest`, `device_id`, `token_id`, + `valid_until_ms`. """ return self.runInteraction( "get_user_by_access_token", self._query_for_auth, token @@ -284,7 +285,7 @@ class RegistrationWorkerStore(SQLBaseStore): def _query_for_auth(self, txn, token): sql = ( "SELECT users.name, users.is_guest, access_tokens.id as token_id," - " access_tokens.device_id" + " access_tokens.device_id, access_tokens.valid_until_ms" " FROM users" " INNER JOIN access_tokens on users.name = access_tokens.user_id" " WHERE token = ?" @@ -603,7 +604,7 @@ class RegistrationStore( ) self.register_background_update_handler( - "users_set_deactivated_flag", self._backgroud_update_set_deactivated_flag + "users_set_deactivated_flag", self._background_update_set_deactivated_flag ) # Create a background job for culling expired 3PID validity tokens @@ -618,14 +619,14 @@ class RegistrationStore( hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS) @defer.inlineCallbacks - def _backgroud_update_set_deactivated_flag(self, progress, batch_size): + def _background_update_set_deactivated_flag(self, progress, batch_size): """Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1 for each of them. """ last_user = progress.get("user_id", "") - def _backgroud_update_set_deactivated_flag_txn(txn): + def _background_update_set_deactivated_flag_txn(txn): txn.execute( """ SELECT @@ -670,7 +671,7 @@ class RegistrationStore( return False end = yield self.runInteraction( - "users_set_deactivated_flag", _backgroud_update_set_deactivated_flag_txn + "users_set_deactivated_flag", _background_update_set_deactivated_flag_txn ) if end: @@ -679,14 +680,16 @@ class RegistrationStore( defer.returnValue(batch_size) @defer.inlineCallbacks - def add_access_token_to_user(self, user_id, token, device_id=None): + def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms): """Adds an access token for the given user. Args: user_id (str): The user ID. token (str): The new access token to add. device_id (str): ID of the device to associate with the access - token + token + valid_until_ms (int|None): when the token is valid until. None for + no expiry. Raises: StoreError if there was a problem adding this. """ @@ -694,14 +697,19 @@ class RegistrationStore( yield self._simple_insert( "access_tokens", - {"id": next_id, "user_id": user_id, "token": token, "device_id": device_id}, + { + "id": next_id, + "user_id": user_id, + "token": token, + "device_id": device_id, + "valid_until_ms": valid_until_ms, + }, desc="add_access_token_to_user", ) - def register( + def register_user( self, user_id, - token=None, password_hash=None, was_guest=False, make_guest=False, @@ -714,9 +722,6 @@ class RegistrationStore( Args: user_id (str): The desired user ID to register. - token (str): The desired access token to use for this user. If this - is not None, the given access token is associated with the user - id. password_hash (str): Optional. The password hash for this user. was_guest (bool): Optional. Whether this is a guest account being upgraded to a non-guest account. @@ -733,10 +738,9 @@ class RegistrationStore( StoreError if the user_id could not be registered. """ return self.runInteraction( - "register", - self._register, + "register_user", + self._register_user, user_id, - token, password_hash, was_guest, make_guest, @@ -746,11 +750,10 @@ class RegistrationStore( user_type, ) - def _register( + def _register_user( self, txn, user_id, - token, password_hash, was_guest, make_guest, @@ -763,8 +766,6 @@ class RegistrationStore( now = int(self.clock.time()) - next_id = self._access_tokens_id_gen.get_next() - try: if was_guest: # Ensure that the guest user actually exists @@ -812,14 +813,6 @@ class RegistrationStore( if self._account_validity.enabled: self.set_expiration_date_for_user_txn(txn, user_id) - if token: - # it's possible for this to get a conflict, but only for a single user - # since tokens are namespaced based on their user ID - txn.execute( - "INSERT INTO access_tokens(id, user_id, token)" " VALUES (?,?,?)", - (next_id, user_id, token), - ) - if create_profile_with_displayname: # set a default displayname serverside to avoid ugly race # between auto-joins and clients trying to set displaynames diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index 1b01934c19..9954bc094f 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -60,7 +60,7 @@ class PaginationChunk(object): class RelationPaginationToken(object): """Pagination token for relation pagination API. - As the results are order by topological ordering, we can use the + As the results are in topological order, we can use the `topological_ordering` and `stream_ordering` fields of the events at the boundaries of the chunk as pagination tokens. diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 8004aeb909..257bcdb2f8 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -24,6 +24,8 @@ from canonicaljson import json from twisted.internet import defer from synapse.api.constants import EventTypes, Membership +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage._base import LoggingTransaction from synapse.storage.events_worker import EventsWorkerStore from synapse.types import get_domain_from_id from synapse.util.async_helpers import Linearizer @@ -53,9 +55,51 @@ ProfileInfo = namedtuple("ProfileInfo", ("avatar_url", "display_name")) MemberSummary = namedtuple("MemberSummary", ("members", "count")) _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update" +_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership" class RoomMemberWorkerStore(EventsWorkerStore): + def __init__(self, db_conn, hs): + super(RoomMemberWorkerStore, self).__init__(db_conn, hs) + + # Is the current_state_events.membership up to date? Or is the + # background update still running? + self._current_state_events_membership_up_to_date = False + + txn = LoggingTransaction( + db_conn.cursor(), + name="_check_safe_current_state_events_membership_updated", + database_engine=self.database_engine, + ) + self._check_safe_current_state_events_membership_updated_txn(txn) + txn.close() + + def _check_safe_current_state_events_membership_updated_txn(self, txn): + """Checks if it is safe to assume the new current_state_events + membership column is up to date + """ + + pending_update = self._simple_select_one_txn( + txn, + table="background_updates", + keyvalues={"update_name": _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME}, + retcols=["update_name"], + allow_none=True, + ) + + self._current_state_events_membership_up_to_date = not pending_update + + # If the update is still running, reschedule to run. + if pending_update: + self._clock.call_later( + 15.0, + run_as_background_process, + "_check_safe_current_state_events_membership_updated", + self.runInteraction, + "_check_safe_current_state_events_membership_updated", + self._check_safe_current_state_events_membership_updated_txn, + ) + @cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True) def get_hosts_in_room(self, room_id, cache_context): """Returns the set of all hosts currently in the room @@ -69,14 +113,23 @@ class RoomMemberWorkerStore(EventsWorkerStore): @cached(max_entries=100000, iterable=True) def get_users_in_room(self, room_id): def f(txn): - sql = ( - "SELECT m.user_id FROM room_memberships as m" - " INNER JOIN current_state_events as c" - " ON m.event_id = c.event_id " - " AND m.room_id = c.room_id " - " AND m.user_id = c.state_key" - " WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?" - ) + # If we can assume current_state_events.membership is up to date + # then we can avoid a join, which is a Very Good Thing given how + # frequently this function gets called. + if self._current_state_events_membership_up_to_date: + sql = """ + SELECT state_key FROM current_state_events + WHERE type = 'm.room.member' AND room_id = ? AND membership = ? + """ + else: + sql = """ + SELECT state_key FROM room_memberships as m + INNER JOIN current_state_events as c + ON m.event_id = c.event_id + AND m.room_id = c.room_id + AND m.user_id = c.state_key + WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ? + """ txn.execute(sql, (room_id, Membership.JOIN)) return [to_ascii(r[0]) for r in txn] @@ -98,15 +151,26 @@ class RoomMemberWorkerStore(EventsWorkerStore): # first get counts. # We do this all in one transaction to keep the cache small. # FIXME: get rid of this when we have room_stats - sql = """ - SELECT count(*), m.membership FROM room_memberships as m - INNER JOIN current_state_events as c - ON m.event_id = c.event_id - AND m.room_id = c.room_id - AND m.user_id = c.state_key - WHERE c.type = 'm.room.member' AND c.room_id = ? - GROUP BY m.membership - """ + + # If we can assume current_state_events.membership is up to date + # then we can avoid a join, which is a Very Good Thing given how + # frequently this function gets called. + if self._current_state_events_membership_up_to_date: + sql = """ + SELECT count(*), membership FROM current_state_events + WHERE type = 'm.room.member' AND room_id = ? + GROUP BY membership + """ + else: + sql = """ + SELECT count(*), m.membership FROM room_memberships as m + INNER JOIN current_state_events as c + ON m.event_id = c.event_id + AND m.room_id = c.room_id + AND m.user_id = c.state_key + WHERE c.type = 'm.room.member' AND c.room_id = ? + GROUP BY m.membership + """ txn.execute(sql, (room_id,)) res = {} @@ -224,7 +288,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): results = [] if membership_list: where_clause = "user_id = ? AND (%s) AND forgotten = 0" % ( - " OR ".join(["membership = ?" for _ in membership_list]), + " OR ".join(["m.membership = ?" for _ in membership_list]), ) args = [user_id] @@ -453,8 +517,8 @@ class RoomMemberWorkerStore(EventsWorkerStore): sql = """ SELECT state_key FROM current_state_events AS c - INNER JOIN room_memberships USING (event_id) - WHERE membership = 'join' + INNER JOIN room_memberships AS m USING (event_id) + WHERE m.membership = 'join' AND type = 'm.room.member' AND c.room_id = ? AND state_key LIKE ? @@ -575,6 +639,26 @@ class RoomMemberWorkerStore(EventsWorkerStore): count = yield self.runInteraction("did_forget_membership", f) defer.returnValue(count == 0) + @defer.inlineCallbacks + def get_rooms_user_has_been_in(self, user_id): + """Get all rooms that the user has ever been in. + + Args: + user_id (str) + + Returns: + Deferred[set[str]]: Set of room IDs. + """ + + room_ids = yield self._simple_select_onecol( + table="room_memberships", + keyvalues={"membership": Membership.JOIN, "user_id": user_id}, + retcol="room_id", + desc="get_rooms_user_has_been_in", + ) + + return set(room_ids) + class RoomMemberStore(RoomMemberWorkerStore): def __init__(self, db_conn, hs): @@ -582,6 +666,10 @@ class RoomMemberStore(RoomMemberWorkerStore): self.register_background_update_handler( _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile ) + self.register_background_update_handler( + _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME, + self._background_current_state_membership, + ) def _store_room_members_txn(self, txn, events, backfilled): """Store a room member in the database. @@ -761,6 +849,52 @@ class RoomMemberStore(RoomMemberWorkerStore): defer.returnValue(result) + @defer.inlineCallbacks + def _background_current_state_membership(self, progress, batch_size): + """Update the new membership column on current_state_events. + """ + + if "rooms" not in progress: + rooms = yield self._simple_select_onecol( + table="current_state_events", + keyvalues={}, + retcol="DISTINCT room_id", + desc="_background_current_state_membership_get_rooms", + ) + progress["rooms"] = rooms + + rooms = progress["rooms"] + + def _background_current_state_membership_txn(txn): + processed = 0 + while rooms and processed < batch_size: + sql = """ + UPDATE current_state_events AS c + SET membership = ( + SELECT membership FROM room_memberships + WHERE event_id = c.event_id + ) + WHERE room_id = ? + """ + txn.execute(sql, (rooms.pop(),)) + processed += txn.rowcount + + self._background_update_progress_txn( + txn, _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME, progress + ) + + return processed + + result = yield self.runInteraction( + "_background_current_state_membership_update", + _background_current_state_membership_txn, + ) + + if not rooms: + yield self._end_background_update(_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME) + + defer.returnValue(result) + class _JoinedHostsCache(object): """Cache for joined hosts in a room that is optimised to handle updates diff --git a/synapse/storage/schema/delta/55/access_token_expiry.sql b/synapse/storage/schema/delta/55/access_token_expiry.sql new file mode 100644 index 0000000000..4590604bfd --- /dev/null +++ b/synapse/storage/schema/delta/55/access_token_expiry.sql @@ -0,0 +1,18 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- when this access token can be used until, in ms since the epoch. NULL means the token +-- never expires. +ALTER TABLE access_tokens ADD COLUMN valid_until_ms BIGINT; diff --git a/synapse/storage/schema/delta/56/current_state_events_membership.sql b/synapse/storage/schema/delta/56/current_state_events_membership.sql new file mode 100644 index 0000000000..b2e08cd85d --- /dev/null +++ b/synapse/storage/schema/delta/56/current_state_events_membership.sql @@ -0,0 +1,25 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- We add membership to current state so that we don't need to join against +-- room_memberships, which can be surprisingly costly (we do such queries +-- very frequently). +-- This will be null for non-membership events and the content.membership key +-- for membership events. (Will also be null for membership events until the +-- background update job has finished). +ALTER TABLE current_state_events ADD membership TEXT; + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('current_state_events_membership', '{}'); diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 0bfe1b4550..a35289876d 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -510,6 +510,12 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): event ID. """ + where_clause, where_args = state_filter.make_sql_filter_clause() + + if not where_clause: + # We delegate to the cached version + return self.get_current_state_ids(room_id) + def _get_filtered_current_state_ids_txn(txn): results = {} sql = """ @@ -517,8 +523,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): WHERE room_id = ? """ - where_clause, where_args = state_filter.make_sql_filter_clause() - if where_clause: sql += " AND (%s)" % (where_clause,) diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 386a9dbe14..a0465484df 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -833,7 +833,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): Returns: Deferred[tuple[list[_EventDictReturn], str]]: Returns the results as a list of _EventDictReturn and a token that points to the end - of the result set. + of the result set. If no events are returned then the end of the + stream has been reached (i.e. there are no events between + `from_token` and `to_token`), or `limit` is zero. """ assert int(limit) >= 0 @@ -905,15 +907,15 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): only those before direction(char): Either 'b' or 'f' to indicate whether we are paginating forwards or backwards from `from_key`. - limit (int): The maximum number of events to return. Zero or less - means no limit. + limit (int): The maximum number of events to return. event_filter (Filter|None): If provided filters the events to those that match the filter. Returns: - tuple[list[dict], str]: Returns the results as a list of dicts and - a token that points to the end of the result set. The dicts have - the keys "event_id", "topological_ordering" and "stream_orderign". + tuple[list[FrozenEvent], str]: Returns the results as a list of + events and a token that points to the end of the result set. If no + events are returned then the end of the stream has been reached + (i.e. there are no events between `from_key` and `to_key`). """ from_key = RoomStreamToken.parse(from_key) diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index fd18619178..c585cf6cf7 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -196,6 +196,26 @@ class TransactionStore(SQLBaseStore): def _set_destination_retry_timings( self, txn, destination, retry_last_ts, retry_interval ): + + if self.database_engine.can_native_upsert: + # Upsert retry time interval if retry_interval is zero (i.e. we're + # resetting it) or greater than the existing retry interval. + + sql = """ + INSERT INTO destinations (destination, retry_last_ts, retry_interval) + VALUES (?, ?, ?) + ON CONFLICT (destination) DO UPDATE SET + retry_last_ts = EXCLUDED.retry_last_ts, + retry_interval = EXCLUDED.retry_interval + WHERE + EXCLUDED.retry_interval = 0 + OR destinations.retry_interval < EXCLUDED.retry_interval + """ + + txn.execute(sql, (destination, retry_last_ts, retry_interval)) + + return + self.database_engine.lock_table(txn, "destinations") # We need to be careful here as the data may have changed from under us diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py index 83466e25d9..7fd16fe65e 100644 --- a/synapse/storage/user_directory.py +++ b/synapse/storage/user_directory.py @@ -618,15 +618,15 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore): sql = """ SELECT room_id FROM ( SELECT c.room_id FROM current_state_events AS c - INNER JOIN room_memberships USING (event_id) + INNER JOIN room_memberships AS m USING (event_id) WHERE type = 'm.room.member' - AND membership = 'join' + AND m.membership = 'join' AND state_key = ? ) AS f1 INNER JOIN ( SELECT c.room_id FROM current_state_events AS c - INNER JOIN room_memberships USING (event_id) + INNER JOIN room_memberships AS m USING (event_id) WHERE type = 'm.room.member' - AND membership = 'join' + AND m.membership = 'join' AND state_key = ? ) f2 USING (room_id) """ |