diff options
Diffstat (limited to 'synapse')
60 files changed, 1397 insertions, 643 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py index e04208961f..8c3d7a210a 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -16,4 +16,4 @@ """ This is a reference implementation of a Matrix home server. """ -__version__ = "0.25.0" +__version__ = "0.25.1" diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 72858cca1f..ac0a3655a5 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -270,7 +270,11 @@ class Auth(object): rights (str): The operation being performed; the access token must allow this. Returns: - dict : dict that includes the user and the ID of their access token. + Deferred[dict]: dict that includes: + `user` (UserID) + `is_guest` (bool) + `token_id` (int|None): access token id. May be None if guest + `device_id` (str|None): device corresponding to access token Raises: AuthError if no user by that token exists or the token is invalid. """ diff --git a/synapse/api/errors.py b/synapse/api/errors.py index d0dfa959dc..79b35b3e7c 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -140,6 +140,22 @@ class RegistrationError(SynapseError): pass +class InteractiveAuthIncompleteError(Exception): + """An error raised when UI auth is not yet complete + + (This indicates we should return a 401 with 'result' as the body) + + Attributes: + result (dict): the server response to the request, which should be + passed back to the client + """ + def __init__(self, result): + super(InteractiveAuthIncompleteError, self).__init__( + "Interactive auth not yet complete", + ) + self.result = result + + class UnrecognizedRequestError(SynapseError): """An error indicating we don't understand the request you're trying to make""" def __init__(self, *args, **kwargs): diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 9477737759..e4318cdfc3 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -25,7 +25,9 @@ except Exception: from daemonize import Daemonize from synapse.util import PreserveLoggingContext from synapse.util.rlimit import change_resource_limit -from twisted.internet import reactor +from twisted.internet import error, reactor + +logger = logging.getLogger(__name__) def start_worker_reactor(appname, config): @@ -120,3 +122,57 @@ def quit_with_error(error_string): sys.stderr.write(" %s\n" % (line.rstrip(),)) sys.stderr.write("*" * line_length + '\n') sys.exit(1) + + +def listen_tcp(bind_addresses, port, factory, backlog=50): + """ + Create a TCP socket for a port and several addresses + """ + for address in bind_addresses: + try: + reactor.listenTCP( + port, + factory, + backlog, + address + ) + except error.CannotListenError as e: + check_bind_error(e, address, bind_addresses) + + +def listen_ssl(bind_addresses, port, factory, context_factory, backlog=50): + """ + Create an SSL socket for a port and several addresses + """ + for address in bind_addresses: + try: + reactor.listenSSL( + port, + factory, + context_factory, + backlog, + address + ) + except error.CannotListenError as e: + check_bind_error(e, address, bind_addresses) + + +def check_bind_error(e, address, bind_addresses): + """ + This method checks an exception occurred while binding on 0.0.0.0. + If :: is specified in the bind addresses a warning is shown. + The exception is still raised otherwise. + + Binding on both 0.0.0.0 and :: causes an exception on Linux and macOS + because :: binds on both IPv4 and IPv6 (as per RFC 3493). + When binding on 0.0.0.0 after :: this can safely be ignored. + + Args: + e (Exception): Exception that was caught. + address (str): Address on which binding was attempted. + bind_addresses (list): Addresses on which the service listens. + """ + if address == '0.0.0.0' and '::' in bind_addresses: + logger.warn('Failed to listen on 0.0.0.0, continuing because listening on [::]') + else: + raise e diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py index ba2657bbad..7d0c2879ae 100644 --- a/synapse/app/appservice.py +++ b/synapse/app/appservice.py @@ -79,17 +79,16 @@ class AppserviceServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - for address in bind_addresses: - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=address + _base.listen_tcp( + bind_addresses, + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, ) + ) logger.info("Synapse appservice now listening on port %d", port) @@ -98,18 +97,15 @@ class AppserviceServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_addresses = listener["bind_addresses"] - - for address in bind_addresses: - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=address + _base.listen_tcp( + listener["bind_addresses"], + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, ) + ) else: logger.warn("Unrecognized listener type: %s", listener["type"]) diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py index 129cfa901f..dc3f6efd43 100644 --- a/synapse/app/client_reader.py +++ b/synapse/app/client_reader.py @@ -103,17 +103,16 @@ class ClientReaderServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - for address in bind_addresses: - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=address + _base.listen_tcp( + bind_addresses, + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, ) + ) logger.info("Synapse client reader now listening on port %d", port) @@ -122,18 +121,16 @@ class ClientReaderServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_addresses = listener["bind_addresses"] - - for address in bind_addresses: - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=address + _base.listen_tcp( + listener["bind_addresses"], + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, ) + ) + else: logger.warn("Unrecognized listener type: %s", listener["type"]) diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py index 40cebe6f4a..a072291e1f 100644 --- a/synapse/app/federation_reader.py +++ b/synapse/app/federation_reader.py @@ -92,17 +92,16 @@ class FederationReaderServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - for address in bind_addresses: - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=address + _base.listen_tcp( + bind_addresses, + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, ) + ) logger.info("Synapse federation reader now listening on port %d", port) @@ -111,18 +110,15 @@ class FederationReaderServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_addresses = listener["bind_addresses"] - - for address in bind_addresses: - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=address + _base.listen_tcp( + listener["bind_addresses"], + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, ) + ) else: logger.warn("Unrecognized listener type: %s", listener["type"]) diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index 389e3909d1..09e9488f06 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -106,17 +106,16 @@ class FederationSenderServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - for address in bind_addresses: - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=address + _base.listen_tcp( + bind_addresses, + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, ) + ) logger.info("Synapse federation_sender now listening on port %d", port) @@ -125,18 +124,15 @@ class FederationSenderServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_addresses = listener["bind_addresses"] - - for address in bind_addresses: - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=address + _base.listen_tcp( + listener["bind_addresses"], + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, ) + ) else: logger.warn("Unrecognized listener type: %s", listener["type"]) diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py index abc7ef5725..ae531c0aa4 100644 --- a/synapse/app/frontend_proxy.py +++ b/synapse/app/frontend_proxy.py @@ -157,17 +157,16 @@ class FrontendProxyServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - for address in bind_addresses: - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=address + _base.listen_tcp( + bind_addresses, + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, ) + ) logger.info("Synapse client reader now listening on port %d", port) @@ -176,18 +175,15 @@ class FrontendProxyServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_addresses = listener["bind_addresses"] - - for address in bind_addresses: - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=address + _base.listen_tcp( + listener["bind_addresses"], + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, ) + ) else: logger.warn("Unrecognized listener type: %s", listener["type"]) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 9e26146338..92ab3b311b 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -25,7 +25,7 @@ from synapse.api.urls import CONTENT_REPO_PREFIX, FEDERATION_PREFIX, \ LEGACY_MEDIA_PREFIX, MEDIA_PREFIX, SERVER_KEY_PREFIX, SERVER_KEY_V2_PREFIX, \ STATIC_PREFIX, WEB_CLIENT_PREFIX from synapse.app import _base -from synapse.app._base import quit_with_error +from synapse.app._base import quit_with_error, listen_ssl, listen_tcp from synapse.config._base import ConfigError from synapse.config.homeserver import HomeServerConfig from synapse.crypto import context_factory @@ -43,7 +43,6 @@ from synapse.rest import ClientRestResource from synapse.rest.key.v1.server_key_resource import LocalKey from synapse.rest.key.v2 import KeyApiV2Resource from synapse.rest.media.v0.content_repository import ContentRepoResource -from synapse.rest.media.v1.media_repository import MediaRepositoryResource from synapse.server import HomeServer from synapse.storage import are_all_users_on_domain from synapse.storage.engines import IncorrectDatabaseSetup, create_engine @@ -131,30 +130,29 @@ class SynapseHomeServer(HomeServer): root_resource = create_resource_tree(resources, root_resource) if tls: - for address in bind_addresses: - reactor.listenSSL( - port, - SynapseSite( - "synapse.access.https.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - self.tls_server_context_factory, - interface=address - ) + listen_ssl( + bind_addresses, + port, + SynapseSite( + "synapse.access.https.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + ), + self.tls_server_context_factory, + ) + else: - for address in bind_addresses: - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=address + listen_tcp( + bind_addresses, + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, ) + ) logger.info("Synapse now listening on port %d", port) def _configure_named_resource(self, name, compress=False): @@ -195,14 +193,19 @@ class SynapseHomeServer(HomeServer): }) if name in ["media", "federation", "client"]: - media_repo = MediaRepositoryResource(self) - resources.update({ - MEDIA_PREFIX: media_repo, - LEGACY_MEDIA_PREFIX: media_repo, - CONTENT_REPO_PREFIX: ContentRepoResource( - self, self.config.uploads_path - ), - }) + if self.get_config().enable_media_repo: + media_repo = self.get_media_repository_resource() + resources.update({ + MEDIA_PREFIX: media_repo, + LEGACY_MEDIA_PREFIX: media_repo, + CONTENT_REPO_PREFIX: ContentRepoResource( + self, self.config.uploads_path + ), + }) + elif name == "media": + raise ConfigError( + "'media' resource conflicts with enable_media_repo=False", + ) if name in ["keys", "federation"]: resources.update({ @@ -225,18 +228,15 @@ class SynapseHomeServer(HomeServer): if listener["type"] == "http": self._listener_http(config, listener) elif listener["type"] == "manhole": - bind_addresses = listener["bind_addresses"] - - for address in bind_addresses: - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=address + listen_tcp( + listener["bind_addresses"], + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, ) + ) elif listener["type"] == "replication": bind_addresses = listener["bind_addresses"] for address in bind_addresses: diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index 36c18bdbcb..eab1597aaa 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -35,7 +35,6 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto from synapse.replication.slave.storage.transactions import TransactionStore from synapse.replication.tcp.client import ReplicationClientHandler from synapse.rest.media.v0.content_repository import ContentRepoResource -from synapse.rest.media.v1.media_repository import MediaRepositoryResource from synapse.server import HomeServer from synapse.storage.engines import create_engine from synapse.storage.media_repository import MediaRepositoryStore @@ -89,7 +88,7 @@ class MediaRepositoryServer(HomeServer): if name == "metrics": resources[METRICS_PREFIX] = MetricsResource(self) elif name == "media": - media_repo = MediaRepositoryResource(self) + media_repo = self.get_media_repository_resource() resources.update({ MEDIA_PREFIX: media_repo, LEGACY_MEDIA_PREFIX: media_repo, @@ -100,17 +99,16 @@ class MediaRepositoryServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - for address in bind_addresses: - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=address + _base.listen_tcp( + bind_addresses, + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, ) + ) logger.info("Synapse media repository now listening on port %d", port) @@ -119,18 +117,15 @@ class MediaRepositoryServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_addresses = listener["bind_addresses"] - - for address in bind_addresses: - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=address + _base.listen_tcp( + listener["bind_addresses"], + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, ) + ) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -151,6 +146,13 @@ def start(config_options): assert config.worker_app == "synapse.app.media_repository" + if config.enable_media_repo: + _base.quit_with_error( + "enable_media_repo must be disabled in the main synapse process\n" + "before the media repo can be run in a separate worker.\n" + "Please add ``enable_media_repo: false`` to the main config\n" + ) + setup_logging(config, use_worker_options=True) events.USE_FROZEN_DICTS = config.use_frozen_dicts diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index db9a4d16f4..7fbbb0b0e1 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -114,17 +114,16 @@ class PusherServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - for address in bind_addresses: - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=address + _base.listen_tcp( + bind_addresses, + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, ) + ) logger.info("Synapse pusher now listening on port %d", port) @@ -133,18 +132,15 @@ class PusherServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_addresses = listener["bind_addresses"] - - for address in bind_addresses: - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=address + _base.listen_tcp( + listener["bind_addresses"], + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, ) + ) else: logger.warn("Unrecognized listener type: %s", listener["type"]) diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 576ac6fb7e..0abba3016e 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -288,17 +288,16 @@ class SynchrotronServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - for address in bind_addresses: - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=address + _base.listen_tcp( + bind_addresses, + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, ) + ) logger.info("Synapse synchrotron now listening on port %d", port) @@ -307,18 +306,15 @@ class SynchrotronServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_addresses = listener["bind_addresses"] - - for address in bind_addresses: - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=address + _base.listen_tcp( + listener["bind_addresses"], + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, ) + ) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -340,11 +336,10 @@ class SyncReplicationHandler(ReplicationClientHandler): self.store = hs.get_datastore() self.typing_handler = hs.get_typing_handler() + # NB this is a SynchrotronPresence, not a normal PresenceHandler self.presence_handler = hs.get_presence_handler() self.notifier = hs.get_notifier() - self.presence_handler.sync_callback = self.send_user_sync - def on_rdata(self, stream_name, token, rows): super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows) diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index be661a70c7..a48c4a2ae6 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -131,17 +131,16 @@ class UserDirectoryServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - for address in bind_addresses: - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=address + _base.listen_tcp( + bind_addresses, + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, ) + ) logger.info("Synapse user_dir now listening on port %d", port) @@ -150,18 +149,15 @@ class UserDirectoryServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_addresses = listener["bind_addresses"] - - for address in bind_addresses: - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=address + _base.listen_tcp( + listener["bind_addresses"], + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, ) + ) else: logger.warn("Unrecognized listener type: %s", listener["type"]) diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index b989007314..d5a7a5ce2f 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -14,6 +14,7 @@ # limitations under the License. from synapse.api.constants import EventTypes from synapse.util.caches.descriptors import cachedInlineCallbacks +from synapse.types import GroupID, get_domain_from_id from twisted.internet import defer @@ -81,12 +82,13 @@ class ApplicationService(object): # values. NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS] - def __init__(self, token, url=None, namespaces=None, hs_token=None, + def __init__(self, token, hostname, url=None, namespaces=None, hs_token=None, sender=None, id=None, protocols=None, rate_limited=True): self.token = token self.url = url self.hs_token = hs_token self.sender = sender + self.server_name = hostname self.namespaces = self._check_namespaces(namespaces) self.id = id @@ -125,6 +127,24 @@ class ApplicationService(object): raise ValueError( "Expected bool for 'exclusive' in ns '%s'" % ns ) + group_id = regex_obj.get("group_id") + if group_id: + if not isinstance(group_id, str): + raise ValueError( + "Expected string for 'group_id' in ns '%s'" % ns + ) + try: + GroupID.from_string(group_id) + except Exception: + raise ValueError( + "Expected valid group ID for 'group_id' in ns '%s'" % ns + ) + + if get_domain_from_id(group_id) != self.server_name: + raise ValueError( + "Expected 'group_id' to be this host in ns '%s'" % ns + ) + regex = regex_obj.get("regex") if isinstance(regex, basestring): regex_obj["regex"] = re.compile(regex) # Pre-compile regex @@ -251,6 +271,21 @@ class ApplicationService(object): if regex_obj["exclusive"] ] + def get_groups_for_user(self, user_id): + """Get the groups that this user is associated with by this AS + + Args: + user_id (str): The ID of the user. + + Returns: + iterable[str]: an iterable that yields group_id strings. + """ + return ( + regex_obj["group_id"] + for regex_obj in self.namespaces[ApplicationService.NS_USERS] + if "group_id" in regex_obj and regex_obj["regex"].match(user_id) + ) + def is_rate_limited(self): return self.rate_limited diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index 82c50b8240..aba0aec6e8 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -154,6 +154,7 @@ def _load_appservice(hostname, as_info, config_filename): ) return ApplicationService( token=as_info["as_token"], + hostname=hostname, url=as_info["url"], namespaces=as_info["namespaces"], hs_token=as_info["hs_token"], diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 05e242aef6..bf19cfee29 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -36,6 +36,7 @@ from .workers import WorkerConfig from .push import PushConfig from .spam_checker import SpamCheckerConfig from .groups import GroupsConfig +from .user_directory import UserDirectoryConfig class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, @@ -44,7 +45,7 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, AppServiceConfig, KeyConfig, SAML2Config, CasConfig, JWTConfig, PasswordConfig, EmailConfig, WorkerConfig, PasswordAuthProviderConfig, PushConfig, - SpamCheckerConfig, GroupsConfig,): + SpamCheckerConfig, GroupsConfig, UserDirectoryConfig,): pass diff --git a/synapse/config/server.py b/synapse/config/server.py index 4d9193536d..436dd8a6fe 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -41,6 +41,12 @@ class ServerConfig(Config): # false only if we are updating the user directory in a worker self.update_user_directory = config.get("update_user_directory", True) + # whether to enable the media repository endpoints. This should be set + # to false if the media repository is running as a separate endpoint; + # doing so ensures that we will not run cache cleanup jobs on the + # master, potentially causing inconsistency. + self.enable_media_repo = config.get("enable_media_repo", True) + self.filter_timeline_limit = config.get("filter_timeline_limit", -1) # Whether we should block invites sent to users on this server @@ -214,13 +220,12 @@ class ServerConfig(Config): port: %(bind_port)s # Local addresses to listen on. - # This will listen on all IPv4 addresses by default. + # On Linux and Mac OS, `::` will listen on all IPv4 and IPv6 + # addresses by default. For most other OSes, this will only listen + # on IPv6. bind_addresses: + - '::' - '0.0.0.0' - # Uncomment to listen on all IPv6 interfaces - # N.B: On at least Linux this will also listen on all IPv4 - # addresses, so you will need to comment out the line above. - # - '::' # This is a 'http' listener, allows us to specify 'resources'. type: http @@ -258,7 +263,7 @@ class ServerConfig(Config): # For when matrix traffic passes through loadbalancer that unwraps TLS. - port: %(unsecure_port)s tls: false - bind_addresses: ['0.0.0.0'] + bind_addresses: ['::', '0.0.0.0'] type: http x_forwarded: false @@ -272,7 +277,7 @@ class ServerConfig(Config): # Turn on the twisted ssh manhole service on localhost on the given # port. # - port: 9000 - # bind_address: 127.0.0.1 + # bind_addresses: ['::1', '127.0.0.1'] # type: manhole """ % locals() diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py new file mode 100644 index 0000000000..38e8947843 --- /dev/null +++ b/synapse/config/user_directory.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import Config + + +class UserDirectoryConfig(Config): + """User Directory Configuration + Configuration for the behaviour of the /user_directory API + """ + + def read_config(self, config): + self.user_directory_search_all_users = False + user_directory_config = config.get("user_directory", None) + if user_directory_config: + self.user_directory_search_all_users = ( + user_directory_config.get("search_all_users", False) + ) + + def default_config(self, config_dir_path, server_name, **kwargs): + return """ + # User Directory configuration + # + # 'search_all_users' defines whether to search all users visible to your HS + # when searching the user directory, rather than limiting to users visible + # in public rooms. Defaults to false. If you set it True, you'll have to run + # UPDATE user_directory_stream_pos SET stream_id = NULL; + # on your database to tell it to rebuild the user_directory search indexes. + # + #user_directory: + # search_all_users: false + """ diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py index 0d0e7b5286..aaa3efaca3 100644 --- a/synapse/crypto/event_signing.py +++ b/synapse/crypto/event_signing.py @@ -32,15 +32,22 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256): """Check whether the hash for this PDU matches the contents""" name, expected_hash = compute_content_hash(event, hash_algorithm) logger.debug("Expecting hash: %s", encode_base64(expected_hash)) - if name not in event.hashes: + + # some malformed events lack a 'hashes'. Protect against it being missing + # or a weird type by basically treating it the same as an unhashed event. + hashes = event.get("hashes") + if not isinstance(hashes, dict): + raise SynapseError(400, "Malformed 'hashes'", Codes.UNAUTHORIZED) + + if name not in hashes: raise SynapseError( 400, "Algorithm %s not in hashes %s" % ( - name, list(event.hashes), + name, list(hashes), ), Codes.UNAUTHORIZED, ) - message_hash_base64 = event.hashes[name] + message_hash_base64 = hashes[name] try: message_hash_bytes = decode_base64(message_hash_base64) except Exception: diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index 7a3c9cbb70..3e7809b04f 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -20,7 +20,7 @@ from .persistence import TransactionActions from .units import Transaction, Edu from synapse.api.errors import HttpResponseException -from synapse.util import logcontext +from synapse.util import logcontext, PreserveLoggingContext from synapse.util.async import run_on_reactor from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter from synapse.util.metrics import measure_func @@ -146,7 +146,6 @@ class TransactionQueue(object): else: return not destination.startswith("localhost") - @defer.inlineCallbacks def notify_new_events(self, current_id): """This gets called when we have some new events we might want to send out to other servers. @@ -156,6 +155,13 @@ class TransactionQueue(object): if self._is_processing: return + # fire off a processing loop in the background. It's likely it will + # outlast the current request, so run it in the sentinel logcontext. + with PreserveLoggingContext(): + self._process_event_queue_loop() + + @defer.inlineCallbacks + def _process_event_queue_loop(self): try: self._is_processing = True while True: diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 7a0ba6ef35..573c9db8a1 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -17,7 +17,10 @@ from twisted.internet import defer from ._base import BaseHandler from synapse.api.constants import LoginType -from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError +from synapse.api.errors import ( + AuthError, Codes, InteractiveAuthIncompleteError, LoginError, StoreError, + SynapseError, +) from synapse.module_api import ModuleApi from synapse.types import UserID from synapse.util.async import run_on_reactor @@ -46,7 +49,6 @@ class AuthHandler(BaseHandler): """ super(AuthHandler, self).__init__(hs) self.checkers = { - LoginType.PASSWORD: self._check_password_auth, LoginType.RECAPTCHA: self._check_recaptcha, LoginType.EMAIL_IDENTITY: self._check_email_identity, LoginType.MSISDN: self._check_msisdn, @@ -75,15 +77,76 @@ class AuthHandler(BaseHandler): self.macaroon_gen = hs.get_macaroon_generator() self._password_enabled = hs.config.password_enabled - login_types = set() + # we keep this as a list despite the O(N^2) implication so that we can + # keep PASSWORD first and avoid confusing clients which pick the first + # type in the list. (NB that the spec doesn't require us to do so and + # clients which favour types that they don't understand over those that + # they do are technically broken) + login_types = [] if self._password_enabled: - login_types.add(LoginType.PASSWORD) + login_types.append(LoginType.PASSWORD) for provider in self.password_providers: if hasattr(provider, "get_supported_login_types"): - login_types.update( - provider.get_supported_login_types().keys() - ) - self._supported_login_types = frozenset(login_types) + for t in provider.get_supported_login_types().keys(): + if t not in login_types: + login_types.append(t) + self._supported_login_types = login_types + + @defer.inlineCallbacks + def validate_user_via_ui_auth(self, requester, request_body, clientip): + """ + Checks that the user is who they claim to be, via a UI auth. + + This is used for things like device deletion and password reset where + the user already has a valid access token, but we want to double-check + that it isn't stolen by re-authenticating them. + + Args: + requester (Requester): The user, as given by the access token + + request_body (dict): The body of the request sent by the client + + clientip (str): The IP address of the client. + + Returns: + defer.Deferred[dict]: the parameters for this request (which may + have been given only in a previous call). + + Raises: + InteractiveAuthIncompleteError if the client has not yet completed + any of the permitted login flows + + AuthError if the client has completed a login flow, and it gives + a different user to `requester` + """ + + # build a list of supported flows + flows = [ + [login_type] for login_type in self._supported_login_types + ] + + result, params, _ = yield self.check_auth( + flows, request_body, clientip, + ) + + # find the completed login type + for login_type in self._supported_login_types: + if login_type not in result: + continue + + user_id = result[login_type] + break + else: + # this can't happen + raise Exception( + "check_auth returned True but no successful login type", + ) + + # check that the UI auth matched the access token + if user_id != requester.user.to_string(): + raise AuthError(403, "Invalid auth") + + defer.returnValue(params) @defer.inlineCallbacks def check_auth(self, flows, clientdict, clientip): @@ -95,26 +158,36 @@ class AuthHandler(BaseHandler): session with a map, which maps each auth-type (str) to the relevant identity authenticated by that auth-type (mostly str, but for captcha, bool). + If no auth flows have been completed successfully, raises an + InteractiveAuthIncompleteError. To handle this, you can use + synapse.rest.client.v2_alpha._base.interactive_auth_handler as a + decorator. + Args: flows (list): A list of login flows. Each flow is an ordered list of strings representing auth-types. At least one full flow must be completed in order for auth to be successful. + clientdict: The dictionary from the client root level, not the 'auth' key: this method prompts for auth if none is sent. + clientip (str): The IP address of the client. + Returns: - A tuple of (authed, dict, dict, session_id) where authed is true if - the client has successfully completed an auth flow. If it is true - the first dict contains the authenticated credentials of each stage. + defer.Deferred[dict, dict, str]: a deferred tuple of + (creds, params, session_id). - If authed is false, the first dictionary is the server response to - the login request and should be passed back to the client. + 'creds' contains the authenticated credentials of each stage. - In either case, the second dict contains the parameters for this - request (which may have been given only in a previous call). + 'params' contains the parameters for this request (which may + have been given only in a previous call). - session_id is the ID of this session, either passed in by the client - or assigned by the call to check_auth + 'session_id' is the ID of this session, either passed in by the + client or assigned by this call + + Raises: + InteractiveAuthIncompleteError if the client has not yet completed + all the stages in any of the permitted flows. """ authdict = None @@ -142,11 +215,8 @@ class AuthHandler(BaseHandler): clientdict = session['clientdict'] if not authdict: - defer.returnValue( - ( - False, self._auth_dict_for_flows(flows, session), - clientdict, session['id'] - ) + raise InteractiveAuthIncompleteError( + self._auth_dict_for_flows(flows, session), ) if 'creds' not in session: @@ -157,14 +227,12 @@ class AuthHandler(BaseHandler): errordict = {} if 'type' in authdict: login_type = authdict['type'] - if login_type not in self.checkers: - raise LoginError(400, "", Codes.UNRECOGNIZED) try: - result = yield self.checkers[login_type](authdict, clientip) + result = yield self._check_auth_dict(authdict, clientip) if result: creds[login_type] = result self._save_session(session) - except LoginError, e: + except LoginError as e: if login_type == LoginType.EMAIL_IDENTITY: # riot used to have a bug where it would request a new # validation token (thus sending a new email) each time it @@ -173,7 +241,7 @@ class AuthHandler(BaseHandler): # # Grandfather in the old behaviour for now to avoid # breaking old riot deployments. - raise e + raise # this step failed. Merge the error dict into the response # so that the client can have another go. @@ -190,12 +258,14 @@ class AuthHandler(BaseHandler): "Auth completed with creds: %r. Client dict has keys: %r", creds, clientdict.keys() ) - defer.returnValue((True, creds, clientdict, session['id'])) + defer.returnValue((creds, clientdict, session['id'])) ret = self._auth_dict_for_flows(flows, session) ret['completed'] = creds.keys() ret.update(errordict) - defer.returnValue((False, ret, clientdict, session['id'])) + raise InteractiveAuthIncompleteError( + ret, + ) @defer.inlineCallbacks def add_oob_auth(self, stagetype, authdict, clientip): @@ -268,17 +338,35 @@ class AuthHandler(BaseHandler): return sess.setdefault('serverdict', {}).get(key, default) @defer.inlineCallbacks - def _check_password_auth(self, authdict, _): - if "user" not in authdict or "password" not in authdict: - raise LoginError(400, "", Codes.MISSING_PARAM) + def _check_auth_dict(self, authdict, clientip): + """Attempt to validate the auth dict provided by a client - user_id = authdict["user"] - password = authdict["password"] + Args: + authdict (object): auth dict provided by the client + clientip (str): IP address of the client - (canonical_id, callback) = yield self.validate_login(user_id, { - "type": LoginType.PASSWORD, - "password": password, - }) + Returns: + Deferred: result of the stage verification. + + Raises: + StoreError if there was a problem accessing the database + SynapseError if there was a problem with the request + LoginError if there was an authentication problem. + """ + login_type = authdict['type'] + checker = self.checkers.get(login_type) + if checker is not None: + res = yield checker(authdict, clientip) + defer.returnValue(res) + + # build a v1-login-style dict out of the authdict and fall back to the + # v1 code + user_id = authdict.get("user") + + if user_id is None: + raise SynapseError(400, "", Codes.MISSING_PARAM) + + (canonical_id, callback) = yield self.validate_login(user_id, authdict) defer.returnValue(canonical_id) @defer.inlineCallbacks @@ -551,7 +639,7 @@ class AuthHandler(BaseHandler): qualified_user_id, password, ) if is_valid: - defer.returnValue(qualified_user_id) + defer.returnValue((qualified_user_id, None)) if (not hasattr(provider, "get_supported_login_types") or not hasattr(provider, "check_auth")): @@ -650,41 +738,6 @@ class AuthHandler(BaseHandler): raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) @defer.inlineCallbacks - def set_password(self, user_id, newpassword, requester=None): - password_hash = self.hash(newpassword) - - except_access_token_id = requester.access_token_id if requester else None - - try: - yield self.store.user_set_password_hash(user_id, password_hash) - except StoreError as e: - if e.code == 404: - raise SynapseError(404, "Unknown user", Codes.NOT_FOUND) - raise e - yield self.delete_access_tokens_for_user( - user_id, except_token_id=except_access_token_id, - ) - yield self.hs.get_pusherpool().remove_pushers_by_user( - user_id, except_access_token_id - ) - - @defer.inlineCallbacks - def deactivate_account(self, user_id): - """Deactivate a user's account - - Args: - user_id (str): ID of user to be deactivated - - Returns: - Deferred - """ - # FIXME: Theoretically there is a race here wherein user resets - # password using threepid. - yield self.delete_access_tokens_for_user(user_id) - yield self.store.user_delete_threepids(user_id) - yield self.store.user_set_password_hash(user_id, None) - - @defer.inlineCallbacks def delete_access_token(self, access_token): """Invalidate a single access token @@ -706,6 +759,12 @@ class AuthHandler(BaseHandler): access_token=access_token, ) + # delete pushers associated with this access token + if user_info["token_id"] is not None: + yield self.hs.get_pusherpool().remove_pushers_by_access_token( + str(user_info["user"]), (user_info["token_id"], ) + ) + @defer.inlineCallbacks def delete_access_tokens_for_user(self, user_id, except_token_id=None, device_id=None): @@ -728,13 +787,18 @@ class AuthHandler(BaseHandler): # see if any of our auth providers want to know about this for provider in self.password_providers: if hasattr(provider, "on_logged_out"): - for token, device_id in tokens_and_devices: + for token, token_id, device_id in tokens_and_devices: yield provider.on_logged_out( user_id=user_id, device_id=device_id, access_token=token, ) + # delete pushers associated with the access tokens + yield self.hs.get_pusherpool().remove_pushers_by_access_token( + user_id, (token_id for _, token_id, _ in tokens_and_devices), + ) + @defer.inlineCallbacks def add_threepid(self, user_id, medium, address, validated_at): # 'Canonicalise' email addresses down to lower case. diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py new file mode 100644 index 0000000000..b1d3814909 --- /dev/null +++ b/synapse/handlers/deactivate_account.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.internet import defer + +from ._base import BaseHandler + +import logging + +logger = logging.getLogger(__name__) + + +class DeactivateAccountHandler(BaseHandler): + """Handler which deals with deactivating user accounts.""" + def __init__(self, hs): + super(DeactivateAccountHandler, self).__init__(hs) + self._auth_handler = hs.get_auth_handler() + self._device_handler = hs.get_device_handler() + + @defer.inlineCallbacks + def deactivate_account(self, user_id): + """Deactivate a user's account + + Args: + user_id (str): ID of user to be deactivated + + Returns: + Deferred + """ + # FIXME: Theoretically there is a race here wherein user resets + # password using threepid. + + # first delete any devices belonging to the user, which will also + # delete corresponding access tokens. + yield self._device_handler.delete_all_devices_for_user(user_id) + # then delete any remaining access tokens which weren't associated with + # a device. + yield self._auth_handler.delete_access_tokens_for_user(user_id) + + yield self.store.user_delete_threepids(user_id) + yield self.store.user_set_password_hash(user_id, None) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 579d8477ba..2152efc692 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -171,12 +171,30 @@ class DeviceHandler(BaseHandler): yield self.notify_device_update(user_id, [device_id]) @defer.inlineCallbacks + def delete_all_devices_for_user(self, user_id, except_device_id=None): + """Delete all of the user's devices + + Args: + user_id (str): + except_device_id (str|None): optional device id which should not + be deleted + + Returns: + defer.Deferred: + """ + device_map = yield self.store.get_devices_by_user(user_id) + device_ids = device_map.keys() + if except_device_id is not None: + device_ids = [d for d in device_ids if d != except_device_id] + yield self.delete_devices(user_id, device_ids) + + @defer.inlineCallbacks def delete_devices(self, user_id, device_ids): """ Delete several devices Args: user_id (str): - device_ids (str): The list of device IDs to delete + device_ids (List[str]): The list of device IDs to delete Returns: defer.Deferred: diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index da00aeb0f4..7e5d3f148d 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -375,6 +375,12 @@ class GroupsLocalHandler(object): def get_publicised_groups_for_user(self, user_id): if self.hs.is_mine_id(user_id): result = yield self.store.get_publicised_groups_for_user(user_id) + + # Check AS associated groups for this user - this depends on the + # RegExps in the AS registration file (under `users`) + for app_service in self.store.get_app_services(): + result.extend(app_service.get_groups_for_user(user_id)) + defer.returnValue({"groups": result}) else: result = yield self.transport_client.get_publicised_groups_for_user( @@ -415,4 +421,9 @@ class GroupsLocalHandler(object): uid ) + # Check AS associated groups for this user - this depends on the + # RegExps in the AS registration file (under `users`) + for app_service in self.store.get_app_services(): + results[uid].extend(app_service.get_groups_for_user(uid)) + defer.returnValue({"users": results}) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index fa96ea69cd..cb158ba962 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -1199,7 +1199,7 @@ def handle_timeout(state, is_mine, syncing_user_ids, now): ) changed = True else: - # We expect to be poked occaisonally by the other side. + # We expect to be poked occasionally by the other side. # This is to protect against forgetful/buggy servers, so that # no one gets stuck online forever. if now - state.last_federation_update_ts > FEDERATION_TIMEOUT: diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 5e5b1952dd..9800e24453 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -36,6 +36,8 @@ class ProfileHandler(BaseHandler): "profile", self.on_profile_query ) + self.user_directory_handler = hs.get_user_directory_handler() + self.clock.looping_call(self._update_remote_profile_cache, self.PROFILE_UPDATE_MS) @defer.inlineCallbacks @@ -139,6 +141,12 @@ class ProfileHandler(BaseHandler): target_user.localpart, new_displayname ) + if self.hs.config.user_directory_search_all_users: + profile = yield self.store.get_profileinfo(target_user.localpart) + yield self.user_directory_handler.handle_local_profile_change( + target_user.to_string(), profile + ) + yield self._update_join_states(requester, target_user) @defer.inlineCallbacks @@ -183,6 +191,12 @@ class ProfileHandler(BaseHandler): target_user.localpart, new_avatar_url ) + if self.hs.config.user_directory_search_all_users: + profile = yield self.store.get_profileinfo(target_user.localpart) + yield self.user_directory_handler.handle_local_profile_change( + target_user.to_string(), profile + ) + yield self._update_join_states(requester, target_user) @defer.inlineCallbacks diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index f6e7e58563..4bc6ef51fe 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -38,6 +38,7 @@ class RegistrationHandler(BaseHandler): self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() self.profile_handler = hs.get_profile_handler() + self.user_directory_handler = hs.get_user_directory_handler() self.captcha_client = CaptchaServerHttpClient(hs) self._next_generated_user_id = None @@ -165,6 +166,13 @@ class RegistrationHandler(BaseHandler): ), admin=admin, ) + + if self.hs.config.user_directory_search_all_users: + profile = yield self.store.get_profileinfo(localpart) + yield self.user_directory_handler.handle_local_profile_change( + user_id, profile + ) + else: # autogen a sequential user ID attempts = 0 diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 496f1fc39b..d1cc87a016 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -205,12 +205,12 @@ class RoomCreationHandler(BaseHandler): }, ratelimit=False) - content = {} - is_direct = config.get("is_direct", None) - if is_direct: - content["is_direct"] = is_direct - for invitee in invite_list: + content = {} + is_direct = config.get("is_direct", None) + if is_direct: + content["is_direct"] = is_direct + yield room_member_handler.update_membership( requester, UserID.from_string(invitee), diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 970fec0666..7e6467cd1d 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -189,6 +189,10 @@ class RoomMemberHandler(BaseHandler): content_specified = bool(content) if content is None: content = {} + else: + # We do a copy here as we potentially change some keys + # later on. + content = dict(content) effective_membership_state = action if action in ["kick", "unban"]: diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py new file mode 100644 index 0000000000..44414e1dc1 --- /dev/null +++ b/synapse/handlers/set_password.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from twisted.internet import defer + +from synapse.api.errors import Codes, StoreError, SynapseError +from ._base import BaseHandler + +logger = logging.getLogger(__name__) + + +class SetPasswordHandler(BaseHandler): + """Handler which deals with changing user account passwords""" + def __init__(self, hs): + super(SetPasswordHandler, self).__init__(hs) + self._auth_handler = hs.get_auth_handler() + self._device_handler = hs.get_device_handler() + + @defer.inlineCallbacks + def set_password(self, user_id, newpassword, requester=None): + password_hash = self._auth_handler.hash(newpassword) + + except_device_id = requester.device_id if requester else None + except_access_token_id = requester.access_token_id if requester else None + + try: + yield self.store.user_set_password_hash(user_id, password_hash) + except StoreError as e: + if e.code == 404: + raise SynapseError(404, "Unknown user", Codes.NOT_FOUND) + raise e + + # we want to log out all of the user's other sessions. First delete + # all his other devices. + yield self._device_handler.delete_all_devices_for_user( + user_id, except_device_id=except_device_id, + ) + + # and now delete any access tokens which weren't associated with + # devices (or were associated with this device). + yield self._auth_handler.delete_access_tokens_for_user( + user_id, except_token_id=except_access_token_id, + ) diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index b5be5d9623..714f0195c8 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -20,12 +20,13 @@ from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.storage.roommember import ProfileInfo from synapse.util.metrics import Measure from synapse.util.async import sleep +from synapse.types import get_localpart_from_id logger = logging.getLogger(__name__) -class UserDirectoyHandler(object): +class UserDirectoryHandler(object): """Handles querying of and keeping updated the user_directory. N.B.: ASSUMES IT IS THE ONLY THING THAT MODIFIES THE USER DIRECTORY @@ -41,9 +42,10 @@ class UserDirectoyHandler(object): one public room. """ - INITIAL_SLEEP_MS = 50 - INITIAL_SLEEP_COUNT = 100 - INITIAL_BATCH_SIZE = 100 + INITIAL_ROOM_SLEEP_MS = 50 + INITIAL_ROOM_SLEEP_COUNT = 100 + INITIAL_ROOM_BATCH_SIZE = 100 + INITIAL_USER_SLEEP_MS = 10 def __init__(self, hs): self.store = hs.get_datastore() @@ -53,6 +55,7 @@ class UserDirectoyHandler(object): self.notifier = hs.get_notifier() self.is_mine_id = hs.is_mine_id self.update_user_directory = hs.config.update_user_directory + self.search_all_users = hs.config.user_directory_search_all_users # When start up for the first time we need to populate the user_directory. # This is a set of user_id's we've inserted already @@ -111,6 +114,15 @@ class UserDirectoyHandler(object): self._is_processing = False @defer.inlineCallbacks + def handle_local_profile_change(self, user_id, profile): + """Called to update index of our local user profiles when they change + irrespective of any rooms the user may be in. + """ + yield self.store.update_profile_in_user_dir( + user_id, profile.display_name, profile.avatar_url, None, + ) + + @defer.inlineCallbacks def _unsafe_process(self): # If self.pos is None then means we haven't fetched it from DB if self.pos is None: @@ -148,16 +160,30 @@ class UserDirectoyHandler(object): room_ids = yield self.store.get_all_rooms() logger.info("Doing initial update of user directory. %d rooms", len(room_ids)) - num_processed_rooms = 1 + num_processed_rooms = 0 for room_id in room_ids: - logger.info("Handling room %d/%d", num_processed_rooms, len(room_ids)) + logger.info("Handling room %d/%d", num_processed_rooms + 1, len(room_ids)) yield self._handle_initial_room(room_id) num_processed_rooms += 1 - yield sleep(self.INITIAL_SLEEP_MS / 1000.) + yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.) logger.info("Processed all rooms.") + if self.search_all_users: + num_processed_users = 0 + user_ids = yield self.store.get_all_local_users() + logger.info("Doing initial update of user directory. %d users", len(user_ids)) + for user_id in user_ids: + # We add profiles for all users even if they don't match the + # include pattern, just in case we want to change it in future + logger.info("Handling user %d/%d", num_processed_users + 1, len(user_ids)) + yield self._handle_local_user(user_id) + num_processed_users += 1 + yield sleep(self.INITIAL_USER_SLEEP_MS / 1000.) + + logger.info("Processed all users") + self.initially_handled_users = None self.initially_handled_users_in_public = None self.initially_handled_users_share = None @@ -201,8 +227,8 @@ class UserDirectoyHandler(object): to_update = set() count = 0 for user_id in user_ids: - if count % self.INITIAL_SLEEP_COUNT == 0: - yield sleep(self.INITIAL_SLEEP_MS / 1000.) + if count % self.INITIAL_ROOM_SLEEP_COUNT == 0: + yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.) if not self.is_mine_id(user_id): count += 1 @@ -216,8 +242,8 @@ class UserDirectoyHandler(object): if user_id == other_user_id: continue - if count % self.INITIAL_SLEEP_COUNT == 0: - yield sleep(self.INITIAL_SLEEP_MS / 1000.) + if count % self.INITIAL_ROOM_SLEEP_COUNT == 0: + yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.) count += 1 user_set = (user_id, other_user_id) @@ -237,13 +263,13 @@ class UserDirectoyHandler(object): else: self.initially_handled_users_share_private_room.add(user_set) - if len(to_insert) > self.INITIAL_BATCH_SIZE: + if len(to_insert) > self.INITIAL_ROOM_BATCH_SIZE: yield self.store.add_users_who_share_room( room_id, not is_public, to_insert, ) to_insert.clear() - if len(to_update) > self.INITIAL_BATCH_SIZE: + if len(to_update) > self.INITIAL_ROOM_BATCH_SIZE: yield self.store.update_users_who_share_room( room_id, not is_public, to_update, ) @@ -385,14 +411,28 @@ class UserDirectoyHandler(object): yield self._handle_remove_user(room_id, user_id) @defer.inlineCallbacks + def _handle_local_user(self, user_id): + """Adds a new local roomless user into the user_directory_search table. + Used to populate up the user index when we have an + user_directory_search_all_users specified. + """ + logger.debug("Adding new local user to dir, %r", user_id) + + profile = yield self.store.get_profileinfo(get_localpart_from_id(user_id)) + + row = yield self.store.get_user_in_directory(user_id) + if not row: + yield self.store.add_profiles_to_user_dir(None, {user_id: profile}) + + @defer.inlineCallbacks def _handle_new_user(self, room_id, user_id, profile): """Called when we might need to add user to directory Args: - room_id (str): room_id that user joined or started being public that + room_id (str): room_id that user joined or started being public user_id (str) """ - logger.debug("Adding user to dir, %r", user_id) + logger.debug("Adding new user to dir, %r", user_id) row = yield self.store.get_user_in_directory(user_id) if not row: @@ -407,7 +447,7 @@ class UserDirectoyHandler(object): if not row: yield self.store.add_users_to_public_room(room_id, [user_id]) else: - logger.debug("Not adding user to public dir, %r", user_id) + logger.debug("Not adding new user to public dir, %r", user_id) # Now we update users who share rooms with users. We do this by getting # all the current users in the room and seeing which aren't already diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py index a97532162f..e2b99ef3bd 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py @@ -362,8 +362,10 @@ def _get_hosts_for_srv_record(dns_client, host): return res # no logcontexts here, so we can safely fire these off and gatherResults - d1 = dns_client.lookupAddress(host).addCallbacks(cb, eb) - d2 = dns_client.lookupIPV6Address(host).addCallbacks(cb, eb) + d1 = dns_client.lookupAddress(host).addCallbacks( + cb, eb, errbackArgs=("A", )) + d2 = dns_client.lookupIPV6Address(host).addCallbacks( + cb, eb, errbackArgs=("AAAA", )) results = yield defer.DeferredList( [d1, d2], consumeErrors=True) diff --git a/synapse/http/server.py b/synapse/http/server.py index 3ca1c9947c..25466cd292 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -28,6 +28,7 @@ from canonicaljson import ( ) from twisted.internet import defer +from twisted.python import failure from twisted.web import server, resource from twisted.web.server import NOT_DONE_YET from twisted.web.util import redirectTo @@ -131,12 +132,17 @@ def wrap_request_handler(request_handler, include_metrics=False): version_string=self.version_string, ) except Exception: - logger.exception( - "Failed handle request %s.%s on %r: %r", + # failure.Failure() fishes the original Failure out + # of our stack, and thus gives us a sensible stack + # trace. + f = failure.Failure() + logger.error( + "Failed handle request %s.%s on %r: %r: %s", request_handler.__module__, request_handler.__name__, self, - request + request, + f.getTraceback().rstrip(), ) respond_with_json( request, diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index dc680ddf43..097c844d31 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet import defer from synapse.types import UserID @@ -81,6 +82,7 @@ class ModuleApi(object): reg = self.hs.get_handlers().registration_handler return reg.register(localpart=localpart) + @defer.inlineCallbacks def invalidate_access_token(self, access_token): """Invalidate an access token for a user @@ -94,8 +96,16 @@ class ModuleApi(object): Raises: synapse.api.errors.AuthError: the access token is invalid """ - - return self._auth_handler.delete_access_token(access_token) + # see if the access token corresponds to a device + user_info = yield self._auth.get_user_by_access_token(access_token) + device_id = user_info.get("device_id") + user_id = user_info["user"].to_string() + if device_id: + # delete the device, which will also delete its access tokens + yield self.hs.get_device_handler().delete_device(user_id, device_id) + else: + # no associated device. Just delete the access token. + yield self._auth_handler.delete_access_token(access_token) def run_db_interaction(self, desc, func, *args, **kwargs): """Run a function with a database connection diff --git a/synapse/notifier.py b/synapse/notifier.py index 626da778cd..ef042681bc 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -255,9 +255,7 @@ class Notifier(object): ) if self.federation_sender: - preserve_fn(self.federation_sender.notify_new_events)( - room_stream_id - ) + self.federation_sender.notify_new_events(room_stream_id) if event.type == EventTypes.Member and event.membership == Membership.JOIN: self._user_joined_room(event.state_key, event.room_id) @@ -297,8 +295,7 @@ class Notifier(object): def on_new_replication_data(self): """Used to inform replication listeners that something has happend without waking up any of the normal user event streams""" - with PreserveLoggingContext(): - self.notify_replication() + self.notify_replication() @defer.inlineCallbacks def wait_for_events(self, user_id, timeout, callback, room_ids=None, @@ -516,8 +513,14 @@ class Notifier(object): self.replication_deferred = ObservableDeferred(defer.Deferred()) deferred.callback(None) - for cb in self.replication_callbacks: - preserve_fn(cb)() + # the callbacks may well outlast the current request, so we run + # them in the sentinel logcontext. + # + # (ideally it would be up to the callbacks to know if they were + # starting off background processes and drop the logcontext + # accordingly, but that requires more changes) + for cb in self.replication_callbacks: + cb() @defer.inlineCallbacks def wait_for_replication(self, callback, timeout): diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 34cb108dcb..134e89b371 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -103,19 +103,25 @@ class PusherPool: yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) @defer.inlineCallbacks - def remove_pushers_by_user(self, user_id, except_access_token_id=None): - all = yield self.store.get_all_pushers() - logger.info( - "Removing all pushers for user %s except access tokens id %r", - user_id, except_access_token_id - ) - for p in all: - if p['user_name'] == user_id and p['access_token'] != except_access_token_id: + def remove_pushers_by_access_token(self, user_id, access_tokens): + """Remove the pushers for a given user corresponding to a set of + access_tokens. + + Args: + user_id (str): user to remove pushers for + access_tokens (Iterable[int]): access token *ids* to remove pushers + for + """ + tokens = set(access_tokens) + for p in (yield self.store.get_pushers_by_user_id(user_id)): + if p['access_token'] in tokens: logger.info( "Removing pusher for app id %s, pushkey %s, user %s", p['app_id'], p['pushkey'], p['user_name'] ) - yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) + yield self.remove_pusher( + p['app_id'], p['pushkey'], p['user_name'], + ) @defer.inlineCallbacks def on_new_notifications(self, min_stream_id, max_stream_id): diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 1d03e79b85..786c3fe864 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -216,11 +216,12 @@ class ReplicationStreamer(object): self.federation_sender.federation_ack(token) @measure_func("repl.on_user_sync") + @defer.inlineCallbacks def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms): """A client has started/stopped syncing on a worker. """ user_sync_counter.inc() - self.presence_handler.update_external_syncs_row( + yield self.presence_handler.update_external_syncs_row( conn_id, user_id, is_syncing, last_sync_ms, ) @@ -244,11 +245,12 @@ class ReplicationStreamer(object): getattr(self.store, cache_func).invalidate(tuple(keys)) @measure_func("repl.on_user_ip") + @defer.inlineCallbacks def on_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen): """The client saw a user request """ user_ip_cache_counter.inc() - self.store.insert_client_ip( + yield self.store.insert_client_ip( user_id, access_token, ip, user_agent, device_id, last_seen, ) diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index 1197158fdc..5022808ea9 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -137,8 +137,8 @@ class DeactivateAccountRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/admin/deactivate/(?P<target_user_id>[^/]*)") def __init__(self, hs): - self._auth_handler = hs.get_auth_handler() super(DeactivateAccountRestServlet, self).__init__(hs) + self._deactivate_account_handler = hs.get_deactivate_account_handler() @defer.inlineCallbacks def on_POST(self, request, target_user_id): @@ -149,7 +149,7 @@ class DeactivateAccountRestServlet(ClientV1RestServlet): if not is_admin: raise AuthError(403, "You are not a server admin") - yield self._auth_handler.deactivate_account(target_user_id) + yield self._deactivate_account_handler.deactivate_account(target_user_id) defer.returnValue((200, {})) @@ -309,7 +309,7 @@ class ResetPasswordRestServlet(ClientV1RestServlet): super(ResetPasswordRestServlet, self).__init__(hs) self.hs = hs self.auth = hs.get_auth() - self.auth_handler = hs.get_auth_handler() + self._set_password_handler = hs.get_set_password_handler() @defer.inlineCallbacks def on_POST(self, request, target_user_id): @@ -330,7 +330,7 @@ class ResetPasswordRestServlet(ClientV1RestServlet): logger.info("new_password: %r", new_password) - yield self.auth_handler.set_password( + yield self._set_password_handler.set_password( target_user_id, new_password, requester ) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py index 6add754782..ca49955935 100644 --- a/synapse/rest/client/v1/logout.py +++ b/synapse/rest/client/v1/logout.py @@ -16,6 +16,7 @@ from twisted.internet import defer from synapse.api.auth import get_access_token_from_request +from synapse.api.errors import AuthError from .base import ClientV1RestServlet, client_path_patterns @@ -30,15 +31,30 @@ class LogoutRestServlet(ClientV1RestServlet): def __init__(self, hs): super(LogoutRestServlet, self).__init__(hs) + self._auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() + self._device_handler = hs.get_device_handler() def on_OPTIONS(self, request): return (200, {}) @defer.inlineCallbacks def on_POST(self, request): - access_token = get_access_token_from_request(request) - yield self._auth_handler.delete_access_token(access_token) + try: + requester = yield self.auth.get_user_by_req(request) + except AuthError: + # this implies the access token has already been deleted. + pass + else: + if requester.device_id is None: + # the acccess token wasn't associated with a device. + # Just delete the access token + access_token = get_access_token_from_request(request) + yield self._auth_handler.delete_access_token(access_token) + else: + yield self._device_handler.delete_device( + requester.user.to_string(), requester.device_id) + defer.returnValue((200, {})) @@ -49,6 +65,7 @@ class LogoutAllRestServlet(ClientV1RestServlet): super(LogoutAllRestServlet, self).__init__(hs) self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() + self._device_handler = hs.get_device_handler() def on_OPTIONS(self, request): return (200, {}) @@ -57,6 +74,12 @@ class LogoutAllRestServlet(ClientV1RestServlet): def on_POST(self, request): requester = yield self.auth.get_user_by_req(request) user_id = requester.user.to_string() + + # first delete all of the user's devices + yield self._device_handler.delete_all_devices_for_user(user_id) + + # .. and then delete any access tokens which weren't associated with + # devices. yield self._auth_handler.delete_access_tokens_for_user(user_id) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py index 1f5bc24cc3..77434937ff 100644 --- a/synapse/rest/client/v2_alpha/_base.py +++ b/synapse/rest/client/v2_alpha/_base.py @@ -15,12 +15,13 @@ """This module contains base REST classes for constructing client v1 servlets. """ - -from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX +import logging import re -import logging +from twisted.internet import defer +from synapse.api.errors import InteractiveAuthIncompleteError +from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX logger = logging.getLogger(__name__) @@ -57,3 +58,37 @@ def set_timeline_upper_limit(filter_json, filter_timeline_limit): filter_json['room']['timeline']["limit"] = min( filter_json['room']['timeline']['limit'], filter_timeline_limit) + + +def interactive_auth_handler(orig): + """Wraps an on_POST method to handle InteractiveAuthIncompleteErrors + + Takes a on_POST method which returns a deferred (errcode, body) response + and adds exception handling to turn a InteractiveAuthIncompleteError into + a 401 response. + + Normal usage is: + + @interactive_auth_handler + @defer.inlineCallbacks + def on_POST(self, request): + # ... + yield self.auth_handler.check_auth + """ + def wrapped(*args, **kwargs): + res = defer.maybeDeferred(orig, *args, **kwargs) + res.addErrback(_catch_incomplete_interactive_auth) + return res + return wrapped + + +def _catch_incomplete_interactive_auth(f): + """helper for interactive_auth_handler + + Catches InteractiveAuthIncompleteErrors and turns them into 401 responses + + Args: + f (failure.Failure): + """ + f.trap(InteractiveAuthIncompleteError) + return 401, f.value.result diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 726e0a2826..385a3ad2ec 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -19,14 +19,14 @@ from twisted.internet import defer from synapse.api.auth import has_access_token from synapse.api.constants import LoginType -from synapse.api.errors import Codes, LoginError, SynapseError +from synapse.api.errors import Codes, SynapseError from synapse.http.servlet import ( RestServlet, assert_params_in_request, parse_json_object_from_request, ) from synapse.util.async import run_on_reactor from synapse.util.msisdn import phone_number_to_msisdn -from ._base import client_v2_patterns +from ._base import client_v2_patterns, interactive_auth_handler logger = logging.getLogger(__name__) @@ -98,56 +98,61 @@ class PasswordRestServlet(RestServlet): self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() self.datastore = self.hs.get_datastore() + self._set_password_handler = hs.get_set_password_handler() + @interactive_auth_handler @defer.inlineCallbacks def on_POST(self, request): - yield run_on_reactor() - body = parse_json_object_from_request(request) - authed, result, params, _ = yield self.auth_handler.check_auth([ - [LoginType.PASSWORD], - [LoginType.EMAIL_IDENTITY], - [LoginType.MSISDN], - ], body, self.hs.get_ip_from_request(request)) - - if not authed: - defer.returnValue((401, result)) + # there are two possibilities here. Either the user does not have an + # access token, and needs to do a password reset; or they have one and + # need to validate their identity. + # + # In the first case, we offer a couple of means of identifying + # themselves (email and msisdn, though it's unclear if msisdn actually + # works). + # + # In the second case, we require a password to confirm their identity. - user_id = None - requester = None - - if LoginType.PASSWORD in result: - # if using password, they should also be logged in + if has_access_token(request): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() - if user_id != result[LoginType.PASSWORD]: - raise LoginError(400, "", Codes.UNKNOWN) - elif LoginType.EMAIL_IDENTITY in result: - threepid = result[LoginType.EMAIL_IDENTITY] - if 'medium' not in threepid or 'address' not in threepid: - raise SynapseError(500, "Malformed threepid") - if threepid['medium'] == 'email': - # For emails, transform the address to lowercase. - # We store all email addreses as lowercase in the DB. - # (See add_threepid in synapse/handlers/auth.py) - threepid['address'] = threepid['address'].lower() - # if using email, we must know about the email they're authing with! - threepid_user_id = yield self.datastore.get_user_id_by_threepid( - threepid['medium'], threepid['address'] + params = yield self.auth_handler.validate_user_via_ui_auth( + requester, body, self.hs.get_ip_from_request(request), ) - if not threepid_user_id: - raise SynapseError(404, "Email address not found", Codes.NOT_FOUND) - user_id = threepid_user_id + user_id = requester.user.to_string() else: - logger.error("Auth succeeded but no known type!", result.keys()) - raise SynapseError(500, "", Codes.UNKNOWN) + requester = None + result, params, _ = yield self.auth_handler.check_auth( + [[LoginType.EMAIL_IDENTITY], [LoginType.MSISDN]], + body, self.hs.get_ip_from_request(request), + ) + + if LoginType.EMAIL_IDENTITY in result: + threepid = result[LoginType.EMAIL_IDENTITY] + if 'medium' not in threepid or 'address' not in threepid: + raise SynapseError(500, "Malformed threepid") + if threepid['medium'] == 'email': + # For emails, transform the address to lowercase. + # We store all email addreses as lowercase in the DB. + # (See add_threepid in synapse/handlers/auth.py) + threepid['address'] = threepid['address'].lower() + # if using email, we must know about the email they're authing with! + threepid_user_id = yield self.datastore.get_user_id_by_threepid( + threepid['medium'], threepid['address'] + ) + if not threepid_user_id: + raise SynapseError(404, "Email address not found", Codes.NOT_FOUND) + user_id = threepid_user_id + else: + logger.error("Auth succeeded but no known type!", result.keys()) + raise SynapseError(500, "", Codes.UNKNOWN) if 'new_password' not in params: raise SynapseError(400, "", Codes.MISSING_PARAM) new_password = params['new_password'] - yield self.auth_handler.set_password( + yield self._set_password_handler.set_password( user_id, new_password, requester ) @@ -161,52 +166,32 @@ class DeactivateAccountRestServlet(RestServlet): PATTERNS = client_v2_patterns("/account/deactivate$") def __init__(self, hs): + super(DeactivateAccountRestServlet, self).__init__() self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() - super(DeactivateAccountRestServlet, self).__init__() + self._deactivate_account_handler = hs.get_deactivate_account_handler() + @interactive_auth_handler @defer.inlineCallbacks def on_POST(self, request): body = parse_json_object_from_request(request) - # if the caller provides an access token, it ought to be valid. - requester = None - if has_access_token(request): - requester = yield self.auth.get_user_by_req( - request, - ) # type: synapse.types.Requester + requester = yield self.auth.get_user_by_req(request) # allow ASes to dectivate their own users - if requester and requester.app_service: - yield self.auth_handler.deactivate_account( + if requester.app_service: + yield self._deactivate_account_handler.deactivate_account( requester.user.to_string() ) defer.returnValue((200, {})) - authed, result, params, _ = yield self.auth_handler.check_auth([ - [LoginType.PASSWORD], - ], body, self.hs.get_ip_from_request(request)) - - if not authed: - defer.returnValue((401, result)) - - if LoginType.PASSWORD in result: - user_id = result[LoginType.PASSWORD] - # if using password, they should also be logged in - if requester is None: - raise SynapseError( - 400, - "Deactivate account requires an access_token", - errcode=Codes.MISSING_TOKEN - ) - if requester.user.to_string() != user_id: - raise LoginError(400, "", Codes.UNKNOWN) - else: - logger.error("Auth succeeded but no known type!", result.keys()) - raise SynapseError(500, "", Codes.UNKNOWN) - - yield self.auth_handler.deactivate_account(user_id) + yield self.auth_handler.validate_user_via_ui_auth( + requester, body, self.hs.get_ip_from_request(request), + ) + yield self._deactivate_account_handler.deactivate_account( + requester.user.to_string(), + ) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index 5321e5abbb..35d58b367a 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -17,9 +17,9 @@ import logging from twisted.internet import defer -from synapse.api import constants, errors +from synapse.api import errors from synapse.http import servlet -from ._base import client_v2_patterns +from ._base import client_v2_patterns, interactive_auth_handler logger = logging.getLogger(__name__) @@ -60,8 +60,11 @@ class DeleteDevicesRestServlet(servlet.RestServlet): self.device_handler = hs.get_device_handler() self.auth_handler = hs.get_auth_handler() + @interactive_auth_handler @defer.inlineCallbacks def on_POST(self, request): + requester = yield self.auth.get_user_by_req(request) + try: body = servlet.parse_json_object_from_request(request) except errors.SynapseError as e: @@ -77,14 +80,10 @@ class DeleteDevicesRestServlet(servlet.RestServlet): 400, "No devices supplied", errcode=errors.Codes.MISSING_PARAM ) - authed, result, params, _ = yield self.auth_handler.check_auth([ - [constants.LoginType.PASSWORD], - ], body, self.hs.get_ip_from_request(request)) - - if not authed: - defer.returnValue((401, result)) + yield self.auth_handler.validate_user_via_ui_auth( + requester, body, self.hs.get_ip_from_request(request), + ) - requester = yield self.auth.get_user_by_req(request) yield self.device_handler.delete_devices( requester.user.to_string(), body['devices'], @@ -115,6 +114,7 @@ class DeviceRestServlet(servlet.RestServlet): ) defer.returnValue((200, device)) + @interactive_auth_handler @defer.inlineCallbacks def on_DELETE(self, request, device_id): requester = yield self.auth.get_user_by_req(request) @@ -130,19 +130,13 @@ class DeviceRestServlet(servlet.RestServlet): else: raise - authed, result, params, _ = yield self.auth_handler.check_auth([ - [constants.LoginType.PASSWORD], - ], body, self.hs.get_ip_from_request(request)) - - if not authed: - defer.returnValue((401, result)) - - # check that the UI auth matched the access token - user_id = result[constants.LoginType.PASSWORD] - if user_id != requester.user.to_string(): - raise errors.AuthError(403, "Invalid auth") + yield self.auth_handler.validate_user_via_ui_auth( + requester, body, self.hs.get_ip_from_request(request), + ) - yield self.device_handler.delete_device(user_id, device_id) + yield self.device_handler.delete_device( + requester.user.to_string(), device_id, + ) defer.returnValue((200, {})) @defer.inlineCallbacks diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py index 089ec71c81..f762dbfa9a 100644 --- a/synapse/rest/client/v2_alpha/groups.py +++ b/synapse/rest/client/v2_alpha/groups.py @@ -38,7 +38,7 @@ class GroupServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, group_id): - requester = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() group_description = yield self.groups_handler.get_group_profile( @@ -74,7 +74,7 @@ class GroupSummaryServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, group_id): - requester = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() get_group_summary = yield self.groups_handler.get_group_summary( @@ -148,7 +148,7 @@ class GroupCategoryServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, group_id, category_id): - requester = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() category = yield self.groups_handler.get_group_category( @@ -200,7 +200,7 @@ class GroupCategoriesServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, group_id): - requester = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() category = yield self.groups_handler.get_group_categories( @@ -225,7 +225,7 @@ class GroupRoleServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, group_id, role_id): - requester = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() category = yield self.groups_handler.get_group_role( @@ -277,7 +277,7 @@ class GroupRolesServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, group_id): - requester = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() category = yield self.groups_handler.get_group_roles( @@ -348,7 +348,7 @@ class GroupRoomServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, group_id): - requester = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() result = yield self.groups_handler.get_rooms_in_group(group_id, requester_user_id) @@ -369,7 +369,7 @@ class GroupUsersServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, group_id): - requester = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() result = yield self.groups_handler.get_users_in_group(group_id, requester_user_id) @@ -672,7 +672,7 @@ class PublicisedGroupsForUserServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id): - yield self.auth.get_user_by_req(request) + yield self.auth.get_user_by_req(request, allow_guest=True) result = yield self.groups_handler.get_publicised_groups_for_user( user_id @@ -697,7 +697,7 @@ class PublicisedGroupsForUsersServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request): - yield self.auth.get_user_by_req(request) + yield self.auth.get_user_by_req(request, allow_guest=True) content = parse_json_object_from_request(request) user_ids = content["user_ids"] @@ -724,7 +724,7 @@ class GroupsForUserServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() result = yield self.groups_handler.get_joined_groups(requester_user_id) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 9e2f7308ce..e9d88a8895 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -27,7 +27,7 @@ from synapse.http.servlet import ( ) from synapse.util.msisdn import phone_number_to_msisdn -from ._base import client_v2_patterns +from ._base import client_v2_patterns, interactive_auth_handler import logging import hmac @@ -176,6 +176,7 @@ class RegisterRestServlet(RestServlet): self.device_handler = hs.get_device_handler() self.macaroon_gen = hs.get_macaroon_generator() + @interactive_auth_handler @defer.inlineCallbacks def on_POST(self, request): yield run_on_reactor() @@ -325,14 +326,10 @@ class RegisterRestServlet(RestServlet): [LoginType.MSISDN, LoginType.EMAIL_IDENTITY], ]) - authed, auth_result, params, session_id = yield self.auth_handler.check_auth( + auth_result, params, session_id = yield self.auth_handler.check_auth( flows, body, self.hs.get_ip_from_request(request) ) - if not authed: - defer.returnValue((401, auth_result)) - return - if registered_user_id is not None: logger.info( "Already registered user ID %r for this session", diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 723f7043f4..40d2e664eb 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -25,7 +25,8 @@ from synapse.util.stringutils import random_string from synapse.util.caches.expiringcache import ExpiringCache from synapse.http.client import SpiderHttpClient from synapse.http.server import ( - request_handler, respond_with_json_bytes + request_handler, respond_with_json_bytes, + respond_with_json, ) from synapse.util.async import ObservableDeferred from synapse.util.stringutils import is_ascii @@ -78,6 +79,9 @@ class PreviewUrlResource(Resource): self._expire_url_cache_data, 10 * 1000 ) + def render_OPTIONS(self, request): + return respond_with_json(request, 200, {}, send_cors=True) + def render_GET(self, request): self._async_render_GET(request) return NOT_DONE_YET @@ -348,11 +352,16 @@ class PreviewUrlResource(Resource): def _expire_url_cache_data(self): """Clean up expired url cache content, media and thumbnails. """ - # TODO: Delete from backup media store now = self.clock.time_msec() + logger.info("Running url preview cache expiry") + + if not (yield self.store.has_completed_background_updates()): + logger.info("Still running DB updates; skipping expiry") + return + # First we delete expired url cache entries media_ids = yield self.store.get_expired_url_cache(now) @@ -426,8 +435,7 @@ class PreviewUrlResource(Resource): yield self.store.delete_url_cache_media(removed_media) - if removed_media: - logger.info("Deleted %d media from url cache", len(removed_media)) + logger.info("Deleted %d media from url cache", len(removed_media)) def decode_and_calc_og(body, media_uri, request_encoding=None): diff --git a/synapse/server.py b/synapse/server.py index 10e3e9a4f1..99693071b6 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -39,18 +39,20 @@ from synapse.federation.transaction_queue import TransactionQueue from synapse.handlers import Handlers from synapse.handlers.appservice import ApplicationServicesHandler from synapse.handlers.auth import AuthHandler, MacaroonGeneartor +from synapse.handlers.deactivate_account import DeactivateAccountHandler from synapse.handlers.devicemessage import DeviceMessageHandler from synapse.handlers.device import DeviceHandler from synapse.handlers.e2e_keys import E2eKeysHandler from synapse.handlers.presence import PresenceHandler from synapse.handlers.room_list import RoomListHandler +from synapse.handlers.set_password import SetPasswordHandler from synapse.handlers.sync import SyncHandler from synapse.handlers.typing import TypingHandler from synapse.handlers.events import EventHandler, EventStreamHandler from synapse.handlers.initial_sync import InitialSyncHandler from synapse.handlers.receipts import ReceiptsHandler from synapse.handlers.read_marker import ReadMarkerHandler -from synapse.handlers.user_directory import UserDirectoyHandler +from synapse.handlers.user_directory import UserDirectoryHandler from synapse.handlers.groups_local import GroupsLocalHandler from synapse.handlers.profile import ProfileHandler from synapse.groups.groups_server import GroupsServerHandler @@ -60,7 +62,10 @@ from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.notifier import Notifier from synapse.push.action_generator import ActionGenerator from synapse.push.pusherpool import PusherPool -from synapse.rest.media.v1.media_repository import MediaRepository +from synapse.rest.media.v1.media_repository import ( + MediaRepository, + MediaRepositoryResource, +) from synapse.state import StateHandler from synapse.storage import DataStore from synapse.streams.events import EventSources @@ -90,17 +95,12 @@ class HomeServer(object): """ DEPENDENCIES = [ - 'config', - 'clock', 'http_client', 'db_pool', - 'persistence_service', 'replication_layer', - 'datastore', 'handlers', 'v1auth', 'auth', - 'rest_servlet_factory', 'state_handler', 'presence_handler', 'sync_handler', @@ -117,19 +117,10 @@ class HomeServer(object): 'application_service_handler', 'device_message_handler', 'profile_handler', + 'deactivate_account_handler', + 'set_password_handler', 'notifier', - 'distributor', - 'client_resource', - 'resource_for_federation', - 'resource_for_static_content', - 'resource_for_web_client', - 'resource_for_content_repo', - 'resource_for_server_key', - 'resource_for_server_key_v2', - 'resource_for_media_repository', - 'resource_for_metrics', 'event_sources', - 'ratelimiter', 'keyring', 'pusherpool', 'event_builder_factory', @@ -137,6 +128,7 @@ class HomeServer(object): 'http_client_context_factory', 'simple_http_client', 'media_repository', + 'media_repository_resource', 'federation_transport_client', 'federation_sender', 'receipts_handler', @@ -183,6 +175,21 @@ class HomeServer(object): def is_mine_id(self, string): return string.split(":", 1)[1] == self.hostname + def get_clock(self): + return self.clock + + def get_datastore(self): + return self.datastore + + def get_config(self): + return self.config + + def get_distributor(self): + return self.distributor + + def get_ratelimiter(self): + return self.ratelimiter + def build_replication_layer(self): return initialize_http_replication(self) @@ -265,6 +272,12 @@ class HomeServer(object): def build_profile_handler(self): return ProfileHandler(self) + def build_deactivate_account_handler(self): + return DeactivateAccountHandler(self) + + def build_set_password_handler(self): + return SetPasswordHandler(self) + def build_event_sources(self): return EventSources(self) @@ -294,6 +307,11 @@ class HomeServer(object): **self.db_config.get("args", {}) ) + def build_media_repository_resource(self): + # build the media repo resource. This indirects through the HomeServer + # to ensure that we only have a single instance of + return MediaRepositoryResource(self) + def build_media_repository(self): return MediaRepository(self) @@ -321,7 +339,7 @@ class HomeServer(object): return ActionGenerator(self) def build_user_directory_handler(self): - return UserDirectoyHandler(self) + return UserDirectoryHandler(self) def build_groups_local_handler(self): return GroupsLocalHandler(self) diff --git a/synapse/server.pyi b/synapse/server.pyi index e8c0386b7f..41416ef252 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -3,10 +3,14 @@ import synapse.federation.transaction_queue import synapse.federation.transport.client import synapse.handlers import synapse.handlers.auth +import synapse.handlers.deactivate_account import synapse.handlers.device import synapse.handlers.e2e_keys -import synapse.storage +import synapse.handlers.set_password +import synapse.rest.media.v1.media_repository import synapse.state +import synapse.storage + class HomeServer(object): def get_auth(self) -> synapse.api.auth.Auth: @@ -30,8 +34,20 @@ class HomeServer(object): def get_state_handler(self) -> synapse.state.StateHandler: pass + def get_deactivate_account_handler(self) -> synapse.handlers.deactivate_account.DeactivateAccountHandler: + pass + + def get_set_password_handler(self) -> synapse.handlers.set_password.SetPasswordHandler: + pass + def get_federation_sender(self) -> synapse.federation.transaction_queue.TransactionQueue: pass def get_federation_transport_client(self) -> synapse.federation.transport.client.TransportLayerClient: pass + + def get_media_repository_resource(self) -> synapse.rest.media.v1.media_repository.MediaRepositoryResource: + pass + + def get_media_repository(self) -> synapse.rest.media.v1.media_repository.MediaRepository: + pass diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 7ebd4f189d..b971f0cb18 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -469,23 +469,53 @@ class SQLBaseStore(object): txn.executemany(sql, vals) + @defer.inlineCallbacks def _simple_upsert(self, table, keyvalues, values, insertion_values={}, desc="_simple_upsert", lock=True): """ + + `lock` should generally be set to True (the default), but can be set + to False if either of the following are true: + + * there is a UNIQUE INDEX on the key columns. In this case a conflict + will cause an IntegrityError in which case this function will retry + the update. + + * we somehow know that we are the only thread which will be updating + this table. + Args: table (str): The table to upsert into keyvalues (dict): The unique key tables and their new values values (dict): The nonunique columns and their new values - insertion_values (dict): key/values to use when inserting + insertion_values (dict): additional key/values to use only when + inserting + lock (bool): True to lock the table when doing the upsert. Returns: Deferred(bool): True if a new entry was created, False if an existing one was updated. """ - return self.runInteraction( - desc, - self._simple_upsert_txn, table, keyvalues, values, insertion_values, - lock - ) + attempts = 0 + while True: + try: + result = yield self.runInteraction( + desc, + self._simple_upsert_txn, table, keyvalues, values, insertion_values, + lock=lock + ) + defer.returnValue(result) + except self.database_engine.module.IntegrityError as e: + attempts += 1 + if attempts >= 5: + # don't retry forever, because things other than races + # can cause IntegrityErrors + raise + + # presumably we raced with another transaction: let's retry. + logger.warn( + "IntegrityError when upserting into %s; retrying: %s", + table, e + ) def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={}, lock=True): @@ -493,7 +523,7 @@ class SQLBaseStore(object): if lock: self.database_engine.lock_table(txn, table) - # Try to update + # First try to update. sql = "UPDATE %s SET %s WHERE %s" % ( table, ", ".join("%s = ?" % (k,) for k in values), @@ -502,28 +532,29 @@ class SQLBaseStore(object): sqlargs = values.values() + keyvalues.values() txn.execute(sql, sqlargs) - if txn.rowcount == 0: - # We didn't update and rows so insert a new one - allvalues = {} - allvalues.update(keyvalues) - allvalues.update(values) - allvalues.update(insertion_values) + if txn.rowcount > 0: + # successfully updated at least one row. + return False - sql = "INSERT INTO %s (%s) VALUES (%s)" % ( - table, - ", ".join(k for k in allvalues), - ", ".join("?" for _ in allvalues) - ) - txn.execute(sql, allvalues.values()) + # We didn't update any rows so insert a new one + allvalues = {} + allvalues.update(keyvalues) + allvalues.update(values) + allvalues.update(insertion_values) - return True - else: - return False + sql = "INSERT INTO %s (%s) VALUES (%s)" % ( + table, + ", ".join(k for k in allvalues), + ", ".join("?" for _ in allvalues) + ) + txn.execute(sql, allvalues.values()) + # successfully inserted + return True def _simple_select_one(self, table, keyvalues, retcols, allow_none=False, desc="_simple_select_one"): """Executes a SELECT query on the named table, which is expected to - return a single row, returning a single column from it. + return a single row, returning multiple columns from it. Args: table : string giving the table name @@ -576,20 +607,18 @@ class SQLBaseStore(object): @staticmethod def _simple_select_onecol_txn(txn, table, keyvalues, retcol): - if keyvalues: - where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys()) - else: - where = "" - sql = ( - "SELECT %(retcol)s FROM %(table)s %(where)s" + "SELECT %(retcol)s FROM %(table)s" ) % { "retcol": retcol, "table": table, - "where": where, } - txn.execute(sql, keyvalues.values()) + if keyvalues: + sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys()) + txn.execute(sql, keyvalues.values()) + else: + txn.execute(sql) return [r[0] for r in txn] @@ -600,7 +629,7 @@ class SQLBaseStore(object): Args: table (str): table name - keyvalues (dict): column names and values to select the rows with + keyvalues (dict|None): column names and values to select the rows with retcol (str): column whos value we wish to retrieve. Returns: diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py index c8a1eb016b..56a0bde549 100644 --- a/synapse/storage/account_data.py +++ b/synapse/storage/account_data.py @@ -222,9 +222,12 @@ class AccountDataStore(SQLBaseStore): """ content_json = json.dumps(content) - def add_account_data_txn(txn, next_id): - self._simple_upsert_txn( - txn, + with self._account_data_id_gen.get_next() as next_id: + # no need to lock here as room_account_data has a unique constraint + # on (user_id, room_id, account_data_type) so _simple_upsert will + # retry if there is a conflict. + yield self._simple_upsert( + desc="add_room_account_data", table="room_account_data", keyvalues={ "user_id": user_id, @@ -234,19 +237,20 @@ class AccountDataStore(SQLBaseStore): values={ "stream_id": next_id, "content": content_json, - } - ) - txn.call_after( - self._account_data_stream_cache.entity_has_changed, - user_id, next_id, + }, + lock=False, ) - txn.call_after(self.get_account_data_for_user.invalidate, (user_id,)) - self._update_max_stream_id(txn, next_id) - with self._account_data_id_gen.get_next() as next_id: - yield self.runInteraction( - "add_room_account_data", add_account_data_txn, next_id - ) + # it's theoretically possible for the above to succeed and the + # below to fail - in which case we might reuse a stream id on + # restart, and the above update might not get propagated. That + # doesn't sound any worse than the whole update getting lost, + # which is what would happen if we combined the two into one + # transaction. + yield self._update_max_stream_id(next_id) + + self._account_data_stream_cache.entity_has_changed(user_id, next_id) + self.get_account_data_for_user.invalidate((user_id,)) result = self._account_data_id_gen.get_current_token() defer.returnValue(result) @@ -263,9 +267,12 @@ class AccountDataStore(SQLBaseStore): """ content_json = json.dumps(content) - def add_account_data_txn(txn, next_id): - self._simple_upsert_txn( - txn, + with self._account_data_id_gen.get_next() as next_id: + # no need to lock here as account_data has a unique constraint on + # (user_id, account_data_type) so _simple_upsert will retry if + # there is a conflict. + yield self._simple_upsert( + desc="add_user_account_data", table="account_data", keyvalues={ "user_id": user_id, @@ -274,40 +281,46 @@ class AccountDataStore(SQLBaseStore): values={ "stream_id": next_id, "content": content_json, - } + }, + lock=False, ) - txn.call_after( - self._account_data_stream_cache.entity_has_changed, + + # it's theoretically possible for the above to succeed and the + # below to fail - in which case we might reuse a stream id on + # restart, and the above update might not get propagated. That + # doesn't sound any worse than the whole update getting lost, + # which is what would happen if we combined the two into one + # transaction. + yield self._update_max_stream_id(next_id) + + self._account_data_stream_cache.entity_has_changed( user_id, next_id, ) - txn.call_after(self.get_account_data_for_user.invalidate, (user_id,)) - txn.call_after( - self.get_global_account_data_by_type_for_user.invalidate, + self.get_account_data_for_user.invalidate((user_id,)) + self.get_global_account_data_by_type_for_user.invalidate( (account_data_type, user_id,) ) - self._update_max_stream_id(txn, next_id) - - with self._account_data_id_gen.get_next() as next_id: - yield self.runInteraction( - "add_user_account_data", add_account_data_txn, next_id - ) result = self._account_data_id_gen.get_current_token() defer.returnValue(result) - def _update_max_stream_id(self, txn, next_id): + def _update_max_stream_id(self, next_id): """Update the max stream_id Args: - txn: The database cursor next_id(int): The the revision to advance to. """ - update_max_id_sql = ( - "UPDATE account_data_max_stream_id" - " SET stream_id = ?" - " WHERE stream_id < ?" + def _update(txn): + update_max_id_sql = ( + "UPDATE account_data_max_stream_id" + " SET stream_id = ?" + " WHERE stream_id < ?" + ) + txn.execute(update_max_id_sql, (next_id, next_id)) + return self.runInteraction( + "update_account_data_max_stream_id", + _update, ) - txn.execute(update_max_id_sql, (next_id, next_id)) @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000) def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context): diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 6f235ac051..11a1b942f1 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -85,6 +85,7 @@ class BackgroundUpdateStore(SQLBaseStore): self._background_update_performance = {} self._background_update_queue = [] self._background_update_handlers = {} + self._all_done = False @defer.inlineCallbacks def start_doing_background_updates(self): @@ -106,9 +107,41 @@ class BackgroundUpdateStore(SQLBaseStore): "No more background updates to do." " Unscheduling background update task." ) + self._all_done = True defer.returnValue(None) @defer.inlineCallbacks + def has_completed_background_updates(self): + """Check if all the background updates have completed + + Returns: + Deferred[bool]: True if all background updates have completed + """ + # if we've previously determined that there is nothing left to do, that + # is easy + if self._all_done: + defer.returnValue(True) + + # obviously, if we have things in our queue, we're not done. + if self._background_update_queue: + defer.returnValue(False) + + # otherwise, check if there are updates to be run. This is important, + # as we may be running on a worker which doesn't perform the bg updates + # itself, but still wants to wait for them to happen. + updates = yield self._simple_select_onecol( + "background_updates", + keyvalues=None, + retcol="1", + desc="check_background_updates", + ) + if not updates: + self._all_done = True + defer.returnValue(True) + + defer.returnValue(False) + + @defer.inlineCallbacks def do_next_background_update(self, desired_duration_ms): """Does some amount of work on the next queued background update @@ -269,7 +302,7 @@ class BackgroundUpdateStore(SQLBaseStore): # Sqlite doesn't support concurrent creation of indexes. # # We don't use partial indices on SQLite as it wasn't introduced - # until 3.8, and wheezy has 3.7 + # until 3.8, and wheezy and CentOS 7 have 3.7 # # We assume that sqlite doesn't give us invalid indices; however # we may still end up with the index existing but the diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py index 52e5cdad70..a66ff7c1e0 100644 --- a/synapse/storage/media_repository.py +++ b/synapse/storage/media_repository.py @@ -12,13 +12,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from synapse.storage.background_updates import BackgroundUpdateStore -from ._base import SQLBaseStore - -class MediaRepositoryStore(SQLBaseStore): +class MediaRepositoryStore(BackgroundUpdateStore): """Persistence for attachments and avatars""" + def __init__(self, db_conn, hs): + super(MediaRepositoryStore, self).__init__(db_conn, hs) + + self.register_background_index_update( + update_name='local_media_repository_url_idx', + index_name='local_media_repository_url_idx', + table='local_media_repository', + columns=['created_ts'], + where_clause='url_cache IS NOT NULL', + ) + def get_default_thumbnails(self, top_level_type, sub_type): return [] diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py index beea3102fc..ec02e73bc2 100644 --- a/synapse/storage/profile.py +++ b/synapse/storage/profile.py @@ -15,6 +15,9 @@ from twisted.internet import defer +from synapse.storage.roommember import ProfileInfo +from synapse.api.errors import StoreError + from ._base import SQLBaseStore @@ -26,6 +29,30 @@ class ProfileStore(SQLBaseStore): desc="create_profile", ) + @defer.inlineCallbacks + def get_profileinfo(self, user_localpart): + try: + profile = yield self._simple_select_one( + table="profiles", + keyvalues={"user_id": user_localpart}, + retcols=("displayname", "avatar_url"), + desc="get_profileinfo", + ) + except StoreError as e: + if e.code == 404: + # no match + defer.returnValue(ProfileInfo(None, None)) + return + else: + raise + + defer.returnValue( + ProfileInfo( + avatar_url=profile['avatar_url'], + display_name=profile['displayname'], + ) + ) + def get_profile_displayname(self, user_localpart): return self._simple_select_one_onecol( table="profiles", diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 34d2f82b7f..3d8b4d5d5b 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -204,34 +204,35 @@ class PusherStore(SQLBaseStore): pushkey, pushkey_ts, lang, data, last_stream_ordering, profile_tag=""): with self._pushers_id_gen.get_next() as stream_id: - def f(txn): - newly_inserted = self._simple_upsert_txn( - txn, - "pushers", - { - "app_id": app_id, - "pushkey": pushkey, - "user_name": user_id, - }, - { - "access_token": access_token, - "kind": kind, - "app_display_name": app_display_name, - "device_display_name": device_display_name, - "ts": pushkey_ts, - "lang": lang, - "data": encode_canonical_json(data), - "last_stream_ordering": last_stream_ordering, - "profile_tag": profile_tag, - "id": stream_id, - }, - ) - if newly_inserted: - # get_if_user_has_pusher only cares if the user has - # at least *one* pusher. - txn.call_after(self.get_if_user_has_pusher.invalidate, (user_id,)) + # no need to lock because `pushers` has a unique key on + # (app_id, pushkey, user_name) so _simple_upsert will retry + newly_inserted = yield self._simple_upsert( + table="pushers", + keyvalues={ + "app_id": app_id, + "pushkey": pushkey, + "user_name": user_id, + }, + values={ + "access_token": access_token, + "kind": kind, + "app_display_name": app_display_name, + "device_display_name": device_display_name, + "ts": pushkey_ts, + "lang": lang, + "data": encode_canonical_json(data), + "last_stream_ordering": last_stream_ordering, + "profile_tag": profile_tag, + "id": stream_id, + }, + desc="add_pusher", + lock=False, + ) - yield self.runInteraction("add_pusher", f) + if newly_inserted: + # get_if_user_has_pusher only cares if the user has + # at least *one* pusher. + self.get_if_user_has_pusher.invalidate(user_id,) @defer.inlineCallbacks def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id): @@ -243,11 +244,19 @@ class PusherStore(SQLBaseStore): "pushers", {"app_id": app_id, "pushkey": pushkey, "user_name": user_id} ) - self._simple_upsert_txn( + + # it's possible for us to end up with duplicate rows for + # (app_id, pushkey, user_id) at different stream_ids, but that + # doesn't really matter. + self._simple_insert_txn( txn, - "deleted_pushers", - {"app_id": app_id, "pushkey": pushkey, "user_id": user_id}, - {"stream_id": stream_id}, + table="deleted_pushers", + values={ + "stream_id": stream_id, + "app_id": app_id, + "pushkey": pushkey, + "user_id": user_id, + }, ) with self._pushers_id_gen.get_next() as stream_id: @@ -310,9 +319,12 @@ class PusherStore(SQLBaseStore): @defer.inlineCallbacks def set_throttle_params(self, pusher_id, room_id, params): + # no need to lock because `pusher_throttle` has a primary key on + # (pusher, room_id) so _simple_upsert will retry yield self._simple_upsert( "pusher_throttle", {"pusher": pusher_id, "room_id": room_id}, params, - desc="set_throttle_params" + desc="set_throttle_params", + lock=False, ) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 8b9544c209..3aa810981f 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -254,8 +254,8 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): If None, tokens associated with any device (or no device) will be deleted Returns: - defer.Deferred[list[str, str|None]]: a list of the deleted tokens - and device IDs + defer.Deferred[list[str, int, str|None, int]]: a list of + (token, token id, device id) for each of the deleted tokens """ def f(txn): keyvalues = { @@ -272,12 +272,12 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): values.append(except_token_id) txn.execute( - "SELECT token, device_id FROM access_tokens WHERE %s" % where_clause, + "SELECT token, id, device_id FROM access_tokens WHERE %s" % where_clause, values ) - tokens_and_devices = [(r[0], r[1]) for r in txn] + tokens_and_devices = [(r[0], r[1], r[2]) for r in txn] - for token, _ in tokens_and_devices: + for token, _, _ in tokens_and_devices: self._invalidate_cache_and_stream( txn, self.get_user_by_access_token, (token,) ) diff --git a/synapse/storage/schema/delta/44/expire_url_cache.sql b/synapse/storage/schema/delta/44/expire_url_cache.sql index e2b775f038..b12f9b2ebf 100644 --- a/synapse/storage/schema/delta/44/expire_url_cache.sql +++ b/synapse/storage/schema/delta/44/expire_url_cache.sql @@ -13,7 +13,10 @@ * limitations under the License. */ -CREATE INDEX local_media_repository_url_idx ON local_media_repository(created_ts) WHERE url_cache IS NOT NULL; +-- this didn't work on SQLite 3.7 (because of lack of partial indexes), so was +-- removed and replaced with 46/local_media_repository_url_idx.sql. +-- +-- CREATE INDEX local_media_repository_url_idx ON local_media_repository(created_ts) WHERE url_cache IS NOT NULL; -- we need to change `expires` to `expires_ts` so that we can index on it. SQLite doesn't support -- indices on expressions until 3.9. diff --git a/synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql b/synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql new file mode 100644 index 0000000000..bb307889c1 --- /dev/null +++ b/synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql @@ -0,0 +1,35 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- drop the unique constraint on deleted_pushers so that we can just insert +-- into it rather than upserting. + +CREATE TABLE deleted_pushers2 ( + stream_id BIGINT NOT NULL, + app_id TEXT NOT NULL, + pushkey TEXT NOT NULL, + user_id TEXT NOT NULL +); + +INSERT INTO deleted_pushers2 (stream_id, app_id, pushkey, user_id) + SELECT stream_id, app_id, pushkey, user_id from deleted_pushers; + +DROP TABLE deleted_pushers; +ALTER TABLE deleted_pushers2 RENAME TO deleted_pushers; + +-- create the index after doing the inserts because that's more efficient. +-- it also means we can give it the same name as the old one without renaming. +CREATE INDEX deleted_pushers_stream_id ON deleted_pushers (stream_id); + diff --git a/synapse/storage/schema/delta/46/local_media_repository_url_idx.sql b/synapse/storage/schema/delta/46/local_media_repository_url_idx.sql new file mode 100644 index 0000000000..bbfc7f5d1a --- /dev/null +++ b/synapse/storage/schema/delta/46/local_media_repository_url_idx.sql @@ -0,0 +1,24 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- register a background update which will recreate the +-- local_media_repository_url_idx index. +-- +-- We do this as a bg update not because it is a particularly onerous +-- operation, but because we'd like it to be a partial index if possible, and +-- the background_index_update code will understand whether we are on +-- postgres or sqlite and behave accordingly. +INSERT INTO background_updates (update_name, progress_json) VALUES + ('local_media_repository_url_idx', '{}'); diff --git a/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql b/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql new file mode 100644 index 0000000000..cb0d5a2576 --- /dev/null +++ b/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql @@ -0,0 +1,35 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- change the user_directory table to also cover global local user profiles +-- rather than just profiles within specific rooms. + +CREATE TABLE user_directory2 ( + user_id TEXT NOT NULL, + room_id TEXT, + display_name TEXT, + avatar_url TEXT +); + +INSERT INTO user_directory2(user_id, room_id, display_name, avatar_url) + SELECT user_id, room_id, display_name, avatar_url from user_directory; + +DROP TABLE user_directory; +ALTER TABLE user_directory2 RENAME TO user_directory; + +-- create indexes after doing the inserts because that's more efficient. +-- it also means we can give it the same name as the old one without renaming. +CREATE INDEX user_directory_room_idx ON user_directory(room_id); +CREATE UNIQUE INDEX user_directory_user_idx ON user_directory(user_id); diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py index 5dc5b9582a..c9bff408ef 100644 --- a/synapse/storage/user_directory.py +++ b/synapse/storage/user_directory.py @@ -164,7 +164,7 @@ class UserDirectoryStore(SQLBaseStore): ) if isinstance(self.database_engine, PostgresEngine): - # We weight the loclpart most highly, then display name and finally + # We weight the localpart most highly, then display name and finally # server name if new_entry: sql = """ @@ -317,6 +317,16 @@ class UserDirectoryStore(SQLBaseStore): rows = yield self._execute("get_all_rooms", None, sql) defer.returnValue([room_id for room_id, in rows]) + @defer.inlineCallbacks + def get_all_local_users(self): + """Get all local users + """ + sql = """ + SELECT name FROM users + """ + rows = yield self._execute("get_all_local_users", None, sql) + defer.returnValue([name for name, in rows]) + def add_users_who_share_room(self, room_id, share_private, user_id_tuples): """Insert entries into the users_who_share_rooms table. The first user should be a local user. @@ -629,6 +639,20 @@ class UserDirectoryStore(SQLBaseStore): ] } """ + + if self.hs.config.user_directory_search_all_users: + join_clause = "" + where_clause = "?<>''" # naughty hack to keep the same number of binds + else: + join_clause = """ + LEFT JOIN users_in_public_rooms AS p USING (user_id) + LEFT JOIN ( + SELECT other_user_id AS user_id FROM users_who_share_rooms + WHERE user_id = ? AND share_private + ) AS s USING (user_id) + """ + where_clause = "(s.user_id IS NOT NULL OR p.user_id IS NOT NULL)" + if isinstance(self.database_engine, PostgresEngine): full_query, exact_query, prefix_query = _parse_query_postgres(search_term) @@ -641,13 +665,9 @@ class UserDirectoryStore(SQLBaseStore): SELECT d.user_id, display_name, avatar_url FROM user_directory_search INNER JOIN user_directory AS d USING (user_id) - LEFT JOIN users_in_public_rooms AS p USING (user_id) - LEFT JOIN ( - SELECT other_user_id AS user_id FROM users_who_share_rooms - WHERE user_id = ? AND share_private - ) AS s USING (user_id) + %s WHERE - (s.user_id IS NOT NULL OR p.user_id IS NOT NULL) + %s AND vector @@ to_tsquery('english', ?) ORDER BY (CASE WHEN s.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END) @@ -671,7 +691,7 @@ class UserDirectoryStore(SQLBaseStore): display_name IS NULL, avatar_url IS NULL LIMIT ? - """ + """ % (join_clause, where_clause) args = (user_id, full_query, exact_query, prefix_query, limit + 1,) elif isinstance(self.database_engine, Sqlite3Engine): search_query = _parse_query_sqlite(search_term) @@ -680,20 +700,16 @@ class UserDirectoryStore(SQLBaseStore): SELECT d.user_id, display_name, avatar_url FROM user_directory_search INNER JOIN user_directory AS d USING (user_id) - LEFT JOIN users_in_public_rooms AS p USING (user_id) - LEFT JOIN ( - SELECT other_user_id AS user_id FROM users_who_share_rooms - WHERE user_id = ? AND share_private - ) AS s USING (user_id) + %s WHERE - (s.user_id IS NOT NULL OR p.user_id IS NOT NULL) + %s AND value MATCH ? ORDER BY rank(matchinfo(user_directory_search)) DESC, display_name IS NULL, avatar_url IS NULL LIMIT ? - """ + """ % (join_clause, where_clause) args = (user_id, search_query, limit + 1) else: # This should be unreachable. @@ -723,7 +739,7 @@ def _parse_query_sqlite(search_term): # Pull out the individual words, discarding any non-word characters. results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) - return " & ".join("(%s* | %s)" % (result, result,) for result in results) + return " & ".join("(%s* OR %s)" % (result, result,) for result in results) def _parse_query_postgres(search_term): |