diff options
Diffstat (limited to 'synapse')
41 files changed, 855 insertions, 593 deletions
diff --git a/synapse/api/errors.py b/synapse/api/errors.py index d0dfa959dc..79b35b3e7c 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -140,6 +140,22 @@ class RegistrationError(SynapseError): pass +class InteractiveAuthIncompleteError(Exception): + """An error raised when UI auth is not yet complete + + (This indicates we should return a 401 with 'result' as the body) + + Attributes: + result (dict): the server response to the request, which should be + passed back to the client + """ + def __init__(self, result): + super(InteractiveAuthIncompleteError, self).__init__( + "Interactive auth not yet complete", + ) + self.result = result + + class UnrecognizedRequestError(SynapseError): """An error indicating we don't understand the request you're trying to make""" def __init__(self, *args, **kwargs): diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 9477737759..e4318cdfc3 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -25,7 +25,9 @@ except Exception: from daemonize import Daemonize from synapse.util import PreserveLoggingContext from synapse.util.rlimit import change_resource_limit -from twisted.internet import reactor +from twisted.internet import error, reactor + +logger = logging.getLogger(__name__) def start_worker_reactor(appname, config): @@ -120,3 +122,57 @@ def quit_with_error(error_string): sys.stderr.write(" %s\n" % (line.rstrip(),)) sys.stderr.write("*" * line_length + '\n') sys.exit(1) + + +def listen_tcp(bind_addresses, port, factory, backlog=50): + """ + Create a TCP socket for a port and several addresses + """ + for address in bind_addresses: + try: + reactor.listenTCP( + port, + factory, + backlog, + address + ) + except error.CannotListenError as e: + check_bind_error(e, address, bind_addresses) + + +def listen_ssl(bind_addresses, port, factory, context_factory, backlog=50): + """ + Create an SSL socket for a port and several addresses + """ + for address in bind_addresses: + try: + reactor.listenSSL( + port, + factory, + context_factory, + backlog, + address + ) + except error.CannotListenError as e: + check_bind_error(e, address, bind_addresses) + + +def check_bind_error(e, address, bind_addresses): + """ + This method checks an exception occurred while binding on 0.0.0.0. + If :: is specified in the bind addresses a warning is shown. + The exception is still raised otherwise. + + Binding on both 0.0.0.0 and :: causes an exception on Linux and macOS + because :: binds on both IPv4 and IPv6 (as per RFC 3493). + When binding on 0.0.0.0 after :: this can safely be ignored. + + Args: + e (Exception): Exception that was caught. + address (str): Address on which binding was attempted. + bind_addresses (list): Addresses on which the service listens. + """ + if address == '0.0.0.0' and '::' in bind_addresses: + logger.warn('Failed to listen on 0.0.0.0, continuing because listening on [::]') + else: + raise e diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py index ba2657bbad..7d0c2879ae 100644 --- a/synapse/app/appservice.py +++ b/synapse/app/appservice.py @@ -79,17 +79,16 @@ class AppserviceServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - for address in bind_addresses: - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=address + _base.listen_tcp( + bind_addresses, + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, ) + ) logger.info("Synapse appservice now listening on port %d", port) @@ -98,18 +97,15 @@ class AppserviceServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_addresses = listener["bind_addresses"] - - for address in bind_addresses: - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=address + _base.listen_tcp( + listener["bind_addresses"], + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, ) + ) else: logger.warn("Unrecognized listener type: %s", listener["type"]) diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py index 129cfa901f..dc3f6efd43 100644 --- a/synapse/app/client_reader.py +++ b/synapse/app/client_reader.py @@ -103,17 +103,16 @@ class ClientReaderServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - for address in bind_addresses: - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=address + _base.listen_tcp( + bind_addresses, + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, ) + ) logger.info("Synapse client reader now listening on port %d", port) @@ -122,18 +121,16 @@ class ClientReaderServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_addresses = listener["bind_addresses"] - - for address in bind_addresses: - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=address + _base.listen_tcp( + listener["bind_addresses"], + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, ) + ) + else: logger.warn("Unrecognized listener type: %s", listener["type"]) diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py index 40cebe6f4a..a072291e1f 100644 --- a/synapse/app/federation_reader.py +++ b/synapse/app/federation_reader.py @@ -92,17 +92,16 @@ class FederationReaderServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - for address in bind_addresses: - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=address + _base.listen_tcp( + bind_addresses, + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, ) + ) logger.info("Synapse federation reader now listening on port %d", port) @@ -111,18 +110,15 @@ class FederationReaderServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_addresses = listener["bind_addresses"] - - for address in bind_addresses: - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=address + _base.listen_tcp( + listener["bind_addresses"], + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, ) + ) else: logger.warn("Unrecognized listener type: %s", listener["type"]) diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index 389e3909d1..09e9488f06 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -106,17 +106,16 @@ class FederationSenderServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - for address in bind_addresses: - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=address + _base.listen_tcp( + bind_addresses, + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, ) + ) logger.info("Synapse federation_sender now listening on port %d", port) @@ -125,18 +124,15 @@ class FederationSenderServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_addresses = listener["bind_addresses"] - - for address in bind_addresses: - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=address + _base.listen_tcp( + listener["bind_addresses"], + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, ) + ) else: logger.warn("Unrecognized listener type: %s", listener["type"]) diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py index abc7ef5725..ae531c0aa4 100644 --- a/synapse/app/frontend_proxy.py +++ b/synapse/app/frontend_proxy.py @@ -157,17 +157,16 @@ class FrontendProxyServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - for address in bind_addresses: - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=address + _base.listen_tcp( + bind_addresses, + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, ) + ) logger.info("Synapse client reader now listening on port %d", port) @@ -176,18 +175,15 @@ class FrontendProxyServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_addresses = listener["bind_addresses"] - - for address in bind_addresses: - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=address + _base.listen_tcp( + listener["bind_addresses"], + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, ) + ) else: logger.warn("Unrecognized listener type: %s", listener["type"]) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 6b8875afb4..92ab3b311b 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -25,7 +25,7 @@ from synapse.api.urls import CONTENT_REPO_PREFIX, FEDERATION_PREFIX, \ LEGACY_MEDIA_PREFIX, MEDIA_PREFIX, SERVER_KEY_PREFIX, SERVER_KEY_V2_PREFIX, \ STATIC_PREFIX, WEB_CLIENT_PREFIX from synapse.app import _base -from synapse.app._base import quit_with_error +from synapse.app._base import quit_with_error, listen_ssl, listen_tcp from synapse.config._base import ConfigError from synapse.config.homeserver import HomeServerConfig from synapse.crypto import context_factory @@ -130,30 +130,29 @@ class SynapseHomeServer(HomeServer): root_resource = create_resource_tree(resources, root_resource) if tls: - for address in bind_addresses: - reactor.listenSSL( - port, - SynapseSite( - "synapse.access.https.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - self.tls_server_context_factory, - interface=address - ) + listen_ssl( + bind_addresses, + port, + SynapseSite( + "synapse.access.https.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + ), + self.tls_server_context_factory, + ) + else: - for address in bind_addresses: - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=address + listen_tcp( + bind_addresses, + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, ) + ) logger.info("Synapse now listening on port %d", port) def _configure_named_resource(self, name, compress=False): @@ -229,18 +228,15 @@ class SynapseHomeServer(HomeServer): if listener["type"] == "http": self._listener_http(config, listener) elif listener["type"] == "manhole": - bind_addresses = listener["bind_addresses"] - - for address in bind_addresses: - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=address + listen_tcp( + listener["bind_addresses"], + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, ) + ) elif listener["type"] == "replication": bind_addresses = listener["bind_addresses"] for address in bind_addresses: diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index c4e5f0965d..eab1597aaa 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -99,17 +99,16 @@ class MediaRepositoryServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - for address in bind_addresses: - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=address + _base.listen_tcp( + bind_addresses, + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, ) + ) logger.info("Synapse media repository now listening on port %d", port) @@ -118,18 +117,15 @@ class MediaRepositoryServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_addresses = listener["bind_addresses"] - - for address in bind_addresses: - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=address + _base.listen_tcp( + listener["bind_addresses"], + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, ) + ) else: logger.warn("Unrecognized listener type: %s", listener["type"]) diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index db9a4d16f4..7fbbb0b0e1 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -114,17 +114,16 @@ class PusherServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - for address in bind_addresses: - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=address + _base.listen_tcp( + bind_addresses, + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, ) + ) logger.info("Synapse pusher now listening on port %d", port) @@ -133,18 +132,15 @@ class PusherServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_addresses = listener["bind_addresses"] - - for address in bind_addresses: - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=address + _base.listen_tcp( + listener["bind_addresses"], + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, ) + ) else: logger.warn("Unrecognized listener type: %s", listener["type"]) diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 323fddee21..0abba3016e 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -288,17 +288,16 @@ class SynchrotronServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - for address in bind_addresses: - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=address + _base.listen_tcp( + bind_addresses, + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, ) + ) logger.info("Synapse synchrotron now listening on port %d", port) @@ -307,18 +306,15 @@ class SynchrotronServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_addresses = listener["bind_addresses"] - - for address in bind_addresses: - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=address + _base.listen_tcp( + listener["bind_addresses"], + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, ) + ) else: logger.warn("Unrecognized listener type: %s", listener["type"]) diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index be661a70c7..a48c4a2ae6 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -131,17 +131,16 @@ class UserDirectoryServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - for address in bind_addresses: - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=address + _base.listen_tcp( + bind_addresses, + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, ) + ) logger.info("Synapse user_dir now listening on port %d", port) @@ -150,18 +149,15 @@ class UserDirectoryServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_addresses = listener["bind_addresses"] - - for address in bind_addresses: - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=address + _base.listen_tcp( + listener["bind_addresses"], + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, ) + ) else: logger.warn("Unrecognized listener type: %s", listener["type"]) diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 05e242aef6..bf19cfee29 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -36,6 +36,7 @@ from .workers import WorkerConfig from .push import PushConfig from .spam_checker import SpamCheckerConfig from .groups import GroupsConfig +from .user_directory import UserDirectoryConfig class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, @@ -44,7 +45,7 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, AppServiceConfig, KeyConfig, SAML2Config, CasConfig, JWTConfig, PasswordConfig, EmailConfig, WorkerConfig, PasswordAuthProviderConfig, PushConfig, - SpamCheckerConfig, GroupsConfig,): + SpamCheckerConfig, GroupsConfig, UserDirectoryConfig,): pass diff --git a/synapse/config/logger.py b/synapse/config/logger.py index a1d6e4d4f7..3f70039acd 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -28,27 +28,27 @@ DEFAULT_LOG_CONFIG = Template(""" version: 1 formatters: - precise: - format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s\ -- %(message)s' + precise: + format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - \ +%(request)s - %(message)s' filters: - context: - (): synapse.util.logcontext.LoggingContextFilter - request: "" + context: + (): synapse.util.logcontext.LoggingContextFilter + request: "" handlers: - file: - class: logging.handlers.RotatingFileHandler - formatter: precise - filename: ${log_file} - maxBytes: 104857600 - backupCount: 10 - filters: [context] - console: - class: logging.StreamHandler - formatter: precise - filters: [context] + file: + class: logging.handlers.RotatingFileHandler + formatter: precise + filename: ${log_file} + maxBytes: 104857600 + backupCount: 10 + filters: [context] + console: + class: logging.StreamHandler + formatter: precise + filters: [context] loggers: synapse: @@ -74,17 +74,10 @@ class LoggingConfig(Config): self.log_file = self.abspath(config.get("log_file")) def default_config(self, config_dir_path, server_name, **kwargs): - log_file = self.abspath("homeserver.log") log_config = self.abspath( os.path.join(config_dir_path, server_name + ".log.config") ) return """ - # Logging verbosity level. Ignored if log_config is specified. - verbose: 0 - - # File to write logging to. Ignored if log_config is specified. - log_file: "%(log_file)s" - # A yaml python logging config file log_config: "%(log_config)s" """ % locals() @@ -123,9 +116,10 @@ class LoggingConfig(Config): def generate_files(self, config): log_config = config.get("log_config") if log_config and not os.path.exists(log_config): + log_file = self.abspath("homeserver.log") with open(log_config, "wb") as log_config_file: log_config_file.write( - DEFAULT_LOG_CONFIG.substitute(log_file=config["log_file"]) + DEFAULT_LOG_CONFIG.substitute(log_file=log_file) ) @@ -150,6 +144,9 @@ def setup_logging(config, use_worker_options=False): ) if log_config is None: + # We don't have a logfile, so fall back to the 'verbosity' param from + # the config or cmdline. (Note that we generate a log config for new + # installs, so this will be an unusual case) level = logging.INFO level_for_storage = logging.INFO if config.verbosity: @@ -157,11 +154,10 @@ def setup_logging(config, use_worker_options=False): if config.verbosity > 1: level_for_storage = logging.DEBUG - # FIXME: we need a logging.WARN for a -q quiet option logger = logging.getLogger('') logger.setLevel(level) - logging.getLogger('synapse.storage').setLevel(level_for_storage) + logging.getLogger('synapse.storage.SQL').setLevel(level_for_storage) formatter = logging.Formatter(log_format) if log_file: diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py index e9828fac17..6602c5b4c7 100644 --- a/synapse/config/password_auth_providers.py +++ b/synapse/config/password_auth_providers.py @@ -29,10 +29,10 @@ class PasswordAuthProviderConfig(Config): # param. ldap_config = config.get("ldap_config", {}) if ldap_config.get("enabled", False): - providers.append[{ + providers.append({ 'module': LDAP_PROVIDER, 'config': ldap_config, - }] + }) providers.extend(config.get("password_providers", [])) for provider in providers: diff --git a/synapse/config/server.py b/synapse/config/server.py index edb90a1348..436dd8a6fe 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -220,13 +220,12 @@ class ServerConfig(Config): port: %(bind_port)s # Local addresses to listen on. - # This will listen on all IPv4 addresses by default. + # On Linux and Mac OS, `::` will listen on all IPv4 and IPv6 + # addresses by default. For most other OSes, this will only listen + # on IPv6. bind_addresses: + - '::' - '0.0.0.0' - # Uncomment to listen on all IPv6 interfaces - # N.B: On at least Linux this will also listen on all IPv4 - # addresses, so you will need to comment out the line above. - # - '::' # This is a 'http' listener, allows us to specify 'resources'. type: http @@ -264,7 +263,7 @@ class ServerConfig(Config): # For when matrix traffic passes through loadbalancer that unwraps TLS. - port: %(unsecure_port)s tls: false - bind_addresses: ['0.0.0.0'] + bind_addresses: ['::', '0.0.0.0'] type: http x_forwarded: false @@ -278,7 +277,7 @@ class ServerConfig(Config): # Turn on the twisted ssh manhole service on localhost on the given # port. # - port: 9000 - # bind_address: 127.0.0.1 + # bind_addresses: ['::1', '127.0.0.1'] # type: manhole """ % locals() diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py new file mode 100644 index 0000000000..38e8947843 --- /dev/null +++ b/synapse/config/user_directory.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import Config + + +class UserDirectoryConfig(Config): + """User Directory Configuration + Configuration for the behaviour of the /user_directory API + """ + + def read_config(self, config): + self.user_directory_search_all_users = False + user_directory_config = config.get("user_directory", None) + if user_directory_config: + self.user_directory_search_all_users = ( + user_directory_config.get("search_all_users", False) + ) + + def default_config(self, config_dir_path, server_name, **kwargs): + return """ + # User Directory configuration + # + # 'search_all_users' defines whether to search all users visible to your HS + # when searching the user directory, rather than limiting to users visible + # in public rooms. Defaults to false. If you set it True, you'll have to run + # UPDATE user_directory_stream_pos SET stream_id = NULL; + # on your database to tell it to rebuild the user_directory search indexes. + # + #user_directory: + # search_all_users: false + """ diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 061ee86b16..cd5627e36a 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -319,7 +319,7 @@ def _is_membership_change_allowed(event, auth_events): # TODO (erikj): Implement kicks. if target_banned and user_level < ban_level: raise AuthError( - 403, "You cannot unban user &s." % (target_user_id,) + 403, "You cannot unban user %s." % (target_user_id,) ) elif target_user_id != event.user_id: kick_level = _get_named_level(auth_events, "kick", 50) diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index a0f5d40eb3..7918d3e442 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -16,7 +16,9 @@ import logging from synapse.api.errors import SynapseError from synapse.crypto.event_signing import check_event_content_hash +from synapse.events import FrozenEvent from synapse.events.utils import prune_event +from synapse.http.servlet import assert_params_in_request from synapse.util import unwrapFirstError, logcontext from twisted.internet import defer @@ -169,3 +171,28 @@ class FederationBase(object): ) return deferreds + + +def event_from_pdu_json(pdu_json, outlier=False): + """Construct a FrozenEvent from an event json received over federation + + Args: + pdu_json (object): pdu as received over federation + outlier (bool): True to mark this event as an outlier + + Returns: + FrozenEvent + + Raises: + SynapseError: if the pdu is missing required fields + """ + # we could probably enforce a bunch of other fields here (room_id, sender, + # origin, etc etc) + assert_params_in_request(pdu_json, ('event_id', 'type')) + event = FrozenEvent( + pdu_json + ) + + event.internal_metadata.outlier = outlier + + return event diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index b8f02f5391..b1fe03f702 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -14,29 +14,29 @@ # limitations under the License. +import copy +import itertools +import logging +import random + from twisted.internet import defer -from .federation_base import FederationBase from synapse.api.constants import Membership - from synapse.api.errors import ( CodeMessageException, HttpResponseException, SynapseError, ) -from synapse.util import unwrapFirstError, logcontext +from synapse.events import builder +from synapse.federation.federation_base import ( + FederationBase, + event_from_pdu_json, +) +import synapse.metrics +from synapse.util import logcontext, unwrapFirstError from synapse.util.caches.expiringcache import ExpiringCache -from synapse.util.logutils import log_function from synapse.util.logcontext import make_deferred_yieldable, preserve_fn -from synapse.events import FrozenEvent, builder -import synapse.metrics - +from synapse.util.logutils import log_function from synapse.util.retryutils import NotRetryingDestination -import copy -import itertools -import logging -import random - - logger = logging.getLogger(__name__) @@ -184,7 +184,7 @@ class FederationClient(FederationBase): logger.debug("backfill transaction_data=%s", repr(transaction_data)) pdus = [ - self.event_from_pdu_json(p, outlier=False) + event_from_pdu_json(p, outlier=False) for p in transaction_data["pdus"] ] @@ -244,7 +244,7 @@ class FederationClient(FederationBase): logger.debug("transaction_data %r", transaction_data) pdu_list = [ - self.event_from_pdu_json(p, outlier=outlier) + event_from_pdu_json(p, outlier=outlier) for p in transaction_data["pdus"] ] @@ -336,11 +336,11 @@ class FederationClient(FederationBase): ) pdus = [ - self.event_from_pdu_json(p, outlier=True) for p in result["pdus"] + event_from_pdu_json(p, outlier=True) for p in result["pdus"] ] auth_chain = [ - self.event_from_pdu_json(p, outlier=True) + event_from_pdu_json(p, outlier=True) for p in result.get("auth_chain", []) ] @@ -441,7 +441,7 @@ class FederationClient(FederationBase): ) auth_chain = [ - self.event_from_pdu_json(p, outlier=True) + event_from_pdu_json(p, outlier=True) for p in res["auth_chain"] ] @@ -570,12 +570,12 @@ class FederationClient(FederationBase): logger.debug("Got content: %s", content) state = [ - self.event_from_pdu_json(p, outlier=True) + event_from_pdu_json(p, outlier=True) for p in content.get("state", []) ] auth_chain = [ - self.event_from_pdu_json(p, outlier=True) + event_from_pdu_json(p, outlier=True) for p in content.get("auth_chain", []) ] @@ -650,7 +650,7 @@ class FederationClient(FederationBase): logger.debug("Got response to send_invite: %s", pdu_dict) - pdu = self.event_from_pdu_json(pdu_dict) + pdu = event_from_pdu_json(pdu_dict) # Check signatures are correct. pdu = yield self._check_sigs_and_hash(pdu) @@ -740,7 +740,7 @@ class FederationClient(FederationBase): ) auth_chain = [ - self.event_from_pdu_json(e) + event_from_pdu_json(e) for e in content["auth_chain"] ] @@ -788,7 +788,7 @@ class FederationClient(FederationBase): ) events = [ - self.event_from_pdu_json(e) + event_from_pdu_json(e) for e in content.get("events", []) ] @@ -805,15 +805,6 @@ class FederationClient(FederationBase): defer.returnValue(signed_events) - def event_from_pdu_json(self, pdu_json, outlier=False): - event = FrozenEvent( - pdu_json - ) - - event.internal_metadata.outlier = outlier - - return event - @defer.inlineCallbacks def forward_third_party_invite(self, destinations, room_id, event_dict): for destination in destinations: diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index a2327f24b6..9849953c9b 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -12,25 +12,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer +import logging -from .federation_base import FederationBase -from .units import Transaction, Edu +import simplejson as json +from twisted.internet import defer +from synapse.api.errors import AuthError, FederationError, SynapseError +from synapse.crypto.event_signing import compute_event_signature +from synapse.federation.federation_base import ( + FederationBase, + event_from_pdu_json, +) +from synapse.federation.units import Edu, Transaction +import synapse.metrics +from synapse.types import get_domain_from_id from synapse.util import async +from synapse.util.caches.response_cache import ResponseCache from synapse.util.logcontext import make_deferred_yieldable, preserve_fn from synapse.util.logutils import log_function -from synapse.util.caches.response_cache import ResponseCache -from synapse.events import FrozenEvent -from synapse.types import get_domain_from_id -import synapse.metrics - -from synapse.api.errors import AuthError, FederationError, SynapseError - -from synapse.crypto.event_signing import compute_event_signature - -import simplejson as json -import logging # when processing incoming transactions, we try to handle multiple rooms in # parallel, up to this limit. @@ -172,7 +171,7 @@ class FederationServer(FederationBase): p["age_ts"] = request_time - int(p["age"]) del p["age"] - event = self.event_from_pdu_json(p) + event = event_from_pdu_json(p) room_id = event.room_id pdus_by_room.setdefault(room_id, []).append(event) @@ -346,7 +345,7 @@ class FederationServer(FederationBase): @defer.inlineCallbacks def on_invite_request(self, origin, content): - pdu = self.event_from_pdu_json(content) + pdu = event_from_pdu_json(content) ret_pdu = yield self.handler.on_invite_request(origin, pdu) time_now = self._clock.time_msec() defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)})) @@ -354,7 +353,7 @@ class FederationServer(FederationBase): @defer.inlineCallbacks def on_send_join_request(self, origin, content): logger.debug("on_send_join_request: content: %s", content) - pdu = self.event_from_pdu_json(content) + pdu = event_from_pdu_json(content) logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures) res_pdus = yield self.handler.on_send_join_request(origin, pdu) time_now = self._clock.time_msec() @@ -374,7 +373,7 @@ class FederationServer(FederationBase): @defer.inlineCallbacks def on_send_leave_request(self, origin, content): logger.debug("on_send_leave_request: content: %s", content) - pdu = self.event_from_pdu_json(content) + pdu = event_from_pdu_json(content) logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures) yield self.handler.on_send_leave_request(origin, pdu) defer.returnValue((200, {})) @@ -411,7 +410,7 @@ class FederationServer(FederationBase): """ with (yield self._server_linearizer.queue((origin, room_id))): auth_chain = [ - self.event_from_pdu_json(e) + event_from_pdu_json(e) for e in content["auth_chain"] ] @@ -586,15 +585,6 @@ class FederationServer(FederationBase): def __str__(self): return "<ReplicationLayer(%s)>" % self.server_name - def event_from_pdu_json(self, pdu_json, outlier=False): - event = FrozenEvent( - pdu_json - ) - - event.internal_metadata.outlier = outlier - - return event - @defer.inlineCallbacks def exchange_third_party_invite( self, diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index af9060fdbf..258cc345dc 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -17,7 +17,10 @@ from twisted.internet import defer, threads from ._base import BaseHandler from synapse.api.constants import LoginType -from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError +from synapse.api.errors import ( + AuthError, Codes, InteractiveAuthIncompleteError, LoginError, StoreError, + SynapseError, +) from synapse.module_api import ModuleApi from synapse.types import UserID from synapse.util.async import run_on_reactor @@ -47,7 +50,6 @@ class AuthHandler(BaseHandler): """ super(AuthHandler, self).__init__(hs) self.checkers = { - LoginType.PASSWORD: self._check_password_auth, LoginType.RECAPTCHA: self._check_recaptcha, LoginType.EMAIL_IDENTITY: self._check_email_identity, LoginType.MSISDN: self._check_msisdn, @@ -76,15 +78,76 @@ class AuthHandler(BaseHandler): self.macaroon_gen = hs.get_macaroon_generator() self._password_enabled = hs.config.password_enabled - login_types = set() + # we keep this as a list despite the O(N^2) implication so that we can + # keep PASSWORD first and avoid confusing clients which pick the first + # type in the list. (NB that the spec doesn't require us to do so and + # clients which favour types that they don't understand over those that + # they do are technically broken) + login_types = [] if self._password_enabled: - login_types.add(LoginType.PASSWORD) + login_types.append(LoginType.PASSWORD) for provider in self.password_providers: if hasattr(provider, "get_supported_login_types"): - login_types.update( - provider.get_supported_login_types().keys() - ) - self._supported_login_types = frozenset(login_types) + for t in provider.get_supported_login_types().keys(): + if t not in login_types: + login_types.append(t) + self._supported_login_types = login_types + + @defer.inlineCallbacks + def validate_user_via_ui_auth(self, requester, request_body, clientip): + """ + Checks that the user is who they claim to be, via a UI auth. + + This is used for things like device deletion and password reset where + the user already has a valid access token, but we want to double-check + that it isn't stolen by re-authenticating them. + + Args: + requester (Requester): The user, as given by the access token + + request_body (dict): The body of the request sent by the client + + clientip (str): The IP address of the client. + + Returns: + defer.Deferred[dict]: the parameters for this request (which may + have been given only in a previous call). + + Raises: + InteractiveAuthIncompleteError if the client has not yet completed + any of the permitted login flows + + AuthError if the client has completed a login flow, and it gives + a different user to `requester` + """ + + # build a list of supported flows + flows = [ + [login_type] for login_type in self._supported_login_types + ] + + result, params, _ = yield self.check_auth( + flows, request_body, clientip, + ) + + # find the completed login type + for login_type in self._supported_login_types: + if login_type not in result: + continue + + user_id = result[login_type] + break + else: + # this can't happen + raise Exception( + "check_auth returned True but no successful login type", + ) + + # check that the UI auth matched the access token + if user_id != requester.user.to_string(): + raise AuthError(403, "Invalid auth") + + defer.returnValue(params) @defer.inlineCallbacks def check_auth(self, flows, clientdict, clientip): @@ -96,26 +159,36 @@ class AuthHandler(BaseHandler): session with a map, which maps each auth-type (str) to the relevant identity authenticated by that auth-type (mostly str, but for captcha, bool). + If no auth flows have been completed successfully, raises an + InteractiveAuthIncompleteError. To handle this, you can use + synapse.rest.client.v2_alpha._base.interactive_auth_handler as a + decorator. + Args: flows (list): A list of login flows. Each flow is an ordered list of strings representing auth-types. At least one full flow must be completed in order for auth to be successful. + clientdict: The dictionary from the client root level, not the 'auth' key: this method prompts for auth if none is sent. + clientip (str): The IP address of the client. + Returns: - A tuple of (authed, dict, dict, session_id) where authed is true if - the client has successfully completed an auth flow. If it is true - the first dict contains the authenticated credentials of each stage. + defer.Deferred[dict, dict, str]: a deferred tuple of + (creds, params, session_id). + + 'creds' contains the authenticated credentials of each stage. - If authed is false, the first dictionary is the server response to - the login request and should be passed back to the client. + 'params' contains the parameters for this request (which may + have been given only in a previous call). - In either case, the second dict contains the parameters for this - request (which may have been given only in a previous call). + 'session_id' is the ID of this session, either passed in by the + client or assigned by this call - session_id is the ID of this session, either passed in by the client - or assigned by the call to check_auth + Raises: + InteractiveAuthIncompleteError if the client has not yet completed + all the stages in any of the permitted flows. """ authdict = None @@ -143,11 +216,8 @@ class AuthHandler(BaseHandler): clientdict = session['clientdict'] if not authdict: - defer.returnValue( - ( - False, self._auth_dict_for_flows(flows, session), - clientdict, session['id'] - ) + raise InteractiveAuthIncompleteError( + self._auth_dict_for_flows(flows, session), ) if 'creds' not in session: @@ -158,14 +228,12 @@ class AuthHandler(BaseHandler): errordict = {} if 'type' in authdict: login_type = authdict['type'] - if login_type not in self.checkers: - raise LoginError(400, "", Codes.UNRECOGNIZED) try: - result = yield self.checkers[login_type](authdict, clientip) + result = yield self._check_auth_dict(authdict, clientip) if result: creds[login_type] = result self._save_session(session) - except LoginError, e: + except LoginError as e: if login_type == LoginType.EMAIL_IDENTITY: # riot used to have a bug where it would request a new # validation token (thus sending a new email) each time it @@ -174,7 +242,7 @@ class AuthHandler(BaseHandler): # # Grandfather in the old behaviour for now to avoid # breaking old riot deployments. - raise e + raise # this step failed. Merge the error dict into the response # so that the client can have another go. @@ -191,12 +259,14 @@ class AuthHandler(BaseHandler): "Auth completed with creds: %r. Client dict has keys: %r", creds, clientdict.keys() ) - defer.returnValue((True, creds, clientdict, session['id'])) + defer.returnValue((creds, clientdict, session['id'])) ret = self._auth_dict_for_flows(flows, session) ret['completed'] = creds.keys() ret.update(errordict) - defer.returnValue((False, ret, clientdict, session['id'])) + raise InteractiveAuthIncompleteError( + ret, + ) @defer.inlineCallbacks def add_oob_auth(self, stagetype, authdict, clientip): @@ -269,17 +339,35 @@ class AuthHandler(BaseHandler): return sess.setdefault('serverdict', {}).get(key, default) @defer.inlineCallbacks - def _check_password_auth(self, authdict, _): - if "user" not in authdict or "password" not in authdict: - raise LoginError(400, "", Codes.MISSING_PARAM) + def _check_auth_dict(self, authdict, clientip): + """Attempt to validate the auth dict provided by a client + + Args: + authdict (object): auth dict provided by the client + clientip (str): IP address of the client + + Returns: + Deferred: result of the stage verification. + + Raises: + StoreError if there was a problem accessing the database + SynapseError if there was a problem with the request + LoginError if there was an authentication problem. + """ + login_type = authdict['type'] + checker = self.checkers.get(login_type) + if checker is not None: + res = yield checker(authdict, clientip) + defer.returnValue(res) + + # build a v1-login-style dict out of the authdict and fall back to the + # v1 code + user_id = authdict.get("user") - user_id = authdict["user"] - password = authdict["password"] + if user_id is None: + raise SynapseError(400, "", Codes.MISSING_PARAM) - (canonical_id, callback) = yield self.validate_login(user_id, { - "type": LoginType.PASSWORD, - "password": password, - }) + (canonical_id, callback) = yield self.validate_login(user_id, authdict) defer.returnValue(canonical_id) @defer.inlineCallbacks diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index 7e5d3f148d..e4d0cc8b02 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -383,11 +383,12 @@ class GroupsLocalHandler(object): defer.returnValue({"groups": result}) else: - result = yield self.transport_client.get_publicised_groups_for_user( - get_domain_from_id(user_id), user_id + bulk_result = yield self.transport_client.bulk_get_publicised_groups( + get_domain_from_id(user_id), [user_id], ) + result = bulk_result.get("users", {}).get(user_id) # TODO: Verify attestations - defer.returnValue(result) + defer.returnValue({"groups": result}) @defer.inlineCallbacks def bulk_get_publicised_groups(self, user_ids, proxy=True): diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 5e5b1952dd..9800e24453 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -36,6 +36,8 @@ class ProfileHandler(BaseHandler): "profile", self.on_profile_query ) + self.user_directory_handler = hs.get_user_directory_handler() + self.clock.looping_call(self._update_remote_profile_cache, self.PROFILE_UPDATE_MS) @defer.inlineCallbacks @@ -139,6 +141,12 @@ class ProfileHandler(BaseHandler): target_user.localpart, new_displayname ) + if self.hs.config.user_directory_search_all_users: + profile = yield self.store.get_profileinfo(target_user.localpart) + yield self.user_directory_handler.handle_local_profile_change( + target_user.to_string(), profile + ) + yield self._update_join_states(requester, target_user) @defer.inlineCallbacks @@ -183,6 +191,12 @@ class ProfileHandler(BaseHandler): target_user.localpart, new_avatar_url ) + if self.hs.config.user_directory_search_all_users: + profile = yield self.store.get_profileinfo(target_user.localpart) + yield self.user_directory_handler.handle_local_profile_change( + target_user.to_string(), profile + ) + yield self._update_join_states(requester, target_user) @defer.inlineCallbacks diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 132e2052c0..5b808beac1 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -38,6 +38,7 @@ class RegistrationHandler(BaseHandler): self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() self.profile_handler = hs.get_profile_handler() + self.user_directory_handler = hs.get_user_directory_handler() self.captcha_client = CaptchaServerHttpClient(hs) self._next_generated_user_id = None @@ -165,6 +166,13 @@ class RegistrationHandler(BaseHandler): ), admin=admin, ) + + if self.hs.config.user_directory_search_all_users: + profile = yield self.store.get_profileinfo(localpart) + yield self.user_directory_handler.handle_local_profile_change( + user_id, profile + ) + else: # autogen a sequential user ID attempts = 0 diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 496f1fc39b..d1cc87a016 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -205,12 +205,12 @@ class RoomCreationHandler(BaseHandler): }, ratelimit=False) - content = {} - is_direct = config.get("is_direct", None) - if is_direct: - content["is_direct"] = is_direct - for invitee in invite_list: + content = {} + is_direct = config.get("is_direct", None) + if is_direct: + content["is_direct"] = is_direct + yield room_member_handler.update_membership( requester, UserID.from_string(invitee), diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 970fec0666..7e6467cd1d 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -189,6 +189,10 @@ class RoomMemberHandler(BaseHandler): content_specified = bool(content) if content is None: content = {} + else: + # We do a copy here as we potentially change some keys + # later on. + content = dict(content) effective_membership_state = action if action in ["kick", "unban"]: diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index b5be5d9623..714f0195c8 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -20,12 +20,13 @@ from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.storage.roommember import ProfileInfo from synapse.util.metrics import Measure from synapse.util.async import sleep +from synapse.types import get_localpart_from_id logger = logging.getLogger(__name__) -class UserDirectoyHandler(object): +class UserDirectoryHandler(object): """Handles querying of and keeping updated the user_directory. N.B.: ASSUMES IT IS THE ONLY THING THAT MODIFIES THE USER DIRECTORY @@ -41,9 +42,10 @@ class UserDirectoyHandler(object): one public room. """ - INITIAL_SLEEP_MS = 50 - INITIAL_SLEEP_COUNT = 100 - INITIAL_BATCH_SIZE = 100 + INITIAL_ROOM_SLEEP_MS = 50 + INITIAL_ROOM_SLEEP_COUNT = 100 + INITIAL_ROOM_BATCH_SIZE = 100 + INITIAL_USER_SLEEP_MS = 10 def __init__(self, hs): self.store = hs.get_datastore() @@ -53,6 +55,7 @@ class UserDirectoyHandler(object): self.notifier = hs.get_notifier() self.is_mine_id = hs.is_mine_id self.update_user_directory = hs.config.update_user_directory + self.search_all_users = hs.config.user_directory_search_all_users # When start up for the first time we need to populate the user_directory. # This is a set of user_id's we've inserted already @@ -111,6 +114,15 @@ class UserDirectoyHandler(object): self._is_processing = False @defer.inlineCallbacks + def handle_local_profile_change(self, user_id, profile): + """Called to update index of our local user profiles when they change + irrespective of any rooms the user may be in. + """ + yield self.store.update_profile_in_user_dir( + user_id, profile.display_name, profile.avatar_url, None, + ) + + @defer.inlineCallbacks def _unsafe_process(self): # If self.pos is None then means we haven't fetched it from DB if self.pos is None: @@ -148,16 +160,30 @@ class UserDirectoyHandler(object): room_ids = yield self.store.get_all_rooms() logger.info("Doing initial update of user directory. %d rooms", len(room_ids)) - num_processed_rooms = 1 + num_processed_rooms = 0 for room_id in room_ids: - logger.info("Handling room %d/%d", num_processed_rooms, len(room_ids)) + logger.info("Handling room %d/%d", num_processed_rooms + 1, len(room_ids)) yield self._handle_initial_room(room_id) num_processed_rooms += 1 - yield sleep(self.INITIAL_SLEEP_MS / 1000.) + yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.) logger.info("Processed all rooms.") + if self.search_all_users: + num_processed_users = 0 + user_ids = yield self.store.get_all_local_users() + logger.info("Doing initial update of user directory. %d users", len(user_ids)) + for user_id in user_ids: + # We add profiles for all users even if they don't match the + # include pattern, just in case we want to change it in future + logger.info("Handling user %d/%d", num_processed_users + 1, len(user_ids)) + yield self._handle_local_user(user_id) + num_processed_users += 1 + yield sleep(self.INITIAL_USER_SLEEP_MS / 1000.) + + logger.info("Processed all users") + self.initially_handled_users = None self.initially_handled_users_in_public = None self.initially_handled_users_share = None @@ -201,8 +227,8 @@ class UserDirectoyHandler(object): to_update = set() count = 0 for user_id in user_ids: - if count % self.INITIAL_SLEEP_COUNT == 0: - yield sleep(self.INITIAL_SLEEP_MS / 1000.) + if count % self.INITIAL_ROOM_SLEEP_COUNT == 0: + yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.) if not self.is_mine_id(user_id): count += 1 @@ -216,8 +242,8 @@ class UserDirectoyHandler(object): if user_id == other_user_id: continue - if count % self.INITIAL_SLEEP_COUNT == 0: - yield sleep(self.INITIAL_SLEEP_MS / 1000.) + if count % self.INITIAL_ROOM_SLEEP_COUNT == 0: + yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.) count += 1 user_set = (user_id, other_user_id) @@ -237,13 +263,13 @@ class UserDirectoyHandler(object): else: self.initially_handled_users_share_private_room.add(user_set) - if len(to_insert) > self.INITIAL_BATCH_SIZE: + if len(to_insert) > self.INITIAL_ROOM_BATCH_SIZE: yield self.store.add_users_who_share_room( room_id, not is_public, to_insert, ) to_insert.clear() - if len(to_update) > self.INITIAL_BATCH_SIZE: + if len(to_update) > self.INITIAL_ROOM_BATCH_SIZE: yield self.store.update_users_who_share_room( room_id, not is_public, to_update, ) @@ -385,14 +411,28 @@ class UserDirectoyHandler(object): yield self._handle_remove_user(room_id, user_id) @defer.inlineCallbacks + def _handle_local_user(self, user_id): + """Adds a new local roomless user into the user_directory_search table. + Used to populate up the user index when we have an + user_directory_search_all_users specified. + """ + logger.debug("Adding new local user to dir, %r", user_id) + + profile = yield self.store.get_profileinfo(get_localpart_from_id(user_id)) + + row = yield self.store.get_user_in_directory(user_id) + if not row: + yield self.store.add_profiles_to_user_dir(None, {user_id: profile}) + + @defer.inlineCallbacks def _handle_new_user(self, room_id, user_id, profile): """Called when we might need to add user to directory Args: - room_id (str): room_id that user joined or started being public that + room_id (str): room_id that user joined or started being public user_id (str) """ - logger.debug("Adding user to dir, %r", user_id) + logger.debug("Adding new user to dir, %r", user_id) row = yield self.store.get_user_in_directory(user_id) if not row: @@ -407,7 +447,7 @@ class UserDirectoyHandler(object): if not row: yield self.store.add_users_to_public_room(room_id, [user_id]) else: - logger.debug("Not adding user to public dir, %r", user_id) + logger.debug("Not adding new user to public dir, %r", user_id) # Now we update users who share rooms with users. We do this by getting # all the current users in the room and seeing which aren't already diff --git a/synapse/http/server.py b/synapse/http/server.py index 25466cd292..6e8f4c9c5f 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -107,6 +107,10 @@ def wrap_request_handler(request_handler, include_metrics=False): with LoggingContext(request_id) as request_context: with Measure(self.clock, "wrapped_request_handler"): request_metrics = RequestMetrics() + # we start the request metrics timer here with an initial stab + # at the servlet name. For most requests that name will be + # JsonResource (or a subclass), and JsonResource._async_render + # will update it once it picks a servlet. request_metrics.start(self.clock, name=self.__class__.__name__) request_context.request = request_id @@ -249,12 +253,23 @@ class JsonResource(HttpServer, resource.Resource): if not m: continue - # We found a match! Trigger callback and then return the - # returned response. We pass both the request and any - # matched groups from the regex to the callback. + # We found a match! First update the metrics object to indicate + # which servlet is handling the request. callback = path_entry.callback + servlet_instance = getattr(callback, "__self__", None) + if servlet_instance is not None: + servlet_classname = servlet_instance.__class__.__name__ + else: + servlet_classname = "%r" % callback + + request_metrics.name = servlet_classname + + # Now trigger the callback. If it returns a response, we send it + # here. If it throws an exception, that is handled by the wrapper + # installed by @request_handler. + kwargs = intern_dict({ name: urllib.unquote(value).decode("UTF-8") if value else value for name, value in m.groupdict().items() @@ -265,17 +280,10 @@ class JsonResource(HttpServer, resource.Resource): code, response = callback_return self._send_response(request, code, response) - servlet_instance = getattr(callback, "__self__", None) - if servlet_instance is not None: - servlet_classname = servlet_instance.__class__.__name__ - else: - servlet_classname = "%r" % callback - - request_metrics.name = servlet_classname - return # Huh. No one wanted to handle that? Fiiiiiine. Send 400. + request_metrics.name = self.__class__.__name__ + ".UnrecognizedRequest" raise UnrecognizedRequestError() def _send_response(self, request, code, response_json_object, diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 5669ecb724..45844aa2d2 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -191,19 +191,25 @@ class LoginRestServlet(ClientV1RestServlet): # convert threepid identifiers to user IDs if identifier["type"] == "m.id.thirdparty": - if 'medium' not in identifier or 'address' not in identifier: + address = identifier.get('address') + medium = identifier.get('medium') + + if medium is None or address is None: raise SynapseError(400, "Invalid thirdparty identifier") - address = identifier['address'] - if identifier['medium'] == 'email': + if medium == 'email': # For emails, transform the address to lowercase. # We store all email addreses as lowercase in the DB. # (See add_threepid in synapse/handlers/auth.py) address = address.lower() user_id = yield self.hs.get_datastore().get_user_id_by_threepid( - identifier['medium'], address + medium, address, ) if not user_id: + logger.warn( + "unknown 3pid identifier medium %s, address %r", + medium, address, + ) raise LoginError(403, "", errcode=Codes.FORBIDDEN) identifier = { diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py index 1f5bc24cc3..77434937ff 100644 --- a/synapse/rest/client/v2_alpha/_base.py +++ b/synapse/rest/client/v2_alpha/_base.py @@ -15,12 +15,13 @@ """This module contains base REST classes for constructing client v1 servlets. """ - -from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX +import logging import re -import logging +from twisted.internet import defer +from synapse.api.errors import InteractiveAuthIncompleteError +from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX logger = logging.getLogger(__name__) @@ -57,3 +58,37 @@ def set_timeline_upper_limit(filter_json, filter_timeline_limit): filter_json['room']['timeline']["limit"] = min( filter_json['room']['timeline']['limit'], filter_timeline_limit) + + +def interactive_auth_handler(orig): + """Wraps an on_POST method to handle InteractiveAuthIncompleteErrors + + Takes a on_POST method which returns a deferred (errcode, body) response + and adds exception handling to turn a InteractiveAuthIncompleteError into + a 401 response. + + Normal usage is: + + @interactive_auth_handler + @defer.inlineCallbacks + def on_POST(self, request): + # ... + yield self.auth_handler.check_auth + """ + def wrapped(*args, **kwargs): + res = defer.maybeDeferred(orig, *args, **kwargs) + res.addErrback(_catch_incomplete_interactive_auth) + return res + return wrapped + + +def _catch_incomplete_interactive_auth(f): + """helper for interactive_auth_handler + + Catches InteractiveAuthIncompleteErrors and turns them into 401 responses + + Args: + f (failure.Failure): + """ + f.trap(InteractiveAuthIncompleteError) + return 401, f.value.result diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index c26ce63bcf..385a3ad2ec 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -19,14 +19,14 @@ from twisted.internet import defer from synapse.api.auth import has_access_token from synapse.api.constants import LoginType -from synapse.api.errors import Codes, LoginError, SynapseError +from synapse.api.errors import Codes, SynapseError from synapse.http.servlet import ( RestServlet, assert_params_in_request, parse_json_object_from_request, ) from synapse.util.async import run_on_reactor from synapse.util.msisdn import phone_number_to_msisdn -from ._base import client_v2_patterns +from ._base import client_v2_patterns, interactive_auth_handler logger = logging.getLogger(__name__) @@ -100,49 +100,53 @@ class PasswordRestServlet(RestServlet): self.datastore = self.hs.get_datastore() self._set_password_handler = hs.get_set_password_handler() + @interactive_auth_handler @defer.inlineCallbacks def on_POST(self, request): - yield run_on_reactor() - body = parse_json_object_from_request(request) - authed, result, params, _ = yield self.auth_handler.check_auth([ - [LoginType.PASSWORD], - [LoginType.EMAIL_IDENTITY], - [LoginType.MSISDN], - ], body, self.hs.get_ip_from_request(request)) - - if not authed: - defer.returnValue((401, result)) + # there are two possibilities here. Either the user does not have an + # access token, and needs to do a password reset; or they have one and + # need to validate their identity. + # + # In the first case, we offer a couple of means of identifying + # themselves (email and msisdn, though it's unclear if msisdn actually + # works). + # + # In the second case, we require a password to confirm their identity. - user_id = None - requester = None - - if LoginType.PASSWORD in result: - # if using password, they should also be logged in + if has_access_token(request): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() - if user_id != result[LoginType.PASSWORD]: - raise LoginError(400, "", Codes.UNKNOWN) - elif LoginType.EMAIL_IDENTITY in result: - threepid = result[LoginType.EMAIL_IDENTITY] - if 'medium' not in threepid or 'address' not in threepid: - raise SynapseError(500, "Malformed threepid") - if threepid['medium'] == 'email': - # For emails, transform the address to lowercase. - # We store all email addreses as lowercase in the DB. - # (See add_threepid in synapse/handlers/auth.py) - threepid['address'] = threepid['address'].lower() - # if using email, we must know about the email they're authing with! - threepid_user_id = yield self.datastore.get_user_id_by_threepid( - threepid['medium'], threepid['address'] + params = yield self.auth_handler.validate_user_via_ui_auth( + requester, body, self.hs.get_ip_from_request(request), ) - if not threepid_user_id: - raise SynapseError(404, "Email address not found", Codes.NOT_FOUND) - user_id = threepid_user_id + user_id = requester.user.to_string() else: - logger.error("Auth succeeded but no known type!", result.keys()) - raise SynapseError(500, "", Codes.UNKNOWN) + requester = None + result, params, _ = yield self.auth_handler.check_auth( + [[LoginType.EMAIL_IDENTITY], [LoginType.MSISDN]], + body, self.hs.get_ip_from_request(request), + ) + + if LoginType.EMAIL_IDENTITY in result: + threepid = result[LoginType.EMAIL_IDENTITY] + if 'medium' not in threepid or 'address' not in threepid: + raise SynapseError(500, "Malformed threepid") + if threepid['medium'] == 'email': + # For emails, transform the address to lowercase. + # We store all email addreses as lowercase in the DB. + # (See add_threepid in synapse/handlers/auth.py) + threepid['address'] = threepid['address'].lower() + # if using email, we must know about the email they're authing with! + threepid_user_id = yield self.datastore.get_user_id_by_threepid( + threepid['medium'], threepid['address'] + ) + if not threepid_user_id: + raise SynapseError(404, "Email address not found", Codes.NOT_FOUND) + user_id = threepid_user_id + else: + logger.error("Auth succeeded but no known type!", result.keys()) + raise SynapseError(500, "", Codes.UNKNOWN) if 'new_password' not in params: raise SynapseError(400, "", Codes.MISSING_PARAM) @@ -168,47 +172,26 @@ class DeactivateAccountRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() self._deactivate_account_handler = hs.get_deactivate_account_handler() + @interactive_auth_handler @defer.inlineCallbacks def on_POST(self, request): body = parse_json_object_from_request(request) - # if the caller provides an access token, it ought to be valid. - requester = None - if has_access_token(request): - requester = yield self.auth.get_user_by_req( - request, - ) # type: synapse.types.Requester + requester = yield self.auth.get_user_by_req(request) # allow ASes to dectivate their own users - if requester and requester.app_service: + if requester.app_service: yield self._deactivate_account_handler.deactivate_account( requester.user.to_string() ) defer.returnValue((200, {})) - authed, result, params, _ = yield self.auth_handler.check_auth([ - [LoginType.PASSWORD], - ], body, self.hs.get_ip_from_request(request)) - - if not authed: - defer.returnValue((401, result)) - - if LoginType.PASSWORD in result: - user_id = result[LoginType.PASSWORD] - # if using password, they should also be logged in - if requester is None: - raise SynapseError( - 400, - "Deactivate account requires an access_token", - errcode=Codes.MISSING_TOKEN - ) - if requester.user.to_string() != user_id: - raise LoginError(400, "", Codes.UNKNOWN) - else: - logger.error("Auth succeeded but no known type!", result.keys()) - raise SynapseError(500, "", Codes.UNKNOWN) - - yield self._deactivate_account_handler.deactivate_account(user_id) + yield self.auth_handler.validate_user_via_ui_auth( + requester, body, self.hs.get_ip_from_request(request), + ) + yield self._deactivate_account_handler.deactivate_account( + requester.user.to_string(), + ) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index 5321e5abbb..35d58b367a 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -17,9 +17,9 @@ import logging from twisted.internet import defer -from synapse.api import constants, errors +from synapse.api import errors from synapse.http import servlet -from ._base import client_v2_patterns +from ._base import client_v2_patterns, interactive_auth_handler logger = logging.getLogger(__name__) @@ -60,8 +60,11 @@ class DeleteDevicesRestServlet(servlet.RestServlet): self.device_handler = hs.get_device_handler() self.auth_handler = hs.get_auth_handler() + @interactive_auth_handler @defer.inlineCallbacks def on_POST(self, request): + requester = yield self.auth.get_user_by_req(request) + try: body = servlet.parse_json_object_from_request(request) except errors.SynapseError as e: @@ -77,14 +80,10 @@ class DeleteDevicesRestServlet(servlet.RestServlet): 400, "No devices supplied", errcode=errors.Codes.MISSING_PARAM ) - authed, result, params, _ = yield self.auth_handler.check_auth([ - [constants.LoginType.PASSWORD], - ], body, self.hs.get_ip_from_request(request)) - - if not authed: - defer.returnValue((401, result)) + yield self.auth_handler.validate_user_via_ui_auth( + requester, body, self.hs.get_ip_from_request(request), + ) - requester = yield self.auth.get_user_by_req(request) yield self.device_handler.delete_devices( requester.user.to_string(), body['devices'], @@ -115,6 +114,7 @@ class DeviceRestServlet(servlet.RestServlet): ) defer.returnValue((200, device)) + @interactive_auth_handler @defer.inlineCallbacks def on_DELETE(self, request, device_id): requester = yield self.auth.get_user_by_req(request) @@ -130,19 +130,13 @@ class DeviceRestServlet(servlet.RestServlet): else: raise - authed, result, params, _ = yield self.auth_handler.check_auth([ - [constants.LoginType.PASSWORD], - ], body, self.hs.get_ip_from_request(request)) - - if not authed: - defer.returnValue((401, result)) - - # check that the UI auth matched the access token - user_id = result[constants.LoginType.PASSWORD] - if user_id != requester.user.to_string(): - raise errors.AuthError(403, "Invalid auth") + yield self.auth_handler.validate_user_via_ui_auth( + requester, body, self.hs.get_ip_from_request(request), + ) - yield self.device_handler.delete_device(user_id, device_id) + yield self.device_handler.delete_device( + requester.user.to_string(), device_id, + ) defer.returnValue((200, {})) @defer.inlineCallbacks diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 9e2f7308ce..e9d88a8895 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -27,7 +27,7 @@ from synapse.http.servlet import ( ) from synapse.util.msisdn import phone_number_to_msisdn -from ._base import client_v2_patterns +from ._base import client_v2_patterns, interactive_auth_handler import logging import hmac @@ -176,6 +176,7 @@ class RegisterRestServlet(RestServlet): self.device_handler = hs.get_device_handler() self.macaroon_gen = hs.get_macaroon_generator() + @interactive_auth_handler @defer.inlineCallbacks def on_POST(self, request): yield run_on_reactor() @@ -325,14 +326,10 @@ class RegisterRestServlet(RestServlet): [LoginType.MSISDN, LoginType.EMAIL_IDENTITY], ]) - authed, auth_result, params, session_id = yield self.auth_handler.check_auth( + auth_result, params, session_id = yield self.auth_handler.check_auth( flows, body, self.hs.get_ip_from_request(request) ) - if not authed: - defer.returnValue((401, auth_result)) - return - if registered_user_id is not None: logger.info( "Already registered user ID %r for this session", diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index 68d56b2b10..70dbf7f5c9 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -85,11 +85,6 @@ class ThumbnailResource(Resource): respond_404(request) return - # if media_info["media_type"] == "image/svg+xml": - # file_path = self.filepaths.local_media_filepath(media_id) - # yield respond_with_file(request, media_info["media_type"], file_path) - # return - thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id) if thumbnail_infos: @@ -114,9 +109,7 @@ class ThumbnailResource(Resource): yield respond_with_file(request, t_type, file_path) else: - yield self._respond_default_thumbnail( - request, media_info, width, height, method, m_type, - ) + respond_404(request) @defer.inlineCallbacks def _select_or_generate_local_thumbnail(self, request, media_id, desired_width, @@ -128,11 +121,6 @@ class ThumbnailResource(Resource): respond_404(request) return - # if media_info["media_type"] == "image/svg+xml": - # file_path = self.filepaths.local_media_filepath(media_id) - # yield respond_with_file(request, media_info["media_type"], file_path) - # return - thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id) for info in thumbnail_infos: t_w = info["thumbnail_width"] == desired_width @@ -166,10 +154,7 @@ class ThumbnailResource(Resource): if file_path: yield respond_with_file(request, desired_type, file_path) else: - yield self._respond_default_thumbnail( - request, media_info, desired_width, desired_height, - desired_method, desired_type, - ) + respond_404(request) @defer.inlineCallbacks def _select_or_generate_remote_thumbnail(self, request, server_name, media_id, @@ -177,11 +162,6 @@ class ThumbnailResource(Resource): desired_method, desired_type): media_info = yield self.media_repo.get_remote_media(server_name, media_id) - # if media_info["media_type"] == "image/svg+xml": - # file_path = self.filepaths.remote_media_filepath(server_name, media_id) - # yield respond_with_file(request, media_info["media_type"], file_path) - # return - thumbnail_infos = yield self.store.get_remote_media_thumbnails( server_name, media_id, ) @@ -213,23 +193,11 @@ class ThumbnailResource(Resource): if file_path: yield respond_with_file(request, desired_type, file_path) else: - yield self._respond_default_thumbnail( - request, media_info, desired_width, desired_height, - desired_method, desired_type, - ) + respond_404(request) @defer.inlineCallbacks def _respond_remote_thumbnail(self, request, server_name, media_id, width, height, method, m_type): - # TODO: Don't download the whole remote file - # We should proxy the thumbnail from the remote server instead. - media_info = yield self.media_repo.get_remote_media(server_name, media_id) - - # if media_info["media_type"] == "image/svg+xml": - # file_path = self.filepaths.remote_media_filepath(server_name, media_id) - # yield respond_with_file(request, media_info["media_type"], file_path) - # return - thumbnail_infos = yield self.store.get_remote_media_thumbnails( server_name, media_id, ) @@ -250,47 +218,7 @@ class ThumbnailResource(Resource): ) yield respond_with_file(request, t_type, file_path, t_length) else: - yield self._respond_default_thumbnail( - request, media_info, width, height, method, m_type, - ) - - @defer.inlineCallbacks - def _respond_default_thumbnail(self, request, media_info, width, height, - method, m_type): - # XXX: how is this meant to work? store.get_default_thumbnails - # appears to always return [] so won't this always 404? - media_type = media_info["media_type"] - top_level_type = media_type.split("/")[0] - sub_type = media_type.split("/")[-1].split(";")[0] - thumbnail_infos = yield self.store.get_default_thumbnails( - top_level_type, sub_type, - ) - if not thumbnail_infos: - thumbnail_infos = yield self.store.get_default_thumbnails( - top_level_type, "_default", - ) - if not thumbnail_infos: - thumbnail_infos = yield self.store.get_default_thumbnails( - "_default", "_default", - ) - if not thumbnail_infos: respond_404(request) - return - - thumbnail_info = self._select_thumbnail( - width, height, "crop", m_type, thumbnail_infos - ) - - t_width = thumbnail_info["thumbnail_width"] - t_height = thumbnail_info["thumbnail_height"] - t_type = thumbnail_info["thumbnail_type"] - t_method = thumbnail_info["thumbnail_method"] - t_length = thumbnail_info["thumbnail_length"] - - file_path = self.filepaths.default_thumbnail( - top_level_type, sub_type, t_width, t_height, t_type, t_method, - ) - yield respond_with_file(request, t_type, file_path, t_length) def _select_thumbnail(self, desired_width, desired_height, desired_method, desired_type, thumbnail_infos): diff --git a/synapse/server.py b/synapse/server.py index 18c72d21a8..99693071b6 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -52,7 +52,7 @@ from synapse.handlers.events import EventHandler, EventStreamHandler from synapse.handlers.initial_sync import InitialSyncHandler from synapse.handlers.receipts import ReceiptsHandler from synapse.handlers.read_marker import ReadMarkerHandler -from synapse.handlers.user_directory import UserDirectoyHandler +from synapse.handlers.user_directory import UserDirectoryHandler from synapse.handlers.groups_local import GroupsLocalHandler from synapse.handlers.profile import ProfileHandler from synapse.groups.groups_server import GroupsServerHandler @@ -339,7 +339,7 @@ class HomeServer(object): return ActionGenerator(self) def build_user_directory_handler(self): - return UserDirectoyHandler(self) + return UserDirectoryHandler(self) def build_groups_local_handler(self): return GroupsLocalHandler(self) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 20fa25895d..b971f0cb18 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -554,7 +554,7 @@ class SQLBaseStore(object): def _simple_select_one(self, table, keyvalues, retcols, allow_none=False, desc="_simple_select_one"): """Executes a SELECT query on the named table, which is expected to - return a single row, returning a single column from it. + return a single row, returning multiple columns from it. Args: table : string giving the table name diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py index a66ff7c1e0..6ebc372498 100644 --- a/synapse/storage/media_repository.py +++ b/synapse/storage/media_repository.py @@ -29,9 +29,6 @@ class MediaRepositoryStore(BackgroundUpdateStore): where_clause='url_cache IS NOT NULL', ) - def get_default_thumbnails(self, top_level_type, sub_type): - return [] - def get_local_media(self, media_id): """Get the metadata for a local piece of media Returns: diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py index beea3102fc..ec02e73bc2 100644 --- a/synapse/storage/profile.py +++ b/synapse/storage/profile.py @@ -15,6 +15,9 @@ from twisted.internet import defer +from synapse.storage.roommember import ProfileInfo +from synapse.api.errors import StoreError + from ._base import SQLBaseStore @@ -26,6 +29,30 @@ class ProfileStore(SQLBaseStore): desc="create_profile", ) + @defer.inlineCallbacks + def get_profileinfo(self, user_localpart): + try: + profile = yield self._simple_select_one( + table="profiles", + keyvalues={"user_id": user_localpart}, + retcols=("displayname", "avatar_url"), + desc="get_profileinfo", + ) + except StoreError as e: + if e.code == 404: + # no match + defer.returnValue(ProfileInfo(None, None)) + return + else: + raise + + defer.returnValue( + ProfileInfo( + avatar_url=profile['avatar_url'], + display_name=profile['displayname'], + ) + ) + def get_profile_displayname(self, user_localpart): return self._simple_select_one_onecol( table="profiles", diff --git a/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql b/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql new file mode 100644 index 0000000000..cb0d5a2576 --- /dev/null +++ b/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql @@ -0,0 +1,35 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- change the user_directory table to also cover global local user profiles +-- rather than just profiles within specific rooms. + +CREATE TABLE user_directory2 ( + user_id TEXT NOT NULL, + room_id TEXT, + display_name TEXT, + avatar_url TEXT +); + +INSERT INTO user_directory2(user_id, room_id, display_name, avatar_url) + SELECT user_id, room_id, display_name, avatar_url from user_directory; + +DROP TABLE user_directory; +ALTER TABLE user_directory2 RENAME TO user_directory; + +-- create indexes after doing the inserts because that's more efficient. +-- it also means we can give it the same name as the old one without renaming. +CREATE INDEX user_directory_room_idx ON user_directory(room_id); +CREATE UNIQUE INDEX user_directory_user_idx ON user_directory(user_id); diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py index 5dc5b9582a..c9bff408ef 100644 --- a/synapse/storage/user_directory.py +++ b/synapse/storage/user_directory.py @@ -164,7 +164,7 @@ class UserDirectoryStore(SQLBaseStore): ) if isinstance(self.database_engine, PostgresEngine): - # We weight the loclpart most highly, then display name and finally + # We weight the localpart most highly, then display name and finally # server name if new_entry: sql = """ @@ -317,6 +317,16 @@ class UserDirectoryStore(SQLBaseStore): rows = yield self._execute("get_all_rooms", None, sql) defer.returnValue([room_id for room_id, in rows]) + @defer.inlineCallbacks + def get_all_local_users(self): + """Get all local users + """ + sql = """ + SELECT name FROM users + """ + rows = yield self._execute("get_all_local_users", None, sql) + defer.returnValue([name for name, in rows]) + def add_users_who_share_room(self, room_id, share_private, user_id_tuples): """Insert entries into the users_who_share_rooms table. The first user should be a local user. @@ -629,6 +639,20 @@ class UserDirectoryStore(SQLBaseStore): ] } """ + + if self.hs.config.user_directory_search_all_users: + join_clause = "" + where_clause = "?<>''" # naughty hack to keep the same number of binds + else: + join_clause = """ + LEFT JOIN users_in_public_rooms AS p USING (user_id) + LEFT JOIN ( + SELECT other_user_id AS user_id FROM users_who_share_rooms + WHERE user_id = ? AND share_private + ) AS s USING (user_id) + """ + where_clause = "(s.user_id IS NOT NULL OR p.user_id IS NOT NULL)" + if isinstance(self.database_engine, PostgresEngine): full_query, exact_query, prefix_query = _parse_query_postgres(search_term) @@ -641,13 +665,9 @@ class UserDirectoryStore(SQLBaseStore): SELECT d.user_id, display_name, avatar_url FROM user_directory_search INNER JOIN user_directory AS d USING (user_id) - LEFT JOIN users_in_public_rooms AS p USING (user_id) - LEFT JOIN ( - SELECT other_user_id AS user_id FROM users_who_share_rooms - WHERE user_id = ? AND share_private - ) AS s USING (user_id) + %s WHERE - (s.user_id IS NOT NULL OR p.user_id IS NOT NULL) + %s AND vector @@ to_tsquery('english', ?) ORDER BY (CASE WHEN s.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END) @@ -671,7 +691,7 @@ class UserDirectoryStore(SQLBaseStore): display_name IS NULL, avatar_url IS NULL LIMIT ? - """ + """ % (join_clause, where_clause) args = (user_id, full_query, exact_query, prefix_query, limit + 1,) elif isinstance(self.database_engine, Sqlite3Engine): search_query = _parse_query_sqlite(search_term) @@ -680,20 +700,16 @@ class UserDirectoryStore(SQLBaseStore): SELECT d.user_id, display_name, avatar_url FROM user_directory_search INNER JOIN user_directory AS d USING (user_id) - LEFT JOIN users_in_public_rooms AS p USING (user_id) - LEFT JOIN ( - SELECT other_user_id AS user_id FROM users_who_share_rooms - WHERE user_id = ? AND share_private - ) AS s USING (user_id) + %s WHERE - (s.user_id IS NOT NULL OR p.user_id IS NOT NULL) + %s AND value MATCH ? ORDER BY rank(matchinfo(user_directory_search)) DESC, display_name IS NULL, avatar_url IS NULL LIMIT ? - """ + """ % (join_clause, where_clause) args = (user_id, search_query, limit + 1) else: # This should be unreachable. @@ -723,7 +739,7 @@ def _parse_query_sqlite(search_term): # Pull out the individual words, discarding any non-word characters. results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) - return " & ".join("(%s* | %s)" % (result, result,) for result in results) + return " & ".join("(%s* OR %s)" % (result, result,) for result in results) def _parse_query_postgres(search_term): |