diff options
Diffstat (limited to 'synapse')
161 files changed, 9103 insertions, 1732 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py index ec83e6adb7..ef8853bd24 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -16,4 +16,4 @@ """ This is a reference implementation of a Matrix home server. """ -__version__ = "0.23.0-rc2" +__version__ = "0.26.0" diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 72858cca1f..ac0a3655a5 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -270,7 +270,11 @@ class Auth(object): rights (str): The operation being performed; the access token must allow this. Returns: - dict : dict that includes the user and the ID of their access token. + Deferred[dict]: dict that includes: + `user` (UserID) + `is_guest` (bool) + `token_id` (int|None): access token id. May be None if guest + `device_id` (str|None): device corresponding to access token Raises: AuthError if no user by that token exists or the token is invalid. """ diff --git a/synapse/api/errors.py b/synapse/api/errors.py index d0dfa959dc..79b35b3e7c 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -140,6 +140,22 @@ class RegistrationError(SynapseError): pass +class InteractiveAuthIncompleteError(Exception): + """An error raised when UI auth is not yet complete + + (This indicates we should return a 401 with 'result' as the body) + + Attributes: + result (dict): the server response to the request, which should be + passed back to the client + """ + def __init__(self, result): + super(InteractiveAuthIncompleteError, self).__init__( + "Interactive auth not yet complete", + ) + self.result = result + + class UnrecognizedRequestError(SynapseError): """An error indicating we don't understand the request you're trying to make""" def __init__(self, *args, **kwargs): diff --git a/synapse/app/_base.py b/synapse/app/_base.py index cd0e815919..9477737759 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -12,10 +12,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import gc import logging +import sys + +try: + import affinity +except Exception: + affinity = None -import affinity from daemonize import Daemonize from synapse.util import PreserveLoggingContext from synapse.util.rlimit import change_resource_limit @@ -78,6 +84,13 @@ def start_reactor( with PreserveLoggingContext(): logger.info("Running") if cpu_affinity is not None: + if not affinity: + quit_with_error( + "Missing package 'affinity' required for cpu_affinity\n" + "option\n\n" + "Install by running:\n\n" + " pip install affinity\n\n" + ) logger.info("Setting CPU affinity to %s" % cpu_affinity) affinity.set_process_affinity_mask(0, cpu_affinity) change_resource_limit(soft_file_limit) @@ -97,3 +110,13 @@ def start_reactor( daemon.start() else: run() + + +def quit_with_error(error_string): + message_lines = error_string.split("\n") + line_length = max([len(l) for l in message_lines if len(l) < 80]) + 2 + sys.stderr.write("*" * line_length + '\n') + for line in message_lines: + sys.stderr.write(" %s\n" % (line.rstrip(),)) + sys.stderr.write("*" * line_length + '\n') + sys.exit(1) diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py index bee4c47498..abc7ef5725 100644 --- a/synapse/app/frontend_proxy.py +++ b/synapse/app/frontend_proxy.py @@ -50,8 +50,7 @@ logger = logging.getLogger("synapse.app.frontend_proxy") class KeyUploadServlet(RestServlet): - PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$", - releases=()) + PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$") def __init__(self, hs): """ @@ -89,9 +88,16 @@ class KeyUploadServlet(RestServlet): if body: # They're actually trying to upload something, proxy to main synapse. + # Pass through the auth headers, if any, in case the access token + # is there. + auth_headers = request.requestHeaders.getRawHeaders("Authorization", []) + headers = { + "Authorization": auth_headers, + } result = yield self.http_client.post_json_get_json( self.main_uri + request.uri, body, + headers=headers, ) defer.returnValue((200, result)) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 84ad8f04a0..6b8875afb4 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -25,10 +25,13 @@ from synapse.api.urls import CONTENT_REPO_PREFIX, FEDERATION_PREFIX, \ LEGACY_MEDIA_PREFIX, MEDIA_PREFIX, SERVER_KEY_PREFIX, SERVER_KEY_V2_PREFIX, \ STATIC_PREFIX, WEB_CLIENT_PREFIX from synapse.app import _base +from synapse.app._base import quit_with_error from synapse.config._base import ConfigError from synapse.config.homeserver import HomeServerConfig from synapse.crypto import context_factory from synapse.federation.transport.server import TransportLayerServer +from synapse.module_api import ModuleApi +from synapse.http.additional_resource import AdditionalResource from synapse.http.server import RootRedirect from synapse.http.site import SynapseSite from synapse.metrics import register_memory_metrics @@ -40,7 +43,6 @@ from synapse.rest import ClientRestResource from synapse.rest.key.v1.server_key_resource import LocalKey from synapse.rest.key.v2 import KeyApiV2Resource from synapse.rest.media.v0.content_repository import ContentRepoResource -from synapse.rest.media.v1.media_repository import MediaRepositoryResource from synapse.server import HomeServer from synapse.storage import are_all_users_on_domain from synapse.storage.engines import IncorrectDatabaseSetup, create_engine @@ -48,6 +50,7 @@ from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_d from synapse.util.httpresourcetree import create_resource_tree from synapse.util.logcontext import LoggingContext from synapse.util.manhole import manhole +from synapse.util.module_loader import load_module from synapse.util.rlimit import change_resource_limit from synapse.util.versionstring import get_version_string from twisted.application import service @@ -106,52 +109,18 @@ class SynapseHomeServer(HomeServer): resources = {} for res in listener_config["resources"]: for name in res["names"]: - if name == "client": - client_resource = ClientRestResource(self) - if res["compress"]: - client_resource = gz_wrap(client_resource) - - resources.update({ - "/_matrix/client/api/v1": client_resource, - "/_matrix/client/r0": client_resource, - "/_matrix/client/unstable": client_resource, - "/_matrix/client/v2_alpha": client_resource, - "/_matrix/client/versions": client_resource, - }) - - if name == "federation": - resources.update({ - FEDERATION_PREFIX: TransportLayerServer(self), - }) - - if name in ["static", "client"]: - resources.update({ - STATIC_PREFIX: File( - os.path.join(os.path.dirname(synapse.__file__), "static") - ), - }) - - if name in ["media", "federation", "client"]: - media_repo = MediaRepositoryResource(self) - resources.update({ - MEDIA_PREFIX: media_repo, - LEGACY_MEDIA_PREFIX: media_repo, - CONTENT_REPO_PREFIX: ContentRepoResource( - self, self.config.uploads_path - ), - }) - - if name in ["keys", "federation"]: - resources.update({ - SERVER_KEY_PREFIX: LocalKey(self), - SERVER_KEY_V2_PREFIX: KeyApiV2Resource(self), - }) - - if name == "webclient": - resources[WEB_CLIENT_PREFIX] = build_resource_for_web_client(self) + resources.update(self._configure_named_resource( + name, res.get("compress", False), + )) - if name == "metrics" and self.get_config().enable_metrics: - resources[METRICS_PREFIX] = MetricsResource(self) + additional_resources = listener_config.get("additional_resources", {}) + logger.debug("Configuring additional resources: %r", + additional_resources) + module_api = ModuleApi(self, self.get_auth_handler()) + for path, resmodule in additional_resources.items(): + handler_cls, config = load_module(resmodule) + handler = handler_cls(config, module_api) + resources[path] = AdditionalResource(self, handler.handle_request) if WEB_CLIENT_PREFIX in resources: root_resource = RootRedirect(WEB_CLIENT_PREFIX) @@ -187,6 +156,72 @@ class SynapseHomeServer(HomeServer): ) logger.info("Synapse now listening on port %d", port) + def _configure_named_resource(self, name, compress=False): + """Build a resource map for a named resource + + Args: + name (str): named resource: one of "client", "federation", etc + compress (bool): whether to enable gzip compression for this + resource + + Returns: + dict[str, Resource]: map from path to HTTP resource + """ + resources = {} + if name == "client": + client_resource = ClientRestResource(self) + if compress: + client_resource = gz_wrap(client_resource) + + resources.update({ + "/_matrix/client/api/v1": client_resource, + "/_matrix/client/r0": client_resource, + "/_matrix/client/unstable": client_resource, + "/_matrix/client/v2_alpha": client_resource, + "/_matrix/client/versions": client_resource, + }) + + if name == "federation": + resources.update({ + FEDERATION_PREFIX: TransportLayerServer(self), + }) + + if name in ["static", "client"]: + resources.update({ + STATIC_PREFIX: File( + os.path.join(os.path.dirname(synapse.__file__), "static") + ), + }) + + if name in ["media", "federation", "client"]: + if self.get_config().enable_media_repo: + media_repo = self.get_media_repository_resource() + resources.update({ + MEDIA_PREFIX: media_repo, + LEGACY_MEDIA_PREFIX: media_repo, + CONTENT_REPO_PREFIX: ContentRepoResource( + self, self.config.uploads_path + ), + }) + elif name == "media": + raise ConfigError( + "'media' resource conflicts with enable_media_repo=False", + ) + + if name in ["keys", "federation"]: + resources.update({ + SERVER_KEY_PREFIX: LocalKey(self), + SERVER_KEY_V2_PREFIX: KeyApiV2Resource(self), + }) + + if name == "webclient": + resources[WEB_CLIENT_PREFIX] = build_resource_for_web_client(self) + + if name == "metrics" and self.get_config().enable_metrics: + resources[METRICS_PREFIX] = MetricsResource(self) + + return resources + def start_listening(self): config = self.get_config() @@ -249,16 +284,6 @@ class SynapseHomeServer(HomeServer): return db_conn -def quit_with_error(error_string): - message_lines = error_string.split("\n") - line_length = max([len(l) for l in message_lines if len(l) < 80]) + 2 - sys.stderr.write("*" * line_length + '\n') - for line in message_lines: - sys.stderr.write(" %s\n" % (line.rstrip(),)) - sys.stderr.write("*" * line_length + '\n') - sys.exit(1) - - def setup(config_options): """ Args: diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index 36c18bdbcb..c4e5f0965d 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -35,7 +35,6 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto from synapse.replication.slave.storage.transactions import TransactionStore from synapse.replication.tcp.client import ReplicationClientHandler from synapse.rest.media.v0.content_repository import ContentRepoResource -from synapse.rest.media.v1.media_repository import MediaRepositoryResource from synapse.server import HomeServer from synapse.storage.engines import create_engine from synapse.storage.media_repository import MediaRepositoryStore @@ -89,7 +88,7 @@ class MediaRepositoryServer(HomeServer): if name == "metrics": resources[METRICS_PREFIX] = MetricsResource(self) elif name == "media": - media_repo = MediaRepositoryResource(self) + media_repo = self.get_media_repository_resource() resources.update({ MEDIA_PREFIX: media_repo, LEGACY_MEDIA_PREFIX: media_repo, @@ -151,6 +150,13 @@ def start(config_options): assert config.worker_app == "synapse.app.media_repository" + if config.enable_media_repo: + _base.quit_with_error( + "enable_media_repo must be disabled in the main synapse process\n" + "before the media repo can be run in a separate worker.\n" + "Please add ``enable_media_repo: false`` to the main config\n" + ) + setup_logging(config, use_worker_options=True) events.USE_FROZEN_DICTS = config.use_frozen_dicts diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 80e4ba5336..323fddee21 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -40,6 +40,7 @@ from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.room import RoomStore +from synapse.replication.slave.storage.groups import SlavedGroupServerStore from synapse.replication.tcp.client import ReplicationClientHandler from synapse.rest.client.v1 import events from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet @@ -69,6 +70,7 @@ class SynchrotronSlavedStore( SlavedRegistrationStore, SlavedFilteringStore, SlavedPresenceStore, + SlavedGroupServerStore, SlavedDeviceInboxStore, SlavedDeviceStore, SlavedClientIpStore, @@ -338,11 +340,10 @@ class SyncReplicationHandler(ReplicationClientHandler): self.store = hs.get_datastore() self.typing_handler = hs.get_typing_handler() + # NB this is a SynchrotronPresence, not a normal PresenceHandler self.presence_handler = hs.get_presence_handler() self.notifier = hs.get_notifier() - self.presence_handler.sync_callback = self.send_user_sync - def on_rdata(self, stream_name, token, rows): super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows) @@ -403,6 +404,10 @@ class SyncReplicationHandler(ReplicationClientHandler): ) elif stream_name == "presence": yield self.presence_handler.process_replication_rows(token, rows) + elif stream_name == "receipts": + self.notifier.on_new_event( + "groups_key", token, users=[row.user_id for row in rows], + ) def start(config_options): diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index b989007314..d5a7a5ce2f 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -14,6 +14,7 @@ # limitations under the License. from synapse.api.constants import EventTypes from synapse.util.caches.descriptors import cachedInlineCallbacks +from synapse.types import GroupID, get_domain_from_id from twisted.internet import defer @@ -81,12 +82,13 @@ class ApplicationService(object): # values. NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS] - def __init__(self, token, url=None, namespaces=None, hs_token=None, + def __init__(self, token, hostname, url=None, namespaces=None, hs_token=None, sender=None, id=None, protocols=None, rate_limited=True): self.token = token self.url = url self.hs_token = hs_token self.sender = sender + self.server_name = hostname self.namespaces = self._check_namespaces(namespaces) self.id = id @@ -125,6 +127,24 @@ class ApplicationService(object): raise ValueError( "Expected bool for 'exclusive' in ns '%s'" % ns ) + group_id = regex_obj.get("group_id") + if group_id: + if not isinstance(group_id, str): + raise ValueError( + "Expected string for 'group_id' in ns '%s'" % ns + ) + try: + GroupID.from_string(group_id) + except Exception: + raise ValueError( + "Expected valid group ID for 'group_id' in ns '%s'" % ns + ) + + if get_domain_from_id(group_id) != self.server_name: + raise ValueError( + "Expected 'group_id' to be this host in ns '%s'" % ns + ) + regex = regex_obj.get("regex") if isinstance(regex, basestring): regex_obj["regex"] = re.compile(regex) # Pre-compile regex @@ -251,6 +271,21 @@ class ApplicationService(object): if regex_obj["exclusive"] ] + def get_groups_for_user(self, user_id): + """Get the groups that this user is associated with by this AS + + Args: + user_id (str): The ID of the user. + + Returns: + iterable[str]: an iterable that yields group_id strings. + """ + return ( + regex_obj["group_id"] + for regex_obj in self.namespaces[ApplicationService.NS_USERS] + if "group_id" in regex_obj and regex_obj["regex"].match(user_id) + ) + def is_rate_limited(self): return self.rate_limited diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 6893610e71..40c433d7ae 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -18,6 +18,7 @@ from synapse.api.constants import ThirdPartyEntityKind from synapse.api.errors import CodeMessageException from synapse.http.client import SimpleHttpClient from synapse.events.utils import serialize_event +from synapse.util.logcontext import preserve_fn, make_deferred_yieldable from synapse.util.caches.response_cache import ResponseCache from synapse.types import ThirdPartyInstanceID @@ -192,9 +193,12 @@ class ApplicationServiceApi(SimpleHttpClient): defer.returnValue(None) key = (service.id, protocol) - return self.protocol_meta_cache.get(key) or ( - self.protocol_meta_cache.set(key, _get()) - ) + result = self.protocol_meta_cache.get(key) + if not result: + result = self.protocol_meta_cache.set( + key, preserve_fn(_get)() + ) + return make_deferred_yieldable(result) @defer.inlineCallbacks def push_bulk(self, service, events, txn_id=None): diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 68a9de17b8..6da315473d 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -123,7 +123,7 @@ class _ServiceQueuer(object): with Measure(self.clock, "servicequeuer.send"): try: yield self.txn_ctrl.send(service, events) - except: + except Exception: logger.exception("AS request failed") finally: self.requests_in_flight.discard(service.id) diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 1ab5593c6e..fa105bce72 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -82,21 +82,37 @@ class Config(object): return os.path.abspath(file_path) if file_path else file_path @classmethod + def path_exists(cls, file_path): + """Check if a file exists + + Unlike os.path.exists, this throws an exception if there is an error + checking if the file exists (for example, if there is a perms error on + the parent dir). + + Returns: + bool: True if the file exists; False if not. + """ + try: + os.stat(file_path) + return True + except OSError as e: + if e.errno != errno.ENOENT: + raise e + return False + + @classmethod def check_file(cls, file_path, config_name): if file_path is None: raise ConfigError( "Missing config for %s." - " You must specify a path for the config file. You can " - "do this with the -c or --config-path option. " - "Adding --generate-config along with --server-name " - "<server name> will generate a config file at the given path." % (config_name,) ) - if not os.path.exists(file_path): + try: + os.stat(file_path) + except OSError as e: raise ConfigError( - "File %s config for %s doesn't exist." - " Try running again with --generate-config" - % (file_path, config_name,) + "Error accessing file '%s' (config for %s): %s" + % (file_path, config_name, e.strerror) ) return cls.abspath(file_path) @@ -248,7 +264,7 @@ class Config(object): " -c CONFIG-FILE\"" ) (config_path,) = config_files - if not os.path.exists(config_path): + if not cls.path_exists(config_path): if config_args.keys_directory: config_dir_path = config_args.keys_directory else: @@ -261,7 +277,7 @@ class Config(object): "Must specify a server_name to a generate config for." " Pass -H server.name." ) - if not os.path.exists(config_dir_path): + if not cls.path_exists(config_dir_path): os.makedirs(config_dir_path) with open(config_path, "wb") as config_file: config_bytes, config = obj.generate_config( diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index 82c50b8240..aba0aec6e8 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -154,6 +154,7 @@ def _load_appservice(hostname, as_info, config_filename): ) return ApplicationService( token=as_info["as_token"], + hostname=hostname, url=as_info["url"], namespaces=as_info["namespaces"], hs_token=as_info["hs_token"], diff --git a/synapse/config/cas.py b/synapse/config/cas.py index 938f6f25f8..8109e5f95e 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py @@ -41,7 +41,7 @@ class CasConfig(Config): #cas_config: # enabled: true # server_url: "https://cas-server.com" - # service_url: "https://homesever.domain.com:8448" + # service_url: "https://homeserver.domain.com:8448" # #required_attributes: # # name: value """ diff --git a/synapse/config/groups.py b/synapse/config/groups.py new file mode 100644 index 0000000000..997fa2881f --- /dev/null +++ b/synapse/config/groups.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import Config + + +class GroupsConfig(Config): + def read_config(self, config): + self.enable_group_creation = config.get("enable_group_creation", False) + self.group_creation_prefix = config.get("group_creation_prefix", "") + + def default_config(self, **kwargs): + return """\ + # Whether to allow non server admins to create groups on this server + enable_group_creation: false + + # If enabled, non server admins can only create groups with local parts + # starting with this prefix + # group_creation_prefix: "unofficial/" + """ diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 3f9d9d5f8b..bf19cfee29 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -35,6 +35,8 @@ from .emailconfig import EmailConfig from .workers import WorkerConfig from .push import PushConfig from .spam_checker import SpamCheckerConfig +from .groups import GroupsConfig +from .user_directory import UserDirectoryConfig class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, @@ -43,7 +45,7 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, AppServiceConfig, KeyConfig, SAML2Config, CasConfig, JWTConfig, PasswordConfig, EmailConfig, WorkerConfig, PasswordAuthProviderConfig, PushConfig, - SpamCheckerConfig,): + SpamCheckerConfig, GroupsConfig, UserDirectoryConfig,): pass diff --git a/synapse/config/key.py b/synapse/config/key.py index 6ee643793e..4b8fc063d0 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -118,10 +118,9 @@ class KeyConfig(Config): signing_keys = self.read_file(signing_key_path, "signing_key") try: return read_signing_keys(signing_keys.splitlines(True)) - except Exception: + except Exception as e: raise ConfigError( - "Error reading signing_key." - " Try running again with --generate-config" + "Error reading signing_key: %s" % (str(e)) ) def read_old_signing_keys(self, old_signing_keys): @@ -141,7 +140,8 @@ class KeyConfig(Config): def generate_files(self, config): signing_key_path = config["signing_key_path"] - if not os.path.exists(signing_key_path): + + if not self.path_exists(signing_key_path): with open(signing_key_path, "w") as signing_key_file: key_id = "a_" + random_string(4) write_signing_keys( diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 2dbeafa9dd..a1d6e4d4f7 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -148,8 +148,8 @@ def setup_logging(config, use_worker_options=False): "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s" " - %(message)s" ) - if log_config is None: + if log_config is None: level = logging.INFO level_for_storage = logging.INFO if config.verbosity: @@ -176,6 +176,10 @@ def setup_logging(config, use_worker_options=False): logger.info("Opened new log file due to SIGHUP") else: handler = logging.StreamHandler() + + def sighup(signum, stack): + pass + handler.setFormatter(formatter) handler.addFilter(LoggingContextFilter(request="")) diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py index 90824cab7f..e9828fac17 100644 --- a/synapse/config/password_auth_providers.py +++ b/synapse/config/password_auth_providers.py @@ -13,41 +13,40 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import Config, ConfigError +from ._base import Config from synapse.util.module_loader import load_module +LDAP_PROVIDER = 'ldap_auth_provider.LdapAuthProvider' + class PasswordAuthProviderConfig(Config): def read_config(self, config): self.password_providers = [] - - provider_config = None + providers = [] # We want to be backwards compatible with the old `ldap_config` # param. ldap_config = config.get("ldap_config", {}) - self.ldap_enabled = ldap_config.get("enabled", False) - if self.ldap_enabled: - from ldap_auth_provider import LdapAuthProvider - parsed_config = LdapAuthProvider.parse_config(ldap_config) - self.password_providers.append((LdapAuthProvider, parsed_config)) + if ldap_config.get("enabled", False): + providers.append[{ + 'module': LDAP_PROVIDER, + 'config': ldap_config, + }] - providers = config.get("password_providers", []) + providers.extend(config.get("password_providers", [])) for provider in providers: + mod_name = provider['module'] + # This is for backwards compat when the ldap auth provider resided # in this package. - if provider['module'] == "synapse.util.ldap_auth_provider.LdapAuthProvider": - from ldap_auth_provider import LdapAuthProvider - provider_class = LdapAuthProvider - try: - provider_config = provider_class.parse_config(provider["config"]) - except Exception as e: - raise ConfigError( - "Failed to parse config for %r: %r" % (provider['module'], e) - ) - else: - (provider_class, provider_config) = load_module(provider) + if mod_name == "synapse.util.ldap_auth_provider.LdapAuthProvider": + mod_name = LDAP_PROVIDER + + (provider_class, provider_config) = load_module({ + "module": mod_name, + "config": provider['config'], + }) self.password_providers.append((provider_class, provider_config)) diff --git a/synapse/config/push.py b/synapse/config/push.py index 9c68318b40..b7e0d46afa 100644 --- a/synapse/config/push.py +++ b/synapse/config/push.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2017 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,28 +19,43 @@ from ._base import Config class PushConfig(Config): def read_config(self, config): - self.push_redact_content = False + push_config = config.get("push", {}) + self.push_include_content = push_config.get("include_content", True) + # There was a a 'redact_content' setting but mistakenly read from the + # 'email'section'. Check for the flag in the 'push' section, and log, + # but do not honour it to avoid nasty surprises when people upgrade. + if push_config.get("redact_content") is not None: + print( + "The push.redact_content content option has never worked. " + "Please set push.include_content if you want this behaviour" + ) + + # Now check for the one in the 'email' section and honour it, + # with a warning. push_config = config.get("email", {}) - self.push_redact_content = push_config.get("redact_content", False) + redact_content = push_config.get("redact_content") + if redact_content is not None: + print( + "The 'email.redact_content' option is deprecated: " + "please set push.include_content instead" + ) + self.push_include_content = not redact_content def default_config(self, config_dir_path, server_name, **kwargs): return """ - # Control how push messages are sent to google/apple to notifications. - # Normally every message said in a room with one or more people using - # mobile devices will be posted to a push server hosted by matrix.org - # which is registered with google and apple in order to allow push - # notifications to be sent to these mobile devices. - # - # Setting redact_content to true will make the push messages contain no - # message content which will provide increased privacy. This is a - # temporary solution pending improvements to Android and iPhone apps - # to get content from the app rather than the notification. - # + # Clients requesting push notifications can either have the body of + # the message sent in the notification poke along with other details + # like the sender, or just the event ID and room ID (`event_id_only`). + # If clients choose the former, this option controls whether the + # notification request includes the content of the event (other details + # like the sender are still included). For `event_id_only` push, it + # has no effect. + # For modern android devices the notification content will still appear # because it is loaded by the app. iPhone, however will send a # notification saying only that a message arrived and who it came from. # #push: - # redact_content: false + # include_content: true """ diff --git a/synapse/config/registration.py b/synapse/config/registration.py index f7e03c4cde..ef917fc9f2 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -41,6 +41,8 @@ class RegistrationConfig(Config): self.allow_guest_access and config.get("invite_3pid_guest", False) ) + self.auto_join_rooms = config.get("auto_join_rooms", []) + def default_config(self, **kwargs): registration_shared_secret = random_string_with_symbols(50) @@ -70,6 +72,11 @@ class RegistrationConfig(Config): - matrix.org - vector.im - riot.im + + # Users who register on this homeserver will automatically be joined + # to these rooms + #auto_join_rooms: + # - "#example:example.com" """ % locals() def add_arguments(self, parser): diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 2c6f57168e..6baa474931 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -70,7 +70,19 @@ class ContentRepositoryConfig(Config): self.max_upload_size = self.parse_size(config["max_upload_size"]) self.max_image_pixels = self.parse_size(config["max_image_pixels"]) self.max_spider_size = self.parse_size(config["max_spider_size"]) + self.media_store_path = self.ensure_directory(config["media_store_path"]) + + self.backup_media_store_path = config.get("backup_media_store_path") + if self.backup_media_store_path: + self.backup_media_store_path = self.ensure_directory( + self.backup_media_store_path + ) + + self.synchronous_backup_media_store = config.get( + "synchronous_backup_media_store", False + ) + self.uploads_path = self.ensure_directory(config["uploads_path"]) self.dynamic_thumbnails = config["dynamic_thumbnails"] self.thumbnail_requirements = parse_thumbnail_requirements( @@ -115,6 +127,14 @@ class ContentRepositoryConfig(Config): # Directory where uploaded images and attachments are stored. media_store_path: "%(media_store)s" + # A secondary directory where uploaded images and attachments are + # stored as a backup. + # backup_media_store_path: "%(media_store)s" + + # Whether to wait for successful write to backup media store before + # returning successfully. + # synchronous_backup_media_store: false + # Directory where in-progress uploads are stored. uploads_path: "%(uploads_path)s" diff --git a/synapse/config/server.py b/synapse/config/server.py index c9a1715f1f..edb90a1348 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -41,6 +41,12 @@ class ServerConfig(Config): # false only if we are updating the user directory in a worker self.update_user_directory = config.get("update_user_directory", True) + # whether to enable the media repository endpoints. This should be set + # to false if the media repository is running as a separate endpoint; + # doing so ensures that we will not run cache cleanup jobs on the + # master, potentially causing inconsistency. + self.enable_media_repo = config.get("enable_media_repo", True) + self.filter_timeline_limit = config.get("filter_timeline_limit", -1) # Whether we should block invites sent to users on this server @@ -247,6 +253,13 @@ class ServerConfig(Config): - names: [federation] # Federation APIs compress: false + # optional list of additional endpoints which can be loaded via + # dynamic modules + # additional_resources: + # "/_matrix/my/custom/endpoint": + # module: my_module.CustomRequestHandler + # config: {} + # Unsecure HTTP listener, # For when matrix traffic passes through loadbalancer that unwraps TLS. - port: %(unsecure_port)s @@ -303,7 +316,7 @@ def read_gc_thresholds(thresholds): return ( int(thresholds[0]), int(thresholds[1]), int(thresholds[2]), ) - except: + except Exception: raise ConfigError( "Value of `gc_threshold` must be a list of three integers if set" ) diff --git a/synapse/config/tls.py b/synapse/config/tls.py index e081840a83..4748f71c2f 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -109,6 +109,12 @@ class TlsConfig(Config): # key. It may be necessary to publish the fingerprints of a new # certificate and wait until the "valid_until_ts" of the previous key # responses have passed before deploying it. + # + # You can calculate a fingerprint from a given TLS listener via: + # openssl s_client -connect $host:$port < /dev/null 2> /dev/null | + # openssl x509 -outform DER | openssl sha256 -binary | base64 | tr -d '=' + # or by checking matrix.org/federationtester/api/report?server_name=$host + # tls_fingerprints: [] # tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}] """ % locals() @@ -126,7 +132,7 @@ class TlsConfig(Config): tls_private_key_path = config["tls_private_key_path"] tls_dh_params_path = config["tls_dh_params_path"] - if not os.path.exists(tls_private_key_path): + if not self.path_exists(tls_private_key_path): with open(tls_private_key_path, "w") as private_key_file: tls_private_key = crypto.PKey() tls_private_key.generate_key(crypto.TYPE_RSA, 2048) @@ -141,7 +147,7 @@ class TlsConfig(Config): crypto.FILETYPE_PEM, private_key_pem ) - if not os.path.exists(tls_certificate_path): + if not self.path_exists(tls_certificate_path): with open(tls_certificate_path, "w") as certificate_file: cert = crypto.X509() subject = cert.get_subject() @@ -159,7 +165,7 @@ class TlsConfig(Config): certificate_file.write(cert_pem) - if not os.path.exists(tls_dh_params_path): + if not self.path_exists(tls_dh_params_path): if GENERATE_DH_PARAMS: subprocess.check_call([ "openssl", "dhparam", diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py new file mode 100644 index 0000000000..38e8947843 --- /dev/null +++ b/synapse/config/user_directory.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import Config + + +class UserDirectoryConfig(Config): + """User Directory Configuration + Configuration for the behaviour of the /user_directory API + """ + + def read_config(self, config): + self.user_directory_search_all_users = False + user_directory_config = config.get("user_directory", None) + if user_directory_config: + self.user_directory_search_all_users = ( + user_directory_config.get("search_all_users", False) + ) + + def default_config(self, config_dir_path, server_name, **kwargs): + return """ + # User Directory configuration + # + # 'search_all_users' defines whether to search all users visible to your HS + # when searching the user directory, rather than limiting to users visible + # in public rooms. Defaults to false. If you set it True, you'll have to run + # UPDATE user_directory_stream_pos SET stream_id = NULL; + # on your database to tell it to rebuild the user_directory search indexes. + # + #user_directory: + # search_all_users: false + """ diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py index aad4752fe7..cff3ca809a 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py @@ -34,7 +34,7 @@ class ServerContextFactory(ssl.ContextFactory): try: _ecCurve = _OpenSSLECCurve(_defaultCurveName) _ecCurve.addECKeyToContext(context) - except: + except Exception: logger.exception("Failed to enable elliptic curve for TLS") context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3) context.use_certificate_chain_file(config.tls_certificate_file) diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py index ec7711ba7d..aaa3efaca3 100644 --- a/synapse/crypto/event_signing.py +++ b/synapse/crypto/event_signing.py @@ -32,18 +32,25 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256): """Check whether the hash for this PDU matches the contents""" name, expected_hash = compute_content_hash(event, hash_algorithm) logger.debug("Expecting hash: %s", encode_base64(expected_hash)) - if name not in event.hashes: + + # some malformed events lack a 'hashes'. Protect against it being missing + # or a weird type by basically treating it the same as an unhashed event. + hashes = event.get("hashes") + if not isinstance(hashes, dict): + raise SynapseError(400, "Malformed 'hashes'", Codes.UNAUTHORIZED) + + if name not in hashes: raise SynapseError( 400, "Algorithm %s not in hashes %s" % ( - name, list(event.hashes), + name, list(hashes), ), Codes.UNAUTHORIZED, ) - message_hash_base64 = event.hashes[name] + message_hash_base64 = hashes[name] try: message_hash_bytes = decode_base64(message_hash_base64) - except: + except Exception: raise SynapseError( 400, "Invalid base64: %s" % (message_hash_base64,), diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 054bac456d..35f810b07b 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -759,7 +759,7 @@ def _handle_key_deferred(verify_request): )) try: verify_signed_json(json_object, server_name, verify_key) - except: + except Exception: raise SynapseError( 401, "Invalid signature for server %s with key %s:%s" % ( diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 4096c606f1..061ee86b16 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -443,12 +443,12 @@ def _check_power_levels(event, auth_events): for k, v in user_list.items(): try: UserID.from_string(k) - except: + except Exception: raise SynapseError(400, "Not a valid user_id: %s" % (k,)) try: int(v) - except: + except Exception: raise SynapseError(400, "Not a valid power level: %s" % (v,)) key = (event.type, event.state_key, ) @@ -470,14 +470,14 @@ def _check_power_levels(event, auth_events): ("invite", None), ] - old_list = current_state.content.get("users") + old_list = current_state.content.get("users", {}) for user in set(old_list.keys() + user_list.keys()): levels_to_check.append( (user, "users") ) - old_list = current_state.content.get("events") - new_list = event.content.get("events") + old_list = current_state.content.get("events", {}) + new_list = event.content.get("events", {}) for ev_id in set(old_list.keys() + new_list.keys()): levels_to_check.append( (ev_id, "events") diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 365fd96bd2..13fbba68c0 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -55,7 +55,7 @@ class EventBuilderFactory(object): local_part = str(int(self.clock.time())) + i + random_string(5) - e_id = EventID.create(local_part, self.hostname) + e_id = EventID(local_part, self.hostname) return e_id.to_string() diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index e739f105b2..633e068eb8 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -22,7 +22,7 @@ class SpamChecker(object): config = None try: module, config = hs.config.spam_checker - except: + except Exception: pass if module is not None: @@ -45,3 +45,69 @@ class SpamChecker(object): return False return self.spam_checker.check_event_for_spam(event) + + def user_may_invite(self, inviter_userid, invitee_userid, room_id): + """Checks if a given user may send an invite + + If this method returns false, the invite will be rejected. + + Args: + userid (string): The sender's user ID + + Returns: + bool: True if the user may send an invite, otherwise False + """ + if self.spam_checker is None: + return True + + return self.spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id) + + def user_may_create_room(self, userid): + """Checks if a given user may create a room + + If this method returns false, the creation request will be rejected. + + Args: + userid (string): The sender's user ID + + Returns: + bool: True if the user may create a room, otherwise False + """ + if self.spam_checker is None: + return True + + return self.spam_checker.user_may_create_room(userid) + + def user_may_create_room_alias(self, userid, room_alias): + """Checks if a given user may create a room alias + + If this method returns false, the association request will be rejected. + + Args: + userid (string): The sender's user ID + room_alias (string): The alias to be created + + Returns: + bool: True if the user may create a room alias, otherwise False + """ + if self.spam_checker is None: + return True + + return self.spam_checker.user_may_create_room_alias(userid, room_alias) + + def user_may_publish_room(self, userid, room_id): + """Checks if a given user may publish a room to the directory + + If this method returns false, the publish request will be rejected. + + Args: + userid (string): The sender's user ID + room_id (string): The ID of the room that would be published + + Returns: + bool: True if the user may publish the room, otherwise False + """ + if self.spam_checker is None: + return True + + return self.spam_checker.user_may_publish_room(userid, room_id) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 7c5e5d957f..b8f02f5391 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -25,7 +25,7 @@ from synapse.api.errors import ( from synapse.util import unwrapFirstError, logcontext from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.logutils import log_function -from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred +from synapse.util.logcontext import make_deferred_yieldable, preserve_fn from synapse.events import FrozenEvent, builder import synapse.metrics @@ -420,7 +420,7 @@ class FederationClient(FederationBase): for e_id in batch ] - res = yield preserve_context_over_deferred( + res = yield make_deferred_yieldable( defer.DeferredList(deferreds, consumeErrors=True) ) for success, result in res: diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 51e3fdea06..a2327f24b6 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -12,14 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - from twisted.internet import defer from .federation_base import FederationBase from .units import Transaction, Edu -from synapse.util.async import Linearizer +from synapse.util import async +from synapse.util.logcontext import make_deferred_yieldable, preserve_fn from synapse.util.logutils import log_function from synapse.util.caches.response_cache import ResponseCache from synapse.events import FrozenEvent @@ -33,6 +32,9 @@ from synapse.crypto.event_signing import compute_event_signature import simplejson as json import logging +# when processing incoming transactions, we try to handle multiple rooms in +# parallel, up to this limit. +TRANSACTION_CONCURRENCY_LIMIT = 10 logger = logging.getLogger(__name__) @@ -52,7 +54,8 @@ class FederationServer(FederationBase): self.auth = hs.get_auth() - self._server_linearizer = Linearizer("fed_server") + self._server_linearizer = async.Linearizer("fed_server") + self._transaction_linearizer = async.Linearizer("fed_txn_handler") # We cache responses to state queries, as they take a while and often # come in waves. @@ -109,25 +112,41 @@ class FederationServer(FederationBase): @defer.inlineCallbacks @log_function def on_incoming_transaction(self, transaction_data): + # keep this as early as possible to make the calculated origin ts as + # accurate as possible. + request_time = self._clock.time_msec() + transaction = Transaction(**transaction_data) - received_pdus_counter.inc_by(len(transaction.pdus)) + if not transaction.transaction_id: + raise Exception("Transaction missing transaction_id") + if not transaction.origin: + raise Exception("Transaction missing origin") - for p in transaction.pdus: - if "unsigned" in p: - unsigned = p["unsigned"] - if "age" in unsigned: - p["age"] = unsigned["age"] - if "age" in p: - p["age_ts"] = int(self._clock.time_msec()) - int(p["age"]) - del p["age"] + logger.debug("[%s] Got transaction", transaction.transaction_id) - pdu_list = [ - self.event_from_pdu_json(p) for p in transaction.pdus - ] + # use a linearizer to ensure that we don't process the same transaction + # multiple times in parallel. + with (yield self._transaction_linearizer.queue( + (transaction.origin, transaction.transaction_id), + )): + result = yield self._handle_incoming_transaction( + transaction, request_time, + ) - logger.debug("[%s] Got transaction", transaction.transaction_id) + defer.returnValue(result) + + @defer.inlineCallbacks + def _handle_incoming_transaction(self, transaction, request_time): + """ Process an incoming transaction and return the HTTP response + + Args: + transaction (Transaction): incoming transaction + request_time (int): timestamp that the HTTP request arrived at + Returns: + Deferred[(int, object)]: http response code and body + """ response = yield self.transaction_actions.have_responded(transaction) if response: @@ -140,42 +159,49 @@ class FederationServer(FederationBase): logger.debug("[%s] Transaction is new", transaction.transaction_id) - results = [] - - for pdu in pdu_list: - # check that it's actually being sent from a valid destination to - # workaround bug #1753 in 0.18.5 and 0.18.6 - if transaction.origin != get_domain_from_id(pdu.event_id): - # We continue to accept join events from any server; this is - # necessary for the federation join dance to work correctly. - # (When we join over federation, the "helper" server is - # responsible for sending out the join event, rather than the - # origin. See bug #1893). - if not ( - pdu.type == 'm.room.member' and - pdu.content and - pdu.content.get("membership", None) == 'join' - ): - logger.info( - "Discarding PDU %s from invalid origin %s", - pdu.event_id, transaction.origin - ) - continue - else: - logger.info( - "Accepting join PDU %s from %s", - pdu.event_id, transaction.origin - ) + received_pdus_counter.inc_by(len(transaction.pdus)) - try: - yield self._handle_received_pdu(transaction.origin, pdu) - results.append({}) - except FederationError as e: - self.send_failure(e, transaction.origin) - results.append({"error": str(e)}) - except Exception as e: - results.append({"error": str(e)}) - logger.exception("Failed to handle PDU") + pdus_by_room = {} + + for p in transaction.pdus: + if "unsigned" in p: + unsigned = p["unsigned"] + if "age" in unsigned: + p["age"] = unsigned["age"] + if "age" in p: + p["age_ts"] = request_time - int(p["age"]) + del p["age"] + + event = self.event_from_pdu_json(p) + room_id = event.room_id + pdus_by_room.setdefault(room_id, []).append(event) + + pdu_results = {} + + # we can process different rooms in parallel (which is useful if they + # require callouts to other servers to fetch missing events), but + # impose a limit to avoid going too crazy with ram/cpu. + @defer.inlineCallbacks + def process_pdus_for_room(room_id): + logger.debug("Processing PDUs for %s", room_id) + for pdu in pdus_by_room[room_id]: + event_id = pdu.event_id + try: + yield self._handle_received_pdu( + transaction.origin, pdu + ) + pdu_results[event_id] = {} + except FederationError as e: + logger.warn("Error handling PDU %s: %s", event_id, e) + pdu_results[event_id] = {"error": str(e)} + except Exception as e: + pdu_results[event_id] = {"error": str(e)} + logger.exception("Failed to handle PDU %s", event_id) + + yield async.concurrently_execute( + process_pdus_for_room, pdus_by_room.keys(), + TRANSACTION_CONCURRENCY_LIMIT, + ) if hasattr(transaction, "edus"): for edu in (Edu(**x) for x in transaction.edus): @@ -185,17 +211,16 @@ class FederationServer(FederationBase): edu.content ) - for failure in getattr(transaction, "pdu_failures", []): - logger.info("Got failure %r", failure) - - logger.debug("Returning: %s", str(results)) + pdu_failures = getattr(transaction, "pdu_failures", []) + for failure in pdu_failures: + logger.info("Got failure %r", failure) response = { - "pdus": dict(zip( - (p.event_id for p in pdu_list), results - )), + "pdus": pdu_results, } + logger.debug("Returning: %s", str(response)) + yield self.transaction_actions.set_response( transaction, 200, response @@ -229,12 +254,13 @@ class FederationServer(FederationBase): result = self._state_resp_cache.get((room_id, event_id)) if not result: with (yield self._server_linearizer.queue((origin, room_id))): - resp = yield self._state_resp_cache.set( + d = self._state_resp_cache.set( (room_id, event_id), - self._on_context_state_request_compute(room_id, event_id) + preserve_fn(self._on_context_state_request_compute)(room_id, event_id) ) + resp = yield make_deferred_yieldable(d) else: - resp = yield result + resp = yield make_deferred_yieldable(result) defer.returnValue((200, resp)) @@ -520,6 +546,30 @@ class FederationServer(FederationBase): Returns (Deferred): completes with None Raises: FederationError if the signatures / hash do not match """ + # check that it's actually being sent from a valid destination to + # workaround bug #1753 in 0.18.5 and 0.18.6 + if origin != get_domain_from_id(pdu.event_id): + # We continue to accept join events from any server; this is + # necessary for the federation join dance to work correctly. + # (When we join over federation, the "helper" server is + # responsible for sending out the join event, rather than the + # origin. See bug #1893). + if not ( + pdu.type == 'm.room.member' and + pdu.content and + pdu.content.get("membership", None) == 'join' + ): + logger.info( + "Discarding PDU %s from invalid origin %s", + pdu.event_id, origin + ) + return + else: + logger.info( + "Accepting join PDU %s from %s", + pdu.event_id, origin + ) + # Check signature. try: pdu = yield self._check_sigs_and_hash(pdu) diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index 003eaba893..3e7809b04f 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -20,8 +20,8 @@ from .persistence import TransactionActions from .units import Transaction, Edu from synapse.api.errors import HttpResponseException +from synapse.util import logcontext, PreserveLoggingContext from synapse.util.async import run_on_reactor -from synapse.util.logcontext import preserve_context_over_fn, preserve_fn from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter from synapse.util.metrics import measure_func from synapse.handlers.presence import format_user_presence_state, get_interested_remotes @@ -146,7 +146,6 @@ class TransactionQueue(object): else: return not destination.startswith("localhost") - @defer.inlineCallbacks def notify_new_events(self, current_id): """This gets called when we have some new events we might want to send out to other servers. @@ -156,6 +155,13 @@ class TransactionQueue(object): if self._is_processing: return + # fire off a processing loop in the background. It's likely it will + # outlast the current request, so run it in the sentinel logcontext. + with PreserveLoggingContext(): + self._process_event_queue_loop() + + @defer.inlineCallbacks + def _process_event_queue_loop(self): try: self._is_processing = True while True: @@ -231,11 +237,9 @@ class TransactionQueue(object): (pdu, order) ) - preserve_context_over_fn( - self._attempt_new_transaction, destination - ) + self._attempt_new_transaction(destination) - @preserve_fn # the caller should not yield on this + @logcontext.preserve_fn # the caller should not yield on this @defer.inlineCallbacks def send_presence(self, states): """Send the new presence states to the appropriate destinations. @@ -299,7 +303,7 @@ class TransactionQueue(object): state.user_id: state for state in states }) - preserve_fn(self._attempt_new_transaction)(destination) + self._attempt_new_transaction(destination) def send_edu(self, destination, edu_type, content, key=None): edu = Edu( @@ -321,9 +325,7 @@ class TransactionQueue(object): else: self.pending_edus_by_dest.setdefault(destination, []).append(edu) - preserve_context_over_fn( - self._attempt_new_transaction, destination - ) + self._attempt_new_transaction(destination) def send_failure(self, failure, destination): if destination == self.server_name or destination == "localhost": @@ -336,9 +338,7 @@ class TransactionQueue(object): destination, [] ).append(failure) - preserve_context_over_fn( - self._attempt_new_transaction, destination - ) + self._attempt_new_transaction(destination) def send_device_messages(self, destination): if destination == self.server_name or destination == "localhost": @@ -347,15 +347,24 @@ class TransactionQueue(object): if not self.can_send_to(destination): return - preserve_context_over_fn( - self._attempt_new_transaction, destination - ) + self._attempt_new_transaction(destination) def get_current_token(self): return 0 - @defer.inlineCallbacks def _attempt_new_transaction(self, destination): + """Try to start a new transaction to this destination + + If there is already a transaction in progress to this destination, + returns immediately. Otherwise kicks off the process of sending a + transaction in the background. + + Args: + destination (str): + + Returns: + None + """ # list of (pending_pdu, deferred, order) if destination in self.pending_transactions: # XXX: pending_transactions can get stuck on by a never-ending @@ -368,6 +377,19 @@ class TransactionQueue(object): ) return + logger.debug("TX [%s] Starting transaction loop", destination) + + # Drop the logcontext before starting the transaction. It doesn't + # really make sense to log all the outbound transactions against + # whatever path led us to this point: that's pretty arbitrary really. + # + # (this also means we can fire off _perform_transaction without + # yielding) + with logcontext.PreserveLoggingContext(): + self._transaction_transmission_loop(destination) + + @defer.inlineCallbacks + def _transaction_transmission_loop(self, destination): pending_pdus = [] try: self.pending_transactions[destination] = 1 diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 52b2a717d2..1f3ce238f6 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -471,3 +471,418 @@ class TransportLayerClient(object): ) defer.returnValue(content) + + @log_function + def get_group_profile(self, destination, group_id, requester_user_id): + """Get a group profile + """ + path = PREFIX + "/groups/%s/profile" % (group_id,) + + return self.client.get_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + ignore_backoff=True, + ) + + @log_function + def update_group_profile(self, destination, group_id, requester_user_id, content): + """Update a remote group profile + + Args: + destination (str) + group_id (str) + requester_user_id (str) + content (dict): The new profile of the group + """ + path = PREFIX + "/groups/%s/profile" % (group_id,) + + return self.client.post_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + data=content, + ignore_backoff=True, + ) + + @log_function + def get_group_summary(self, destination, group_id, requester_user_id): + """Get a group summary + """ + path = PREFIX + "/groups/%s/summary" % (group_id,) + + return self.client.get_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + ignore_backoff=True, + ) + + @log_function + def get_rooms_in_group(self, destination, group_id, requester_user_id): + """Get all rooms in a group + """ + path = PREFIX + "/groups/%s/rooms" % (group_id,) + + return self.client.get_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + ignore_backoff=True, + ) + + def add_room_to_group(self, destination, group_id, requester_user_id, room_id, + content): + """Add a room to a group + """ + path = PREFIX + "/groups/%s/room/%s" % (group_id, room_id,) + + return self.client.post_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + data=content, + ignore_backoff=True, + ) + + def update_room_in_group(self, destination, group_id, requester_user_id, room_id, + config_key, content): + """Update room in group + """ + path = PREFIX + "/groups/%s/room/%s/config/%s" % (group_id, room_id, config_key,) + + return self.client.post_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + data=content, + ignore_backoff=True, + ) + + def remove_room_from_group(self, destination, group_id, requester_user_id, room_id): + """Remove a room from a group + """ + path = PREFIX + "/groups/%s/room/%s" % (group_id, room_id,) + + return self.client.delete_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + ignore_backoff=True, + ) + + @log_function + def get_users_in_group(self, destination, group_id, requester_user_id): + """Get users in a group + """ + path = PREFIX + "/groups/%s/users" % (group_id,) + + return self.client.get_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + ignore_backoff=True, + ) + + @log_function + def get_invited_users_in_group(self, destination, group_id, requester_user_id): + """Get users that have been invited to a group + """ + path = PREFIX + "/groups/%s/invited_users" % (group_id,) + + return self.client.get_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + ignore_backoff=True, + ) + + @log_function + def accept_group_invite(self, destination, group_id, user_id, content): + """Accept a group invite + """ + path = PREFIX + "/groups/%s/users/%s/accept_invite" % (group_id, user_id) + + return self.client.post_json( + destination=destination, + path=path, + data=content, + ignore_backoff=True, + ) + + @log_function + def invite_to_group(self, destination, group_id, user_id, requester_user_id, content): + """Invite a user to a group + """ + path = PREFIX + "/groups/%s/users/%s/invite" % (group_id, user_id) + + return self.client.post_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + data=content, + ignore_backoff=True, + ) + + @log_function + def invite_to_group_notification(self, destination, group_id, user_id, content): + """Sent by group server to inform a user's server that they have been + invited. + """ + + path = PREFIX + "/groups/local/%s/users/%s/invite" % (group_id, user_id) + + return self.client.post_json( + destination=destination, + path=path, + data=content, + ignore_backoff=True, + ) + + @log_function + def remove_user_from_group(self, destination, group_id, requester_user_id, + user_id, content): + """Remove a user fron a group + """ + path = PREFIX + "/groups/%s/users/%s/remove" % (group_id, user_id) + + return self.client.post_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + data=content, + ignore_backoff=True, + ) + + @log_function + def remove_user_from_group_notification(self, destination, group_id, user_id, + content): + """Sent by group server to inform a user's server that they have been + kicked from the group. + """ + + path = PREFIX + "/groups/local/%s/users/%s/remove" % (group_id, user_id) + + return self.client.post_json( + destination=destination, + path=path, + data=content, + ignore_backoff=True, + ) + + @log_function + def renew_group_attestation(self, destination, group_id, user_id, content): + """Sent by either a group server or a user's server to periodically update + the attestations + """ + + path = PREFIX + "/groups/%s/renew_attestation/%s" % (group_id, user_id) + + return self.client.post_json( + destination=destination, + path=path, + data=content, + ignore_backoff=True, + ) + + @log_function + def update_group_summary_room(self, destination, group_id, user_id, room_id, + category_id, content): + """Update a room entry in a group summary + """ + if category_id: + path = PREFIX + "/groups/%s/summary/categories/%s/rooms/%s" % ( + group_id, category_id, room_id, + ) + else: + path = PREFIX + "/groups/%s/summary/rooms/%s" % (group_id, room_id,) + + return self.client.post_json( + destination=destination, + path=path, + args={"requester_user_id": user_id}, + data=content, + ignore_backoff=True, + ) + + @log_function + def delete_group_summary_room(self, destination, group_id, user_id, room_id, + category_id): + """Delete a room entry in a group summary + """ + if category_id: + path = PREFIX + "/groups/%s/summary/categories/%s/rooms/%s" % ( + group_id, category_id, room_id, + ) + else: + path = PREFIX + "/groups/%s/summary/rooms/%s" % (group_id, room_id,) + + return self.client.delete_json( + destination=destination, + path=path, + args={"requester_user_id": user_id}, + ignore_backoff=True, + ) + + @log_function + def get_group_categories(self, destination, group_id, requester_user_id): + """Get all categories in a group + """ + path = PREFIX + "/groups/%s/categories" % (group_id,) + + return self.client.get_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + ignore_backoff=True, + ) + + @log_function + def get_group_category(self, destination, group_id, requester_user_id, category_id): + """Get category info in a group + """ + path = PREFIX + "/groups/%s/categories/%s" % (group_id, category_id,) + + return self.client.get_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + ignore_backoff=True, + ) + + @log_function + def update_group_category(self, destination, group_id, requester_user_id, category_id, + content): + """Update a category in a group + """ + path = PREFIX + "/groups/%s/categories/%s" % (group_id, category_id,) + + return self.client.post_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + data=content, + ignore_backoff=True, + ) + + @log_function + def delete_group_category(self, destination, group_id, requester_user_id, + category_id): + """Delete a category in a group + """ + path = PREFIX + "/groups/%s/categories/%s" % (group_id, category_id,) + + return self.client.delete_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + ignore_backoff=True, + ) + + @log_function + def get_group_roles(self, destination, group_id, requester_user_id): + """Get all roles in a group + """ + path = PREFIX + "/groups/%s/roles" % (group_id,) + + return self.client.get_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + ignore_backoff=True, + ) + + @log_function + def get_group_role(self, destination, group_id, requester_user_id, role_id): + """Get a roles info + """ + path = PREFIX + "/groups/%s/roles/%s" % (group_id, role_id,) + + return self.client.get_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + ignore_backoff=True, + ) + + @log_function + def update_group_role(self, destination, group_id, requester_user_id, role_id, + content): + """Update a role in a group + """ + path = PREFIX + "/groups/%s/roles/%s" % (group_id, role_id,) + + return self.client.post_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + data=content, + ignore_backoff=True, + ) + + @log_function + def delete_group_role(self, destination, group_id, requester_user_id, role_id): + """Delete a role in a group + """ + path = PREFIX + "/groups/%s/roles/%s" % (group_id, role_id,) + + return self.client.delete_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + ignore_backoff=True, + ) + + @log_function + def update_group_summary_user(self, destination, group_id, requester_user_id, + user_id, role_id, content): + """Update a users entry in a group + """ + if role_id: + path = PREFIX + "/groups/%s/summary/roles/%s/users/%s" % ( + group_id, role_id, user_id, + ) + else: + path = PREFIX + "/groups/%s/summary/users/%s" % (group_id, user_id,) + + return self.client.post_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + data=content, + ignore_backoff=True, + ) + + @log_function + def delete_group_summary_user(self, destination, group_id, requester_user_id, + user_id, role_id): + """Delete a users entry in a group + """ + if role_id: + path = PREFIX + "/groups/%s/summary/roles/%s/users/%s" % ( + group_id, role_id, user_id, + ) + else: + path = PREFIX + "/groups/%s/summary/users/%s" % (group_id, user_id,) + + return self.client.delete_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + ignore_backoff=True, + ) + + def bulk_get_publicised_groups(self, destination, user_ids): + """Get the groups a list of users are publicising + """ + + path = PREFIX + "/get_groups_publicised" + + content = {"user_ids": user_ids} + + return self.client.post_json( + destination=destination, + path=path, + data=content, + ignore_backoff=True, + ) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index a78f01e442..2b02b021ec 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -25,7 +25,7 @@ from synapse.http.servlet import ( from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.versionstring import get_version_string from synapse.util.logcontext import preserve_fn -from synapse.types import ThirdPartyInstanceID +from synapse.types import ThirdPartyInstanceID, get_domain_from_id import functools import logging @@ -112,7 +112,7 @@ class Authenticator(object): key = strip_quotes(param_dict["key"]) sig = strip_quotes(param_dict["sig"]) return (origin, key, sig) - except: + except Exception: raise AuthenticationError( 400, "Malformed Authorization header", Codes.UNAUTHORIZED ) @@ -177,7 +177,7 @@ class BaseFederationServlet(object): if self.REQUIRE_AUTH: logger.exception("authenticate_request failed") raise - except: + except Exception: logger.exception("authenticate_request failed") raise @@ -270,7 +270,7 @@ class FederationSendServlet(BaseFederationServlet): code, response = yield self.handler.on_incoming_transaction( transaction_data ) - except: + except Exception: logger.exception("on_incoming_transaction failed") raise @@ -609,6 +609,514 @@ class FederationVersionServlet(BaseFederationServlet): })) +class FederationGroupsProfileServlet(BaseFederationServlet): + """Get/set the basic profile of a group on behalf of a user + """ + PATH = "/groups/(?P<group_id>[^/]*)/profile$" + + @defer.inlineCallbacks + def on_GET(self, origin, content, query, group_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = yield self.handler.get_group_profile( + group_id, requester_user_id + ) + + defer.returnValue((200, new_content)) + + @defer.inlineCallbacks + def on_POST(self, origin, content, query, group_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = yield self.handler.update_group_profile( + group_id, requester_user_id, content + ) + + defer.returnValue((200, new_content)) + + +class FederationGroupsSummaryServlet(BaseFederationServlet): + PATH = "/groups/(?P<group_id>[^/]*)/summary$" + + @defer.inlineCallbacks + def on_GET(self, origin, content, query, group_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = yield self.handler.get_group_summary( + group_id, requester_user_id + ) + + defer.returnValue((200, new_content)) + + +class FederationGroupsRoomsServlet(BaseFederationServlet): + """Get the rooms in a group on behalf of a user + """ + PATH = "/groups/(?P<group_id>[^/]*)/rooms$" + + @defer.inlineCallbacks + def on_GET(self, origin, content, query, group_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = yield self.handler.get_rooms_in_group( + group_id, requester_user_id + ) + + defer.returnValue((200, new_content)) + + +class FederationGroupsAddRoomsServlet(BaseFederationServlet): + """Add/remove room from group + """ + PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)$" + + @defer.inlineCallbacks + def on_POST(self, origin, content, query, group_id, room_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = yield self.handler.add_room_to_group( + group_id, requester_user_id, room_id, content + ) + + defer.returnValue((200, new_content)) + + @defer.inlineCallbacks + def on_DELETE(self, origin, content, query, group_id, room_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = yield self.handler.remove_room_from_group( + group_id, requester_user_id, room_id, + ) + + defer.returnValue((200, new_content)) + + +class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet): + """Update room config in group + """ + PATH = ( + "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)" + "/config/(?P<config_key>[^/]*)$" + ) + + @defer.inlineCallbacks + def on_POST(self, origin, content, query, group_id, room_id, config_key): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + result = yield self.groups_handler.update_room_in_group( + group_id, requester_user_id, room_id, config_key, content, + ) + + defer.returnValue((200, result)) + + +class FederationGroupsUsersServlet(BaseFederationServlet): + """Get the users in a group on behalf of a user + """ + PATH = "/groups/(?P<group_id>[^/]*)/users$" + + @defer.inlineCallbacks + def on_GET(self, origin, content, query, group_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = yield self.handler.get_users_in_group( + group_id, requester_user_id + ) + + defer.returnValue((200, new_content)) + + +class FederationGroupsInvitedUsersServlet(BaseFederationServlet): + """Get the users that have been invited to a group + """ + PATH = "/groups/(?P<group_id>[^/]*)/invited_users$" + + @defer.inlineCallbacks + def on_GET(self, origin, content, query, group_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = yield self.handler.get_invited_users_in_group( + group_id, requester_user_id + ) + + defer.returnValue((200, new_content)) + + +class FederationGroupsInviteServlet(BaseFederationServlet): + """Ask a group server to invite someone to the group + """ + PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite$" + + @defer.inlineCallbacks + def on_POST(self, origin, content, query, group_id, user_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = yield self.handler.invite_to_group( + group_id, user_id, requester_user_id, content, + ) + + defer.returnValue((200, new_content)) + + +class FederationGroupsAcceptInviteServlet(BaseFederationServlet): + """Accept an invitation from the group server + """ + PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite$" + + @defer.inlineCallbacks + def on_POST(self, origin, content, query, group_id, user_id): + if get_domain_from_id(user_id) != origin: + raise SynapseError(403, "user_id doesn't match origin") + + new_content = yield self.handler.accept_invite( + group_id, user_id, content, + ) + + defer.returnValue((200, new_content)) + + +class FederationGroupsRemoveUserServlet(BaseFederationServlet): + """Leave or kick a user from the group + """ + PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove$" + + @defer.inlineCallbacks + def on_POST(self, origin, content, query, group_id, user_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = yield self.handler.remove_user_from_group( + group_id, user_id, requester_user_id, content, + ) + + defer.returnValue((200, new_content)) + + +class FederationGroupsLocalInviteServlet(BaseFederationServlet): + """A group server has invited a local user + """ + PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite$" + + @defer.inlineCallbacks + def on_POST(self, origin, content, query, group_id, user_id): + if get_domain_from_id(group_id) != origin: + raise SynapseError(403, "group_id doesn't match origin") + + new_content = yield self.handler.on_invite( + group_id, user_id, content, + ) + + defer.returnValue((200, new_content)) + + +class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet): + """A group server has removed a local user + """ + PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove$" + + @defer.inlineCallbacks + def on_POST(self, origin, content, query, group_id, user_id): + if get_domain_from_id(group_id) != origin: + raise SynapseError(403, "user_id doesn't match origin") + + new_content = yield self.handler.user_removed_from_group( + group_id, user_id, content, + ) + + defer.returnValue((200, new_content)) + + +class FederationGroupsRenewAttestaionServlet(BaseFederationServlet): + """A group or user's server renews their attestation + """ + PATH = "/groups/(?P<group_id>[^/]*)/renew_attestation/(?P<user_id>[^/]*)$" + + @defer.inlineCallbacks + def on_POST(self, origin, content, query, group_id, user_id): + # We don't need to check auth here as we check the attestation signatures + + new_content = yield self.handler.on_renew_attestation( + group_id, user_id, content + ) + + defer.returnValue((200, new_content)) + + +class FederationGroupsSummaryRoomsServlet(BaseFederationServlet): + """Add/remove a room from the group summary, with optional category. + + Matches both: + - /groups/:group/summary/rooms/:room_id + - /groups/:group/summary/categories/:category/rooms/:room_id + """ + PATH = ( + "/groups/(?P<group_id>[^/]*)/summary" + "(/categories/(?P<category_id>[^/]+))?" + "/rooms/(?P<room_id>[^/]*)$" + ) + + @defer.inlineCallbacks + def on_POST(self, origin, content, query, group_id, category_id, room_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if category_id == "": + raise SynapseError(400, "category_id cannot be empty string") + + resp = yield self.handler.update_group_summary_room( + group_id, requester_user_id, + room_id=room_id, + category_id=category_id, + content=content, + ) + + defer.returnValue((200, resp)) + + @defer.inlineCallbacks + def on_DELETE(self, origin, content, query, group_id, category_id, room_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if category_id == "": + raise SynapseError(400, "category_id cannot be empty string") + + resp = yield self.handler.delete_group_summary_room( + group_id, requester_user_id, + room_id=room_id, + category_id=category_id, + ) + + defer.returnValue((200, resp)) + + +class FederationGroupsCategoriesServlet(BaseFederationServlet): + """Get all categories for a group + """ + PATH = ( + "/groups/(?P<group_id>[^/]*)/categories/$" + ) + + @defer.inlineCallbacks + def on_GET(self, origin, content, query, group_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + resp = yield self.handler.get_group_categories( + group_id, requester_user_id, + ) + + defer.returnValue((200, resp)) + + +class FederationGroupsCategoryServlet(BaseFederationServlet): + """Add/remove/get a category in a group + """ + PATH = ( + "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$" + ) + + @defer.inlineCallbacks + def on_GET(self, origin, content, query, group_id, category_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + resp = yield self.handler.get_group_category( + group_id, requester_user_id, category_id + ) + + defer.returnValue((200, resp)) + + @defer.inlineCallbacks + def on_POST(self, origin, content, query, group_id, category_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if category_id == "": + raise SynapseError(400, "category_id cannot be empty string") + + resp = yield self.handler.upsert_group_category( + group_id, requester_user_id, category_id, content, + ) + + defer.returnValue((200, resp)) + + @defer.inlineCallbacks + def on_DELETE(self, origin, content, query, group_id, category_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if category_id == "": + raise SynapseError(400, "category_id cannot be empty string") + + resp = yield self.handler.delete_group_category( + group_id, requester_user_id, category_id, + ) + + defer.returnValue((200, resp)) + + +class FederationGroupsRolesServlet(BaseFederationServlet): + """Get roles in a group + """ + PATH = ( + "/groups/(?P<group_id>[^/]*)/roles/$" + ) + + @defer.inlineCallbacks + def on_GET(self, origin, content, query, group_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + resp = yield self.handler.get_group_roles( + group_id, requester_user_id, + ) + + defer.returnValue((200, resp)) + + +class FederationGroupsRoleServlet(BaseFederationServlet): + """Add/remove/get a role in a group + """ + PATH = ( + "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$" + ) + + @defer.inlineCallbacks + def on_GET(self, origin, content, query, group_id, role_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + resp = yield self.handler.get_group_role( + group_id, requester_user_id, role_id + ) + + defer.returnValue((200, resp)) + + @defer.inlineCallbacks + def on_POST(self, origin, content, query, group_id, role_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if role_id == "": + raise SynapseError(400, "role_id cannot be empty string") + + resp = yield self.handler.update_group_role( + group_id, requester_user_id, role_id, content, + ) + + defer.returnValue((200, resp)) + + @defer.inlineCallbacks + def on_DELETE(self, origin, content, query, group_id, role_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if role_id == "": + raise SynapseError(400, "role_id cannot be empty string") + + resp = yield self.handler.delete_group_role( + group_id, requester_user_id, role_id, + ) + + defer.returnValue((200, resp)) + + +class FederationGroupsSummaryUsersServlet(BaseFederationServlet): + """Add/remove a user from the group summary, with optional role. + + Matches both: + - /groups/:group/summary/users/:user_id + - /groups/:group/summary/roles/:role/users/:user_id + """ + PATH = ( + "/groups/(?P<group_id>[^/]*)/summary" + "(/roles/(?P<role_id>[^/]+))?" + "/users/(?P<user_id>[^/]*)$" + ) + + @defer.inlineCallbacks + def on_POST(self, origin, content, query, group_id, role_id, user_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if role_id == "": + raise SynapseError(400, "role_id cannot be empty string") + + resp = yield self.handler.update_group_summary_user( + group_id, requester_user_id, + user_id=user_id, + role_id=role_id, + content=content, + ) + + defer.returnValue((200, resp)) + + @defer.inlineCallbacks + def on_DELETE(self, origin, content, query, group_id, role_id, user_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if role_id == "": + raise SynapseError(400, "role_id cannot be empty string") + + resp = yield self.handler.delete_group_summary_user( + group_id, requester_user_id, + user_id=user_id, + role_id=role_id, + ) + + defer.returnValue((200, resp)) + + +class FederationGroupsBulkPublicisedServlet(BaseFederationServlet): + """Get roles in a group + """ + PATH = ( + "/get_groups_publicised$" + ) + + @defer.inlineCallbacks + def on_POST(self, origin, content, query): + resp = yield self.handler.bulk_get_publicised_groups( + content["user_ids"], proxy=False, + ) + + defer.returnValue((200, resp)) + + FEDERATION_SERVLET_CLASSES = ( FederationSendServlet, FederationPullServlet, @@ -635,10 +1143,42 @@ FEDERATION_SERVLET_CLASSES = ( FederationVersionServlet, ) + ROOM_LIST_CLASSES = ( PublicRoomList, ) +GROUP_SERVER_SERVLET_CLASSES = ( + FederationGroupsProfileServlet, + FederationGroupsSummaryServlet, + FederationGroupsRoomsServlet, + FederationGroupsUsersServlet, + FederationGroupsInvitedUsersServlet, + FederationGroupsInviteServlet, + FederationGroupsAcceptInviteServlet, + FederationGroupsRemoveUserServlet, + FederationGroupsSummaryRoomsServlet, + FederationGroupsCategoriesServlet, + FederationGroupsCategoryServlet, + FederationGroupsRolesServlet, + FederationGroupsRoleServlet, + FederationGroupsSummaryUsersServlet, + FederationGroupsAddRoomsServlet, + FederationGroupsAddRoomsConfigServlet, +) + + +GROUP_LOCAL_SERVLET_CLASSES = ( + FederationGroupsLocalInviteServlet, + FederationGroupsRemoveLocalUserServlet, + FederationGroupsBulkPublicisedServlet, +) + + +GROUP_ATTESTATION_SERVLET_CLASSES = ( + FederationGroupsRenewAttestaionServlet, +) + def register_servlets(hs, resource, authenticator, ratelimiter): for servletclass in FEDERATION_SERVLET_CLASSES: @@ -656,3 +1196,27 @@ def register_servlets(hs, resource, authenticator, ratelimiter): ratelimiter=ratelimiter, server_name=hs.hostname, ).register(resource) + + for servletclass in GROUP_SERVER_SERVLET_CLASSES: + servletclass( + handler=hs.get_groups_server_handler(), + authenticator=authenticator, + ratelimiter=ratelimiter, + server_name=hs.hostname, + ).register(resource) + + for servletclass in GROUP_LOCAL_SERVLET_CLASSES: + servletclass( + handler=hs.get_groups_local_handler(), + authenticator=authenticator, + ratelimiter=ratelimiter, + server_name=hs.hostname, + ).register(resource) + + for servletclass in GROUP_ATTESTATION_SERVLET_CLASSES: + servletclass( + handler=hs.get_groups_attestation_renewer(), + authenticator=authenticator, + ratelimiter=ratelimiter, + server_name=hs.hostname, + ).register(resource) diff --git a/synapse/groups/__init__.py b/synapse/groups/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/synapse/groups/__init__.py diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py new file mode 100644 index 0000000000..1fb709e6c3 --- /dev/null +++ b/synapse/groups/attestations.py @@ -0,0 +1,195 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Attestations ensure that users and groups can't lie about their memberships. + +When a user joins a group the HS and GS swap attestations, which allow them +both to independently prove to third parties their membership.These +attestations have a validity period so need to be periodically renewed. + +If a user leaves (or gets kicked out of) a group, either side can still use +their attestation to "prove" their membership, until the attestation expires. +Therefore attestations shouldn't be relied on to prove membership in important +cases, but can for less important situtations, e.g. showing a users membership +of groups on their profile, showing flairs, etc.abs + +An attestsation is a signed blob of json that looks like: + + { + "user_id": "@foo:a.example.com", + "group_id": "+bar:b.example.com", + "valid_until_ms": 1507994728530, + "signatures":{"matrix.org":{"ed25519:auto":"..."}} + } +""" + +import logging +import random + +from twisted.internet import defer + +from synapse.api.errors import SynapseError +from synapse.types import get_domain_from_id +from synapse.util.logcontext import preserve_fn + +from signedjson.sign import sign_json + + +logger = logging.getLogger(__name__) + + +# Default validity duration for new attestations we create +DEFAULT_ATTESTATION_LENGTH_MS = 3 * 24 * 60 * 60 * 1000 + +# We add some jitter to the validity duration of attestations so that if we +# add lots of users at once we don't need to renew them all at once. +# The jitter is a multiplier picked randomly between the first and second number +DEFAULT_ATTESTATION_JITTER = (0.9, 1.3) + +# Start trying to update our attestations when they come this close to expiring +UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000 + + +class GroupAttestationSigning(object): + """Creates and verifies group attestations. + """ + def __init__(self, hs): + self.keyring = hs.get_keyring() + self.clock = hs.get_clock() + self.server_name = hs.hostname + self.signing_key = hs.config.signing_key[0] + + @defer.inlineCallbacks + def verify_attestation(self, attestation, group_id, user_id, server_name=None): + """Verifies that the given attestation matches the given parameters. + + An optional server_name can be supplied to explicitly set which server's + signature is expected. Otherwise assumes that either the group_id or user_id + is local and uses the other's server as the one to check. + """ + + if not server_name: + if get_domain_from_id(group_id) == self.server_name: + server_name = get_domain_from_id(user_id) + elif get_domain_from_id(user_id) == self.server_name: + server_name = get_domain_from_id(group_id) + else: + raise Exception("Expected either group_id or user_id to be local") + + if user_id != attestation["user_id"]: + raise SynapseError(400, "Attestation has incorrect user_id") + + if group_id != attestation["group_id"]: + raise SynapseError(400, "Attestation has incorrect group_id") + valid_until_ms = attestation["valid_until_ms"] + + # TODO: We also want to check that *new* attestations that people give + # us to store are valid for at least a little while. + if valid_until_ms < self.clock.time_msec(): + raise SynapseError(400, "Attestation expired") + + yield self.keyring.verify_json_for_server(server_name, attestation) + + def create_attestation(self, group_id, user_id): + """Create an attestation for the group_id and user_id with default + validity length. + """ + validity_period = DEFAULT_ATTESTATION_LENGTH_MS + validity_period *= random.uniform(*DEFAULT_ATTESTATION_JITTER) + valid_until_ms = int(self.clock.time_msec() + validity_period) + + return sign_json({ + "group_id": group_id, + "user_id": user_id, + "valid_until_ms": valid_until_ms, + }, self.server_name, self.signing_key) + + +class GroupAttestionRenewer(object): + """Responsible for sending and receiving attestation updates. + """ + + def __init__(self, hs): + self.clock = hs.get_clock() + self.store = hs.get_datastore() + self.assestations = hs.get_groups_attestation_signing() + self.transport_client = hs.get_federation_transport_client() + self.is_mine_id = hs.is_mine_id + self.attestations = hs.get_groups_attestation_signing() + + self._renew_attestations_loop = self.clock.looping_call( + self._renew_attestations, 30 * 60 * 1000, + ) + + @defer.inlineCallbacks + def on_renew_attestation(self, group_id, user_id, content): + """When a remote updates an attestation + """ + attestation = content["attestation"] + + if not self.is_mine_id(group_id) and not self.is_mine_id(user_id): + raise SynapseError(400, "Neither user not group are on this server") + + yield self.attestations.verify_attestation( + attestation, + user_id=user_id, + group_id=group_id, + ) + + yield self.store.update_remote_attestion(group_id, user_id, attestation) + + defer.returnValue({}) + + @defer.inlineCallbacks + def _renew_attestations(self): + """Called periodically to check if we need to update any of our attestations + """ + + now = self.clock.time_msec() + + rows = yield self.store.get_attestations_need_renewals( + now + UPDATE_ATTESTATION_TIME_MS + ) + + @defer.inlineCallbacks + def _renew_attestation(group_id, user_id): + if not self.is_mine_id(group_id): + destination = get_domain_from_id(group_id) + elif not self.is_mine_id(user_id): + destination = get_domain_from_id(user_id) + else: + logger.warn( + "Incorrectly trying to do attestations for user: %r in %r", + user_id, group_id, + ) + yield self.store.remove_attestation_renewal(group_id, user_id) + return + + attestation = self.attestations.create_attestation(group_id, user_id) + + yield self.transport_client.renew_group_attestation( + destination, group_id, user_id, + content={"attestation": attestation}, + ) + + yield self.store.update_attestation_renewal( + group_id, user_id, attestation + ) + + for row in rows: + group_id = row["group_id"] + user_id = row["user_id"] + + preserve_fn(_renew_attestation)(group_id, user_id) diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py new file mode 100644 index 0000000000..0b995aed70 --- /dev/null +++ b/synapse/groups/groups_server.py @@ -0,0 +1,864 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.api.errors import SynapseError +from synapse.types import GroupID, RoomID, UserID, get_domain_from_id +from twisted.internet import defer + +logger = logging.getLogger(__name__) + + +# TODO: Allow users to "knock" or simpkly join depending on rules +# TODO: Federation admin APIs +# TODO: is_priveged flag to users and is_public to users and rooms +# TODO: Audit log for admins (profile updates, membership changes, users who tried +# to join but were rejected, etc) +# TODO: Flairs + + +class GroupsServerHandler(object): + def __init__(self, hs): + self.hs = hs + self.store = hs.get_datastore() + self.room_list_handler = hs.get_room_list_handler() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.keyring = hs.get_keyring() + self.is_mine_id = hs.is_mine_id + self.signing_key = hs.config.signing_key[0] + self.server_name = hs.hostname + self.attestations = hs.get_groups_attestation_signing() + self.transport_client = hs.get_federation_transport_client() + self.profile_handler = hs.get_profile_handler() + + # Ensure attestations get renewed + hs.get_groups_attestation_renewer() + + @defer.inlineCallbacks + def check_group_is_ours(self, group_id, requester_user_id, + and_exists=False, and_is_admin=None): + """Check that the group is ours, and optionally if it exists. + + If group does exist then return group. + + Args: + group_id (str) + and_exists (bool): whether to also check if group exists + and_is_admin (str): whether to also check if given str is a user_id + that is an admin + """ + if not self.is_mine_id(group_id): + raise SynapseError(400, "Group not on this server") + + group = yield self.store.get_group(group_id) + if and_exists and not group: + raise SynapseError(404, "Unknown group") + + is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) + if group and not is_user_in_group and not group["is_public"]: + raise SynapseError(404, "Unknown group") + + if and_is_admin: + is_admin = yield self.store.is_user_admin_in_group(group_id, and_is_admin) + if not is_admin: + raise SynapseError(403, "User is not admin in group") + + defer.returnValue(group) + + @defer.inlineCallbacks + def get_group_summary(self, group_id, requester_user_id): + """Get the summary for a group as seen by requester_user_id. + + The group summary consists of the profile of the room, and a curated + list of users and rooms. These list *may* be organised by role/category. + The roles/categories are ordered, and so are the users/rooms within them. + + A user/room may appear in multiple roles/categories. + """ + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) + + profile = yield self.get_group_profile(group_id, requester_user_id) + + users, roles = yield self.store.get_users_for_summary_by_role( + group_id, include_private=is_user_in_group, + ) + + # TODO: Add profiles to users + + rooms, categories = yield self.store.get_rooms_for_summary_by_category( + group_id, include_private=is_user_in_group, + ) + + for room_entry in rooms: + room_id = room_entry["room_id"] + joined_users = yield self.store.get_users_in_room(room_id) + entry = yield self.room_list_handler.generate_room_entry( + room_id, len(joined_users), + with_alias=False, allow_private=True, + ) + entry = dict(entry) # so we don't change whats cached + entry.pop("room_id", None) + + room_entry["profile"] = entry + + rooms.sort(key=lambda e: e.get("order", 0)) + + for entry in users: + user_id = entry["user_id"] + + if not self.is_mine_id(requester_user_id): + attestation = yield self.store.get_remote_attestation(group_id, user_id) + if not attestation: + continue + + entry["attestation"] = attestation + else: + entry["attestation"] = self.attestations.create_attestation( + group_id, user_id, + ) + + user_profile = yield self.profile_handler.get_profile_from_cache(user_id) + entry.update(user_profile) + + users.sort(key=lambda e: e.get("order", 0)) + + membership_info = yield self.store.get_users_membership_info_in_group( + group_id, requester_user_id, + ) + + defer.returnValue({ + "profile": profile, + "users_section": { + "users": users, + "roles": roles, + "total_user_count_estimate": 0, # TODO + }, + "rooms_section": { + "rooms": rooms, + "categories": categories, + "total_room_count_estimate": 0, # TODO + }, + "user": membership_info, + }) + + @defer.inlineCallbacks + def update_group_summary_room(self, group_id, requester_user_id, + room_id, category_id, content): + """Add/update a room to the group summary + """ + yield self.check_group_is_ours( + group_id, + requester_user_id, + and_exists=True, + and_is_admin=requester_user_id, + ) + + RoomID.from_string(room_id) # Ensure valid room id + + order = content.get("order", None) + + is_public = _parse_visibility_from_contents(content) + + yield self.store.add_room_to_summary( + group_id=group_id, + room_id=room_id, + category_id=category_id, + order=order, + is_public=is_public, + ) + + defer.returnValue({}) + + @defer.inlineCallbacks + def delete_group_summary_room(self, group_id, requester_user_id, + room_id, category_id): + """Remove a room from the summary + """ + yield self.check_group_is_ours( + group_id, + requester_user_id, + and_exists=True, + and_is_admin=requester_user_id, + ) + + yield self.store.remove_room_from_summary( + group_id=group_id, + room_id=room_id, + category_id=category_id, + ) + + defer.returnValue({}) + + @defer.inlineCallbacks + def get_group_categories(self, group_id, requester_user_id): + """Get all categories in a group (as seen by user) + """ + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + categories = yield self.store.get_group_categories( + group_id=group_id, + ) + defer.returnValue({"categories": categories}) + + @defer.inlineCallbacks + def get_group_category(self, group_id, requester_user_id, category_id): + """Get a specific category in a group (as seen by user) + """ + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + res = yield self.store.get_group_category( + group_id=group_id, + category_id=category_id, + ) + + defer.returnValue(res) + + @defer.inlineCallbacks + def update_group_category(self, group_id, requester_user_id, category_id, content): + """Add/Update a group category + """ + yield self.check_group_is_ours( + group_id, + requester_user_id, + and_exists=True, + and_is_admin=requester_user_id, + ) + + is_public = _parse_visibility_from_contents(content) + profile = content.get("profile") + + yield self.store.upsert_group_category( + group_id=group_id, + category_id=category_id, + is_public=is_public, + profile=profile, + ) + + defer.returnValue({}) + + @defer.inlineCallbacks + def delete_group_category(self, group_id, requester_user_id, category_id): + """Delete a group category + """ + yield self.check_group_is_ours( + group_id, + requester_user_id, + and_exists=True, + and_is_admin=requester_user_id + ) + + yield self.store.remove_group_category( + group_id=group_id, + category_id=category_id, + ) + + defer.returnValue({}) + + @defer.inlineCallbacks + def get_group_roles(self, group_id, requester_user_id): + """Get all roles in a group (as seen by user) + """ + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + roles = yield self.store.get_group_roles( + group_id=group_id, + ) + defer.returnValue({"roles": roles}) + + @defer.inlineCallbacks + def get_group_role(self, group_id, requester_user_id, role_id): + """Get a specific role in a group (as seen by user) + """ + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + res = yield self.store.get_group_role( + group_id=group_id, + role_id=role_id, + ) + defer.returnValue(res) + + @defer.inlineCallbacks + def update_group_role(self, group_id, requester_user_id, role_id, content): + """Add/update a role in a group + """ + yield self.check_group_is_ours( + group_id, + requester_user_id, + and_exists=True, + and_is_admin=requester_user_id, + ) + + is_public = _parse_visibility_from_contents(content) + + profile = content.get("profile") + + yield self.store.upsert_group_role( + group_id=group_id, + role_id=role_id, + is_public=is_public, + profile=profile, + ) + + defer.returnValue({}) + + @defer.inlineCallbacks + def delete_group_role(self, group_id, requester_user_id, role_id): + """Remove role from group + """ + yield self.check_group_is_ours( + group_id, + requester_user_id, + and_exists=True, + and_is_admin=requester_user_id, + ) + + yield self.store.remove_group_role( + group_id=group_id, + role_id=role_id, + ) + + defer.returnValue({}) + + @defer.inlineCallbacks + def update_group_summary_user(self, group_id, requester_user_id, user_id, role_id, + content): + """Add/update a users entry in the group summary + """ + yield self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id, + ) + + order = content.get("order", None) + + is_public = _parse_visibility_from_contents(content) + + yield self.store.add_user_to_summary( + group_id=group_id, + user_id=user_id, + role_id=role_id, + order=order, + is_public=is_public, + ) + + defer.returnValue({}) + + @defer.inlineCallbacks + def delete_group_summary_user(self, group_id, requester_user_id, user_id, role_id): + """Remove a user from the group summary + """ + yield self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id, + ) + + yield self.store.remove_user_from_summary( + group_id=group_id, + user_id=user_id, + role_id=role_id, + ) + + defer.returnValue({}) + + @defer.inlineCallbacks + def get_group_profile(self, group_id, requester_user_id): + """Get the group profile as seen by requester_user_id + """ + + yield self.check_group_is_ours(group_id, requester_user_id) + + group_description = yield self.store.get_group(group_id) + + if group_description: + defer.returnValue(group_description) + else: + raise SynapseError(404, "Unknown group") + + @defer.inlineCallbacks + def update_group_profile(self, group_id, requester_user_id, content): + """Update the group profile + """ + yield self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id, + ) + + profile = {} + for keyname in ("name", "avatar_url", "short_description", + "long_description"): + if keyname in content: + value = content[keyname] + if not isinstance(value, basestring): + raise SynapseError(400, "%r value is not a string" % (keyname,)) + profile[keyname] = value + + yield self.store.update_group_profile(group_id, profile) + + @defer.inlineCallbacks + def get_users_in_group(self, group_id, requester_user_id): + """Get the users in group as seen by requester_user_id. + + The ordering is arbitrary at the moment + """ + + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) + + user_results = yield self.store.get_users_in_group( + group_id, include_private=is_user_in_group, + ) + + chunk = [] + for user_result in user_results: + g_user_id = user_result["user_id"] + is_public = user_result["is_public"] + is_privileged = user_result["is_admin"] + + entry = {"user_id": g_user_id} + + profile = yield self.profile_handler.get_profile_from_cache(g_user_id) + entry.update(profile) + + entry["is_public"] = bool(is_public) + entry["is_privileged"] = bool(is_privileged) + + if not self.is_mine_id(g_user_id): + attestation = yield self.store.get_remote_attestation(group_id, g_user_id) + if not attestation: + continue + + entry["attestation"] = attestation + else: + entry["attestation"] = self.attestations.create_attestation( + group_id, g_user_id, + ) + + chunk.append(entry) + + # TODO: If admin add lists of users whose attestations have timed out + + defer.returnValue({ + "chunk": chunk, + "total_user_count_estimate": len(user_results), + }) + + @defer.inlineCallbacks + def get_invited_users_in_group(self, group_id, requester_user_id): + """Get the users that have been invited to a group as seen by requester_user_id. + + The ordering is arbitrary at the moment + """ + + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) + + if not is_user_in_group: + raise SynapseError(403, "User not in group") + + invited_users = yield self.store.get_invited_users_in_group(group_id) + + user_profiles = [] + + for user_id in invited_users: + user_profile = { + "user_id": user_id + } + try: + profile = yield self.profile_handler.get_profile_from_cache(user_id) + user_profile.update(profile) + except Exception as e: + logger.warn("Error getting profile for %s: %s", user_id, e) + user_profiles.append(user_profile) + + defer.returnValue({ + "chunk": user_profiles, + "total_user_count_estimate": len(invited_users), + }) + + @defer.inlineCallbacks + def get_rooms_in_group(self, group_id, requester_user_id): + """Get the rooms in group as seen by requester_user_id + + This returns rooms in order of decreasing number of joined users + """ + + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) + + room_results = yield self.store.get_rooms_in_group( + group_id, include_private=is_user_in_group, + ) + + chunk = [] + for room_result in room_results: + room_id = room_result["room_id"] + + joined_users = yield self.store.get_users_in_room(room_id) + entry = yield self.room_list_handler.generate_room_entry( + room_id, len(joined_users), + with_alias=False, allow_private=True, + ) + + if not entry: + continue + + entry["is_public"] = bool(room_result["is_public"]) + + chunk.append(entry) + + chunk.sort(key=lambda e: -e["num_joined_members"]) + + defer.returnValue({ + "chunk": chunk, + "total_room_count_estimate": len(room_results), + }) + + @defer.inlineCallbacks + def add_room_to_group(self, group_id, requester_user_id, room_id, content): + """Add room to group + """ + RoomID.from_string(room_id) # Ensure valid room id + + yield self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id + ) + + is_public = _parse_visibility_from_contents(content) + + yield self.store.add_room_to_group(group_id, room_id, is_public=is_public) + + defer.returnValue({}) + + @defer.inlineCallbacks + def update_room_in_group(self, group_id, requester_user_id, room_id, config_key, + content): + """Update room in group + """ + RoomID.from_string(room_id) # Ensure valid room id + + yield self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id + ) + + if config_key == "m.visibility": + is_public = _parse_visibility_dict(content) + + yield self.store.update_room_in_group_visibility( + group_id, room_id, + is_public=is_public, + ) + else: + raise SynapseError(400, "Uknown config option") + + defer.returnValue({}) + + @defer.inlineCallbacks + def remove_room_from_group(self, group_id, requester_user_id, room_id): + """Remove room from group + """ + yield self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id + ) + + yield self.store.remove_room_from_group(group_id, room_id) + + defer.returnValue({}) + + @defer.inlineCallbacks + def invite_to_group(self, group_id, user_id, requester_user_id, content): + """Invite user to group + """ + + group = yield self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id + ) + + # TODO: Check if user knocked + # TODO: Check if user is already invited + + content = { + "profile": { + "name": group["name"], + "avatar_url": group["avatar_url"], + }, + "inviter": requester_user_id, + } + + if self.hs.is_mine_id(user_id): + groups_local = self.hs.get_groups_local_handler() + res = yield groups_local.on_invite(group_id, user_id, content) + local_attestation = None + else: + local_attestation = self.attestations.create_attestation(group_id, user_id) + content.update({ + "attestation": local_attestation, + }) + + res = yield self.transport_client.invite_to_group_notification( + get_domain_from_id(user_id), group_id, user_id, content + ) + + user_profile = res.get("user_profile", {}) + yield self.store.add_remote_profile_cache( + user_id, + displayname=user_profile.get("displayname"), + avatar_url=user_profile.get("avatar_url"), + ) + + if res["state"] == "join": + if not self.hs.is_mine_id(user_id): + remote_attestation = res["attestation"] + + yield self.attestations.verify_attestation( + remote_attestation, + user_id=user_id, + group_id=group_id, + ) + else: + remote_attestation = None + + yield self.store.add_user_to_group( + group_id, user_id, + is_admin=False, + is_public=False, # TODO + local_attestation=local_attestation, + remote_attestation=remote_attestation, + ) + elif res["state"] == "invite": + yield self.store.add_group_invite( + group_id, user_id, + ) + defer.returnValue({ + "state": "invite" + }) + elif res["state"] == "reject": + defer.returnValue({ + "state": "reject" + }) + else: + raise SynapseError(502, "Unknown state returned by HS") + + @defer.inlineCallbacks + def accept_invite(self, group_id, requester_user_id, content): + """User tries to accept an invite to the group. + + This is different from them asking to join, and so should error if no + invite exists (and they're not a member of the group) + """ + + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + is_invited = yield self.store.is_user_invited_to_local_group( + group_id, requester_user_id, + ) + if not is_invited: + raise SynapseError(403, "User not invited to group") + + if not self.hs.is_mine_id(requester_user_id): + local_attestation = self.attestations.create_attestation( + group_id, requester_user_id, + ) + remote_attestation = content["attestation"] + + yield self.attestations.verify_attestation( + remote_attestation, + user_id=requester_user_id, + group_id=group_id, + ) + else: + local_attestation = None + remote_attestation = None + + is_public = _parse_visibility_from_contents(content) + + yield self.store.add_user_to_group( + group_id, requester_user_id, + is_admin=False, + is_public=is_public, + local_attestation=local_attestation, + remote_attestation=remote_attestation, + ) + + defer.returnValue({ + "state": "join", + "attestation": local_attestation, + }) + + @defer.inlineCallbacks + def knock(self, group_id, requester_user_id, content): + """A user requests becoming a member of the group + """ + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + raise NotImplementedError() + + @defer.inlineCallbacks + def accept_knock(self, group_id, requester_user_id, content): + """Accept a users knock to the room. + + Errors if the user hasn't knocked, rather than inviting them. + """ + + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + raise NotImplementedError() + + @defer.inlineCallbacks + def remove_user_from_group(self, group_id, user_id, requester_user_id, content): + """Remove a user from the group; either a user is leaving or an admin + kicked them. + """ + + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + is_kick = False + if requester_user_id != user_id: + is_admin = yield self.store.is_user_admin_in_group( + group_id, requester_user_id + ) + if not is_admin: + raise SynapseError(403, "User is not admin in group") + + is_kick = True + + yield self.store.remove_user_from_group( + group_id, user_id, + ) + + if is_kick: + if self.hs.is_mine_id(user_id): + groups_local = self.hs.get_groups_local_handler() + yield groups_local.user_removed_from_group(group_id, user_id, {}) + else: + yield self.transport_client.remove_user_from_group_notification( + get_domain_from_id(user_id), group_id, user_id, {} + ) + + if not self.hs.is_mine_id(user_id): + yield self.store.maybe_delete_remote_profile_cache(user_id) + + defer.returnValue({}) + + @defer.inlineCallbacks + def create_group(self, group_id, requester_user_id, content): + group = yield self.check_group_is_ours(group_id, requester_user_id) + + logger.info("Attempting to create group with ID: %r", group_id) + + # parsing the id into a GroupID validates it. + group_id_obj = GroupID.from_string(group_id) + + if group: + raise SynapseError(400, "Group already exists") + + is_admin = yield self.auth.is_server_admin(UserID.from_string(requester_user_id)) + if not is_admin: + if not self.hs.config.enable_group_creation: + raise SynapseError( + 403, "Only a server admin can create groups on this server", + ) + localpart = group_id_obj.localpart + if not localpart.startswith(self.hs.config.group_creation_prefix): + raise SynapseError( + 400, + "Can only create groups with prefix %r on this server" % ( + self.hs.config.group_creation_prefix, + ), + ) + + profile = content.get("profile", {}) + name = profile.get("name") + avatar_url = profile.get("avatar_url") + short_description = profile.get("short_description") + long_description = profile.get("long_description") + user_profile = content.get("user_profile", {}) + + yield self.store.create_group( + group_id, + requester_user_id, + name=name, + avatar_url=avatar_url, + short_description=short_description, + long_description=long_description, + ) + + if not self.hs.is_mine_id(requester_user_id): + remote_attestation = content["attestation"] + + yield self.attestations.verify_attestation( + remote_attestation, + user_id=requester_user_id, + group_id=group_id, + ) + + local_attestation = self.attestations.create_attestation( + group_id, + requester_user_id, + ) + else: + local_attestation = None + remote_attestation = None + + yield self.store.add_user_to_group( + group_id, requester_user_id, + is_admin=True, + is_public=True, # TODO + local_attestation=local_attestation, + remote_attestation=remote_attestation, + ) + + if not self.hs.is_mine_id(requester_user_id): + yield self.store.add_remote_profile_cache( + requester_user_id, + displayname=user_profile.get("displayname"), + avatar_url=user_profile.get("avatar_url"), + ) + + defer.returnValue({ + "group_id": group_id, + }) + + +def _parse_visibility_from_contents(content): + """Given a content for a request parse out whether the entity should be + public or not + """ + + visibility = content.get("m.visibility") + if visibility: + return _parse_visibility_dict(visibility) + else: + is_public = True + + return is_public + + +def _parse_visibility_dict(visibility): + """Given a dict for the "m.visibility" config return if the entity should + be public or not + """ + vis_type = visibility.get("type") + if not vis_type: + return True + + if vis_type not in ("public", "private"): + raise SynapseError( + 400, "Synapse only supports 'public'/'private' visibility" + ) + return vis_type == "public" diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py index 5ad408f549..53213cdccf 100644 --- a/synapse/handlers/__init__.py +++ b/synapse/handlers/__init__.py @@ -20,7 +20,6 @@ from .room import ( from .room_member import RoomMemberHandler from .message import MessageHandler from .federation import FederationHandler -from .profile import ProfileHandler from .directory import DirectoryHandler from .admin import AdminHandler from .identity import IdentityHandler @@ -52,7 +51,6 @@ class Handlers(object): self.room_creation_handler = RoomCreationHandler(hs) self.room_member_handler = RoomMemberHandler(hs) self.federation_handler = FederationHandler(hs) - self.profile_handler = ProfileHandler(hs) self.directory_handler = DirectoryHandler(hs) self.admin_handler = AdminHandler(hs) self.identity_handler = IdentityHandler(hs) diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 05af54d31b..feca3e4c10 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -17,7 +17,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes from synapse.util.metrics import Measure -from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred +from synapse.util.logcontext import make_deferred_yieldable, preserve_fn import logging @@ -70,11 +70,10 @@ class ApplicationServicesHandler(object): with Measure(self.clock, "notify_interested_services"): self.is_processing = True try: - upper_bound = self.current_max limit = 100 while True: upper_bound, events = yield self.store.get_new_events_for_appservice( - upper_bound, limit + self.current_max, limit ) if not events: @@ -105,9 +104,6 @@ class ApplicationServicesHandler(object): ) yield self.store.set_appservice_last_pos(upper_bound) - - if len(events) < limit: - break finally: self.is_processing = False @@ -163,7 +159,7 @@ class ApplicationServicesHandler(object): def query_3pe(self, kind, protocol, fields): services = yield self._get_services_for_3pn(protocol) - results = yield preserve_context_over_deferred(defer.DeferredList([ + results = yield make_deferred_yieldable(defer.DeferredList([ preserve_fn(self.appservice_api.query_3pe)(service, kind, protocol, fields) for service in services ], consumeErrors=True)) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index b00446bec0..573c9db8a1 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -13,13 +13,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from twisted.internet import defer from ._base import BaseHandler from synapse.api.constants import LoginType +from synapse.api.errors import ( + AuthError, Codes, InteractiveAuthIncompleteError, LoginError, StoreError, + SynapseError, +) +from synapse.module_api import ModuleApi from synapse.types import UserID -from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError from synapse.util.async import run_on_reactor from synapse.util.caches.expiringcache import ExpiringCache @@ -46,7 +49,6 @@ class AuthHandler(BaseHandler): """ super(AuthHandler, self).__init__(hs) self.checkers = { - LoginType.PASSWORD: self._check_password_auth, LoginType.RECAPTCHA: self._check_recaptcha, LoginType.EMAIL_IDENTITY: self._check_email_identity, LoginType.MSISDN: self._check_msisdn, @@ -63,10 +65,7 @@ class AuthHandler(BaseHandler): reset_expiry_on_get=True, ) - account_handler = _AccountHandler( - hs, check_user_exists=self.check_user_exists - ) - + account_handler = ModuleApi(hs, self) self.password_providers = [ module(config=config, account_handler=account_handler) for module, config in hs.config.password_providers @@ -75,39 +74,120 @@ class AuthHandler(BaseHandler): logger.info("Extra password_providers: %r", self.password_providers) self.hs = hs # FIXME better possibility to access registrationHandler later? - self.device_handler = hs.get_device_handler() self.macaroon_gen = hs.get_macaroon_generator() + self._password_enabled = hs.config.password_enabled + + # we keep this as a list despite the O(N^2) implication so that we can + # keep PASSWORD first and avoid confusing clients which pick the first + # type in the list. (NB that the spec doesn't require us to do so and + # clients which favour types that they don't understand over those that + # they do are technically broken) + login_types = [] + if self._password_enabled: + login_types.append(LoginType.PASSWORD) + for provider in self.password_providers: + if hasattr(provider, "get_supported_login_types"): + for t in provider.get_supported_login_types().keys(): + if t not in login_types: + login_types.append(t) + self._supported_login_types = login_types + + @defer.inlineCallbacks + def validate_user_via_ui_auth(self, requester, request_body, clientip): + """ + Checks that the user is who they claim to be, via a UI auth. + + This is used for things like device deletion and password reset where + the user already has a valid access token, but we want to double-check + that it isn't stolen by re-authenticating them. + + Args: + requester (Requester): The user, as given by the access token + + request_body (dict): The body of the request sent by the client + + clientip (str): The IP address of the client. + + Returns: + defer.Deferred[dict]: the parameters for this request (which may + have been given only in a previous call). + + Raises: + InteractiveAuthIncompleteError if the client has not yet completed + any of the permitted login flows + + AuthError if the client has completed a login flow, and it gives + a different user to `requester` + """ + + # build a list of supported flows + flows = [ + [login_type] for login_type in self._supported_login_types + ] + + result, params, _ = yield self.check_auth( + flows, request_body, clientip, + ) + + # find the completed login type + for login_type in self._supported_login_types: + if login_type not in result: + continue + + user_id = result[login_type] + break + else: + # this can't happen + raise Exception( + "check_auth returned True but no successful login type", + ) + + # check that the UI auth matched the access token + if user_id != requester.user.to_string(): + raise AuthError(403, "Invalid auth") + + defer.returnValue(params) @defer.inlineCallbacks def check_auth(self, flows, clientdict, clientip): """ Takes a dictionary sent by the client in the login / registration - protocol and handles the login flow. + protocol and handles the User-Interactive Auth flow. As a side effect, this function fills in the 'creds' key on the user's session with a map, which maps each auth-type (str) to the relevant identity authenticated by that auth-type (mostly str, but for captcha, bool). + If no auth flows have been completed successfully, raises an + InteractiveAuthIncompleteError. To handle this, you can use + synapse.rest.client.v2_alpha._base.interactive_auth_handler as a + decorator. + Args: flows (list): A list of login flows. Each flow is an ordered list of strings representing auth-types. At least one full flow must be completed in order for auth to be successful. + clientdict: The dictionary from the client root level, not the 'auth' key: this method prompts for auth if none is sent. + clientip (str): The IP address of the client. + Returns: - A tuple of (authed, dict, dict, session_id) where authed is true if - the client has successfully completed an auth flow. If it is true - the first dict contains the authenticated credentials of each stage. + defer.Deferred[dict, dict, str]: a deferred tuple of + (creds, params, session_id). + + 'creds' contains the authenticated credentials of each stage. - If authed is false, the first dictionary is the server response to - the login request and should be passed back to the client. + 'params' contains the parameters for this request (which may + have been given only in a previous call). - In either case, the second dict contains the parameters for this - request (which may have been given only in a previous call). + 'session_id' is the ID of this session, either passed in by the + client or assigned by this call - session_id is the ID of this session, either passed in by the client - or assigned by the call to check_auth + Raises: + InteractiveAuthIncompleteError if the client has not yet completed + all the stages in any of the permitted flows. """ authdict = None @@ -135,11 +215,8 @@ class AuthHandler(BaseHandler): clientdict = session['clientdict'] if not authdict: - defer.returnValue( - ( - False, self._auth_dict_for_flows(flows, session), - clientdict, session['id'] - ) + raise InteractiveAuthIncompleteError( + self._auth_dict_for_flows(flows, session), ) if 'creds' not in session: @@ -150,14 +227,12 @@ class AuthHandler(BaseHandler): errordict = {} if 'type' in authdict: login_type = authdict['type'] - if login_type not in self.checkers: - raise LoginError(400, "", Codes.UNRECOGNIZED) try: - result = yield self.checkers[login_type](authdict, clientip) + result = yield self._check_auth_dict(authdict, clientip) if result: creds[login_type] = result self._save_session(session) - except LoginError, e: + except LoginError as e: if login_type == LoginType.EMAIL_IDENTITY: # riot used to have a bug where it would request a new # validation token (thus sending a new email) each time it @@ -166,7 +241,7 @@ class AuthHandler(BaseHandler): # # Grandfather in the old behaviour for now to avoid # breaking old riot deployments. - raise e + raise # this step failed. Merge the error dict into the response # so that the client can have another go. @@ -183,12 +258,14 @@ class AuthHandler(BaseHandler): "Auth completed with creds: %r. Client dict has keys: %r", creds, clientdict.keys() ) - defer.returnValue((True, creds, clientdict, session['id'])) + defer.returnValue((creds, clientdict, session['id'])) ret = self._auth_dict_for_flows(flows, session) ret['completed'] = creds.keys() ret.update(errordict) - defer.returnValue((False, ret, clientdict, session['id'])) + raise InteractiveAuthIncompleteError( + ret, + ) @defer.inlineCallbacks def add_oob_auth(self, stagetype, authdict, clientip): @@ -260,16 +337,37 @@ class AuthHandler(BaseHandler): sess = self._get_session_info(session_id) return sess.setdefault('serverdict', {}).get(key, default) - def _check_password_auth(self, authdict, _): - if "user" not in authdict or "password" not in authdict: - raise LoginError(400, "", Codes.MISSING_PARAM) + @defer.inlineCallbacks + def _check_auth_dict(self, authdict, clientip): + """Attempt to validate the auth dict provided by a client - user_id = authdict["user"] - password = authdict["password"] - if not user_id.startswith('@'): - user_id = UserID.create(user_id, self.hs.hostname).to_string() + Args: + authdict (object): auth dict provided by the client + clientip (str): IP address of the client - return self._check_password(user_id, password) + Returns: + Deferred: result of the stage verification. + + Raises: + StoreError if there was a problem accessing the database + SynapseError if there was a problem with the request + LoginError if there was an authentication problem. + """ + login_type = authdict['type'] + checker = self.checkers.get(login_type) + if checker is not None: + res = yield checker(authdict, clientip) + defer.returnValue(res) + + # build a v1-login-style dict out of the authdict and fall back to the + # v1 code + user_id = authdict.get("user") + + if user_id is None: + raise SynapseError(400, "", Codes.MISSING_PARAM) + + (canonical_id, callback) = yield self.validate_login(user_id, authdict) + defer.returnValue(canonical_id) @defer.inlineCallbacks def _check_recaptcha(self, authdict, clientip): @@ -398,26 +496,8 @@ class AuthHandler(BaseHandler): return self.sessions[session_id] - def validate_password_login(self, user_id, password): - """ - Authenticates the user with their username and password. - - Used only by the v1 login API. - - Args: - user_id (str): complete @user:id - password (str): Password - Returns: - defer.Deferred: (str) canonical user id - Raises: - StoreError if there was a problem accessing the database - LoginError if there was an authentication problem. - """ - return self._check_password(user_id, password) - @defer.inlineCallbacks - def get_access_token_for_user_id(self, user_id, device_id=None, - initial_display_name=None): + def get_access_token_for_user_id(self, user_id, device_id=None): """ Creates a new access token for the user with the given user ID. @@ -431,13 +511,10 @@ class AuthHandler(BaseHandler): device_id (str|None): the device ID to associate with the tokens. None to leave the tokens unassociated with a device (deprecated: we should always have a device ID) - initial_display_name (str): display name to associate with the - device if it needs re-registering Returns: The access token for the user's session. Raises: StoreError if there was a problem storing the token. - LoginError if there was an authentication problem. """ logger.info("Logging in user %s on device %s", user_id, device_id) access_token = yield self.issue_access_token(user_id, device_id) @@ -447,9 +524,11 @@ class AuthHandler(BaseHandler): # really don't want is active access_tokens without a record of the # device, so we double-check it here. if device_id is not None: - yield self.device_handler.check_device_registered( - user_id, device_id, initial_display_name - ) + try: + yield self.store.get_device(user_id, device_id) + except StoreError: + yield self.store.delete_access_token(access_token) + raise StoreError(400, "Login raced against device deletion") defer.returnValue(access_token) @@ -501,29 +580,115 @@ class AuthHandler(BaseHandler): ) defer.returnValue(result) + def get_supported_login_types(self): + """Get a the login types supported for the /login API + + By default this is just 'm.login.password' (unless password_enabled is + False in the config file), but password auth providers can provide + other login types. + + Returns: + Iterable[str]: login types + """ + return self._supported_login_types + @defer.inlineCallbacks - def _check_password(self, user_id, password): - """Authenticate a user against the LDAP and local databases. + def validate_login(self, username, login_submission): + """Authenticates the user for the /login API - user_id is checked case insensitively against the local database, but - will throw if there are multiple inexact matches. + Also used by the user-interactive auth flow to validate + m.login.password auth types. Args: - user_id (str): complete @user:id + username (str): username supplied by the user + login_submission (dict): the whole of the login submission + (including 'type' and other relevant fields) Returns: - (str) the canonical_user_id + Deferred[str, func]: canonical user id, and optional callback + to be called once the access token and device id are issued Raises: - LoginError if login fails + StoreError if there was a problem accessing the database + SynapseError if there was a problem with the request + LoginError if there was an authentication problem. """ + + if username.startswith('@'): + qualified_user_id = username + else: + qualified_user_id = UserID( + username, self.hs.hostname + ).to_string() + + login_type = login_submission.get("type") + known_login_type = False + + # special case to check for "password" for the check_password interface + # for the auth providers + password = login_submission.get("password") + if login_type == LoginType.PASSWORD: + if not self._password_enabled: + raise SynapseError(400, "Password login has been disabled.") + if not password: + raise SynapseError(400, "Missing parameter: password") + for provider in self.password_providers: - is_valid = yield provider.check_password(user_id, password) - if is_valid: - defer.returnValue(user_id) + if (hasattr(provider, "check_password") + and login_type == LoginType.PASSWORD): + known_login_type = True + is_valid = yield provider.check_password( + qualified_user_id, password, + ) + if is_valid: + defer.returnValue((qualified_user_id, None)) + + if (not hasattr(provider, "get_supported_login_types") + or not hasattr(provider, "check_auth")): + # this password provider doesn't understand custom login types + continue + + supported_login_types = provider.get_supported_login_types() + if login_type not in supported_login_types: + # this password provider doesn't understand this login type + continue + + known_login_type = True + login_fields = supported_login_types[login_type] + + missing_fields = [] + login_dict = {} + for f in login_fields: + if f not in login_submission: + missing_fields.append(f) + else: + login_dict[f] = login_submission[f] + if missing_fields: + raise SynapseError( + 400, "Missing parameters for login type %s: %s" % ( + login_type, + missing_fields, + ), + ) - canonical_user_id = yield self._check_local_password(user_id, password) + result = yield provider.check_auth( + username, login_type, login_dict, + ) + if result: + if isinstance(result, str): + result = (result, None) + defer.returnValue(result) + + if login_type == LoginType.PASSWORD: + known_login_type = True + + canonical_user_id = yield self._check_local_password( + qualified_user_id, password, + ) - if canonical_user_id: - defer.returnValue(canonical_user_id) + if canonical_user_id: + defer.returnValue((canonical_user_id, None)) + + if not known_login_type: + raise SynapseError(400, "Unknown login type %s" % login_type) # unknown username or invalid password. We raise a 403 here, but note # that if we're doing user-interactive login, it turns all LoginErrors @@ -573,22 +738,65 @@ class AuthHandler(BaseHandler): raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) @defer.inlineCallbacks - def set_password(self, user_id, newpassword, requester=None): - password_hash = self.hash(newpassword) + def delete_access_token(self, access_token): + """Invalidate a single access token - except_access_token_id = requester.access_token_id if requester else None + Args: + access_token (str): access token to be deleted - try: - yield self.store.user_set_password_hash(user_id, password_hash) - except StoreError as e: - if e.code == 404: - raise SynapseError(404, "Unknown user", Codes.NOT_FOUND) - raise e - yield self.store.user_delete_access_tokens( - user_id, except_access_token_id + Returns: + Deferred + """ + user_info = yield self.auth.get_user_by_access_token(access_token) + yield self.store.delete_access_token(access_token) + + # see if any of our auth providers want to know about this + for provider in self.password_providers: + if hasattr(provider, "on_logged_out"): + yield provider.on_logged_out( + user_id=str(user_info["user"]), + device_id=user_info["device_id"], + access_token=access_token, + ) + + # delete pushers associated with this access token + if user_info["token_id"] is not None: + yield self.hs.get_pusherpool().remove_pushers_by_access_token( + str(user_info["user"]), (user_info["token_id"], ) + ) + + @defer.inlineCallbacks + def delete_access_tokens_for_user(self, user_id, except_token_id=None, + device_id=None): + """Invalidate access tokens belonging to a user + + Args: + user_id (str): ID of user the tokens belong to + except_token_id (str|None): access_token ID which should *not* be + deleted + device_id (str|None): ID of device the tokens are associated with. + If None, tokens associated with any device (or no device) will + be deleted + Returns: + Deferred + """ + tokens_and_devices = yield self.store.user_delete_access_tokens( + user_id, except_token_id=except_token_id, device_id=device_id, ) - yield self.hs.get_pusherpool().remove_pushers_by_user( - user_id, except_access_token_id + + # see if any of our auth providers want to know about this + for provider in self.password_providers: + if hasattr(provider, "on_logged_out"): + for token, token_id, device_id in tokens_and_devices: + yield provider.on_logged_out( + user_id=user_id, + device_id=device_id, + access_token=token, + ) + + # delete pushers associated with the access tokens + yield self.hs.get_pusherpool().remove_pushers_by_access_token( + user_id, (token_id for _, token_id, _ in tokens_and_devices), ) @defer.inlineCallbacks @@ -696,30 +904,3 @@ class MacaroonGeneartor(object): macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) return macaroon - - -class _AccountHandler(object): - """A proxy object that gets passed to password auth providers so they - can register new users etc if necessary. - """ - def __init__(self, hs, check_user_exists): - self.hs = hs - - self._check_user_exists = check_user_exists - - def check_user_exists(self, user_id): - """Check if user exissts. - - Returns: - Deferred(bool) - """ - return self._check_user_exists(user_id) - - def register(self, localpart): - """Registers a new user with given localpart - - Returns: - Deferred: a 2-tuple of (user_id, access_token) - """ - reg = self.hs.get_handlers().registration_handler - return reg.register(localpart=localpart) diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py new file mode 100644 index 0000000000..b1d3814909 --- /dev/null +++ b/synapse/handlers/deactivate_account.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.internet import defer + +from ._base import BaseHandler + +import logging + +logger = logging.getLogger(__name__) + + +class DeactivateAccountHandler(BaseHandler): + """Handler which deals with deactivating user accounts.""" + def __init__(self, hs): + super(DeactivateAccountHandler, self).__init__(hs) + self._auth_handler = hs.get_auth_handler() + self._device_handler = hs.get_device_handler() + + @defer.inlineCallbacks + def deactivate_account(self, user_id): + """Deactivate a user's account + + Args: + user_id (str): ID of user to be deactivated + + Returns: + Deferred + """ + # FIXME: Theoretically there is a race here wherein user resets + # password using threepid. + + # first delete any devices belonging to the user, which will also + # delete corresponding access tokens. + yield self._device_handler.delete_all_devices_for_user(user_id) + # then delete any remaining access tokens which weren't associated with + # a device. + yield self._auth_handler.delete_access_tokens_for_user(user_id) + + yield self.store.user_delete_threepids(user_id) + yield self.store.user_set_password_hash(user_id, None) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index dac4b3f4e0..2152efc692 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -34,6 +34,7 @@ class DeviceHandler(BaseHandler): self.hs = hs self.state = hs.get_state_handler() + self._auth_handler = hs.get_auth_handler() self.federation_sender = hs.get_federation_sender() self.federation = hs.get_replication_layer() @@ -159,9 +160,8 @@ class DeviceHandler(BaseHandler): else: raise - yield self.store.user_delete_access_tokens( + yield self._auth_handler.delete_access_tokens_for_user( user_id, device_id=device_id, - delete_refresh_tokens=True, ) yield self.store.delete_e2e_keys_by_device( @@ -171,12 +171,30 @@ class DeviceHandler(BaseHandler): yield self.notify_device_update(user_id, [device_id]) @defer.inlineCallbacks + def delete_all_devices_for_user(self, user_id, except_device_id=None): + """Delete all of the user's devices + + Args: + user_id (str): + except_device_id (str|None): optional device id which should not + be deleted + + Returns: + defer.Deferred: + """ + device_map = yield self.store.get_devices_by_user(user_id) + device_ids = device_map.keys() + if except_device_id is not None: + device_ids = [d for d in device_ids if d != except_device_id] + yield self.delete_devices(user_id, device_ids) + + @defer.inlineCallbacks def delete_devices(self, user_id, device_ids): """ Delete several devices Args: user_id (str): - device_ids (str): The list of device IDs to delete + device_ids (List[str]): The list of device IDs to delete Returns: defer.Deferred: @@ -194,9 +212,8 @@ class DeviceHandler(BaseHandler): # Delete access tokens and e2e keys for each device. Not optimised as it is not # considered as part of a critical path. for device_id in device_ids: - yield self.store.user_delete_access_tokens( + yield self._auth_handler.delete_access_tokens_for_user( user_id, device_id=device_id, - delete_refresh_tokens=True, ) yield self.store.delete_e2e_keys_by_device( user_id=user_id, device_id=device_id diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 943554ce98..a0464ae5c0 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -40,6 +40,8 @@ class DirectoryHandler(BaseHandler): "directory", self.on_directory_query ) + self.spam_checker = hs.get_spam_checker() + @defer.inlineCallbacks def _create_association(self, room_alias, room_id, servers=None, creator=None): # general association creation for both human users and app services @@ -73,6 +75,11 @@ class DirectoryHandler(BaseHandler): # association creation for human users # TODO(erikj): Do user auth. + if not self.spam_checker.user_may_create_room_alias(user_id, room_alias): + raise SynapseError( + 403, "This user is not permitted to create this alias", + ) + can_create = yield self.can_modify_alias( room_alias, user_id=user_id @@ -327,6 +334,14 @@ class DirectoryHandler(BaseHandler): room_id (str) visibility (str): "public" or "private" """ + if not self.spam_checker.user_may_publish_room( + requester.user.to_string(), room_id + ): + raise AuthError( + 403, + "This user is not permitted to publish rooms to the room list" + ) + if requester.is_guest: raise AuthError(403, "Guests cannot edit the published room list") diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 18f87cad67..ac70730885 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -14,7 +14,6 @@ # limitations under the License. """Contains handlers for federation events.""" -import synapse.util.logcontext from signedjson.key import decode_verify_key_bytes from signedjson.sign import verify_signed_json from unpaddedbase64 import decode_base64 @@ -26,10 +25,7 @@ from synapse.api.errors import ( ) from synapse.api.constants import EventTypes, Membership, RejectedReason from synapse.events.validator import EventValidator -from synapse.util import unwrapFirstError -from synapse.util.logcontext import ( - preserve_fn, preserve_context_over_deferred -) +from synapse.util import unwrapFirstError, logcontext from synapse.util.metrics import measure_func from synapse.util.logutils import log_function from synapse.util.async import run_on_reactor, Linearizer @@ -77,6 +73,7 @@ class FederationHandler(BaseHandler): self.action_generator = hs.get_action_generator() self.is_mine_id = hs.is_mine_id self.pusher_pool = hs.get_pusherpool() + self.spam_checker = hs.get_spam_checker() self.replication_layer.set_handler(self) @@ -125,6 +122,28 @@ class FederationHandler(BaseHandler): self.room_queues[pdu.room_id].append((pdu, origin)) return + # If we're no longer in the room just ditch the event entirely. This + # is probably an old server that has come back and thinks we're still + # in the room (or we've been rejoined to the room by a state reset). + # + # If we were never in the room then maybe our database got vaped and + # we should check if we *are* in fact in the room. If we are then we + # can magically rejoin the room. + is_in_room = yield self.auth.check_host_in_room( + pdu.room_id, + self.server_name + ) + if not is_in_room: + was_in_room = yield self.store.was_host_joined( + pdu.room_id, self.server_name, + ) + if was_in_room: + logger.info( + "Ignoring PDU %s for room %s from %s as we've left the room!", + pdu.event_id, pdu.room_id, origin, + ) + return + state = None auth_chain = [] @@ -208,7 +227,7 @@ class FederationHandler(BaseHandler): state, auth_chain = yield self.replication_layer.get_state_for_room( origin, pdu.room_id, pdu.event_id, ) - except: + except Exception: logger.exception("Failed to get state for event: %s", pdu.event_id) yield self._process_received_pdu( @@ -442,7 +461,7 @@ class FederationHandler(BaseHandler): def check_match(id): try: return server_name == get_domain_from_id(id) - except: + except Exception: return False # Parses mapping `event_id -> (type, state_key) -> state event_id` @@ -480,7 +499,7 @@ class FederationHandler(BaseHandler): continue try: domain = get_domain_from_id(ev.state_key) - except: + except Exception: continue if domain != server_name: @@ -591,9 +610,9 @@ class FederationHandler(BaseHandler): missing_auth - failed_to_fetch ) - results = yield preserve_context_over_deferred(defer.gatherResults( + results = yield logcontext.make_deferred_yieldable(defer.gatherResults( [ - preserve_fn(self.replication_layer.get_pdu)( + logcontext.preserve_fn(self.replication_layer.get_pdu)( [dest], event_id, outlier=True, @@ -719,7 +738,7 @@ class FederationHandler(BaseHandler): joined_domains[dom] = min(d, old_d) else: joined_domains[dom] = d - except: + except Exception: pass return sorted(joined_domains.items(), key=lambda d: d[1]) @@ -785,10 +804,14 @@ class FederationHandler(BaseHandler): event_ids = list(extremities.keys()) logger.debug("calling resolve_state_groups in _maybe_backfill") - states = yield preserve_context_over_deferred(defer.gatherResults([ - preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e]) - for e in event_ids - ])) + states = yield logcontext.make_deferred_yieldable(defer.gatherResults( + [ + logcontext.preserve_fn(self.state_handler.resolve_state_groups)( + room_id, [e] + ) + for e in event_ids + ], consumeErrors=True, + )) states = dict(zip(event_ids, [s.state for s in states])) state_map = yield self.store.get_events( @@ -917,7 +940,7 @@ class FederationHandler(BaseHandler): room_creator_user_id="", is_public=False ) - except: + except Exception: # FIXME pass @@ -941,9 +964,7 @@ class FederationHandler(BaseHandler): # lots of requests for missing prev_events which we do actually # have. Hence we fire off the deferred, but don't wait for it. - synapse.util.logcontext.preserve_fn(self._handle_queued_pdus)( - room_queue - ) + logcontext.preserve_fn(self._handle_queued_pdus)(room_queue) defer.returnValue(True) @@ -1070,6 +1091,9 @@ class FederationHandler(BaseHandler): """ event = pdu + if event.state_key is None: + raise SynapseError(400, "The invite event did not have a state key") + is_blocked = yield self.store.is_room_blocked(event.room_id) if is_blocked: raise SynapseError(403, "This room has been blocked on this server") @@ -1077,6 +1101,13 @@ class FederationHandler(BaseHandler): if self.hs.config.block_non_admin_invites: raise SynapseError(403, "This server does not accept room invites") + if not self.spam_checker.user_may_invite( + event.sender, event.state_key, event.room_id, + ): + raise SynapseError( + 403, "This user is not permitted to send invites to this server/user" + ) + membership = event.content.get("membership") if event.type != EventTypes.Member or membership != Membership.INVITE: raise SynapseError(400, "The event was not an m.room.member invite event") @@ -1085,9 +1116,6 @@ class FederationHandler(BaseHandler): if sender_domain != origin: raise SynapseError(400, "The invite event was not from the server sending it") - if event.state_key is None: - raise SynapseError(400, "The invite event did not have a state key") - if not self.is_mine_id(event.state_key): raise SynapseError(400, "The invite event must be for this server") @@ -1430,7 +1458,7 @@ class FederationHandler(BaseHandler): if not backfilled: # this intentionally does not yield: we don't care about the result # and don't need to wait for it. - preserve_fn(self.pusher_pool.on_new_notifications)( + logcontext.preserve_fn(self.pusher_pool.on_new_notifications)( event_stream_id, max_stream_id ) @@ -1443,16 +1471,16 @@ class FederationHandler(BaseHandler): a bunch of outliers, but not a chunk of individual events that depend on each other for state calculations. """ - contexts = yield preserve_context_over_deferred(defer.gatherResults( + contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults( [ - preserve_fn(self._prep_event)( + logcontext.preserve_fn(self._prep_event)( origin, ev_info["event"], state=ev_info.get("state"), auth_events=ev_info.get("auth_events"), ) for ev_info in event_infos - ] + ], consumeErrors=True, )) yield self.store.persist_events( @@ -1678,6 +1706,17 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks @log_function def do_auth(self, origin, event, context, auth_events): + """ + + Args: + origin (str): + event (synapse.events.FrozenEvent): + context (synapse.events.snapshot.EventContext): + auth_events (dict[(str, str)->str]): + + Returns: + defer.Deferred[None] + """ # Check if we have all the auth events. current_state = set(e.event_id for e in auth_events.values()) event_auth_events = set(e_id for e_id, _ in event.auth_events) @@ -1747,7 +1786,7 @@ class FederationHandler(BaseHandler): [e_id for e_id, _ in event.auth_events] ) seen_events = set(have_events.keys()) - except: + except Exception: # FIXME: logger.exception("Failed to get auth chain") @@ -1760,18 +1799,17 @@ class FederationHandler(BaseHandler): # Do auth conflict res. logger.info("Different auth: %s", different_auth) - different_events = yield preserve_context_over_deferred(defer.gatherResults( - [ - preserve_fn(self.store.get_event)( + different_events = yield logcontext.make_deferred_yieldable( + defer.gatherResults([ + logcontext.preserve_fn(self.store.get_event)( d, allow_none=True, allow_rejected=False, ) for d in different_auth if d in have_events and not have_events[d] - ], - consumeErrors=True - )).addErrback(unwrapFirstError) + ], consumeErrors=True) + ).addErrback(unwrapFirstError) if different_events: local_view = dict(auth_events) @@ -1790,16 +1828,9 @@ class FederationHandler(BaseHandler): current_state = set(e.event_id for e in auth_events.values()) different_auth = event_auth_events - current_state - context.current_state_ids = dict(context.current_state_ids) - context.current_state_ids.update({ - k: a.event_id for k, a in auth_events.items() - if k != event_key - }) - context.prev_state_ids = dict(context.prev_state_ids) - context.prev_state_ids.update({ - k: a.event_id for k, a in auth_events.items() - }) - context.state_group = self.store.get_next_state_group() + self._update_context_for_auth_events( + context, auth_events, event_key, + ) if different_auth and not event.internal_metadata.is_outlier(): logger.info("Different auth after resolution: %s", different_auth) @@ -1872,23 +1903,16 @@ class FederationHandler(BaseHandler): except AuthError: pass - except: + except Exception: # FIXME: logger.exception("Failed to query auth chain") # 4. Look at rejects and their proofs. # TODO. - context.current_state_ids = dict(context.current_state_ids) - context.current_state_ids.update({ - k: a.event_id for k, a in auth_events.items() - if k != event_key - }) - context.prev_state_ids = dict(context.prev_state_ids) - context.prev_state_ids.update({ - k: a.event_id for k, a in auth_events.items() - }) - context.state_group = self.store.get_next_state_group() + self._update_context_for_auth_events( + context, auth_events, event_key, + ) try: self.auth.check(event, auth_events=auth_events) @@ -1896,6 +1920,35 @@ class FederationHandler(BaseHandler): logger.warn("Failed auth resolution for %r because %s", event, e) raise e + def _update_context_for_auth_events(self, context, auth_events, + event_key): + """Update the state_ids in an event context after auth event resolution + + Args: + context (synapse.events.snapshot.EventContext): event context + to be updated + + auth_events (dict[(str, str)->str]): Events to update in the event + context. + + event_key ((str, str)): (type, state_key) for the current event. + this will not be included in the current_state in the context. + """ + state_updates = { + k: a.event_id for k, a in auth_events.iteritems() + if k != event_key + } + context.current_state_ids = dict(context.current_state_ids) + context.current_state_ids.update(state_updates) + if context.delta_ids is not None: + context.delta_ids = dict(context.delta_ids) + context.delta_ids.update(state_updates) + context.prev_state_ids = dict(context.prev_state_ids) + context.prev_state_ids.update({ + k: a.event_id for k, a in auth_events.iteritems() + }) + context.state_group = self.store.get_next_state_group() + @defer.inlineCallbacks def construct_auth_difference(self, local_auth, remote_auth): """ Given a local and remote auth chain, find the differences. This @@ -1939,7 +1992,7 @@ class FederationHandler(BaseHandler): def get_next(it, opt=None): try: return it.next() - except: + except Exception: return opt current_local = get_next(local_iter) diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py new file mode 100644 index 0000000000..7e5d3f148d --- /dev/null +++ b/synapse/handlers/groups_local.py @@ -0,0 +1,429 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.api.errors import SynapseError +from synapse.types import get_domain_from_id + +import logging + +logger = logging.getLogger(__name__) + + +def _create_rerouter(func_name): + """Returns a function that looks at the group id and calls the function + on federation or the local group server if the group is local + """ + def f(self, group_id, *args, **kwargs): + if self.is_mine_id(group_id): + return getattr(self.groups_server_handler, func_name)( + group_id, *args, **kwargs + ) + else: + destination = get_domain_from_id(group_id) + return getattr(self.transport_client, func_name)( + destination, group_id, *args, **kwargs + ) + return f + + +class GroupsLocalHandler(object): + def __init__(self, hs): + self.hs = hs + self.store = hs.get_datastore() + self.room_list_handler = hs.get_room_list_handler() + self.groups_server_handler = hs.get_groups_server_handler() + self.transport_client = hs.get_federation_transport_client() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.keyring = hs.get_keyring() + self.is_mine_id = hs.is_mine_id + self.signing_key = hs.config.signing_key[0] + self.server_name = hs.hostname + self.notifier = hs.get_notifier() + self.attestations = hs.get_groups_attestation_signing() + + self.profile_handler = hs.get_profile_handler() + + # Ensure attestations get renewed + hs.get_groups_attestation_renewer() + + # The following functions merely route the query to the local groups server + # or federation depending on if the group is local or remote + + get_group_profile = _create_rerouter("get_group_profile") + update_group_profile = _create_rerouter("update_group_profile") + get_rooms_in_group = _create_rerouter("get_rooms_in_group") + + get_invited_users_in_group = _create_rerouter("get_invited_users_in_group") + + add_room_to_group = _create_rerouter("add_room_to_group") + update_room_in_group = _create_rerouter("update_room_in_group") + remove_room_from_group = _create_rerouter("remove_room_from_group") + + update_group_summary_room = _create_rerouter("update_group_summary_room") + delete_group_summary_room = _create_rerouter("delete_group_summary_room") + + update_group_category = _create_rerouter("update_group_category") + delete_group_category = _create_rerouter("delete_group_category") + get_group_category = _create_rerouter("get_group_category") + get_group_categories = _create_rerouter("get_group_categories") + + update_group_summary_user = _create_rerouter("update_group_summary_user") + delete_group_summary_user = _create_rerouter("delete_group_summary_user") + + update_group_role = _create_rerouter("update_group_role") + delete_group_role = _create_rerouter("delete_group_role") + get_group_role = _create_rerouter("get_group_role") + get_group_roles = _create_rerouter("get_group_roles") + + @defer.inlineCallbacks + def get_group_summary(self, group_id, requester_user_id): + """Get the group summary for a group. + + If the group is remote we check that the users have valid attestations. + """ + if self.is_mine_id(group_id): + res = yield self.groups_server_handler.get_group_summary( + group_id, requester_user_id + ) + else: + res = yield self.transport_client.get_group_summary( + get_domain_from_id(group_id), group_id, requester_user_id, + ) + + group_server_name = get_domain_from_id(group_id) + + # Loop through the users and validate the attestations. + chunk = res["users_section"]["users"] + valid_users = [] + for entry in chunk: + g_user_id = entry["user_id"] + attestation = entry.pop("attestation", {}) + try: + if get_domain_from_id(g_user_id) != group_server_name: + yield self.attestations.verify_attestation( + attestation, + group_id=group_id, + user_id=g_user_id, + server_name=get_domain_from_id(g_user_id), + ) + valid_users.append(entry) + except Exception as e: + logger.info("Failed to verify user is in group: %s", e) + + res["users_section"]["users"] = valid_users + + res["users_section"]["users"].sort(key=lambda e: e.get("order", 0)) + res["rooms_section"]["rooms"].sort(key=lambda e: e.get("order", 0)) + + # Add `is_publicised` flag to indicate whether the user has publicised their + # membership of the group on their profile + result = yield self.store.get_publicised_groups_for_user(requester_user_id) + is_publicised = group_id in result + + res.setdefault("user", {})["is_publicised"] = is_publicised + + defer.returnValue(res) + + @defer.inlineCallbacks + def create_group(self, group_id, user_id, content): + """Create a group + """ + + logger.info("Asking to create group with ID: %r", group_id) + + if self.is_mine_id(group_id): + res = yield self.groups_server_handler.create_group( + group_id, user_id, content + ) + local_attestation = None + remote_attestation = None + else: + local_attestation = self.attestations.create_attestation(group_id, user_id) + content["attestation"] = local_attestation + + content["user_profile"] = yield self.profile_handler.get_profile(user_id) + + res = yield self.transport_client.create_group( + get_domain_from_id(group_id), group_id, user_id, content, + ) + + remote_attestation = res["attestation"] + yield self.attestations.verify_attestation( + remote_attestation, + group_id=group_id, + user_id=user_id, + server_name=get_domain_from_id(group_id), + ) + + is_publicised = content.get("publicise", False) + token = yield self.store.register_user_group_membership( + group_id, user_id, + membership="join", + is_admin=True, + local_attestation=local_attestation, + remote_attestation=remote_attestation, + is_publicised=is_publicised, + ) + self.notifier.on_new_event( + "groups_key", token, users=[user_id], + ) + + defer.returnValue(res) + + @defer.inlineCallbacks + def get_users_in_group(self, group_id, requester_user_id): + """Get users in a group + """ + if self.is_mine_id(group_id): + res = yield self.groups_server_handler.get_users_in_group( + group_id, requester_user_id + ) + defer.returnValue(res) + + group_server_name = get_domain_from_id(group_id) + + res = yield self.transport_client.get_users_in_group( + get_domain_from_id(group_id), group_id, requester_user_id, + ) + + chunk = res["chunk"] + valid_entries = [] + for entry in chunk: + g_user_id = entry["user_id"] + attestation = entry.pop("attestation", {}) + try: + if get_domain_from_id(g_user_id) != group_server_name: + yield self.attestations.verify_attestation( + attestation, + group_id=group_id, + user_id=g_user_id, + server_name=get_domain_from_id(g_user_id), + ) + valid_entries.append(entry) + except Exception as e: + logger.info("Failed to verify user is in group: %s", e) + + res["chunk"] = valid_entries + + defer.returnValue(res) + + @defer.inlineCallbacks + def join_group(self, group_id, user_id, content): + """Request to join a group + """ + raise NotImplementedError() # TODO + + @defer.inlineCallbacks + def accept_invite(self, group_id, user_id, content): + """Accept an invite to a group + """ + if self.is_mine_id(group_id): + yield self.groups_server_handler.accept_invite( + group_id, user_id, content + ) + local_attestation = None + remote_attestation = None + else: + local_attestation = self.attestations.create_attestation(group_id, user_id) + content["attestation"] = local_attestation + + res = yield self.transport_client.accept_group_invite( + get_domain_from_id(group_id), group_id, user_id, content, + ) + + remote_attestation = res["attestation"] + + yield self.attestations.verify_attestation( + remote_attestation, + group_id=group_id, + user_id=user_id, + server_name=get_domain_from_id(group_id), + ) + + # TODO: Check that the group is public and we're being added publically + is_publicised = content.get("publicise", False) + + token = yield self.store.register_user_group_membership( + group_id, user_id, + membership="join", + is_admin=False, + local_attestation=local_attestation, + remote_attestation=remote_attestation, + is_publicised=is_publicised, + ) + self.notifier.on_new_event( + "groups_key", token, users=[user_id], + ) + + defer.returnValue({}) + + @defer.inlineCallbacks + def invite(self, group_id, user_id, requester_user_id, config): + """Invite a user to a group + """ + content = { + "requester_user_id": requester_user_id, + "config": config, + } + if self.is_mine_id(group_id): + res = yield self.groups_server_handler.invite_to_group( + group_id, user_id, requester_user_id, content, + ) + else: + res = yield self.transport_client.invite_to_group( + get_domain_from_id(group_id), group_id, user_id, requester_user_id, + content, + ) + + defer.returnValue(res) + + @defer.inlineCallbacks + def on_invite(self, group_id, user_id, content): + """One of our users were invited to a group + """ + # TODO: Support auto join and rejection + + if not self.is_mine_id(user_id): + raise SynapseError(400, "User not on this server") + + local_profile = {} + if "profile" in content: + if "name" in content["profile"]: + local_profile["name"] = content["profile"]["name"] + if "avatar_url" in content["profile"]: + local_profile["avatar_url"] = content["profile"]["avatar_url"] + + token = yield self.store.register_user_group_membership( + group_id, user_id, + membership="invite", + content={"profile": local_profile, "inviter": content["inviter"]}, + ) + self.notifier.on_new_event( + "groups_key", token, users=[user_id], + ) + try: + user_profile = yield self.profile_handler.get_profile(user_id) + except Exception as e: + logger.warn("No profile for user %s: %s", user_id, e) + user_profile = {} + + defer.returnValue({"state": "invite", "user_profile": user_profile}) + + @defer.inlineCallbacks + def remove_user_from_group(self, group_id, user_id, requester_user_id, content): + """Remove a user from a group + """ + if user_id == requester_user_id: + token = yield self.store.register_user_group_membership( + group_id, user_id, + membership="leave", + ) + self.notifier.on_new_event( + "groups_key", token, users=[user_id], + ) + + # TODO: Should probably remember that we tried to leave so that we can + # retry if the group server is currently down. + + if self.is_mine_id(group_id): + res = yield self.groups_server_handler.remove_user_from_group( + group_id, user_id, requester_user_id, content, + ) + else: + content["requester_user_id"] = requester_user_id + res = yield self.transport_client.remove_user_from_group( + get_domain_from_id(group_id), group_id, requester_user_id, + user_id, content, + ) + + defer.returnValue(res) + + @defer.inlineCallbacks + def user_removed_from_group(self, group_id, user_id, content): + """One of our users was removed/kicked from a group + """ + # TODO: Check if user in group + token = yield self.store.register_user_group_membership( + group_id, user_id, + membership="leave", + ) + self.notifier.on_new_event( + "groups_key", token, users=[user_id], + ) + + @defer.inlineCallbacks + def get_joined_groups(self, user_id): + group_ids = yield self.store.get_joined_groups(user_id) + defer.returnValue({"groups": group_ids}) + + @defer.inlineCallbacks + def get_publicised_groups_for_user(self, user_id): + if self.hs.is_mine_id(user_id): + result = yield self.store.get_publicised_groups_for_user(user_id) + + # Check AS associated groups for this user - this depends on the + # RegExps in the AS registration file (under `users`) + for app_service in self.store.get_app_services(): + result.extend(app_service.get_groups_for_user(user_id)) + + defer.returnValue({"groups": result}) + else: + result = yield self.transport_client.get_publicised_groups_for_user( + get_domain_from_id(user_id), user_id + ) + # TODO: Verify attestations + defer.returnValue(result) + + @defer.inlineCallbacks + def bulk_get_publicised_groups(self, user_ids, proxy=True): + destinations = {} + local_users = set() + + for user_id in user_ids: + if self.hs.is_mine_id(user_id): + local_users.add(user_id) + else: + destinations.setdefault( + get_domain_from_id(user_id), set() + ).add(user_id) + + if not proxy and destinations: + raise SynapseError(400, "Some user_ids are not local") + + results = {} + failed_results = [] + for destination, dest_user_ids in destinations.iteritems(): + try: + r = yield self.transport_client.bulk_get_publicised_groups( + destination, list(dest_user_ids), + ) + results.update(r["users"]) + except Exception: + failed_results.extend(dest_user_ids) + + for uid in local_users: + results[uid] = yield self.store.get_publicised_groups_for_user( + uid + ) + + # Check AS associated groups for this user - this depends on the + # RegExps in the AS registration file (under `users`) + for app_service in self.store.get_app_services(): + results[uid].extend(app_service.get_groups_for_user(uid)) + + defer.returnValue({"users": results}) diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 10f5f35a69..c5267b4b84 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -27,7 +27,7 @@ from synapse.types import ( from synapse.util import unwrapFirstError from synapse.util.async import concurrently_execute from synapse.util.caches.snapshot_cache import SnapshotCache -from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred +from synapse.util.logcontext import make_deferred_yieldable, preserve_fn from synapse.visibility import filter_events_for_client from ._base import BaseHandler @@ -163,7 +163,7 @@ class InitialSyncHandler(BaseHandler): lambda states: states[event.event_id] ) - (messages, token), current_state = yield preserve_context_over_deferred( + (messages, token), current_state = yield make_deferred_yieldable( defer.gatherResults( [ preserve_fn(self.store.get_recent_events_for_room)( @@ -214,7 +214,7 @@ class InitialSyncHandler(BaseHandler): }) d["account_data"] = account_data_events - except: + except Exception: logger.exception("Failed to get snapshot") yield concurrently_execute(handle_room, room_list, 10) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index f6740544c1..21f1717dd2 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014 - 2016 OpenMarket Ltd +# Copyright 2017 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,6 +26,7 @@ from synapse.types import ( from synapse.util.async import run_on_reactor, ReadWriteLock, Limiter from synapse.util.logcontext import preserve_fn from synapse.util.metrics import measure_func +from synapse.util.frozenutils import unfreeze from synapse.visibility import filter_events_for_client from ._base import BaseHandler @@ -46,6 +48,7 @@ class MessageHandler(BaseHandler): self.state = hs.get_state_handler() self.clock = hs.get_clock() self.validator = EventValidator() + self.profile_handler = hs.get_profile_handler() self.pagination_lock = ReadWriteLock() @@ -211,7 +214,7 @@ class MessageHandler(BaseHandler): if membership in {Membership.JOIN, Membership.INVITE}: # If event doesn't include a display name, add one. - profile = self.hs.get_handlers().profile_handler + profile = self.profile_handler content = builder.content try: @@ -323,9 +326,12 @@ class MessageHandler(BaseHandler): txn_id=txn_id ) - if self.spam_checker.check_event_for_spam(event): + spam_error = self.spam_checker.check_event_for_spam(event) + if spam_error: + if not isinstance(spam_error, basestring): + spam_error = "Spam is not permitted here" raise SynapseError( - 403, "Spam is not permitted here", Codes.FORBIDDEN + 403, spam_error, Codes.FORBIDDEN ) yield self.send_nonmember_event( @@ -420,27 +426,41 @@ class MessageHandler(BaseHandler): ) @defer.inlineCallbacks - def get_joined_members(self, user_id, room_id): + def get_joined_members(self, requester, room_id): """Get all the joined members in the room and their profile information. If the user has left the room return the state events from when they left. Args: - user_id(str): The user requesting state events. + requester(Requester): The user requesting state events. room_id(str): The room ID to get all state events from. Returns: A dict of user_id to profile info """ - membership, membership_event_id = yield self._check_in_room_or_world_readable( - room_id, user_id - ) - - if membership == Membership.JOIN: - users_with_profile = yield self.state.get_current_user_in_room(room_id) - else: - raise NotImplementedError( - "Getting joined members after leaving is not implemented" + user_id = requester.user.to_string() + if not requester.app_service: + # We check AS auth after fetching the room membership, as it + # requires us to pull out all joined members anyway. + membership, _ = yield self._check_in_room_or_world_readable( + room_id, user_id ) + if membership != Membership.JOIN: + raise NotImplementedError( + "Getting joined members after leaving is not implemented" + ) + + users_with_profile = yield self.state.get_current_user_in_room(room_id) + + # If this is an AS, double check that they are allowed to see the members. + # This can either be because the AS user is in the room or becuase there + # is a user in the room that the AS is "interested in" + if requester.app_service and user_id not in users_with_profile: + for uid in users_with_profile: + if requester.app_service.is_interested_in_user(uid): + break + else: + # Loop fell through, AS has no interested users in room + raise AuthError(403, "Appservice not in room") defer.returnValue({ user_id: { @@ -541,9 +561,9 @@ class MessageHandler(BaseHandler): # Ensure that we can round trip before trying to persist in db try: - dump = ujson.dumps(event.content) + dump = ujson.dumps(unfreeze(event.content)) ujson.loads(dump) - except: + except Exception: logger.exception("Failed to encode content: %r", event.content) raise diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index c7c0b0a1e2..cb158ba962 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -364,7 +364,7 @@ class PresenceHandler(object): ) preserve_fn(self._update_states)(changes) - except: + except Exception: logger.exception("Exception in _handle_timeouts loop") @defer.inlineCallbacks @@ -1199,7 +1199,7 @@ def handle_timeout(state, is_mine, syncing_user_ids, now): ) changed = True else: - # We expect to be poked occaisonally by the other side. + # We expect to be poked occasionally by the other side. # This is to protect against forgetful/buggy servers, so that # no one gets stuck online forever. if now - state.last_federation_update_ts > FEDERATION_TIMEOUT: diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 7abee98dea..9800e24453 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -17,16 +17,16 @@ import logging from twisted.internet import defer -import synapse.types from synapse.api.errors import SynapseError, AuthError, CodeMessageException -from synapse.types import UserID +from synapse.types import UserID, get_domain_from_id from ._base import BaseHandler - logger = logging.getLogger(__name__) class ProfileHandler(BaseHandler): + PROFILE_UPDATE_MS = 60 * 1000 + PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000 def __init__(self, hs): super(ProfileHandler, self).__init__(hs) @@ -36,6 +36,65 @@ class ProfileHandler(BaseHandler): "profile", self.on_profile_query ) + self.user_directory_handler = hs.get_user_directory_handler() + + self.clock.looping_call(self._update_remote_profile_cache, self.PROFILE_UPDATE_MS) + + @defer.inlineCallbacks + def get_profile(self, user_id): + target_user = UserID.from_string(user_id) + if self.hs.is_mine(target_user): + displayname = yield self.store.get_profile_displayname( + target_user.localpart + ) + avatar_url = yield self.store.get_profile_avatar_url( + target_user.localpart + ) + + defer.returnValue({ + "displayname": displayname, + "avatar_url": avatar_url, + }) + else: + try: + result = yield self.federation.make_query( + destination=target_user.domain, + query_type="profile", + args={ + "user_id": user_id, + }, + ignore_backoff=True, + ) + defer.returnValue(result) + except CodeMessageException as e: + if e.code != 404: + logger.exception("Failed to get displayname") + + raise + + @defer.inlineCallbacks + def get_profile_from_cache(self, user_id): + """Get the profile information from our local cache. If the user is + ours then the profile information will always be corect. Otherwise, + it may be out of date/missing. + """ + target_user = UserID.from_string(user_id) + if self.hs.is_mine(target_user): + displayname = yield self.store.get_profile_displayname( + target_user.localpart + ) + avatar_url = yield self.store.get_profile_avatar_url( + target_user.localpart + ) + + defer.returnValue({ + "displayname": displayname, + "avatar_url": avatar_url, + }) + else: + profile = yield self.store.get_from_remote_profile_cache(user_id) + defer.returnValue(profile or {}) + @defer.inlineCallbacks def get_displayname(self, target_user): if self.hs.is_mine(target_user): @@ -60,7 +119,7 @@ class ProfileHandler(BaseHandler): logger.exception("Failed to get displayname") raise - except: + except Exception: logger.exception("Failed to get displayname") else: defer.returnValue(result["displayname"]) @@ -82,7 +141,13 @@ class ProfileHandler(BaseHandler): target_user.localpart, new_displayname ) - yield self._update_join_states(requester) + if self.hs.config.user_directory_search_all_users: + profile = yield self.store.get_profileinfo(target_user.localpart) + yield self.user_directory_handler.handle_local_profile_change( + target_user.to_string(), profile + ) + + yield self._update_join_states(requester, target_user) @defer.inlineCallbacks def get_avatar_url(self, target_user): @@ -107,7 +172,7 @@ class ProfileHandler(BaseHandler): if e.code != 404: logger.exception("Failed to get avatar_url") raise - except: + except Exception: logger.exception("Failed to get avatar_url") defer.returnValue(result["avatar_url"]) @@ -126,7 +191,13 @@ class ProfileHandler(BaseHandler): target_user.localpart, new_avatar_url ) - yield self._update_join_states(requester) + if self.hs.config.user_directory_search_all_users: + profile = yield self.store.get_profileinfo(target_user.localpart) + yield self.user_directory_handler.handle_local_profile_change( + target_user.to_string(), profile + ) + + yield self._update_join_states(requester, target_user) @defer.inlineCallbacks def on_profile_query(self, args): @@ -151,28 +222,24 @@ class ProfileHandler(BaseHandler): defer.returnValue(response) @defer.inlineCallbacks - def _update_join_states(self, requester): - user = requester.user - if not self.hs.is_mine(user): + def _update_join_states(self, requester, target_user): + if not self.hs.is_mine(target_user): return yield self.ratelimit(requester) room_ids = yield self.store.get_rooms_for_user( - user.to_string(), + target_user.to_string(), ) for room_id in room_ids: handler = self.hs.get_handlers().room_member_handler try: - # Assume the user isn't a guest because we don't let guests set - # profile or avatar data. - # XXX why are we recreating `requester` here for each room? - # what was wrong with the `requester` we were passed? - requester = synapse.types.create_requester(user) + # Assume the target_user isn't a guest, + # because we don't let guests set profile or avatar data. yield handler.update_membership( requester, - user, + target_user, room_id, "join", # We treat a profile update like a join. ratelimit=False, # Try to hide that these events aren't atomic. @@ -182,3 +249,44 @@ class ProfileHandler(BaseHandler): "Failed to update join event for room %s - %s", room_id, str(e.message) ) + + def _update_remote_profile_cache(self): + """Called periodically to check profiles of remote users we haven't + checked in a while. + """ + entries = yield self.store.get_remote_profile_cache_entries_that_expire( + last_checked=self.clock.time_msec() - self.PROFILE_UPDATE_EVERY_MS + ) + + for user_id, displayname, avatar_url in entries: + is_subscribed = yield self.store.is_subscribed_remote_profile_for_user( + user_id, + ) + if not is_subscribed: + yield self.store.maybe_delete_remote_profile_cache(user_id) + continue + + try: + profile = yield self.federation.make_query( + destination=get_domain_from_id(user_id), + query_type="profile", + args={ + "user_id": user_id, + }, + ignore_backoff=True, + ) + except Exception: + logger.exception("Failed to get avatar_url") + + yield self.store.update_remote_profile_cache( + user_id, displayname, avatar_url + ) + continue + + new_name = profile.get("displayname") + new_avatar = profile.get("avatar_url") + + # We always hit update to update the last_check timestamp + yield self.store.update_remote_profile_cache( + user_id, new_name, new_avatar + ) diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index e1cd3a48e9..0525765272 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from synapse.util import logcontext from ._base import BaseHandler @@ -59,6 +60,8 @@ class ReceiptsHandler(BaseHandler): is_new = yield self._handle_new_receipts([receipt]) if is_new: + # fire off a process in the background to send the receipt to + # remote servers self._push_remotes([receipt]) @defer.inlineCallbacks @@ -126,6 +129,7 @@ class ReceiptsHandler(BaseHandler): defer.returnValue(True) + @logcontext.preserve_fn # caller should not yield on this @defer.inlineCallbacks def _push_remotes(self, receipts): """Given a list of receipts, works out which remote servers should be diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index ee3a2269a8..4bc6ef51fe 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -15,7 +15,6 @@ """Contains functions for registering clients.""" import logging -import urllib from twisted.internet import defer @@ -23,6 +22,7 @@ from synapse.api.errors import ( AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError ) from synapse.http.client import CaptchaServerHttpClient +from synapse import types from synapse.types import UserID from synapse.util.async import run_on_reactor from ._base import BaseHandler @@ -36,6 +36,9 @@ class RegistrationHandler(BaseHandler): super(RegistrationHandler, self).__init__(hs) self.auth = hs.get_auth() + self._auth_handler = hs.get_auth_handler() + self.profile_handler = hs.get_profile_handler() + self.user_directory_handler = hs.get_user_directory_handler() self.captcha_client = CaptchaServerHttpClient(hs) self._next_generated_user_id = None @@ -45,12 +48,10 @@ class RegistrationHandler(BaseHandler): @defer.inlineCallbacks def check_username(self, localpart, guest_access_token=None, assigned_user_id=None): - yield run_on_reactor() - - if urllib.quote(localpart.encode('utf-8')) != localpart: + if types.contains_invalid_mxid_characters(localpart): raise SynapseError( 400, - "User ID can only contain characters a-z, 0-9, or '_-./'", + "User ID can only contain characters a-z, 0-9, or '=_-./'", Codes.INVALID_USERNAME ) @@ -80,7 +81,7 @@ class RegistrationHandler(BaseHandler): "A different user ID has already been registered for this session", ) - yield self.check_user_id_not_appservice_exclusive(user_id) + self.check_user_id_not_appservice_exclusive(user_id) users = yield self.store.get_users_by_id_case_insensitive(user_id) if users: @@ -165,6 +166,13 @@ class RegistrationHandler(BaseHandler): ), admin=admin, ) + + if self.hs.config.user_directory_search_all_users: + profile = yield self.store.get_profileinfo(localpart) + yield self.user_directory_handler.handle_local_profile_change( + user_id, profile + ) + else: # autogen a sequential user ID attempts = 0 @@ -253,11 +261,10 @@ class RegistrationHandler(BaseHandler): """ Registers email_id as SAML2 Based Auth. """ - if urllib.quote(localpart) != localpart: + if types.contains_invalid_mxid_characters(localpart): raise SynapseError( 400, - "User ID must only contain characters which do not" - " require URL encoding." + "User ID can only contain characters a-z, 0-9, or '=_-./'", ) user = UserID(localpart, self.hs.hostname) user_id = user.to_string() @@ -291,7 +298,7 @@ class RegistrationHandler(BaseHandler): try: identity_handler = self.hs.get_handlers().identity_handler threepid = yield identity_handler.threepid_from_creds(c) - except: + except Exception: logger.exception("Couldn't validate 3pid") raise RegistrationError(400, "Couldn't validate 3pid") @@ -418,13 +425,12 @@ class RegistrationHandler(BaseHandler): create_profile_with_localpart=user.localpart, ) else: - yield self.store.user_delete_access_tokens(user_id=user_id) + yield self._auth_handler.delete_access_tokens_for_user(user_id) yield self.store.add_access_token_to_user(user_id=user_id, token=token) if displayname is not None: logger.info("setting user display name: %s -> %s", user_id, displayname) - profile_handler = self.hs.get_handlers().profile_handler - yield profile_handler.set_displayname( + yield self.profile_handler.set_displayname( user, requester, displayname, by_admin=True, ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 5698d28088..d1cc87a016 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -60,6 +60,11 @@ class RoomCreationHandler(BaseHandler): }, } + def __init__(self, hs): + super(RoomCreationHandler, self).__init__(hs) + + self.spam_checker = hs.get_spam_checker() + @defer.inlineCallbacks def create_room(self, requester, config, ratelimit=True): """ Creates a new room. @@ -75,6 +80,9 @@ class RoomCreationHandler(BaseHandler): """ user_id = requester.user.to_string() + if not self.spam_checker.user_may_create_room(user_id): + raise SynapseError(403, "You are not permitted to create rooms") + if ratelimit: yield self.ratelimit(requester) @@ -83,7 +91,7 @@ class RoomCreationHandler(BaseHandler): if wchar in config["room_alias_name"]: raise SynapseError(400, "Invalid characters in room alias") - room_alias = RoomAlias.create( + room_alias = RoomAlias( config["room_alias_name"], self.hs.hostname, ) @@ -100,7 +108,7 @@ class RoomCreationHandler(BaseHandler): for i in invite_list: try: UserID.from_string(i) - except: + except Exception: raise SynapseError(400, "Invalid user_id: %s" % (i,)) invite_3pid_list = config.get("invite_3pid", []) @@ -115,7 +123,7 @@ class RoomCreationHandler(BaseHandler): while attempts < 5: try: random_string = stringutils.random_string(18) - gen_room_id = RoomID.create( + gen_room_id = RoomID( random_string, self.hs.hostname, ) @@ -197,12 +205,12 @@ class RoomCreationHandler(BaseHandler): }, ratelimit=False) - content = {} - is_direct = config.get("is_direct", None) - if is_direct: - content["is_direct"] = is_direct - for invitee in invite_list: + content = {} + is_direct = config.get("is_direct", None) + if is_direct: + content["is_direct"] = is_direct + yield room_member_handler.update_membership( requester, UserID.from_string(invitee), diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 516cd9a6ac..bb40075387 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -20,6 +20,7 @@ from ._base import BaseHandler from synapse.api.constants import ( EventTypes, JoinRules, ) +from synapse.util.logcontext import make_deferred_yieldable, preserve_fn from synapse.util.async import concurrently_execute from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.response_cache import ResponseCache @@ -70,6 +71,7 @@ class RoomListHandler(BaseHandler): if search_filter: # We explicitly don't bother caching searches or requests for # appservice specific lists. + logger.info("Bypassing cache as search request.") return self._get_public_room_list( limit, since_token, search_filter, network_tuple=network_tuple, ) @@ -77,13 +79,16 @@ class RoomListHandler(BaseHandler): key = (limit, since_token, network_tuple) result = self.response_cache.get(key) if not result: + logger.info("No cached result, calculating one.") result = self.response_cache.set( key, - self._get_public_room_list( + preserve_fn(self._get_public_room_list)( limit, since_token, network_tuple=network_tuple ) ) - return result + else: + logger.info("Using cached deferred result.") + return make_deferred_yieldable(result) @defer.inlineCallbacks def _get_public_room_list(self, limit=None, since_token=None, @@ -149,6 +154,8 @@ class RoomListHandler(BaseHandler): # We want larger rooms to be first, hence negating num_joined_users rooms_to_order_value[room_id] = (-num_joined_users, room_id) + logger.info("Getting ordering for %i rooms since %s", + len(room_ids), stream_token) yield concurrently_execute(get_order_for_room, room_ids, 10) sorted_entries = sorted(rooms_to_order_value.items(), key=lambda e: e[1]) @@ -176,34 +183,42 @@ class RoomListHandler(BaseHandler): rooms_to_scan = rooms_to_scan[:since_token.current_limit] rooms_to_scan.reverse() - # Actually generate the entries. _append_room_entry_to_chunk will append to - # chunk but will stop if len(chunk) > limit - chunk = [] - if limit and not search_filter: + logger.info("After sorting and filtering, %i rooms remain", + len(rooms_to_scan)) + + # _append_room_entry_to_chunk will append to chunk but will stop if + # len(chunk) > limit + # + # Normally we will generate enough results on the first iteration here, + # but if there is a search filter, _append_room_entry_to_chunk may + # filter some results out, in which case we loop again. + # + # We don't want to scan over the entire range either as that + # would potentially waste a lot of work. + # + # XXX if there is no limit, we may end up DoSing the server with + # calls to get_current_state_ids for every single room on the + # server. Surely we should cap this somehow? + # + if limit: step = limit + 1 - for i in xrange(0, len(rooms_to_scan), step): - # We iterate here because the vast majority of cases we'll stop - # at first iteration, but occaisonally _append_room_entry_to_chunk - # won't append to the chunk and so we need to loop again. - # We don't want to scan over the entire range either as that - # would potentially waste a lot of work. - yield concurrently_execute( - lambda r: self._append_room_entry_to_chunk( - r, rooms_to_num_joined[r], - chunk, limit, search_filter - ), - rooms_to_scan[i:i + step], 10 - ) - if len(chunk) >= limit + 1: - break else: + step = len(rooms_to_scan) + + chunk = [] + for i in xrange(0, len(rooms_to_scan), step): + batch = rooms_to_scan[i:i + step] + logger.info("Processing %i rooms for result", len(batch)) yield concurrently_execute( lambda r: self._append_room_entry_to_chunk( r, rooms_to_num_joined[r], chunk, limit, search_filter ), - rooms_to_scan, 5 + batch, 5, ) + logger.info("Now %i rooms in result", len(chunk)) + if len(chunk) >= limit + 1: + break chunk.sort(key=lambda e: (-e["num_joined_members"], e["room_id"])) @@ -276,13 +291,14 @@ class RoomListHandler(BaseHandler): # We've already got enough, so lets just drop it. return - result = yield self._generate_room_entry(room_id, num_joined_users) + result = yield self.generate_room_entry(room_id, num_joined_users) if result and _matches_room_entry(result, search_filter): chunk.append(result) @cachedInlineCallbacks(num_args=1, cache_context=True) - def _generate_room_entry(self, room_id, num_joined_users, cache_context): + def generate_room_entry(self, room_id, num_joined_users, cache_context, + with_alias=True, allow_private=False): """Returns the entry for a room """ result = { @@ -316,14 +332,15 @@ class RoomListHandler(BaseHandler): join_rules_event = current_state.get((EventTypes.JoinRules, "")) if join_rules_event: join_rule = join_rules_event.content.get("join_rule", None) - if join_rule and join_rule != JoinRules.PUBLIC: + if not allow_private and join_rule and join_rule != JoinRules.PUBLIC: defer.returnValue(None) - aliases = yield self.store.get_aliases_for_room( - room_id, on_invalidate=cache_context.invalidate - ) - if aliases: - result["aliases"] = aliases + if with_alias: + aliases = yield self.store.get_aliases_for_room( + room_id, on_invalidate=cache_context.invalidate + ) + if aliases: + result["aliases"] = aliases name_event = yield current_state.get((EventTypes.Name, "")) if name_event: diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 9a498c2d3e..7e6467cd1d 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -45,9 +45,12 @@ class RoomMemberHandler(BaseHandler): def __init__(self, hs): super(RoomMemberHandler, self).__init__(hs) + self.profile_handler = hs.get_profile_handler() + self.member_linearizer = Linearizer(name="member") self.clock = hs.get_clock() + self.spam_checker = hs.get_spam_checker() self.distributor = hs.get_distributor() self.distributor.declare("user_joined_room") @@ -186,6 +189,10 @@ class RoomMemberHandler(BaseHandler): content_specified = bool(content) if content is None: content = {} + else: + # We do a copy here as we potentially change some keys + # later on. + content = dict(content) effective_membership_state = action if action in ["kick", "unban"]: @@ -210,12 +217,26 @@ class RoomMemberHandler(BaseHandler): if is_blocked: raise SynapseError(403, "This room has been blocked on this server") - if (effective_membership_state == "invite" and - self.hs.config.block_non_admin_invites): + if effective_membership_state == "invite": + block_invite = False is_requester_admin = yield self.auth.is_server_admin( requester.user, ) if not is_requester_admin: + if self.hs.config.block_non_admin_invites: + logger.info( + "Blocking invite: user is not admin and non-admin " + "invites disabled" + ) + block_invite = True + + if not self.spam_checker.user_may_invite( + requester.user.to_string(), target.to_string(), room_id, + ): + logger.info("Blocking invite due to spam checker") + block_invite = True + + if block_invite: raise SynapseError( 403, "Invites have been disabled on this server", ) @@ -267,7 +288,7 @@ class RoomMemberHandler(BaseHandler): content["membership"] = Membership.JOIN - profile = self.hs.get_handlers().profile_handler + profile = self.profile_handler if not content_specified: content["displayname"] = yield profile.get_displayname(target) content["avatar_url"] = yield profile.get_avatar_url(target) diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index df75d70fac..9772ed1a0e 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -61,7 +61,7 @@ class SearchHandler(BaseHandler): assert batch_group is not None assert batch_group_key is not None assert batch_token is not None - except: + except Exception: raise SynapseError(400, "Invalid batch") try: diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py new file mode 100644 index 0000000000..44414e1dc1 --- /dev/null +++ b/synapse/handlers/set_password.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from twisted.internet import defer + +from synapse.api.errors import Codes, StoreError, SynapseError +from ._base import BaseHandler + +logger = logging.getLogger(__name__) + + +class SetPasswordHandler(BaseHandler): + """Handler which deals with changing user account passwords""" + def __init__(self, hs): + super(SetPasswordHandler, self).__init__(hs) + self._auth_handler = hs.get_auth_handler() + self._device_handler = hs.get_device_handler() + + @defer.inlineCallbacks + def set_password(self, user_id, newpassword, requester=None): + password_hash = self._auth_handler.hash(newpassword) + + except_device_id = requester.device_id if requester else None + except_access_token_id = requester.access_token_id if requester else None + + try: + yield self.store.user_set_password_hash(user_id, password_hash) + except StoreError as e: + if e.code == 404: + raise SynapseError(404, "Unknown user", Codes.NOT_FOUND) + raise e + + # we want to log out all of the user's other sessions. First delete + # all his other devices. + yield self._device_handler.delete_all_devices_for_user( + user_id, except_device_id=except_device_id, + ) + + # and now delete any access tokens which weren't associated with + # devices (or were associated with this device). + yield self._auth_handler.delete_access_tokens_for_user( + user_id, except_token_id=except_access_token_id, + ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index dd0ec00ae6..b12988f3c9 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -15,7 +15,7 @@ from synapse.api.constants import Membership, EventTypes from synapse.util.async import concurrently_execute -from synapse.util.logcontext import LoggingContext +from synapse.util.logcontext import LoggingContext, make_deferred_yieldable, preserve_fn from synapse.util.metrics import Measure, measure_func from synapse.util.caches.response_cache import ResponseCache from synapse.push.clientformat import format_push_rules_for_user @@ -108,6 +108,17 @@ class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [ return True +class GroupsSyncResult(collections.namedtuple("GroupsSyncResult", [ + "join", + "invite", + "leave", +])): + __slots__ = [] + + def __nonzero__(self): + return bool(self.join or self.invite or self.leave) + + class DeviceLists(collections.namedtuple("DeviceLists", [ "changed", # list of user_ids whose devices may have changed "left", # list of user_ids whose devices we no longer track @@ -129,6 +140,7 @@ class SyncResult(collections.namedtuple("SyncResult", [ "device_lists", # List of user_ids whose devices have chanegd "device_one_time_keys_count", # Dict of algorithm to count for one time keys # for this device + "groups", ])): __slots__ = [] @@ -144,7 +156,8 @@ class SyncResult(collections.namedtuple("SyncResult", [ self.archived or self.account_data or self.to_device or - self.device_lists + self.device_lists or + self.groups ) @@ -171,11 +184,11 @@ class SyncHandler(object): if not result: result = self.response_cache.set( sync_config.request_key, - self._wait_for_sync_for_user( + preserve_fn(self._wait_for_sync_for_user)( sync_config, since_token, timeout, full_state ) ) - return result + return make_deferred_yieldable(result) @defer.inlineCallbacks def _wait_for_sync_for_user(self, sync_config, since_token, timeout, @@ -595,6 +608,8 @@ class SyncHandler(object): user_id, device_id ) + yield self._generate_sync_entry_for_groups(sync_result_builder) + defer.returnValue(SyncResult( presence=sync_result_builder.presence, account_data=sync_result_builder.account_data, @@ -603,10 +618,57 @@ class SyncHandler(object): archived=sync_result_builder.archived, to_device=sync_result_builder.to_device, device_lists=device_lists, + groups=sync_result_builder.groups, device_one_time_keys_count=one_time_key_counts, next_batch=sync_result_builder.now_token, )) + @measure_func("_generate_sync_entry_for_groups") + @defer.inlineCallbacks + def _generate_sync_entry_for_groups(self, sync_result_builder): + user_id = sync_result_builder.sync_config.user.to_string() + since_token = sync_result_builder.since_token + now_token = sync_result_builder.now_token + + if since_token and since_token.groups_key: + results = yield self.store.get_groups_changes_for_user( + user_id, since_token.groups_key, now_token.groups_key, + ) + else: + results = yield self.store.get_all_groups_for_user( + user_id, now_token.groups_key, + ) + + invited = {} + joined = {} + left = {} + for result in results: + membership = result["membership"] + group_id = result["group_id"] + gtype = result["type"] + content = result["content"] + + if membership == "join": + if gtype == "membership": + # TODO: Add profile + content.pop("membership", None) + joined[group_id] = content["content"] + else: + joined.setdefault(group_id, {})[gtype] = content + elif membership == "invite": + if gtype == "membership": + content.pop("membership", None) + invited[group_id] = content["content"] + else: + if gtype == "membership": + left[group_id] = content["content"] + + sync_result_builder.groups = GroupsSyncResult( + join=joined, + invite=invited, + leave=left, + ) + @measure_func("_generate_sync_entry_for_device_list") @defer.inlineCallbacks def _generate_sync_entry_for_device_list(self, sync_result_builder, @@ -1368,6 +1430,7 @@ class SyncResultBuilder(object): self.invited = [] self.archived = [] self.device = [] + self.groups = None self.to_device = [] diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 2a49456bfc..714f0195c8 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -20,12 +20,13 @@ from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.storage.roommember import ProfileInfo from synapse.util.metrics import Measure from synapse.util.async import sleep +from synapse.types import get_localpart_from_id logger = logging.getLogger(__name__) -class UserDirectoyHandler(object): +class UserDirectoryHandler(object): """Handles querying of and keeping updated the user_directory. N.B.: ASSUMES IT IS THE ONLY THING THAT MODIFIES THE USER DIRECTORY @@ -41,9 +42,10 @@ class UserDirectoyHandler(object): one public room. """ - INITIAL_SLEEP_MS = 50 - INITIAL_SLEEP_COUNT = 100 - INITIAL_BATCH_SIZE = 100 + INITIAL_ROOM_SLEEP_MS = 50 + INITIAL_ROOM_SLEEP_COUNT = 100 + INITIAL_ROOM_BATCH_SIZE = 100 + INITIAL_USER_SLEEP_MS = 10 def __init__(self, hs): self.store = hs.get_datastore() @@ -53,6 +55,7 @@ class UserDirectoyHandler(object): self.notifier = hs.get_notifier() self.is_mine_id = hs.is_mine_id self.update_user_directory = hs.config.update_user_directory + self.search_all_users = hs.config.user_directory_search_all_users # When start up for the first time we need to populate the user_directory. # This is a set of user_id's we've inserted already @@ -111,6 +114,15 @@ class UserDirectoyHandler(object): self._is_processing = False @defer.inlineCallbacks + def handle_local_profile_change(self, user_id, profile): + """Called to update index of our local user profiles when they change + irrespective of any rooms the user may be in. + """ + yield self.store.update_profile_in_user_dir( + user_id, profile.display_name, profile.avatar_url, None, + ) + + @defer.inlineCallbacks def _unsafe_process(self): # If self.pos is None then means we haven't fetched it from DB if self.pos is None: @@ -148,16 +160,30 @@ class UserDirectoyHandler(object): room_ids = yield self.store.get_all_rooms() logger.info("Doing initial update of user directory. %d rooms", len(room_ids)) - num_processed_rooms = 1 + num_processed_rooms = 0 for room_id in room_ids: - logger.info("Handling room %d/%d", num_processed_rooms, len(room_ids)) - yield self._handle_intial_room(room_id) + logger.info("Handling room %d/%d", num_processed_rooms + 1, len(room_ids)) + yield self._handle_initial_room(room_id) num_processed_rooms += 1 - yield sleep(self.INITIAL_SLEEP_MS / 1000.) + yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.) logger.info("Processed all rooms.") + if self.search_all_users: + num_processed_users = 0 + user_ids = yield self.store.get_all_local_users() + logger.info("Doing initial update of user directory. %d users", len(user_ids)) + for user_id in user_ids: + # We add profiles for all users even if they don't match the + # include pattern, just in case we want to change it in future + logger.info("Handling user %d/%d", num_processed_users + 1, len(user_ids)) + yield self._handle_local_user(user_id) + num_processed_users += 1 + yield sleep(self.INITIAL_USER_SLEEP_MS / 1000.) + + logger.info("Processed all users") + self.initially_handled_users = None self.initially_handled_users_in_public = None self.initially_handled_users_share = None @@ -166,7 +192,7 @@ class UserDirectoyHandler(object): yield self.store.update_user_directory_stream_pos(new_pos) @defer.inlineCallbacks - def _handle_intial_room(self, room_id): + def _handle_initial_room(self, room_id): """Called when we initially fill out user_directory one room at a time """ is_in_room = yield self.store.is_host_joined(room_id, self.server_name) @@ -201,8 +227,8 @@ class UserDirectoyHandler(object): to_update = set() count = 0 for user_id in user_ids: - if count % self.INITIAL_SLEEP_COUNT == 0: - yield sleep(self.INITIAL_SLEEP_MS / 1000.) + if count % self.INITIAL_ROOM_SLEEP_COUNT == 0: + yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.) if not self.is_mine_id(user_id): count += 1 @@ -216,8 +242,8 @@ class UserDirectoyHandler(object): if user_id == other_user_id: continue - if count % self.INITIAL_SLEEP_COUNT == 0: - yield sleep(self.INITIAL_SLEEP_MS / 1000.) + if count % self.INITIAL_ROOM_SLEEP_COUNT == 0: + yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.) count += 1 user_set = (user_id, other_user_id) @@ -237,13 +263,13 @@ class UserDirectoyHandler(object): else: self.initially_handled_users_share_private_room.add(user_set) - if len(to_insert) > self.INITIAL_BATCH_SIZE: + if len(to_insert) > self.INITIAL_ROOM_BATCH_SIZE: yield self.store.add_users_who_share_room( room_id, not is_public, to_insert, ) to_insert.clear() - if len(to_update) > self.INITIAL_BATCH_SIZE: + if len(to_update) > self.INITIAL_ROOM_BATCH_SIZE: yield self.store.update_users_who_share_room( room_id, not is_public, to_update, ) @@ -385,14 +411,28 @@ class UserDirectoyHandler(object): yield self._handle_remove_user(room_id, user_id) @defer.inlineCallbacks + def _handle_local_user(self, user_id): + """Adds a new local roomless user into the user_directory_search table. + Used to populate up the user index when we have an + user_directory_search_all_users specified. + """ + logger.debug("Adding new local user to dir, %r", user_id) + + profile = yield self.store.get_profileinfo(get_localpart_from_id(user_id)) + + row = yield self.store.get_user_in_directory(user_id) + if not row: + yield self.store.add_profiles_to_user_dir(None, {user_id: profile}) + + @defer.inlineCallbacks def _handle_new_user(self, room_id, user_id, profile): """Called when we might need to add user to directory Args: - room_id (str): room_id that user joined or started being public that + room_id (str): room_id that user joined or started being public user_id (str) """ - logger.debug("Adding user to dir, %r", user_id) + logger.debug("Adding new user to dir, %r", user_id) row = yield self.store.get_user_in_directory(user_id) if not row: @@ -407,7 +447,7 @@ class UserDirectoyHandler(object): if not row: yield self.store.add_users_to_public_room(room_id, [user_id]) else: - logger.debug("Not adding user to public dir, %r", user_id) + logger.debug("Not adding new user to public dir, %r", user_id) # Now we update users who share rooms with users. We do this by getting # all the current users in the room and seeing which aren't already diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py new file mode 100644 index 0000000000..343e932cb1 --- /dev/null +++ b/synapse/http/additional_resource.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.http.server import wrap_request_handler +from twisted.web.resource import Resource +from twisted.web.server import NOT_DONE_YET + + +class AdditionalResource(Resource): + """Resource wrapper for additional_resources + + If the user has configured additional_resources, we need to wrap the + handler class with a Resource so that we can map it into the resource tree. + + This class is also where we wrap the request handler with logging, metrics, + and exception handling. + """ + def __init__(self, hs, handler): + """Initialise AdditionalResource + + The ``handler`` should return a deferred which completes when it has + done handling the request. It should write a response with + ``request.write()``, and call ``request.finish()``. + + Args: + hs (synapse.server.HomeServer): homeserver + handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred): + function to be called to handle the request. + """ + Resource.__init__(self) + self._handler = handler + + # these are required by the request_handler wrapper + self.version_string = hs.version_string + self.clock = hs.get_clock() + + def render(self, request): + self._async_render(request) + return NOT_DONE_YET + + @wrap_request_handler + def _async_render(self, request): + return self._handler(request) diff --git a/synapse/http/client.py b/synapse/http/client.py index 9eba046bbf..4abb479ae3 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -18,7 +18,7 @@ from OpenSSL.SSL import VERIFY_NONE from synapse.api.errors import ( CodeMessageException, MatrixCodeMessageException, SynapseError, Codes, ) -from synapse.util.logcontext import preserve_context_over_fn +from synapse.util.logcontext import make_deferred_yieldable from synapse.util import logcontext import synapse.metrics from synapse.http.endpoint import SpiderEndpoint @@ -114,43 +114,73 @@ class SimpleHttpClient(object): raise e @defer.inlineCallbacks - def post_urlencoded_get_json(self, uri, args={}): + def post_urlencoded_get_json(self, uri, args={}, headers=None): + """ + Args: + uri (str): + args (dict[str, str|List[str]]): query params + headers (dict[str, List[str]]|None): If not None, a map from + header name to a list of values for that header + + Returns: + Deferred[object]: parsed json + """ + # TODO: Do we ever want to log message contents? logger.debug("post_urlencoded_get_json args: %s", args) query_bytes = urllib.urlencode(encode_urlencode_args(args), True) + actual_headers = { + b"Content-Type": [b"application/x-www-form-urlencoded"], + b"User-Agent": [self.user_agent], + } + if headers: + actual_headers.update(headers) + response = yield self.request( "POST", uri.encode("ascii"), - headers=Headers({ - b"Content-Type": [b"application/x-www-form-urlencoded"], - b"User-Agent": [self.user_agent], - }), + headers=Headers(actual_headers), bodyProducer=FileBodyProducer(StringIO(query_bytes)) ) - body = yield preserve_context_over_fn(readBody, response) + body = yield make_deferred_yieldable(readBody(response)) defer.returnValue(json.loads(body)) @defer.inlineCallbacks - def post_json_get_json(self, uri, post_json): + def post_json_get_json(self, uri, post_json, headers=None): + """ + + Args: + uri (str): + post_json (object): + headers (dict[str, List[str]]|None): If not None, a map from + header name to a list of values for that header + + Returns: + Deferred[object]: parsed json + """ json_str = encode_canonical_json(post_json) logger.debug("HTTP POST %s -> %s", json_str, uri) + actual_headers = { + b"Content-Type": [b"application/json"], + b"User-Agent": [self.user_agent], + } + if headers: + actual_headers.update(headers) + response = yield self.request( "POST", uri.encode("ascii"), - headers=Headers({ - b"Content-Type": [b"application/json"], - b"User-Agent": [self.user_agent], - }), + headers=Headers(actual_headers), bodyProducer=FileBodyProducer(StringIO(json_str)) ) - body = yield preserve_context_over_fn(readBody, response) + body = yield make_deferred_yieldable(readBody(response)) if 200 <= response.code < 300: defer.returnValue(json.loads(body)) @@ -160,7 +190,7 @@ class SimpleHttpClient(object): defer.returnValue(json.loads(body)) @defer.inlineCallbacks - def get_json(self, uri, args={}): + def get_json(self, uri, args={}, headers=None): """ Gets some json from the given URI. Args: @@ -169,6 +199,8 @@ class SimpleHttpClient(object): None. **Note**: The value of each key is assumed to be an iterable and *not* a string. + headers (dict[str, List[str]]|None): If not None, a map from + header name to a list of values for that header Returns: Deferred: Succeeds when we get *any* 2xx HTTP response, with the HTTP body as JSON. @@ -177,13 +209,13 @@ class SimpleHttpClient(object): error message. """ try: - body = yield self.get_raw(uri, args) + body = yield self.get_raw(uri, args, headers=headers) defer.returnValue(json.loads(body)) except CodeMessageException as e: raise self._exceptionFromFailedRequest(e.code, e.msg) @defer.inlineCallbacks - def put_json(self, uri, json_body, args={}): + def put_json(self, uri, json_body, args={}, headers=None): """ Puts some json to the given URI. Args: @@ -193,6 +225,8 @@ class SimpleHttpClient(object): None. **Note**: The value of each key is assumed to be an iterable and *not* a string. + headers (dict[str, List[str]]|None): If not None, a map from + header name to a list of values for that header Returns: Deferred: Succeeds when we get *any* 2xx HTTP response, with the HTTP body as JSON. @@ -205,17 +239,21 @@ class SimpleHttpClient(object): json_str = encode_canonical_json(json_body) + actual_headers = { + b"Content-Type": [b"application/json"], + b"User-Agent": [self.user_agent], + } + if headers: + actual_headers.update(headers) + response = yield self.request( "PUT", uri.encode("ascii"), - headers=Headers({ - b"User-Agent": [self.user_agent], - "Content-Type": ["application/json"] - }), + headers=Headers(actual_headers), bodyProducer=FileBodyProducer(StringIO(json_str)) ) - body = yield preserve_context_over_fn(readBody, response) + body = yield make_deferred_yieldable(readBody(response)) if 200 <= response.code < 300: defer.returnValue(json.loads(body)) @@ -226,7 +264,7 @@ class SimpleHttpClient(object): raise CodeMessageException(response.code, body) @defer.inlineCallbacks - def get_raw(self, uri, args={}): + def get_raw(self, uri, args={}, headers=None): """ Gets raw text from the given URI. Args: @@ -235,6 +273,8 @@ class SimpleHttpClient(object): None. **Note**: The value of each key is assumed to be an iterable and *not* a string. + headers (dict[str, List[str]]|None): If not None, a map from + header name to a list of values for that header Returns: Deferred: Succeeds when we get *any* 2xx HTTP response, with the HTTP body at text. @@ -246,15 +286,19 @@ class SimpleHttpClient(object): query_bytes = urllib.urlencode(args, True) uri = "%s?%s" % (uri, query_bytes) + actual_headers = { + b"User-Agent": [self.user_agent], + } + if headers: + actual_headers.update(headers) + response = yield self.request( "GET", uri.encode("ascii"), - headers=Headers({ - b"User-Agent": [self.user_agent], - }) + headers=Headers(actual_headers), ) - body = yield preserve_context_over_fn(readBody, response) + body = yield make_deferred_yieldable(readBody(response)) if 200 <= response.code < 300: defer.returnValue(body) @@ -274,27 +318,33 @@ class SimpleHttpClient(object): # The two should be factored out. @defer.inlineCallbacks - def get_file(self, url, output_stream, max_size=None): + def get_file(self, url, output_stream, max_size=None, headers=None): """GETs a file from a given URL Args: url (str): The URL to GET output_stream (file): File to write the response body to. + headers (dict[str, List[str]]|None): If not None, a map from + header name to a list of values for that header Returns: A (int,dict,string,int) tuple of the file length, dict of the response headers, absolute URI of the response and HTTP response code. """ + actual_headers = { + b"User-Agent": [self.user_agent], + } + if headers: + actual_headers.update(headers) + response = yield self.request( "GET", url.encode("ascii"), - headers=Headers({ - b"User-Agent": [self.user_agent], - }) + headers=Headers(actual_headers), ) - headers = dict(response.headers.getAllRawHeaders()) + resp_headers = dict(response.headers.getAllRawHeaders()) - if 'Content-Length' in headers and headers['Content-Length'] > max_size: + if 'Content-Length' in resp_headers and resp_headers['Content-Length'] > max_size: logger.warn("Requested URL is too large > %r bytes" % (self.max_size,)) raise SynapseError( 502, @@ -315,10 +365,9 @@ class SimpleHttpClient(object): # straight back in again try: - length = yield preserve_context_over_fn( - _readBodyToFile, - response, output_stream, max_size - ) + length = yield make_deferred_yieldable(_readBodyToFile( + response, output_stream, max_size, + )) except Exception as e: logger.exception("Failed to download body") raise SynapseError( @@ -327,7 +376,9 @@ class SimpleHttpClient(object): Codes.UNKNOWN, ) - defer.returnValue((length, headers, response.request.absoluteURI, response.code)) + defer.returnValue( + (length, resp_headers, response.request.absoluteURI, response.code), + ) # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. @@ -395,7 +446,7 @@ class CaptchaServerHttpClient(SimpleHttpClient): ) try: - body = yield preserve_context_over_fn(readBody, response) + body = yield make_deferred_yieldable(readBody(response)) defer.returnValue(body) except PartialDownloadError as e: # twisted dislikes google's response, no content length. diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py index 241b17f2cb..e2b99ef3bd 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py @@ -354,16 +354,30 @@ def _get_hosts_for_srv_record(dns_client, host): return res[0] - def eb(res): - res.trap(DNSNameError) - return [] + def eb(res, record_type): + if res.check(DNSNameError): + return [] + logger.warn("Error looking up %s for %s: %s", + record_type, host, res, res.value) + return res # no logcontexts here, so we can safely fire these off and gatherResults - d1 = dns_client.lookupAddress(host).addCallbacks(cb, eb) - d2 = dns_client.lookupIPV6Address(host).addCallbacks(cb, eb) - results = yield defer.gatherResults([d1, d2], consumeErrors=True) + d1 = dns_client.lookupAddress(host).addCallbacks( + cb, eb, errbackArgs=("A", )) + d2 = dns_client.lookupIPV6Address(host).addCallbacks( + cb, eb, errbackArgs=("AAAA", )) + results = yield defer.DeferredList( + [d1, d2], consumeErrors=True) + + # if all of the lookups failed, raise an exception rather than blowing out + # the cache with an empty result. + if results and all(s == defer.FAILURE for (s, _) in results): + defer.returnValue(results[0][1]) + + for (success, result) in results: + if success == defer.FAILURE: + continue - for result in results: for answer in result: if not answer.payload: continue diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 747a791f83..833496b72d 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -204,18 +204,15 @@ class MatrixFederationHttpClient(object): raise logger.warn( - "{%s} Sending request failed to %s: %s %s: %s - %s", + "{%s} Sending request failed to %s: %s %s: %s", txn_id, destination, method, url_bytes, - type(e).__name__, _flatten_response_never_received(e), ) - log_result = "%s - %s" % ( - type(e).__name__, _flatten_response_never_received(e), - ) + log_result = _flatten_response_never_received(e) if retries_left and not timeout: if long_retries: @@ -347,7 +344,7 @@ class MatrixFederationHttpClient(object): @defer.inlineCallbacks def post_json(self, destination, path, data={}, long_retries=False, - timeout=None, ignore_backoff=False): + timeout=None, ignore_backoff=False, args={}): """ Sends the specifed json data using POST Args: @@ -383,6 +380,7 @@ class MatrixFederationHttpClient(object): destination, "POST", path, + query_bytes=encode_query_args(args), body_callback=body_callback, headers_dict={"Content-Type": ["application/json"]}, long_retries=long_retries, @@ -427,13 +425,6 @@ class MatrixFederationHttpClient(object): """ logger.debug("get_json args: %s", args) - encoded_args = {} - for k, vs in args.items(): - if isinstance(vs, basestring): - vs = [vs] - encoded_args[k] = [v.encode("UTF-8") for v in vs] - - query_bytes = urllib.urlencode(encoded_args, True) logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail) def body_callback(method, url_bytes, headers_dict): @@ -444,7 +435,7 @@ class MatrixFederationHttpClient(object): destination, "GET", path, - query_bytes=query_bytes, + query_bytes=encode_query_args(args), body_callback=body_callback, retry_on_dns_fail=retry_on_dns_fail, timeout=timeout, @@ -461,6 +452,52 @@ class MatrixFederationHttpClient(object): defer.returnValue(json.loads(body)) @defer.inlineCallbacks + def delete_json(self, destination, path, long_retries=False, + timeout=None, ignore_backoff=False, args={}): + """Send a DELETE request to the remote expecting some json response + + Args: + destination (str): The remote server to send the HTTP request + to. + path (str): The HTTP path. + long_retries (bool): A boolean that indicates whether we should + retry for a short or long time. + timeout(int): How long to try (in ms) the destination for before + giving up. None indicates no timeout. + ignore_backoff (bool): true to ignore the historical backoff data and + try the request anyway. + Returns: + Deferred: Succeeds when we get a 2xx HTTP response. The result + will be the decoded JSON body. + + Fails with ``HTTPRequestException`` if we get an HTTP response + code >= 300. + + Fails with ``NotRetryingDestination`` if we are not yet ready + to retry this server. + """ + + response = yield self._request( + destination, + "DELETE", + path, + query_bytes=encode_query_args(args), + headers_dict={"Content-Type": ["application/json"]}, + long_retries=long_retries, + timeout=timeout, + ignore_backoff=ignore_backoff, + ) + + if 200 <= response.code < 300: + # We need to update the transactions table to say it was sent? + check_content_type_is_json(response.headers) + + with logcontext.PreserveLoggingContext(): + body = yield readBody(response) + + defer.returnValue(json.loads(body)) + + @defer.inlineCallbacks def get_file(self, destination, path, output_stream, args={}, retry_on_dns_fail=True, max_size=None, ignore_backoff=False): @@ -513,7 +550,7 @@ class MatrixFederationHttpClient(object): length = yield _readBodyToFile( response, output_stream, max_size ) - except: + except Exception: logger.exception("Failed to download body") raise @@ -578,12 +615,14 @@ class _JsonProducer(object): def _flatten_response_never_received(e): if hasattr(e, "reasons"): - return ", ".join( + reasons = ", ".join( _flatten_response_never_received(f.value) for f in e.reasons ) + + return "%s:[%s]" % (type(e).__name__, reasons) else: - return "%s: %s" % (type(e).__name__, e.message,) + return repr(e) def check_content_type_is_json(headers): @@ -610,3 +649,15 @@ def check_content_type_is_json(headers): raise RuntimeError( "Content-Type not application/json: was '%s'" % c_type ) + + +def encode_query_args(args): + encoded_args = {} + for k, vs in args.items(): + if isinstance(vs, basestring): + vs = [vs] + encoded_args[k] = [v.encode("UTF-8") for v in vs] + + query_bytes = urllib.urlencode(encoded_args, True) + + return query_bytes diff --git a/synapse/http/server.py b/synapse/http/server.py index 7ef3d526b1..25466cd292 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -28,6 +28,7 @@ from canonicaljson import ( ) from twisted.internet import defer +from twisted.python import failure from twisted.web import server, resource from twisted.web.server import NOT_DONE_YET from twisted.web.util import redirectTo @@ -130,13 +131,18 @@ def wrap_request_handler(request_handler, include_metrics=False): pretty_print=_request_user_agent_is_curl(request), version_string=self.version_string, ) - except: - logger.exception( - "Failed handle request %s.%s on %r: %r", + except Exception: + # failure.Failure() fishes the original Failure out + # of our stack, and thus gives us a sensible stack + # trace. + f = failure.Failure() + logger.error( + "Failed handle request %s.%s on %r: %r: %s", request_handler.__module__, request_handler.__name__, self, - request + request, + f.getTraceback().rstrip(), ) respond_with_json( request, @@ -145,7 +151,9 @@ def wrap_request_handler(request_handler, include_metrics=False): "error": "Internal server error", "errcode": Codes.UNKNOWN, }, - send_cors=True + send_cors=True, + pretty_print=_request_user_agent_is_curl(request), + version_string=self.version_string, ) finally: try: diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 9a4c36ad5d..71420e54db 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -48,7 +48,7 @@ def parse_integer_from_args(args, name, default=None, required=False): if name in args: try: return int(args[name][0]) - except: + except Exception: message = "Query parameter %r must be an integer" % (name,) raise SynapseError(400, message) else: @@ -88,7 +88,7 @@ def parse_boolean_from_args(args, name, default=None, required=False): "true": True, "false": False, }[args[name][0]] - except: + except Exception: message = ( "Boolean query parameter %r must be one of" " ['true', 'false']" @@ -162,12 +162,13 @@ def parse_json_value_from_request(request): """ try: content_bytes = request.content.read() - except: + except Exception: raise SynapseError(400, "Error reading JSON content.") try: content = simplejson.loads(content_bytes) - except simplejson.JSONDecodeError: + except Exception as e: + logger.warn("Unable to parse JSON: %s", e) raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) return content diff --git a/synapse/http/site.py b/synapse/http/site.py index 4b09d7ee66..cd1492b1c3 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -67,7 +67,7 @@ class SynapseRequest(Request): ru_utime, ru_stime = context.get_resource_usage() db_txn_count = context.db_txn_count db_txn_duration = context.db_txn_duration - except: + except Exception: ru_utime, ru_stime = (0, 0) db_txn_count, db_txn_duration = (0, 0) diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py new file mode 100644 index 0000000000..097c844d31 --- /dev/null +++ b/synapse/module_api/__init__.py @@ -0,0 +1,123 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.internet import defer + +from synapse.types import UserID + + +class ModuleApi(object): + """A proxy object that gets passed to password auth providers so they + can register new users etc if necessary. + """ + def __init__(self, hs, auth_handler): + self.hs = hs + + self._store = hs.get_datastore() + self._auth = hs.get_auth() + self._auth_handler = auth_handler + + def get_user_by_req(self, req, allow_guest=False): + """Check the access_token provided for a request + + Args: + req (twisted.web.server.Request): Incoming HTTP request + allow_guest (bool): True if guest users should be allowed. If this + is False, and the access token is for a guest user, an + AuthError will be thrown + Returns: + twisted.internet.defer.Deferred[synapse.types.Requester]: + the requester for this request + Raises: + synapse.api.errors.AuthError: if no user by that token exists, + or the token is invalid. + """ + return self._auth.get_user_by_req(req, allow_guest) + + def get_qualified_user_id(self, username): + """Qualify a user id, if necessary + + Takes a user id provided by the user and adds the @ and :domain to + qualify it, if necessary + + Args: + username (str): provided user id + + Returns: + str: qualified @user:id + """ + if username.startswith('@'): + return username + return UserID(username, self.hs.hostname).to_string() + + def check_user_exists(self, user_id): + """Check if user exists. + + Args: + user_id (str): Complete @user:id + + Returns: + Deferred[str|None]: Canonical (case-corrected) user_id, or None + if the user is not registered. + """ + return self._auth_handler.check_user_exists(user_id) + + def register(self, localpart): + """Registers a new user with given localpart + + Returns: + Deferred: a 2-tuple of (user_id, access_token) + """ + reg = self.hs.get_handlers().registration_handler + return reg.register(localpart=localpart) + + @defer.inlineCallbacks + def invalidate_access_token(self, access_token): + """Invalidate an access token for a user + + Args: + access_token(str): access token + + Returns: + twisted.internet.defer.Deferred - resolves once the access token + has been removed. + + Raises: + synapse.api.errors.AuthError: the access token is invalid + """ + # see if the access token corresponds to a device + user_info = yield self._auth.get_user_by_access_token(access_token) + device_id = user_info.get("device_id") + user_id = user_info["user"].to_string() + if device_id: + # delete the device, which will also delete its access tokens + yield self.hs.get_device_handler().delete_device(user_id, device_id) + else: + # no associated device. Just delete the access token. + yield self._auth_handler.delete_access_token(access_token) + + def run_db_interaction(self, desc, func, *args, **kwargs): + """Run a function with a database connection + + Args: + desc (str): description for the transaction, for metrics etc + func (func): function to be run. Passed a database cursor object + as well as *args and **kwargs + *args: positional args to be passed to func + **kwargs: named args to be passed to func + + Returns: + Deferred[object]: result of func + """ + return self._store.runInteraction(desc, func, *args, **kwargs) diff --git a/synapse/notifier.py b/synapse/notifier.py index 385208b574..ef042681bc 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -255,9 +255,7 @@ class Notifier(object): ) if self.federation_sender: - preserve_fn(self.federation_sender.notify_new_events)( - room_stream_id - ) + self.federation_sender.notify_new_events(room_stream_id) if event.type == EventTypes.Member and event.membership == Membership.JOIN: self._user_joined_room(event.state_key, event.room_id) @@ -289,7 +287,7 @@ class Notifier(object): for user_stream in user_streams: try: user_stream.notify(stream_key, new_token, time_now_ms) - except: + except Exception: logger.exception("Failed to notify listener") self.notify_replication() @@ -297,8 +295,7 @@ class Notifier(object): def on_new_replication_data(self): """Used to inform replication listeners that something has happend without waking up any of the normal user event streams""" - with PreserveLoggingContext(): - self.notify_replication() + self.notify_replication() @defer.inlineCallbacks def wait_for_events(self, user_id, timeout, callback, room_ids=None, @@ -516,8 +513,14 @@ class Notifier(object): self.replication_deferred = ObservableDeferred(defer.Deferred()) deferred.callback(None) - for cb in self.replication_callbacks: - preserve_fn(cb)() + # the callbacks may well outlast the current request, so we run + # them in the sentinel logcontext. + # + # (ideally it would be up to the callbacks to know if they were + # starting off background processes and drop the logcontext + # accordingly, but that requires more changes) + for cb in self.replication_callbacks: + cb() @defer.inlineCallbacks def wait_for_replication(self, callback, timeout): diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py index 85effdfa46..7a18afe5f9 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py @@ -1,4 +1,5 @@ # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2017 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -238,6 +239,28 @@ BASE_APPEND_OVERRIDE_RULES = [ } ] }, + { + 'rule_id': 'global/override/.m.rule.roomnotif', + 'conditions': [ + { + 'kind': 'event_match', + 'key': 'content.body', + 'pattern': '@room', + '_id': '_roomnotif_content', + }, + { + 'kind': 'sender_notification_permission', + 'key': 'room', + '_id': '_roomnotif_pl', + }, + ], + 'actions': [ + 'notify', { + 'set_tweak': 'highlight', + 'value': True, + } + ] + } ] diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index b0d64aa6c4..425a017bdf 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015 OpenMarket Ltd +# Copyright 2017 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,11 +20,13 @@ from twisted.internet import defer from .push_rule_evaluator import PushRuleEvaluatorForEvent +from synapse.event_auth import get_user_power_level from synapse.api.constants import EventTypes, Membership from synapse.metrics import get_metrics_for from synapse.util.caches import metrics as cache_metrics from synapse.util.caches.descriptors import cached from synapse.util.async import Linearizer +from synapse.state import POWER_KEY from collections import namedtuple @@ -59,6 +62,7 @@ class BulkPushRuleEvaluator(object): def __init__(self, hs): self.hs = hs self.store = hs.get_datastore() + self.auth = hs.get_auth() self.room_push_rule_cache_metrics = cache_metrics.register_cache( "cache", @@ -109,6 +113,29 @@ class BulkPushRuleEvaluator(object): ) @defer.inlineCallbacks + def _get_power_levels_and_sender_level(self, event, context): + pl_event_id = context.prev_state_ids.get(POWER_KEY) + if pl_event_id: + # fastpath: if there's a power level event, that's all we need, and + # not having a power level event is an extreme edge case + pl_event = yield self.store.get_event(pl_event_id) + auth_events = {POWER_KEY: pl_event} + else: + auth_events_ids = yield self.auth.compute_auth_events( + event, context.prev_state_ids, for_verification=False, + ) + auth_events = yield self.store.get_events(auth_events_ids) + auth_events = { + (e.type, e.state_key): e for e in auth_events.itervalues() + } + + sender_level = get_user_power_level(event.sender, auth_events) + + pl_event = auth_events.get(POWER_KEY) + + defer.returnValue((pl_event.content if pl_event else {}, sender_level)) + + @defer.inlineCallbacks def action_for_event_by_user(self, event, context): """Given an event and context, evaluate the push rules and return the results @@ -123,7 +150,13 @@ class BulkPushRuleEvaluator(object): event, context ) - evaluator = PushRuleEvaluatorForEvent(event, len(room_members)) + (power_levels, sender_power_level) = ( + yield self._get_power_levels_and_sender_level(event, context) + ) + + evaluator = PushRuleEvaluatorForEvent( + event, len(room_members), sender_power_level, power_levels, + ) condition_cache = {} diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index a69dda7b09..58df98a793 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -121,7 +121,7 @@ class EmailPusher(object): starting_max_ordering = self.max_stream_ordering try: yield self._unsafe_process() - except: + except Exception: logger.exception("Exception processing notifs") if self.max_stream_ordering == starting_max_ordering: break diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 62c41cd9db..c16f61452c 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2017 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -131,7 +132,7 @@ class HttpPusher(object): starting_max_ordering = self.max_stream_ordering try: yield self._unsafe_process() - except: + except Exception: logger.exception("Exception processing notifs") if self.max_stream_ordering == starting_max_ordering: break @@ -295,7 +296,7 @@ class HttpPusher(object): if event.type == 'm.room.member': d['notification']['membership'] = event.content['membership'] d['notification']['user_is_target'] = event.state_key == self.user_id - if not self.hs.config.push_redact_content and 'content' in event: + if self.hs.config.push_include_content and 'content' in event: d['notification']['content'] = event.content # We no longer send aliases separately, instead, we send the human @@ -314,7 +315,7 @@ class HttpPusher(object): defer.returnValue([]) try: resp = yield self.http_client.post_json_get_json(self.url, notification_dict) - except: + except Exception: logger.warn("Failed to push %s ", self.url) defer.returnValue(False) rejected = [] @@ -345,7 +346,7 @@ class HttpPusher(object): } try: resp = yield self.http_client.post_json_get_json(self.url, d) - except: + except Exception: logger.exception("Failed to push %s ", self.url) defer.returnValue(False) rejected = [] diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 172c27c137..3601f2d365 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2017 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -29,6 +30,21 @@ INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$") def _room_member_count(ev, condition, room_member_count): + return _test_ineq_condition(condition, room_member_count) + + +def _sender_notification_permission(ev, condition, sender_power_level, power_levels): + notif_level_key = condition.get('key') + if notif_level_key is None: + return False + + notif_levels = power_levels.get('notifications', {}) + room_notif_level = notif_levels.get(notif_level_key, 50) + + return sender_power_level >= room_notif_level + + +def _test_ineq_condition(condition, number): if 'is' not in condition: return False m = INEQUALITY_EXPR.match(condition['is']) @@ -41,15 +57,15 @@ def _room_member_count(ev, condition, room_member_count): rhs = int(rhs) if ineq == '' or ineq == '==': - return room_member_count == rhs + return number == rhs elif ineq == '<': - return room_member_count < rhs + return number < rhs elif ineq == '>': - return room_member_count > rhs + return number > rhs elif ineq == '>=': - return room_member_count >= rhs + return number >= rhs elif ineq == '<=': - return room_member_count <= rhs + return number <= rhs else: return False @@ -65,9 +81,11 @@ def tweaks_for_actions(actions): class PushRuleEvaluatorForEvent(object): - def __init__(self, event, room_member_count): + def __init__(self, event, room_member_count, sender_power_level, power_levels): self._event = event self._room_member_count = room_member_count + self._sender_power_level = sender_power_level + self._power_levels = power_levels # Maps strings of e.g. 'content.body' -> event["content"]["body"] self._value_cache = _flatten_dict(event) @@ -81,6 +99,10 @@ class PushRuleEvaluatorForEvent(object): return _room_member_count( self._event, condition, self._room_member_count ) + elif condition['kind'] == 'sender_notification_permission': + return _sender_notification_permission( + self._event, condition, self._sender_power_level, self._power_levels, + ) else: return True @@ -183,7 +205,7 @@ def _glob_to_re(glob, word_boundary): r, ) if word_boundary: - r = r"\b%s\b" % (r,) + r = _re_word_boundary(r) return re.compile(r, flags=re.IGNORECASE) else: @@ -192,7 +214,7 @@ def _glob_to_re(glob, word_boundary): return re.compile(r, flags=re.IGNORECASE) elif word_boundary: r = re.escape(glob) - r = r"\b%s\b" % (r,) + r = _re_word_boundary(r) return re.compile(r, flags=re.IGNORECASE) else: @@ -200,6 +222,18 @@ def _glob_to_re(glob, word_boundary): return re.compile(r, flags=re.IGNORECASE) +def _re_word_boundary(r): + """ + Adds word boundary characters to the start and end of an + expression to require that the match occur as a whole word, + but do so respecting the fact that strings starting or ending + with non-word characters will change word boundaries. + """ + # we can't use \b as it chokes on unicode. however \W seems to be okay + # as shorthand for [^0-9A-Za-z_]. + return r"(^|\W)%s(\W|$)" % (r,) + + def _flatten_dict(d, prefix=[], result=None): if result is None: result = {} diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py index 491f27bded..71576330a9 100644 --- a/synapse/push/pusher.py +++ b/synapse/push/pusher.py @@ -27,7 +27,7 @@ logger = logging.getLogger(__name__) try: from synapse.push.emailpusher import EmailPusher from synapse.push.mailer import Mailer, load_jinja2_templates -except: +except Exception: pass diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 43cb6e9c01..134e89b371 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -17,7 +17,7 @@ from twisted.internet import defer from .pusher import PusherFactory -from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred +from synapse.util.logcontext import make_deferred_yieldable, preserve_fn from synapse.util.async import run_on_reactor import logging @@ -103,19 +103,25 @@ class PusherPool: yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) @defer.inlineCallbacks - def remove_pushers_by_user(self, user_id, except_access_token_id=None): - all = yield self.store.get_all_pushers() - logger.info( - "Removing all pushers for user %s except access tokens id %r", - user_id, except_access_token_id - ) - for p in all: - if p['user_name'] == user_id and p['access_token'] != except_access_token_id: + def remove_pushers_by_access_token(self, user_id, access_tokens): + """Remove the pushers for a given user corresponding to a set of + access_tokens. + + Args: + user_id (str): user to remove pushers for + access_tokens (Iterable[int]): access token *ids* to remove pushers + for + """ + tokens = set(access_tokens) + for p in (yield self.store.get_pushers_by_user_id(user_id)): + if p['access_token'] in tokens: logger.info( "Removing pusher for app id %s, pushkey %s, user %s", p['app_id'], p['pushkey'], p['user_name'] ) - yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) + yield self.remove_pusher( + p['app_id'], p['pushkey'], p['user_name'], + ) @defer.inlineCallbacks def on_new_notifications(self, min_stream_id, max_stream_id): @@ -136,8 +142,8 @@ class PusherPool: ) ) - yield preserve_context_over_deferred(defer.gatherResults(deferreds)) - except: + yield make_deferred_yieldable(defer.gatherResults(deferreds)) + except Exception: logger.exception("Exception in pusher on_new_notifications") @defer.inlineCallbacks @@ -161,8 +167,8 @@ class PusherPool: preserve_fn(p.on_new_receipts)(min_stream_id, max_stream_id) ) - yield preserve_context_over_deferred(defer.gatherResults(deferreds)) - except: + yield make_deferred_yieldable(defer.gatherResults(deferreds)) + except Exception: logger.exception("Exception in pusher on_new_receipts") @defer.inlineCallbacks @@ -188,7 +194,7 @@ class PusherPool: for pusherdict in pushers: try: p = self.pusher_factory.create_pusher(pusherdict) - except: + except Exception: logger.exception("Couldn't start a pusher: caught Exception") continue if p: diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 630e92c90e..7052333c19 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -40,7 +40,6 @@ REQUIREMENTS = { "pymacaroons-pynacl": ["pymacaroons"], "msgpack-python>=0.3.0": ["msgpack"], "phonenumbers>=8.2.0": ["phonenumbers"], - "affinity": ["affinity"], } CONDITIONAL_REQUIREMENTS = { "web_client": { @@ -59,6 +58,9 @@ CONDITIONAL_REQUIREMENTS = { "psutil": { "psutil>=2.0.0": ["psutil>=2.0.0"], }, + "affinity": { + "affinity": ["affinity"], + }, } diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index b962641166..61f5590c53 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) class BaseSlavedStore(SQLBaseStore): def __init__(self, db_conn, hs): - super(BaseSlavedStore, self).__init__(hs) + super(BaseSlavedStore, self).__init__(db_conn, hs) if isinstance(self.database_engine, PostgresEngine): self._cache_id_gen = SlavedIdTracker( db_conn, "cache_invalidation_stream", "stream_id", diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 94ebbffc1b..29d7296b43 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -12,20 +12,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from ._base import BaseSlavedStore -from ._slaved_id_tracker import SlavedIdTracker +import logging from synapse.api.constants import EventTypes from synapse.storage import DataStore -from synapse.storage.roommember import RoomMemberStore from synapse.storage.event_federation import EventFederationStore from synapse.storage.event_push_actions import EventPushActionsStore -from synapse.storage.state import StateStore +from synapse.storage.roommember import RoomMemberStore +from synapse.storage.state import StateGroupReadStore from synapse.storage.stream import StreamStore from synapse.util.caches.stream_change_cache import StreamChangeCache - -import logging - +from ._base import BaseSlavedStore +from ._slaved_id_tracker import SlavedIdTracker logger = logging.getLogger(__name__) @@ -39,7 +37,7 @@ logger = logging.getLogger(__name__) # the method descriptor on the DataStore and chuck them into our class. -class SlavedEventStore(BaseSlavedStore): +class SlavedEventStore(StateGroupReadStore, BaseSlavedStore): def __init__(self, db_conn, hs): super(SlavedEventStore, self).__init__(db_conn, hs) @@ -90,25 +88,9 @@ class SlavedEventStore(BaseSlavedStore): _get_unread_counts_by_pos_txn = ( DataStore._get_unread_counts_by_pos_txn.__func__ ) - _get_state_group_for_events = ( - StateStore.__dict__["_get_state_group_for_events"] - ) - _get_state_group_for_event = ( - StateStore.__dict__["_get_state_group_for_event"] - ) - _get_state_groups_from_groups = ( - StateStore.__dict__["_get_state_groups_from_groups"] - ) - _get_state_groups_from_groups_txn = ( - DataStore._get_state_groups_from_groups_txn.__func__ - ) get_recent_event_ids_for_room = ( StreamStore.__dict__["get_recent_event_ids_for_room"] ) - get_current_state_ids = ( - StateStore.__dict__["get_current_state_ids"] - ) - get_state_group_delta = StateStore.__dict__["get_state_group_delta"] _get_joined_hosts_cache = RoomMemberStore.__dict__["_get_joined_hosts_cache"] has_room_changed_since = DataStore.has_room_changed_since.__func__ @@ -134,12 +116,6 @@ class SlavedEventStore(BaseSlavedStore): DataStore.get_room_events_stream_for_room.__func__ ) get_events_around = DataStore.get_events_around.__func__ - get_state_for_event = DataStore.get_state_for_event.__func__ - get_state_for_events = DataStore.get_state_for_events.__func__ - get_state_groups = DataStore.get_state_groups.__func__ - get_state_groups_ids = DataStore.get_state_groups_ids.__func__ - get_state_ids_for_event = DataStore.get_state_ids_for_event.__func__ - get_state_ids_for_events = DataStore.get_state_ids_for_events.__func__ get_joined_users_from_state = DataStore.get_joined_users_from_state.__func__ get_joined_users_from_context = DataStore.get_joined_users_from_context.__func__ _get_joined_users_from_context = ( @@ -169,10 +145,7 @@ class SlavedEventStore(BaseSlavedStore): _get_rooms_for_user_where_membership_is_txn = ( DataStore._get_rooms_for_user_where_membership_is_txn.__func__ ) - _get_state_for_groups = DataStore._get_state_for_groups.__func__ - _get_all_state_from_cache = DataStore._get_all_state_from_cache.__func__ _get_events_around_txn = DataStore._get_events_around_txn.__func__ - _get_some_state_from_cache = DataStore._get_some_state_from_cache.__func__ get_backfill_events = DataStore.get_backfill_events.__func__ _get_backfill_events = DataStore._get_backfill_events.__func__ diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py new file mode 100644 index 0000000000..0bc4bce5b0 --- /dev/null +++ b/synapse/replication/slave/storage/groups.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import BaseSlavedStore +from ._slaved_id_tracker import SlavedIdTracker +from synapse.storage import DataStore +from synapse.util.caches.stream_change_cache import StreamChangeCache + + +class SlavedGroupServerStore(BaseSlavedStore): + def __init__(self, db_conn, hs): + super(SlavedGroupServerStore, self).__init__(db_conn, hs) + + self.hs = hs + + self._group_updates_id_gen = SlavedIdTracker( + db_conn, "local_group_updates", "stream_id", + ) + self._group_updates_stream_cache = StreamChangeCache( + "_group_updates_stream_cache", self._group_updates_id_gen.get_current_token(), + ) + + get_groups_changes_for_user = DataStore.get_groups_changes_for_user.__func__ + get_group_stream_token = DataStore.get_group_stream_token.__func__ + get_all_groups_for_user = DataStore.get_all_groups_for_user.__func__ + + def stream_positions(self): + result = super(SlavedGroupServerStore, self).stream_positions() + result["groups"] = self._group_updates_id_gen.get_current_token() + return result + + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "groups": + self._group_updates_id_gen.advance(token) + for row in rows: + self._group_updates_stream_cache.entity_has_changed( + row.user_id, token + ) + + return super(SlavedGroupServerStore, self).process_replication_rows( + stream_name, token, rows + ) diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 3ea3ca5a6f..786c3fe864 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -160,7 +160,11 @@ class ReplicationStreamer(object): "Getting stream: %s: %s -> %s", stream.NAME, stream.last_token, stream.upto_token ) - updates, current_token = yield stream.get_updates() + try: + updates, current_token = yield stream.get_updates() + except Exception: + logger.info("Failed to handle stream %s", stream.NAME) + raise logger.debug( "Sending %d updates to %d connections", @@ -212,11 +216,12 @@ class ReplicationStreamer(object): self.federation_sender.federation_ack(token) @measure_func("repl.on_user_sync") + @defer.inlineCallbacks def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms): """A client has started/stopped syncing on a worker. """ user_sync_counter.inc() - self.presence_handler.update_external_syncs_row( + yield self.presence_handler.update_external_syncs_row( conn_id, user_id, is_syncing, last_sync_ms, ) @@ -240,11 +245,12 @@ class ReplicationStreamer(object): getattr(self.store, cache_func).invalidate(tuple(keys)) @measure_func("repl.on_user_ip") + @defer.inlineCallbacks def on_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen): """The client saw a user request """ user_ip_cache_counter.inc() - self.store.insert_client_ip( + yield self.store.insert_client_ip( user_id, access_token, ip, user_agent, device_id, last_seen, ) diff --git a/synapse/replication/tcp/streams.py b/synapse/replication/tcp/streams.py index fbafe12cc2..4c60bf79f9 100644 --- a/synapse/replication/tcp/streams.py +++ b/synapse/replication/tcp/streams.py @@ -118,6 +118,12 @@ CurrentStateDeltaStreamRow = namedtuple("CurrentStateDeltaStream", ( "state_key", # str "event_id", # str, optional )) +GroupsStreamRow = namedtuple("GroupsStreamRow", ( + "group_id", # str + "user_id", # str + "type", # str + "content", # dict +)) class Stream(object): @@ -464,6 +470,19 @@ class CurrentStateDeltaStream(Stream): super(CurrentStateDeltaStream, self).__init__(hs) +class GroupServerStream(Stream): + NAME = "groups" + ROW_TYPE = GroupsStreamRow + + def __init__(self, hs): + store = hs.get_datastore() + + self.current_token = store.get_group_stream_token + self.update_function = store.get_all_groups_changes + + super(GroupServerStream, self).__init__(hs) + + STREAMS_MAP = { stream.NAME: stream for stream in ( @@ -482,5 +501,6 @@ STREAMS_MAP = { TagAccountDataStream, AccountDataStream, CurrentStateDeltaStream, + GroupServerStream, ) } diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 3d809d181b..16f5a73b95 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -52,6 +52,7 @@ from synapse.rest.client.v2_alpha import ( thirdparty, sendtodevice, user_directory, + groups, ) from synapse.http.server import JsonResource @@ -102,3 +103,4 @@ class ClientRestResource(JsonResource): thirdparty.register_servlets(hs, client_resource) sendtodevice.register_servlets(hs, client_resource) user_directory.register_servlets(hs, client_resource) + groups.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index 465b25033d..5022808ea9 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -137,8 +137,8 @@ class DeactivateAccountRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/admin/deactivate/(?P<target_user_id>[^/]*)") def __init__(self, hs): - self.store = hs.get_datastore() super(DeactivateAccountRestServlet, self).__init__(hs) + self._deactivate_account_handler = hs.get_deactivate_account_handler() @defer.inlineCallbacks def on_POST(self, request, target_user_id): @@ -149,12 +149,7 @@ class DeactivateAccountRestServlet(ClientV1RestServlet): if not is_admin: raise AuthError(403, "You are not a server admin") - # FIXME: Theoretically there is a race here wherein user resets password - # using threepid. - yield self.store.user_delete_access_tokens(target_user_id) - yield self.store.user_delete_threepids(target_user_id) - yield self.store.user_set_password_hash(target_user_id, None) - + yield self._deactivate_account_handler.deactivate_account(target_user_id) defer.returnValue((200, {})) @@ -314,7 +309,7 @@ class ResetPasswordRestServlet(ClientV1RestServlet): super(ResetPasswordRestServlet, self).__init__(hs) self.hs = hs self.auth = hs.get_auth() - self.auth_handler = hs.get_auth_handler() + self._set_password_handler = hs.get_set_password_handler() @defer.inlineCallbacks def on_POST(self, request, target_user_id): @@ -335,7 +330,7 @@ class ResetPasswordRestServlet(ClientV1RestServlet): logger.info("new_password: %r", new_password) - yield self.auth_handler.set_password( + yield self._set_password_handler.set_password( target_user_id, new_password, requester ) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index f15aa5c13f..1c3933380f 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -93,7 +93,7 @@ class ClientDirectoryServer(ClientV1RestServlet): ) except SynapseError as e: raise e - except: + except Exception: logger.exception("Failed to create association") raise except AuthError: diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index a43410fb37..5669ecb724 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -85,7 +85,6 @@ def login_id_thirdparty_from_phone(identifier): class LoginRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/login$") - PASS_TYPE = "m.login.password" SAML2_TYPE = "m.login.saml2" CAS_TYPE = "m.login.cas" TOKEN_TYPE = "m.login.token" @@ -94,7 +93,6 @@ class LoginRestServlet(ClientV1RestServlet): def __init__(self, hs): super(LoginRestServlet, self).__init__(hs) self.idp_redirect_url = hs.config.saml2_idp_redirect_url - self.password_enabled = hs.config.password_enabled self.saml2_enabled = hs.config.saml2_enabled self.jwt_enabled = hs.config.jwt_enabled self.jwt_secret = hs.config.jwt_secret @@ -121,8 +119,10 @@ class LoginRestServlet(ClientV1RestServlet): # fall back to the fallback API if they don't understand one of the # login flow types returned. flows.append({"type": LoginRestServlet.TOKEN_TYPE}) - if self.password_enabled: - flows.append({"type": LoginRestServlet.PASS_TYPE}) + + flows.extend(( + {"type": t} for t in self.auth_handler.get_supported_login_types() + )) return (200, {"flows": flows}) @@ -133,14 +133,8 @@ class LoginRestServlet(ClientV1RestServlet): def on_POST(self, request): login_submission = parse_json_object_from_request(request) try: - if login_submission["type"] == LoginRestServlet.PASS_TYPE: - if not self.password_enabled: - raise SynapseError(400, "Password login has been disabled.") - - result = yield self.do_password_login(login_submission) - defer.returnValue(result) - elif self.saml2_enabled and (login_submission["type"] == - LoginRestServlet.SAML2_TYPE): + if self.saml2_enabled and (login_submission["type"] == + LoginRestServlet.SAML2_TYPE): relay_state = "" if "relay_state" in login_submission: relay_state = "&RelayState=" + urllib.quote( @@ -157,15 +151,31 @@ class LoginRestServlet(ClientV1RestServlet): result = yield self.do_token_login(login_submission) defer.returnValue(result) else: - raise SynapseError(400, "Bad login type.") + result = yield self._do_other_login(login_submission) + defer.returnValue(result) except KeyError: raise SynapseError(400, "Missing JSON keys.") @defer.inlineCallbacks - def do_password_login(self, login_submission): - if "password" not in login_submission: - raise SynapseError(400, "Missing parameter: password") + def _do_other_login(self, login_submission): + """Handle non-token/saml/jwt logins + Args: + login_submission: + + Returns: + (int, object): HTTP code/response + """ + # Log the request we got, but only certain fields to minimise the chance of + # logging someone's password (even if they accidentally put it in the wrong + # field) + logger.info( + "Got login request with identifier: %r, medium: %r, address: %r, user: %r", + login_submission.get('identifier'), + login_submission.get('medium'), + login_submission.get('address'), + login_submission.get('user'), + ) login_submission_legacy_convert(login_submission) if "identifier" not in login_submission: @@ -208,30 +218,29 @@ class LoginRestServlet(ClientV1RestServlet): if "user" not in identifier: raise SynapseError(400, "User identifier is missing 'user' key") - user_id = identifier["user"] - - if not user_id.startswith('@'): - user_id = UserID.create( - user_id, self.hs.hostname - ).to_string() - auth_handler = self.auth_handler - user_id = yield auth_handler.validate_password_login( - user_id=user_id, - password=login_submission["password"], + canonical_user_id, callback = yield auth_handler.validate_login( + identifier["user"], + login_submission, + ) + + device_id = yield self._register_device( + canonical_user_id, login_submission, ) - device_id = yield self._register_device(user_id, login_submission) access_token = yield auth_handler.get_access_token_for_user_id( - user_id, device_id, - login_submission.get("initial_device_display_name"), + canonical_user_id, device_id, ) + result = { - "user_id": user_id, # may have changed + "user_id": canonical_user_id, "access_token": access_token, "home_server": self.hs.hostname, "device_id": device_id, } + if callback is not None: + yield callback(result) + defer.returnValue((200, result)) @defer.inlineCallbacks @@ -244,7 +253,6 @@ class LoginRestServlet(ClientV1RestServlet): device_id = yield self._register_device(user_id, login_submission) access_token = yield auth_handler.get_access_token_for_user_id( user_id, device_id, - login_submission.get("initial_device_display_name"), ) result = { "user_id": user_id, # may have changed @@ -278,7 +286,7 @@ class LoginRestServlet(ClientV1RestServlet): if user is None: raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED) - user_id = UserID.create(user, self.hs.hostname).to_string() + user_id = UserID(user, self.hs.hostname).to_string() auth_handler = self.auth_handler registered_user_id = yield auth_handler.check_user_exists(user_id) if registered_user_id: @@ -287,7 +295,6 @@ class LoginRestServlet(ClientV1RestServlet): ) access_token = yield auth_handler.get_access_token_for_user_id( registered_user_id, device_id, - login_submission.get("initial_device_display_name"), ) result = { @@ -444,7 +451,7 @@ class CasTicketServlet(ClientV1RestServlet): if required_value != actual_value: raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) - user_id = UserID.create(user, self.hs.hostname).to_string() + user_id = UserID(user, self.hs.hostname).to_string() auth_handler = self.auth_handler registered_user_id = yield auth_handler.check_user_exists(user_id) if not registered_user_id: diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py index 1358d0acab..ca49955935 100644 --- a/synapse/rest/client/v1/logout.py +++ b/synapse/rest/client/v1/logout.py @@ -16,6 +16,7 @@ from twisted.internet import defer from synapse.api.auth import get_access_token_from_request +from synapse.api.errors import AuthError from .base import ClientV1RestServlet, client_path_patterns @@ -30,15 +31,30 @@ class LogoutRestServlet(ClientV1RestServlet): def __init__(self, hs): super(LogoutRestServlet, self).__init__(hs) - self.store = hs.get_datastore() + self._auth = hs.get_auth() + self._auth_handler = hs.get_auth_handler() + self._device_handler = hs.get_device_handler() def on_OPTIONS(self, request): return (200, {}) @defer.inlineCallbacks def on_POST(self, request): - access_token = get_access_token_from_request(request) - yield self.store.delete_access_token(access_token) + try: + requester = yield self.auth.get_user_by_req(request) + except AuthError: + # this implies the access token has already been deleted. + pass + else: + if requester.device_id is None: + # the acccess token wasn't associated with a device. + # Just delete the access token + access_token = get_access_token_from_request(request) + yield self._auth_handler.delete_access_token(access_token) + else: + yield self._device_handler.delete_device( + requester.user.to_string(), requester.device_id) + defer.returnValue((200, {})) @@ -47,8 +63,9 @@ class LogoutAllRestServlet(ClientV1RestServlet): def __init__(self, hs): super(LogoutAllRestServlet, self).__init__(hs) - self.store = hs.get_datastore() self.auth = hs.get_auth() + self._auth_handler = hs.get_auth_handler() + self._device_handler = hs.get_device_handler() def on_OPTIONS(self, request): return (200, {}) @@ -57,7 +74,13 @@ class LogoutAllRestServlet(ClientV1RestServlet): def on_POST(self, request): requester = yield self.auth.get_user_by_req(request) user_id = requester.user.to_string() - yield self.store.user_delete_access_tokens(user_id) + + # first delete all of the user's devices + yield self._device_handler.delete_all_devices_for_user(user_id) + + # .. and then delete any access tokens which weren't associated with + # devices. + yield self._auth_handler.delete_access_tokens_for_user(user_id) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index 47b2dc45e7..4a73813c58 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -78,7 +78,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet): raise KeyError() except SynapseError as e: raise e - except: + except Exception: raise SynapseError(400, "Unable to parse state") yield self.presence_handler.set_state(user, state) diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index 1a5045c9ec..e4e3611a14 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -26,13 +26,13 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet): def __init__(self, hs): super(ProfileDisplaynameRestServlet, self).__init__(hs) - self.handlers = hs.get_handlers() + self.profile_handler = hs.get_profile_handler() @defer.inlineCallbacks def on_GET(self, request, user_id): user = UserID.from_string(user_id) - displayname = yield self.handlers.profile_handler.get_displayname( + displayname = yield self.profile_handler.get_displayname( user, ) @@ -52,10 +52,10 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet): try: new_name = content["displayname"] - except: + except Exception: defer.returnValue((400, "Unable to parse name")) - yield self.handlers.profile_handler.set_displayname( + yield self.profile_handler.set_displayname( user, requester, new_name, is_admin) defer.returnValue((200, {})) @@ -69,13 +69,13 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet): def __init__(self, hs): super(ProfileAvatarURLRestServlet, self).__init__(hs) - self.handlers = hs.get_handlers() + self.profile_handler = hs.get_profile_handler() @defer.inlineCallbacks def on_GET(self, request, user_id): user = UserID.from_string(user_id) - avatar_url = yield self.handlers.profile_handler.get_avatar_url( + avatar_url = yield self.profile_handler.get_avatar_url( user, ) @@ -94,10 +94,10 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet): content = parse_json_object_from_request(request) try: new_name = content["avatar_url"] - except: + except Exception: defer.returnValue((400, "Unable to parse name")) - yield self.handlers.profile_handler.set_avatar_url( + yield self.profile_handler.set_avatar_url( user, requester, new_name, is_admin) defer.returnValue((200, {})) @@ -111,16 +111,16 @@ class ProfileRestServlet(ClientV1RestServlet): def __init__(self, hs): super(ProfileRestServlet, self).__init__(hs) - self.handlers = hs.get_handlers() + self.profile_handler = hs.get_profile_handler() @defer.inlineCallbacks def on_GET(self, request, user_id): user = UserID.from_string(user_id) - displayname = yield self.handlers.profile_handler.get_displayname( + displayname = yield self.profile_handler.get_displayname( user, ) - avatar_url = yield self.handlers.profile_handler.get_avatar_url( + avatar_url = yield self.profile_handler.get_avatar_url( user, ) diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index ecf7e311a9..32ed1d3ab2 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -359,7 +359,7 @@ class RegisterRestServlet(ClientV1RestServlet): if compare_digest(want_mac, got_mac): handler = self.handlers.registration_handler user_id, token = yield handler.register( - localpart=user, + localpart=user.lower(), password=password, admin=bool(admin), ) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 4be0fee38d..75b735b47d 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -238,7 +238,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet): try: content = parse_json_object_from_request(request) - except: + except Exception: # Turns out we used to ignore the body entirely, and some clients # cheekily send invalid bodies. content = {} @@ -247,7 +247,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet): room_id = room_identifier try: remote_room_hosts = request.args["server_name"] - except: + except Exception: remote_room_hosts = None elif RoomAlias.is_valid(room_identifier): handler = self.handlers.room_member_handler @@ -403,10 +403,9 @@ class JoinedRoomMemberListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() users_with_profile = yield self.message_handler.get_joined_members( - user_id, room_id, + requester, room_id, ) defer.returnValue((200, { @@ -588,7 +587,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet): try: content = parse_json_object_from_request(request) - except: + except Exception: # Turns out we used to ignore the body entirely, and some clients # cheekily send invalid bodies. content = {} diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py index 1f5bc24cc3..77434937ff 100644 --- a/synapse/rest/client/v2_alpha/_base.py +++ b/synapse/rest/client/v2_alpha/_base.py @@ -15,12 +15,13 @@ """This module contains base REST classes for constructing client v1 servlets. """ - -from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX +import logging import re -import logging +from twisted.internet import defer +from synapse.api.errors import InteractiveAuthIncompleteError +from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX logger = logging.getLogger(__name__) @@ -57,3 +58,37 @@ def set_timeline_upper_limit(filter_json, filter_timeline_limit): filter_json['room']['timeline']["limit"] = min( filter_json['room']['timeline']['limit'], filter_timeline_limit) + + +def interactive_auth_handler(orig): + """Wraps an on_POST method to handle InteractiveAuthIncompleteErrors + + Takes a on_POST method which returns a deferred (errcode, body) response + and adds exception handling to turn a InteractiveAuthIncompleteError into + a 401 response. + + Normal usage is: + + @interactive_auth_handler + @defer.inlineCallbacks + def on_POST(self, request): + # ... + yield self.auth_handler.check_auth + """ + def wrapped(*args, **kwargs): + res = defer.maybeDeferred(orig, *args, **kwargs) + res.addErrback(_catch_incomplete_interactive_auth) + return res + return wrapped + + +def _catch_incomplete_interactive_auth(f): + """helper for interactive_auth_handler + + Catches InteractiveAuthIncompleteErrors and turns them into 401 responses + + Args: + f (failure.Failure): + """ + f.trap(InteractiveAuthIncompleteError) + return 401, f.value.result diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 4990b22b9f..385a3ad2ec 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -13,21 +13,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging from twisted.internet import defer +from synapse.api.auth import has_access_token from synapse.api.constants import LoginType -from synapse.api.errors import LoginError, SynapseError, Codes +from synapse.api.errors import Codes, SynapseError from synapse.http.servlet import ( - RestServlet, parse_json_object_from_request, assert_params_in_request + RestServlet, assert_params_in_request, + parse_json_object_from_request, ) from synapse.util.async import run_on_reactor from synapse.util.msisdn import phone_number_to_msisdn - -from ._base import client_v2_patterns - -import logging - +from ._base import client_v2_patterns, interactive_auth_handler logger = logging.getLogger(__name__) @@ -99,56 +98,61 @@ class PasswordRestServlet(RestServlet): self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() self.datastore = self.hs.get_datastore() + self._set_password_handler = hs.get_set_password_handler() + @interactive_auth_handler @defer.inlineCallbacks def on_POST(self, request): - yield run_on_reactor() - body = parse_json_object_from_request(request) - authed, result, params, _ = yield self.auth_handler.check_auth([ - [LoginType.PASSWORD], - [LoginType.EMAIL_IDENTITY], - [LoginType.MSISDN], - ], body, self.hs.get_ip_from_request(request)) - - if not authed: - defer.returnValue((401, result)) - - user_id = None - requester = None - - if LoginType.PASSWORD in result: - # if using password, they should also be logged in + # there are two possibilities here. Either the user does not have an + # access token, and needs to do a password reset; or they have one and + # need to validate their identity. + # + # In the first case, we offer a couple of means of identifying + # themselves (email and msisdn, though it's unclear if msisdn actually + # works). + # + # In the second case, we require a password to confirm their identity. + + if has_access_token(request): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() - if user_id != result[LoginType.PASSWORD]: - raise LoginError(400, "", Codes.UNKNOWN) - elif LoginType.EMAIL_IDENTITY in result: - threepid = result[LoginType.EMAIL_IDENTITY] - if 'medium' not in threepid or 'address' not in threepid: - raise SynapseError(500, "Malformed threepid") - if threepid['medium'] == 'email': - # For emails, transform the address to lowercase. - # We store all email addreses as lowercase in the DB. - # (See add_threepid in synapse/handlers/auth.py) - threepid['address'] = threepid['address'].lower() - # if using email, we must know about the email they're authing with! - threepid_user_id = yield self.datastore.get_user_id_by_threepid( - threepid['medium'], threepid['address'] + params = yield self.auth_handler.validate_user_via_ui_auth( + requester, body, self.hs.get_ip_from_request(request), ) - if not threepid_user_id: - raise SynapseError(404, "Email address not found", Codes.NOT_FOUND) - user_id = threepid_user_id + user_id = requester.user.to_string() else: - logger.error("Auth succeeded but no known type!", result.keys()) - raise SynapseError(500, "", Codes.UNKNOWN) + requester = None + result, params, _ = yield self.auth_handler.check_auth( + [[LoginType.EMAIL_IDENTITY], [LoginType.MSISDN]], + body, self.hs.get_ip_from_request(request), + ) + + if LoginType.EMAIL_IDENTITY in result: + threepid = result[LoginType.EMAIL_IDENTITY] + if 'medium' not in threepid or 'address' not in threepid: + raise SynapseError(500, "Malformed threepid") + if threepid['medium'] == 'email': + # For emails, transform the address to lowercase. + # We store all email addreses as lowercase in the DB. + # (See add_threepid in synapse/handlers/auth.py) + threepid['address'] = threepid['address'].lower() + # if using email, we must know about the email they're authing with! + threepid_user_id = yield self.datastore.get_user_id_by_threepid( + threepid['medium'], threepid['address'] + ) + if not threepid_user_id: + raise SynapseError(404, "Email address not found", Codes.NOT_FOUND) + user_id = threepid_user_id + else: + logger.error("Auth succeeded but no known type!", result.keys()) + raise SynapseError(500, "", Codes.UNKNOWN) if 'new_password' not in params: raise SynapseError(400, "", Codes.MISSING_PARAM) new_password = params['new_password'] - yield self.auth_handler.set_password( + yield self._set_password_handler.set_password( user_id, new_password, requester ) @@ -162,42 +166,32 @@ class DeactivateAccountRestServlet(RestServlet): PATTERNS = client_v2_patterns("/account/deactivate$") def __init__(self, hs): + super(DeactivateAccountRestServlet, self).__init__() self.hs = hs - self.store = hs.get_datastore() self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() - super(DeactivateAccountRestServlet, self).__init__() + self._deactivate_account_handler = hs.get_deactivate_account_handler() + @interactive_auth_handler @defer.inlineCallbacks def on_POST(self, request): body = parse_json_object_from_request(request) - authed, result, params, _ = yield self.auth_handler.check_auth([ - [LoginType.PASSWORD], - ], body, self.hs.get_ip_from_request(request)) - - if not authed: - defer.returnValue((401, result)) - - user_id = None - requester = None - - if LoginType.PASSWORD in result: - # if using password, they should also be logged in - requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() - if user_id != result[LoginType.PASSWORD]: - raise LoginError(400, "", Codes.UNKNOWN) - else: - logger.error("Auth succeeded but no known type!", result.keys()) - raise SynapseError(500, "", Codes.UNKNOWN) + requester = yield self.auth.get_user_by_req(request) - # FIXME: Theoretically there is a race here wherein user resets password - # using threepid. - yield self.store.user_delete_access_tokens(user_id) - yield self.store.user_delete_threepids(user_id) - yield self.store.user_set_password_hash(user_id, None) + # allow ASes to dectivate their own users + if requester.app_service: + yield self._deactivate_account_handler.deactivate_account( + requester.user.to_string() + ) + defer.returnValue((200, {})) + yield self.auth_handler.validate_user_via_ui_auth( + requester, body, self.hs.get_ip_from_request(request), + ) + yield self._deactivate_account_handler.deactivate_account( + requester.user.to_string(), + ) defer.returnValue((200, {})) @@ -373,6 +367,20 @@ class ThreepidDeleteRestServlet(RestServlet): defer.returnValue((200, {})) +class WhoamiRestServlet(RestServlet): + PATTERNS = client_v2_patterns("/account/whoami$") + + def __init__(self, hs): + super(WhoamiRestServlet, self).__init__() + self.auth = hs.get_auth() + + @defer.inlineCallbacks + def on_GET(self, request): + requester = yield self.auth.get_user_by_req(request) + + defer.returnValue((200, {'user_id': requester.user.to_string()})) + + def register_servlets(hs, http_server): EmailPasswordRequestTokenRestServlet(hs).register(http_server) MsisdnPasswordRequestTokenRestServlet(hs).register(http_server) @@ -382,3 +390,4 @@ def register_servlets(hs, http_server): MsisdnThreepidRequestTokenRestServlet(hs).register(http_server) ThreepidRestServlet(hs).register(http_server) ThreepidDeleteRestServlet(hs).register(http_server) + WhoamiRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index b57ba95d24..35d58b367a 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -17,15 +17,15 @@ import logging from twisted.internet import defer -from synapse.api import constants, errors +from synapse.api import errors from synapse.http import servlet -from ._base import client_v2_patterns +from ._base import client_v2_patterns, interactive_auth_handler logger = logging.getLogger(__name__) class DevicesRestServlet(servlet.RestServlet): - PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False) + PATTERNS = client_v2_patterns("/devices$", v2_alpha=False) def __init__(self, hs): """ @@ -51,7 +51,7 @@ class DeleteDevicesRestServlet(servlet.RestServlet): API for bulk deletion of devices. Accepts a JSON object with a devices key which lists the device_ids to delete. Requires user interactive auth. """ - PATTERNS = client_v2_patterns("/delete_devices", releases=[], v2_alpha=False) + PATTERNS = client_v2_patterns("/delete_devices", v2_alpha=False) def __init__(self, hs): super(DeleteDevicesRestServlet, self).__init__() @@ -60,8 +60,11 @@ class DeleteDevicesRestServlet(servlet.RestServlet): self.device_handler = hs.get_device_handler() self.auth_handler = hs.get_auth_handler() + @interactive_auth_handler @defer.inlineCallbacks def on_POST(self, request): + requester = yield self.auth.get_user_by_req(request) + try: body = servlet.parse_json_object_from_request(request) except errors.SynapseError as e: @@ -77,14 +80,10 @@ class DeleteDevicesRestServlet(servlet.RestServlet): 400, "No devices supplied", errcode=errors.Codes.MISSING_PARAM ) - authed, result, params, _ = yield self.auth_handler.check_auth([ - [constants.LoginType.PASSWORD], - ], body, self.hs.get_ip_from_request(request)) - - if not authed: - defer.returnValue((401, result)) + yield self.auth_handler.validate_user_via_ui_auth( + requester, body, self.hs.get_ip_from_request(request), + ) - requester = yield self.auth.get_user_by_req(request) yield self.device_handler.delete_devices( requester.user.to_string(), body['devices'], @@ -93,8 +92,7 @@ class DeleteDevicesRestServlet(servlet.RestServlet): class DeviceRestServlet(servlet.RestServlet): - PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", - releases=[], v2_alpha=False) + PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", v2_alpha=False) def __init__(self, hs): """ @@ -116,8 +114,11 @@ class DeviceRestServlet(servlet.RestServlet): ) defer.returnValue((200, device)) + @interactive_auth_handler @defer.inlineCallbacks def on_DELETE(self, request, device_id): + requester = yield self.auth.get_user_by_req(request) + try: body = servlet.parse_json_object_from_request(request) @@ -129,17 +130,12 @@ class DeviceRestServlet(servlet.RestServlet): else: raise - authed, result, params, _ = yield self.auth_handler.check_auth([ - [constants.LoginType.PASSWORD], - ], body, self.hs.get_ip_from_request(request)) - - if not authed: - defer.returnValue((401, result)) + yield self.auth_handler.validate_user_via_ui_auth( + requester, body, self.hs.get_ip_from_request(request), + ) - requester = yield self.auth.get_user_by_req(request) yield self.device_handler.delete_device( - requester.user.to_string(), - device_id, + requester.user.to_string(), device_id, ) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py index d2b2fd66e6..1b9dc4528d 100644 --- a/synapse/rest/client/v2_alpha/filter.py +++ b/synapse/rest/client/v2_alpha/filter.py @@ -50,7 +50,7 @@ class GetFilterRestServlet(RestServlet): try: filter_id = int(filter_id) - except: + except Exception: raise SynapseError(400, "Invalid filter_id") try: diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py new file mode 100644 index 0000000000..f762dbfa9a --- /dev/null +++ b/synapse/rest/client/v2_alpha/groups.py @@ -0,0 +1,758 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.types import GroupID + +from ._base import client_v2_patterns + +import logging + +logger = logging.getLogger(__name__) + + +class GroupServlet(RestServlet): + """Get the group profile + """ + PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/profile$") + + def __init__(self, hs): + super(GroupServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_GET(self, request, group_id): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + group_description = yield self.groups_handler.get_group_profile( + group_id, + requester_user_id, + ) + + defer.returnValue((200, group_description)) + + @defer.inlineCallbacks + def on_POST(self, request, group_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + yield self.groups_handler.update_group_profile( + group_id, requester_user_id, content, + ) + + defer.returnValue((200, {})) + + +class GroupSummaryServlet(RestServlet): + """Get the full group summary + """ + PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/summary$") + + def __init__(self, hs): + super(GroupSummaryServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_GET(self, request, group_id): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + get_group_summary = yield self.groups_handler.get_group_summary( + group_id, + requester_user_id, + ) + + defer.returnValue((200, get_group_summary)) + + +class GroupSummaryRoomsCatServlet(RestServlet): + """Update/delete a rooms entry in the summary. + + Matches both: + - /groups/:group/summary/rooms/:room_id + - /groups/:group/summary/categories/:category/rooms/:room_id + """ + PATTERNS = client_v2_patterns( + "/groups/(?P<group_id>[^/]*)/summary" + "(/categories/(?P<category_id>[^/]+))?" + "/rooms/(?P<room_id>[^/]*)$" + ) + + def __init__(self, hs): + super(GroupSummaryRoomsCatServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_PUT(self, request, group_id, category_id, room_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + resp = yield self.groups_handler.update_group_summary_room( + group_id, requester_user_id, + room_id=room_id, + category_id=category_id, + content=content, + ) + + defer.returnValue((200, resp)) + + @defer.inlineCallbacks + def on_DELETE(self, request, group_id, category_id, room_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + resp = yield self.groups_handler.delete_group_summary_room( + group_id, requester_user_id, + room_id=room_id, + category_id=category_id, + ) + + defer.returnValue((200, resp)) + + +class GroupCategoryServlet(RestServlet): + """Get/add/update/delete a group category + """ + PATTERNS = client_v2_patterns( + "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$" + ) + + def __init__(self, hs): + super(GroupCategoryServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_GET(self, request, group_id, category_id): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + category = yield self.groups_handler.get_group_category( + group_id, requester_user_id, + category_id=category_id, + ) + + defer.returnValue((200, category)) + + @defer.inlineCallbacks + def on_PUT(self, request, group_id, category_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + resp = yield self.groups_handler.update_group_category( + group_id, requester_user_id, + category_id=category_id, + content=content, + ) + + defer.returnValue((200, resp)) + + @defer.inlineCallbacks + def on_DELETE(self, request, group_id, category_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + resp = yield self.groups_handler.delete_group_category( + group_id, requester_user_id, + category_id=category_id, + ) + + defer.returnValue((200, resp)) + + +class GroupCategoriesServlet(RestServlet): + """Get all group categories + """ + PATTERNS = client_v2_patterns( + "/groups/(?P<group_id>[^/]*)/categories/$" + ) + + def __init__(self, hs): + super(GroupCategoriesServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_GET(self, request, group_id): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + category = yield self.groups_handler.get_group_categories( + group_id, requester_user_id, + ) + + defer.returnValue((200, category)) + + +class GroupRoleServlet(RestServlet): + """Get/add/update/delete a group role + """ + PATTERNS = client_v2_patterns( + "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$" + ) + + def __init__(self, hs): + super(GroupRoleServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_GET(self, request, group_id, role_id): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + category = yield self.groups_handler.get_group_role( + group_id, requester_user_id, + role_id=role_id, + ) + + defer.returnValue((200, category)) + + @defer.inlineCallbacks + def on_PUT(self, request, group_id, role_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + resp = yield self.groups_handler.update_group_role( + group_id, requester_user_id, + role_id=role_id, + content=content, + ) + + defer.returnValue((200, resp)) + + @defer.inlineCallbacks + def on_DELETE(self, request, group_id, role_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + resp = yield self.groups_handler.delete_group_role( + group_id, requester_user_id, + role_id=role_id, + ) + + defer.returnValue((200, resp)) + + +class GroupRolesServlet(RestServlet): + """Get all group roles + """ + PATTERNS = client_v2_patterns( + "/groups/(?P<group_id>[^/]*)/roles/$" + ) + + def __init__(self, hs): + super(GroupRolesServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_GET(self, request, group_id): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + category = yield self.groups_handler.get_group_roles( + group_id, requester_user_id, + ) + + defer.returnValue((200, category)) + + +class GroupSummaryUsersRoleServlet(RestServlet): + """Update/delete a user's entry in the summary. + + Matches both: + - /groups/:group/summary/users/:room_id + - /groups/:group/summary/roles/:role/users/:user_id + """ + PATTERNS = client_v2_patterns( + "/groups/(?P<group_id>[^/]*)/summary" + "(/roles/(?P<role_id>[^/]+))?" + "/users/(?P<user_id>[^/]*)$" + ) + + def __init__(self, hs): + super(GroupSummaryUsersRoleServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_PUT(self, request, group_id, role_id, user_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + resp = yield self.groups_handler.update_group_summary_user( + group_id, requester_user_id, + user_id=user_id, + role_id=role_id, + content=content, + ) + + defer.returnValue((200, resp)) + + @defer.inlineCallbacks + def on_DELETE(self, request, group_id, role_id, user_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + resp = yield self.groups_handler.delete_group_summary_user( + group_id, requester_user_id, + user_id=user_id, + role_id=role_id, + ) + + defer.returnValue((200, resp)) + + +class GroupRoomServlet(RestServlet): + """Get all rooms in a group + """ + PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/rooms$") + + def __init__(self, hs): + super(GroupRoomServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_GET(self, request, group_id): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + result = yield self.groups_handler.get_rooms_in_group(group_id, requester_user_id) + + defer.returnValue((200, result)) + + +class GroupUsersServlet(RestServlet): + """Get all users in a group + """ + PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/users$") + + def __init__(self, hs): + super(GroupUsersServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_GET(self, request, group_id): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + result = yield self.groups_handler.get_users_in_group(group_id, requester_user_id) + + defer.returnValue((200, result)) + + +class GroupInvitedUsersServlet(RestServlet): + """Get users invited to a group + """ + PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/invited_users$") + + def __init__(self, hs): + super(GroupInvitedUsersServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_GET(self, request, group_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + result = yield self.groups_handler.get_invited_users_in_group( + group_id, + requester_user_id, + ) + + defer.returnValue((200, result)) + + +class GroupCreateServlet(RestServlet): + """Create a group + """ + PATTERNS = client_v2_patterns("/create_group$") + + def __init__(self, hs): + super(GroupCreateServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + self.server_name = hs.hostname + + @defer.inlineCallbacks + def on_POST(self, request): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + # TODO: Create group on remote server + content = parse_json_object_from_request(request) + localpart = content.pop("localpart") + group_id = GroupID(localpart, self.server_name).to_string() + + result = yield self.groups_handler.create_group( + group_id, + requester_user_id, + content, + ) + + defer.returnValue((200, result)) + + +class GroupAdminRoomsServlet(RestServlet): + """Add a room to the group + """ + PATTERNS = client_v2_patterns( + "/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)$" + ) + + def __init__(self, hs): + super(GroupAdminRoomsServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_PUT(self, request, group_id, room_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + result = yield self.groups_handler.add_room_to_group( + group_id, requester_user_id, room_id, content, + ) + + defer.returnValue((200, result)) + + @defer.inlineCallbacks + def on_DELETE(self, request, group_id, room_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + result = yield self.groups_handler.remove_room_from_group( + group_id, requester_user_id, room_id, + ) + + defer.returnValue((200, result)) + + +class GroupAdminRoomsConfigServlet(RestServlet): + """Update the config of a room in a group + """ + PATTERNS = client_v2_patterns( + "/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)" + "/config/(?P<config_key>[^/]*)$" + ) + + def __init__(self, hs): + super(GroupAdminRoomsConfigServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_PUT(self, request, group_id, room_id, config_key): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + result = yield self.groups_handler.update_room_in_group( + group_id, requester_user_id, room_id, config_key, content, + ) + + defer.returnValue((200, result)) + + +class GroupAdminUsersInviteServlet(RestServlet): + """Invite a user to the group + """ + PATTERNS = client_v2_patterns( + "/groups/(?P<group_id>[^/]*)/admin/users/invite/(?P<user_id>[^/]*)$" + ) + + def __init__(self, hs): + super(GroupAdminUsersInviteServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + self.store = hs.get_datastore() + self.is_mine_id = hs.is_mine_id + + @defer.inlineCallbacks + def on_PUT(self, request, group_id, user_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + config = content.get("config", {}) + result = yield self.groups_handler.invite( + group_id, user_id, requester_user_id, config, + ) + + defer.returnValue((200, result)) + + +class GroupAdminUsersKickServlet(RestServlet): + """Kick a user from the group + """ + PATTERNS = client_v2_patterns( + "/groups/(?P<group_id>[^/]*)/admin/users/remove/(?P<user_id>[^/]*)$" + ) + + def __init__(self, hs): + super(GroupAdminUsersKickServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_PUT(self, request, group_id, user_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + result = yield self.groups_handler.remove_user_from_group( + group_id, user_id, requester_user_id, content, + ) + + defer.returnValue((200, result)) + + +class GroupSelfLeaveServlet(RestServlet): + """Leave a joined group + """ + PATTERNS = client_v2_patterns( + "/groups/(?P<group_id>[^/]*)/self/leave$" + ) + + def __init__(self, hs): + super(GroupSelfLeaveServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_PUT(self, request, group_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + result = yield self.groups_handler.remove_user_from_group( + group_id, requester_user_id, requester_user_id, content, + ) + + defer.returnValue((200, result)) + + +class GroupSelfJoinServlet(RestServlet): + """Attempt to join a group, or knock + """ + PATTERNS = client_v2_patterns( + "/groups/(?P<group_id>[^/]*)/self/join$" + ) + + def __init__(self, hs): + super(GroupSelfJoinServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_PUT(self, request, group_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + result = yield self.groups_handler.join_group( + group_id, requester_user_id, content, + ) + + defer.returnValue((200, result)) + + +class GroupSelfAcceptInviteServlet(RestServlet): + """Accept a group invite + """ + PATTERNS = client_v2_patterns( + "/groups/(?P<group_id>[^/]*)/self/accept_invite$" + ) + + def __init__(self, hs): + super(GroupSelfAcceptInviteServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_PUT(self, request, group_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + result = yield self.groups_handler.accept_invite( + group_id, requester_user_id, content, + ) + + defer.returnValue((200, result)) + + +class GroupSelfUpdatePublicityServlet(RestServlet): + """Update whether we publicise a users membership of a group + """ + PATTERNS = client_v2_patterns( + "/groups/(?P<group_id>[^/]*)/self/update_publicity$" + ) + + def __init__(self, hs): + super(GroupSelfUpdatePublicityServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def on_PUT(self, request, group_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + publicise = content["publicise"] + yield self.store.update_group_publicity( + group_id, requester_user_id, publicise, + ) + + defer.returnValue((200, {})) + + +class PublicisedGroupsForUserServlet(RestServlet): + """Get the list of groups a user is advertising + """ + PATTERNS = client_v2_patterns( + "/publicised_groups/(?P<user_id>[^/]*)$" + ) + + def __init__(self, hs): + super(PublicisedGroupsForUserServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.store = hs.get_datastore() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_GET(self, request, user_id): + yield self.auth.get_user_by_req(request, allow_guest=True) + + result = yield self.groups_handler.get_publicised_groups_for_user( + user_id + ) + + defer.returnValue((200, result)) + + +class PublicisedGroupsForUsersServlet(RestServlet): + """Get the list of groups a user is advertising + """ + PATTERNS = client_v2_patterns( + "/publicised_groups$" + ) + + def __init__(self, hs): + super(PublicisedGroupsForUsersServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.store = hs.get_datastore() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_POST(self, request): + yield self.auth.get_user_by_req(request, allow_guest=True) + + content = parse_json_object_from_request(request) + user_ids = content["user_ids"] + + result = yield self.groups_handler.bulk_get_publicised_groups( + user_ids + ) + + defer.returnValue((200, result)) + + +class GroupsForUserServlet(RestServlet): + """Get all groups the logged in user is joined to + """ + PATTERNS = client_v2_patterns( + "/joined_groups$" + ) + + def __init__(self, hs): + super(GroupsForUserServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_GET(self, request): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + result = yield self.groups_handler.get_joined_groups(requester_user_id) + + defer.returnValue((200, result)) + + +def register_servlets(hs, http_server): + GroupServlet(hs).register(http_server) + GroupSummaryServlet(hs).register(http_server) + GroupInvitedUsersServlet(hs).register(http_server) + GroupUsersServlet(hs).register(http_server) + GroupRoomServlet(hs).register(http_server) + GroupCreateServlet(hs).register(http_server) + GroupAdminRoomsServlet(hs).register(http_server) + GroupAdminRoomsConfigServlet(hs).register(http_server) + GroupAdminUsersInviteServlet(hs).register(http_server) + GroupAdminUsersKickServlet(hs).register(http_server) + GroupSelfLeaveServlet(hs).register(http_server) + GroupSelfJoinServlet(hs).register(http_server) + GroupSelfAcceptInviteServlet(hs).register(http_server) + GroupsForUserServlet(hs).register(http_server) + GroupCategoryServlet(hs).register(http_server) + GroupCategoriesServlet(hs).register(http_server) + GroupSummaryRoomsCatServlet(hs).register(http_server) + GroupRoleServlet(hs).register(http_server) + GroupRolesServlet(hs).register(http_server) + GroupSelfUpdatePublicityServlet(hs).register(http_server) + GroupSummaryUsersRoleServlet(hs).register(http_server) + PublicisedGroupsForUserServlet(hs).register(http_server) + PublicisedGroupsForUsersServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 943e87e7fd..3cc87ea63f 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -53,8 +53,7 @@ class KeyUploadServlet(RestServlet): }, } """ - PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$", - releases=()) + PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$") def __init__(self, hs): """ @@ -128,10 +127,7 @@ class KeyQueryServlet(RestServlet): } } } } } } """ - PATTERNS = client_v2_patterns( - "/keys/query$", - releases=() - ) + PATTERNS = client_v2_patterns("/keys/query$") def __init__(self, hs): """ @@ -160,10 +156,7 @@ class KeyChangesServlet(RestServlet): 200 OK { "changed": ["@foo:example.com"] } """ - PATTERNS = client_v2_patterns( - "/keys/changes$", - releases=() - ) + PATTERNS = client_v2_patterns("/keys/changes$") def __init__(self, hs): """ @@ -213,10 +206,7 @@ class OneTimeKeyServlet(RestServlet): } } } } """ - PATTERNS = client_v2_patterns( - "/keys/claim$", - releases=() - ) + PATTERNS = client_v2_patterns("/keys/claim$") def __init__(self, hs): super(OneTimeKeyServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py index fd2a3d69d4..ec170109fe 100644 --- a/synapse/rest/client/v2_alpha/notifications.py +++ b/synapse/rest/client/v2_alpha/notifications.py @@ -30,7 +30,7 @@ logger = logging.getLogger(__name__) class NotificationsServlet(RestServlet): - PATTERNS = client_v2_patterns("/notifications$", releases=()) + PATTERNS = client_v2_patterns("/notifications$") def __init__(self, hs): super(NotificationsServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 1421c18152..e9d88a8895 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -17,15 +17,17 @@ from twisted.internet import defer import synapse +import synapse.types from synapse.api.auth import get_access_token_from_request, has_access_token from synapse.api.constants import LoginType +from synapse.types import RoomID, RoomAlias from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError from synapse.http.servlet import ( RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string ) from synapse.util.msisdn import phone_number_to_msisdn -from ._base import client_v2_patterns +from ._base import client_v2_patterns, interactive_auth_handler import logging import hmac @@ -170,9 +172,11 @@ class RegisterRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_handlers().registration_handler self.identity_handler = hs.get_handlers().identity_handler + self.room_member_handler = hs.get_handlers().room_member_handler self.device_handler = hs.get_device_handler() self.macaroon_gen = hs.get_macaroon_generator() + @interactive_auth_handler @defer.inlineCallbacks def on_POST(self, request): yield run_on_reactor() @@ -221,6 +225,12 @@ class RegisterRestServlet(RestServlet): # 'user' key not 'username'). Since this is a new addition, we'll # fallback to 'username' if they gave one. desired_username = body.get("user", desired_username) + + # XXX we should check that desired_username is valid. Currently + # we give appservices carte blanche for any insanity in mxids, + # because the IRC bridges rely on being able to register stupid + # IDs. + access_token = get_access_token_from_request(request) if isinstance(desired_username, basestring): @@ -230,6 +240,15 @@ class RegisterRestServlet(RestServlet): defer.returnValue((200, result)) # we throw for non 200 responses return + # for either shared secret or regular registration, downcase the + # provided username before attempting to register it. This should mean + # that people who try to register with upper-case in their usernames + # don't get a nasty surprise. (Note that we treat username + # case-insenstively in login, so they are free to carry on imagining + # that their username is CrAzYh4cKeR if that keeps them happy) + if desired_username is not None: + desired_username = desired_username.lower() + # == Shared Secret Registration == (e.g. create new user scripts) if 'mac' in body: # FIXME: Should we really be determining if this is shared secret @@ -307,14 +326,10 @@ class RegisterRestServlet(RestServlet): [LoginType.MSISDN, LoginType.EMAIL_IDENTITY], ]) - authed, auth_result, params, session_id = yield self.auth_handler.check_auth( + auth_result, params, session_id = yield self.auth_handler.check_auth( flows, body, self.hs.get_ip_from_request(request) ) - if not authed: - defer.returnValue((401, auth_result)) - return - if registered_user_id is not None: logger.info( "Already registered user ID %r for this session", @@ -333,6 +348,9 @@ class RegisterRestServlet(RestServlet): new_password = params.get("password", None) guest_access_token = params.get("guest_access_token", None) + if desired_username is not None: + desired_username = desired_username.lower() + (registered_user_id, _) = yield self.registration_handler.register( localpart=desired_username, password=new_password, @@ -340,6 +358,14 @@ class RegisterRestServlet(RestServlet): generate_token=False, ) + # auto-join the user to any rooms we're supposed to dump them into + fake_requester = synapse.types.create_requester(registered_user_id) + for r in self.hs.config.auto_join_rooms: + try: + yield self._join_user_to_room(fake_requester, r) + except Exception as e: + logger.error("Failed to join new user to %r: %r", r, e) + # remember that we've now registered that user account, and with # what user ID (since the user may not have specified) self.auth_handler.set_session_data( @@ -373,6 +399,29 @@ class RegisterRestServlet(RestServlet): return 200, {} @defer.inlineCallbacks + def _join_user_to_room(self, requester, room_identifier): + room_id = None + if RoomID.is_valid(room_identifier): + room_id = room_identifier + elif RoomAlias.is_valid(room_identifier): + room_alias = RoomAlias.from_string(room_identifier) + room_id, remote_room_hosts = ( + yield self.room_member_handler.lookup_room_alias(room_alias) + ) + room_id = room_id.to_string() + else: + raise SynapseError(400, "%s was not legal room ID or room alias" % ( + room_identifier, + )) + + yield self.room_member_handler.update_membership( + requester=requester, + target=requester.user, + room_id=room_id, + action="join", + ) + + @defer.inlineCallbacks def _do_appservice_registration(self, username, as_token, body): user_id = yield self.registration_handler.appservice_register( username, as_token @@ -383,13 +432,22 @@ class RegisterRestServlet(RestServlet): def _do_shared_secret_registration(self, username, password, body): if not self.hs.config.registration_shared_secret: raise SynapseError(400, "Shared secret registration is not enabled") + if not username: + raise SynapseError( + 400, "username must be specified", errcode=Codes.BAD_JSON, + ) - user = username.encode("utf-8") + # use the username from the original request rather than the + # downcased one in `username` for the mac calculation + user = body["username"].encode("utf-8") # str() because otherwise hmac complains that 'unicode' does not # have the buffer interface got_mac = str(body["mac"]) + # FIXME this is different to the /v1/register endpoint, which + # includes the password and admin flag in the hashed text. Why are + # these different? want_mac = hmac.new( key=self.hs.config.registration_shared_secret, msg=user, @@ -523,25 +581,28 @@ class RegisterRestServlet(RestServlet): Args: (str) user_id: full canonical @user:id (object) params: registration parameters, from which we pull - device_id and initial_device_name + device_id, initial_device_name and inhibit_login Returns: defer.Deferred: (object) dictionary for response from /register """ - device_id = yield self._register_device(user_id, params) + result = { + "user_id": user_id, + "home_server": self.hs.hostname, + } + if not params.get("inhibit_login", False): + device_id = yield self._register_device(user_id, params) - access_token = ( - yield self.auth_handler.get_access_token_for_user_id( - user_id, device_id=device_id, - initial_display_name=params.get("initial_device_display_name") + access_token = ( + yield self.auth_handler.get_access_token_for_user_id( + user_id, device_id=device_id, + ) ) - ) - defer.returnValue({ - "user_id": user_id, - "access_token": access_token, - "home_server": self.hs.hostname, - "device_id": device_id, - }) + result.update({ + "access_token": access_token, + "device_id": device_id, + }) + defer.returnValue(result) def _register_device(self, user_id, params): """Register a device for a user. diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py index d607bd2970..90bdb1db15 100644 --- a/synapse/rest/client/v2_alpha/sendtodevice.py +++ b/synapse/rest/client/v2_alpha/sendtodevice.py @@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) class SendToDeviceRestServlet(servlet.RestServlet): PATTERNS = client_v2_patterns( "/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$", - releases=[], v2_alpha=False + v2_alpha=False ) def __init__(self, hs): diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 978af9c280..a0a8e4b8e4 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -125,7 +125,7 @@ class SyncRestServlet(RestServlet): filter_object = json.loads(filter_id) set_timeline_upper_limit(filter_object, self.hs.config.filter_timeline_limit) - except: + except Exception: raise SynapseError(400, "Invalid filter JSON") self.filtering.check_valid_filter(filter_object) filter = FilterCollection(filter_object) @@ -200,6 +200,11 @@ class SyncRestServlet(RestServlet): "invite": invited, "leave": archived, }, + "groups": { + "join": sync_result.groups.join, + "invite": sync_result.groups.invite, + "leave": sync_result.groups.leave, + }, "device_one_time_keys_count": sync_result.device_one_time_keys_count, "next_batch": sync_result.next_batch.to_string(), } diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index 6fceb23e26..6773b9ba60 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -26,7 +26,7 @@ logger = logging.getLogger(__name__) class ThirdPartyProtocolsServlet(RestServlet): - PATTERNS = client_v2_patterns("/thirdparty/protocols", releases=()) + PATTERNS = client_v2_patterns("/thirdparty/protocols") def __init__(self, hs): super(ThirdPartyProtocolsServlet, self).__init__() @@ -43,8 +43,7 @@ class ThirdPartyProtocolsServlet(RestServlet): class ThirdPartyProtocolServlet(RestServlet): - PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$", - releases=()) + PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$") def __init__(self, hs): super(ThirdPartyProtocolServlet, self).__init__() @@ -66,8 +65,7 @@ class ThirdPartyProtocolServlet(RestServlet): class ThirdPartyUserServlet(RestServlet): - PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$", - releases=()) + PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$") def __init__(self, hs): super(ThirdPartyUserServlet, self).__init__() @@ -90,8 +88,7 @@ class ThirdPartyUserServlet(RestServlet): class ThirdPartyLocationServlet(RestServlet): - PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$", - releases=()) + PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$") def __init__(self, hs): super(ThirdPartyLocationServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py index 6e012da4aa..2d4a43c353 100644 --- a/synapse/rest/client/v2_alpha/user_directory.py +++ b/synapse/rest/client/v2_alpha/user_directory.py @@ -65,7 +65,7 @@ class UserDirectorySearchRestServlet(RestServlet): try: search_term = body["search_term"] - except: + except Exception: raise SynapseError(400, "`search_term` is required field") results = yield self.user_directory_handler.search_users( diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index e984ea47db..2ecb15deee 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -30,6 +30,7 @@ class VersionsRestServlet(RestServlet): "r0.0.1", "r0.1.0", "r0.2.0", + "r0.3.0", ] }) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 9fe2013657..cc2842aa72 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -213,7 +213,7 @@ class RemoteKey(Resource): ) except KeyLookupError as e: logger.info("Failed to fetch key: %s", e) - except: + except Exception: logger.exception("Failed to get key for %r", server_name) yield self.query_keys( request, query, query_remote_on_cache_miss=False diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index b9600f2167..95fa95fce3 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -17,6 +17,7 @@ from synapse.http.server import respond_with_json, finish_request from synapse.api.errors import ( cs_error, Codes, SynapseError ) +from synapse.util import logcontext from twisted.internet import defer from twisted.protocols.basic import FileSender @@ -44,7 +45,7 @@ def parse_media_id(request): except UnicodeDecodeError: pass return server_name, media_id, file_name - except: + except Exception: raise SynapseError( 404, "Invalid media id token %r" % (request.postpath,), @@ -103,7 +104,9 @@ def respond_with_file(request, media_type, file_path, ) with open(file_path, "rb") as f: - yield FileSender().beginFileTransfer(f, request) + yield logcontext.make_deferred_yieldable( + FileSender().beginFileTransfer(f, request) + ) finish_request(request) else: diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py index d92b7ff337..d5164e47e0 100644 --- a/synapse/rest/media/v1/filepath.py +++ b/synapse/rest/media/v1/filepath.py @@ -14,78 +14,200 @@ # limitations under the License. import os +import re +import functools + +NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d") + + +def _wrap_in_base_path(func): + """Takes a function that returns a relative path and turns it into an + absolute path based on the location of the primary media store + """ + @functools.wraps(func) + def _wrapped(self, *args, **kwargs): + path = func(self, *args, **kwargs) + return os.path.join(self.base_path, path) + + return _wrapped class MediaFilePaths(object): + """Describes where files are stored on disk. - def __init__(self, base_path): - self.base_path = base_path + Most of the functions have a `*_rel` variant which returns a file path that + is relative to the base media store path. This is mainly used when we want + to write to the backup media store (when one is configured) + """ - def default_thumbnail(self, default_top_level, default_sub_type, width, - height, content_type, method): + def __init__(self, primary_base_path): + self.base_path = primary_base_path + + def default_thumbnail_rel(self, default_top_level, default_sub_type, width, + height, content_type, method): top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s-%s" % ( width, height, top_level_type, sub_type, method ) return os.path.join( - self.base_path, "default_thumbnails", default_top_level, + "default_thumbnails", default_top_level, default_sub_type, file_name ) - def local_media_filepath(self, media_id): + default_thumbnail = _wrap_in_base_path(default_thumbnail_rel) + + def local_media_filepath_rel(self, media_id): return os.path.join( - self.base_path, "local_content", + "local_content", media_id[0:2], media_id[2:4], media_id[4:] ) - def local_media_thumbnail(self, media_id, width, height, content_type, - method): + local_media_filepath = _wrap_in_base_path(local_media_filepath_rel) + + def local_media_thumbnail_rel(self, media_id, width, height, content_type, + method): top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s-%s" % ( width, height, top_level_type, sub_type, method ) return os.path.join( - self.base_path, "local_thumbnails", + "local_thumbnails", media_id[0:2], media_id[2:4], media_id[4:], file_name ) - def remote_media_filepath(self, server_name, file_id): + local_media_thumbnail = _wrap_in_base_path(local_media_thumbnail_rel) + + def remote_media_filepath_rel(self, server_name, file_id): return os.path.join( - self.base_path, "remote_content", server_name, + "remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:] ) - def remote_media_thumbnail(self, server_name, file_id, width, height, - content_type, method): + remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel) + + def remote_media_thumbnail_rel(self, server_name, file_id, width, height, + content_type, method): top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type) return os.path.join( - self.base_path, "remote_thumbnail", server_name, + "remote_thumbnail", server_name, file_id[0:2], file_id[2:4], file_id[4:], file_name ) + remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel) + def remote_media_thumbnail_dir(self, server_name, file_id): return os.path.join( self.base_path, "remote_thumbnail", server_name, file_id[0:2], file_id[2:4], file_id[4:], ) - def url_cache_filepath(self, media_id): - return os.path.join( - self.base_path, "url_cache", - media_id[0:2], media_id[2:4], media_id[4:] - ) + def url_cache_filepath_rel(self, media_id): + if NEW_FORMAT_ID_RE.match(media_id): + # Media id is of the form <DATE><RANDOM_STRING> + # E.g.: 2017-09-28-fsdRDt24DS234dsf + return os.path.join( + "url_cache", + media_id[:10], media_id[11:] + ) + else: + return os.path.join( + "url_cache", + media_id[0:2], media_id[2:4], media_id[4:], + ) + + url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel) + + def url_cache_filepath_dirs_to_delete(self, media_id): + "The dirs to try and remove if we delete the media_id file" + if NEW_FORMAT_ID_RE.match(media_id): + return [ + os.path.join( + self.base_path, "url_cache", + media_id[:10], + ), + ] + else: + return [ + os.path.join( + self.base_path, "url_cache", + media_id[0:2], media_id[2:4], + ), + os.path.join( + self.base_path, "url_cache", + media_id[0:2], + ), + ] + + def url_cache_thumbnail_rel(self, media_id, width, height, content_type, + method): + # Media id is of the form <DATE><RANDOM_STRING> + # E.g.: 2017-09-28-fsdRDt24DS234dsf - def url_cache_thumbnail(self, media_id, width, height, content_type, - method): top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s-%s" % ( width, height, top_level_type, sub_type, method ) - return os.path.join( - self.base_path, "url_cache_thumbnails", - media_id[0:2], media_id[2:4], media_id[4:], - file_name - ) + + if NEW_FORMAT_ID_RE.match(media_id): + return os.path.join( + "url_cache_thumbnails", + media_id[:10], media_id[11:], + file_name + ) + else: + return os.path.join( + "url_cache_thumbnails", + media_id[0:2], media_id[2:4], media_id[4:], + file_name + ) + + url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel) + + def url_cache_thumbnail_directory(self, media_id): + # Media id is of the form <DATE><RANDOM_STRING> + # E.g.: 2017-09-28-fsdRDt24DS234dsf + + if NEW_FORMAT_ID_RE.match(media_id): + return os.path.join( + self.base_path, "url_cache_thumbnails", + media_id[:10], media_id[11:], + ) + else: + return os.path.join( + self.base_path, "url_cache_thumbnails", + media_id[0:2], media_id[2:4], media_id[4:], + ) + + def url_cache_thumbnail_dirs_to_delete(self, media_id): + "The dirs to try and remove if we delete the media_id thumbnails" + # Media id is of the form <DATE><RANDOM_STRING> + # E.g.: 2017-09-28-fsdRDt24DS234dsf + if NEW_FORMAT_ID_RE.match(media_id): + return [ + os.path.join( + self.base_path, "url_cache_thumbnails", + media_id[:10], media_id[11:], + ), + os.path.join( + self.base_path, "url_cache_thumbnails", + media_id[:10], + ), + ] + else: + return [ + os.path.join( + self.base_path, "url_cache_thumbnails", + media_id[0:2], media_id[2:4], media_id[4:], + ), + os.path.join( + self.base_path, "url_cache_thumbnails", + media_id[0:2], media_id[2:4], + ), + os.path.join( + self.base_path, "url_cache_thumbnails", + media_id[0:2], + ), + ] diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 0ea1248ce6..eed9056a2f 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -33,7 +33,7 @@ from synapse.api.errors import SynapseError, HttpResponseException, \ from synapse.util.async import Linearizer from synapse.util.stringutils import is_ascii -from synapse.util.logcontext import preserve_context_over_fn +from synapse.util.logcontext import make_deferred_yieldable, preserve_fn from synapse.util.retryutils import NotRetryingDestination import os @@ -59,7 +59,14 @@ class MediaRepository(object): self.store = hs.get_datastore() self.max_upload_size = hs.config.max_upload_size self.max_image_pixels = hs.config.max_image_pixels - self.filepaths = MediaFilePaths(hs.config.media_store_path) + + self.primary_base_path = hs.config.media_store_path + self.filepaths = MediaFilePaths(self.primary_base_path) + + self.backup_base_path = hs.config.backup_media_store_path + + self.synchronous_backup_media_store = hs.config.synchronous_backup_media_store + self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.thumbnail_requirements = hs.config.thumbnail_requirements @@ -87,18 +94,86 @@ class MediaRepository(object): if not os.path.exists(dirname): os.makedirs(dirname) + @staticmethod + def _write_file_synchronously(source, fname): + """Write `source` to the path `fname` synchronously. Should be called + from a thread. + + Args: + source: A file like object to be written + fname (str): Path to write to + """ + MediaRepository._makedirs(fname) + source.seek(0) # Ensure we read from the start of the file + with open(fname, "wb") as f: + shutil.copyfileobj(source, f) + + @defer.inlineCallbacks + def write_to_file_and_backup(self, source, path): + """Write `source` to the on disk media store, and also the backup store + if configured. + + Args: + source: A file like object that should be written + path (str): Relative path to write file to + + Returns: + Deferred[str]: the file path written to in the primary media store + """ + fname = os.path.join(self.primary_base_path, path) + + # Write to the main repository + yield make_deferred_yieldable(threads.deferToThread( + self._write_file_synchronously, source, fname, + )) + + # Write to backup repository + yield self.copy_to_backup(path) + + defer.returnValue(fname) + + @defer.inlineCallbacks + def copy_to_backup(self, path): + """Copy a file from the primary to backup media store, if configured. + + Args: + path(str): Relative path to write file to + """ + if self.backup_base_path: + primary_fname = os.path.join(self.primary_base_path, path) + backup_fname = os.path.join(self.backup_base_path, path) + + # We can either wait for successful writing to the backup repository + # or write in the background and immediately return + if self.synchronous_backup_media_store: + yield make_deferred_yieldable(threads.deferToThread( + shutil.copyfile, primary_fname, backup_fname, + )) + else: + preserve_fn(threads.deferToThread)( + shutil.copyfile, primary_fname, backup_fname, + ) + @defer.inlineCallbacks def create_content(self, media_type, upload_name, content, content_length, auth_user): + """Store uploaded content for a local user and return the mxc URL + + Args: + media_type(str): The content type of the file + upload_name(str): The name of the file + content: A file like object that is the content to store + content_length(int): The length of the content + auth_user(str): The user_id of the uploader + + Returns: + Deferred[str]: The mxc url of the stored content + """ media_id = random_string(24) - fname = self.filepaths.local_media_filepath(media_id) - self._makedirs(fname) - - # This shouldn't block for very long because the content will have - # already been uploaded at this point. - with open(fname, "wb") as f: - f.write(content) + fname = yield self.write_to_file_and_backup( + content, self.filepaths.local_media_filepath_rel(media_id) + ) logger.info("Stored local media in file %r", fname) @@ -115,7 +190,7 @@ class MediaRepository(object): "media_length": content_length, } - yield self._generate_local_thumbnails(media_id, media_info) + yield self._generate_thumbnails(None, media_id, media_info) defer.returnValue("mxc://%s/%s" % (self.server_name, media_id)) @@ -148,9 +223,10 @@ class MediaRepository(object): def _download_remote_file(self, server_name, media_id): file_id = random_string(24) - fname = self.filepaths.remote_media_filepath( + fpath = self.filepaths.remote_media_filepath_rel( server_name, file_id ) + fname = os.path.join(self.primary_base_path, fpath) self._makedirs(fname) try: @@ -192,6 +268,8 @@ class MediaRepository(object): server_name, media_id) raise SynapseError(502, "Failed to fetch remote media") + yield self.copy_to_backup(fpath) + media_type = headers["Content-Type"][0] time_now_ms = self.clock.time_msec() @@ -232,7 +310,7 @@ class MediaRepository(object): media_length=length, filesystem_id=file_id, ) - except: + except Exception: os.remove(fname) raise @@ -244,7 +322,7 @@ class MediaRepository(object): "filesystem_id": file_id, } - yield self._generate_remote_thumbnails( + yield self._generate_thumbnails( server_name, media_id, media_info ) @@ -253,9 +331,8 @@ class MediaRepository(object): def _get_thumbnail_requirements(self, media_type): return self.thumbnail_requirements.get(media_type, ()) - def _generate_thumbnail(self, input_path, t_path, t_width, t_height, + def _generate_thumbnail(self, thumbnailer, t_width, t_height, t_method, t_type): - thumbnailer = Thumbnailer(input_path) m_width = thumbnailer.width m_height = thumbnailer.height @@ -267,72 +344,105 @@ class MediaRepository(object): return if t_method == "crop": - t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) + t_byte_source = thumbnailer.crop(t_width, t_height, t_type) elif t_method == "scale": t_width, t_height = thumbnailer.aspect(t_width, t_height) t_width = min(m_width, t_width) t_height = min(m_height, t_height) - t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) + t_byte_source = thumbnailer.scale(t_width, t_height, t_type) else: - t_len = None + t_byte_source = None - return t_len + return t_byte_source @defer.inlineCallbacks def generate_local_exact_thumbnail(self, media_id, t_width, t_height, t_method, t_type): input_path = self.filepaths.local_media_filepath(media_id) - t_path = self.filepaths.local_media_thumbnail( - media_id, t_width, t_height, t_type, t_method - ) - self._makedirs(t_path) - - t_len = yield preserve_context_over_fn( - threads.deferToThread, + thumbnailer = Thumbnailer(input_path) + t_byte_source = yield make_deferred_yieldable(threads.deferToThread( self._generate_thumbnail, - input_path, t_path, t_width, t_height, t_method, t_type - ) + thumbnailer, t_width, t_height, t_method, t_type + )) + + if t_byte_source: + try: + output_path = yield self.write_to_file_and_backup( + t_byte_source, + self.filepaths.local_media_thumbnail_rel( + media_id, t_width, t_height, t_type, t_method + ) + ) + finally: + t_byte_source.close() + + logger.info("Stored thumbnail in file %r", output_path) + + t_len = os.path.getsize(output_path) - if t_len: yield self.store.store_local_thumbnail( media_id, t_width, t_height, t_type, t_method, t_len ) - defer.returnValue(t_path) + defer.returnValue(output_path) @defer.inlineCallbacks def generate_remote_exact_thumbnail(self, server_name, file_id, media_id, t_width, t_height, t_method, t_type): input_path = self.filepaths.remote_media_filepath(server_name, file_id) - t_path = self.filepaths.remote_media_thumbnail( - server_name, file_id, t_width, t_height, t_type, t_method - ) - self._makedirs(t_path) - - t_len = yield preserve_context_over_fn( - threads.deferToThread, + thumbnailer = Thumbnailer(input_path) + t_byte_source = yield make_deferred_yieldable(threads.deferToThread( self._generate_thumbnail, - input_path, t_path, t_width, t_height, t_method, t_type - ) + thumbnailer, t_width, t_height, t_method, t_type + )) + + if t_byte_source: + try: + output_path = yield self.write_to_file_and_backup( + t_byte_source, + self.filepaths.remote_media_thumbnail_rel( + server_name, file_id, t_width, t_height, t_type, t_method + ) + ) + finally: + t_byte_source.close() + + logger.info("Stored thumbnail in file %r", output_path) + + t_len = os.path.getsize(output_path) - if t_len: yield self.store.store_remote_media_thumbnail( server_name, media_id, file_id, t_width, t_height, t_type, t_method, t_len ) - defer.returnValue(t_path) + defer.returnValue(output_path) @defer.inlineCallbacks - def _generate_local_thumbnails(self, media_id, media_info, url_cache=False): + def _generate_thumbnails(self, server_name, media_id, media_info, url_cache=False): + """Generate and store thumbnails for an image. + + Args: + server_name(str|None): The server name if remote media, else None if local + media_id(str) + media_info(dict) + url_cache(bool): If we are thumbnailing images downloaded for the URL cache, + used exclusively by the url previewer + + Returns: + Deferred[dict]: Dict with "width" and "height" keys of original image + """ media_type = media_info["media_type"] + file_id = media_info.get("filesystem_id") requirements = self._get_thumbnail_requirements(media_type) if not requirements: return - if url_cache: + if server_name: + input_path = self.filepaths.remote_media_filepath(server_name, file_id) + elif url_cache: input_path = self.filepaths.url_cache_filepath(media_id) else: input_path = self.filepaths.local_media_filepath(media_id) @@ -348,135 +458,72 @@ class MediaRepository(object): ) return - local_thumbnails = [] - - def generate_thumbnails(): - scales = set() - crops = set() - for r_width, r_height, r_method, r_type in requirements: - if r_method == "scale": - t_width, t_height = thumbnailer.aspect(r_width, r_height) - scales.add(( - min(m_width, t_width), min(m_height, t_height), r_type, - )) - elif r_method == "crop": - crops.add((r_width, r_height, r_type)) - - for t_width, t_height, t_type in scales: - t_method = "scale" - if url_cache: - t_path = self.filepaths.url_cache_thumbnail( - media_id, t_width, t_height, t_type, t_method - ) - else: - t_path = self.filepaths.local_media_thumbnail( - media_id, t_width, t_height, t_type, t_method - ) - self._makedirs(t_path) - t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) + # We deduplicate the thumbnail sizes by ignoring the cropped versions if + # they have the same dimensions of a scaled one. + thumbnails = {} + for r_width, r_height, r_method, r_type in requirements: + if r_method == "crop": + thumbnails.setdefault((r_width, r_height, r_type), r_method) + elif r_method == "scale": + t_width, t_height = thumbnailer.aspect(r_width, r_height) + t_width = min(m_width, t_width) + t_height = min(m_height, t_height) + thumbnails[(t_width, t_height, r_type)] = r_method + + # Now we generate the thumbnails for each dimension, store it + for (t_width, t_height, t_type), t_method in thumbnails.iteritems(): + # Work out the correct file name for thumbnail + if server_name: + file_path = self.filepaths.remote_media_thumbnail_rel( + server_name, file_id, t_width, t_height, t_type, t_method + ) + elif url_cache: + file_path = self.filepaths.url_cache_thumbnail_rel( + media_id, t_width, t_height, t_type, t_method + ) + else: + file_path = self.filepaths.local_media_thumbnail_rel( + media_id, t_width, t_height, t_type, t_method + ) - local_thumbnails.append(( - media_id, t_width, t_height, t_type, t_method, t_len + # Generate the thumbnail + if t_method == "crop": + t_byte_source = yield make_deferred_yieldable(threads.deferToThread( + thumbnailer.crop, + t_width, t_height, t_type, )) - - for t_width, t_height, t_type in crops: - if (t_width, t_height, t_type) in scales: - # If the aspect ratio of the cropped thumbnail matches a purely - # scaled one then there is no point in calculating a separate - # thumbnail. - continue - t_method = "crop" - if url_cache: - t_path = self.filepaths.url_cache_thumbnail( - media_id, t_width, t_height, t_type, t_method - ) - else: - t_path = self.filepaths.local_media_thumbnail( - media_id, t_width, t_height, t_type, t_method - ) - self._makedirs(t_path) - t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) - local_thumbnails.append(( - media_id, t_width, t_height, t_type, t_method, t_len + elif t_method == "scale": + t_byte_source = yield make_deferred_yieldable(threads.deferToThread( + thumbnailer.scale, + t_width, t_height, t_type, )) + else: + logger.error("Unrecognized method: %r", t_method) + continue - yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails) - - for l in local_thumbnails: - yield self.store.store_local_thumbnail(*l) - - defer.returnValue({ - "width": m_width, - "height": m_height, - }) - - @defer.inlineCallbacks - def _generate_remote_thumbnails(self, server_name, media_id, media_info): - media_type = media_info["media_type"] - file_id = media_info["filesystem_id"] - requirements = self._get_thumbnail_requirements(media_type) - if not requirements: - return + if not t_byte_source: + continue - remote_thumbnails = [] + try: + # Write to disk + output_path = yield self.write_to_file_and_backup( + t_byte_source, file_path, + ) + finally: + t_byte_source.close() - input_path = self.filepaths.remote_media_filepath(server_name, file_id) - thumbnailer = Thumbnailer(input_path) - m_width = thumbnailer.width - m_height = thumbnailer.height + t_len = os.path.getsize(output_path) - def generate_thumbnails(): - if m_width * m_height >= self.max_image_pixels: - logger.info( - "Image too large to thumbnail %r x %r > %r", - m_width, m_height, self.max_image_pixels - ) - return - - scales = set() - crops = set() - for r_width, r_height, r_method, r_type in requirements: - if r_method == "scale": - t_width, t_height = thumbnailer.aspect(r_width, r_height) - scales.add(( - min(m_width, t_width), min(m_height, t_height), r_type, - )) - elif r_method == "crop": - crops.add((r_width, r_height, r_type)) - - for t_width, t_height, t_type in scales: - t_method = "scale" - t_path = self.filepaths.remote_media_thumbnail( - server_name, file_id, t_width, t_height, t_type, t_method - ) - self._makedirs(t_path) - t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) - remote_thumbnails.append([ + # Write to database + if server_name: + yield self.store.store_remote_media_thumbnail( server_name, media_id, file_id, t_width, t_height, t_type, t_method, t_len - ]) - - for t_width, t_height, t_type in crops: - if (t_width, t_height, t_type) in scales: - # If the aspect ratio of the cropped thumbnail matches a purely - # scaled one then there is no point in calculating a separate - # thumbnail. - continue - t_method = "crop" - t_path = self.filepaths.remote_media_thumbnail( - server_name, file_id, t_width, t_height, t_type, t_method ) - self._makedirs(t_path) - t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) - remote_thumbnails.append([ - server_name, media_id, file_id, - t_width, t_height, t_type, t_method, t_len - ]) - - yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails) - - for r in remote_thumbnails: - yield self.store.store_remote_media_thumbnail(*r) + else: + yield self.store.store_local_thumbnail( + media_id, t_width, t_height, t_type, t_method, t_len + ) defer.returnValue({ "width": m_width, @@ -497,6 +544,8 @@ class MediaRepository(object): logger.info("Deleting: %r", key) + # TODO: Should we delete from the backup store + with (yield self.remote_media_linearizer.queue(key)): full_path = self.filepaths.remote_media_filepath(origin, file_id) try: diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index b81a336c5d..40d2e664eb 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -20,11 +20,13 @@ from twisted.web.resource import Resource from synapse.api.errors import ( SynapseError, Codes, ) +from synapse.util.logcontext import preserve_fn, make_deferred_yieldable from synapse.util.stringutils import random_string from synapse.util.caches.expiringcache import ExpiringCache from synapse.http.client import SpiderHttpClient from synapse.http.server import ( - request_handler, respond_with_json_bytes + request_handler, respond_with_json_bytes, + respond_with_json, ) from synapse.util.async import ObservableDeferred from synapse.util.stringutils import is_ascii @@ -36,6 +38,9 @@ import cgi import ujson as json import urlparse import itertools +import datetime +import errno +import shutil import logging logger = logging.getLogger(__name__) @@ -56,19 +61,26 @@ class PreviewUrlResource(Resource): self.store = hs.get_datastore() self.client = SpiderHttpClient(hs) self.media_repo = media_repo + self.primary_base_path = media_repo.primary_base_path self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist - # simple memory cache mapping urls to OG metadata - self.cache = ExpiringCache( + # memory cache mapping urls to an ObservableDeferred returning + # JSON-encoded OG metadata + self._cache = ExpiringCache( cache_name="url_previews", clock=self.clock, # don't spider URLs more often than once an hour expiry_ms=60 * 60 * 1000, ) - self.cache.start() + self._cache.start() - self.downloads = {} + self._cleaner_loop = self.clock.looping_call( + self._expire_url_cache_data, 10 * 1000 + ) + + def render_OPTIONS(self, request): + return respond_with_json(request, 200, {}, send_cors=True) def render_GET(self, request): self._async_render_GET(request) @@ -86,6 +98,7 @@ class PreviewUrlResource(Resource): else: ts = self.clock.time_msec() + # XXX: we could move this into _do_preview if we wanted. url_tuple = urlparse.urlsplit(url) for entry in self.url_preview_url_blacklist: match = True @@ -118,53 +131,59 @@ class PreviewUrlResource(Resource): Codes.UNKNOWN ) - # first check the memory cache - good to handle all the clients on this - # HS thundering away to preview the same URL at the same time. - og = self.cache.get(url) - if og: - respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True) - return + # the in-memory cache: + # * ensures that only one request is active at a time + # * takes load off the DB for the thundering herds + # * also caches any failures (unlike the DB) so we don't keep + # requesting the same endpoint + + observable = self._cache.get(url) + + if not observable: + download = preserve_fn(self._do_preview)( + url, requester.user, ts, + ) + observable = ObservableDeferred( + download, + consumeErrors=True + ) + self._cache[url] = observable + else: + logger.info("Returning cached response") + + og = yield make_deferred_yieldable(observable.observe()) + respond_with_json_bytes(request, 200, og, send_cors=True) - # then check the URL cache in the DB (which will also provide us with + @defer.inlineCallbacks + def _do_preview(self, url, user, ts): + """Check the db, and download the URL and build a preview + + Args: + url (str): + user (str): + ts (int): + + Returns: + Deferred[str]: json-encoded og data + """ + # check the URL cache in the DB (which will also provide us with # historical previews, if we have any) cache_result = yield self.store.get_url_cache(url, ts) if ( cache_result and - cache_result["download_ts"] + cache_result["expires"] > ts and + cache_result["expires_ts"] > ts and cache_result["response_code"] / 100 == 2 ): - respond_with_json_bytes( - request, 200, cache_result["og"].encode('utf-8'), - send_cors=True - ) + defer.returnValue(cache_result["og"]) return - # Ensure only one download for a given URL is active at a time - download = self.downloads.get(url) - if download is None: - download = self._download_url(url, requester.user) - download = ObservableDeferred( - download, - consumeErrors=True - ) - self.downloads[url] = download - - @download.addBoth - def callback(media_info): - del self.downloads[url] - return media_info - media_info = yield download.observe() - - # FIXME: we should probably update our cache now anyway, so that - # even if the OG calculation raises, we don't keep hammering on the - # remote server. For now, leave it uncached to aid debugging OG - # calculation problems + media_info = yield self._download_url(url, user) logger.debug("got media_info of '%s'" % media_info) if _is_media(media_info['media_type']): - dims = yield self.media_repo._generate_local_thumbnails( - media_info['filesystem_id'], media_info, url_cache=True, + dims = yield self.media_repo._generate_thumbnails( + None, media_info['filesystem_id'], media_info, url_cache=True, ) og = { @@ -204,13 +223,13 @@ class PreviewUrlResource(Resource): # just rely on the caching on the master request to speed things up. if 'og:image' in og and og['og:image']: image_info = yield self._download_url( - _rebase_url(og['og:image'], media_info['uri']), requester.user + _rebase_url(og['og:image'], media_info['uri']), user ) if _is_media(image_info['media_type']): # TODO: make sure we don't choke on white-on-transparent images - dims = yield self.media_repo._generate_local_thumbnails( - image_info['filesystem_id'], image_info, url_cache=True, + dims = yield self.media_repo._generate_thumbnails( + None, image_info['filesystem_id'], image_info, url_cache=True, ) if dims: og["og:image:width"] = dims['width'] @@ -231,21 +250,20 @@ class PreviewUrlResource(Resource): logger.debug("Calculated OG for %s as %s" % (url, og)) - # store OG in ephemeral in-memory cache - self.cache[url] = og + jsonog = json.dumps(og) # store OG in history-aware DB cache yield self.store.store_url_cache( url, media_info["response_code"], media_info["etag"], - media_info["expires"], - json.dumps(og), + media_info["expires"] + media_info["created_ts"], + jsonog, media_info["filesystem_id"], media_info["created_ts"], ) - respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True) + defer.returnValue(jsonog) @defer.inlineCallbacks def _download_url(self, url, user): @@ -253,10 +271,10 @@ class PreviewUrlResource(Resource): # we're most likely being explicitly triggered by a human rather than a # bot, so are we really a robot? - # XXX: horrible duplication with base_resource's _download_remote_file() - file_id = random_string(24) + file_id = datetime.date.today().isoformat() + '_' + random_string(16) - fname = self.filepaths.url_cache_filepath(file_id) + fpath = self.filepaths.url_cache_filepath_rel(file_id) + fname = os.path.join(self.primary_base_path, fpath) self.media_repo._makedirs(fname) try: @@ -267,6 +285,8 @@ class PreviewUrlResource(Resource): ) # FIXME: pass through 404s and other error messages nicely + yield self.media_repo.copy_to_backup(fpath) + media_type = headers["Content-Type"][0] time_now_ms = self.clock.time_msec() @@ -328,6 +348,95 @@ class PreviewUrlResource(Resource): "etag": headers["ETag"][0] if "ETag" in headers else None, }) + @defer.inlineCallbacks + def _expire_url_cache_data(self): + """Clean up expired url cache content, media and thumbnails. + """ + # TODO: Delete from backup media store + + now = self.clock.time_msec() + + logger.info("Running url preview cache expiry") + + if not (yield self.store.has_completed_background_updates()): + logger.info("Still running DB updates; skipping expiry") + return + + # First we delete expired url cache entries + media_ids = yield self.store.get_expired_url_cache(now) + + removed_media = [] + for media_id in media_ids: + fname = self.filepaths.url_cache_filepath(media_id) + try: + os.remove(fname) + except OSError as e: + # If the path doesn't exist, meh + if e.errno != errno.ENOENT: + logger.warn("Failed to remove media: %r: %s", media_id, e) + continue + + removed_media.append(media_id) + + try: + dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id) + for dir in dirs: + os.rmdir(dir) + except Exception: + pass + + yield self.store.delete_url_cache(removed_media) + + if removed_media: + logger.info("Deleted %d entries from url cache", len(removed_media)) + + # Now we delete old images associated with the url cache. + # These may be cached for a bit on the client (i.e., they + # may have a room open with a preview url thing open). + # So we wait a couple of days before deleting, just in case. + expire_before = now - 2 * 24 * 60 * 60 * 1000 + media_ids = yield self.store.get_url_cache_media_before(expire_before) + + removed_media = [] + for media_id in media_ids: + fname = self.filepaths.url_cache_filepath(media_id) + try: + os.remove(fname) + except OSError as e: + # If the path doesn't exist, meh + if e.errno != errno.ENOENT: + logger.warn("Failed to remove media: %r: %s", media_id, e) + continue + + try: + dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id) + for dir in dirs: + os.rmdir(dir) + except Exception: + pass + + thumbnail_dir = self.filepaths.url_cache_thumbnail_directory(media_id) + try: + shutil.rmtree(thumbnail_dir) + except OSError as e: + # If the path doesn't exist, meh + if e.errno != errno.ENOENT: + logger.warn("Failed to remove media: %r: %s", media_id, e) + continue + + removed_media.append(media_id) + + try: + dirs = self.filepaths.url_cache_thumbnail_dirs_to_delete(media_id) + for dir in dirs: + os.rmdir(dir) + except Exception: + pass + + yield self.store.delete_url_cache_media(removed_media) + + logger.info("Deleted %d media from url cache", len(removed_media)) + def decode_and_calc_og(body, media_uri, request_encoding=None): from lxml import etree @@ -425,7 +534,14 @@ def _calc_og(tree, media_uri): from lxml import etree TAGS_TO_REMOVE = ( - "header", "nav", "aside", "footer", "script", "style", etree.Comment + "header", + "nav", + "aside", + "footer", + "script", + "noscript", + "style", + etree.Comment ) # Split all the text nodes into paragraphs (by splitting on new diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index 3868d4f65f..e1ee535b9a 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -50,12 +50,16 @@ class Thumbnailer(object): else: return ((max_height * self.width) // self.height, max_height) - def scale(self, output_path, width, height, output_type): - """Rescales the image to the given dimensions""" + def scale(self, width, height, output_type): + """Rescales the image to the given dimensions. + + Returns: + BytesIO: the bytes of the encoded image ready to be written to disk + """ scaled = self.image.resize((width, height), Image.ANTIALIAS) - return self.save_image(scaled, output_type, output_path) + return self._encode_image(scaled, output_type) - def crop(self, output_path, width, height, output_type): + def crop(self, width, height, output_type): """Rescales and crops the image to the given dimensions preserving aspect:: (w_in / h_in) = (w_scaled / h_scaled) @@ -65,6 +69,9 @@ class Thumbnailer(object): Args: max_width: The largest possible width. max_height: The larget possible height. + + Returns: + BytesIO: the bytes of the encoded image ready to be written to disk """ if width * self.height > height * self.width: scaled_height = (width * self.height) // self.width @@ -82,13 +89,9 @@ class Thumbnailer(object): crop_left = (scaled_width - width) // 2 crop_right = width + crop_left cropped = scaled_image.crop((crop_left, 0, crop_right, height)) - return self.save_image(cropped, output_type, output_path) + return self._encode_image(cropped, output_type) - def save_image(self, output_image, output_type, output_path): + def _encode_image(self, output_image, output_type): output_bytes_io = BytesIO() output_image.save(output_bytes_io, self.FORMATS[output_type], quality=80) - output_bytes = output_bytes_io.getvalue() - with open(output_path, "wb") as output_file: - output_file.write(output_bytes) - logger.info("Stored thumbnail in file %r", output_path) - return len(output_bytes) + return output_bytes_io diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index 4ab33f73bf..f6f498cdc5 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -93,7 +93,7 @@ class UploadResource(Resource): # TODO(markjh): parse content-dispostion content_uri = yield self.media_repo.create_content( - media_type, upload_name, request.content.read(), + media_type, upload_name, request.content, content_length, requester.user ) diff --git a/synapse/server.py b/synapse/server.py index 4d44af745e..99693071b6 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -39,24 +39,33 @@ from synapse.federation.transaction_queue import TransactionQueue from synapse.handlers import Handlers from synapse.handlers.appservice import ApplicationServicesHandler from synapse.handlers.auth import AuthHandler, MacaroonGeneartor +from synapse.handlers.deactivate_account import DeactivateAccountHandler from synapse.handlers.devicemessage import DeviceMessageHandler from synapse.handlers.device import DeviceHandler from synapse.handlers.e2e_keys import E2eKeysHandler from synapse.handlers.presence import PresenceHandler from synapse.handlers.room_list import RoomListHandler +from synapse.handlers.set_password import SetPasswordHandler from synapse.handlers.sync import SyncHandler from synapse.handlers.typing import TypingHandler from synapse.handlers.events import EventHandler, EventStreamHandler from synapse.handlers.initial_sync import InitialSyncHandler from synapse.handlers.receipts import ReceiptsHandler from synapse.handlers.read_marker import ReadMarkerHandler -from synapse.handlers.user_directory import UserDirectoyHandler +from synapse.handlers.user_directory import UserDirectoryHandler +from synapse.handlers.groups_local import GroupsLocalHandler +from synapse.handlers.profile import ProfileHandler +from synapse.groups.groups_server import GroupsServerHandler +from synapse.groups.attestations import GroupAttestionRenewer, GroupAttestationSigning from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.notifier import Notifier from synapse.push.action_generator import ActionGenerator from synapse.push.pusherpool import PusherPool -from synapse.rest.media.v1.media_repository import MediaRepository +from synapse.rest.media.v1.media_repository import ( + MediaRepository, + MediaRepositoryResource, +) from synapse.state import StateHandler from synapse.storage import DataStore from synapse.streams.events import EventSources @@ -86,17 +95,12 @@ class HomeServer(object): """ DEPENDENCIES = [ - 'config', - 'clock', 'http_client', 'db_pool', - 'persistence_service', 'replication_layer', - 'datastore', 'handlers', 'v1auth', 'auth', - 'rest_servlet_factory', 'state_handler', 'presence_handler', 'sync_handler', @@ -112,19 +116,11 @@ class HomeServer(object): 'application_service_scheduler', 'application_service_handler', 'device_message_handler', + 'profile_handler', + 'deactivate_account_handler', + 'set_password_handler', 'notifier', - 'distributor', - 'client_resource', - 'resource_for_federation', - 'resource_for_static_content', - 'resource_for_web_client', - 'resource_for_content_repo', - 'resource_for_server_key', - 'resource_for_server_key_v2', - 'resource_for_media_repository', - 'resource_for_metrics', 'event_sources', - 'ratelimiter', 'keyring', 'pusherpool', 'event_builder_factory', @@ -132,6 +128,7 @@ class HomeServer(object): 'http_client_context_factory', 'simple_http_client', 'media_repository', + 'media_repository_resource', 'federation_transport_client', 'federation_sender', 'receipts_handler', @@ -140,6 +137,10 @@ class HomeServer(object): 'read_marker_handler', 'action_generator', 'user_directory_handler', + 'groups_local_handler', + 'groups_server_handler', + 'groups_attestation_signing', + 'groups_attestation_renewer', 'spam_checker', ] @@ -174,6 +175,21 @@ class HomeServer(object): def is_mine_id(self, string): return string.split(":", 1)[1] == self.hostname + def get_clock(self): + return self.clock + + def get_datastore(self): + return self.datastore + + def get_config(self): + return self.config + + def get_distributor(self): + return self.distributor + + def get_ratelimiter(self): + return self.ratelimiter + def build_replication_layer(self): return initialize_http_replication(self) @@ -253,6 +269,15 @@ class HomeServer(object): def build_initial_sync_handler(self): return InitialSyncHandler(self) + def build_profile_handler(self): + return ProfileHandler(self) + + def build_deactivate_account_handler(self): + return DeactivateAccountHandler(self) + + def build_set_password_handler(self): + return SetPasswordHandler(self) + def build_event_sources(self): return EventSources(self) @@ -282,6 +307,11 @@ class HomeServer(object): **self.db_config.get("args", {}) ) + def build_media_repository_resource(self): + # build the media repo resource. This indirects through the HomeServer + # to ensure that we only have a single instance of + return MediaRepositoryResource(self) + def build_media_repository(self): return MediaRepository(self) @@ -309,7 +339,19 @@ class HomeServer(object): return ActionGenerator(self) def build_user_directory_handler(self): - return UserDirectoyHandler(self) + return UserDirectoryHandler(self) + + def build_groups_local_handler(self): + return GroupsLocalHandler(self) + + def build_groups_server_handler(self): + return GroupsServerHandler(self) + + def build_groups_attestation_signing(self): + return GroupAttestationSigning(self) + + def build_groups_attestation_renewer(self): + return GroupAttestionRenewer(self) def build_spam_checker(self): return SpamChecker(self) diff --git a/synapse/server.pyi b/synapse/server.pyi index 9570df5537..41416ef252 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -1,10 +1,16 @@ import synapse.api.auth +import synapse.federation.transaction_queue +import synapse.federation.transport.client import synapse.handlers import synapse.handlers.auth +import synapse.handlers.deactivate_account import synapse.handlers.device import synapse.handlers.e2e_keys -import synapse.storage +import synapse.handlers.set_password +import synapse.rest.media.v1.media_repository import synapse.state +import synapse.storage + class HomeServer(object): def get_auth(self) -> synapse.api.auth.Auth: @@ -27,3 +33,21 @@ class HomeServer(object): def get_state_handler(self) -> synapse.state.StateHandler: pass + + def get_deactivate_account_handler(self) -> synapse.handlers.deactivate_account.DeactivateAccountHandler: + pass + + def get_set_password_handler(self) -> synapse.handlers.set_password.SetPasswordHandler: + pass + + def get_federation_sender(self) -> synapse.federation.transaction_queue.TransactionQueue: + pass + + def get_federation_transport_client(self) -> synapse.federation.transport.client.TransportLayerClient: + pass + + def get_media_repository_resource(self) -> synapse.rest.media.v1.media_repository.MediaRepositoryResource: + pass + + def get_media_repository(self) -> synapse.rest.media.v1.media_repository.MediaRepository: + pass diff --git a/synapse/state.py b/synapse/state.py index 390799fbd5..9e624b4937 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -288,6 +288,9 @@ class StateHandler(object): """ logger.debug("resolve_state_groups event_ids %s", event_ids) + # map from state group id to the state in that state group (where + # 'state' is a map from state key to event id) + # dict[int, dict[(str, str), str]] state_groups_ids = yield self.store.get_state_groups_ids( room_id, event_ids ) @@ -320,11 +323,15 @@ class StateHandler(object): "Resolving state for %s with %d groups", room_id, len(state_groups_ids) ) + # build a map from state key to the event_ids which set that state. + # dict[(str, str), set[str]) state = {} for st in state_groups_ids.values(): for key, e_id in st.items(): state.setdefault(key, set()).add(e_id) + # build a map from state key to the event_ids which set that state, + # including only those where there are state keys in conflict. conflicted_state = { k: list(v) for k, v in state.items() @@ -494,8 +501,14 @@ def _resolve_with_state_fac(unconflicted_state, conflicted_state, logger.info("Asking for %d conflicted events", len(needed_events)) + # dict[str, FrozenEvent]: a map from state event id to event. Only includes + # the state events which are in conflict. state_map = yield state_map_factory(needed_events) + # get the ids of the auth events which allow us to authenticate the + # conflicted state, picking only from the unconflicting state. + # + # dict[(str, str), str]: a map from state key to event id auth_events = _create_auth_events_from_maps( unconflicted_state, conflicted_state, state_map ) @@ -547,7 +560,7 @@ def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_ resolved_state = _resolve_state_events( conflicted_state, auth_events ) - except: + except Exception: logger.exception("Failed to resolve state") raise diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index b92472df33..d01d46338a 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -37,7 +37,7 @@ from .media_repository import MediaRepositoryStore from .rejections import RejectionsStore from .event_push_actions import EventPushActionsStore from .deviceinbox import DeviceInboxStore - +from .group_server import GroupServerStore from .state import StateStore from .signatures import SignatureStore from .filtering import FilteringStore @@ -88,6 +88,7 @@ class DataStore(RoomMemberStore, RoomStore, DeviceStore, DeviceInboxStore, UserDirectoryStore, + GroupServerStore, ): def __init__(self, db_conn, hs): @@ -135,6 +136,9 @@ class DataStore(RoomMemberStore, RoomStore, db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")], ) + self._group_updates_id_gen = StreamIdGenerator( + db_conn, "local_group_updates", "stream_id", + ) if isinstance(self.database_engine, PostgresEngine): self._cache_id_gen = StreamIdGenerator( @@ -235,6 +239,18 @@ class DataStore(RoomMemberStore, RoomStore, prefilled_cache=curr_state_delta_prefill, ) + _group_updates_prefill, min_group_updates_id = self._get_cache_dict( + db_conn, "local_group_updates", + entity_column="user_id", + stream_column="stream_id", + max_value=self._group_updates_id_gen.get_current_token(), + limit=1000, + ) + self._group_updates_stream_cache = StreamChangeCache( + "_group_updates_stream_cache", min_group_updates_id, + prefilled_cache=_group_updates_prefill, + ) + cur = LoggingTransaction( db_conn.cursor(), name="_find_stream_orderings_for_times_txn", @@ -252,7 +268,7 @@ class DataStore(RoomMemberStore, RoomStore, self._stream_order_on_start = self.get_room_max_stream_ordering() self._min_stream_order_on_start = self.get_room_min_stream_ordering() - super(DataStore, self).__init__(hs) + super(DataStore, self).__init__(db_conn, hs) def take_presence_startup_info(self): active_on_startup = self._presence_on_startup diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 6f54036d67..b971f0cb18 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -16,8 +16,6 @@ import logging from synapse.api.errors import StoreError from synapse.util.logcontext import LoggingContext, PreserveLoggingContext -from synapse.util.caches import CACHE_SIZE_FACTOR -from synapse.util.caches.dictionary_cache import DictionaryCache from synapse.util.caches.descriptors import Cache from synapse.storage.engines import PostgresEngine import synapse.metrics @@ -103,7 +101,7 @@ class LoggingTransaction(object): "[SQL values] {%s} %r", self.name, args[0] ) - except: + except Exception: # Don't let logging failures stop SQL from working pass @@ -162,7 +160,7 @@ class PerformanceCounters(object): class SQLBaseStore(object): _TXN_ID = 0 - def __init__(self, hs): + def __init__(self, db_conn, hs): self.hs = hs self._clock = hs.get_clock() self._db_pool = hs.get_db_pool() @@ -180,10 +178,6 @@ class SQLBaseStore(object): self._get_event_cache = Cache("*getEvent*", keylen=3, max_entries=hs.config.event_cache_size) - self._state_group_cache = DictionaryCache( - "*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR - ) - self._event_fetch_lock = threading.Condition() self._event_fetch_list = [] self._event_fetch_ongoing = 0 @@ -475,23 +469,53 @@ class SQLBaseStore(object): txn.executemany(sql, vals) + @defer.inlineCallbacks def _simple_upsert(self, table, keyvalues, values, insertion_values={}, desc="_simple_upsert", lock=True): """ + + `lock` should generally be set to True (the default), but can be set + to False if either of the following are true: + + * there is a UNIQUE INDEX on the key columns. In this case a conflict + will cause an IntegrityError in which case this function will retry + the update. + + * we somehow know that we are the only thread which will be updating + this table. + Args: table (str): The table to upsert into keyvalues (dict): The unique key tables and their new values values (dict): The nonunique columns and their new values - insertion_values (dict): key/values to use when inserting + insertion_values (dict): additional key/values to use only when + inserting + lock (bool): True to lock the table when doing the upsert. Returns: Deferred(bool): True if a new entry was created, False if an existing one was updated. """ - return self.runInteraction( - desc, - self._simple_upsert_txn, table, keyvalues, values, insertion_values, - lock - ) + attempts = 0 + while True: + try: + result = yield self.runInteraction( + desc, + self._simple_upsert_txn, table, keyvalues, values, insertion_values, + lock=lock + ) + defer.returnValue(result) + except self.database_engine.module.IntegrityError as e: + attempts += 1 + if attempts >= 5: + # don't retry forever, because things other than races + # can cause IntegrityErrors + raise + + # presumably we raced with another transaction: let's retry. + logger.warn( + "IntegrityError when upserting into %s; retrying: %s", + table, e + ) def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={}, lock=True): @@ -499,7 +523,7 @@ class SQLBaseStore(object): if lock: self.database_engine.lock_table(txn, table) - # Try to update + # First try to update. sql = "UPDATE %s SET %s WHERE %s" % ( table, ", ".join("%s = ?" % (k,) for k in values), @@ -508,28 +532,29 @@ class SQLBaseStore(object): sqlargs = values.values() + keyvalues.values() txn.execute(sql, sqlargs) - if txn.rowcount == 0: - # We didn't update and rows so insert a new one - allvalues = {} - allvalues.update(keyvalues) - allvalues.update(values) - allvalues.update(insertion_values) + if txn.rowcount > 0: + # successfully updated at least one row. + return False - sql = "INSERT INTO %s (%s) VALUES (%s)" % ( - table, - ", ".join(k for k in allvalues), - ", ".join("?" for _ in allvalues) - ) - txn.execute(sql, allvalues.values()) + # We didn't update any rows so insert a new one + allvalues = {} + allvalues.update(keyvalues) + allvalues.update(values) + allvalues.update(insertion_values) - return True - else: - return False + sql = "INSERT INTO %s (%s) VALUES (%s)" % ( + table, + ", ".join(k for k in allvalues), + ", ".join("?" for _ in allvalues) + ) + txn.execute(sql, allvalues.values()) + # successfully inserted + return True def _simple_select_one(self, table, keyvalues, retcols, allow_none=False, desc="_simple_select_one"): """Executes a SELECT query on the named table, which is expected to - return a single row, returning a single column from it. + return a single row, returning multiple columns from it. Args: table : string giving the table name @@ -582,20 +607,18 @@ class SQLBaseStore(object): @staticmethod def _simple_select_onecol_txn(txn, table, keyvalues, retcol): - if keyvalues: - where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys()) - else: - where = "" - sql = ( - "SELECT %(retcol)s FROM %(table)s %(where)s" + "SELECT %(retcol)s FROM %(table)s" ) % { "retcol": retcol, "table": table, - "where": where, } - txn.execute(sql, keyvalues.values()) + if keyvalues: + sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys()) + txn.execute(sql, keyvalues.values()) + else: + txn.execute(sql) return [r[0] for r in txn] @@ -606,7 +629,7 @@ class SQLBaseStore(object): Args: table (str): table name - keyvalues (dict): column names and values to select the rows with + keyvalues (dict|None): column names and values to select the rows with retcol (str): column whos value we wish to retrieve. Returns: @@ -743,6 +766,33 @@ class SQLBaseStore(object): txn.execute(sql, values) return cls.cursor_to_dict(txn) + def _simple_update(self, table, keyvalues, updatevalues, desc): + return self.runInteraction( + desc, + self._simple_update_txn, + table, keyvalues, updatevalues, + ) + + @staticmethod + def _simple_update_txn(txn, table, keyvalues, updatevalues): + if keyvalues: + where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys()) + else: + where = "" + + update_sql = "UPDATE %s SET %s %s" % ( + table, + ", ".join("%s = ?" % (k,) for k in updatevalues), + where, + ) + + txn.execute( + update_sql, + updatevalues.values() + keyvalues.values() + ) + + return txn.rowcount + def _simple_update_one(self, table, keyvalues, updatevalues, desc="_simple_update_one"): """Executes an UPDATE query on the named table, setting new values for @@ -768,27 +818,13 @@ class SQLBaseStore(object): table, keyvalues, updatevalues, ) - @staticmethod - def _simple_update_one_txn(txn, table, keyvalues, updatevalues): - if keyvalues: - where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys()) - else: - where = "" - - update_sql = "UPDATE %s SET %s %s" % ( - table, - ", ".join("%s = ?" % (k,) for k in updatevalues), - where, - ) - - txn.execute( - update_sql, - updatevalues.values() + keyvalues.values() - ) + @classmethod + def _simple_update_one_txn(cls, txn, table, keyvalues, updatevalues): + rowcount = cls._simple_update_txn(txn, table, keyvalues, updatevalues) - if txn.rowcount == 0: + if rowcount == 0: raise StoreError(404, "No row found") - if txn.rowcount > 1: + if rowcount > 1: raise StoreError(500, "More than one row matched") @staticmethod diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py index ff14e54c11..56a0bde549 100644 --- a/synapse/storage/account_data.py +++ b/synapse/storage/account_data.py @@ -63,7 +63,7 @@ class AccountDataStore(SQLBaseStore): "get_account_data_for_user", get_account_data_for_user_txn ) - @cachedInlineCallbacks(num_args=2) + @cachedInlineCallbacks(num_args=2, max_entries=5000) def get_global_account_data_by_type_for_user(self, data_type, user_id): """ Returns: @@ -222,9 +222,12 @@ class AccountDataStore(SQLBaseStore): """ content_json = json.dumps(content) - def add_account_data_txn(txn, next_id): - self._simple_upsert_txn( - txn, + with self._account_data_id_gen.get_next() as next_id: + # no need to lock here as room_account_data has a unique constraint + # on (user_id, room_id, account_data_type) so _simple_upsert will + # retry if there is a conflict. + yield self._simple_upsert( + desc="add_room_account_data", table="room_account_data", keyvalues={ "user_id": user_id, @@ -234,19 +237,20 @@ class AccountDataStore(SQLBaseStore): values={ "stream_id": next_id, "content": content_json, - } - ) - txn.call_after( - self._account_data_stream_cache.entity_has_changed, - user_id, next_id, + }, + lock=False, ) - txn.call_after(self.get_account_data_for_user.invalidate, (user_id,)) - self._update_max_stream_id(txn, next_id) - with self._account_data_id_gen.get_next() as next_id: - yield self.runInteraction( - "add_room_account_data", add_account_data_txn, next_id - ) + # it's theoretically possible for the above to succeed and the + # below to fail - in which case we might reuse a stream id on + # restart, and the above update might not get propagated. That + # doesn't sound any worse than the whole update getting lost, + # which is what would happen if we combined the two into one + # transaction. + yield self._update_max_stream_id(next_id) + + self._account_data_stream_cache.entity_has_changed(user_id, next_id) + self.get_account_data_for_user.invalidate((user_id,)) result = self._account_data_id_gen.get_current_token() defer.returnValue(result) @@ -263,9 +267,12 @@ class AccountDataStore(SQLBaseStore): """ content_json = json.dumps(content) - def add_account_data_txn(txn, next_id): - self._simple_upsert_txn( - txn, + with self._account_data_id_gen.get_next() as next_id: + # no need to lock here as account_data has a unique constraint on + # (user_id, account_data_type) so _simple_upsert will retry if + # there is a conflict. + yield self._simple_upsert( + desc="add_user_account_data", table="account_data", keyvalues={ "user_id": user_id, @@ -274,40 +281,46 @@ class AccountDataStore(SQLBaseStore): values={ "stream_id": next_id, "content": content_json, - } + }, + lock=False, ) - txn.call_after( - self._account_data_stream_cache.entity_has_changed, + + # it's theoretically possible for the above to succeed and the + # below to fail - in which case we might reuse a stream id on + # restart, and the above update might not get propagated. That + # doesn't sound any worse than the whole update getting lost, + # which is what would happen if we combined the two into one + # transaction. + yield self._update_max_stream_id(next_id) + + self._account_data_stream_cache.entity_has_changed( user_id, next_id, ) - txn.call_after(self.get_account_data_for_user.invalidate, (user_id,)) - txn.call_after( - self.get_global_account_data_by_type_for_user.invalidate, + self.get_account_data_for_user.invalidate((user_id,)) + self.get_global_account_data_by_type_for_user.invalidate( (account_data_type, user_id,) ) - self._update_max_stream_id(txn, next_id) - - with self._account_data_id_gen.get_next() as next_id: - yield self.runInteraction( - "add_user_account_data", add_account_data_txn, next_id - ) result = self._account_data_id_gen.get_current_token() defer.returnValue(result) - def _update_max_stream_id(self, txn, next_id): + def _update_max_stream_id(self, next_id): """Update the max stream_id Args: - txn: The database cursor next_id(int): The the revision to advance to. """ - update_max_id_sql = ( - "UPDATE account_data_max_stream_id" - " SET stream_id = ?" - " WHERE stream_id < ?" + def _update(txn): + update_max_id_sql = ( + "UPDATE account_data_max_stream_id" + " SET stream_id = ?" + " WHERE stream_id < ?" + ) + txn.execute(update_max_id_sql, (next_id, next_id)) + return self.runInteraction( + "update_account_data_max_stream_id", + _update, ) - txn.execute(update_max_id_sql, (next_id, next_id)) @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000) def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context): diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py index c63935cb07..d8c84b7141 100644 --- a/synapse/storage/appservice.py +++ b/synapse/storage/appservice.py @@ -48,8 +48,8 @@ def _make_exclusive_regex(services_cache): class ApplicationServiceStore(SQLBaseStore): - def __init__(self, hs): - super(ApplicationServiceStore, self).__init__(hs) + def __init__(self, db_conn, hs): + super(ApplicationServiceStore, self).__init__(db_conn, hs) self.hostname = hs.hostname self.services_cache = load_appservices( hs.hostname, @@ -173,8 +173,8 @@ class ApplicationServiceStore(SQLBaseStore): class ApplicationServiceTransactionStore(SQLBaseStore): - def __init__(self, hs): - super(ApplicationServiceTransactionStore, self).__init__(hs) + def __init__(self, db_conn, hs): + super(ApplicationServiceTransactionStore, self).__init__(db_conn, hs) @defer.inlineCallbacks def get_appservices_by_state(self, state): diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 7157fb1dfb..11a1b942f1 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -80,11 +80,12 @@ class BackgroundUpdateStore(SQLBaseStore): BACKGROUND_UPDATE_INTERVAL_MS = 1000 BACKGROUND_UPDATE_DURATION_MS = 100 - def __init__(self, hs): - super(BackgroundUpdateStore, self).__init__(hs) + def __init__(self, db_conn, hs): + super(BackgroundUpdateStore, self).__init__(db_conn, hs) self._background_update_performance = {} self._background_update_queue = [] self._background_update_handlers = {} + self._all_done = False @defer.inlineCallbacks def start_doing_background_updates(self): @@ -98,7 +99,7 @@ class BackgroundUpdateStore(SQLBaseStore): result = yield self.do_next_background_update( self.BACKGROUND_UPDATE_DURATION_MS ) - except: + except Exception: logger.exception("Error doing update") else: if result is None: @@ -106,9 +107,41 @@ class BackgroundUpdateStore(SQLBaseStore): "No more background updates to do." " Unscheduling background update task." ) + self._all_done = True defer.returnValue(None) @defer.inlineCallbacks + def has_completed_background_updates(self): + """Check if all the background updates have completed + + Returns: + Deferred[bool]: True if all background updates have completed + """ + # if we've previously determined that there is nothing left to do, that + # is easy + if self._all_done: + defer.returnValue(True) + + # obviously, if we have things in our queue, we're not done. + if self._background_update_queue: + defer.returnValue(False) + + # otherwise, check if there are updates to be run. This is important, + # as we may be running on a worker which doesn't perform the bg updates + # itself, but still wants to wait for them to happen. + updates = yield self._simple_select_onecol( + "background_updates", + keyvalues=None, + retcol="1", + desc="check_background_updates", + ) + if not updates: + self._all_done = True + defer.returnValue(True) + + defer.returnValue(False) + + @defer.inlineCallbacks def do_next_background_update(self, desired_duration_ms): """Does some amount of work on the next queued background update @@ -269,7 +302,7 @@ class BackgroundUpdateStore(SQLBaseStore): # Sqlite doesn't support concurrent creation of indexes. # # We don't use partial indices on SQLite as it wasn't introduced - # until 3.8, and wheezy has 3.7 + # until 3.8, and wheezy and CentOS 7 have 3.7 # # We assume that sqlite doesn't give us invalid indices; however # we may still end up with the index existing but the diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py index 3c95e90eca..a03d1d6104 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py @@ -32,14 +32,14 @@ LAST_SEEN_GRANULARITY = 120 * 1000 class ClientIpStore(background_updates.BackgroundUpdateStore): - def __init__(self, hs): + def __init__(self, db_conn, hs): self.client_ip_last_seen = Cache( name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR, ) - super(ClientIpStore, self).__init__(hs) + super(ClientIpStore, self).__init__(db_conn, hs) self.register_background_index_update( "user_ips_device_index", diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py index 0b62b493d5..548e795daf 100644 --- a/synapse/storage/deviceinbox.py +++ b/synapse/storage/deviceinbox.py @@ -29,8 +29,8 @@ logger = logging.getLogger(__name__) class DeviceInboxStore(BackgroundUpdateStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" - def __init__(self, hs): - super(DeviceInboxStore, self).__init__(hs) + def __init__(self, db_conn, hs): + super(DeviceInboxStore, self).__init__(db_conn, hs) self.register_background_index_update( "device_inbox_stream_index", diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index bb27fd1f70..bd2effdf34 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -26,8 +26,8 @@ logger = logging.getLogger(__name__) class DeviceStore(SQLBaseStore): - def __init__(self, hs): - super(DeviceStore, self).__init__(hs) + def __init__(self, db_conn, hs): + super(DeviceStore, self).__init__(db_conn, hs) # Map of (user_id, device_id) -> bool. If there is an entry that implies # the device exists. diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index e8133de2fa..55a05c59d5 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -39,8 +39,8 @@ class EventFederationStore(SQLBaseStore): EVENT_AUTH_STATE_ONLY = "event_auth_state_only" - def __init__(self, hs): - super(EventFederationStore, self).__init__(hs) + def __init__(self, db_conn, hs): + super(EventFederationStore, self).__init__(db_conn, hs) self.register_background_update_handler( self.EVENT_AUTH_STATE_ONLY, diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index d6d8723b4a..8efe2fd4bb 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -65,8 +65,8 @@ def _deserialize_action(actions, is_highlight): class EventPushActionsStore(SQLBaseStore): EPA_HIGHLIGHT_INDEX = "epa_highlight_index" - def __init__(self, hs): - super(EventPushActionsStore, self).__init__(hs) + def __init__(self, db_conn, hs): + super(EventPushActionsStore, self).__init__(db_conn, hs) self.register_background_index_update( self.EPA_HIGHLIGHT_INDEX, diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 7002b3752e..d08f7571d7 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -21,7 +21,7 @@ from synapse.events.utils import prune_event from synapse.util.async import ObservableDeferred from synapse.util.logcontext import ( - preserve_fn, PreserveLoggingContext, preserve_context_over_deferred + preserve_fn, PreserveLoggingContext, make_deferred_yieldable ) from synapse.util.logutils import log_function from synapse.util.metrics import Measure @@ -88,13 +88,23 @@ class _EventPeristenceQueue(object): def add_to_queue(self, room_id, events_and_contexts, backfilled): """Add events to the queue, with the given persist_event options. + NB: due to the normal usage pattern of this method, it does *not* + follow the synapse logcontext rules, and leaves the logcontext in + place whether or not the returned deferred is ready. + Args: room_id (str): events_and_contexts (list[(EventBase, EventContext)]): backfilled (bool): + + Returns: + defer.Deferred: a deferred which will resolve once the events are + persisted. Runs its callbacks *without* a logcontext. """ queue = self._event_persist_queues.setdefault(room_id, deque()) if queue: + # if the last item in the queue has the same `backfilled` setting, + # we can just add these new events to that item. end_item = queue[-1] if end_item.backfilled == backfilled: end_item.events_and_contexts.extend(events_and_contexts) @@ -113,11 +123,11 @@ class _EventPeristenceQueue(object): def handle_queue(self, room_id, per_item_callback): """Attempts to handle the queue for a room if not already being handled. - The given callback will be invoked with for each item in the queue,1 + The given callback will be invoked with for each item in the queue, of type _EventPersistQueueItem. The per_item_callback will continuously be called with new items, unless the queue becomnes empty. The return value of the function will be given to the deferreds waiting on the item, - exceptions will be passed to the deferres as well. + exceptions will be passed to the deferreds as well. This function should therefore be called whenever anything is added to the queue. @@ -187,8 +197,8 @@ class EventsStore(SQLBaseStore): EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" - def __init__(self, hs): - super(EventsStore, self).__init__(hs) + def __init__(self, db_conn, hs): + super(EventsStore, self).__init__(db_conn, hs) self._clock = hs.get_clock() self.register_background_update_handler( self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts @@ -233,7 +243,7 @@ class EventsStore(SQLBaseStore): deferreds = [] for room_id, evs_ctxs in partitioned.iteritems(): - d = preserve_fn(self._event_persist_queue.add_to_queue)( + d = self._event_persist_queue.add_to_queue( room_id, evs_ctxs, backfilled=backfilled, ) @@ -242,7 +252,7 @@ class EventsStore(SQLBaseStore): for room_id in partitioned: self._maybe_start_persisting(room_id) - return preserve_context_over_deferred( + return make_deferred_yieldable( defer.gatherResults(deferreds, consumeErrors=True) ) @@ -267,7 +277,7 @@ class EventsStore(SQLBaseStore): self._maybe_start_persisting(event.room_id) - yield preserve_context_over_deferred(deferred) + yield make_deferred_yieldable(deferred) max_persisted_id = yield self._stream_id_gen.get_current_token() defer.returnValue((event.internal_metadata.stream_ordering, max_persisted_id)) @@ -784,6 +794,9 @@ class EventsStore(SQLBaseStore): self._invalidate_cache_and_stream( txn, self.is_host_joined, (room_id, host) ) + self._invalidate_cache_and_stream( + txn, self.was_host_joined, (room_id, host) + ) self._invalidate_cache_and_stream( txn, self.get_users_in_room, (room_id,) @@ -1468,7 +1481,7 @@ class EventsStore(SQLBaseStore): for i in ids if i in res ]) - except: + except Exception: logger.exception("Failed to callback") with PreserveLoggingContext(): reactor.callFromThread(fire, event_list, row_dict) @@ -1523,7 +1536,7 @@ class EventsStore(SQLBaseStore): if not allow_rejected: rows[:] = [r for r in rows if not r["rejects"]] - res = yield preserve_context_over_deferred(defer.gatherResults( + res = yield make_deferred_yieldable(defer.gatherResults( [ preserve_fn(self._get_event_from_row)( row["internal_metadata"], row["json"], row["redacts"], diff --git a/synapse/storage/group_server.py b/synapse/storage/group_server.py new file mode 100644 index 0000000000..8fde1aab8e --- /dev/null +++ b/synapse/storage/group_server.py @@ -0,0 +1,1233 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.api.errors import SynapseError + +from ._base import SQLBaseStore + +import ujson as json + + +# The category ID for the "default" category. We don't store as null in the +# database to avoid the fun of null != null +_DEFAULT_CATEGORY_ID = "" +_DEFAULT_ROLE_ID = "" + + +class GroupServerStore(SQLBaseStore): + def get_group(self, group_id): + return self._simple_select_one( + table="groups", + keyvalues={ + "group_id": group_id, + }, + retcols=( + "name", "short_description", "long_description", "avatar_url", "is_public" + ), + allow_none=True, + desc="is_user_in_group", + ) + + def get_users_in_group(self, group_id, include_private=False): + # TODO: Pagination + + keyvalues = { + "group_id": group_id, + } + if not include_private: + keyvalues["is_public"] = True + + return self._simple_select_list( + table="group_users", + keyvalues=keyvalues, + retcols=("user_id", "is_public", "is_admin",), + desc="get_users_in_group", + ) + + def get_invited_users_in_group(self, group_id): + # TODO: Pagination + + return self._simple_select_onecol( + table="group_invites", + keyvalues={ + "group_id": group_id, + }, + retcol="user_id", + desc="get_invited_users_in_group", + ) + + def get_rooms_in_group(self, group_id, include_private=False): + # TODO: Pagination + + keyvalues = { + "group_id": group_id, + } + if not include_private: + keyvalues["is_public"] = True + + return self._simple_select_list( + table="group_rooms", + keyvalues=keyvalues, + retcols=("room_id", "is_public",), + desc="get_rooms_in_group", + ) + + def get_rooms_for_summary_by_category(self, group_id, include_private=False): + """Get the rooms and categories that should be included in a summary request + + Returns ([rooms], [categories]) + """ + def _get_rooms_for_summary_txn(txn): + keyvalues = { + "group_id": group_id, + } + if not include_private: + keyvalues["is_public"] = True + + sql = """ + SELECT room_id, is_public, category_id, room_order + FROM group_summary_rooms + WHERE group_id = ? + """ + + if not include_private: + sql += " AND is_public = ?" + txn.execute(sql, (group_id, True)) + else: + txn.execute(sql, (group_id,)) + + rooms = [ + { + "room_id": row[0], + "is_public": row[1], + "category_id": row[2] if row[2] != _DEFAULT_CATEGORY_ID else None, + "order": row[3], + } + for row in txn + ] + + sql = """ + SELECT category_id, is_public, profile, cat_order + FROM group_summary_room_categories + INNER JOIN group_room_categories USING (group_id, category_id) + WHERE group_id = ? + """ + + if not include_private: + sql += " AND is_public = ?" + txn.execute(sql, (group_id, True)) + else: + txn.execute(sql, (group_id,)) + + categories = { + row[0]: { + "is_public": row[1], + "profile": json.loads(row[2]), + "order": row[3], + } + for row in txn + } + + return rooms, categories + return self.runInteraction( + "get_rooms_for_summary", _get_rooms_for_summary_txn + ) + + def add_room_to_summary(self, group_id, room_id, category_id, order, is_public): + return self.runInteraction( + "add_room_to_summary", self._add_room_to_summary_txn, + group_id, room_id, category_id, order, is_public, + ) + + def _add_room_to_summary_txn(self, txn, group_id, room_id, category_id, order, + is_public): + """Add (or update) room's entry in summary. + + Args: + group_id (str) + room_id (str) + category_id (str): If not None then adds the category to the end of + the summary if its not already there. [Optional] + order (int): If not None inserts the room at that position, e.g. + an order of 1 will put the room first. Otherwise, the room gets + added to the end. + """ + room_in_group = self._simple_select_one_onecol_txn( + txn, + table="group_rooms", + keyvalues={ + "group_id": group_id, + "room_id": room_id, + }, + retcol="room_id", + allow_none=True, + ) + if not room_in_group: + raise SynapseError(400, "room not in group") + + if category_id is None: + category_id = _DEFAULT_CATEGORY_ID + else: + cat_exists = self._simple_select_one_onecol_txn( + txn, + table="group_room_categories", + keyvalues={ + "group_id": group_id, + "category_id": category_id, + }, + retcol="group_id", + allow_none=True, + ) + if not cat_exists: + raise SynapseError(400, "Category doesn't exist") + + # TODO: Check category is part of summary already + cat_exists = self._simple_select_one_onecol_txn( + txn, + table="group_summary_room_categories", + keyvalues={ + "group_id": group_id, + "category_id": category_id, + }, + retcol="group_id", + allow_none=True, + ) + if not cat_exists: + # If not, add it with an order larger than all others + txn.execute(""" + INSERT INTO group_summary_room_categories + (group_id, category_id, cat_order) + SELECT ?, ?, COALESCE(MAX(cat_order), 0) + 1 + FROM group_summary_room_categories + WHERE group_id = ? AND category_id = ? + """, (group_id, category_id, group_id, category_id)) + + existing = self._simple_select_one_txn( + txn, + table="group_summary_rooms", + keyvalues={ + "group_id": group_id, + "room_id": room_id, + "category_id": category_id, + }, + retcols=("room_order", "is_public",), + allow_none=True, + ) + + if order is not None: + # Shuffle other room orders that come after the given order + sql = """ + UPDATE group_summary_rooms SET room_order = room_order + 1 + WHERE group_id = ? AND category_id = ? AND room_order >= ? + """ + txn.execute(sql, (group_id, category_id, order,)) + elif not existing: + sql = """ + SELECT COALESCE(MAX(room_order), 0) + 1 FROM group_summary_rooms + WHERE group_id = ? AND category_id = ? + """ + txn.execute(sql, (group_id, category_id,)) + order, = txn.fetchone() + + if existing: + to_update = {} + if order is not None: + to_update["room_order"] = order + if is_public is not None: + to_update["is_public"] = is_public + self._simple_update_txn( + txn, + table="group_summary_rooms", + keyvalues={ + "group_id": group_id, + "category_id": category_id, + "room_id": room_id, + }, + values=to_update, + ) + else: + if is_public is None: + is_public = True + + self._simple_insert_txn( + txn, + table="group_summary_rooms", + values={ + "group_id": group_id, + "category_id": category_id, + "room_id": room_id, + "room_order": order, + "is_public": is_public, + }, + ) + + def remove_room_from_summary(self, group_id, room_id, category_id): + if category_id is None: + category_id = _DEFAULT_CATEGORY_ID + + return self._simple_delete( + table="group_summary_rooms", + keyvalues={ + "group_id": group_id, + "category_id": category_id, + "room_id": room_id, + }, + desc="remove_room_from_summary", + ) + + @defer.inlineCallbacks + def get_group_categories(self, group_id): + rows = yield self._simple_select_list( + table="group_room_categories", + keyvalues={ + "group_id": group_id, + }, + retcols=("category_id", "is_public", "profile"), + desc="get_group_categories", + ) + + defer.returnValue({ + row["category_id"]: { + "is_public": row["is_public"], + "profile": json.loads(row["profile"]), + } + for row in rows + }) + + @defer.inlineCallbacks + def get_group_category(self, group_id, category_id): + category = yield self._simple_select_one( + table="group_room_categories", + keyvalues={ + "group_id": group_id, + "category_id": category_id, + }, + retcols=("is_public", "profile"), + desc="get_group_category", + ) + + category["profile"] = json.loads(category["profile"]) + + defer.returnValue(category) + + def upsert_group_category(self, group_id, category_id, profile, is_public): + """Add/update room category for group + """ + insertion_values = {} + update_values = {"category_id": category_id} # This cannot be empty + + if profile is None: + insertion_values["profile"] = "{}" + else: + update_values["profile"] = json.dumps(profile) + + if is_public is None: + insertion_values["is_public"] = True + else: + update_values["is_public"] = is_public + + return self._simple_upsert( + table="group_room_categories", + keyvalues={ + "group_id": group_id, + "category_id": category_id, + }, + values=update_values, + insertion_values=insertion_values, + desc="upsert_group_category", + ) + + def remove_group_category(self, group_id, category_id): + return self._simple_delete( + table="group_room_categories", + keyvalues={ + "group_id": group_id, + "category_id": category_id, + }, + desc="remove_group_category", + ) + + @defer.inlineCallbacks + def get_group_roles(self, group_id): + rows = yield self._simple_select_list( + table="group_roles", + keyvalues={ + "group_id": group_id, + }, + retcols=("role_id", "is_public", "profile"), + desc="get_group_roles", + ) + + defer.returnValue({ + row["role_id"]: { + "is_public": row["is_public"], + "profile": json.loads(row["profile"]), + } + for row in rows + }) + + @defer.inlineCallbacks + def get_group_role(self, group_id, role_id): + role = yield self._simple_select_one( + table="group_roles", + keyvalues={ + "group_id": group_id, + "role_id": role_id, + }, + retcols=("is_public", "profile"), + desc="get_group_role", + ) + + role["profile"] = json.loads(role["profile"]) + + defer.returnValue(role) + + def upsert_group_role(self, group_id, role_id, profile, is_public): + """Add/remove user role + """ + insertion_values = {} + update_values = {"role_id": role_id} # This cannot be empty + + if profile is None: + insertion_values["profile"] = "{}" + else: + update_values["profile"] = json.dumps(profile) + + if is_public is None: + insertion_values["is_public"] = True + else: + update_values["is_public"] = is_public + + return self._simple_upsert( + table="group_roles", + keyvalues={ + "group_id": group_id, + "role_id": role_id, + }, + values=update_values, + insertion_values=insertion_values, + desc="upsert_group_role", + ) + + def remove_group_role(self, group_id, role_id): + return self._simple_delete( + table="group_roles", + keyvalues={ + "group_id": group_id, + "role_id": role_id, + }, + desc="remove_group_role", + ) + + def add_user_to_summary(self, group_id, user_id, role_id, order, is_public): + return self.runInteraction( + "add_user_to_summary", self._add_user_to_summary_txn, + group_id, user_id, role_id, order, is_public, + ) + + def _add_user_to_summary_txn(self, txn, group_id, user_id, role_id, order, + is_public): + """Add (or update) user's entry in summary. + + Args: + group_id (str) + user_id (str) + role_id (str): If not None then adds the role to the end of + the summary if its not already there. [Optional] + order (int): If not None inserts the user at that position, e.g. + an order of 1 will put the user first. Otherwise, the user gets + added to the end. + """ + user_in_group = self._simple_select_one_onecol_txn( + txn, + table="group_users", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + retcol="user_id", + allow_none=True, + ) + if not user_in_group: + raise SynapseError(400, "user not in group") + + if role_id is None: + role_id = _DEFAULT_ROLE_ID + else: + role_exists = self._simple_select_one_onecol_txn( + txn, + table="group_roles", + keyvalues={ + "group_id": group_id, + "role_id": role_id, + }, + retcol="group_id", + allow_none=True, + ) + if not role_exists: + raise SynapseError(400, "Role doesn't exist") + + # TODO: Check role is part of the summary already + role_exists = self._simple_select_one_onecol_txn( + txn, + table="group_summary_roles", + keyvalues={ + "group_id": group_id, + "role_id": role_id, + }, + retcol="group_id", + allow_none=True, + ) + if not role_exists: + # If not, add it with an order larger than all others + txn.execute(""" + INSERT INTO group_summary_roles + (group_id, role_id, role_order) + SELECT ?, ?, COALESCE(MAX(role_order), 0) + 1 + FROM group_summary_roles + WHERE group_id = ? AND role_id = ? + """, (group_id, role_id, group_id, role_id)) + + existing = self._simple_select_one_txn( + txn, + table="group_summary_users", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + "role_id": role_id, + }, + retcols=("user_order", "is_public",), + allow_none=True, + ) + + if order is not None: + # Shuffle other users orders that come after the given order + sql = """ + UPDATE group_summary_users SET user_order = user_order + 1 + WHERE group_id = ? AND role_id = ? AND user_order >= ? + """ + txn.execute(sql, (group_id, role_id, order,)) + elif not existing: + sql = """ + SELECT COALESCE(MAX(user_order), 0) + 1 FROM group_summary_users + WHERE group_id = ? AND role_id = ? + """ + txn.execute(sql, (group_id, role_id,)) + order, = txn.fetchone() + + if existing: + to_update = {} + if order is not None: + to_update["user_order"] = order + if is_public is not None: + to_update["is_public"] = is_public + self._simple_update_txn( + txn, + table="group_summary_users", + keyvalues={ + "group_id": group_id, + "role_id": role_id, + "user_id": user_id, + }, + values=to_update, + ) + else: + if is_public is None: + is_public = True + + self._simple_insert_txn( + txn, + table="group_summary_users", + values={ + "group_id": group_id, + "role_id": role_id, + "user_id": user_id, + "user_order": order, + "is_public": is_public, + }, + ) + + def remove_user_from_summary(self, group_id, user_id, role_id): + if role_id is None: + role_id = _DEFAULT_ROLE_ID + + return self._simple_delete( + table="group_summary_users", + keyvalues={ + "group_id": group_id, + "role_id": role_id, + "user_id": user_id, + }, + desc="remove_user_from_summary", + ) + + def get_users_for_summary_by_role(self, group_id, include_private=False): + """Get the users and roles that should be included in a summary request + + Returns ([users], [roles]) + """ + def _get_users_for_summary_txn(txn): + keyvalues = { + "group_id": group_id, + } + if not include_private: + keyvalues["is_public"] = True + + sql = """ + SELECT user_id, is_public, role_id, user_order + FROM group_summary_users + WHERE group_id = ? + """ + + if not include_private: + sql += " AND is_public = ?" + txn.execute(sql, (group_id, True)) + else: + txn.execute(sql, (group_id,)) + + users = [ + { + "user_id": row[0], + "is_public": row[1], + "role_id": row[2] if row[2] != _DEFAULT_ROLE_ID else None, + "order": row[3], + } + for row in txn + ] + + sql = """ + SELECT role_id, is_public, profile, role_order + FROM group_summary_roles + INNER JOIN group_roles USING (group_id, role_id) + WHERE group_id = ? + """ + + if not include_private: + sql += " AND is_public = ?" + txn.execute(sql, (group_id, True)) + else: + txn.execute(sql, (group_id,)) + + roles = { + row[0]: { + "is_public": row[1], + "profile": json.loads(row[2]), + "order": row[3], + } + for row in txn + } + + return users, roles + return self.runInteraction( + "get_users_for_summary_by_role", _get_users_for_summary_txn + ) + + def is_user_in_group(self, user_id, group_id): + return self._simple_select_one_onecol( + table="group_users", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + retcol="user_id", + allow_none=True, + desc="is_user_in_group", + ).addCallback(lambda r: bool(r)) + + def is_user_admin_in_group(self, group_id, user_id): + return self._simple_select_one_onecol( + table="group_users", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + retcol="is_admin", + allow_none=True, + desc="is_user_admin_in_group", + ) + + def add_group_invite(self, group_id, user_id): + """Record that the group server has invited a user + """ + return self._simple_insert( + table="group_invites", + values={ + "group_id": group_id, + "user_id": user_id, + }, + desc="add_group_invite", + ) + + def is_user_invited_to_local_group(self, group_id, user_id): + """Has the group server invited a user? + """ + return self._simple_select_one_onecol( + table="group_invites", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + retcol="user_id", + desc="is_user_invited_to_local_group", + allow_none=True, + ) + + def get_users_membership_info_in_group(self, group_id, user_id): + """Get a dict describing the membership of a user in a group. + + Example if joined: + + { + "membership": "join", + "is_public": True, + "is_privileged": False, + } + + Returns an empty dict if the user is not join/invite/etc + """ + def _get_users_membership_in_group_txn(txn): + row = self._simple_select_one_txn( + txn, + table="group_users", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + retcols=("is_admin", "is_public"), + allow_none=True, + ) + + if row: + return { + "membership": "join", + "is_public": row["is_public"], + "is_privileged": row["is_admin"], + } + + row = self._simple_select_one_onecol_txn( + txn, + table="group_invites", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + retcol="user_id", + allow_none=True, + ) + + if row: + return { + "membership": "invite", + } + + return {} + + return self.runInteraction( + "get_users_membership_info_in_group", _get_users_membership_in_group_txn, + ) + + def add_user_to_group(self, group_id, user_id, is_admin=False, is_public=True, + local_attestation=None, remote_attestation=None): + """Add a user to the group server. + + Args: + group_id (str) + user_id (str) + is_admin (bool) + is_public (bool) + local_attestation (dict): The attestation the GS created to give + to the remote server. Optional if the user and group are on the + same server + remote_attestation (dict): The attestation given to GS by remote + server. Optional if the user and group are on the same server + """ + def _add_user_to_group_txn(txn): + self._simple_insert_txn( + txn, + table="group_users", + values={ + "group_id": group_id, + "user_id": user_id, + "is_admin": is_admin, + "is_public": is_public, + }, + ) + + self._simple_delete_txn( + txn, + table="group_invites", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + ) + + if local_attestation: + self._simple_insert_txn( + txn, + table="group_attestations_renewals", + values={ + "group_id": group_id, + "user_id": user_id, + "valid_until_ms": local_attestation["valid_until_ms"], + }, + ) + if remote_attestation: + self._simple_insert_txn( + txn, + table="group_attestations_remote", + values={ + "group_id": group_id, + "user_id": user_id, + "valid_until_ms": remote_attestation["valid_until_ms"], + "attestation_json": json.dumps(remote_attestation), + }, + ) + + return self.runInteraction( + "add_user_to_group", _add_user_to_group_txn + ) + + def remove_user_from_group(self, group_id, user_id): + def _remove_user_from_group_txn(txn): + self._simple_delete_txn( + txn, + table="group_users", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + ) + self._simple_delete_txn( + txn, + table="group_invites", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + ) + self._simple_delete_txn( + txn, + table="group_attestations_renewals", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + ) + self._simple_delete_txn( + txn, + table="group_attestations_remote", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + ) + self._simple_delete_txn( + txn, + table="group_summary_users", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + ) + return self.runInteraction("remove_user_from_group", _remove_user_from_group_txn) + + def add_room_to_group(self, group_id, room_id, is_public): + return self._simple_insert( + table="group_rooms", + values={ + "group_id": group_id, + "room_id": room_id, + "is_public": is_public, + }, + desc="add_room_to_group", + ) + + def update_room_in_group_visibility(self, group_id, room_id, is_public): + return self._simple_update( + table="group_rooms", + keyvalues={ + "group_id": group_id, + "room_id": room_id, + }, + updatevalues={ + "is_public": is_public, + }, + desc="update_room_in_group_visibility", + ) + + def remove_room_from_group(self, group_id, room_id): + def _remove_room_from_group_txn(txn): + self._simple_delete_txn( + txn, + table="group_rooms", + keyvalues={ + "group_id": group_id, + "room_id": room_id, + }, + ) + + self._simple_delete_txn( + txn, + table="group_summary_rooms", + keyvalues={ + "group_id": group_id, + "room_id": room_id, + }, + ) + return self.runInteraction( + "remove_room_from_group", _remove_room_from_group_txn, + ) + + def get_publicised_groups_for_user(self, user_id): + """Get all groups a user is publicising + """ + return self._simple_select_onecol( + table="local_group_membership", + keyvalues={ + "user_id": user_id, + "membership": "join", + "is_publicised": True, + }, + retcol="group_id", + desc="get_publicised_groups_for_user", + ) + + def update_group_publicity(self, group_id, user_id, publicise): + """Update whether the user is publicising their membership of the group + """ + return self._simple_update_one( + table="local_group_membership", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + updatevalues={ + "is_publicised": publicise, + }, + desc="update_group_publicity" + ) + + @defer.inlineCallbacks + def register_user_group_membership(self, group_id, user_id, membership, + is_admin=False, content={}, + local_attestation=None, + remote_attestation=None, + is_publicised=False, + ): + """Registers that a local user is a member of a (local or remote) group. + + Args: + group_id (str) + user_id (str) + membership (str) + is_admin (bool) + content (dict): Content of the membership, e.g. includes the inviter + if the user has been invited. + local_attestation (dict): If remote group then store the fact that we + have given out an attestation, else None. + remote_attestation (dict): If remote group then store the remote + attestation from the group, else None. + """ + def _register_user_group_membership_txn(txn, next_id): + # TODO: Upsert? + self._simple_delete_txn( + txn, + table="local_group_membership", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + ) + self._simple_insert_txn( + txn, + table="local_group_membership", + values={ + "group_id": group_id, + "user_id": user_id, + "is_admin": is_admin, + "membership": membership, + "is_publicised": is_publicised, + "content": json.dumps(content), + }, + ) + + self._simple_insert_txn( + txn, + table="local_group_updates", + values={ + "stream_id": next_id, + "group_id": group_id, + "user_id": user_id, + "type": "membership", + "content": json.dumps({"membership": membership, "content": content}), + } + ) + self._group_updates_stream_cache.entity_has_changed(user_id, next_id) + + # TODO: Insert profile to ensure it comes down stream if its a join. + + if membership == "join": + if local_attestation: + self._simple_insert_txn( + txn, + table="group_attestations_renewals", + values={ + "group_id": group_id, + "user_id": user_id, + "valid_until_ms": local_attestation["valid_until_ms"], + } + ) + if remote_attestation: + self._simple_insert_txn( + txn, + table="group_attestations_remote", + values={ + "group_id": group_id, + "user_id": user_id, + "valid_until_ms": remote_attestation["valid_until_ms"], + "attestation_json": json.dumps(remote_attestation), + } + ) + else: + self._simple_delete_txn( + txn, + table="group_attestations_renewals", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + ) + self._simple_delete_txn( + txn, + table="group_attestations_remote", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + ) + + return next_id + + with self._group_updates_id_gen.get_next() as next_id: + res = yield self.runInteraction( + "register_user_group_membership", + _register_user_group_membership_txn, next_id, + ) + defer.returnValue(res) + + @defer.inlineCallbacks + def create_group(self, group_id, user_id, name, avatar_url, short_description, + long_description,): + yield self._simple_insert( + table="groups", + values={ + "group_id": group_id, + "name": name, + "avatar_url": avatar_url, + "short_description": short_description, + "long_description": long_description, + "is_public": True, + }, + desc="create_group", + ) + + @defer.inlineCallbacks + def update_group_profile(self, group_id, profile,): + yield self._simple_update_one( + table="groups", + keyvalues={ + "group_id": group_id, + }, + updatevalues=profile, + desc="update_group_profile", + ) + + def get_attestations_need_renewals(self, valid_until_ms): + """Get all attestations that need to be renewed until givent time + """ + def _get_attestations_need_renewals_txn(txn): + sql = """ + SELECT group_id, user_id FROM group_attestations_renewals + WHERE valid_until_ms <= ? + """ + txn.execute(sql, (valid_until_ms,)) + return self.cursor_to_dict(txn) + return self.runInteraction( + "get_attestations_need_renewals", _get_attestations_need_renewals_txn + ) + + def update_attestation_renewal(self, group_id, user_id, attestation): + """Update an attestation that we have renewed + """ + return self._simple_update_one( + table="group_attestations_renewals", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + updatevalues={ + "valid_until_ms": attestation["valid_until_ms"], + }, + desc="update_attestation_renewal", + ) + + def update_remote_attestion(self, group_id, user_id, attestation): + """Update an attestation that a remote has renewed + """ + return self._simple_update_one( + table="group_attestations_remote", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + updatevalues={ + "valid_until_ms": attestation["valid_until_ms"], + "attestation_json": json.dumps(attestation) + }, + desc="update_remote_attestion", + ) + + def remove_attestation_renewal(self, group_id, user_id): + """Remove an attestation that we thought we should renew, but actually + shouldn't. Ideally this would never get called as we would never + incorrectly try and do attestations for local users on local groups. + + Args: + group_id (str) + user_id (str) + """ + return self._simple_delete( + table="group_attestations_renewals", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + desc="remove_attestation_renewal", + ) + + @defer.inlineCallbacks + def get_remote_attestation(self, group_id, user_id): + """Get the attestation that proves the remote agrees that the user is + in the group. + """ + row = yield self._simple_select_one( + table="group_attestations_remote", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + retcols=("valid_until_ms", "attestation_json"), + desc="get_remote_attestation", + allow_none=True, + ) + + now = int(self._clock.time_msec()) + if row and now < row["valid_until_ms"]: + defer.returnValue(json.loads(row["attestation_json"])) + + defer.returnValue(None) + + def get_joined_groups(self, user_id): + return self._simple_select_onecol( + table="local_group_membership", + keyvalues={ + "user_id": user_id, + "membership": "join", + }, + retcol="group_id", + desc="get_joined_groups", + ) + + def get_all_groups_for_user(self, user_id, now_token): + def _get_all_groups_for_user_txn(txn): + sql = """ + SELECT group_id, type, membership, u.content + FROM local_group_updates AS u + INNER JOIN local_group_membership USING (group_id, user_id) + WHERE user_id = ? AND membership != 'leave' + AND stream_id <= ? + """ + txn.execute(sql, (user_id, now_token,)) + return [ + { + "group_id": row[0], + "type": row[1], + "membership": row[2], + "content": json.loads(row[3]), + } + for row in txn + ] + return self.runInteraction( + "get_all_groups_for_user", _get_all_groups_for_user_txn, + ) + + def get_groups_changes_for_user(self, user_id, from_token, to_token): + from_token = int(from_token) + has_changed = self._group_updates_stream_cache.has_entity_changed( + user_id, from_token, + ) + if not has_changed: + return [] + + def _get_groups_changes_for_user_txn(txn): + sql = """ + SELECT group_id, membership, type, u.content + FROM local_group_updates AS u + INNER JOIN local_group_membership USING (group_id, user_id) + WHERE user_id = ? AND ? < stream_id AND stream_id <= ? + """ + txn.execute(sql, (user_id, from_token, to_token,)) + return [{ + "group_id": group_id, + "membership": membership, + "type": gtype, + "content": json.loads(content_json), + } for group_id, membership, gtype, content_json in txn] + return self.runInteraction( + "get_groups_changes_for_user", _get_groups_changes_for_user_txn, + ) + + def get_all_groups_changes(self, from_token, to_token, limit): + from_token = int(from_token) + has_changed = self._group_updates_stream_cache.has_any_entity_changed( + from_token, + ) + if not has_changed: + return [] + + def _get_all_groups_changes_txn(txn): + sql = """ + SELECT stream_id, group_id, user_id, type, content + FROM local_group_updates + WHERE ? < stream_id AND stream_id <= ? + LIMIT ? + """ + txn.execute(sql, (from_token, to_token, limit,)) + return [( + stream_id, + group_id, + user_id, + gtype, + json.loads(content_json), + ) for stream_id, group_id, user_id, gtype, content_json in txn] + return self.runInteraction( + "get_all_groups_changes", _get_all_groups_changes_txn, + ) + + def get_group_stream_token(self): + return self._group_updates_id_gen.get_current_token() diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py index 82bb61b811..a66ff7c1e0 100644 --- a/synapse/storage/media_repository.py +++ b/synapse/storage/media_repository.py @@ -12,13 +12,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from synapse.storage.background_updates import BackgroundUpdateStore -from ._base import SQLBaseStore - -class MediaRepositoryStore(SQLBaseStore): +class MediaRepositoryStore(BackgroundUpdateStore): """Persistence for attachments and avatars""" + def __init__(self, db_conn, hs): + super(MediaRepositoryStore, self).__init__(db_conn, hs) + + self.register_background_index_update( + update_name='local_media_repository_url_idx', + index_name='local_media_repository_url_idx', + table='local_media_repository', + columns=['created_ts'], + where_clause='url_cache IS NOT NULL', + ) + def get_default_thumbnails(self, top_level_type, sub_type): return [] @@ -62,7 +72,7 @@ class MediaRepositoryStore(SQLBaseStore): def get_url_cache_txn(txn): # get the most recently cached result (relative to the given ts) sql = ( - "SELECT response_code, etag, expires, og, media_id, download_ts" + "SELECT response_code, etag, expires_ts, og, media_id, download_ts" " FROM local_media_repository_url_cache" " WHERE url = ? AND download_ts <= ?" " ORDER BY download_ts DESC LIMIT 1" @@ -74,7 +84,7 @@ class MediaRepositoryStore(SQLBaseStore): # ...or if we've requested a timestamp older than the oldest # copy in the cache, return the oldest copy (if any) sql = ( - "SELECT response_code, etag, expires, og, media_id, download_ts" + "SELECT response_code, etag, expires_ts, og, media_id, download_ts" " FROM local_media_repository_url_cache" " WHERE url = ? AND download_ts > ?" " ORDER BY download_ts ASC LIMIT 1" @@ -86,14 +96,14 @@ class MediaRepositoryStore(SQLBaseStore): return None return dict(zip(( - 'response_code', 'etag', 'expires', 'og', 'media_id', 'download_ts' + 'response_code', 'etag', 'expires_ts', 'og', 'media_id', 'download_ts' ), row)) return self.runInteraction( "get_url_cache", get_url_cache_txn ) - def store_url_cache(self, url, response_code, etag, expires, og, media_id, + def store_url_cache(self, url, response_code, etag, expires_ts, og, media_id, download_ts): return self._simple_insert( "local_media_repository_url_cache", @@ -101,7 +111,7 @@ class MediaRepositoryStore(SQLBaseStore): "url": url, "response_code": response_code, "etag": etag, - "expires": expires, + "expires_ts": expires_ts, "og": og, "media_id": media_id, "download_ts": download_ts, @@ -238,3 +248,70 @@ class MediaRepositoryStore(SQLBaseStore): }, ) return self.runInteraction("delete_remote_media", delete_remote_media_txn) + + def get_expired_url_cache(self, now_ts): + sql = ( + "SELECT media_id FROM local_media_repository_url_cache" + " WHERE expires_ts < ?" + " ORDER BY expires_ts ASC" + " LIMIT 500" + ) + + def _get_expired_url_cache_txn(txn): + txn.execute(sql, (now_ts,)) + return [row[0] for row in txn] + + return self.runInteraction("get_expired_url_cache", _get_expired_url_cache_txn) + + def delete_url_cache(self, media_ids): + if len(media_ids) == 0: + return + + sql = ( + "DELETE FROM local_media_repository_url_cache" + " WHERE media_id = ?" + ) + + def _delete_url_cache_txn(txn): + txn.executemany(sql, [(media_id,) for media_id in media_ids]) + + return self.runInteraction("delete_url_cache", _delete_url_cache_txn) + + def get_url_cache_media_before(self, before_ts): + sql = ( + "SELECT media_id FROM local_media_repository" + " WHERE created_ts < ? AND url_cache IS NOT NULL" + " ORDER BY created_ts ASC" + " LIMIT 500" + ) + + def _get_url_cache_media_before_txn(txn): + txn.execute(sql, (before_ts,)) + return [row[0] for row in txn] + + return self.runInteraction( + "get_url_cache_media_before", _get_url_cache_media_before_txn, + ) + + def delete_url_cache_media(self, media_ids): + if len(media_ids) == 0: + return + + def _delete_url_cache_media_txn(txn): + sql = ( + "DELETE FROM local_media_repository" + " WHERE media_id = ?" + ) + + txn.executemany(sql, [(media_id,) for media_id in media_ids]) + + sql = ( + "DELETE FROM local_media_repository_thumbnails" + " WHERE media_id = ?" + ) + + txn.executemany(sql, [(media_id,) for media_id in media_ids]) + + return self.runInteraction( + "delete_url_cache_media", _delete_url_cache_media_txn, + ) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 72b670b83b..d1691bbac2 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 43 +SCHEMA_VERSION = 46 dir_path = os.path.abspath(os.path.dirname(__file__)) @@ -44,6 +44,13 @@ def prepare_database(db_conn, database_engine, config): If `config` is None then prepare_database will assert that no upgrade is necessary, *or* will create a fresh database if the database is empty. + + Args: + db_conn: + database_engine: + config (synapse.config.homeserver.HomeServerConfig|None): + application config, or None if we are connecting to an existing + database which we expect to be configured already """ try: cur = db_conn.cursor() @@ -64,9 +71,13 @@ def prepare_database(db_conn, database_engine, config): else: _setup_new_database(cur, database_engine) + # check if any of our configured dynamic modules want a database + if config is not None: + _apply_module_schemas(cur, database_engine, config) + cur.close() db_conn.commit() - except: + except Exception: db_conn.rollback() raise @@ -283,6 +294,65 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files, ) +def _apply_module_schemas(txn, database_engine, config): + """Apply the module schemas for the dynamic modules, if any + + Args: + cur: database cursor + database_engine: synapse database engine class + config (synapse.config.homeserver.HomeServerConfig): + application config + """ + for (mod, _config) in config.password_providers: + if not hasattr(mod, 'get_db_schema_files'): + continue + modname = ".".join((mod.__module__, mod.__name__)) + _apply_module_schema_files( + txn, database_engine, modname, mod.get_db_schema_files(), + ) + + +def _apply_module_schema_files(cur, database_engine, modname, names_and_streams): + """Apply the module schemas for a single module + + Args: + cur: database cursor + database_engine: synapse database engine class + modname (str): fully qualified name of the module + names_and_streams (Iterable[(str, file)]): the names and streams of + schemas to be applied + """ + cur.execute( + database_engine.convert_param_style( + "SELECT file FROM applied_module_schemas WHERE module_name = ?" + ), + (modname,) + ) + applied_deltas = set(d for d, in cur) + for (name, stream) in names_and_streams: + if name in applied_deltas: + continue + + root_name, ext = os.path.splitext(name) + if ext != '.sql': + raise PrepareDatabaseException( + "only .sql files are currently supported for module schemas", + ) + + logger.info("applying schema %s for %s", name, modname) + for statement in get_statements(stream): + cur.execute(statement) + + # Mark as done. + cur.execute( + database_engine.convert_param_style( + "INSERT INTO applied_module_schemas (module_name, file)" + " VALUES (?,?)", + ), + (modname, name) + ) + + def get_statements(f): statement_buffer = "" in_comment = False # If we're in a /* ... */ style comment diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py index 26a40905ae..ec02e73bc2 100644 --- a/synapse/storage/profile.py +++ b/synapse/storage/profile.py @@ -13,6 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet import defer + +from synapse.storage.roommember import ProfileInfo +from synapse.api.errors import StoreError + from ._base import SQLBaseStore @@ -24,6 +29,30 @@ class ProfileStore(SQLBaseStore): desc="create_profile", ) + @defer.inlineCallbacks + def get_profileinfo(self, user_localpart): + try: + profile = yield self._simple_select_one( + table="profiles", + keyvalues={"user_id": user_localpart}, + retcols=("displayname", "avatar_url"), + desc="get_profileinfo", + ) + except StoreError as e: + if e.code == 404: + # no match + defer.returnValue(ProfileInfo(None, None)) + return + else: + raise + + defer.returnValue( + ProfileInfo( + avatar_url=profile['avatar_url'], + display_name=profile['displayname'], + ) + ) + def get_profile_displayname(self, user_localpart): return self._simple_select_one_onecol( table="profiles", @@ -55,3 +84,99 @@ class ProfileStore(SQLBaseStore): updatevalues={"avatar_url": new_avatar_url}, desc="set_profile_avatar_url", ) + + def get_from_remote_profile_cache(self, user_id): + return self._simple_select_one( + table="remote_profile_cache", + keyvalues={"user_id": user_id}, + retcols=("displayname", "avatar_url",), + allow_none=True, + desc="get_from_remote_profile_cache", + ) + + def add_remote_profile_cache(self, user_id, displayname, avatar_url): + """Ensure we are caching the remote user's profiles. + + This should only be called when `is_subscribed_remote_profile_for_user` + would return true for the user. + """ + return self._simple_upsert( + table="remote_profile_cache", + keyvalues={"user_id": user_id}, + values={ + "displayname": displayname, + "avatar_url": avatar_url, + "last_check": self._clock.time_msec(), + }, + desc="add_remote_profile_cache", + ) + + def update_remote_profile_cache(self, user_id, displayname, avatar_url): + return self._simple_update( + table="remote_profile_cache", + keyvalues={"user_id": user_id}, + values={ + "displayname": displayname, + "avatar_url": avatar_url, + "last_check": self._clock.time_msec(), + }, + desc="update_remote_profile_cache", + ) + + @defer.inlineCallbacks + def maybe_delete_remote_profile_cache(self, user_id): + """Check if we still care about the remote user's profile, and if we + don't then remove their profile from the cache + """ + subscribed = yield self.is_subscribed_remote_profile_for_user(user_id) + if not subscribed: + yield self._simple_delete( + table="remote_profile_cache", + keyvalues={"user_id": user_id}, + desc="delete_remote_profile_cache", + ) + + def get_remote_profile_cache_entries_that_expire(self, last_checked): + """Get all users who haven't been checked since `last_checked` + """ + def _get_remote_profile_cache_entries_that_expire_txn(txn): + sql = """ + SELECT user_id, displayname, avatar_url + FROM remote_profile_cache + WHERE last_check < ? + """ + + txn.execute(sql, (last_checked,)) + + return self.cursor_to_dict(txn) + + return self.runInteraction( + "get_remote_profile_cache_entries_that_expire", + _get_remote_profile_cache_entries_that_expire_txn, + ) + + @defer.inlineCallbacks + def is_subscribed_remote_profile_for_user(self, user_id): + """Check whether we are interested in a remote user's profile. + """ + res = yield self._simple_select_one_onecol( + table="group_users", + keyvalues={"user_id": user_id}, + retcol="user_id", + allow_none=True, + desc="should_update_remote_profile_cache_for_user", + ) + + if res: + defer.returnValue(True) + + res = yield self._simple_select_one_onecol( + table="group_invites", + keyvalues={"user_id": user_id}, + retcol="user_id", + allow_none=True, + desc="should_update_remote_profile_cache_for_user", + ) + + if res: + defer.returnValue(True) diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 34d2f82b7f..3d8b4d5d5b 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -204,34 +204,35 @@ class PusherStore(SQLBaseStore): pushkey, pushkey_ts, lang, data, last_stream_ordering, profile_tag=""): with self._pushers_id_gen.get_next() as stream_id: - def f(txn): - newly_inserted = self._simple_upsert_txn( - txn, - "pushers", - { - "app_id": app_id, - "pushkey": pushkey, - "user_name": user_id, - }, - { - "access_token": access_token, - "kind": kind, - "app_display_name": app_display_name, - "device_display_name": device_display_name, - "ts": pushkey_ts, - "lang": lang, - "data": encode_canonical_json(data), - "last_stream_ordering": last_stream_ordering, - "profile_tag": profile_tag, - "id": stream_id, - }, - ) - if newly_inserted: - # get_if_user_has_pusher only cares if the user has - # at least *one* pusher. - txn.call_after(self.get_if_user_has_pusher.invalidate, (user_id,)) + # no need to lock because `pushers` has a unique key on + # (app_id, pushkey, user_name) so _simple_upsert will retry + newly_inserted = yield self._simple_upsert( + table="pushers", + keyvalues={ + "app_id": app_id, + "pushkey": pushkey, + "user_name": user_id, + }, + values={ + "access_token": access_token, + "kind": kind, + "app_display_name": app_display_name, + "device_display_name": device_display_name, + "ts": pushkey_ts, + "lang": lang, + "data": encode_canonical_json(data), + "last_stream_ordering": last_stream_ordering, + "profile_tag": profile_tag, + "id": stream_id, + }, + desc="add_pusher", + lock=False, + ) - yield self.runInteraction("add_pusher", f) + if newly_inserted: + # get_if_user_has_pusher only cares if the user has + # at least *one* pusher. + self.get_if_user_has_pusher.invalidate(user_id,) @defer.inlineCallbacks def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id): @@ -243,11 +244,19 @@ class PusherStore(SQLBaseStore): "pushers", {"app_id": app_id, "pushkey": pushkey, "user_name": user_id} ) - self._simple_upsert_txn( + + # it's possible for us to end up with duplicate rows for + # (app_id, pushkey, user_id) at different stream_ids, but that + # doesn't really matter. + self._simple_insert_txn( txn, - "deleted_pushers", - {"app_id": app_id, "pushkey": pushkey, "user_id": user_id}, - {"stream_id": stream_id}, + table="deleted_pushers", + values={ + "stream_id": stream_id, + "app_id": app_id, + "pushkey": pushkey, + "user_id": user_id, + }, ) with self._pushers_id_gen.get_next() as stream_id: @@ -310,9 +319,12 @@ class PusherStore(SQLBaseStore): @defer.inlineCallbacks def set_throttle_params(self, pusher_id, room_id, params): + # no need to lock because `pusher_throttle` has a primary key on + # (pusher, room_id) so _simple_upsert will retry yield self._simple_upsert( "pusher_throttle", {"pusher": pusher_id, "room_id": room_id}, params, - desc="set_throttle_params" + desc="set_throttle_params", + lock=False, ) diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index f42b8014c7..12b3cc7f5f 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -27,8 +27,8 @@ logger = logging.getLogger(__name__) class ReceiptsStore(SQLBaseStore): - def __init__(self, hs): - super(ReceiptsStore, self).__init__(hs) + def __init__(self, db_conn, hs): + super(ReceiptsStore, self).__init__(db_conn, hs) self._receipts_stream_cache = StreamChangeCache( "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token() diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 20acd58fcf..3aa810981f 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -24,8 +24,8 @@ from synapse.util.caches.descriptors import cached, cachedInlineCallbacks class RegistrationStore(background_updates.BackgroundUpdateStore): - def __init__(self, hs): - super(RegistrationStore, self).__init__(hs) + def __init__(self, db_conn, hs): + super(RegistrationStore, self).__init__(db_conn, hs) self.clock = hs.get_clock() @@ -36,12 +36,15 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): columns=["user_id", "device_id"], ) - self.register_background_index_update( - "refresh_tokens_device_index", - index_name="refresh_tokens_device_id", - table="refresh_tokens", - columns=["user_id", "device_id"], - ) + # we no longer use refresh tokens, but it's possible that some people + # might have a background update queued to build this index. Just + # clear the background update. + @defer.inlineCallbacks + def noop_update(progress, batch_size): + yield self._end_background_update("refresh_tokens_device_index") + defer.returnValue(1) + self.register_background_update_handler( + "refresh_tokens_device_index", noop_update) @defer.inlineCallbacks def add_access_token_to_user(self, user_id, token, device_id=None): @@ -177,9 +180,11 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): ) if create_profile_with_localpart: + # set a default displayname serverside to avoid ugly race + # between auto-joins and clients trying to set displaynames txn.execute( - "INSERT INTO profiles(user_id) VALUES (?)", - (create_profile_with_localpart,) + "INSERT INTO profiles(user_id, displayname) VALUES (?,?)", + (create_profile_with_localpart, create_profile_with_localpart) ) self._invalidate_cache_and_stream( @@ -236,12 +241,10 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): "user_set_password_hash", user_set_password_hash_txn ) - @defer.inlineCallbacks def user_delete_access_tokens(self, user_id, except_token_id=None, - device_id=None, - delete_refresh_tokens=False): + device_id=None): """ - Invalidate access/refresh tokens belonging to a user + Invalidate access tokens belonging to a user Args: user_id (str): ID of user the tokens belong to @@ -250,10 +253,9 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): device_id (str|None): ID of device the tokens are associated with. If None, tokens associated with any device (or no device) will be deleted - delete_refresh_tokens (bool): True to delete refresh tokens as - well as access tokens. Returns: - defer.Deferred: + defer.Deferred[list[str, int, str|None, int]]: a list of + (token, token id, device id) for each of the deleted tokens """ def f(txn): keyvalues = { @@ -262,13 +264,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): if device_id is not None: keyvalues["device_id"] = device_id - if delete_refresh_tokens: - self._simple_delete_txn( - txn, - table="refresh_tokens", - keyvalues=keyvalues, - ) - items = keyvalues.items() where_clause = " AND ".join(k + " = ?" for k, _ in items) values = [v for _, v in items] @@ -277,14 +272,14 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): values.append(except_token_id) txn.execute( - "SELECT token FROM access_tokens WHERE %s" % where_clause, + "SELECT token, id, device_id FROM access_tokens WHERE %s" % where_clause, values ) - rows = self.cursor_to_dict(txn) + tokens_and_devices = [(r[0], r[1], r[2]) for r in txn] - for row in rows: + for token, _, _ in tokens_and_devices: self._invalidate_cache_and_stream( - txn, self.get_user_by_access_token, (row["token"],) + txn, self.get_user_by_access_token, (token,) ) txn.execute( @@ -292,7 +287,9 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): values ) - yield self.runInteraction( + return tokens_and_devices + + return self.runInteraction( "user_delete_access_tokens", f, ) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 457ca288d0..3e77fd3901 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -49,8 +49,8 @@ _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update" class RoomMemberStore(SQLBaseStore): - def __init__(self, hs): - super(RoomMemberStore, self).__init__(hs) + def __init__(self, db_conn, hs): + super(RoomMemberStore, self).__init__(db_conn, hs) self.register_background_update_handler( _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile ) @@ -533,6 +533,46 @@ class RoomMemberStore(SQLBaseStore): defer.returnValue(True) + @cachedInlineCallbacks() + def was_host_joined(self, room_id, host): + """Check whether the server is or ever was in the room. + + Args: + room_id (str) + host (str) + + Returns: + Deferred: Resolves to True if the host is/was in the room, otherwise + False. + """ + if '%' in host or '_' in host: + raise Exception("Invalid host name") + + sql = """ + SELECT user_id FROM room_memberships + WHERE room_id = ? + AND user_id LIKE ? + AND membership = 'join' + LIMIT 1 + """ + + # We do need to be careful to ensure that host doesn't have any wild cards + # in it, but we checked above for known ones and we'll check below that + # the returned user actually has the correct domain. + like_clause = "%:" + host + + rows = yield self._execute("was_host_joined", None, sql, room_id, like_clause) + + if not rows: + defer.returnValue(False) + + user_id = rows[0][0] + if get_domain_from_id(user_id) != host: + # This can only happen if the host name has something funky in it + raise Exception("Invalid host name") + + defer.returnValue(True) + def get_joined_hosts(self, room_id, state_entry): state_group = state_entry.state_group if not state_group: @@ -596,7 +636,7 @@ class RoomMemberStore(SQLBaseStore): room_id = row["room_id"] try: content = json.loads(row["content"]) - except: + except Exception: continue display_name = content.get("displayname", None) diff --git a/synapse/storage/schema/delta/30/as_users.py b/synapse/storage/schema/delta/30/as_users.py index 5b7d8d1ab5..c53e53c94f 100644 --- a/synapse/storage/schema/delta/30/as_users.py +++ b/synapse/storage/schema/delta/30/as_users.py @@ -22,7 +22,7 @@ def run_create(cur, database_engine, *args, **kwargs): # NULL indicates user was not registered by an appservice. try: cur.execute("ALTER TABLE users ADD COLUMN appservice_id TEXT") - except: + except Exception: # Maybe we already added the column? Hope so... pass diff --git a/synapse/storage/schema/delta/43/user_share.sql b/synapse/storage/schema/delta/43/user_share.sql index 4501d90cbb..ee7062abe4 100644 --- a/synapse/storage/schema/delta/43/user_share.sql +++ b/synapse/storage/schema/delta/43/user_share.sql @@ -29,5 +29,5 @@ CREATE INDEX users_who_share_rooms_r_idx ON users_who_share_rooms(room_id); CREATE INDEX users_who_share_rooms_o_idx ON users_who_share_rooms(other_user_id); --- Make sure that we popualte the table initially +-- Make sure that we populate the table initially UPDATE user_directory_stream_pos SET stream_id = NULL; diff --git a/synapse/storage/schema/delta/44/expire_url_cache.sql b/synapse/storage/schema/delta/44/expire_url_cache.sql new file mode 100644 index 0000000000..b12f9b2ebf --- /dev/null +++ b/synapse/storage/schema/delta/44/expire_url_cache.sql @@ -0,0 +1,41 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- this didn't work on SQLite 3.7 (because of lack of partial indexes), so was +-- removed and replaced with 46/local_media_repository_url_idx.sql. +-- +-- CREATE INDEX local_media_repository_url_idx ON local_media_repository(created_ts) WHERE url_cache IS NOT NULL; + +-- we need to change `expires` to `expires_ts` so that we can index on it. SQLite doesn't support +-- indices on expressions until 3.9. +CREATE TABLE local_media_repository_url_cache_new( + url TEXT, + response_code INTEGER, + etag TEXT, + expires_ts BIGINT, + og TEXT, + media_id TEXT, + download_ts BIGINT +); + +INSERT INTO local_media_repository_url_cache_new + SELECT url, response_code, etag, expires + download_ts, og, media_id, download_ts FROM local_media_repository_url_cache; + +DROP TABLE local_media_repository_url_cache; +ALTER TABLE local_media_repository_url_cache_new RENAME TO local_media_repository_url_cache; + +CREATE INDEX local_media_repository_url_cache_expires_idx ON local_media_repository_url_cache(expires_ts); +CREATE INDEX local_media_repository_url_cache_by_url_download_ts ON local_media_repository_url_cache(url, download_ts); +CREATE INDEX local_media_repository_url_cache_media_idx ON local_media_repository_url_cache(media_id); diff --git a/synapse/storage/schema/delta/45/group_server.sql b/synapse/storage/schema/delta/45/group_server.sql new file mode 100644 index 0000000000..b2333848a0 --- /dev/null +++ b/synapse/storage/schema/delta/45/group_server.sql @@ -0,0 +1,167 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE groups ( + group_id TEXT NOT NULL, + name TEXT, -- the display name of the room + avatar_url TEXT, + short_description TEXT, + long_description TEXT +); + +CREATE UNIQUE INDEX groups_idx ON groups(group_id); + + +-- list of users the group server thinks are joined +CREATE TABLE group_users ( + group_id TEXT NOT NULL, + user_id TEXT NOT NULL, + is_admin BOOLEAN NOT NULL, + is_public BOOLEAN NOT NULL -- whether the users membership can be seen by everyone +); + + +CREATE INDEX groups_users_g_idx ON group_users(group_id, user_id); +CREATE INDEX groups_users_u_idx ON group_users(user_id); + +-- list of users the group server thinks are invited +CREATE TABLE group_invites ( + group_id TEXT NOT NULL, + user_id TEXT NOT NULL +); + +CREATE INDEX groups_invites_g_idx ON group_invites(group_id, user_id); +CREATE INDEX groups_invites_u_idx ON group_invites(user_id); + + +CREATE TABLE group_rooms ( + group_id TEXT NOT NULL, + room_id TEXT NOT NULL, + is_public BOOLEAN NOT NULL -- whether the room can be seen by everyone +); + +CREATE UNIQUE INDEX groups_rooms_g_idx ON group_rooms(group_id, room_id); +CREATE INDEX groups_rooms_r_idx ON group_rooms(room_id); + + +-- Rooms to include in the summary +CREATE TABLE group_summary_rooms ( + group_id TEXT NOT NULL, + room_id TEXT NOT NULL, + category_id TEXT NOT NULL, + room_order BIGINT NOT NULL, + is_public BOOLEAN NOT NULL, -- whether the room should be show to everyone + UNIQUE (group_id, category_id, room_id, room_order), + CHECK (room_order > 0) +); + +CREATE UNIQUE INDEX group_summary_rooms_g_idx ON group_summary_rooms(group_id, room_id, category_id); + + +-- Categories to include in the summary +CREATE TABLE group_summary_room_categories ( + group_id TEXT NOT NULL, + category_id TEXT NOT NULL, + cat_order BIGINT NOT NULL, + UNIQUE (group_id, category_id, cat_order), + CHECK (cat_order > 0) +); + +-- The categories in the group +CREATE TABLE group_room_categories ( + group_id TEXT NOT NULL, + category_id TEXT NOT NULL, + profile TEXT NOT NULL, + is_public BOOLEAN NOT NULL, -- whether the category should be show to everyone + UNIQUE (group_id, category_id) +); + +-- The users to include in the group summary +CREATE TABLE group_summary_users ( + group_id TEXT NOT NULL, + user_id TEXT NOT NULL, + role_id TEXT NOT NULL, + user_order BIGINT NOT NULL, + is_public BOOLEAN NOT NULL -- whether the user should be show to everyone +); + +CREATE INDEX group_summary_users_g_idx ON group_summary_users(group_id); + +-- The roles to include in the group summary +CREATE TABLE group_summary_roles ( + group_id TEXT NOT NULL, + role_id TEXT NOT NULL, + role_order BIGINT NOT NULL, + UNIQUE (group_id, role_id, role_order), + CHECK (role_order > 0) +); + + +-- The roles in a groups +CREATE TABLE group_roles ( + group_id TEXT NOT NULL, + role_id TEXT NOT NULL, + profile TEXT NOT NULL, + is_public BOOLEAN NOT NULL, -- whether the role should be show to everyone + UNIQUE (group_id, role_id) +); + + +-- List of attestations we've given out and need to renew +CREATE TABLE group_attestations_renewals ( + group_id TEXT NOT NULL, + user_id TEXT NOT NULL, + valid_until_ms BIGINT NOT NULL +); + +CREATE INDEX group_attestations_renewals_g_idx ON group_attestations_renewals(group_id, user_id); +CREATE INDEX group_attestations_renewals_u_idx ON group_attestations_renewals(user_id); +CREATE INDEX group_attestations_renewals_v_idx ON group_attestations_renewals(valid_until_ms); + + +-- List of attestations we've received from remotes and are interested in. +CREATE TABLE group_attestations_remote ( + group_id TEXT NOT NULL, + user_id TEXT NOT NULL, + valid_until_ms BIGINT NOT NULL, + attestation_json TEXT NOT NULL +); + +CREATE INDEX group_attestations_remote_g_idx ON group_attestations_remote(group_id, user_id); +CREATE INDEX group_attestations_remote_u_idx ON group_attestations_remote(user_id); +CREATE INDEX group_attestations_remote_v_idx ON group_attestations_remote(valid_until_ms); + + +-- The group membership for the HS's users +CREATE TABLE local_group_membership ( + group_id TEXT NOT NULL, + user_id TEXT NOT NULL, + is_admin BOOLEAN NOT NULL, + membership TEXT NOT NULL, + is_publicised BOOLEAN NOT NULL, -- if the user is publicising their membership + content TEXT NOT NULL +); + +CREATE INDEX local_group_membership_u_idx ON local_group_membership(user_id, group_id); +CREATE INDEX local_group_membership_g_idx ON local_group_membership(group_id); + + +CREATE TABLE local_group_updates ( + stream_id BIGINT NOT NULL, + group_id TEXT NOT NULL, + user_id TEXT NOT NULL, + type TEXT NOT NULL, + content TEXT NOT NULL +); diff --git a/synapse/storage/schema/delta/23/refresh_tokens.sql b/synapse/storage/schema/delta/45/profile_cache.sql index 34db0cf12b..e5ddc84df0 100644 --- a/synapse/storage/schema/delta/23/refresh_tokens.sql +++ b/synapse/storage/schema/delta/45/profile_cache.sql @@ -1,4 +1,4 @@ -/* Copyright 2015, 2016 OpenMarket Ltd +/* Copyright 2017 New Vector Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,9 +13,16 @@ * limitations under the License. */ -CREATE TABLE IF NOT EXISTS refresh_tokens( - id INTEGER PRIMARY KEY, - token TEXT NOT NULL, + +-- A subset of remote users whose profiles we have cached. +-- Whether a user is in this table or not is defined by the storage function +-- `is_subscribed_remote_profile_for_user` +CREATE TABLE remote_profile_cache ( user_id TEXT NOT NULL, - UNIQUE (token) + displayname TEXT, + avatar_url TEXT, + last_check BIGINT NOT NULL ); + +CREATE UNIQUE INDEX remote_profile_cache_user_id ON remote_profile_cache(user_id); +CREATE INDEX remote_profile_cache_time ON remote_profile_cache(last_check); diff --git a/synapse/storage/schema/delta/33/refreshtoken_device.sql b/synapse/storage/schema/delta/46/drop_refresh_tokens.sql index 290bd6da86..68c48a89a9 100644 --- a/synapse/storage/schema/delta/33/refreshtoken_device.sql +++ b/synapse/storage/schema/delta/46/drop_refresh_tokens.sql @@ -1,4 +1,4 @@ -/* Copyright 2016 OpenMarket Ltd +/* Copyright 2017 New Vector Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,4 +13,5 @@ * limitations under the License. */ -ALTER TABLE refresh_tokens ADD COLUMN device_id TEXT; +/* we no longer use (or create) the refresh_tokens table */ +DROP TABLE IF EXISTS refresh_tokens; diff --git a/synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql b/synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql new file mode 100644 index 0000000000..bb307889c1 --- /dev/null +++ b/synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql @@ -0,0 +1,35 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- drop the unique constraint on deleted_pushers so that we can just insert +-- into it rather than upserting. + +CREATE TABLE deleted_pushers2 ( + stream_id BIGINT NOT NULL, + app_id TEXT NOT NULL, + pushkey TEXT NOT NULL, + user_id TEXT NOT NULL +); + +INSERT INTO deleted_pushers2 (stream_id, app_id, pushkey, user_id) + SELECT stream_id, app_id, pushkey, user_id from deleted_pushers; + +DROP TABLE deleted_pushers; +ALTER TABLE deleted_pushers2 RENAME TO deleted_pushers; + +-- create the index after doing the inserts because that's more efficient. +-- it also means we can give it the same name as the old one without renaming. +CREATE INDEX deleted_pushers_stream_id ON deleted_pushers (stream_id); + diff --git a/synapse/storage/schema/delta/46/group_server.sql b/synapse/storage/schema/delta/46/group_server.sql new file mode 100644 index 0000000000..097679bc9a --- /dev/null +++ b/synapse/storage/schema/delta/46/group_server.sql @@ -0,0 +1,32 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE groups_new ( + group_id TEXT NOT NULL, + name TEXT, -- the display name of the room + avatar_url TEXT, + short_description TEXT, + long_description TEXT, + is_public BOOL NOT NULL -- whether non-members can access group APIs +); + +-- NB: awful hack to get the default to be true on postgres and 1 on sqlite +INSERT INTO groups_new + SELECT group_id, name, avatar_url, short_description, long_description, (1=1) FROM groups; + +DROP TABLE groups; +ALTER TABLE groups_new RENAME TO groups; + +CREATE UNIQUE INDEX groups_idx ON groups(group_id); diff --git a/synapse/storage/schema/delta/33/refreshtoken_device_index.sql b/synapse/storage/schema/delta/46/local_media_repository_url_idx.sql index bb225dafbf..bbfc7f5d1a 100644 --- a/synapse/storage/schema/delta/33/refreshtoken_device_index.sql +++ b/synapse/storage/schema/delta/46/local_media_repository_url_idx.sql @@ -1,4 +1,4 @@ -/* Copyright 2016 OpenMarket Ltd +/* Copyright 2017 New Vector Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,5 +13,12 @@ * limitations under the License. */ +-- register a background update which will recreate the +-- local_media_repository_url_idx index. +-- +-- We do this as a bg update not because it is a particularly onerous +-- operation, but because we'd like it to be a partial index if possible, and +-- the background_index_update code will understand whether we are on +-- postgres or sqlite and behave accordingly. INSERT INTO background_updates (update_name, progress_json) VALUES - ('refresh_tokens_device_index', '{}'); + ('local_media_repository_url_idx', '{}'); diff --git a/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql b/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql new file mode 100644 index 0000000000..cb0d5a2576 --- /dev/null +++ b/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql @@ -0,0 +1,35 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- change the user_directory table to also cover global local user profiles +-- rather than just profiles within specific rooms. + +CREATE TABLE user_directory2 ( + user_id TEXT NOT NULL, + room_id TEXT, + display_name TEXT, + avatar_url TEXT +); + +INSERT INTO user_directory2(user_id, room_id, display_name, avatar_url) + SELECT user_id, room_id, display_name, avatar_url from user_directory; + +DROP TABLE user_directory; +ALTER TABLE user_directory2 RENAME TO user_directory; + +-- create indexes after doing the inserts because that's more efficient. +-- it also means we can give it the same name as the old one without renaming. +CREATE INDEX user_directory_room_idx ON user_directory(room_id); +CREATE UNIQUE INDEX user_directory_user_idx ON user_directory(user_id); diff --git a/synapse/storage/schema/delta/46/user_dir_typos.sql b/synapse/storage/schema/delta/46/user_dir_typos.sql new file mode 100644 index 0000000000..d9505f8da1 --- /dev/null +++ b/synapse/storage/schema/delta/46/user_dir_typos.sql @@ -0,0 +1,24 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- this is just embarassing :| +ALTER TABLE users_in_pubic_room RENAME TO users_in_public_rooms; + +-- this is only 300K rows on matrix.org and takes ~3s to generate the index, +-- so is hopefully not going to block anyone else for that long... +CREATE INDEX users_in_public_rooms_room_idx ON users_in_public_rooms(room_id); +CREATE UNIQUE INDEX users_in_public_rooms_user_idx ON users_in_public_rooms(user_id); +DROP INDEX users_in_pubic_room_room_idx; +DROP INDEX users_in_pubic_room_user_idx; diff --git a/synapse/storage/schema/schema_version.sql b/synapse/storage/schema/schema_version.sql index a7ade69986..42e5cb6df5 100644 --- a/synapse/storage/schema/schema_version.sql +++ b/synapse/storage/schema/schema_version.sql @@ -25,3 +25,10 @@ CREATE TABLE IF NOT EXISTS applied_schema_deltas( file TEXT NOT NULL, UNIQUE(version, file) ); + +-- a list of schema files we have loaded on behalf of dynamic modules +CREATE TABLE IF NOT EXISTS applied_module_schemas( + module_name TEXT NOT NULL, + file TEXT NOT NULL, + UNIQUE(module_name, file) +); diff --git a/synapse/storage/search.py b/synapse/storage/search.py index 8f2b3c4435..479b04c636 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -33,8 +33,8 @@ class SearchStore(BackgroundUpdateStore): EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order" EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist" - def __init__(self, hs): - super(SearchStore, self).__init__(hs) + def __init__(self, db_conn, hs): + super(SearchStore, self).__init__(db_conn, hs) self.register_background_update_handler( self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search ) @@ -81,7 +81,7 @@ class SearchStore(BackgroundUpdateStore): etype = row["type"] try: content = json.loads(row["content"]) - except: + except Exception: continue if etype == "m.room.message": @@ -407,7 +407,7 @@ class SearchStore(BackgroundUpdateStore): origin_server_ts, stream = pagination_token.split(",") origin_server_ts = int(origin_server_ts) stream = int(stream) - except: + except Exception: raise SynapseError(400, "Invalid pagination token") clauses.append( diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 5673e4aa96..360e3e4355 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -13,16 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore -from synapse.util.caches.descriptors import cached, cachedList -from synapse.util.caches import intern_string -from synapse.util.stringutils import to_ascii -from synapse.storage.engines import PostgresEngine +from collections import namedtuple +import logging from twisted.internet import defer -from collections import namedtuple -import logging +from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.storage.engines import PostgresEngine +from synapse.util.caches import intern_string, CACHE_SIZE_FACTOR +from synapse.util.caches.descriptors import cached, cachedList +from synapse.util.caches.dictionary_cache import DictionaryCache +from synapse.util.stringutils import to_ascii +from ._base import SQLBaseStore logger = logging.getLogger(__name__) @@ -40,45 +42,22 @@ class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delt return len(self.delta_ids) if self.delta_ids else 0 -class StateStore(SQLBaseStore): - """ Keeps track of the state at a given event. +class StateGroupReadStore(SQLBaseStore): + """The read-only parts of StateGroupStore - This is done by the concept of `state groups`. Every event is a assigned - a state group (identified by an arbitrary string), which references a - collection of state events. The current state of an event is then the - collection of state events referenced by the event's state group. - - Hence, every change in the current state causes a new state group to be - generated. However, if no change happens (e.g., if we get a message event - with only one parent it inherits the state group from its parent.) - - There are three tables: - * `state_groups`: Stores group name, first event with in the group and - room id. - * `event_to_state_groups`: Maps events to state groups. - * `state_groups_state`: Maps state group to state events. + None of these functions write to the state tables, so are suitable for + including in the SlavedStores. """ STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" - def __init__(self, hs): - super(StateStore, self).__init__(hs) - self.register_background_update_handler( - self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, - self._background_deduplicate_state, - ) - self.register_background_update_handler( - self.STATE_GROUP_INDEX_UPDATE_NAME, - self._background_index_state, - ) - self.register_background_index_update( - self.CURRENT_STATE_INDEX_UPDATE_NAME, - index_name="current_state_events_member_index", - table="current_state_events", - columns=["state_key"], - where_clause="type='m.room.member'", + def __init__(self, db_conn, hs): + super(StateGroupReadStore, self).__init__(db_conn, hs) + + self._state_group_cache = DictionaryCache( + "*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR ) @cached(max_entries=100000, iterable=True) @@ -190,178 +169,6 @@ class StateStore(SQLBaseStore): for group, event_id_map in group_to_ids.iteritems() }) - def _have_persisted_state_group_txn(self, txn, state_group): - txn.execute( - "SELECT count(*) FROM state_groups WHERE id = ?", - (state_group,) - ) - row = txn.fetchone() - return row and row[0] - - def _store_mult_state_groups_txn(self, txn, events_and_contexts): - state_groups = {} - for event, context in events_and_contexts: - if event.internal_metadata.is_outlier(): - continue - - if context.current_state_ids is None: - # AFAIK, this can never happen - logger.error( - "Non-outlier event %s had current_state_ids==None", - event.event_id) - continue - - # if the event was rejected, just give it the same state as its - # predecessor. - if context.rejected: - state_groups[event.event_id] = context.prev_group - continue - - state_groups[event.event_id] = context.state_group - - if self._have_persisted_state_group_txn(txn, context.state_group): - continue - - self._simple_insert_txn( - txn, - table="state_groups", - values={ - "id": context.state_group, - "room_id": event.room_id, - "event_id": event.event_id, - }, - ) - - # We persist as a delta if we can, while also ensuring the chain - # of deltas isn't tooo long, as otherwise read performance degrades. - if context.prev_group: - is_in_db = self._simple_select_one_onecol_txn( - txn, - table="state_groups", - keyvalues={"id": context.prev_group}, - retcol="id", - allow_none=True, - ) - if not is_in_db: - raise Exception( - "Trying to persist state with unpersisted prev_group: %r" - % (context.prev_group,) - ) - - potential_hops = self._count_state_group_hops_txn( - txn, context.prev_group - ) - if context.prev_group and potential_hops < MAX_STATE_DELTA_HOPS: - self._simple_insert_txn( - txn, - table="state_group_edges", - values={ - "state_group": context.state_group, - "prev_state_group": context.prev_group, - }, - ) - - self._simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": context.state_group, - "room_id": event.room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in context.delta_ids.iteritems() - ], - ) - else: - self._simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": context.state_group, - "room_id": event.room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in context.current_state_ids.iteritems() - ], - ) - - # Prefill the state group cache with this group. - # It's fine to use the sequence like this as the state group map - # is immutable. (If the map wasn't immutable then this prefill could - # race with another update) - txn.call_after( - self._state_group_cache.update, - self._state_group_cache.sequence, - key=context.state_group, - value=dict(context.current_state_ids), - full=True, - ) - - self._simple_insert_many_txn( - txn, - table="event_to_state_groups", - values=[ - { - "state_group": state_group_id, - "event_id": event_id, - } - for event_id, state_group_id in state_groups.iteritems() - ], - ) - - for event_id, state_group_id in state_groups.iteritems(): - txn.call_after( - self._get_state_group_for_event.prefill, - (event_id,), state_group_id - ) - - def _count_state_group_hops_txn(self, txn, state_group): - """Given a state group, count how many hops there are in the tree. - - This is used to ensure the delta chains don't get too long. - """ - if isinstance(self.database_engine, PostgresEngine): - sql = (""" - WITH RECURSIVE state(state_group) AS ( - VALUES(?::bigint) - UNION ALL - SELECT prev_state_group FROM state_group_edges e, state s - WHERE s.state_group = e.state_group - ) - SELECT count(*) FROM state; - """) - - txn.execute(sql, (state_group,)) - row = txn.fetchone() - if row and row[0]: - return row[0] - else: - return 0 - else: - # We don't use WITH RECURSIVE on sqlite3 as there are distributions - # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) - next_group = state_group - count = 0 - - while next_group: - next_group = self._simple_select_one_onecol_txn( - txn, - table="state_group_edges", - keyvalues={"state_group": next_group}, - retcol="prev_state_group", - allow_none=True, - ) - if next_group: - count += 1 - - return count - @defer.inlineCallbacks def _get_state_groups_from_groups(self, groups, types): """Returns dictionary state_group -> (dict of (type, state_key) -> event id) @@ -742,6 +549,220 @@ class StateStore(SQLBaseStore): defer.returnValue(results) + +class StateStore(StateGroupReadStore, BackgroundUpdateStore): + """ Keeps track of the state at a given event. + + This is done by the concept of `state groups`. Every event is a assigned + a state group (identified by an arbitrary string), which references a + collection of state events. The current state of an event is then the + collection of state events referenced by the event's state group. + + Hence, every change in the current state causes a new state group to be + generated. However, if no change happens (e.g., if we get a message event + with only one parent it inherits the state group from its parent.) + + There are three tables: + * `state_groups`: Stores group name, first event with in the group and + room id. + * `event_to_state_groups`: Maps events to state groups. + * `state_groups_state`: Maps state group to state events. + """ + + STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" + STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" + CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" + + def __init__(self, db_conn, hs): + super(StateStore, self).__init__(db_conn, hs) + self.register_background_update_handler( + self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, + self._background_deduplicate_state, + ) + self.register_background_update_handler( + self.STATE_GROUP_INDEX_UPDATE_NAME, + self._background_index_state, + ) + self.register_background_index_update( + self.CURRENT_STATE_INDEX_UPDATE_NAME, + index_name="current_state_events_member_index", + table="current_state_events", + columns=["state_key"], + where_clause="type='m.room.member'", + ) + + def _have_persisted_state_group_txn(self, txn, state_group): + txn.execute( + "SELECT count(*) FROM state_groups WHERE id = ?", + (state_group,) + ) + row = txn.fetchone() + return row and row[0] + + def _store_mult_state_groups_txn(self, txn, events_and_contexts): + state_groups = {} + for event, context in events_and_contexts: + if event.internal_metadata.is_outlier(): + continue + + if context.current_state_ids is None: + # AFAIK, this can never happen + logger.error( + "Non-outlier event %s had current_state_ids==None", + event.event_id) + continue + + # if the event was rejected, just give it the same state as its + # predecessor. + if context.rejected: + state_groups[event.event_id] = context.prev_group + continue + + state_groups[event.event_id] = context.state_group + + if self._have_persisted_state_group_txn(txn, context.state_group): + continue + + self._simple_insert_txn( + txn, + table="state_groups", + values={ + "id": context.state_group, + "room_id": event.room_id, + "event_id": event.event_id, + }, + ) + + # We persist as a delta if we can, while also ensuring the chain + # of deltas isn't tooo long, as otherwise read performance degrades. + if context.prev_group: + is_in_db = self._simple_select_one_onecol_txn( + txn, + table="state_groups", + keyvalues={"id": context.prev_group}, + retcol="id", + allow_none=True, + ) + if not is_in_db: + raise Exception( + "Trying to persist state with unpersisted prev_group: %r" + % (context.prev_group,) + ) + + potential_hops = self._count_state_group_hops_txn( + txn, context.prev_group + ) + if context.prev_group and potential_hops < MAX_STATE_DELTA_HOPS: + self._simple_insert_txn( + txn, + table="state_group_edges", + values={ + "state_group": context.state_group, + "prev_state_group": context.prev_group, + }, + ) + + self._simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": context.state_group, + "room_id": event.room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in context.delta_ids.iteritems() + ], + ) + else: + self._simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": context.state_group, + "room_id": event.room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in context.current_state_ids.iteritems() + ], + ) + + # Prefill the state group cache with this group. + # It's fine to use the sequence like this as the state group map + # is immutable. (If the map wasn't immutable then this prefill could + # race with another update) + txn.call_after( + self._state_group_cache.update, + self._state_group_cache.sequence, + key=context.state_group, + value=dict(context.current_state_ids), + full=True, + ) + + self._simple_insert_many_txn( + txn, + table="event_to_state_groups", + values=[ + { + "state_group": state_group_id, + "event_id": event_id, + } + for event_id, state_group_id in state_groups.iteritems() + ], + ) + + for event_id, state_group_id in state_groups.iteritems(): + txn.call_after( + self._get_state_group_for_event.prefill, + (event_id,), state_group_id + ) + + def _count_state_group_hops_txn(self, txn, state_group): + """Given a state group, count how many hops there are in the tree. + + This is used to ensure the delta chains don't get too long. + """ + if isinstance(self.database_engine, PostgresEngine): + sql = (""" + WITH RECURSIVE state(state_group) AS ( + VALUES(?::bigint) + UNION ALL + SELECT prev_state_group FROM state_group_edges e, state s + WHERE s.state_group = e.state_group + ) + SELECT count(*) FROM state; + """) + + txn.execute(sql, (state_group,)) + row = txn.fetchone() + if row and row[0]: + return row[0] + else: + return 0 + else: + # We don't use WITH RECURSIVE on sqlite3 as there are distributions + # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) + next_group = state_group + count = 0 + + while next_group: + next_group = self._simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": next_group}, + retcol="prev_state_group", + allow_none=True, + ) + if next_group: + count += 1 + + return count + def get_next_state_group(self): return self._state_groups_id_gen.get_next() diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index dddd5fc0e7..52bdce5be2 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -39,7 +39,7 @@ from ._base import SQLBaseStore from synapse.util.caches.descriptors import cached from synapse.api.constants import EventTypes from synapse.types import RoomStreamToken -from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred +from synapse.util.logcontext import make_deferred_yieldable, preserve_fn from synapse.storage.engines import PostgresEngine, Sqlite3Engine import logging @@ -234,7 +234,7 @@ class StreamStore(SQLBaseStore): results = {} room_ids = list(room_ids) for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)): - res = yield preserve_context_over_deferred(defer.gatherResults([ + res = yield make_deferred_yieldable(defer.gatherResults([ preserve_fn(self.get_room_events_stream_for_room)( room_id, from_key, to_key, limit, order=order, ) diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index 809fdd311f..8f61f7ffae 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -46,8 +46,8 @@ class TransactionStore(SQLBaseStore): """A collection of queries for handling PDUs. """ - def __init__(self, hs): - super(TransactionStore, self).__init__(hs) + def __init__(self, db_conn, hs): + super(TransactionStore, self).__init__(db_conn, hs) self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000) diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py index 2a4db3f03c..c9bff408ef 100644 --- a/synapse/storage/user_directory.py +++ b/synapse/storage/user_directory.py @@ -63,7 +63,7 @@ class UserDirectoryStore(SQLBaseStore): user_ids (list(str)): Users to add """ yield self._simple_insert_many( - table="users_in_pubic_room", + table="users_in_public_rooms", values=[ { "user_id": user_id, @@ -164,7 +164,7 @@ class UserDirectoryStore(SQLBaseStore): ) if isinstance(self.database_engine, PostgresEngine): - # We weight the loclpart most highly, then display name and finally + # We weight the localpart most highly, then display name and finally # server name if new_entry: sql = """ @@ -219,7 +219,7 @@ class UserDirectoryStore(SQLBaseStore): @defer.inlineCallbacks def update_user_in_public_user_list(self, user_id, room_id): yield self._simple_update_one( - table="users_in_pubic_room", + table="users_in_public_rooms", keyvalues={"user_id": user_id}, updatevalues={"room_id": room_id}, desc="update_user_in_public_user_list", @@ -240,7 +240,7 @@ class UserDirectoryStore(SQLBaseStore): ) self._simple_delete_txn( txn, - table="users_in_pubic_room", + table="users_in_public_rooms", keyvalues={"user_id": user_id}, ) txn.call_after( @@ -256,7 +256,7 @@ class UserDirectoryStore(SQLBaseStore): @defer.inlineCallbacks def remove_from_user_in_public_room(self, user_id): yield self._simple_delete( - table="users_in_pubic_room", + table="users_in_public_rooms", keyvalues={"user_id": user_id}, desc="remove_from_user_in_public_room", ) @@ -267,7 +267,7 @@ class UserDirectoryStore(SQLBaseStore): in the given room_id """ return self._simple_select_onecol( - table="users_in_pubic_room", + table="users_in_public_rooms", keyvalues={"room_id": room_id}, retcol="user_id", desc="get_users_in_public_due_to_room", @@ -286,7 +286,7 @@ class UserDirectoryStore(SQLBaseStore): ) user_ids_pub = yield self._simple_select_onecol( - table="users_in_pubic_room", + table="users_in_public_rooms", keyvalues={"room_id": room_id}, retcol="user_id", desc="get_users_in_dir_due_to_room", @@ -317,6 +317,16 @@ class UserDirectoryStore(SQLBaseStore): rows = yield self._execute("get_all_rooms", None, sql) defer.returnValue([room_id for room_id, in rows]) + @defer.inlineCallbacks + def get_all_local_users(self): + """Get all local users + """ + sql = """ + SELECT name FROM users + """ + rows = yield self._execute("get_all_local_users", None, sql) + defer.returnValue([name for name, in rows]) + def add_users_who_share_room(self, room_id, share_private, user_id_tuples): """Insert entries into the users_who_share_rooms table. The first user should be a local user. @@ -514,7 +524,7 @@ class UserDirectoryStore(SQLBaseStore): def _delete_all_from_user_dir_txn(txn): txn.execute("DELETE FROM user_directory") txn.execute("DELETE FROM user_directory_search") - txn.execute("DELETE FROM users_in_pubic_room") + txn.execute("DELETE FROM users_in_public_rooms") txn.execute("DELETE FROM users_who_share_rooms") txn.call_after(self.get_user_in_directory.invalidate_all) txn.call_after(self.get_user_in_public_room.invalidate_all) @@ -537,7 +547,7 @@ class UserDirectoryStore(SQLBaseStore): @cached() def get_user_in_public_room(self, user_id): return self._simple_select_one( - table="users_in_pubic_room", + table="users_in_public_rooms", keyvalues={"user_id": user_id}, retcols=("room_id",), allow_none=True, @@ -629,6 +639,20 @@ class UserDirectoryStore(SQLBaseStore): ] } """ + + if self.hs.config.user_directory_search_all_users: + join_clause = "" + where_clause = "?<>''" # naughty hack to keep the same number of binds + else: + join_clause = """ + LEFT JOIN users_in_public_rooms AS p USING (user_id) + LEFT JOIN ( + SELECT other_user_id AS user_id FROM users_who_share_rooms + WHERE user_id = ? AND share_private + ) AS s USING (user_id) + """ + where_clause = "(s.user_id IS NOT NULL OR p.user_id IS NOT NULL)" + if isinstance(self.database_engine, PostgresEngine): full_query, exact_query, prefix_query = _parse_query_postgres(search_term) @@ -641,13 +665,9 @@ class UserDirectoryStore(SQLBaseStore): SELECT d.user_id, display_name, avatar_url FROM user_directory_search INNER JOIN user_directory AS d USING (user_id) - LEFT JOIN users_in_pubic_room AS p USING (user_id) - LEFT JOIN ( - SELECT other_user_id AS user_id FROM users_who_share_rooms - WHERE user_id = ? AND share_private - ) AS s USING (user_id) + %s WHERE - (s.user_id IS NOT NULL OR p.user_id IS NOT NULL) + %s AND vector @@ to_tsquery('english', ?) ORDER BY (CASE WHEN s.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END) @@ -671,7 +691,7 @@ class UserDirectoryStore(SQLBaseStore): display_name IS NULL, avatar_url IS NULL LIMIT ? - """ + """ % (join_clause, where_clause) args = (user_id, full_query, exact_query, prefix_query, limit + 1,) elif isinstance(self.database_engine, Sqlite3Engine): search_query = _parse_query_sqlite(search_term) @@ -680,20 +700,16 @@ class UserDirectoryStore(SQLBaseStore): SELECT d.user_id, display_name, avatar_url FROM user_directory_search INNER JOIN user_directory AS d USING (user_id) - LEFT JOIN users_in_pubic_room AS p USING (user_id) - LEFT JOIN ( - SELECT other_user_id AS user_id FROM users_who_share_rooms - WHERE user_id = ? AND share_private - ) AS s USING (user_id) + %s WHERE - (s.user_id IS NOT NULL OR p.user_id IS NOT NULL) + %s AND value MATCH ? ORDER BY rank(matchinfo(user_directory_search)) DESC, display_name IS NULL, avatar_url IS NULL LIMIT ? - """ + """ % (join_clause, where_clause) args = (user_id, search_query, limit + 1) else: # This should be unreachable. @@ -723,7 +739,7 @@ def _parse_query_sqlite(search_term): # Pull out the individual words, discarding any non-word characters. results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) - return " & ".join("(%s* | %s)" % (result, result,) for result in results) + return " & ".join("(%s* OR %s)" % (result, result,) for result in results) def _parse_query_postgres(search_term): diff --git a/synapse/streams/config.py b/synapse/streams/config.py index 4f089bfb94..ca78e551cb 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py @@ -80,13 +80,13 @@ class PaginationConfig(object): from_tok = None # For backwards compat. elif from_tok: from_tok = StreamToken.from_string(from_tok) - except: + except Exception: raise SynapseError(400, "'from' paramater is invalid") try: if to_tok: to_tok = StreamToken.from_string(to_tok) - except: + except Exception: raise SynapseError(400, "'to' paramater is invalid") limit = get_param("limit", None) @@ -98,7 +98,7 @@ class PaginationConfig(object): try: return PaginationConfig(from_tok, to_tok, direction, limit) - except: + except Exception: logger.exception("Failed to create pagination config") raise SynapseError(400, "Invalid request.") diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 91a59b0bae..f03ad99118 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -45,6 +45,7 @@ class EventSources(object): push_rules_key, _ = self.store.get_push_rules_stream_token() to_device_key = self.store.get_to_device_stream_token() device_list_key = self.store.get_device_stream_token() + groups_key = self.store.get_group_stream_token() token = StreamToken( room_key=( @@ -65,6 +66,7 @@ class EventSources(object): push_rules_key=push_rules_key, to_device_key=to_device_key, device_list_key=device_list_key, + groups_key=groups_key, ) defer.returnValue(token) @@ -73,6 +75,7 @@ class EventSources(object): push_rules_key, _ = self.store.get_push_rules_stream_token() to_device_key = self.store.get_to_device_stream_token() device_list_key = self.store.get_device_stream_token() + groups_key = self.store.get_group_stream_token() token = StreamToken( room_key=( @@ -93,5 +96,6 @@ class EventSources(object): push_rules_key=push_rules_key, to_device_key=to_device_key, device_list_key=device_list_key, + groups_key=groups_key, ) defer.returnValue(token) diff --git a/synapse/types.py b/synapse/types.py index 111948540d..6e76c016d9 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import string from synapse.api.errors import SynapseError @@ -126,15 +127,11 @@ class DomainSpecificString( try: cls.from_string(s) return True - except: + except Exception: return False __str__ = to_string - @classmethod - def create(cls, localpart, domain,): - return cls(localpart=localpart, domain=domain) - class UserID(DomainSpecificString): """Structure representing a user ID.""" @@ -156,6 +153,43 @@ class EventID(DomainSpecificString): SIGIL = "$" +class GroupID(DomainSpecificString): + """Structure representing a group ID.""" + SIGIL = "+" + + @classmethod + def from_string(cls, s): + group_id = super(GroupID, cls).from_string(s) + if not group_id.localpart: + raise SynapseError( + 400, + "Group ID cannot be empty", + ) + + if contains_invalid_mxid_characters(group_id.localpart): + raise SynapseError( + 400, + "Group ID can only contain characters a-z, 0-9, or '=_-./'", + ) + + return group_id + + +mxid_localpart_allowed_characters = set("_-./=" + string.ascii_lowercase + string.digits) + + +def contains_invalid_mxid_characters(localpart): + """Check for characters not allowed in an mxid or groupid localpart + + Args: + localpart (basestring): the localpart to be checked + + Returns: + bool: True if there are any naughty characters + """ + return any(c not in mxid_localpart_allowed_characters for c in localpart) + + class StreamToken( namedtuple("Token", ( "room_key", @@ -166,6 +200,7 @@ class StreamToken( "push_rules_key", "to_device_key", "device_list_key", + "groups_key", )) ): _SEPARATOR = "_" @@ -178,7 +213,7 @@ class StreamToken( # i.e. old token from before receipt_key keys.append("0") return cls(*keys) - except: + except Exception: raise SynapseError(400, "Invalid Token") def to_string(self): @@ -204,6 +239,7 @@ class StreamToken( or (int(other.push_rules_key) < int(self.push_rules_key)) or (int(other.to_device_key) < int(self.to_device_key)) or (int(other.device_list_key) < int(self.device_list_key)) + or (int(other.groups_key) < int(self.groups_key)) ) def copy_and_advance(self, key, new_value): @@ -263,7 +299,7 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")): if string[0] == 't': parts = string[1:].split('-', 1) return cls(topological=int(parts[0]), stream=int(parts[1])) - except: + except Exception: pass raise SynapseError(400, "Invalid token %r" % (string,)) @@ -272,7 +308,7 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")): try: if string[0] == 's': return cls(topological=None, stream=int(string[1:])) - except: + except Exception: pass raise SynapseError(400, "Invalid token %r" % (string,)) diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 2a2360ab5d..756d8ffa32 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -59,9 +59,9 @@ class Clock(object): f(function): The function to call repeatedly. msec(float): How long to wait between calls in milliseconds. """ - l = task.LoopingCall(f) - l.start(msec / 1000.0, now=False) - return l + call = task.LoopingCall(f) + call.start(msec / 1000.0, now=False) + return call def call_later(self, delay, callback, *args, **kwargs): """Call something later @@ -82,7 +82,7 @@ class Clock(object): def cancel_call_later(self, timer, ignore_errs=False): try: timer.cancel() - except: + except Exception: if not ignore_errs: raise @@ -97,12 +97,12 @@ class Clock(object): try: ret_deferred.errback(e) - except: + except Exception: pass try: given_deferred.cancel() - except: + except Exception: pass timer = None @@ -110,7 +110,7 @@ class Clock(object): def cancel(res): try: self.cancel_call_later(timer) - except: + except Exception: pass return res @@ -119,7 +119,7 @@ class Clock(object): def success(res): try: ret_deferred.callback(res) - except: + except Exception: pass return res @@ -127,7 +127,7 @@ class Clock(object): def err(res): try: ret_deferred.errback(res) - except: + except Exception: pass given_deferred.addCallbacks(callback=success, errback=err) diff --git a/synapse/util/async.py b/synapse/util/async.py index 1453faf0ef..0729bb2863 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -17,9 +17,9 @@ from twisted.internet import defer, reactor from .logcontext import ( - PreserveLoggingContext, preserve_fn, preserve_context_over_deferred, + PreserveLoggingContext, make_deferred_yieldable, preserve_fn ) -from synapse.util import unwrapFirstError +from synapse.util import logcontext, unwrapFirstError from contextlib import contextmanager @@ -53,6 +53,11 @@ class ObservableDeferred(object): Cancelling or otherwise resolving an observer will not affect the original ObservableDeferred. + + NB that it does not attempt to do anything with logcontexts; in general + you should probably make_deferred_yieldable the deferreds + returned by `observe`, and ensure that the original deferred runs its + callbacks in the sentinel logcontext. """ __slots__ = ["_deferred", "_observers", "_result"] @@ -68,7 +73,7 @@ class ObservableDeferred(object): try: # TODO: Handle errors here. self._observers.pop().callback(r) - except: + except Exception: pass return r @@ -78,7 +83,7 @@ class ObservableDeferred(object): try: # TODO: Handle errors here. self._observers.pop().errback(f) - except: + except Exception: pass if consumeErrors: @@ -155,7 +160,7 @@ def concurrently_execute(func, args, limit): except StopIteration: pass - return preserve_context_over_deferred(defer.gatherResults([ + return logcontext.make_deferred_yieldable(defer.gatherResults([ preserve_fn(_concurrently_execute_inner)() for _ in xrange(limit) ], consumeErrors=True)).addErrback(unwrapFirstError) @@ -200,10 +205,29 @@ class Linearizer(object): try: with PreserveLoggingContext(): yield current_defer - except: + except Exception: logger.exception("Unexpected exception in Linearizer") - logger.info("Acquired linearizer lock %r for key %r", self.name, key) + logger.info("Acquired linearizer lock %r for key %r", self.name, + key) + + # if the code holding the lock completes synchronously, then it + # will recursively run the next claimant on the list. That can + # relatively rapidly lead to stack exhaustion. This is essentially + # the same problem as http://twistedmatrix.com/trac/ticket/9304. + # + # In order to break the cycle, we add a cheeky sleep(0) here to + # ensure that we fall back to the reactor between each iteration. + # + # (There's no particular need for it to happen before we return + # the context manager, but it needs to happen while we hold the + # lock, and the context manager's exit code must be synchronous, + # so actually this is the only sensible place. + yield run_on_reactor() + + else: + logger.info("Acquired uncontended linearizer lock %r for key %r", + self.name, key) @contextmanager def _ctx_manager(): @@ -211,7 +235,8 @@ class Linearizer(object): yield finally: logger.info("Releasing linearizer lock %r for key %r", self.name, key) - new_defer.callback(None) + with PreserveLoggingContext(): + new_defer.callback(None) current_d = self.key_to_defer.get(key) if current_d is new_defer: self.key_to_defer.pop(key, None) @@ -253,8 +278,13 @@ class Limiter(object): if entry[0] >= self.max_count: new_defer = defer.Deferred() entry[1].append(new_defer) + + logger.info("Waiting to acquire limiter lock for key %r", key) with PreserveLoggingContext(): yield new_defer + logger.info("Acquired limiter lock for key %r", key) + else: + logger.info("Acquired uncontended limiter lock for key %r", key) entry[0] += 1 @@ -263,16 +293,21 @@ class Limiter(object): try: yield finally: + logger.info("Releasing limiter lock for key %r", key) + # We've finished executing so check if there are any things # blocked waiting to execute and start one of them entry[0] -= 1 - try: - entry[1].pop(0).callback(None) - except IndexError: - # If nothing else is executing for this key then remove it - # from the map - if entry[0] == 0: - self.key_to_defer.pop(key, None) + + if entry[1]: + next_def = entry[1].pop(0) + + with PreserveLoggingContext(): + next_def.callback(None) + elif entry[0] == 0: + # We were the last thing for this key: remove it from the + # map. + del self.key_to_defer[key] defer.returnValue(_ctx_manager()) @@ -316,7 +351,7 @@ class ReadWriteLock(object): # We wait for the latest writer to finish writing. We can safely ignore # any existing readers... as they're readers. - yield curr_writer + yield make_deferred_yieldable(curr_writer) @contextmanager def _ctx_manager(): @@ -345,7 +380,7 @@ class ReadWriteLock(object): curr_readers.clear() self.key_to_current_writer[key] = new_defer - yield preserve_context_over_deferred(defer.gatherResults(to_wait_on)) + yield make_deferred_yieldable(defer.gatherResults(to_wait_on)) @contextmanager def _ctx_manager(): diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index e68f94ce77..734331caaa 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -13,32 +13,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer +import logging -from synapse.util.logcontext import ( - PreserveLoggingContext, preserve_context_over_fn -) +from twisted.internet import defer from synapse.util import unwrapFirstError - -import logging - +from synapse.util.logcontext import PreserveLoggingContext logger = logging.getLogger(__name__) def user_left_room(distributor, user, room_id): - return preserve_context_over_fn( - distributor.fire, - "user_left_room", user=user, room_id=room_id - ) + with PreserveLoggingContext(): + distributor.fire("user_left_room", user=user, room_id=room_id) def user_joined_room(distributor, user, room_id): - return preserve_context_over_fn( - distributor.fire, - "user_joined_room", user=user, room_id=room_id - ) + with PreserveLoggingContext(): + distributor.fire("user_joined_room", user=user, room_id=room_id) class Distributor(object): diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index 990216145e..48c9f6802d 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -42,7 +42,7 @@ try: def get_thread_resource_usage(): return resource.getrusage(RUSAGE_THREAD) -except: +except Exception: # If the system doesn't support resource.getrusage(RUSAGE_THREAD) then we # won't track resource usage by returning None. def get_thread_resource_usage(): @@ -261,67 +261,6 @@ class PreserveLoggingContext(object): ) -class _PreservingContextDeferred(defer.Deferred): - """A deferred that ensures that all callbacks and errbacks are called with - the given logging context. - """ - def __init__(self, context): - self._log_context = context - defer.Deferred.__init__(self) - - def addCallbacks(self, callback, errback=None, - callbackArgs=None, callbackKeywords=None, - errbackArgs=None, errbackKeywords=None): - callback = self._wrap_callback(callback) - errback = self._wrap_callback(errback) - return defer.Deferred.addCallbacks( - self, callback, - errback=errback, - callbackArgs=callbackArgs, - callbackKeywords=callbackKeywords, - errbackArgs=errbackArgs, - errbackKeywords=errbackKeywords, - ) - - def _wrap_callback(self, f): - def g(res, *args, **kwargs): - with PreserveLoggingContext(self._log_context): - res = f(res, *args, **kwargs) - return res - return g - - -def preserve_context_over_fn(fn, *args, **kwargs): - """Takes a function and invokes it with the given arguments, but removes - and restores the current logging context while doing so. - - If the result is a deferred, call preserve_context_over_deferred before - returning it. - """ - with PreserveLoggingContext(): - res = fn(*args, **kwargs) - - if isinstance(res, defer.Deferred): - return preserve_context_over_deferred(res) - else: - return res - - -def preserve_context_over_deferred(deferred, context=None): - """Given a deferred wrap it such that any callbacks added later to it will - be invoked with the current context. - - Deprecated: this almost certainly doesn't do want you want, ie make - the deferred follow the synapse logcontext rules: try - ``make_deferred_yieldable`` instead. - """ - if context is None: - context = LoggingContext.current_context() - d = _PreservingContextDeferred(context) - deferred.chainDeferred(d) - return d - - def preserve_fn(f): """Wraps a function, to ensure that the current context is restored after return from the function, and that the sentinel context is set once the diff --git a/synapse/util/logformatter.py b/synapse/util/logformatter.py new file mode 100644 index 0000000000..cdbc4bffd7 --- /dev/null +++ b/synapse/util/logformatter.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import StringIO +import logging +import traceback + + +class LogFormatter(logging.Formatter): + """Log formatter which gives more detail for exceptions + + This is the same as the standard log formatter, except that when logging + exceptions [typically via log.foo("msg", exc_info=1)], it prints the + sequence that led up to the point at which the exception was caught. + (Normally only stack frames between the point the exception was raised and + where it was caught are logged). + """ + def __init__(self, *args, **kwargs): + super(LogFormatter, self).__init__(*args, **kwargs) + + def formatException(self, ei): + sio = StringIO.StringIO() + (typ, val, tb) = ei + + # log the stack above the exception capture point if possible, but + # check that we actually have an f_back attribute to work around + # https://twistedmatrix.com/trac/ticket/9305 + + if tb and hasattr(tb.tb_frame, 'f_back'): + sio.write("Capture point (most recent call last):\n") + traceback.print_stack(tb.tb_frame.f_back, None, sio) + + traceback.print_exception(typ, val, tb, None, sio) + s = sio.getvalue() + sio.close() + if s[-1:] == "\n": + s = s[:-1] + return s diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index 4fa9d1a03c..1adedbb361 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -189,7 +189,7 @@ class RetryDestinationLimiter(object): yield self.store.set_destination_retry_timings( self.destination, retry_last_ts, self.retry_interval ) - except: + except Exception: logger.exception( "Failed to store set_destination_retry_timings", ) diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py index 7412fc57a4..b70f9a6b0a 100644 --- a/synapse/util/wheel_timer.py +++ b/synapse/util/wheel_timer.py @@ -91,7 +91,4 @@ class WheelTimer(object): return ret def __len__(self): - l = 0 - for entry in self.entries: - l += len(entry.queue) - return l + return sum(len(entry.queue) for entry in self.entries) diff --git a/synapse/visibility.py b/synapse/visibility.py index d7dbdc77ff..aaca2c584c 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -17,7 +17,7 @@ from twisted.internet import defer from synapse.api.constants import Membership, EventTypes -from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred +from synapse.util.logcontext import make_deferred_yieldable, preserve_fn import logging @@ -58,7 +58,7 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state, always_include_ids (set(event_id)): set of event ids to specifically include (unless sender is ignored) """ - forgotten = yield preserve_context_over_deferred(defer.gatherResults([ + forgotten = yield make_deferred_yieldable(defer.gatherResults([ defer.maybeDeferred( preserve_fn(store.who_forgot_in_room), room_id, |