diff options
50 files changed, 1377 insertions, 571 deletions
diff --git a/README.rst b/README.rst index 9da8c7f7a8..6f146b63b1 100644 --- a/README.rst +++ b/README.rst @@ -823,7 +823,9 @@ spidering 'internal' URLs on your network. At the very least we recommend that your loopback and RFC1918 IP addresses are blacklisted. This also requires the optional lxml and netaddr python dependencies to be -installed. +installed. This in turn requires the libxml2 library to be available - on +Debian/Ubuntu this means ``apt-get install libxml2-dev``, or equivalent for +your OS. Password reset diff --git a/docs/code_style.rst b/docs/code_style.rst index 8d73d17beb..9c52cb3182 100644 --- a/docs/code_style.rst +++ b/docs/code_style.rst @@ -1,52 +1,119 @@ -Basically, PEP8 +- Everything should comply with PEP8. Code should pass + ``pep8 --max-line-length=100`` without any warnings. -- NEVER tabs. 4 spaces to indent. -- Max line width: 79 chars (with flexibility to overflow by a "few chars" if +- **Indenting**: + + - NEVER tabs. 4 spaces to indent. + + - follow PEP8; either hanging indent or multiline-visual indent depending + on the size and shape of the arguments and what makes more sense to the + author. In other words, both this:: + + print("I am a fish %s" % "moo") + + and this:: + + print("I am a fish %s" % + "moo") + + and this:: + + print( + "I am a fish %s" % + "moo", + ) + + ...are valid, although given each one takes up 2x more vertical space than + the previous, it's up to the author's discretion as to which layout makes + most sense for their function invocation. (e.g. if they want to add + comments per-argument, or put expressions in the arguments, or group + related arguments together, or want to deliberately extend or preserve + vertical/horizontal space) + +- **Line length**: + + Max line length is 79 chars (with flexibility to overflow by a "few chars" if the overflowing content is not semantically significant and avoids an explosion of vertical whitespace). -- Use camel case for class and type names -- Use underscores for functions and variables. -- Use double quotes. -- Use parentheses instead of '\\' for line continuation where ever possible - (which is pretty much everywhere) -- There should be max a single new line between: + + Use parentheses instead of ``\`` for line continuation where ever possible + (which is pretty much everywhere). + +- **Naming**: + + - Use camel case for class and type names + - Use underscores for functions and variables. + +- Use double quotes ``"foo"`` rather than single quotes ``'foo'``. + +- **Blank lines**: + + - There should be max a single new line between: + - statements - functions in a class -- There should be two new lines between: + + - There should be two new lines between: + - definitions in a module (e.g., between different classes) -- There should be spaces where spaces should be and not where there shouldn't be: - - a single space after a comma - - a single space before and after for '=' when used as assignment - - no spaces before and after for '=' for default values and keyword arguments. -- Indenting must follow PEP8; either hanging indent or multiline-visual indent - depending on the size and shape of the arguments and what makes more sense to - the author. In other words, both this:: - - print("I am a fish %s" % "moo") - - and this:: - - print("I am a fish %s" % - "moo") - - and this:: - - print( - "I am a fish %s" % - "moo" - ) - - ...are valid, although given each one takes up 2x more vertical space than - the previous, it's up to the author's discretion as to which layout makes most - sense for their function invocation. (e.g. if they want to add comments - per-argument, or put expressions in the arguments, or group related arguments - together, or want to deliberately extend or preserve vertical/horizontal - space) - -Comments should follow the `google code style <http://google.github.io/styleguide/pyguide.html?showone=Comments#Comments>`_. -This is so that we can generate documentation with -`sphinx <http://sphinxcontrib-napoleon.readthedocs.org/en/latest/>`_. See the -`examples <http://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html>`_ -in the sphinx documentation. - -Code should pass pep8 --max-line-length=100 without any warnings. + +- **Whitespace**: + + There should be spaces where spaces should be and not where there shouldn't + be: + + - a single space after a comma + - a single space before and after for '=' when used as assignment + - no spaces before and after for '=' for default values and keyword arguments. + +- **Comments**: should follow the `google code style + <http://google.github.io/styleguide/pyguide.html?showone=Comments#Comments>`_. + This is so that we can generate documentation with `sphinx + <http://sphinxcontrib-napoleon.readthedocs.org/en/latest/>`_. See the + `examples + <http://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html>`_ + in the sphinx documentation. + +- **Imports**: + + - Prefer to import classes and functions than packages or modules. + + Example:: + + from synapse.types import UserID + ... + user_id = UserID(local, server) + + is preferred over:: + + from synapse import types + ... + user_id = types.UserID(local, server) + + (or any other variant). + + This goes against the advice in the Google style guide, but it means that + errors in the name are caught early (at import time). + + - Multiple imports from the same package can be combined onto one line:: + + from synapse.types import GroupID, RoomID, UserID + + An effort should be made to keep the individual imports in alphabetical + order. + + If the list becomes long, wrap it with parentheses and split it over + multiple lines. + + - As per `PEP-8 <https://www.python.org/dev/peps/pep-0008/#imports>`_, + imports should be grouped in the following order, with a blank line between + each group: + + 1. standard library imports + 2. related third party imports + 3. local application/library specific imports + + - Imports within each group should be sorted alphabetically by module name. + + - Avoid wildcard imports (``from synapse.types import *``) and relative + imports (``from .types import UserID``). diff --git a/docs/password_auth_providers.rst b/docs/password_auth_providers.rst new file mode 100644 index 0000000000..d8a7b61cdc --- /dev/null +++ b/docs/password_auth_providers.rst @@ -0,0 +1,99 @@ +Password auth provider modules +============================== + +Password auth providers offer a way for server administrators to integrate +their Synapse installation with an existing authentication system. + +A password auth provider is a Python class which is dynamically loaded into +Synapse, and provides a number of methods by which it can integrate with the +authentication system. + +This document serves as a reference for those looking to implement their own +password auth providers. + +Required methods +---------------- + +Password auth provider classes must provide the following methods: + +*class* ``SomeProvider.parse_config``\(*config*) + + This method is passed the ``config`` object for this module from the + homeserver configuration file. + + It should perform any appropriate sanity checks on the provided + configuration, and return an object which is then passed into ``__init__``. + +*class* ``SomeProvider``\(*config*, *account_handler*) + + The constructor is passed the config object returned by ``parse_config``, + and a ``synapse.module_api.ModuleApi`` object which allows the + password provider to check if accounts exist and/or create new ones. + +Optional methods +---------------- + +Password auth provider classes may optionally provide the following methods. + +*class* ``SomeProvider.get_db_schema_files``\() + + This method, if implemented, should return an Iterable of ``(name, + stream)`` pairs of database schema files. Each file is applied in turn at + initialisation, and a record is then made in the database so that it is + not re-applied on the next start. + +``someprovider.get_supported_login_types``\() + + This method, if implemented, should return a ``dict`` mapping from a login + type identifier (such as ``m.login.password``) to an iterable giving the + fields which must be provided by the user in the submission to the + ``/login`` api. These fields are passed in the ``login_dict`` dictionary + to ``check_auth``. + + For example, if a password auth provider wants to implement a custom login + type of ``com.example.custom_login``, where the client is expected to pass + the fields ``secret1`` and ``secret2``, the provider should implement this + method and return the following dict:: + + {"com.example.custom_login": ("secret1", "secret2")} + +``someprovider.check_auth``\(*username*, *login_type*, *login_dict*) + + This method is the one that does the real work. If implemented, it will be + called for each login attempt where the login type matches one of the keys + returned by ``get_supported_login_types``. + + It is passed the (possibly UNqualified) ``user`` provided by the client, + the login type, and a dictionary of login secrets passed by the client. + + The method should return a Twisted ``Deferred`` object, which resolves to + the canonical ``@localpart:domain`` user id if authentication is successful, + and ``None`` if not. + + Alternatively, the ``Deferred`` can resolve to a ``(str, func)`` tuple, in + which case the second field is a callback which will be called with the + result from the ``/login`` call (including ``access_token``, ``device_id``, + etc.) + +``someprovider.check_password``\(*user_id*, *password*) + + This method provides a simpler interface than ``get_supported_login_types`` + and ``check_auth`` for password auth providers that just want to provide a + mechanism for validating ``m.login.password`` logins. + + Iif implemented, it will be called to check logins with an + ``m.login.password`` login type. It is passed a qualified + ``@localpart:domain`` user id, and the password provided by the user. + + The method should return a Twisted ``Deferred`` object, which resolves to + ``True`` if authentication is successful, and ``False`` if not. + +``someprovider.on_logged_out``\(*user_id*, *device_id*, *access_token*) + + This method, if implemented, is called when a user logs out. It is passed + the qualified user ID, the ID of the deactivated device (if any: access + tokens are occasionally created without an associated device ID), and the + (now deactivated) access token. + + It may return a Twisted ``Deferred`` object; the logout request will wait + for the deferred to complete but the result is ignored. diff --git a/docs/url_previews.rst b/docs/url_previews.md index 634d9d907f..665554e165 100644 --- a/docs/url_previews.rst +++ b/docs/url_previews.md @@ -56,6 +56,7 @@ As a first cut, let's do #2 and have the receiver hit the API to calculate its o API --- +``` GET /_matrix/media/r0/preview_url?url=http://wherever.com 200 OK { @@ -66,6 +67,7 @@ GET /_matrix/media/r0/preview_url?url=http://wherever.com "og:description" : "“Synapse 0.12 is out! Lots of polishing, performance &amp; bugfixes: /sync API, /r0 prefix, fulltext search, 3PID invites https://t.co/5alhXLLEGP”" "og:site_name" : "Twitter" } +``` * Downloads the URL * If HTML, just stores it in RAM and parses it for OG meta tags diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index dc7fe940e8..3a8972efc3 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -42,6 +42,7 @@ BOOLEAN_COLUMNS = { "public_room_list_stream": ["visibility"], "device_lists_outbound_pokes": ["sent"], "users_who_share_rooms": ["share_private"], + "groups": ["is_public"], } @@ -112,6 +113,7 @@ class Store(object): _simple_update_one = SQLBaseStore.__dict__["_simple_update_one"] _simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"] + _simple_update_txn = SQLBaseStore.__dict__["_simple_update_txn"] def runInteraction(self, desc, func, *args, **kwargs): def r(conn): @@ -318,7 +320,7 @@ class Porter(object): backward_chunk = min(row[0] for row in brows) - 1 rows = frows + brows - self._convert_rows(table, headers, rows) + rows = self._convert_rows(table, headers, rows) def insert(txn): self.postgres_store.insert_many_txn( @@ -554,17 +556,29 @@ class Porter(object): i for i, h in enumerate(headers) if h in bool_col_names ] + class BadValueException(Exception): + pass + def conv(j, col): if j in bool_cols: return bool(col) + elif isinstance(col, basestring) and "\0" in col: + logger.warn("DROPPING ROW: NUL value in table %s col %s: %r", table, headers[j], col) + raise BadValueException(); return col + outrows = [] for i, row in enumerate(rows): - rows[i] = tuple( - conv(j, col) - for j, col in enumerate(row) - if j > 0 - ) + try: + outrows.append(tuple( + conv(j, col) + for j, col in enumerate(row) + if j > 0 + )) + except BadValueException: + pass + + return outrows @defer.inlineCallbacks def _setup_sent_transactions(self): @@ -592,7 +606,7 @@ class Porter(object): "select", r, ) - self._convert_rows("sent_transactions", headers, rows) + rows = self._convert_rows("sent_transactions", headers, rows) inserted_rows = len(rows) if inserted_rows: diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py index bee4c47498..abc7ef5725 100644 --- a/synapse/app/frontend_proxy.py +++ b/synapse/app/frontend_proxy.py @@ -50,8 +50,7 @@ logger = logging.getLogger("synapse.app.frontend_proxy") class KeyUploadServlet(RestServlet): - PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$", - releases=()) + PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$") def __init__(self, hs): """ @@ -89,9 +88,16 @@ class KeyUploadServlet(RestServlet): if body: # They're actually trying to upload something, proxy to main synapse. + # Pass through the auth headers, if any, in case the access token + # is there. + auth_headers = request.requestHeaders.getRawHeaders("Authorization", []) + headers = { + "Authorization": auth_headers, + } result = yield self.http_client.post_json_get_json( self.main_uri + request.uri, body, + headers=headers, ) defer.returnValue((200, result)) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 3adf72e141..9e26146338 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -30,6 +30,8 @@ from synapse.config._base import ConfigError from synapse.config.homeserver import HomeServerConfig from synapse.crypto import context_factory from synapse.federation.transport.server import TransportLayerServer +from synapse.module_api import ModuleApi +from synapse.http.additional_resource import AdditionalResource from synapse.http.server import RootRedirect from synapse.http.site import SynapseSite from synapse.metrics import register_memory_metrics @@ -49,6 +51,7 @@ from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_d from synapse.util.httpresourcetree import create_resource_tree from synapse.util.logcontext import LoggingContext from synapse.util.manhole import manhole +from synapse.util.module_loader import load_module from synapse.util.rlimit import change_resource_limit from synapse.util.versionstring import get_version_string from twisted.application import service @@ -107,52 +110,18 @@ class SynapseHomeServer(HomeServer): resources = {} for res in listener_config["resources"]: for name in res["names"]: - if name == "client": - client_resource = ClientRestResource(self) - if res["compress"]: - client_resource = gz_wrap(client_resource) - - resources.update({ - "/_matrix/client/api/v1": client_resource, - "/_matrix/client/r0": client_resource, - "/_matrix/client/unstable": client_resource, - "/_matrix/client/v2_alpha": client_resource, - "/_matrix/client/versions": client_resource, - }) - - if name == "federation": - resources.update({ - FEDERATION_PREFIX: TransportLayerServer(self), - }) - - if name in ["static", "client"]: - resources.update({ - STATIC_PREFIX: File( - os.path.join(os.path.dirname(synapse.__file__), "static") - ), - }) - - if name in ["media", "federation", "client"]: - media_repo = MediaRepositoryResource(self) - resources.update({ - MEDIA_PREFIX: media_repo, - LEGACY_MEDIA_PREFIX: media_repo, - CONTENT_REPO_PREFIX: ContentRepoResource( - self, self.config.uploads_path - ), - }) - - if name in ["keys", "federation"]: - resources.update({ - SERVER_KEY_PREFIX: LocalKey(self), - SERVER_KEY_V2_PREFIX: KeyApiV2Resource(self), - }) - - if name == "webclient": - resources[WEB_CLIENT_PREFIX] = build_resource_for_web_client(self) + resources.update(self._configure_named_resource( + name, res.get("compress", False), + )) - if name == "metrics" and self.get_config().enable_metrics: - resources[METRICS_PREFIX] = MetricsResource(self) + additional_resources = listener_config.get("additional_resources", {}) + logger.debug("Configuring additional resources: %r", + additional_resources) + module_api = ModuleApi(self, self.get_auth_handler()) + for path, resmodule in additional_resources.items(): + handler_cls, config = load_module(resmodule) + handler = handler_cls(config, module_api) + resources[path] = AdditionalResource(self, handler.handle_request) if WEB_CLIENT_PREFIX in resources: root_resource = RootRedirect(WEB_CLIENT_PREFIX) @@ -188,6 +157,67 @@ class SynapseHomeServer(HomeServer): ) logger.info("Synapse now listening on port %d", port) + def _configure_named_resource(self, name, compress=False): + """Build a resource map for a named resource + + Args: + name (str): named resource: one of "client", "federation", etc + compress (bool): whether to enable gzip compression for this + resource + + Returns: + dict[str, Resource]: map from path to HTTP resource + """ + resources = {} + if name == "client": + client_resource = ClientRestResource(self) + if compress: + client_resource = gz_wrap(client_resource) + + resources.update({ + "/_matrix/client/api/v1": client_resource, + "/_matrix/client/r0": client_resource, + "/_matrix/client/unstable": client_resource, + "/_matrix/client/v2_alpha": client_resource, + "/_matrix/client/versions": client_resource, + }) + + if name == "federation": + resources.update({ + FEDERATION_PREFIX: TransportLayerServer(self), + }) + + if name in ["static", "client"]: + resources.update({ + STATIC_PREFIX: File( + os.path.join(os.path.dirname(synapse.__file__), "static") + ), + }) + + if name in ["media", "federation", "client"]: + media_repo = MediaRepositoryResource(self) + resources.update({ + MEDIA_PREFIX: media_repo, + LEGACY_MEDIA_PREFIX: media_repo, + CONTENT_REPO_PREFIX: ContentRepoResource( + self, self.config.uploads_path + ), + }) + + if name in ["keys", "federation"]: + resources.update({ + SERVER_KEY_PREFIX: LocalKey(self), + SERVER_KEY_V2_PREFIX: KeyApiV2Resource(self), + }) + + if name == "webclient": + resources[WEB_CLIENT_PREFIX] = build_resource_for_web_client(self) + + if name == "metrics" and self.get_config().enable_metrics: + resources[METRICS_PREFIX] = MetricsResource(self) + + return resources + def start_listening(self): config = self.get_config() diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 6893610e71..40c433d7ae 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -18,6 +18,7 @@ from synapse.api.constants import ThirdPartyEntityKind from synapse.api.errors import CodeMessageException from synapse.http.client import SimpleHttpClient from synapse.events.utils import serialize_event +from synapse.util.logcontext import preserve_fn, make_deferred_yieldable from synapse.util.caches.response_cache import ResponseCache from synapse.types import ThirdPartyInstanceID @@ -192,9 +193,12 @@ class ApplicationServiceApi(SimpleHttpClient): defer.returnValue(None) key = (service.id, protocol) - return self.protocol_meta_cache.get(key) or ( - self.protocol_meta_cache.set(key, _get()) - ) + result = self.protocol_meta_cache.get(key) + if not result: + result = self.protocol_meta_cache.set( + key, preserve_fn(_get)() + ) + return make_deferred_yieldable(result) @defer.inlineCallbacks def push_bulk(self, service, events, txn_id=None): diff --git a/synapse/config/cas.py b/synapse/config/cas.py index 938f6f25f8..8109e5f95e 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py @@ -41,7 +41,7 @@ class CasConfig(Config): #cas_config: # enabled: true # server_url: "https://cas-server.com" - # service_url: "https://homesever.domain.com:8448" + # service_url: "https://homeserver.domain.com:8448" # #required_attributes: # # name: value """ diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 2dbeafa9dd..a1d6e4d4f7 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -148,8 +148,8 @@ def setup_logging(config, use_worker_options=False): "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s" " - %(message)s" ) - if log_config is None: + if log_config is None: level = logging.INFO level_for_storage = logging.INFO if config.verbosity: @@ -176,6 +176,10 @@ def setup_logging(config, use_worker_options=False): logger.info("Opened new log file due to SIGHUP") else: handler = logging.StreamHandler() + + def sighup(signum, stack): + pass + handler.setFormatter(formatter) handler.addFilter(LoggingContextFilter(request="")) diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py index 90824cab7f..e9828fac17 100644 --- a/synapse/config/password_auth_providers.py +++ b/synapse/config/password_auth_providers.py @@ -13,41 +13,40 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import Config, ConfigError +from ._base import Config from synapse.util.module_loader import load_module +LDAP_PROVIDER = 'ldap_auth_provider.LdapAuthProvider' + class PasswordAuthProviderConfig(Config): def read_config(self, config): self.password_providers = [] - - provider_config = None + providers = [] # We want to be backwards compatible with the old `ldap_config` # param. ldap_config = config.get("ldap_config", {}) - self.ldap_enabled = ldap_config.get("enabled", False) - if self.ldap_enabled: - from ldap_auth_provider import LdapAuthProvider - parsed_config = LdapAuthProvider.parse_config(ldap_config) - self.password_providers.append((LdapAuthProvider, parsed_config)) + if ldap_config.get("enabled", False): + providers.append[{ + 'module': LDAP_PROVIDER, + 'config': ldap_config, + }] - providers = config.get("password_providers", []) + providers.extend(config.get("password_providers", [])) for provider in providers: + mod_name = provider['module'] + # This is for backwards compat when the ldap auth provider resided # in this package. - if provider['module'] == "synapse.util.ldap_auth_provider.LdapAuthProvider": - from ldap_auth_provider import LdapAuthProvider - provider_class = LdapAuthProvider - try: - provider_config = provider_class.parse_config(provider["config"]) - except Exception as e: - raise ConfigError( - "Failed to parse config for %r: %r" % (provider['module'], e) - ) - else: - (provider_class, provider_config) = load_module(provider) + if mod_name == "synapse.util.ldap_auth_provider.LdapAuthProvider": + mod_name = LDAP_PROVIDER + + (provider_class, provider_config) = load_module({ + "module": mod_name, + "config": provider['config'], + }) self.password_providers.append((provider_class, provider_config)) diff --git a/synapse/config/server.py b/synapse/config/server.py index b66993dab9..4d9193536d 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -247,6 +247,13 @@ class ServerConfig(Config): - names: [federation] # Federation APIs compress: false + # optional list of additional endpoints which can be loaded via + # dynamic modules + # additional_resources: + # "/_matrix/my/custom/endpoint": + # module: my_module.CustomRequestHandler + # config: {} + # Unsecure HTTP listener, # For when matrix traffic passes through loadbalancer that unwraps TLS. - port: %(unsecure_port)s diff --git a/synapse/config/tls.py b/synapse/config/tls.py index 247f18f454..4748f71c2f 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -109,6 +109,12 @@ class TlsConfig(Config): # key. It may be necessary to publish the fingerprints of a new # certificate and wait until the "valid_until_ts" of the previous key # responses have passed before deploying it. + # + # You can calculate a fingerprint from a given TLS listener via: + # openssl s_client -connect $host:$port < /dev/null 2> /dev/null | + # openssl x509 -outform DER | openssl sha256 -binary | base64 | tr -d '=' + # or by checking matrix.org/federationtester/api/report?server_name=$host + # tls_fingerprints: [] # tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}] """ % locals() diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index e15228e70b..a2327f24b6 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -18,6 +18,7 @@ from .federation_base import FederationBase from .units import Transaction, Edu from synapse.util import async +from synapse.util.logcontext import make_deferred_yieldable, preserve_fn from synapse.util.logutils import log_function from synapse.util.caches.response_cache import ResponseCache from synapse.events import FrozenEvent @@ -253,12 +254,13 @@ class FederationServer(FederationBase): result = self._state_resp_cache.get((room_id, event_id)) if not result: with (yield self._server_linearizer.queue((origin, room_id))): - resp = yield self._state_resp_cache.set( + d = self._state_resp_cache.set( (room_id, event_id), - self._on_context_state_request_compute(room_id, event_id) + preserve_fn(self._on_context_state_request_compute)(room_id, event_id) ) + resp = yield make_deferred_yieldable(d) else: - resp = yield result + resp = yield make_deferred_yieldable(result) defer.returnValue((200, resp)) diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index d25ae1b282..ed41dfc7ee 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -531,9 +531,9 @@ class TransportLayerClient(object): ignore_backoff=True, ) - def add_room_to_group(self, destination, group_id, requester_user_id, room_id, - content): - """Add a room to a group + def update_room_group_association(self, destination, group_id, requester_user_id, + room_id, content): + """Add or update an association between room and group """ path = PREFIX + "/groups/%s/room/%s" % (group_id, room_id,) @@ -545,7 +545,8 @@ class TransportLayerClient(object): ignore_backoff=True, ) - def remove_room_from_group(self, destination, group_id, requester_user_id, room_id): + def delete_room_group_association(self, destination, group_id, requester_user_id, + room_id): """Remove a room from a group """ path = PREFIX + "/groups/%s/room/%s" % (group_id, room_id,) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 8f3c14c303..ded6d4edc9 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -684,7 +684,7 @@ class FederationGroupsAddRoomsServlet(BaseFederationServlet): if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - new_content = yield self.handler.add_room_to_group( + new_content = yield self.handler.update_room_group_association( group_id, requester_user_id, room_id, content ) @@ -696,7 +696,7 @@ class FederationGroupsAddRoomsServlet(BaseFederationServlet): if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - new_content = yield self.handler.remove_room_from_group( + new_content = yield self.handler.delete_room_group_association( group_id, requester_user_id, room_id, ) diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py index b751cf5e43..1fb709e6c3 100644 --- a/synapse/groups/attestations.py +++ b/synapse/groups/attestations.py @@ -13,6 +13,31 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Attestations ensure that users and groups can't lie about their memberships. + +When a user joins a group the HS and GS swap attestations, which allow them +both to independently prove to third parties their membership.These +attestations have a validity period so need to be periodically renewed. + +If a user leaves (or gets kicked out of) a group, either side can still use +their attestation to "prove" their membership, until the attestation expires. +Therefore attestations shouldn't be relied on to prove membership in important +cases, but can for less important situtations, e.g. showing a users membership +of groups on their profile, showing flairs, etc.abs + +An attestsation is a signed blob of json that looks like: + + { + "user_id": "@foo:a.example.com", + "group_id": "+bar:b.example.com", + "valid_until_ms": 1507994728530, + "signatures":{"matrix.org":{"ed25519:auto":"..."}} + } +""" + +import logging +import random + from twisted.internet import defer from synapse.api.errors import SynapseError @@ -22,9 +47,17 @@ from synapse.util.logcontext import preserve_fn from signedjson.sign import sign_json +logger = logging.getLogger(__name__) + + # Default validity duration for new attestations we create DEFAULT_ATTESTATION_LENGTH_MS = 3 * 24 * 60 * 60 * 1000 +# We add some jitter to the validity duration of attestations so that if we +# add lots of users at once we don't need to renew them all at once. +# The jitter is a multiplier picked randomly between the first and second number +DEFAULT_ATTESTATION_JITTER = (0.9, 1.3) + # Start trying to update our attestations when they come this close to expiring UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000 @@ -73,10 +106,14 @@ class GroupAttestationSigning(object): """Create an attestation for the group_id and user_id with default validity length. """ + validity_period = DEFAULT_ATTESTATION_LENGTH_MS + validity_period *= random.uniform(*DEFAULT_ATTESTATION_JITTER) + valid_until_ms = int(self.clock.time_msec() + validity_period) + return sign_json({ "group_id": group_id, "user_id": user_id, - "valid_until_ms": self.clock.time_msec() + DEFAULT_ATTESTATION_LENGTH_MS, + "valid_until_ms": valid_until_ms, }, self.server_name, self.signing_key) @@ -128,12 +165,19 @@ class GroupAttestionRenewer(object): @defer.inlineCallbacks def _renew_attestation(group_id, user_id): - attestation = self.attestations.create_attestation(group_id, user_id) - - if self.is_mine_id(group_id): + if not self.is_mine_id(group_id): + destination = get_domain_from_id(group_id) + elif not self.is_mine_id(user_id): destination = get_domain_from_id(user_id) else: - destination = get_domain_from_id(group_id) + logger.warn( + "Incorrectly trying to do attestations for user: %r in %r", + user_id, group_id, + ) + yield self.store.remove_attestation_renewal(group_id, user_id) + return + + attestation = self.attestations.create_attestation(group_id, user_id) yield self.transport_client.renew_group_attestation( destination, group_id, user_id, diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index 23beb3187e..11199dd215 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -49,7 +49,8 @@ class GroupsServerHandler(object): hs.get_groups_attestation_renewer() @defer.inlineCallbacks - def check_group_is_ours(self, group_id, and_exists=False, and_is_admin=None): + def check_group_is_ours(self, group_id, requester_user_id, + and_exists=False, and_is_admin=None): """Check that the group is ours, and optionally if it exists. If group does exist then return group. @@ -67,6 +68,10 @@ class GroupsServerHandler(object): if and_exists and not group: raise SynapseError(404, "Unknown group") + is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) + if group and not is_user_in_group and not group["is_public"]: + raise SynapseError(404, "Unknown group") + if and_is_admin: is_admin = yield self.store.is_user_admin_in_group(group_id, and_is_admin) if not is_admin: @@ -84,7 +89,7 @@ class GroupsServerHandler(object): A user/room may appear in multiple roles/categories. """ - yield self.check_group_is_ours(group_id, and_exists=True) + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) @@ -153,10 +158,16 @@ class GroupsServerHandler(object): }) @defer.inlineCallbacks - def update_group_summary_room(self, group_id, user_id, room_id, category_id, content): + def update_group_summary_room(self, group_id, requester_user_id, + room_id, category_id, content): """Add/update a room to the group summary """ - yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id) + yield self.check_group_is_ours( + group_id, + requester_user_id, + and_exists=True, + and_is_admin=requester_user_id, + ) RoomID.from_string(room_id) # Ensure valid room id @@ -175,10 +186,16 @@ class GroupsServerHandler(object): defer.returnValue({}) @defer.inlineCallbacks - def delete_group_summary_room(self, group_id, user_id, room_id, category_id): + def delete_group_summary_room(self, group_id, requester_user_id, + room_id, category_id): """Remove a room from the summary """ - yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id) + yield self.check_group_is_ours( + group_id, + requester_user_id, + and_exists=True, + and_is_admin=requester_user_id, + ) yield self.store.remove_room_from_summary( group_id=group_id, @@ -189,10 +206,10 @@ class GroupsServerHandler(object): defer.returnValue({}) @defer.inlineCallbacks - def get_group_categories(self, group_id, user_id): + def get_group_categories(self, group_id, requester_user_id): """Get all categories in a group (as seen by user) """ - yield self.check_group_is_ours(group_id, and_exists=True) + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) categories = yield self.store.get_group_categories( group_id=group_id, @@ -200,10 +217,10 @@ class GroupsServerHandler(object): defer.returnValue({"categories": categories}) @defer.inlineCallbacks - def get_group_category(self, group_id, user_id, category_id): + def get_group_category(self, group_id, requester_user_id, category_id): """Get a specific category in a group (as seen by user) """ - yield self.check_group_is_ours(group_id, and_exists=True) + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) res = yield self.store.get_group_category( group_id=group_id, @@ -213,10 +230,15 @@ class GroupsServerHandler(object): defer.returnValue(res) @defer.inlineCallbacks - def update_group_category(self, group_id, user_id, category_id, content): + def update_group_category(self, group_id, requester_user_id, category_id, content): """Add/Update a group category """ - yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id) + yield self.check_group_is_ours( + group_id, + requester_user_id, + and_exists=True, + and_is_admin=requester_user_id, + ) is_public = _parse_visibility_from_contents(content) profile = content.get("profile") @@ -231,10 +253,15 @@ class GroupsServerHandler(object): defer.returnValue({}) @defer.inlineCallbacks - def delete_group_category(self, group_id, user_id, category_id): + def delete_group_category(self, group_id, requester_user_id, category_id): """Delete a group category """ - yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id) + yield self.check_group_is_ours( + group_id, + requester_user_id, + and_exists=True, + and_is_admin=requester_user_id + ) yield self.store.remove_group_category( group_id=group_id, @@ -244,10 +271,10 @@ class GroupsServerHandler(object): defer.returnValue({}) @defer.inlineCallbacks - def get_group_roles(self, group_id, user_id): + def get_group_roles(self, group_id, requester_user_id): """Get all roles in a group (as seen by user) """ - yield self.check_group_is_ours(group_id, and_exists=True) + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) roles = yield self.store.get_group_roles( group_id=group_id, @@ -255,10 +282,10 @@ class GroupsServerHandler(object): defer.returnValue({"roles": roles}) @defer.inlineCallbacks - def get_group_role(self, group_id, user_id, role_id): + def get_group_role(self, group_id, requester_user_id, role_id): """Get a specific role in a group (as seen by user) """ - yield self.check_group_is_ours(group_id, and_exists=True) + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) res = yield self.store.get_group_role( group_id=group_id, @@ -267,10 +294,15 @@ class GroupsServerHandler(object): defer.returnValue(res) @defer.inlineCallbacks - def update_group_role(self, group_id, user_id, role_id, content): + def update_group_role(self, group_id, requester_user_id, role_id, content): """Add/update a role in a group """ - yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id) + yield self.check_group_is_ours( + group_id, + requester_user_id, + and_exists=True, + and_is_admin=requester_user_id, + ) is_public = _parse_visibility_from_contents(content) @@ -286,10 +318,15 @@ class GroupsServerHandler(object): defer.returnValue({}) @defer.inlineCallbacks - def delete_group_role(self, group_id, user_id, role_id): + def delete_group_role(self, group_id, requester_user_id, role_id): """Remove role from group """ - yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id) + yield self.check_group_is_ours( + group_id, + requester_user_id, + and_exists=True, + and_is_admin=requester_user_id, + ) yield self.store.remove_group_role( group_id=group_id, @@ -304,7 +341,7 @@ class GroupsServerHandler(object): """Add/update a users entry in the group summary """ yield self.check_group_is_ours( - group_id, and_exists=True, and_is_admin=requester_user_id, + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id, ) order = content.get("order", None) @@ -326,7 +363,7 @@ class GroupsServerHandler(object): """Remove a user from the group summary """ yield self.check_group_is_ours( - group_id, and_exists=True, and_is_admin=requester_user_id, + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id, ) yield self.store.remove_user_from_summary( @@ -342,7 +379,7 @@ class GroupsServerHandler(object): """Get the group profile as seen by requester_user_id """ - yield self.check_group_is_ours(group_id) + yield self.check_group_is_ours(group_id, requester_user_id) group_description = yield self.store.get_group(group_id) @@ -356,7 +393,7 @@ class GroupsServerHandler(object): """Update the group profile """ yield self.check_group_is_ours( - group_id, and_exists=True, and_is_admin=requester_user_id, + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id, ) profile = {} @@ -377,7 +414,7 @@ class GroupsServerHandler(object): The ordering is arbitrary at the moment """ - yield self.check_group_is_ours(group_id, and_exists=True) + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) @@ -389,14 +426,15 @@ class GroupsServerHandler(object): for user_result in user_results: g_user_id = user_result["user_id"] is_public = user_result["is_public"] + is_privileged = user_result["is_admin"] entry = {"user_id": g_user_id} profile = yield self.profile_handler.get_profile_from_cache(g_user_id) entry.update(profile) - if not is_public: - entry["is_public"] = False + entry["is_public"] = bool(is_public) + entry["is_privileged"] = bool(is_privileged) if not self.is_mine_id(g_user_id): attestation = yield self.store.get_remote_attestation(group_id, g_user_id) @@ -425,7 +463,7 @@ class GroupsServerHandler(object): The ordering is arbitrary at the moment """ - yield self.check_group_is_ours(group_id, and_exists=True) + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) @@ -459,7 +497,7 @@ class GroupsServerHandler(object): This returns rooms in order of decreasing number of joined users """ - yield self.check_group_is_ours(group_id, and_exists=True) + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) @@ -470,7 +508,6 @@ class GroupsServerHandler(object): chunk = [] for room_result in room_results: room_id = room_result["room_id"] - is_public = room_result["is_public"] joined_users = yield self.store.get_users_in_room(room_id) entry = yield self.room_list_handler.generate_room_entry( @@ -481,8 +518,7 @@ class GroupsServerHandler(object): if not entry: continue - if not is_public: - entry["is_public"] = False + entry["is_public"] = bool(room_result["is_public"]) chunk.append(entry) @@ -494,30 +530,33 @@ class GroupsServerHandler(object): }) @defer.inlineCallbacks - def add_room_to_group(self, group_id, requester_user_id, room_id, content): - """Add room to group + def update_room_group_association(self, group_id, requester_user_id, room_id, + content): + """Add or update an association between room and group """ RoomID.from_string(room_id) # Ensure valid room id yield self.check_group_is_ours( - group_id, and_exists=True, and_is_admin=requester_user_id + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) is_public = _parse_visibility_from_contents(content) - yield self.store.add_room_to_group(group_id, room_id, is_public=is_public) + yield self.store.update_room_group_association( + group_id, room_id, is_public=is_public + ) defer.returnValue({}) @defer.inlineCallbacks - def remove_room_from_group(self, group_id, requester_user_id, room_id): + def delete_room_group_association(self, group_id, requester_user_id, room_id): """Remove room from group """ yield self.check_group_is_ours( - group_id, and_exists=True, and_is_admin=requester_user_id + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) - yield self.store.remove_room_from_group(group_id, room_id) + yield self.store.delete_room_group_association(group_id, room_id) defer.returnValue({}) @@ -527,7 +566,7 @@ class GroupsServerHandler(object): """ group = yield self.check_group_is_ours( - group_id, and_exists=True, and_is_admin=requester_user_id + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) # TODO: Check if user knocked @@ -596,35 +635,40 @@ class GroupsServerHandler(object): raise SynapseError(502, "Unknown state returned by HS") @defer.inlineCallbacks - def accept_invite(self, group_id, user_id, content): + def accept_invite(self, group_id, requester_user_id, content): """User tries to accept an invite to the group. This is different from them asking to join, and so should error if no invite exists (and they're not a member of the group) """ - yield self.check_group_is_ours(group_id, and_exists=True) + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - if not self.store.is_user_invited_to_local_group(group_id, user_id): + is_invited = yield self.store.is_user_invited_to_local_group( + group_id, requester_user_id, + ) + if not is_invited: raise SynapseError(403, "User not invited to group") - if not self.hs.is_mine_id(user_id): + if not self.hs.is_mine_id(requester_user_id): + local_attestation = self.attestations.create_attestation( + group_id, requester_user_id, + ) remote_attestation = content["attestation"] yield self.attestations.verify_attestation( remote_attestation, - user_id=user_id, + user_id=requester_user_id, group_id=group_id, ) else: + local_attestation = None remote_attestation = None - local_attestation = self.attestations.create_attestation(group_id, user_id) - is_public = _parse_visibility_from_contents(content) yield self.store.add_user_to_group( - group_id, user_id, + group_id, requester_user_id, is_admin=False, is_public=is_public, local_attestation=local_attestation, @@ -637,31 +681,31 @@ class GroupsServerHandler(object): }) @defer.inlineCallbacks - def knock(self, group_id, user_id, content): + def knock(self, group_id, requester_user_id, content): """A user requests becoming a member of the group """ - yield self.check_group_is_ours(group_id, and_exists=True) + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) raise NotImplementedError() @defer.inlineCallbacks - def accept_knock(self, group_id, user_id, content): + def accept_knock(self, group_id, requester_user_id, content): """Accept a users knock to the room. Errors if the user hasn't knocked, rather than inviting them. """ - yield self.check_group_is_ours(group_id, and_exists=True) + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) raise NotImplementedError() @defer.inlineCallbacks def remove_user_from_group(self, group_id, user_id, requester_user_id, content): - """Remove a user from the group; either a user is leaving or and admin - kicked htem. + """Remove a user from the group; either a user is leaving or an admin + kicked them. """ - yield self.check_group_is_ours(group_id, and_exists=True) + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) is_kick = False if requester_user_id != user_id: @@ -692,8 +736,8 @@ class GroupsServerHandler(object): defer.returnValue({}) @defer.inlineCallbacks - def create_group(self, group_id, user_id, content): - group = yield self.check_group_is_ours(group_id) + def create_group(self, group_id, requester_user_id, content): + group = yield self.check_group_is_ours(group_id, requester_user_id) logger.info("Attempting to create group with ID: %r", group_id) @@ -703,11 +747,11 @@ class GroupsServerHandler(object): if group: raise SynapseError(400, "Group already exists") - is_admin = yield self.auth.is_server_admin(UserID.from_string(user_id)) + is_admin = yield self.auth.is_server_admin(UserID.from_string(requester_user_id)) if not is_admin: if not self.hs.config.enable_group_creation: raise SynapseError( - 403, "Only server admin can create group on this server", + 403, "Only a server admin can create groups on this server", ) localpart = group_id_obj.localpart if not localpart.startswith(self.hs.config.group_creation_prefix): @@ -727,38 +771,41 @@ class GroupsServerHandler(object): yield self.store.create_group( group_id, - user_id, + requester_user_id, name=name, avatar_url=avatar_url, short_description=short_description, long_description=long_description, ) - if not self.hs.is_mine_id(user_id): + if not self.hs.is_mine_id(requester_user_id): remote_attestation = content["attestation"] yield self.attestations.verify_attestation( remote_attestation, - user_id=user_id, + user_id=requester_user_id, group_id=group_id, ) - local_attestation = self.attestations.create_attestation(group_id, user_id) + local_attestation = self.attestations.create_attestation( + group_id, + requester_user_id, + ) else: local_attestation = None remote_attestation = None yield self.store.add_user_to_group( - group_id, user_id, + group_id, requester_user_id, is_admin=True, is_public=True, # TODO local_attestation=local_attestation, remote_attestation=remote_attestation, ) - if not self.hs.is_mine_id(user_id): + if not self.hs.is_mine_id(requester_user_id): yield self.store.add_remote_profile_cache( - user_id, + requester_user_id, displayname=user_profile.get("displayname"), avatar_url=user_profile.get("avatar_url"), ) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 9cef9d184b..7a0ba6ef35 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -13,13 +13,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from twisted.internet import defer from ._base import BaseHandler from synapse.api.constants import LoginType -from synapse.types import UserID from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError +from synapse.module_api import ModuleApi +from synapse.types import UserID from synapse.util.async import run_on_reactor from synapse.util.caches.expiringcache import ExpiringCache @@ -63,10 +63,7 @@ class AuthHandler(BaseHandler): reset_expiry_on_get=True, ) - account_handler = _AccountHandler( - hs, check_user_exists=self.check_user_exists - ) - + account_handler = ModuleApi(hs, self) self.password_providers = [ module(config=config, account_handler=account_handler) for module, config in hs.config.password_providers @@ -75,14 +72,24 @@ class AuthHandler(BaseHandler): logger.info("Extra password_providers: %r", self.password_providers) self.hs = hs # FIXME better possibility to access registrationHandler later? - self.device_handler = hs.get_device_handler() self.macaroon_gen = hs.get_macaroon_generator() + self._password_enabled = hs.config.password_enabled + + login_types = set() + if self._password_enabled: + login_types.add(LoginType.PASSWORD) + for provider in self.password_providers: + if hasattr(provider, "get_supported_login_types"): + login_types.update( + provider.get_supported_login_types().keys() + ) + self._supported_login_types = frozenset(login_types) @defer.inlineCallbacks def check_auth(self, flows, clientdict, clientip): """ Takes a dictionary sent by the client in the login / registration - protocol and handles the login flow. + protocol and handles the User-Interactive Auth flow. As a side effect, this function fills in the 'creds' key on the user's session with a map, which maps each auth-type (str) to the relevant @@ -260,16 +267,19 @@ class AuthHandler(BaseHandler): sess = self._get_session_info(session_id) return sess.setdefault('serverdict', {}).get(key, default) + @defer.inlineCallbacks def _check_password_auth(self, authdict, _): if "user" not in authdict or "password" not in authdict: raise LoginError(400, "", Codes.MISSING_PARAM) user_id = authdict["user"] password = authdict["password"] - if not user_id.startswith('@'): - user_id = UserID(user_id, self.hs.hostname).to_string() - return self._check_password(user_id, password) + (canonical_id, callback) = yield self.validate_login(user_id, { + "type": LoginType.PASSWORD, + "password": password, + }) + defer.returnValue(canonical_id) @defer.inlineCallbacks def _check_recaptcha(self, authdict, clientip): @@ -398,26 +408,8 @@ class AuthHandler(BaseHandler): return self.sessions[session_id] - def validate_password_login(self, user_id, password): - """ - Authenticates the user with their username and password. - - Used only by the v1 login API. - - Args: - user_id (str): complete @user:id - password (str): Password - Returns: - defer.Deferred: (str) canonical user id - Raises: - StoreError if there was a problem accessing the database - LoginError if there was an authentication problem. - """ - return self._check_password(user_id, password) - @defer.inlineCallbacks - def get_access_token_for_user_id(self, user_id, device_id=None, - initial_display_name=None): + def get_access_token_for_user_id(self, user_id, device_id=None): """ Creates a new access token for the user with the given user ID. @@ -431,13 +423,10 @@ class AuthHandler(BaseHandler): device_id (str|None): the device ID to associate with the tokens. None to leave the tokens unassociated with a device (deprecated: we should always have a device ID) - initial_display_name (str): display name to associate with the - device if it needs re-registering Returns: The access token for the user's session. Raises: StoreError if there was a problem storing the token. - LoginError if there was an authentication problem. """ logger.info("Logging in user %s on device %s", user_id, device_id) access_token = yield self.issue_access_token(user_id, device_id) @@ -447,9 +436,11 @@ class AuthHandler(BaseHandler): # really don't want is active access_tokens without a record of the # device, so we double-check it here. if device_id is not None: - yield self.device_handler.check_device_registered( - user_id, device_id, initial_display_name - ) + try: + yield self.store.get_device(user_id, device_id) + except StoreError: + yield self.store.delete_access_token(access_token) + raise StoreError(400, "Login raced against device deletion") defer.returnValue(access_token) @@ -501,29 +492,115 @@ class AuthHandler(BaseHandler): ) defer.returnValue(result) + def get_supported_login_types(self): + """Get a the login types supported for the /login API + + By default this is just 'm.login.password' (unless password_enabled is + False in the config file), but password auth providers can provide + other login types. + + Returns: + Iterable[str]: login types + """ + return self._supported_login_types + @defer.inlineCallbacks - def _check_password(self, user_id, password): - """Authenticate a user against the LDAP and local databases. + def validate_login(self, username, login_submission): + """Authenticates the user for the /login API - user_id is checked case insensitively against the local database, but - will throw if there are multiple inexact matches. + Also used by the user-interactive auth flow to validate + m.login.password auth types. Args: - user_id (str): complete @user:id + username (str): username supplied by the user + login_submission (dict): the whole of the login submission + (including 'type' and other relevant fields) Returns: - (str) the canonical_user_id + Deferred[str, func]: canonical user id, and optional callback + to be called once the access token and device id are issued Raises: - LoginError if login fails + StoreError if there was a problem accessing the database + SynapseError if there was a problem with the request + LoginError if there was an authentication problem. """ + + if username.startswith('@'): + qualified_user_id = username + else: + qualified_user_id = UserID( + username, self.hs.hostname + ).to_string() + + login_type = login_submission.get("type") + known_login_type = False + + # special case to check for "password" for the check_password interface + # for the auth providers + password = login_submission.get("password") + if login_type == LoginType.PASSWORD: + if not self._password_enabled: + raise SynapseError(400, "Password login has been disabled.") + if not password: + raise SynapseError(400, "Missing parameter: password") + for provider in self.password_providers: - is_valid = yield provider.check_password(user_id, password) - if is_valid: - defer.returnValue(user_id) + if (hasattr(provider, "check_password") + and login_type == LoginType.PASSWORD): + known_login_type = True + is_valid = yield provider.check_password( + qualified_user_id, password, + ) + if is_valid: + defer.returnValue(qualified_user_id) + + if (not hasattr(provider, "get_supported_login_types") + or not hasattr(provider, "check_auth")): + # this password provider doesn't understand custom login types + continue + + supported_login_types = provider.get_supported_login_types() + if login_type not in supported_login_types: + # this password provider doesn't understand this login type + continue + + known_login_type = True + login_fields = supported_login_types[login_type] + + missing_fields = [] + login_dict = {} + for f in login_fields: + if f not in login_submission: + missing_fields.append(f) + else: + login_dict[f] = login_submission[f] + if missing_fields: + raise SynapseError( + 400, "Missing parameters for login type %s: %s" % ( + login_type, + missing_fields, + ), + ) + + result = yield provider.check_auth( + username, login_type, login_dict, + ) + if result: + if isinstance(result, str): + result = (result, None) + defer.returnValue(result) + + if login_type == LoginType.PASSWORD: + known_login_type = True + + canonical_user_id = yield self._check_local_password( + qualified_user_id, password, + ) - canonical_user_id = yield self._check_local_password(user_id, password) + if canonical_user_id: + defer.returnValue((canonical_user_id, None)) - if canonical_user_id: - defer.returnValue(canonical_user_id) + if not known_login_type: + raise SynapseError(400, "Unknown login type %s" % login_type) # unknown username or invalid password. We raise a 403 here, but note # that if we're doing user-interactive login, it turns all LoginErrors @@ -584,14 +661,81 @@ class AuthHandler(BaseHandler): if e.code == 404: raise SynapseError(404, "Unknown user", Codes.NOT_FOUND) raise e - yield self.store.user_delete_access_tokens( - user_id, except_access_token_id + yield self.delete_access_tokens_for_user( + user_id, except_token_id=except_access_token_id, ) yield self.hs.get_pusherpool().remove_pushers_by_user( user_id, except_access_token_id ) @defer.inlineCallbacks + def deactivate_account(self, user_id): + """Deactivate a user's account + + Args: + user_id (str): ID of user to be deactivated + + Returns: + Deferred + """ + # FIXME: Theoretically there is a race here wherein user resets + # password using threepid. + yield self.delete_access_tokens_for_user(user_id) + yield self.store.user_delete_threepids(user_id) + yield self.store.user_set_password_hash(user_id, None) + + @defer.inlineCallbacks + def delete_access_token(self, access_token): + """Invalidate a single access token + + Args: + access_token (str): access token to be deleted + + Returns: + Deferred + """ + user_info = yield self.auth.get_user_by_access_token(access_token) + yield self.store.delete_access_token(access_token) + + # see if any of our auth providers want to know about this + for provider in self.password_providers: + if hasattr(provider, "on_logged_out"): + yield provider.on_logged_out( + user_id=str(user_info["user"]), + device_id=user_info["device_id"], + access_token=access_token, + ) + + @defer.inlineCallbacks + def delete_access_tokens_for_user(self, user_id, except_token_id=None, + device_id=None): + """Invalidate access tokens belonging to a user + + Args: + user_id (str): ID of user the tokens belong to + except_token_id (str|None): access_token ID which should *not* be + deleted + device_id (str|None): ID of device the tokens are associated with. + If None, tokens associated with any device (or no device) will + be deleted + Returns: + Deferred + """ + tokens_and_devices = yield self.store.user_delete_access_tokens( + user_id, except_token_id=except_token_id, device_id=device_id, + ) + + # see if any of our auth providers want to know about this + for provider in self.password_providers: + if hasattr(provider, "on_logged_out"): + for token, device_id in tokens_and_devices: + yield provider.on_logged_out( + user_id=user_id, + device_id=device_id, + access_token=token, + ) + + @defer.inlineCallbacks def add_threepid(self, user_id, medium, address, validated_at): # 'Canonicalise' email addresses down to lower case. # We've now moving towards the Home Server being the entity that @@ -696,30 +840,3 @@ class MacaroonGeneartor(object): macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) return macaroon - - -class _AccountHandler(object): - """A proxy object that gets passed to password auth providers so they - can register new users etc if necessary. - """ - def __init__(self, hs, check_user_exists): - self.hs = hs - - self._check_user_exists = check_user_exists - - def check_user_exists(self, user_id): - """Check if user exissts. - - Returns: - Deferred(bool) - """ - return self._check_user_exists(user_id) - - def register(self, localpart): - """Registers a new user with given localpart - - Returns: - Deferred: a 2-tuple of (user_id, access_token) - """ - reg = self.hs.get_handlers().registration_handler - return reg.register(localpart=localpart) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index dac4b3f4e0..579d8477ba 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -34,6 +34,7 @@ class DeviceHandler(BaseHandler): self.hs = hs self.state = hs.get_state_handler() + self._auth_handler = hs.get_auth_handler() self.federation_sender = hs.get_federation_sender() self.federation = hs.get_replication_layer() @@ -159,9 +160,8 @@ class DeviceHandler(BaseHandler): else: raise - yield self.store.user_delete_access_tokens( + yield self._auth_handler.delete_access_tokens_for_user( user_id, device_id=device_id, - delete_refresh_tokens=True, ) yield self.store.delete_e2e_keys_by_device( @@ -194,9 +194,8 @@ class DeviceHandler(BaseHandler): # Delete access tokens and e2e keys for each device. Not optimised as it is not # considered as part of a critical path. for device_id in device_ids: - yield self.store.user_delete_access_tokens( + yield self._auth_handler.delete_access_tokens_for_user( user_id, device_id=device_id, - delete_refresh_tokens=True, ) yield self.store.delete_e2e_keys_by_device( user_id=user_id, device_id=device_id diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index 6699d0888f..dabc2a3fbb 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -70,8 +70,8 @@ class GroupsLocalHandler(object): get_invited_users_in_group = _create_rerouter("get_invited_users_in_group") - add_room_to_group = _create_rerouter("add_room_to_group") - remove_room_from_group = _create_rerouter("remove_room_from_group") + update_room_group_association = _create_rerouter("update_room_group_association") + delete_room_group_association = _create_rerouter("delete_room_group_association") update_group_summary_room = _create_rerouter("update_group_summary_room") delete_group_summary_room = _create_rerouter("delete_group_summary_room") diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 49dc33c147..f6e7e58563 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -36,6 +36,7 @@ class RegistrationHandler(BaseHandler): super(RegistrationHandler, self).__init__(hs) self.auth = hs.get_auth() + self._auth_handler = hs.get_auth_handler() self.profile_handler = hs.get_profile_handler() self.captcha_client = CaptchaServerHttpClient(hs) @@ -416,7 +417,7 @@ class RegistrationHandler(BaseHandler): create_profile_with_localpart=user.localpart, ) else: - yield self.store.user_delete_access_tokens(user_id=user_id) + yield self._auth_handler.delete_access_tokens_for_user(user_id) yield self.store.add_access_token_to_user(user_id=user_id, token=token) if displayname is not None: diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 41e1781df7..2cf34e51cb 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -20,6 +20,7 @@ from ._base import BaseHandler from synapse.api.constants import ( EventTypes, JoinRules, ) +from synapse.util.logcontext import make_deferred_yieldable, preserve_fn from synapse.util.async import concurrently_execute from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.response_cache import ResponseCache @@ -70,6 +71,7 @@ class RoomListHandler(BaseHandler): if search_filter: # We explicitly don't bother caching searches or requests for # appservice specific lists. + logger.info("Bypassing cache as search request.") return self._get_public_room_list( limit, since_token, search_filter, network_tuple=network_tuple, ) @@ -77,13 +79,16 @@ class RoomListHandler(BaseHandler): key = (limit, since_token, network_tuple) result = self.response_cache.get(key) if not result: + logger.info("No cached result, calculating one.") result = self.response_cache.set( key, - self._get_public_room_list( + preserve_fn(self._get_public_room_list)( limit, since_token, network_tuple=network_tuple ) ) - return result + else: + logger.info("Using cached deferred result.") + return make_deferred_yieldable(result) @defer.inlineCallbacks def _get_public_room_list(self, limit=None, since_token=None, diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 219529936f..b12988f3c9 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -15,7 +15,7 @@ from synapse.api.constants import Membership, EventTypes from synapse.util.async import concurrently_execute -from synapse.util.logcontext import LoggingContext +from synapse.util.logcontext import LoggingContext, make_deferred_yieldable, preserve_fn from synapse.util.metrics import Measure, measure_func from synapse.util.caches.response_cache import ResponseCache from synapse.push.clientformat import format_push_rules_for_user @@ -184,11 +184,11 @@ class SyncHandler(object): if not result: result = self.response_cache.set( sync_config.request_key, - self._wait_for_sync_for_user( + preserve_fn(self._wait_for_sync_for_user)( sync_config, since_token, timeout, full_state ) ) - return result + return make_deferred_yieldable(result) @defer.inlineCallbacks def _wait_for_sync_for_user(self, sync_config, since_token, timeout, diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py new file mode 100644 index 0000000000..343e932cb1 --- /dev/null +++ b/synapse/http/additional_resource.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.http.server import wrap_request_handler +from twisted.web.resource import Resource +from twisted.web.server import NOT_DONE_YET + + +class AdditionalResource(Resource): + """Resource wrapper for additional_resources + + If the user has configured additional_resources, we need to wrap the + handler class with a Resource so that we can map it into the resource tree. + + This class is also where we wrap the request handler with logging, metrics, + and exception handling. + """ + def __init__(self, hs, handler): + """Initialise AdditionalResource + + The ``handler`` should return a deferred which completes when it has + done handling the request. It should write a response with + ``request.write()``, and call ``request.finish()``. + + Args: + hs (synapse.server.HomeServer): homeserver + handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred): + function to be called to handle the request. + """ + Resource.__init__(self) + self._handler = handler + + # these are required by the request_handler wrapper + self.version_string = hs.version_string + self.clock = hs.get_clock() + + def render(self, request): + self._async_render(request) + return NOT_DONE_YET + + @wrap_request_handler + def _async_render(self, request): + return self._handler(request) diff --git a/synapse/http/client.py b/synapse/http/client.py index 9eba046bbf..4abb479ae3 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -18,7 +18,7 @@ from OpenSSL.SSL import VERIFY_NONE from synapse.api.errors import ( CodeMessageException, MatrixCodeMessageException, SynapseError, Codes, ) -from synapse.util.logcontext import preserve_context_over_fn +from synapse.util.logcontext import make_deferred_yieldable from synapse.util import logcontext import synapse.metrics from synapse.http.endpoint import SpiderEndpoint @@ -114,43 +114,73 @@ class SimpleHttpClient(object): raise e @defer.inlineCallbacks - def post_urlencoded_get_json(self, uri, args={}): + def post_urlencoded_get_json(self, uri, args={}, headers=None): + """ + Args: + uri (str): + args (dict[str, str|List[str]]): query params + headers (dict[str, List[str]]|None): If not None, a map from + header name to a list of values for that header + + Returns: + Deferred[object]: parsed json + """ + # TODO: Do we ever want to log message contents? logger.debug("post_urlencoded_get_json args: %s", args) query_bytes = urllib.urlencode(encode_urlencode_args(args), True) + actual_headers = { + b"Content-Type": [b"application/x-www-form-urlencoded"], + b"User-Agent": [self.user_agent], + } + if headers: + actual_headers.update(headers) + response = yield self.request( "POST", uri.encode("ascii"), - headers=Headers({ - b"Content-Type": [b"application/x-www-form-urlencoded"], - b"User-Agent": [self.user_agent], - }), + headers=Headers(actual_headers), bodyProducer=FileBodyProducer(StringIO(query_bytes)) ) - body = yield preserve_context_over_fn(readBody, response) + body = yield make_deferred_yieldable(readBody(response)) defer.returnValue(json.loads(body)) @defer.inlineCallbacks - def post_json_get_json(self, uri, post_json): + def post_json_get_json(self, uri, post_json, headers=None): + """ + + Args: + uri (str): + post_json (object): + headers (dict[str, List[str]]|None): If not None, a map from + header name to a list of values for that header + + Returns: + Deferred[object]: parsed json + """ json_str = encode_canonical_json(post_json) logger.debug("HTTP POST %s -> %s", json_str, uri) + actual_headers = { + b"Content-Type": [b"application/json"], + b"User-Agent": [self.user_agent], + } + if headers: + actual_headers.update(headers) + response = yield self.request( "POST", uri.encode("ascii"), - headers=Headers({ - b"Content-Type": [b"application/json"], - b"User-Agent": [self.user_agent], - }), + headers=Headers(actual_headers), bodyProducer=FileBodyProducer(StringIO(json_str)) ) - body = yield preserve_context_over_fn(readBody, response) + body = yield make_deferred_yieldable(readBody(response)) if 200 <= response.code < 300: defer.returnValue(json.loads(body)) @@ -160,7 +190,7 @@ class SimpleHttpClient(object): defer.returnValue(json.loads(body)) @defer.inlineCallbacks - def get_json(self, uri, args={}): + def get_json(self, uri, args={}, headers=None): """ Gets some json from the given URI. Args: @@ -169,6 +199,8 @@ class SimpleHttpClient(object): None. **Note**: The value of each key is assumed to be an iterable and *not* a string. + headers (dict[str, List[str]]|None): If not None, a map from + header name to a list of values for that header Returns: Deferred: Succeeds when we get *any* 2xx HTTP response, with the HTTP body as JSON. @@ -177,13 +209,13 @@ class SimpleHttpClient(object): error message. """ try: - body = yield self.get_raw(uri, args) + body = yield self.get_raw(uri, args, headers=headers) defer.returnValue(json.loads(body)) except CodeMessageException as e: raise self._exceptionFromFailedRequest(e.code, e.msg) @defer.inlineCallbacks - def put_json(self, uri, json_body, args={}): + def put_json(self, uri, json_body, args={}, headers=None): """ Puts some json to the given URI. Args: @@ -193,6 +225,8 @@ class SimpleHttpClient(object): None. **Note**: The value of each key is assumed to be an iterable and *not* a string. + headers (dict[str, List[str]]|None): If not None, a map from + header name to a list of values for that header Returns: Deferred: Succeeds when we get *any* 2xx HTTP response, with the HTTP body as JSON. @@ -205,17 +239,21 @@ class SimpleHttpClient(object): json_str = encode_canonical_json(json_body) + actual_headers = { + b"Content-Type": [b"application/json"], + b"User-Agent": [self.user_agent], + } + if headers: + actual_headers.update(headers) + response = yield self.request( "PUT", uri.encode("ascii"), - headers=Headers({ - b"User-Agent": [self.user_agent], - "Content-Type": ["application/json"] - }), + headers=Headers(actual_headers), bodyProducer=FileBodyProducer(StringIO(json_str)) ) - body = yield preserve_context_over_fn(readBody, response) + body = yield make_deferred_yieldable(readBody(response)) if 200 <= response.code < 300: defer.returnValue(json.loads(body)) @@ -226,7 +264,7 @@ class SimpleHttpClient(object): raise CodeMessageException(response.code, body) @defer.inlineCallbacks - def get_raw(self, uri, args={}): + def get_raw(self, uri, args={}, headers=None): """ Gets raw text from the given URI. Args: @@ -235,6 +273,8 @@ class SimpleHttpClient(object): None. **Note**: The value of each key is assumed to be an iterable and *not* a string. + headers (dict[str, List[str]]|None): If not None, a map from + header name to a list of values for that header Returns: Deferred: Succeeds when we get *any* 2xx HTTP response, with the HTTP body at text. @@ -246,15 +286,19 @@ class SimpleHttpClient(object): query_bytes = urllib.urlencode(args, True) uri = "%s?%s" % (uri, query_bytes) + actual_headers = { + b"User-Agent": [self.user_agent], + } + if headers: + actual_headers.update(headers) + response = yield self.request( "GET", uri.encode("ascii"), - headers=Headers({ - b"User-Agent": [self.user_agent], - }) + headers=Headers(actual_headers), ) - body = yield preserve_context_over_fn(readBody, response) + body = yield make_deferred_yieldable(readBody(response)) if 200 <= response.code < 300: defer.returnValue(body) @@ -274,27 +318,33 @@ class SimpleHttpClient(object): # The two should be factored out. @defer.inlineCallbacks - def get_file(self, url, output_stream, max_size=None): + def get_file(self, url, output_stream, max_size=None, headers=None): """GETs a file from a given URL Args: url (str): The URL to GET output_stream (file): File to write the response body to. + headers (dict[str, List[str]]|None): If not None, a map from + header name to a list of values for that header Returns: A (int,dict,string,int) tuple of the file length, dict of the response headers, absolute URI of the response and HTTP response code. """ + actual_headers = { + b"User-Agent": [self.user_agent], + } + if headers: + actual_headers.update(headers) + response = yield self.request( "GET", url.encode("ascii"), - headers=Headers({ - b"User-Agent": [self.user_agent], - }) + headers=Headers(actual_headers), ) - headers = dict(response.headers.getAllRawHeaders()) + resp_headers = dict(response.headers.getAllRawHeaders()) - if 'Content-Length' in headers and headers['Content-Length'] > max_size: + if 'Content-Length' in resp_headers and resp_headers['Content-Length'] > max_size: logger.warn("Requested URL is too large > %r bytes" % (self.max_size,)) raise SynapseError( 502, @@ -315,10 +365,9 @@ class SimpleHttpClient(object): # straight back in again try: - length = yield preserve_context_over_fn( - _readBodyToFile, - response, output_stream, max_size - ) + length = yield make_deferred_yieldable(_readBodyToFile( + response, output_stream, max_size, + )) except Exception as e: logger.exception("Failed to download body") raise SynapseError( @@ -327,7 +376,9 @@ class SimpleHttpClient(object): Codes.UNKNOWN, ) - defer.returnValue((length, headers, response.request.absoluteURI, response.code)) + defer.returnValue( + (length, resp_headers, response.request.absoluteURI, response.code), + ) # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. @@ -395,7 +446,7 @@ class CaptchaServerHttpClient(SimpleHttpClient): ) try: - body = yield preserve_context_over_fn(readBody, response) + body = yield make_deferred_yieldable(readBody(response)) defer.returnValue(body) except PartialDownloadError as e: # twisted dislikes google's response, no content length. diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py new file mode 100644 index 0000000000..dc680ddf43 --- /dev/null +++ b/synapse/module_api/__init__.py @@ -0,0 +1,113 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.types import UserID + + +class ModuleApi(object): + """A proxy object that gets passed to password auth providers so they + can register new users etc if necessary. + """ + def __init__(self, hs, auth_handler): + self.hs = hs + + self._store = hs.get_datastore() + self._auth = hs.get_auth() + self._auth_handler = auth_handler + + def get_user_by_req(self, req, allow_guest=False): + """Check the access_token provided for a request + + Args: + req (twisted.web.server.Request): Incoming HTTP request + allow_guest (bool): True if guest users should be allowed. If this + is False, and the access token is for a guest user, an + AuthError will be thrown + Returns: + twisted.internet.defer.Deferred[synapse.types.Requester]: + the requester for this request + Raises: + synapse.api.errors.AuthError: if no user by that token exists, + or the token is invalid. + """ + return self._auth.get_user_by_req(req, allow_guest) + + def get_qualified_user_id(self, username): + """Qualify a user id, if necessary + + Takes a user id provided by the user and adds the @ and :domain to + qualify it, if necessary + + Args: + username (str): provided user id + + Returns: + str: qualified @user:id + """ + if username.startswith('@'): + return username + return UserID(username, self.hs.hostname).to_string() + + def check_user_exists(self, user_id): + """Check if user exists. + + Args: + user_id (str): Complete @user:id + + Returns: + Deferred[str|None]: Canonical (case-corrected) user_id, or None + if the user is not registered. + """ + return self._auth_handler.check_user_exists(user_id) + + def register(self, localpart): + """Registers a new user with given localpart + + Returns: + Deferred: a 2-tuple of (user_id, access_token) + """ + reg = self.hs.get_handlers().registration_handler + return reg.register(localpart=localpart) + + def invalidate_access_token(self, access_token): + """Invalidate an access token for a user + + Args: + access_token(str): access token + + Returns: + twisted.internet.defer.Deferred - resolves once the access token + has been removed. + + Raises: + synapse.api.errors.AuthError: the access token is invalid + """ + + return self._auth_handler.delete_access_token(access_token) + + def run_db_interaction(self, desc, func, *args, **kwargs): + """Run a function with a database connection + + Args: + desc (str): description for the transaction, for metrics etc + func (func): function to be run. Passed a database cursor object + as well as *args and **kwargs + *args: positional args to be passed to func + **kwargs: named args to be passed to func + + Returns: + Deferred[object]: result of func + """ + return self._store.runInteraction(desc, func, *args, **kwargs) diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index 465b25033d..1197158fdc 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -137,7 +137,7 @@ class DeactivateAccountRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/admin/deactivate/(?P<target_user_id>[^/]*)") def __init__(self, hs): - self.store = hs.get_datastore() + self._auth_handler = hs.get_auth_handler() super(DeactivateAccountRestServlet, self).__init__(hs) @defer.inlineCallbacks @@ -149,12 +149,7 @@ class DeactivateAccountRestServlet(ClientV1RestServlet): if not is_admin: raise AuthError(403, "You are not a server admin") - # FIXME: Theoretically there is a race here wherein user resets password - # using threepid. - yield self.store.user_delete_access_tokens(target_user_id) - yield self.store.user_delete_threepids(target_user_id) - yield self.store.user_set_password_hash(target_user_id, None) - + yield self._auth_handler.deactivate_account(target_user_id) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 9536e8ade6..5669ecb724 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -85,7 +85,6 @@ def login_id_thirdparty_from_phone(identifier): class LoginRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/login$") - PASS_TYPE = "m.login.password" SAML2_TYPE = "m.login.saml2" CAS_TYPE = "m.login.cas" TOKEN_TYPE = "m.login.token" @@ -94,7 +93,6 @@ class LoginRestServlet(ClientV1RestServlet): def __init__(self, hs): super(LoginRestServlet, self).__init__(hs) self.idp_redirect_url = hs.config.saml2_idp_redirect_url - self.password_enabled = hs.config.password_enabled self.saml2_enabled = hs.config.saml2_enabled self.jwt_enabled = hs.config.jwt_enabled self.jwt_secret = hs.config.jwt_secret @@ -121,8 +119,10 @@ class LoginRestServlet(ClientV1RestServlet): # fall back to the fallback API if they don't understand one of the # login flow types returned. flows.append({"type": LoginRestServlet.TOKEN_TYPE}) - if self.password_enabled: - flows.append({"type": LoginRestServlet.PASS_TYPE}) + + flows.extend(( + {"type": t} for t in self.auth_handler.get_supported_login_types() + )) return (200, {"flows": flows}) @@ -133,14 +133,8 @@ class LoginRestServlet(ClientV1RestServlet): def on_POST(self, request): login_submission = parse_json_object_from_request(request) try: - if login_submission["type"] == LoginRestServlet.PASS_TYPE: - if not self.password_enabled: - raise SynapseError(400, "Password login has been disabled.") - - result = yield self.do_password_login(login_submission) - defer.returnValue(result) - elif self.saml2_enabled and (login_submission["type"] == - LoginRestServlet.SAML2_TYPE): + if self.saml2_enabled and (login_submission["type"] == + LoginRestServlet.SAML2_TYPE): relay_state = "" if "relay_state" in login_submission: relay_state = "&RelayState=" + urllib.quote( @@ -157,15 +151,31 @@ class LoginRestServlet(ClientV1RestServlet): result = yield self.do_token_login(login_submission) defer.returnValue(result) else: - raise SynapseError(400, "Bad login type.") + result = yield self._do_other_login(login_submission) + defer.returnValue(result) except KeyError: raise SynapseError(400, "Missing JSON keys.") @defer.inlineCallbacks - def do_password_login(self, login_submission): - if "password" not in login_submission: - raise SynapseError(400, "Missing parameter: password") + def _do_other_login(self, login_submission): + """Handle non-token/saml/jwt logins + Args: + login_submission: + + Returns: + (int, object): HTTP code/response + """ + # Log the request we got, but only certain fields to minimise the chance of + # logging someone's password (even if they accidentally put it in the wrong + # field) + logger.info( + "Got login request with identifier: %r, medium: %r, address: %r, user: %r", + login_submission.get('identifier'), + login_submission.get('medium'), + login_submission.get('address'), + login_submission.get('user'), + ) login_submission_legacy_convert(login_submission) if "identifier" not in login_submission: @@ -208,30 +218,29 @@ class LoginRestServlet(ClientV1RestServlet): if "user" not in identifier: raise SynapseError(400, "User identifier is missing 'user' key") - user_id = identifier["user"] - - if not user_id.startswith('@'): - user_id = UserID( - user_id, self.hs.hostname - ).to_string() - auth_handler = self.auth_handler - user_id = yield auth_handler.validate_password_login( - user_id=user_id, - password=login_submission["password"], + canonical_user_id, callback = yield auth_handler.validate_login( + identifier["user"], + login_submission, + ) + + device_id = yield self._register_device( + canonical_user_id, login_submission, ) - device_id = yield self._register_device(user_id, login_submission) access_token = yield auth_handler.get_access_token_for_user_id( - user_id, device_id, - login_submission.get("initial_device_display_name"), + canonical_user_id, device_id, ) + result = { - "user_id": user_id, # may have changed + "user_id": canonical_user_id, "access_token": access_token, "home_server": self.hs.hostname, "device_id": device_id, } + if callback is not None: + yield callback(result) + defer.returnValue((200, result)) @defer.inlineCallbacks @@ -244,7 +253,6 @@ class LoginRestServlet(ClientV1RestServlet): device_id = yield self._register_device(user_id, login_submission) access_token = yield auth_handler.get_access_token_for_user_id( user_id, device_id, - login_submission.get("initial_device_display_name"), ) result = { "user_id": user_id, # may have changed @@ -287,7 +295,6 @@ class LoginRestServlet(ClientV1RestServlet): ) access_token = yield auth_handler.get_access_token_for_user_id( registered_user_id, device_id, - login_submission.get("initial_device_display_name"), ) result = { diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py index 1358d0acab..6add754782 100644 --- a/synapse/rest/client/v1/logout.py +++ b/synapse/rest/client/v1/logout.py @@ -30,7 +30,7 @@ class LogoutRestServlet(ClientV1RestServlet): def __init__(self, hs): super(LogoutRestServlet, self).__init__(hs) - self.store = hs.get_datastore() + self._auth_handler = hs.get_auth_handler() def on_OPTIONS(self, request): return (200, {}) @@ -38,7 +38,7 @@ class LogoutRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request): access_token = get_access_token_from_request(request) - yield self.store.delete_access_token(access_token) + yield self._auth_handler.delete_access_token(access_token) defer.returnValue((200, {})) @@ -47,8 +47,8 @@ class LogoutAllRestServlet(ClientV1RestServlet): def __init__(self, hs): super(LogoutAllRestServlet, self).__init__(hs) - self.store = hs.get_datastore() self.auth = hs.get_auth() + self._auth_handler = hs.get_auth_handler() def on_OPTIONS(self, request): return (200, {}) @@ -57,7 +57,7 @@ class LogoutAllRestServlet(ClientV1RestServlet): def on_POST(self, request): requester = yield self.auth.get_user_by_req(request) user_id = requester.user.to_string() - yield self.store.user_delete_access_tokens(user_id) + yield self._auth_handler.delete_access_tokens_for_user(user_id) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 4990b22b9f..3062e04c59 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -13,22 +13,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging from twisted.internet import defer +from synapse.api.auth import has_access_token from synapse.api.constants import LoginType -from synapse.api.errors import LoginError, SynapseError, Codes +from synapse.api.errors import Codes, LoginError, SynapseError from synapse.http.servlet import ( - RestServlet, parse_json_object_from_request, assert_params_in_request + RestServlet, assert_params_in_request, + parse_json_object_from_request, ) from synapse.util.async import run_on_reactor from synapse.util.msisdn import phone_number_to_msisdn - from ._base import client_v2_patterns -import logging - - logger = logging.getLogger(__name__) @@ -163,7 +162,6 @@ class DeactivateAccountRestServlet(RestServlet): def __init__(self, hs): self.hs = hs - self.store = hs.get_datastore() self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() super(DeactivateAccountRestServlet, self).__init__() @@ -172,6 +170,20 @@ class DeactivateAccountRestServlet(RestServlet): def on_POST(self, request): body = parse_json_object_from_request(request) + # if the caller provides an access token, it ought to be valid. + requester = None + if has_access_token(request): + requester = yield self.auth.get_user_by_req( + request, + ) # type: synapse.types.Requester + + # allow ASes to dectivate their own users + if requester and requester.app_service: + yield self.auth_handler.deactivate_account( + requester.user.to_string() + ) + defer.returnValue((200, {})) + authed, result, params, _ = yield self.auth_handler.check_auth([ [LoginType.PASSWORD], ], body, self.hs.get_ip_from_request(request)) @@ -179,25 +191,22 @@ class DeactivateAccountRestServlet(RestServlet): if not authed: defer.returnValue((401, result)) - user_id = None - requester = None - if LoginType.PASSWORD in result: + user_id = result[LoginType.PASSWORD] # if using password, they should also be logged in - requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() - if user_id != result[LoginType.PASSWORD]: + if requester is None: + raise SynapseError( + 400, + "Deactivate account requires an access_token", + errcode=Codes.MISSING_TOKEN + ) + if requester.user.to_string() != user_id: raise LoginError(400, "", Codes.UNKNOWN) else: logger.error("Auth succeeded but no known type!", result.keys()) raise SynapseError(500, "", Codes.UNKNOWN) - # FIXME: Theoretically there is a race here wherein user resets password - # using threepid. - yield self.store.user_delete_access_tokens(user_id) - yield self.store.user_delete_threepids(user_id) - yield self.store.user_set_password_hash(user_id, None) - + yield self.auth_handler.deactivate_account(user_id) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index b57ba95d24..5321e5abbb 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) class DevicesRestServlet(servlet.RestServlet): - PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False) + PATTERNS = client_v2_patterns("/devices$", v2_alpha=False) def __init__(self, hs): """ @@ -51,7 +51,7 @@ class DeleteDevicesRestServlet(servlet.RestServlet): API for bulk deletion of devices. Accepts a JSON object with a devices key which lists the device_ids to delete. Requires user interactive auth. """ - PATTERNS = client_v2_patterns("/delete_devices", releases=[], v2_alpha=False) + PATTERNS = client_v2_patterns("/delete_devices", v2_alpha=False) def __init__(self, hs): super(DeleteDevicesRestServlet, self).__init__() @@ -93,8 +93,7 @@ class DeleteDevicesRestServlet(servlet.RestServlet): class DeviceRestServlet(servlet.RestServlet): - PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", - releases=[], v2_alpha=False) + PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", v2_alpha=False) def __init__(self, hs): """ @@ -118,6 +117,8 @@ class DeviceRestServlet(servlet.RestServlet): @defer.inlineCallbacks def on_DELETE(self, request, device_id): + requester = yield self.auth.get_user_by_req(request) + try: body = servlet.parse_json_object_from_request(request) @@ -136,11 +137,12 @@ class DeviceRestServlet(servlet.RestServlet): if not authed: defer.returnValue((401, result)) - requester = yield self.auth.get_user_by_req(request) - yield self.device_handler.delete_device( - requester.user.to_string(), - device_id, - ) + # check that the UI auth matched the access token + user_id = result[constants.LoginType.PASSWORD] + if user_id != requester.user.to_string(): + raise errors.AuthError(403, "Invalid auth") + + yield self.device_handler.delete_device(user_id, device_id) defer.returnValue((200, {})) @defer.inlineCallbacks diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py index 100f47ca9e..792608cd48 100644 --- a/synapse/rest/client/v2_alpha/groups.py +++ b/synapse/rest/client/v2_alpha/groups.py @@ -39,20 +39,23 @@ class GroupServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, group_id): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() + requester_user_id = requester.user.to_string() - group_description = yield self.groups_handler.get_group_profile(group_id, user_id) + group_description = yield self.groups_handler.get_group_profile( + group_id, + requester_user_id, + ) defer.returnValue((200, group_description)) @defer.inlineCallbacks def on_POST(self, request, group_id): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() + requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) yield self.groups_handler.update_group_profile( - group_id, user_id, content, + group_id, requester_user_id, content, ) defer.returnValue((200, {})) @@ -72,9 +75,12 @@ class GroupSummaryServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, group_id): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() + requester_user_id = requester.user.to_string() - get_group_summary = yield self.groups_handler.get_group_summary(group_id, user_id) + get_group_summary = yield self.groups_handler.get_group_summary( + group_id, + requester_user_id, + ) defer.returnValue((200, get_group_summary)) @@ -101,11 +107,11 @@ class GroupSummaryRoomsCatServlet(RestServlet): @defer.inlineCallbacks def on_PUT(self, request, group_id, category_id, room_id): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() + requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) resp = yield self.groups_handler.update_group_summary_room( - group_id, user_id, + group_id, requester_user_id, room_id=room_id, category_id=category_id, content=content, @@ -116,10 +122,10 @@ class GroupSummaryRoomsCatServlet(RestServlet): @defer.inlineCallbacks def on_DELETE(self, request, group_id, category_id, room_id): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() + requester_user_id = requester.user.to_string() resp = yield self.groups_handler.delete_group_summary_room( - group_id, user_id, + group_id, requester_user_id, room_id=room_id, category_id=category_id, ) @@ -143,10 +149,10 @@ class GroupCategoryServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, group_id, category_id): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() + requester_user_id = requester.user.to_string() category = yield self.groups_handler.get_group_category( - group_id, user_id, + group_id, requester_user_id, category_id=category_id, ) @@ -155,11 +161,11 @@ class GroupCategoryServlet(RestServlet): @defer.inlineCallbacks def on_PUT(self, request, group_id, category_id): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() + requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) resp = yield self.groups_handler.update_group_category( - group_id, user_id, + group_id, requester_user_id, category_id=category_id, content=content, ) @@ -169,10 +175,10 @@ class GroupCategoryServlet(RestServlet): @defer.inlineCallbacks def on_DELETE(self, request, group_id, category_id): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() + requester_user_id = requester.user.to_string() resp = yield self.groups_handler.delete_group_category( - group_id, user_id, + group_id, requester_user_id, category_id=category_id, ) @@ -195,10 +201,10 @@ class GroupCategoriesServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, group_id): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() + requester_user_id = requester.user.to_string() category = yield self.groups_handler.get_group_categories( - group_id, user_id, + group_id, requester_user_id, ) defer.returnValue((200, category)) @@ -220,10 +226,10 @@ class GroupRoleServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, group_id, role_id): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() + requester_user_id = requester.user.to_string() category = yield self.groups_handler.get_group_role( - group_id, user_id, + group_id, requester_user_id, role_id=role_id, ) @@ -232,11 +238,11 @@ class GroupRoleServlet(RestServlet): @defer.inlineCallbacks def on_PUT(self, request, group_id, role_id): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() + requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) resp = yield self.groups_handler.update_group_role( - group_id, user_id, + group_id, requester_user_id, role_id=role_id, content=content, ) @@ -246,10 +252,10 @@ class GroupRoleServlet(RestServlet): @defer.inlineCallbacks def on_DELETE(self, request, group_id, role_id): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() + requester_user_id = requester.user.to_string() resp = yield self.groups_handler.delete_group_role( - group_id, user_id, + group_id, requester_user_id, role_id=role_id, ) @@ -272,10 +278,10 @@ class GroupRolesServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, group_id): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() + requester_user_id = requester.user.to_string() category = yield self.groups_handler.get_group_roles( - group_id, user_id, + group_id, requester_user_id, ) defer.returnValue((200, category)) @@ -343,9 +349,9 @@ class GroupRoomServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, group_id): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() + requester_user_id = requester.user.to_string() - result = yield self.groups_handler.get_rooms_in_group(group_id, user_id) + result = yield self.groups_handler.get_rooms_in_group(group_id, requester_user_id) defer.returnValue((200, result)) @@ -364,9 +370,9 @@ class GroupUsersServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, group_id): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() + requester_user_id = requester.user.to_string() - result = yield self.groups_handler.get_users_in_group(group_id, user_id) + result = yield self.groups_handler.get_users_in_group(group_id, requester_user_id) defer.returnValue((200, result)) @@ -385,9 +391,12 @@ class GroupInvitedUsersServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, group_id): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() + requester_user_id = requester.user.to_string() - result = yield self.groups_handler.get_invited_users_in_group(group_id, user_id) + result = yield self.groups_handler.get_invited_users_in_group( + group_id, + requester_user_id, + ) defer.returnValue((200, result)) @@ -407,14 +416,18 @@ class GroupCreateServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() + requester_user_id = requester.user.to_string() # TODO: Create group on remote server content = parse_json_object_from_request(request) localpart = content.pop("localpart") group_id = GroupID(localpart, self.server_name).to_string() - result = yield self.groups_handler.create_group(group_id, user_id, content) + result = yield self.groups_handler.create_group( + group_id, + requester_user_id, + content, + ) defer.returnValue((200, result)) @@ -435,11 +448,11 @@ class GroupAdminRoomsServlet(RestServlet): @defer.inlineCallbacks def on_PUT(self, request, group_id, room_id): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() + requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - result = yield self.groups_handler.add_room_to_group( - group_id, user_id, room_id, content, + result = yield self.groups_handler.update_room_group_association( + group_id, requester_user_id, room_id, content, ) defer.returnValue((200, result)) @@ -447,10 +460,10 @@ class GroupAdminRoomsServlet(RestServlet): @defer.inlineCallbacks def on_DELETE(self, request, group_id, room_id): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() + requester_user_id = requester.user.to_string() - result = yield self.groups_handler.remove_room_from_group( - group_id, user_id, room_id, + result = yield self.groups_handler.delete_room_group_association( + group_id, requester_user_id, room_id, ) defer.returnValue((200, result)) @@ -685,9 +698,9 @@ class GroupsForUserServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() + requester_user_id = requester.user.to_string() - result = yield self.groups_handler.get_joined_groups(user_id) + result = yield self.groups_handler.get_joined_groups(requester_user_id) defer.returnValue((200, result)) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 943e87e7fd..3cc87ea63f 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -53,8 +53,7 @@ class KeyUploadServlet(RestServlet): }, } """ - PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$", - releases=()) + PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$") def __init__(self, hs): """ @@ -128,10 +127,7 @@ class KeyQueryServlet(RestServlet): } } } } } } """ - PATTERNS = client_v2_patterns( - "/keys/query$", - releases=() - ) + PATTERNS = client_v2_patterns("/keys/query$") def __init__(self, hs): """ @@ -160,10 +156,7 @@ class KeyChangesServlet(RestServlet): 200 OK { "changed": ["@foo:example.com"] } """ - PATTERNS = client_v2_patterns( - "/keys/changes$", - releases=() - ) + PATTERNS = client_v2_patterns("/keys/changes$") def __init__(self, hs): """ @@ -213,10 +206,7 @@ class OneTimeKeyServlet(RestServlet): } } } } """ - PATTERNS = client_v2_patterns( - "/keys/claim$", - releases=() - ) + PATTERNS = client_v2_patterns("/keys/claim$") def __init__(self, hs): super(OneTimeKeyServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py index fd2a3d69d4..ec170109fe 100644 --- a/synapse/rest/client/v2_alpha/notifications.py +++ b/synapse/rest/client/v2_alpha/notifications.py @@ -30,7 +30,7 @@ logger = logging.getLogger(__name__) class NotificationsServlet(RestServlet): - PATTERNS = client_v2_patterns("/notifications$", releases=()) + PATTERNS = client_v2_patterns("/notifications$") def __init__(self, hs): super(NotificationsServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index d9a8cdbbb5..eebd071e59 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -557,25 +557,28 @@ class RegisterRestServlet(RestServlet): Args: (str) user_id: full canonical @user:id (object) params: registration parameters, from which we pull - device_id and initial_device_name + device_id, initial_device_name and inhibit_login Returns: defer.Deferred: (object) dictionary for response from /register """ - device_id = yield self._register_device(user_id, params) + result = { + "user_id": user_id, + "home_server": self.hs.hostname, + } + if not params.get("inhibit_login", False): + device_id = yield self._register_device(user_id, params) - access_token = ( - yield self.auth_handler.get_access_token_for_user_id( - user_id, device_id=device_id, - initial_display_name=params.get("initial_device_display_name") + access_token = ( + yield self.auth_handler.get_access_token_for_user_id( + user_id, device_id=device_id, + ) ) - ) - defer.returnValue({ - "user_id": user_id, - "access_token": access_token, - "home_server": self.hs.hostname, - "device_id": device_id, - }) + result.update({ + "access_token": access_token, + "device_id": device_id, + }) + defer.returnValue(result) def _register_device(self, user_id, params): """Register a device for a user. diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py index d607bd2970..90bdb1db15 100644 --- a/synapse/rest/client/v2_alpha/sendtodevice.py +++ b/synapse/rest/client/v2_alpha/sendtodevice.py @@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) class SendToDeviceRestServlet(servlet.RestServlet): PATTERNS = client_v2_patterns( "/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$", - releases=[], v2_alpha=False + v2_alpha=False ) def __init__(self, hs): diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index 6fceb23e26..6773b9ba60 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -26,7 +26,7 @@ logger = logging.getLogger(__name__) class ThirdPartyProtocolsServlet(RestServlet): - PATTERNS = client_v2_patterns("/thirdparty/protocols", releases=()) + PATTERNS = client_v2_patterns("/thirdparty/protocols") def __init__(self, hs): super(ThirdPartyProtocolsServlet, self).__init__() @@ -43,8 +43,7 @@ class ThirdPartyProtocolsServlet(RestServlet): class ThirdPartyProtocolServlet(RestServlet): - PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$", - releases=()) + PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$") def __init__(self, hs): super(ThirdPartyProtocolServlet, self).__init__() @@ -66,8 +65,7 @@ class ThirdPartyProtocolServlet(RestServlet): class ThirdPartyUserServlet(RestServlet): - PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$", - releases=()) + PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$") def __init__(self, hs): super(ThirdPartyUserServlet, self).__init__() @@ -90,8 +88,7 @@ class ThirdPartyUserServlet(RestServlet): class ThirdPartyLocationServlet(RestServlet): - PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$", - releases=()) + PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$") def __init__(self, hs): super(ThirdPartyLocationServlet, self).__init__() diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 80114fca0d..7907a9d17a 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -520,7 +520,14 @@ def _calc_og(tree, media_uri): from lxml import etree TAGS_TO_REMOVE = ( - "header", "nav", "aside", "footer", "script", "style", etree.Comment + "header", + "nav", + "aside", + "footer", + "script", + "noscript", + "style", + etree.Comment ) # Split all the text nodes into paragraphs (by splitting on new diff --git a/synapse/storage/group_server.py b/synapse/storage/group_server.py index 9e63db5c6c..6b261dcc0f 100644 --- a/synapse/storage/group_server.py +++ b/synapse/storage/group_server.py @@ -35,7 +35,9 @@ class GroupServerStore(SQLBaseStore): keyvalues={ "group_id": group_id, }, - retcols=("name", "short_description", "long_description", "avatar_url",), + retcols=( + "name", "short_description", "long_description", "avatar_url", "is_public" + ), allow_none=True, desc="is_user_in_group", ) @@ -52,7 +54,7 @@ class GroupServerStore(SQLBaseStore): return self._simple_select_list( table="group_users", keyvalues=keyvalues, - retcols=("user_id", "is_public",), + retcols=("user_id", "is_public", "is_admin",), desc="get_users_in_group", ) @@ -844,19 +846,25 @@ class GroupServerStore(SQLBaseStore): ) return self.runInteraction("remove_user_from_group", _remove_user_from_group_txn) - def add_room_to_group(self, group_id, room_id, is_public): - return self._simple_insert( + def update_room_group_association(self, group_id, room_id, is_public): + return self._simple_upsert( table="group_rooms", - values={ + keyvalues={ "group_id": group_id, "room_id": room_id, + }, + values={ "is_public": is_public, }, - desc="add_room_to_group", + insertion_values={ + "group_id": group_id, + "room_id": room_id, + }, + desc="update_room_group_association", ) - def remove_room_from_group(self, group_id, room_id): - def _remove_room_from_group_txn(txn): + def delete_room_group_association(self, group_id, room_id): + def _delete_room_group_association_txn(txn): self._simple_delete_txn( txn, table="group_rooms", @@ -875,7 +883,7 @@ class GroupServerStore(SQLBaseStore): }, ) return self.runInteraction( - "remove_room_from_group", _remove_room_from_group_txn, + "delete_room_group_association", _delete_room_group_association_txn, ) def get_publicised_groups_for_user(self, user_id): @@ -1026,6 +1034,7 @@ class GroupServerStore(SQLBaseStore): "avatar_url": avatar_url, "short_description": short_description, "long_description": long_description, + "is_public": True, }, desc="create_group", ) @@ -1086,6 +1095,24 @@ class GroupServerStore(SQLBaseStore): desc="update_remote_attestion", ) + def remove_attestation_renewal(self, group_id, user_id): + """Remove an attestation that we thought we should renew, but actually + shouldn't. Ideally this would never get called as we would never + incorrectly try and do attestations for local users on local groups. + + Args: + group_id (str) + user_id (str) + """ + return self._simple_delete( + table="group_attestations_renewals", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + desc="remove_attestation_renewal", + ) + @defer.inlineCallbacks def get_remote_attestation(self, group_id, user_id): """Get the attestation that proves the remote agrees that the user is diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 817c2185c8..d1691bbac2 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 45 +SCHEMA_VERSION = 46 dir_path = os.path.abspath(os.path.dirname(__file__)) @@ -44,6 +44,13 @@ def prepare_database(db_conn, database_engine, config): If `config` is None then prepare_database will assert that no upgrade is necessary, *or* will create a fresh database if the database is empty. + + Args: + db_conn: + database_engine: + config (synapse.config.homeserver.HomeServerConfig|None): + application config, or None if we are connecting to an existing + database which we expect to be configured already """ try: cur = db_conn.cursor() @@ -64,6 +71,10 @@ def prepare_database(db_conn, database_engine, config): else: _setup_new_database(cur, database_engine) + # check if any of our configured dynamic modules want a database + if config is not None: + _apply_module_schemas(cur, database_engine, config) + cur.close() db_conn.commit() except Exception: @@ -283,6 +294,65 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files, ) +def _apply_module_schemas(txn, database_engine, config): + """Apply the module schemas for the dynamic modules, if any + + Args: + cur: database cursor + database_engine: synapse database engine class + config (synapse.config.homeserver.HomeServerConfig): + application config + """ + for (mod, _config) in config.password_providers: + if not hasattr(mod, 'get_db_schema_files'): + continue + modname = ".".join((mod.__module__, mod.__name__)) + _apply_module_schema_files( + txn, database_engine, modname, mod.get_db_schema_files(), + ) + + +def _apply_module_schema_files(cur, database_engine, modname, names_and_streams): + """Apply the module schemas for a single module + + Args: + cur: database cursor + database_engine: synapse database engine class + modname (str): fully qualified name of the module + names_and_streams (Iterable[(str, file)]): the names and streams of + schemas to be applied + """ + cur.execute( + database_engine.convert_param_style( + "SELECT file FROM applied_module_schemas WHERE module_name = ?" + ), + (modname,) + ) + applied_deltas = set(d for d, in cur) + for (name, stream) in names_and_streams: + if name in applied_deltas: + continue + + root_name, ext = os.path.splitext(name) + if ext != '.sql': + raise PrepareDatabaseException( + "only .sql files are currently supported for module schemas", + ) + + logger.info("applying schema %s for %s", name, modname) + for statement in get_statements(stream): + cur.execute(statement) + + # Mark as done. + cur.execute( + database_engine.convert_param_style( + "INSERT INTO applied_module_schemas (module_name, file)" + " VALUES (?,?)", + ), + (modname, name) + ) + + def get_statements(f): statement_buffer = "" in_comment = False # If we're in a /* ... */ style comment diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 20acd58fcf..9c4f61da76 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -36,12 +36,15 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): columns=["user_id", "device_id"], ) - self.register_background_index_update( - "refresh_tokens_device_index", - index_name="refresh_tokens_device_id", - table="refresh_tokens", - columns=["user_id", "device_id"], - ) + # we no longer use refresh tokens, but it's possible that some people + # might have a background update queued to build this index. Just + # clear the background update. + @defer.inlineCallbacks + def noop_update(progress, batch_size): + yield self._end_background_update("refresh_tokens_device_index") + defer.returnValue(1) + self.register_background_update_handler( + "refresh_tokens_device_index", noop_update) @defer.inlineCallbacks def add_access_token_to_user(self, user_id, token, device_id=None): @@ -177,9 +180,11 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): ) if create_profile_with_localpart: + # set a default displayname serverside to avoid ugly race + # between auto-joins and clients trying to set displaynames txn.execute( - "INSERT INTO profiles(user_id) VALUES (?)", - (create_profile_with_localpart,) + "INSERT INTO profiles(user_id, displayname) VALUES (?,?)", + (create_profile_with_localpart, create_profile_with_localpart) ) self._invalidate_cache_and_stream( @@ -238,10 +243,9 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): @defer.inlineCallbacks def user_delete_access_tokens(self, user_id, except_token_id=None, - device_id=None, - delete_refresh_tokens=False): + device_id=None): """ - Invalidate access/refresh tokens belonging to a user + Invalidate access tokens belonging to a user Args: user_id (str): ID of user the tokens belong to @@ -250,10 +254,9 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): device_id (str|None): ID of device the tokens are associated with. If None, tokens associated with any device (or no device) will be deleted - delete_refresh_tokens (bool): True to delete refresh tokens as - well as access tokens. Returns: - defer.Deferred: + defer.Deferred[list[str, str|None]]: a list of the deleted tokens + and device IDs """ def f(txn): keyvalues = { @@ -262,13 +265,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): if device_id is not None: keyvalues["device_id"] = device_id - if delete_refresh_tokens: - self._simple_delete_txn( - txn, - table="refresh_tokens", - keyvalues=keyvalues, - ) - items = keyvalues.items() where_clause = " AND ".join(k + " = ?" for k, _ in items) values = [v for _, v in items] @@ -277,14 +273,14 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): values.append(except_token_id) txn.execute( - "SELECT token FROM access_tokens WHERE %s" % where_clause, + "SELECT token, device_id FROM access_tokens WHERE %s" % where_clause, values ) - rows = self.cursor_to_dict(txn) + tokens_and_devices = [(r[0], r[1]) for r in txn] - for row in rows: + for token, _ in tokens_and_devices: self._invalidate_cache_and_stream( - txn, self.get_user_by_access_token, (row["token"],) + txn, self.get_user_by_access_token, (token,) ) txn.execute( @@ -292,6 +288,8 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): values ) + return tokens_and_devices + yield self.runInteraction( "user_delete_access_tokens", f, ) diff --git a/synapse/storage/schema/delta/23/refresh_tokens.sql b/synapse/storage/schema/delta/23/refresh_tokens.sql deleted file mode 100644 index 34db0cf12b..0000000000 --- a/synapse/storage/schema/delta/23/refresh_tokens.sql +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE IF NOT EXISTS refresh_tokens( - id INTEGER PRIMARY KEY, - token TEXT NOT NULL, - user_id TEXT NOT NULL, - UNIQUE (token) -); diff --git a/synapse/storage/schema/delta/33/refreshtoken_device_index.sql b/synapse/storage/schema/delta/33/refreshtoken_device_index.sql deleted file mode 100644 index bb225dafbf..0000000000 --- a/synapse/storage/schema/delta/33/refreshtoken_device_index.sql +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -INSERT INTO background_updates (update_name, progress_json) VALUES - ('refresh_tokens_device_index', '{}'); diff --git a/synapse/storage/schema/delta/33/refreshtoken_device.sql b/synapse/storage/schema/delta/46/drop_refresh_tokens.sql index 290bd6da86..68c48a89a9 100644 --- a/synapse/storage/schema/delta/33/refreshtoken_device.sql +++ b/synapse/storage/schema/delta/46/drop_refresh_tokens.sql @@ -1,4 +1,4 @@ -/* Copyright 2016 OpenMarket Ltd +/* Copyright 2017 New Vector Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,4 +13,5 @@ * limitations under the License. */ -ALTER TABLE refresh_tokens ADD COLUMN device_id TEXT; +/* we no longer use (or create) the refresh_tokens table */ +DROP TABLE IF EXISTS refresh_tokens; diff --git a/synapse/storage/schema/delta/46/group_server.sql b/synapse/storage/schema/delta/46/group_server.sql new file mode 100644 index 0000000000..e754b554f8 --- /dev/null +++ b/synapse/storage/schema/delta/46/group_server.sql @@ -0,0 +1,32 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE groups_new ( + group_id TEXT NOT NULL, + name TEXT, -- the display name of the room + avatar_url TEXT, + short_description TEXT, + long_description TEXT, + is_public BOOL NOT NULL -- whether non-members can access group APIs +); + +-- NB: awful hack to get the default to be true on postgres and 1 on sqlite +INSERT INTO groups_new + SELECT group_id, name, avatar_url, short_description, long_description, (1=1) FROM groups; + +DROP TABLE groups; +ALTER TABLE groups_new RENAME TO groups; + +CREATE UNIQUE INDEX groups_idx ON groups(group_id); diff --git a/synapse/storage/schema/schema_version.sql b/synapse/storage/schema/schema_version.sql index a7ade69986..42e5cb6df5 100644 --- a/synapse/storage/schema/schema_version.sql +++ b/synapse/storage/schema/schema_version.sql @@ -25,3 +25,10 @@ CREATE TABLE IF NOT EXISTS applied_schema_deltas( file TEXT NOT NULL, UNIQUE(version, file) ); + +-- a list of schema files we have loaded on behalf of dynamic modules +CREATE TABLE IF NOT EXISTS applied_module_schemas( + module_name TEXT NOT NULL, + file TEXT NOT NULL, + UNIQUE(module_name, file) +); diff --git a/synapse/util/async.py b/synapse/util/async.py index 1a884e96ee..e786fb38a9 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -278,8 +278,13 @@ class Limiter(object): if entry[0] >= self.max_count: new_defer = defer.Deferred() entry[1].append(new_defer) + + logger.info("Waiting to acquire limiter lock for key %r", key) with PreserveLoggingContext(): yield new_defer + logger.info("Acquired limiter lock for key %r", key) + else: + logger.info("Acquired uncontended limiter lock for key %r", key) entry[0] += 1 @@ -288,16 +293,21 @@ class Limiter(object): try: yield finally: + logger.info("Releasing limiter lock for key %r", key) + # We've finished executing so check if there are any things # blocked waiting to execute and start one of them entry[0] -= 1 - try: - entry[1].pop(0).callback(None) - except IndexError: - # If nothing else is executing for this key then remove it - # from the map - if entry[0] == 0: - self.key_to_defer.pop(key, None) + + if entry[1]: + next_def = entry[1].pop(0) + + with PreserveLoggingContext(): + next_def.callback(None) + elif entry[0] == 0: + # We were the last thing for this key: remove it from the + # map. + del self.key_to_defer[key] defer.returnValue(_ctx_manager()) diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 316ecdb32d..7c7b164ee6 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -86,7 +86,8 @@ class RegistrationStoreTestCase(unittest.TestCase): # now delete some yield self.store.user_delete_access_tokens( - self.user_id, device_id=self.device_id, delete_refresh_tokens=True) + self.user_id, device_id=self.device_id, + ) # check they were deleted user = yield self.store.get_user_by_access_token(self.tokens[1]) @@ -97,8 +98,7 @@ class RegistrationStoreTestCase(unittest.TestCase): self.assertEqual(self.user_id, user["name"]) # now delete the rest - yield self.store.user_delete_access_tokens( - self.user_id, delete_refresh_tokens=True) + yield self.store.user_delete_access_tokens(self.user_id) user = yield self.store.get_user_by_access_token(self.tokens[0]) self.assertIsNone(user, diff --git a/tests/utils.py b/tests/utils.py index d2ebce4b2e..ed8a7360f5 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -310,6 +310,7 @@ class SQLiteMemoryDbPool(ConnectionPool, object): ) self.config = Mock() + self.config.password_providers = [] self.config.database_config = {"name": "sqlite3"} def prepare(self): |