diff options
Diffstat (limited to 'synapse')
41 files changed, 1117 insertions, 510 deletions
diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py index bee4c47498..abc7ef5725 100644 --- a/synapse/app/frontend_proxy.py +++ b/synapse/app/frontend_proxy.py @@ -50,8 +50,7 @@ logger = logging.getLogger("synapse.app.frontend_proxy") class KeyUploadServlet(RestServlet): - PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$", - releases=()) + PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$") def __init__(self, hs): """ @@ -89,9 +88,16 @@ class KeyUploadServlet(RestServlet): if body: # They're actually trying to upload something, proxy to main synapse. + # Pass through the auth headers, if any, in case the access token + # is there. + auth_headers = request.requestHeaders.getRawHeaders("Authorization", []) + headers = { + "Authorization": auth_headers, + } result = yield self.http_client.post_json_get_json( self.main_uri + request.uri, body, + headers=headers, ) defer.returnValue((200, result)) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 3adf72e141..9e26146338 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -30,6 +30,8 @@ from synapse.config._base import ConfigError from synapse.config.homeserver import HomeServerConfig from synapse.crypto import context_factory from synapse.federation.transport.server import TransportLayerServer +from synapse.module_api import ModuleApi +from synapse.http.additional_resource import AdditionalResource from synapse.http.server import RootRedirect from synapse.http.site import SynapseSite from synapse.metrics import register_memory_metrics @@ -49,6 +51,7 @@ from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_d from synapse.util.httpresourcetree import create_resource_tree from synapse.util.logcontext import LoggingContext from synapse.util.manhole import manhole +from synapse.util.module_loader import load_module from synapse.util.rlimit import change_resource_limit from synapse.util.versionstring import get_version_string from twisted.application import service @@ -107,52 +110,18 @@ class SynapseHomeServer(HomeServer): resources = {} for res in listener_config["resources"]: for name in res["names"]: - if name == "client": - client_resource = ClientRestResource(self) - if res["compress"]: - client_resource = gz_wrap(client_resource) - - resources.update({ - "/_matrix/client/api/v1": client_resource, - "/_matrix/client/r0": client_resource, - "/_matrix/client/unstable": client_resource, - "/_matrix/client/v2_alpha": client_resource, - "/_matrix/client/versions": client_resource, - }) - - if name == "federation": - resources.update({ - FEDERATION_PREFIX: TransportLayerServer(self), - }) - - if name in ["static", "client"]: - resources.update({ - STATIC_PREFIX: File( - os.path.join(os.path.dirname(synapse.__file__), "static") - ), - }) - - if name in ["media", "federation", "client"]: - media_repo = MediaRepositoryResource(self) - resources.update({ - MEDIA_PREFIX: media_repo, - LEGACY_MEDIA_PREFIX: media_repo, - CONTENT_REPO_PREFIX: ContentRepoResource( - self, self.config.uploads_path - ), - }) - - if name in ["keys", "federation"]: - resources.update({ - SERVER_KEY_PREFIX: LocalKey(self), - SERVER_KEY_V2_PREFIX: KeyApiV2Resource(self), - }) - - if name == "webclient": - resources[WEB_CLIENT_PREFIX] = build_resource_for_web_client(self) + resources.update(self._configure_named_resource( + name, res.get("compress", False), + )) - if name == "metrics" and self.get_config().enable_metrics: - resources[METRICS_PREFIX] = MetricsResource(self) + additional_resources = listener_config.get("additional_resources", {}) + logger.debug("Configuring additional resources: %r", + additional_resources) + module_api = ModuleApi(self, self.get_auth_handler()) + for path, resmodule in additional_resources.items(): + handler_cls, config = load_module(resmodule) + handler = handler_cls(config, module_api) + resources[path] = AdditionalResource(self, handler.handle_request) if WEB_CLIENT_PREFIX in resources: root_resource = RootRedirect(WEB_CLIENT_PREFIX) @@ -188,6 +157,67 @@ class SynapseHomeServer(HomeServer): ) logger.info("Synapse now listening on port %d", port) + def _configure_named_resource(self, name, compress=False): + """Build a resource map for a named resource + + Args: + name (str): named resource: one of "client", "federation", etc + compress (bool): whether to enable gzip compression for this + resource + + Returns: + dict[str, Resource]: map from path to HTTP resource + """ + resources = {} + if name == "client": + client_resource = ClientRestResource(self) + if compress: + client_resource = gz_wrap(client_resource) + + resources.update({ + "/_matrix/client/api/v1": client_resource, + "/_matrix/client/r0": client_resource, + "/_matrix/client/unstable": client_resource, + "/_matrix/client/v2_alpha": client_resource, + "/_matrix/client/versions": client_resource, + }) + + if name == "federation": + resources.update({ + FEDERATION_PREFIX: TransportLayerServer(self), + }) + + if name in ["static", "client"]: + resources.update({ + STATIC_PREFIX: File( + os.path.join(os.path.dirname(synapse.__file__), "static") + ), + }) + + if name in ["media", "federation", "client"]: + media_repo = MediaRepositoryResource(self) + resources.update({ + MEDIA_PREFIX: media_repo, + LEGACY_MEDIA_PREFIX: media_repo, + CONTENT_REPO_PREFIX: ContentRepoResource( + self, self.config.uploads_path + ), + }) + + if name in ["keys", "federation"]: + resources.update({ + SERVER_KEY_PREFIX: LocalKey(self), + SERVER_KEY_V2_PREFIX: KeyApiV2Resource(self), + }) + + if name == "webclient": + resources[WEB_CLIENT_PREFIX] = build_resource_for_web_client(self) + + if name == "metrics" and self.get_config().enable_metrics: + resources[METRICS_PREFIX] = MetricsResource(self) + + return resources + def start_listening(self): config = self.get_config() diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 6893610e71..40c433d7ae 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -18,6 +18,7 @@ from synapse.api.constants import ThirdPartyEntityKind from synapse.api.errors import CodeMessageException from synapse.http.client import SimpleHttpClient from synapse.events.utils import serialize_event +from synapse.util.logcontext import preserve_fn, make_deferred_yieldable from synapse.util.caches.response_cache import ResponseCache from synapse.types import ThirdPartyInstanceID @@ -192,9 +193,12 @@ class ApplicationServiceApi(SimpleHttpClient): defer.returnValue(None) key = (service.id, protocol) - return self.protocol_meta_cache.get(key) or ( - self.protocol_meta_cache.set(key, _get()) - ) + result = self.protocol_meta_cache.get(key) + if not result: + result = self.protocol_meta_cache.set( + key, preserve_fn(_get)() + ) + return make_deferred_yieldable(result) @defer.inlineCallbacks def push_bulk(self, service, events, txn_id=None): diff --git a/synapse/config/cas.py b/synapse/config/cas.py index 938f6f25f8..8109e5f95e 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py @@ -41,7 +41,7 @@ class CasConfig(Config): #cas_config: # enabled: true # server_url: "https://cas-server.com" - # service_url: "https://homesever.domain.com:8448" + # service_url: "https://homeserver.domain.com:8448" # #required_attributes: # # name: value """ diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 2dbeafa9dd..a1d6e4d4f7 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -148,8 +148,8 @@ def setup_logging(config, use_worker_options=False): "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s" " - %(message)s" ) - if log_config is None: + if log_config is None: level = logging.INFO level_for_storage = logging.INFO if config.verbosity: @@ -176,6 +176,10 @@ def setup_logging(config, use_worker_options=False): logger.info("Opened new log file due to SIGHUP") else: handler = logging.StreamHandler() + + def sighup(signum, stack): + pass + handler.setFormatter(formatter) handler.addFilter(LoggingContextFilter(request="")) diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py index 90824cab7f..e9828fac17 100644 --- a/synapse/config/password_auth_providers.py +++ b/synapse/config/password_auth_providers.py @@ -13,41 +13,40 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import Config, ConfigError +from ._base import Config from synapse.util.module_loader import load_module +LDAP_PROVIDER = 'ldap_auth_provider.LdapAuthProvider' + class PasswordAuthProviderConfig(Config): def read_config(self, config): self.password_providers = [] - - provider_config = None + providers = [] # We want to be backwards compatible with the old `ldap_config` # param. ldap_config = config.get("ldap_config", {}) - self.ldap_enabled = ldap_config.get("enabled", False) - if self.ldap_enabled: - from ldap_auth_provider import LdapAuthProvider - parsed_config = LdapAuthProvider.parse_config(ldap_config) - self.password_providers.append((LdapAuthProvider, parsed_config)) + if ldap_config.get("enabled", False): + providers.append[{ + 'module': LDAP_PROVIDER, + 'config': ldap_config, + }] - providers = config.get("password_providers", []) + providers.extend(config.get("password_providers", [])) for provider in providers: + mod_name = provider['module'] + # This is for backwards compat when the ldap auth provider resided # in this package. - if provider['module'] == "synapse.util.ldap_auth_provider.LdapAuthProvider": - from ldap_auth_provider import LdapAuthProvider - provider_class = LdapAuthProvider - try: - provider_config = provider_class.parse_config(provider["config"]) - except Exception as e: - raise ConfigError( - "Failed to parse config for %r: %r" % (provider['module'], e) - ) - else: - (provider_class, provider_config) = load_module(provider) + if mod_name == "synapse.util.ldap_auth_provider.LdapAuthProvider": + mod_name = LDAP_PROVIDER + + (provider_class, provider_config) = load_module({ + "module": mod_name, + "config": provider['config'], + }) self.password_providers.append((provider_class, provider_config)) diff --git a/synapse/config/server.py b/synapse/config/server.py index b66993dab9..4d9193536d 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -247,6 +247,13 @@ class ServerConfig(Config): - names: [federation] # Federation APIs compress: false + # optional list of additional endpoints which can be loaded via + # dynamic modules + # additional_resources: + # "/_matrix/my/custom/endpoint": + # module: my_module.CustomRequestHandler + # config: {} + # Unsecure HTTP listener, # For when matrix traffic passes through loadbalancer that unwraps TLS. - port: %(unsecure_port)s diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index e15228e70b..a2327f24b6 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -18,6 +18,7 @@ from .federation_base import FederationBase from .units import Transaction, Edu from synapse.util import async +from synapse.util.logcontext import make_deferred_yieldable, preserve_fn from synapse.util.logutils import log_function from synapse.util.caches.response_cache import ResponseCache from synapse.events import FrozenEvent @@ -253,12 +254,13 @@ class FederationServer(FederationBase): result = self._state_resp_cache.get((room_id, event_id)) if not result: with (yield self._server_linearizer.queue((origin, room_id))): - resp = yield self._state_resp_cache.set( + d = self._state_resp_cache.set( (room_id, event_id), - self._on_context_state_request_compute(room_id, event_id) + preserve_fn(self._on_context_state_request_compute)(room_id, event_id) ) + resp = yield make_deferred_yieldable(d) else: - resp = yield result + resp = yield make_deferred_yieldable(result) defer.returnValue((200, resp)) diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index d25ae1b282..ed41dfc7ee 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -531,9 +531,9 @@ class TransportLayerClient(object): ignore_backoff=True, ) - def add_room_to_group(self, destination, group_id, requester_user_id, room_id, - content): - """Add a room to a group + def update_room_group_association(self, destination, group_id, requester_user_id, + room_id, content): + """Add or update an association between room and group """ path = PREFIX + "/groups/%s/room/%s" % (group_id, room_id,) @@ -545,7 +545,8 @@ class TransportLayerClient(object): ignore_backoff=True, ) - def remove_room_from_group(self, destination, group_id, requester_user_id, room_id): + def delete_room_group_association(self, destination, group_id, requester_user_id, + room_id): """Remove a room from a group """ path = PREFIX + "/groups/%s/room/%s" % (group_id, room_id,) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 8f3c14c303..ded6d4edc9 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -684,7 +684,7 @@ class FederationGroupsAddRoomsServlet(BaseFederationServlet): if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - new_content = yield self.handler.add_room_to_group( + new_content = yield self.handler.update_room_group_association( group_id, requester_user_id, room_id, content ) @@ -696,7 +696,7 @@ class FederationGroupsAddRoomsServlet(BaseFederationServlet): if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - new_content = yield self.handler.remove_room_from_group( + new_content = yield self.handler.delete_room_group_association( group_id, requester_user_id, room_id, ) diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py index b751cf5e43..1fb709e6c3 100644 --- a/synapse/groups/attestations.py +++ b/synapse/groups/attestations.py @@ -13,6 +13,31 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Attestations ensure that users and groups can't lie about their memberships. + +When a user joins a group the HS and GS swap attestations, which allow them +both to independently prove to third parties their membership.These +attestations have a validity period so need to be periodically renewed. + +If a user leaves (or gets kicked out of) a group, either side can still use +their attestation to "prove" their membership, until the attestation expires. +Therefore attestations shouldn't be relied on to prove membership in important +cases, but can for less important situtations, e.g. showing a users membership +of groups on their profile, showing flairs, etc.abs + +An attestsation is a signed blob of json that looks like: + + { + "user_id": "@foo:a.example.com", + "group_id": "+bar:b.example.com", + "valid_until_ms": 1507994728530, + "signatures":{"matrix.org":{"ed25519:auto":"..."}} + } +""" + +import logging +import random + from twisted.internet import defer from synapse.api.errors import SynapseError @@ -22,9 +47,17 @@ from synapse.util.logcontext import preserve_fn from signedjson.sign import sign_json +logger = logging.getLogger(__name__) + + # Default validity duration for new attestations we create DEFAULT_ATTESTATION_LENGTH_MS = 3 * 24 * 60 * 60 * 1000 +# We add some jitter to the validity duration of attestations so that if we +# add lots of users at once we don't need to renew them all at once. +# The jitter is a multiplier picked randomly between the first and second number +DEFAULT_ATTESTATION_JITTER = (0.9, 1.3) + # Start trying to update our attestations when they come this close to expiring UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000 @@ -73,10 +106,14 @@ class GroupAttestationSigning(object): """Create an attestation for the group_id and user_id with default validity length. """ + validity_period = DEFAULT_ATTESTATION_LENGTH_MS + validity_period *= random.uniform(*DEFAULT_ATTESTATION_JITTER) + valid_until_ms = int(self.clock.time_msec() + validity_period) + return sign_json({ "group_id": group_id, "user_id": user_id, - "valid_until_ms": self.clock.time_msec() + DEFAULT_ATTESTATION_LENGTH_MS, + "valid_until_ms": valid_until_ms, }, self.server_name, self.signing_key) @@ -128,12 +165,19 @@ class GroupAttestionRenewer(object): @defer.inlineCallbacks def _renew_attestation(group_id, user_id): - attestation = self.attestations.create_attestation(group_id, user_id) - - if self.is_mine_id(group_id): + if not self.is_mine_id(group_id): + destination = get_domain_from_id(group_id) + elif not self.is_mine_id(user_id): destination = get_domain_from_id(user_id) else: - destination = get_domain_from_id(group_id) + logger.warn( + "Incorrectly trying to do attestations for user: %r in %r", + user_id, group_id, + ) + yield self.store.remove_attestation_renewal(group_id, user_id) + return + + attestation = self.attestations.create_attestation(group_id, user_id) yield self.transport_client.renew_group_attestation( destination, group_id, user_id, diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index 23beb3187e..addc70ce94 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -49,7 +49,8 @@ class GroupsServerHandler(object): hs.get_groups_attestation_renewer() @defer.inlineCallbacks - def check_group_is_ours(self, group_id, and_exists=False, and_is_admin=None): + def check_group_is_ours(self, group_id, requester_user_id, + and_exists=False, and_is_admin=None): """Check that the group is ours, and optionally if it exists. If group does exist then return group. @@ -67,6 +68,10 @@ class GroupsServerHandler(object): if and_exists and not group: raise SynapseError(404, "Unknown group") + is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) + if group and not is_user_in_group and not group["is_public"]: + raise SynapseError(404, "Unknown group") + if and_is_admin: is_admin = yield self.store.is_user_admin_in_group(group_id, and_is_admin) if not is_admin: @@ -84,7 +89,7 @@ class GroupsServerHandler(object): A user/room may appear in multiple roles/categories. """ - yield self.check_group_is_ours(group_id, and_exists=True) + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) @@ -153,10 +158,16 @@ class GroupsServerHandler(object): }) @defer.inlineCallbacks - def update_group_summary_room(self, group_id, user_id, room_id, category_id, content): + def update_group_summary_room(self, group_id, requester_user_id, + room_id, category_id, content): """Add/update a room to the group summary """ - yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id) + yield self.check_group_is_ours( + group_id, + requester_user_id, + and_exists=True, + and_is_admin=requester_user_id, + ) RoomID.from_string(room_id) # Ensure valid room id @@ -175,10 +186,16 @@ class GroupsServerHandler(object): defer.returnValue({}) @defer.inlineCallbacks - def delete_group_summary_room(self, group_id, user_id, room_id, category_id): + def delete_group_summary_room(self, group_id, requester_user_id, + room_id, category_id): """Remove a room from the summary """ - yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id) + yield self.check_group_is_ours( + group_id, + requester_user_id, + and_exists=True, + and_is_admin=requester_user_id, + ) yield self.store.remove_room_from_summary( group_id=group_id, @@ -189,10 +206,10 @@ class GroupsServerHandler(object): defer.returnValue({}) @defer.inlineCallbacks - def get_group_categories(self, group_id, user_id): + def get_group_categories(self, group_id, requester_user_id): """Get all categories in a group (as seen by user) """ - yield self.check_group_is_ours(group_id, and_exists=True) + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) categories = yield self.store.get_group_categories( group_id=group_id, @@ -200,10 +217,10 @@ class GroupsServerHandler(object): defer.returnValue({"categories": categories}) @defer.inlineCallbacks - def get_group_category(self, group_id, user_id, category_id): + def get_group_category(self, group_id, requester_user_id, category_id): """Get a specific category in a group (as seen by user) """ - yield self.check_group_is_ours(group_id, and_exists=True) + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) res = yield self.store.get_group_category( group_id=group_id, @@ -213,10 +230,15 @@ class GroupsServerHandler(object): defer.returnValue(res) @defer.inlineCallbacks - def update_group_category(self, group_id, user_id, category_id, content): + def update_group_category(self, group_id, requester_user_id, category_id, content): """Add/Update a group category """ - yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id) + yield self.check_group_is_ours( + group_id, + requester_user_id, + and_exists=True, + and_is_admin=requester_user_id, + ) is_public = _parse_visibility_from_contents(content) profile = content.get("profile") @@ -231,10 +253,15 @@ class GroupsServerHandler(object): defer.returnValue({}) @defer.inlineCallbacks - def delete_group_category(self, group_id, user_id, category_id): + def delete_group_category(self, group_id, requester_user_id, category_id): """Delete a group category """ - yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id) + yield self.check_group_is_ours( + group_id, + requester_user_id, + and_exists=True, + and_is_admin=requester_user_id + ) yield self.store.remove_group_category( group_id=group_id, @@ -244,10 +271,10 @@ class GroupsServerHandler(object): defer.returnValue({}) @defer.inlineCallbacks - def get_group_roles(self, group_id, user_id): + def get_group_roles(self, group_id, requester_user_id): """Get all roles in a group (as seen by user) """ - yield self.check_group_is_ours(group_id, and_exists=True) + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) roles = yield self.store.get_group_roles( group_id=group_id, @@ -255,10 +282,10 @@ class GroupsServerHandler(object): defer.returnValue({"roles": roles}) @defer.inlineCallbacks - def get_group_role(self, group_id, user_id, role_id): + def get_group_role(self, group_id, requester_user_id, role_id): """Get a specific role in a group (as seen by user) """ - yield self.check_group_is_ours(group_id, and_exists=True) + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) res = yield self.store.get_group_role( group_id=group_id, @@ -267,10 +294,15 @@ class GroupsServerHandler(object): defer.returnValue(res) @defer.inlineCallbacks - def update_group_role(self, group_id, user_id, role_id, content): + def update_group_role(self, group_id, requester_user_id, role_id, content): """Add/update a role in a group """ - yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id) + yield self.check_group_is_ours( + group_id, + requester_user_id, + and_exists=True, + and_is_admin=requester_user_id, + ) is_public = _parse_visibility_from_contents(content) @@ -286,10 +318,15 @@ class GroupsServerHandler(object): defer.returnValue({}) @defer.inlineCallbacks - def delete_group_role(self, group_id, user_id, role_id): + def delete_group_role(self, group_id, requester_user_id, role_id): """Remove role from group """ - yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id) + yield self.check_group_is_ours( + group_id, + requester_user_id, + and_exists=True, + and_is_admin=requester_user_id, + ) yield self.store.remove_group_role( group_id=group_id, @@ -304,7 +341,7 @@ class GroupsServerHandler(object): """Add/update a users entry in the group summary """ yield self.check_group_is_ours( - group_id, and_exists=True, and_is_admin=requester_user_id, + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id, ) order = content.get("order", None) @@ -326,7 +363,7 @@ class GroupsServerHandler(object): """Remove a user from the group summary """ yield self.check_group_is_ours( - group_id, and_exists=True, and_is_admin=requester_user_id, + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id, ) yield self.store.remove_user_from_summary( @@ -342,7 +379,7 @@ class GroupsServerHandler(object): """Get the group profile as seen by requester_user_id """ - yield self.check_group_is_ours(group_id) + yield self.check_group_is_ours(group_id, requester_user_id) group_description = yield self.store.get_group(group_id) @@ -356,7 +393,7 @@ class GroupsServerHandler(object): """Update the group profile """ yield self.check_group_is_ours( - group_id, and_exists=True, and_is_admin=requester_user_id, + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id, ) profile = {} @@ -377,7 +414,7 @@ class GroupsServerHandler(object): The ordering is arbitrary at the moment """ - yield self.check_group_is_ours(group_id, and_exists=True) + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) @@ -425,7 +462,7 @@ class GroupsServerHandler(object): The ordering is arbitrary at the moment """ - yield self.check_group_is_ours(group_id, and_exists=True) + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) @@ -459,7 +496,7 @@ class GroupsServerHandler(object): This returns rooms in order of decreasing number of joined users """ - yield self.check_group_is_ours(group_id, and_exists=True) + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) @@ -470,7 +507,6 @@ class GroupsServerHandler(object): chunk = [] for room_result in room_results: room_id = room_result["room_id"] - is_public = room_result["is_public"] joined_users = yield self.store.get_users_in_room(room_id) entry = yield self.room_list_handler.generate_room_entry( @@ -481,8 +517,7 @@ class GroupsServerHandler(object): if not entry: continue - if not is_public: - entry["is_public"] = False + entry["is_public"] = bool(room_result["is_public"]) chunk.append(entry) @@ -494,30 +529,33 @@ class GroupsServerHandler(object): }) @defer.inlineCallbacks - def add_room_to_group(self, group_id, requester_user_id, room_id, content): - """Add room to group + def update_room_group_association(self, group_id, requester_user_id, room_id, + content): + """Add or update an association between room and group """ RoomID.from_string(room_id) # Ensure valid room id yield self.check_group_is_ours( - group_id, and_exists=True, and_is_admin=requester_user_id + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) is_public = _parse_visibility_from_contents(content) - yield self.store.add_room_to_group(group_id, room_id, is_public=is_public) + yield self.store.update_room_group_association( + group_id, room_id, is_public=is_public + ) defer.returnValue({}) @defer.inlineCallbacks - def remove_room_from_group(self, group_id, requester_user_id, room_id): + def delete_room_group_association(self, group_id, requester_user_id, room_id): """Remove room from group """ yield self.check_group_is_ours( - group_id, and_exists=True, and_is_admin=requester_user_id + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) - yield self.store.remove_room_from_group(group_id, room_id) + yield self.store.delete_room_group_association(group_id, room_id) defer.returnValue({}) @@ -527,7 +565,7 @@ class GroupsServerHandler(object): """ group = yield self.check_group_is_ours( - group_id, and_exists=True, and_is_admin=requester_user_id + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) # TODO: Check if user knocked @@ -596,35 +634,40 @@ class GroupsServerHandler(object): raise SynapseError(502, "Unknown state returned by HS") @defer.inlineCallbacks - def accept_invite(self, group_id, user_id, content): + def accept_invite(self, group_id, requester_user_id, content): """User tries to accept an invite to the group. This is different from them asking to join, and so should error if no invite exists (and they're not a member of the group) """ - yield self.check_group_is_ours(group_id, and_exists=True) + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - if not self.store.is_user_invited_to_local_group(group_id, user_id): + is_invited = yield self.store.is_user_invited_to_local_group( + group_id, requester_user_id, + ) + if not is_invited: raise SynapseError(403, "User not invited to group") - if not self.hs.is_mine_id(user_id): + if not self.hs.is_mine_id(requester_user_id): + local_attestation = self.attestations.create_attestation( + group_id, requester_user_id, + ) remote_attestation = content["attestation"] yield self.attestations.verify_attestation( remote_attestation, - user_id=user_id, + user_id=requester_user_id, group_id=group_id, ) else: + local_attestation = None remote_attestation = None - local_attestation = self.attestations.create_attestation(group_id, user_id) - is_public = _parse_visibility_from_contents(content) yield self.store.add_user_to_group( - group_id, user_id, + group_id, requester_user_id, is_admin=False, is_public=is_public, local_attestation=local_attestation, @@ -637,31 +680,31 @@ class GroupsServerHandler(object): }) @defer.inlineCallbacks - def knock(self, group_id, user_id, content): + def knock(self, group_id, requester_user_id, content): """A user requests becoming a member of the group """ - yield self.check_group_is_ours(group_id, and_exists=True) + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) raise NotImplementedError() @defer.inlineCallbacks - def accept_knock(self, group_id, user_id, content): + def accept_knock(self, group_id, requester_user_id, content): """Accept a users knock to the room. Errors if the user hasn't knocked, rather than inviting them. """ - yield self.check_group_is_ours(group_id, and_exists=True) + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) raise NotImplementedError() @defer.inlineCallbacks def remove_user_from_group(self, group_id, user_id, requester_user_id, content): - """Remove a user from the group; either a user is leaving or and admin - kicked htem. + """Remove a user from the group; either a user is leaving or an admin + kicked them. """ - yield self.check_group_is_ours(group_id, and_exists=True) + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) is_kick = False if requester_user_id != user_id: @@ -692,8 +735,8 @@ class GroupsServerHandler(object): defer.returnValue({}) @defer.inlineCallbacks - def create_group(self, group_id, user_id, content): - group = yield self.check_group_is_ours(group_id) + def create_group(self, group_id, requester_user_id, content): + group = yield self.check_group_is_ours(group_id, requester_user_id) logger.info("Attempting to create group with ID: %r", group_id) @@ -703,11 +746,11 @@ class GroupsServerHandler(object): if group: raise SynapseError(400, "Group already exists") - is_admin = yield self.auth.is_server_admin(UserID.from_string(user_id)) + is_admin = yield self.auth.is_server_admin(UserID.from_string(requester_user_id)) if not is_admin: if not self.hs.config.enable_group_creation: raise SynapseError( - 403, "Only server admin can create group on this server", + 403, "Only a server admin can create groups on this server", ) localpart = group_id_obj.localpart if not localpart.startswith(self.hs.config.group_creation_prefix): @@ -727,38 +770,41 @@ class GroupsServerHandler(object): yield self.store.create_group( group_id, - user_id, + requester_user_id, name=name, avatar_url=avatar_url, short_description=short_description, long_description=long_description, ) - if not self.hs.is_mine_id(user_id): + if not self.hs.is_mine_id(requester_user_id): remote_attestation = content["attestation"] yield self.attestations.verify_attestation( remote_attestation, - user_id=user_id, + user_id=requester_user_id, group_id=group_id, ) - local_attestation = self.attestations.create_attestation(group_id, user_id) + local_attestation = self.attestations.create_attestation( + group_id, + requester_user_id, + ) else: local_attestation = None remote_attestation = None yield self.store.add_user_to_group( - group_id, user_id, + group_id, requester_user_id, is_admin=True, is_public=True, # TODO local_attestation=local_attestation, remote_attestation=remote_attestation, ) - if not self.hs.is_mine_id(user_id): + if not self.hs.is_mine_id(requester_user_id): yield self.store.add_remote_profile_cache( - user_id, + requester_user_id, displayname=user_profile.get("displayname"), avatar_url=user_profile.get("avatar_url"), ) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 9cef9d184b..7a0ba6ef35 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -13,13 +13,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from twisted.internet import defer from ._base import BaseHandler from synapse.api.constants import LoginType -from synapse.types import UserID from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError +from synapse.module_api import ModuleApi +from synapse.types import UserID from synapse.util.async import run_on_reactor from synapse.util.caches.expiringcache import ExpiringCache @@ -63,10 +63,7 @@ class AuthHandler(BaseHandler): reset_expiry_on_get=True, ) - account_handler = _AccountHandler( - hs, check_user_exists=self.check_user_exists - ) - + account_handler = ModuleApi(hs, self) self.password_providers = [ module(config=config, account_handler=account_handler) for module, config in hs.config.password_providers @@ -75,14 +72,24 @@ class AuthHandler(BaseHandler): logger.info("Extra password_providers: %r", self.password_providers) self.hs = hs # FIXME better possibility to access registrationHandler later? - self.device_handler = hs.get_device_handler() self.macaroon_gen = hs.get_macaroon_generator() + self._password_enabled = hs.config.password_enabled + + login_types = set() + if self._password_enabled: + login_types.add(LoginType.PASSWORD) + for provider in self.password_providers: + if hasattr(provider, "get_supported_login_types"): + login_types.update( + provider.get_supported_login_types().keys() + ) + self._supported_login_types = frozenset(login_types) @defer.inlineCallbacks def check_auth(self, flows, clientdict, clientip): """ Takes a dictionary sent by the client in the login / registration - protocol and handles the login flow. + protocol and handles the User-Interactive Auth flow. As a side effect, this function fills in the 'creds' key on the user's session with a map, which maps each auth-type (str) to the relevant @@ -260,16 +267,19 @@ class AuthHandler(BaseHandler): sess = self._get_session_info(session_id) return sess.setdefault('serverdict', {}).get(key, default) + @defer.inlineCallbacks def _check_password_auth(self, authdict, _): if "user" not in authdict or "password" not in authdict: raise LoginError(400, "", Codes.MISSING_PARAM) user_id = authdict["user"] password = authdict["password"] - if not user_id.startswith('@'): - user_id = UserID(user_id, self.hs.hostname).to_string() - return self._check_password(user_id, password) + (canonical_id, callback) = yield self.validate_login(user_id, { + "type": LoginType.PASSWORD, + "password": password, + }) + defer.returnValue(canonical_id) @defer.inlineCallbacks def _check_recaptcha(self, authdict, clientip): @@ -398,26 +408,8 @@ class AuthHandler(BaseHandler): return self.sessions[session_id] - def validate_password_login(self, user_id, password): - """ - Authenticates the user with their username and password. - - Used only by the v1 login API. - - Args: - user_id (str): complete @user:id - password (str): Password - Returns: - defer.Deferred: (str) canonical user id - Raises: - StoreError if there was a problem accessing the database - LoginError if there was an authentication problem. - """ - return self._check_password(user_id, password) - @defer.inlineCallbacks - def get_access_token_for_user_id(self, user_id, device_id=None, - initial_display_name=None): + def get_access_token_for_user_id(self, user_id, device_id=None): """ Creates a new access token for the user with the given user ID. @@ -431,13 +423,10 @@ class AuthHandler(BaseHandler): device_id (str|None): the device ID to associate with the tokens. None to leave the tokens unassociated with a device (deprecated: we should always have a device ID) - initial_display_name (str): display name to associate with the - device if it needs re-registering Returns: The access token for the user's session. Raises: StoreError if there was a problem storing the token. - LoginError if there was an authentication problem. """ logger.info("Logging in user %s on device %s", user_id, device_id) access_token = yield self.issue_access_token(user_id, device_id) @@ -447,9 +436,11 @@ class AuthHandler(BaseHandler): # really don't want is active access_tokens without a record of the # device, so we double-check it here. if device_id is not None: - yield self.device_handler.check_device_registered( - user_id, device_id, initial_display_name - ) + try: + yield self.store.get_device(user_id, device_id) + except StoreError: + yield self.store.delete_access_token(access_token) + raise StoreError(400, "Login raced against device deletion") defer.returnValue(access_token) @@ -501,29 +492,115 @@ class AuthHandler(BaseHandler): ) defer.returnValue(result) + def get_supported_login_types(self): + """Get a the login types supported for the /login API + + By default this is just 'm.login.password' (unless password_enabled is + False in the config file), but password auth providers can provide + other login types. + + Returns: + Iterable[str]: login types + """ + return self._supported_login_types + @defer.inlineCallbacks - def _check_password(self, user_id, password): - """Authenticate a user against the LDAP and local databases. + def validate_login(self, username, login_submission): + """Authenticates the user for the /login API - user_id is checked case insensitively against the local database, but - will throw if there are multiple inexact matches. + Also used by the user-interactive auth flow to validate + m.login.password auth types. Args: - user_id (str): complete @user:id + username (str): username supplied by the user + login_submission (dict): the whole of the login submission + (including 'type' and other relevant fields) Returns: - (str) the canonical_user_id + Deferred[str, func]: canonical user id, and optional callback + to be called once the access token and device id are issued Raises: - LoginError if login fails + StoreError if there was a problem accessing the database + SynapseError if there was a problem with the request + LoginError if there was an authentication problem. """ + + if username.startswith('@'): + qualified_user_id = username + else: + qualified_user_id = UserID( + username, self.hs.hostname + ).to_string() + + login_type = login_submission.get("type") + known_login_type = False + + # special case to check for "password" for the check_password interface + # for the auth providers + password = login_submission.get("password") + if login_type == LoginType.PASSWORD: + if not self._password_enabled: + raise SynapseError(400, "Password login has been disabled.") + if not password: + raise SynapseError(400, "Missing parameter: password") + for provider in self.password_providers: - is_valid = yield provider.check_password(user_id, password) - if is_valid: - defer.returnValue(user_id) + if (hasattr(provider, "check_password") + and login_type == LoginType.PASSWORD): + known_login_type = True + is_valid = yield provider.check_password( + qualified_user_id, password, + ) + if is_valid: + defer.returnValue(qualified_user_id) + + if (not hasattr(provider, "get_supported_login_types") + or not hasattr(provider, "check_auth")): + # this password provider doesn't understand custom login types + continue + + supported_login_types = provider.get_supported_login_types() + if login_type not in supported_login_types: + # this password provider doesn't understand this login type + continue + + known_login_type = True + login_fields = supported_login_types[login_type] + + missing_fields = [] + login_dict = {} + for f in login_fields: + if f not in login_submission: + missing_fields.append(f) + else: + login_dict[f] = login_submission[f] + if missing_fields: + raise SynapseError( + 400, "Missing parameters for login type %s: %s" % ( + login_type, + missing_fields, + ), + ) + + result = yield provider.check_auth( + username, login_type, login_dict, + ) + if result: + if isinstance(result, str): + result = (result, None) + defer.returnValue(result) + + if login_type == LoginType.PASSWORD: + known_login_type = True + + canonical_user_id = yield self._check_local_password( + qualified_user_id, password, + ) - canonical_user_id = yield self._check_local_password(user_id, password) + if canonical_user_id: + defer.returnValue((canonical_user_id, None)) - if canonical_user_id: - defer.returnValue(canonical_user_id) + if not known_login_type: + raise SynapseError(400, "Unknown login type %s" % login_type) # unknown username or invalid password. We raise a 403 here, but note # that if we're doing user-interactive login, it turns all LoginErrors @@ -584,14 +661,81 @@ class AuthHandler(BaseHandler): if e.code == 404: raise SynapseError(404, "Unknown user", Codes.NOT_FOUND) raise e - yield self.store.user_delete_access_tokens( - user_id, except_access_token_id + yield self.delete_access_tokens_for_user( + user_id, except_token_id=except_access_token_id, ) yield self.hs.get_pusherpool().remove_pushers_by_user( user_id, except_access_token_id ) @defer.inlineCallbacks + def deactivate_account(self, user_id): + """Deactivate a user's account + + Args: + user_id (str): ID of user to be deactivated + + Returns: + Deferred + """ + # FIXME: Theoretically there is a race here wherein user resets + # password using threepid. + yield self.delete_access_tokens_for_user(user_id) + yield self.store.user_delete_threepids(user_id) + yield self.store.user_set_password_hash(user_id, None) + + @defer.inlineCallbacks + def delete_access_token(self, access_token): + """Invalidate a single access token + + Args: + access_token (str): access token to be deleted + + Returns: + Deferred + """ + user_info = yield self.auth.get_user_by_access_token(access_token) + yield self.store.delete_access_token(access_token) + + # see if any of our auth providers want to know about this + for provider in self.password_providers: + if hasattr(provider, "on_logged_out"): + yield provider.on_logged_out( + user_id=str(user_info["user"]), + device_id=user_info["device_id"], + access_token=access_token, + ) + + @defer.inlineCallbacks + def delete_access_tokens_for_user(self, user_id, except_token_id=None, + device_id=None): + """Invalidate access tokens belonging to a user + + Args: + user_id (str): ID of user the tokens belong to + except_token_id (str|None): access_token ID which should *not* be + deleted + device_id (str|None): ID of device the tokens are associated with. + If None, tokens associated with any device (or no device) will + be deleted + Returns: + Deferred + """ + tokens_and_devices = yield self.store.user_delete_access_tokens( + user_id, except_token_id=except_token_id, device_id=device_id, + ) + + # see if any of our auth providers want to know about this + for provider in self.password_providers: + if hasattr(provider, "on_logged_out"): + for token, device_id in tokens_and_devices: + yield provider.on_logged_out( + user_id=user_id, + device_id=device_id, + access_token=token, + ) + + @defer.inlineCallbacks def add_threepid(self, user_id, medium, address, validated_at): # 'Canonicalise' email addresses down to lower case. # We've now moving towards the Home Server being the entity that @@ -696,30 +840,3 @@ class MacaroonGeneartor(object): macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) return macaroon - - -class _AccountHandler(object): - """A proxy object that gets passed to password auth providers so they - can register new users etc if necessary. - """ - def __init__(self, hs, check_user_exists): - self.hs = hs - - self._check_user_exists = check_user_exists - - def check_user_exists(self, user_id): - """Check if user exissts. - - Returns: - Deferred(bool) - """ - return self._check_user_exists(user_id) - - def register(self, localpart): - """Registers a new user with given localpart - - Returns: - Deferred: a 2-tuple of (user_id, access_token) - """ - reg = self.hs.get_handlers().registration_handler - return reg.register(localpart=localpart) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index dac4b3f4e0..579d8477ba 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -34,6 +34,7 @@ class DeviceHandler(BaseHandler): self.hs = hs self.state = hs.get_state_handler() + self._auth_handler = hs.get_auth_handler() self.federation_sender = hs.get_federation_sender() self.federation = hs.get_replication_layer() @@ -159,9 +160,8 @@ class DeviceHandler(BaseHandler): else: raise - yield self.store.user_delete_access_tokens( + yield self._auth_handler.delete_access_tokens_for_user( user_id, device_id=device_id, - delete_refresh_tokens=True, ) yield self.store.delete_e2e_keys_by_device( @@ -194,9 +194,8 @@ class DeviceHandler(BaseHandler): # Delete access tokens and e2e keys for each device. Not optimised as it is not # considered as part of a critical path. for device_id in device_ids: - yield self.store.user_delete_access_tokens( + yield self._auth_handler.delete_access_tokens_for_user( user_id, device_id=device_id, - delete_refresh_tokens=True, ) yield self.store.delete_e2e_keys_by_device( user_id=user_id, device_id=device_id diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index 6699d0888f..dabc2a3fbb 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -70,8 +70,8 @@ class GroupsLocalHandler(object): get_invited_users_in_group = _create_rerouter("get_invited_users_in_group") - add_room_to_group = _create_rerouter("add_room_to_group") - remove_room_from_group = _create_rerouter("remove_room_from_group") + update_room_group_association = _create_rerouter("update_room_group_association") + delete_room_group_association = _create_rerouter("delete_room_group_association") update_group_summary_room = _create_rerouter("update_group_summary_room") delete_group_summary_room = _create_rerouter("delete_group_summary_room") diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 49dc33c147..f6e7e58563 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -36,6 +36,7 @@ class RegistrationHandler(BaseHandler): super(RegistrationHandler, self).__init__(hs) self.auth = hs.get_auth() + self._auth_handler = hs.get_auth_handler() self.profile_handler = hs.get_profile_handler() self.captcha_client = CaptchaServerHttpClient(hs) @@ -416,7 +417,7 @@ class RegistrationHandler(BaseHandler): create_profile_with_localpart=user.localpart, ) else: - yield self.store.user_delete_access_tokens(user_id=user_id) + yield self._auth_handler.delete_access_tokens_for_user(user_id) yield self.store.add_access_token_to_user(user_id=user_id, token=token) if displayname is not None: diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 41e1781df7..2cf34e51cb 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -20,6 +20,7 @@ from ._base import BaseHandler from synapse.api.constants import ( EventTypes, JoinRules, ) +from synapse.util.logcontext import make_deferred_yieldable, preserve_fn from synapse.util.async import concurrently_execute from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.response_cache import ResponseCache @@ -70,6 +71,7 @@ class RoomListHandler(BaseHandler): if search_filter: # We explicitly don't bother caching searches or requests for # appservice specific lists. + logger.info("Bypassing cache as search request.") return self._get_public_room_list( limit, since_token, search_filter, network_tuple=network_tuple, ) @@ -77,13 +79,16 @@ class RoomListHandler(BaseHandler): key = (limit, since_token, network_tuple) result = self.response_cache.get(key) if not result: + logger.info("No cached result, calculating one.") result = self.response_cache.set( key, - self._get_public_room_list( + preserve_fn(self._get_public_room_list)( limit, since_token, network_tuple=network_tuple ) ) - return result + else: + logger.info("Using cached deferred result.") + return make_deferred_yieldable(result) @defer.inlineCallbacks def _get_public_room_list(self, limit=None, since_token=None, diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 219529936f..b12988f3c9 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -15,7 +15,7 @@ from synapse.api.constants import Membership, EventTypes from synapse.util.async import concurrently_execute -from synapse.util.logcontext import LoggingContext +from synapse.util.logcontext import LoggingContext, make_deferred_yieldable, preserve_fn from synapse.util.metrics import Measure, measure_func from synapse.util.caches.response_cache import ResponseCache from synapse.push.clientformat import format_push_rules_for_user @@ -184,11 +184,11 @@ class SyncHandler(object): if not result: result = self.response_cache.set( sync_config.request_key, - self._wait_for_sync_for_user( + preserve_fn(self._wait_for_sync_for_user)( sync_config, since_token, timeout, full_state ) ) - return result + return make_deferred_yieldable(result) @defer.inlineCallbacks def _wait_for_sync_for_user(self, sync_config, since_token, timeout, diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py new file mode 100644 index 0000000000..343e932cb1 --- /dev/null +++ b/synapse/http/additional_resource.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.http.server import wrap_request_handler +from twisted.web.resource import Resource +from twisted.web.server import NOT_DONE_YET + + +class AdditionalResource(Resource): + """Resource wrapper for additional_resources + + If the user has configured additional_resources, we need to wrap the + handler class with a Resource so that we can map it into the resource tree. + + This class is also where we wrap the request handler with logging, metrics, + and exception handling. + """ + def __init__(self, hs, handler): + """Initialise AdditionalResource + + The ``handler`` should return a deferred which completes when it has + done handling the request. It should write a response with + ``request.write()``, and call ``request.finish()``. + + Args: + hs (synapse.server.HomeServer): homeserver + handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred): + function to be called to handle the request. + """ + Resource.__init__(self) + self._handler = handler + + # these are required by the request_handler wrapper + self.version_string = hs.version_string + self.clock = hs.get_clock() + + def render(self, request): + self._async_render(request) + return NOT_DONE_YET + + @wrap_request_handler + def _async_render(self, request): + return self._handler(request) diff --git a/synapse/http/client.py b/synapse/http/client.py index 9eba046bbf..4abb479ae3 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -18,7 +18,7 @@ from OpenSSL.SSL import VERIFY_NONE from synapse.api.errors import ( CodeMessageException, MatrixCodeMessageException, SynapseError, Codes, ) -from synapse.util.logcontext import preserve_context_over_fn +from synapse.util.logcontext import make_deferred_yieldable from synapse.util import logcontext import synapse.metrics from synapse.http.endpoint import SpiderEndpoint @@ -114,43 +114,73 @@ class SimpleHttpClient(object): raise e @defer.inlineCallbacks - def post_urlencoded_get_json(self, uri, args={}): + def post_urlencoded_get_json(self, uri, args={}, headers=None): + """ + Args: + uri (str): + args (dict[str, str|List[str]]): query params + headers (dict[str, List[str]]|None): If not None, a map from + header name to a list of values for that header + + Returns: + Deferred[object]: parsed json + """ + # TODO: Do we ever want to log message contents? logger.debug("post_urlencoded_get_json args: %s", args) query_bytes = urllib.urlencode(encode_urlencode_args(args), True) + actual_headers = { + b"Content-Type": [b"application/x-www-form-urlencoded"], + b"User-Agent": [self.user_agent], + } + if headers: + actual_headers.update(headers) + response = yield self.request( "POST", uri.encode("ascii"), - headers=Headers({ - b"Content-Type": [b"application/x-www-form-urlencoded"], - b"User-Agent": [self.user_agent], - }), + headers=Headers(actual_headers), bodyProducer=FileBodyProducer(StringIO(query_bytes)) ) - body = yield preserve_context_over_fn(readBody, response) + body = yield make_deferred_yieldable(readBody(response)) defer.returnValue(json.loads(body)) @defer.inlineCallbacks - def post_json_get_json(self, uri, post_json): + def post_json_get_json(self, uri, post_json, headers=None): + """ + + Args: + uri (str): + post_json (object): + headers (dict[str, List[str]]|None): If not None, a map from + header name to a list of values for that header + + Returns: + Deferred[object]: parsed json + """ json_str = encode_canonical_json(post_json) logger.debug("HTTP POST %s -> %s", json_str, uri) + actual_headers = { + b"Content-Type": [b"application/json"], + b"User-Agent": [self.user_agent], + } + if headers: + actual_headers.update(headers) + response = yield self.request( "POST", uri.encode("ascii"), - headers=Headers({ - b"Content-Type": [b"application/json"], - b"User-Agent": [self.user_agent], - }), + headers=Headers(actual_headers), bodyProducer=FileBodyProducer(StringIO(json_str)) ) - body = yield preserve_context_over_fn(readBody, response) + body = yield make_deferred_yieldable(readBody(response)) if 200 <= response.code < 300: defer.returnValue(json.loads(body)) @@ -160,7 +190,7 @@ class SimpleHttpClient(object): defer.returnValue(json.loads(body)) @defer.inlineCallbacks - def get_json(self, uri, args={}): + def get_json(self, uri, args={}, headers=None): """ Gets some json from the given URI. Args: @@ -169,6 +199,8 @@ class SimpleHttpClient(object): None. **Note**: The value of each key is assumed to be an iterable and *not* a string. + headers (dict[str, List[str]]|None): If not None, a map from + header name to a list of values for that header Returns: Deferred: Succeeds when we get *any* 2xx HTTP response, with the HTTP body as JSON. @@ -177,13 +209,13 @@ class SimpleHttpClient(object): error message. """ try: - body = yield self.get_raw(uri, args) + body = yield self.get_raw(uri, args, headers=headers) defer.returnValue(json.loads(body)) except CodeMessageException as e: raise self._exceptionFromFailedRequest(e.code, e.msg) @defer.inlineCallbacks - def put_json(self, uri, json_body, args={}): + def put_json(self, uri, json_body, args={}, headers=None): """ Puts some json to the given URI. Args: @@ -193,6 +225,8 @@ class SimpleHttpClient(object): None. **Note**: The value of each key is assumed to be an iterable and *not* a string. + headers (dict[str, List[str]]|None): If not None, a map from + header name to a list of values for that header Returns: Deferred: Succeeds when we get *any* 2xx HTTP response, with the HTTP body as JSON. @@ -205,17 +239,21 @@ class SimpleHttpClient(object): json_str = encode_canonical_json(json_body) + actual_headers = { + b"Content-Type": [b"application/json"], + b"User-Agent": [self.user_agent], + } + if headers: + actual_headers.update(headers) + response = yield self.request( "PUT", uri.encode("ascii"), - headers=Headers({ - b"User-Agent": [self.user_agent], - "Content-Type": ["application/json"] - }), + headers=Headers(actual_headers), bodyProducer=FileBodyProducer(StringIO(json_str)) ) - body = yield preserve_context_over_fn(readBody, response) + body = yield make_deferred_yieldable(readBody(response)) if 200 <= response.code < 300: defer.returnValue(json.loads(body)) @@ -226,7 +264,7 @@ class SimpleHttpClient(object): raise CodeMessageException(response.code, body) @defer.inlineCallbacks - def get_raw(self, uri, args={}): + def get_raw(self, uri, args={}, headers=None): """ Gets raw text from the given URI. Args: @@ -235,6 +273,8 @@ class SimpleHttpClient(object): None. **Note**: The value of each key is assumed to be an iterable and *not* a string. + headers (dict[str, List[str]]|None): If not None, a map from + header name to a list of values for that header Returns: Deferred: Succeeds when we get *any* 2xx HTTP response, with the HTTP body at text. @@ -246,15 +286,19 @@ class SimpleHttpClient(object): query_bytes = urllib.urlencode(args, True) uri = "%s?%s" % (uri, query_bytes) + actual_headers = { + b"User-Agent": [self.user_agent], + } + if headers: + actual_headers.update(headers) + response = yield self.request( "GET", uri.encode("ascii"), - headers=Headers({ - b"User-Agent": [self.user_agent], - }) + headers=Headers(actual_headers), ) - body = yield preserve_context_over_fn(readBody, response) + body = yield make_deferred_yieldable(readBody(response)) if 200 <= response.code < 300: defer.returnValue(body) @@ -274,27 +318,33 @@ class SimpleHttpClient(object): # The two should be factored out. @defer.inlineCallbacks - def get_file(self, url, output_stream, max_size=None): + def get_file(self, url, output_stream, max_size=None, headers=None): """GETs a file from a given URL Args: url (str): The URL to GET output_stream (file): File to write the response body to. + headers (dict[str, List[str]]|None): If not None, a map from + header name to a list of values for that header Returns: A (int,dict,string,int) tuple of the file length, dict of the response headers, absolute URI of the response and HTTP response code. """ + actual_headers = { + b"User-Agent": [self.user_agent], + } + if headers: + actual_headers.update(headers) + response = yield self.request( "GET", url.encode("ascii"), - headers=Headers({ - b"User-Agent": [self.user_agent], - }) + headers=Headers(actual_headers), ) - headers = dict(response.headers.getAllRawHeaders()) + resp_headers = dict(response.headers.getAllRawHeaders()) - if 'Content-Length' in headers and headers['Content-Length'] > max_size: + if 'Content-Length' in resp_headers and resp_headers['Content-Length'] > max_size: logger.warn("Requested URL is too large > %r bytes" % (self.max_size,)) raise SynapseError( 502, @@ -315,10 +365,9 @@ class SimpleHttpClient(object): # straight back in again try: - length = yield preserve_context_over_fn( - _readBodyToFile, - response, output_stream, max_size - ) + length = yield make_deferred_yieldable(_readBodyToFile( + response, output_stream, max_size, + )) except Exception as e: logger.exception("Failed to download body") raise SynapseError( @@ -327,7 +376,9 @@ class SimpleHttpClient(object): Codes.UNKNOWN, ) - defer.returnValue((length, headers, response.request.absoluteURI, response.code)) + defer.returnValue( + (length, resp_headers, response.request.absoluteURI, response.code), + ) # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. @@ -395,7 +446,7 @@ class CaptchaServerHttpClient(SimpleHttpClient): ) try: - body = yield preserve_context_over_fn(readBody, response) + body = yield make_deferred_yieldable(readBody(response)) defer.returnValue(body) except PartialDownloadError as e: # twisted dislikes google's response, no content length. diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py new file mode 100644 index 0000000000..dc680ddf43 --- /dev/null +++ b/synapse/module_api/__init__.py @@ -0,0 +1,113 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.types import UserID + + +class ModuleApi(object): + """A proxy object that gets passed to password auth providers so they + can register new users etc if necessary. + """ + def __init__(self, hs, auth_handler): + self.hs = hs + + self._store = hs.get_datastore() + self._auth = hs.get_auth() + self._auth_handler = auth_handler + + def get_user_by_req(self, req, allow_guest=False): + """Check the access_token provided for a request + + Args: + req (twisted.web.server.Request): Incoming HTTP request + allow_guest (bool): True if guest users should be allowed. If this + is False, and the access token is for a guest user, an + AuthError will be thrown + Returns: + twisted.internet.defer.Deferred[synapse.types.Requester]: + the requester for this request + Raises: + synapse.api.errors.AuthError: if no user by that token exists, + or the token is invalid. + """ + return self._auth.get_user_by_req(req, allow_guest) + + def get_qualified_user_id(self, username): + """Qualify a user id, if necessary + + Takes a user id provided by the user and adds the @ and :domain to + qualify it, if necessary + + Args: + username (str): provided user id + + Returns: + str: qualified @user:id + """ + if username.startswith('@'): + return username + return UserID(username, self.hs.hostname).to_string() + + def check_user_exists(self, user_id): + """Check if user exists. + + Args: + user_id (str): Complete @user:id + + Returns: + Deferred[str|None]: Canonical (case-corrected) user_id, or None + if the user is not registered. + """ + return self._auth_handler.check_user_exists(user_id) + + def register(self, localpart): + """Registers a new user with given localpart + + Returns: + Deferred: a 2-tuple of (user_id, access_token) + """ + reg = self.hs.get_handlers().registration_handler + return reg.register(localpart=localpart) + + def invalidate_access_token(self, access_token): + """Invalidate an access token for a user + + Args: + access_token(str): access token + + Returns: + twisted.internet.defer.Deferred - resolves once the access token + has been removed. + + Raises: + synapse.api.errors.AuthError: the access token is invalid + """ + + return self._auth_handler.delete_access_token(access_token) + + def run_db_interaction(self, desc, func, *args, **kwargs): + """Run a function with a database connection + + Args: + desc (str): description for the transaction, for metrics etc + func (func): function to be run. Passed a database cursor object + as well as *args and **kwargs + *args: positional args to be passed to func + **kwargs: named args to be passed to func + + Returns: + Deferred[object]: result of func + """ + return self._store.runInteraction(desc, func, *args, **kwargs) diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index 465b25033d..1197158fdc 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -137,7 +137,7 @@ class DeactivateAccountRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/admin/deactivate/(?P<target_user_id>[^/]*)") def __init__(self, hs): - self.store = hs.get_datastore() + self._auth_handler = hs.get_auth_handler() super(DeactivateAccountRestServlet, self).__init__(hs) @defer.inlineCallbacks @@ -149,12 +149,7 @@ class DeactivateAccountRestServlet(ClientV1RestServlet): if not is_admin: raise AuthError(403, "You are not a server admin") - # FIXME: Theoretically there is a race here wherein user resets password - # using threepid. - yield self.store.user_delete_access_tokens(target_user_id) - yield self.store.user_delete_threepids(target_user_id) - yield self.store.user_set_password_hash(target_user_id, None) - + yield self._auth_handler.deactivate_account(target_user_id) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 9536e8ade6..5669ecb724 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -85,7 +85,6 @@ def login_id_thirdparty_from_phone(identifier): class LoginRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/login$") - PASS_TYPE = "m.login.password" SAML2_TYPE = "m.login.saml2" CAS_TYPE = "m.login.cas" TOKEN_TYPE = "m.login.token" @@ -94,7 +93,6 @@ class LoginRestServlet(ClientV1RestServlet): def __init__(self, hs): super(LoginRestServlet, self).__init__(hs) self.idp_redirect_url = hs.config.saml2_idp_redirect_url - self.password_enabled = hs.config.password_enabled self.saml2_enabled = hs.config.saml2_enabled self.jwt_enabled = hs.config.jwt_enabled self.jwt_secret = hs.config.jwt_secret @@ -121,8 +119,10 @@ class LoginRestServlet(ClientV1RestServlet): # fall back to the fallback API if they don't understand one of the # login flow types returned. flows.append({"type": LoginRestServlet.TOKEN_TYPE}) - if self.password_enabled: - flows.append({"type": LoginRestServlet.PASS_TYPE}) + + flows.extend(( + {"type": t} for t in self.auth_handler.get_supported_login_types() + )) return (200, {"flows": flows}) @@ -133,14 +133,8 @@ class LoginRestServlet(ClientV1RestServlet): def on_POST(self, request): login_submission = parse_json_object_from_request(request) try: - if login_submission["type"] == LoginRestServlet.PASS_TYPE: - if not self.password_enabled: - raise SynapseError(400, "Password login has been disabled.") - - result = yield self.do_password_login(login_submission) - defer.returnValue(result) - elif self.saml2_enabled and (login_submission["type"] == - LoginRestServlet.SAML2_TYPE): + if self.saml2_enabled and (login_submission["type"] == + LoginRestServlet.SAML2_TYPE): relay_state = "" if "relay_state" in login_submission: relay_state = "&RelayState=" + urllib.quote( @@ -157,15 +151,31 @@ class LoginRestServlet(ClientV1RestServlet): result = yield self.do_token_login(login_submission) defer.returnValue(result) else: - raise SynapseError(400, "Bad login type.") + result = yield self._do_other_login(login_submission) + defer.returnValue(result) except KeyError: raise SynapseError(400, "Missing JSON keys.") @defer.inlineCallbacks - def do_password_login(self, login_submission): - if "password" not in login_submission: - raise SynapseError(400, "Missing parameter: password") + def _do_other_login(self, login_submission): + """Handle non-token/saml/jwt logins + Args: + login_submission: + + Returns: + (int, object): HTTP code/response + """ + # Log the request we got, but only certain fields to minimise the chance of + # logging someone's password (even if they accidentally put it in the wrong + # field) + logger.info( + "Got login request with identifier: %r, medium: %r, address: %r, user: %r", + login_submission.get('identifier'), + login_submission.get('medium'), + login_submission.get('address'), + login_submission.get('user'), + ) login_submission_legacy_convert(login_submission) if "identifier" not in login_submission: @@ -208,30 +218,29 @@ class LoginRestServlet(ClientV1RestServlet): if "user" not in identifier: raise SynapseError(400, "User identifier is missing 'user' key") - user_id = identifier["user"] - - if not user_id.startswith('@'): - user_id = UserID( - user_id, self.hs.hostname - ).to_string() - auth_handler = self.auth_handler - user_id = yield auth_handler.validate_password_login( - user_id=user_id, - password=login_submission["password"], + canonical_user_id, callback = yield auth_handler.validate_login( + identifier["user"], + login_submission, + ) + + device_id = yield self._register_device( + canonical_user_id, login_submission, ) - device_id = yield self._register_device(user_id, login_submission) access_token = yield auth_handler.get_access_token_for_user_id( - user_id, device_id, - login_submission.get("initial_device_display_name"), + canonical_user_id, device_id, ) + result = { - "user_id": user_id, # may have changed + "user_id": canonical_user_id, "access_token": access_token, "home_server": self.hs.hostname, "device_id": device_id, } + if callback is not None: + yield callback(result) + defer.returnValue((200, result)) @defer.inlineCallbacks @@ -244,7 +253,6 @@ class LoginRestServlet(ClientV1RestServlet): device_id = yield self._register_device(user_id, login_submission) access_token = yield auth_handler.get_access_token_for_user_id( user_id, device_id, - login_submission.get("initial_device_display_name"), ) result = { "user_id": user_id, # may have changed @@ -287,7 +295,6 @@ class LoginRestServlet(ClientV1RestServlet): ) access_token = yield auth_handler.get_access_token_for_user_id( registered_user_id, device_id, - login_submission.get("initial_device_display_name"), ) result = { diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py index 1358d0acab..6add754782 100644 --- a/synapse/rest/client/v1/logout.py +++ b/synapse/rest/client/v1/logout.py @@ -30,7 +30,7 @@ class LogoutRestServlet(ClientV1RestServlet): def __init__(self, hs): super(LogoutRestServlet, self).__init__(hs) - self.store = hs.get_datastore() + self._auth_handler = hs.get_auth_handler() def on_OPTIONS(self, request): return (200, {}) @@ -38,7 +38,7 @@ class LogoutRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request): access_token = get_access_token_from_request(request) - yield self.store.delete_access_token(access_token) + yield self._auth_handler.delete_access_token(access_token) defer.returnValue((200, {})) @@ -47,8 +47,8 @@ class LogoutAllRestServlet(ClientV1RestServlet): def __init__(self, hs): super(LogoutAllRestServlet, self).__init__(hs) - self.store = hs.get_datastore() self.auth = hs.get_auth() + self._auth_handler = hs.get_auth_handler() def on_OPTIONS(self, request): return (200, {}) @@ -57,7 +57,7 @@ class LogoutAllRestServlet(ClientV1RestServlet): def on_POST(self, request): requester = yield self.auth.get_user_by_req(request) user_id = requester.user.to_string() - yield self.store.user_delete_access_tokens(user_id) + yield self._auth_handler.delete_access_tokens_for_user(user_id) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 4990b22b9f..3062e04c59 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -13,22 +13,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging from twisted.internet import defer +from synapse.api.auth import has_access_token from synapse.api.constants import LoginType -from synapse.api.errors import LoginError, SynapseError, Codes +from synapse.api.errors import Codes, LoginError, SynapseError from synapse.http.servlet import ( - RestServlet, parse_json_object_from_request, assert_params_in_request + RestServlet, assert_params_in_request, + parse_json_object_from_request, ) from synapse.util.async import run_on_reactor from synapse.util.msisdn import phone_number_to_msisdn - from ._base import client_v2_patterns -import logging - - logger = logging.getLogger(__name__) @@ -163,7 +162,6 @@ class DeactivateAccountRestServlet(RestServlet): def __init__(self, hs): self.hs = hs - self.store = hs.get_datastore() self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() super(DeactivateAccountRestServlet, self).__init__() @@ -172,6 +170,20 @@ class DeactivateAccountRestServlet(RestServlet): def on_POST(self, request): body = parse_json_object_from_request(request) + # if the caller provides an access token, it ought to be valid. + requester = None + if has_access_token(request): + requester = yield self.auth.get_user_by_req( + request, + ) # type: synapse.types.Requester + + # allow ASes to dectivate their own users + if requester and requester.app_service: + yield self.auth_handler.deactivate_account( + requester.user.to_string() + ) + defer.returnValue((200, {})) + authed, result, params, _ = yield self.auth_handler.check_auth([ [LoginType.PASSWORD], ], body, self.hs.get_ip_from_request(request)) @@ -179,25 +191,22 @@ class DeactivateAccountRestServlet(RestServlet): if not authed: defer.returnValue((401, result)) - user_id = None - requester = None - if LoginType.PASSWORD in result: + user_id = result[LoginType.PASSWORD] # if using password, they should also be logged in - requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() - if user_id != result[LoginType.PASSWORD]: + if requester is None: + raise SynapseError( + 400, + "Deactivate account requires an access_token", + errcode=Codes.MISSING_TOKEN + ) + if requester.user.to_string() != user_id: raise LoginError(400, "", Codes.UNKNOWN) else: logger.error("Auth succeeded but no known type!", result.keys()) raise SynapseError(500, "", Codes.UNKNOWN) - # FIXME: Theoretically there is a race here wherein user resets password - # using threepid. - yield self.store.user_delete_access_tokens(user_id) - yield self.store.user_delete_threepids(user_id) - yield self.store.user_set_password_hash(user_id, None) - + yield self.auth_handler.deactivate_account(user_id) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index b57ba95d24..5321e5abbb 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) class DevicesRestServlet(servlet.RestServlet): - PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False) + PATTERNS = client_v2_patterns("/devices$", v2_alpha=False) def __init__(self, hs): """ @@ -51,7 +51,7 @@ class DeleteDevicesRestServlet(servlet.RestServlet): API for bulk deletion of devices. Accepts a JSON object with a devices key which lists the device_ids to delete. Requires user interactive auth. """ - PATTERNS = client_v2_patterns("/delete_devices", releases=[], v2_alpha=False) + PATTERNS = client_v2_patterns("/delete_devices", v2_alpha=False) def __init__(self, hs): super(DeleteDevicesRestServlet, self).__init__() @@ -93,8 +93,7 @@ class DeleteDevicesRestServlet(servlet.RestServlet): class DeviceRestServlet(servlet.RestServlet): - PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", - releases=[], v2_alpha=False) + PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", v2_alpha=False) def __init__(self, hs): """ @@ -118,6 +117,8 @@ class DeviceRestServlet(servlet.RestServlet): @defer.inlineCallbacks def on_DELETE(self, request, device_id): + requester = yield self.auth.get_user_by_req(request) + try: body = servlet.parse_json_object_from_request(request) @@ -136,11 +137,12 @@ class DeviceRestServlet(servlet.RestServlet): if not authed: defer.returnValue((401, result)) - requester = yield self.auth.get_user_by_req(request) - yield self.device_handler.delete_device( - requester.user.to_string(), - device_id, - ) + # check that the UI auth matched the access token + user_id = result[constants.LoginType.PASSWORD] + if user_id != requester.user.to_string(): + raise errors.AuthError(403, "Invalid auth") + + yield self.device_handler.delete_device(user_id, device_id) defer.returnValue((200, {})) @defer.inlineCallbacks diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py index 100f47ca9e..792608cd48 100644 --- a/synapse/rest/client/v2_alpha/groups.py +++ b/synapse/rest/client/v2_alpha/groups.py @@ -39,20 +39,23 @@ class GroupServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, group_id): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() + requester_user_id = requester.user.to_string() - group_description = yield self.groups_handler.get_group_profile(group_id, user_id) + group_description = yield self.groups_handler.get_group_profile( + group_id, + requester_user_id, + ) defer.returnValue((200, group_description)) @defer.inlineCallbacks def on_POST(self, request, group_id): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() + requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) yield self.groups_handler.update_group_profile( - group_id, user_id, content, + group_id, requester_user_id, content, ) defer.returnValue((200, {})) @@ -72,9 +75,12 @@ class GroupSummaryServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, group_id): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() + requester_user_id = requester.user.to_string() - get_group_summary = yield self.groups_handler.get_group_summary(group_id, user_id) + get_group_summary = yield self.groups_handler.get_group_summary( + group_id, + requester_user_id, + ) defer.returnValue((200, get_group_summary)) @@ -101,11 +107,11 @@ class GroupSummaryRoomsCatServlet(RestServlet): @defer.inlineCallbacks def on_PUT(self, request, group_id, category_id, room_id): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() + requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) resp = yield self.groups_handler.update_group_summary_room( - group_id, user_id, + group_id, requester_user_id, room_id=room_id, category_id=category_id, content=content, @@ -116,10 +122,10 @@ class GroupSummaryRoomsCatServlet(RestServlet): @defer.inlineCallbacks def on_DELETE(self, request, group_id, category_id, room_id): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() + requester_user_id = requester.user.to_string() resp = yield self.groups_handler.delete_group_summary_room( - group_id, user_id, + group_id, requester_user_id, room_id=room_id, category_id=category_id, ) @@ -143,10 +149,10 @@ class GroupCategoryServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, group_id, category_id): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() + requester_user_id = requester.user.to_string() category = yield self.groups_handler.get_group_category( - group_id, user_id, + group_id, requester_user_id, category_id=category_id, ) @@ -155,11 +161,11 @@ class GroupCategoryServlet(RestServlet): @defer.inlineCallbacks def on_PUT(self, request, group_id, category_id): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() + requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) resp = yield self.groups_handler.update_group_category( - group_id, user_id, + group_id, requester_user_id, category_id=category_id, content=content, ) @@ -169,10 +175,10 @@ class GroupCategoryServlet(RestServlet): @defer.inlineCallbacks def on_DELETE(self, request, group_id, category_id): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() + requester_user_id = requester.user.to_string() resp = yield self.groups_handler.delete_group_category( - group_id, user_id, + group_id, requester_user_id, category_id=category_id, ) @@ -195,10 +201,10 @@ class GroupCategoriesServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, group_id): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() + requester_user_id = requester.user.to_string() category = yield self.groups_handler.get_group_categories( - group_id, user_id, + group_id, requester_user_id, ) defer.returnValue((200, category)) @@ -220,10 +226,10 @@ class GroupRoleServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, group_id, role_id): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() + requester_user_id = requester.user.to_string() category = yield self.groups_handler.get_group_role( - group_id, user_id, + group_id, requester_user_id, role_id=role_id, ) @@ -232,11 +238,11 @@ class GroupRoleServlet(RestServlet): @defer.inlineCallbacks def on_PUT(self, request, group_id, role_id): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() + requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) resp = yield self.groups_handler.update_group_role( - group_id, user_id, + group_id, requester_user_id, role_id=role_id, content=content, ) @@ -246,10 +252,10 @@ class GroupRoleServlet(RestServlet): @defer.inlineCallbacks def on_DELETE(self, request, group_id, role_id): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() + requester_user_id = requester.user.to_string() resp = yield self.groups_handler.delete_group_role( - group_id, user_id, + group_id, requester_user_id, role_id=role_id, ) @@ -272,10 +278,10 @@ class GroupRolesServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, group_id): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() + requester_user_id = requester.user.to_string() category = yield self.groups_handler.get_group_roles( - group_id, user_id, + group_id, requester_user_id, ) defer.returnValue((200, category)) @@ -343,9 +349,9 @@ class GroupRoomServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, group_id): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() + requester_user_id = requester.user.to_string() - result = yield self.groups_handler.get_rooms_in_group(group_id, user_id) + result = yield self.groups_handler.get_rooms_in_group(group_id, requester_user_id) defer.returnValue((200, result)) @@ -364,9 +370,9 @@ class GroupUsersServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, group_id): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() + requester_user_id = requester.user.to_string() - result = yield self.groups_handler.get_users_in_group(group_id, user_id) + result = yield self.groups_handler.get_users_in_group(group_id, requester_user_id) defer.returnValue((200, result)) @@ -385,9 +391,12 @@ class GroupInvitedUsersServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, group_id): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() + requester_user_id = requester.user.to_string() - result = yield self.groups_handler.get_invited_users_in_group(group_id, user_id) + result = yield self.groups_handler.get_invited_users_in_group( + group_id, + requester_user_id, + ) defer.returnValue((200, result)) @@ -407,14 +416,18 @@ class GroupCreateServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() + requester_user_id = requester.user.to_string() # TODO: Create group on remote server content = parse_json_object_from_request(request) localpart = content.pop("localpart") group_id = GroupID(localpart, self.server_name).to_string() - result = yield self.groups_handler.create_group(group_id, user_id, content) + result = yield self.groups_handler.create_group( + group_id, + requester_user_id, + content, + ) defer.returnValue((200, result)) @@ -435,11 +448,11 @@ class GroupAdminRoomsServlet(RestServlet): @defer.inlineCallbacks def on_PUT(self, request, group_id, room_id): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() + requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - result = yield self.groups_handler.add_room_to_group( - group_id, user_id, room_id, content, + result = yield self.groups_handler.update_room_group_association( + group_id, requester_user_id, room_id, content, ) defer.returnValue((200, result)) @@ -447,10 +460,10 @@ class GroupAdminRoomsServlet(RestServlet): @defer.inlineCallbacks def on_DELETE(self, request, group_id, room_id): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() + requester_user_id = requester.user.to_string() - result = yield self.groups_handler.remove_room_from_group( - group_id, user_id, room_id, + result = yield self.groups_handler.delete_room_group_association( + group_id, requester_user_id, room_id, ) defer.returnValue((200, result)) @@ -685,9 +698,9 @@ class GroupsForUserServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() + requester_user_id = requester.user.to_string() - result = yield self.groups_handler.get_joined_groups(user_id) + result = yield self.groups_handler.get_joined_groups(requester_user_id) defer.returnValue((200, result)) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 943e87e7fd..3cc87ea63f 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -53,8 +53,7 @@ class KeyUploadServlet(RestServlet): }, } """ - PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$", - releases=()) + PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$") def __init__(self, hs): """ @@ -128,10 +127,7 @@ class KeyQueryServlet(RestServlet): } } } } } } """ - PATTERNS = client_v2_patterns( - "/keys/query$", - releases=() - ) + PATTERNS = client_v2_patterns("/keys/query$") def __init__(self, hs): """ @@ -160,10 +156,7 @@ class KeyChangesServlet(RestServlet): 200 OK { "changed": ["@foo:example.com"] } """ - PATTERNS = client_v2_patterns( - "/keys/changes$", - releases=() - ) + PATTERNS = client_v2_patterns("/keys/changes$") def __init__(self, hs): """ @@ -213,10 +206,7 @@ class OneTimeKeyServlet(RestServlet): } } } } """ - PATTERNS = client_v2_patterns( - "/keys/claim$", - releases=() - ) + PATTERNS = client_v2_patterns("/keys/claim$") def __init__(self, hs): super(OneTimeKeyServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py index fd2a3d69d4..ec170109fe 100644 --- a/synapse/rest/client/v2_alpha/notifications.py +++ b/synapse/rest/client/v2_alpha/notifications.py @@ -30,7 +30,7 @@ logger = logging.getLogger(__name__) class NotificationsServlet(RestServlet): - PATTERNS = client_v2_patterns("/notifications$", releases=()) + PATTERNS = client_v2_patterns("/notifications$") def __init__(self, hs): super(NotificationsServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index d9a8cdbbb5..eebd071e59 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -557,25 +557,28 @@ class RegisterRestServlet(RestServlet): Args: (str) user_id: full canonical @user:id (object) params: registration parameters, from which we pull - device_id and initial_device_name + device_id, initial_device_name and inhibit_login Returns: defer.Deferred: (object) dictionary for response from /register """ - device_id = yield self._register_device(user_id, params) + result = { + "user_id": user_id, + "home_server": self.hs.hostname, + } + if not params.get("inhibit_login", False): + device_id = yield self._register_device(user_id, params) - access_token = ( - yield self.auth_handler.get_access_token_for_user_id( - user_id, device_id=device_id, - initial_display_name=params.get("initial_device_display_name") + access_token = ( + yield self.auth_handler.get_access_token_for_user_id( + user_id, device_id=device_id, + ) ) - ) - defer.returnValue({ - "user_id": user_id, - "access_token": access_token, - "home_server": self.hs.hostname, - "device_id": device_id, - }) + result.update({ + "access_token": access_token, + "device_id": device_id, + }) + defer.returnValue(result) def _register_device(self, user_id, params): """Register a device for a user. diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py index d607bd2970..90bdb1db15 100644 --- a/synapse/rest/client/v2_alpha/sendtodevice.py +++ b/synapse/rest/client/v2_alpha/sendtodevice.py @@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) class SendToDeviceRestServlet(servlet.RestServlet): PATTERNS = client_v2_patterns( "/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$", - releases=[], v2_alpha=False + v2_alpha=False ) def __init__(self, hs): diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index 6fceb23e26..6773b9ba60 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -26,7 +26,7 @@ logger = logging.getLogger(__name__) class ThirdPartyProtocolsServlet(RestServlet): - PATTERNS = client_v2_patterns("/thirdparty/protocols", releases=()) + PATTERNS = client_v2_patterns("/thirdparty/protocols") def __init__(self, hs): super(ThirdPartyProtocolsServlet, self).__init__() @@ -43,8 +43,7 @@ class ThirdPartyProtocolsServlet(RestServlet): class ThirdPartyProtocolServlet(RestServlet): - PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$", - releases=()) + PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$") def __init__(self, hs): super(ThirdPartyProtocolServlet, self).__init__() @@ -66,8 +65,7 @@ class ThirdPartyProtocolServlet(RestServlet): class ThirdPartyUserServlet(RestServlet): - PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$", - releases=()) + PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$") def __init__(self, hs): super(ThirdPartyUserServlet, self).__init__() @@ -90,8 +88,7 @@ class ThirdPartyUserServlet(RestServlet): class ThirdPartyLocationServlet(RestServlet): - PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$", - releases=()) + PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$") def __init__(self, hs): super(ThirdPartyLocationServlet, self).__init__() diff --git a/synapse/storage/group_server.py b/synapse/storage/group_server.py index 9e63db5c6c..f6924e1a32 100644 --- a/synapse/storage/group_server.py +++ b/synapse/storage/group_server.py @@ -35,7 +35,9 @@ class GroupServerStore(SQLBaseStore): keyvalues={ "group_id": group_id, }, - retcols=("name", "short_description", "long_description", "avatar_url",), + retcols=( + "name", "short_description", "long_description", "avatar_url", "is_public" + ), allow_none=True, desc="is_user_in_group", ) @@ -844,19 +846,25 @@ class GroupServerStore(SQLBaseStore): ) return self.runInteraction("remove_user_from_group", _remove_user_from_group_txn) - def add_room_to_group(self, group_id, room_id, is_public): - return self._simple_insert( + def update_room_group_association(self, group_id, room_id, is_public): + return self._simple_upsert( table="group_rooms", - values={ + keyvalues={ "group_id": group_id, "room_id": room_id, + }, + values={ "is_public": is_public, }, - desc="add_room_to_group", + insertion_values={ + "group_id": group_id, + "room_id": room_id, + }, + desc="update_room_group_association", ) - def remove_room_from_group(self, group_id, room_id): - def _remove_room_from_group_txn(txn): + def delete_room_group_association(self, group_id, room_id): + def _delete_room_group_association_txn(txn): self._simple_delete_txn( txn, table="group_rooms", @@ -875,7 +883,7 @@ class GroupServerStore(SQLBaseStore): }, ) return self.runInteraction( - "remove_room_from_group", _remove_room_from_group_txn, + "delete_room_group_association", _delete_room_group_association_txn, ) def get_publicised_groups_for_user(self, user_id): @@ -1026,6 +1034,7 @@ class GroupServerStore(SQLBaseStore): "avatar_url": avatar_url, "short_description": short_description, "long_description": long_description, + "is_public": True, }, desc="create_group", ) @@ -1086,6 +1095,24 @@ class GroupServerStore(SQLBaseStore): desc="update_remote_attestion", ) + def remove_attestation_renewal(self, group_id, user_id): + """Remove an attestation that we thought we should renew, but actually + shouldn't. Ideally this would never get called as we would never + incorrectly try and do attestations for local users on local groups. + + Args: + group_id (str) + user_id (str) + """ + return self._simple_delete( + table="group_attestations_renewals", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + desc="remove_attestation_renewal", + ) + @defer.inlineCallbacks def get_remote_attestation(self, group_id, user_id): """Get the attestation that proves the remote agrees that the user is diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 817c2185c8..d1691bbac2 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 45 +SCHEMA_VERSION = 46 dir_path = os.path.abspath(os.path.dirname(__file__)) @@ -44,6 +44,13 @@ def prepare_database(db_conn, database_engine, config): If `config` is None then prepare_database will assert that no upgrade is necessary, *or* will create a fresh database if the database is empty. + + Args: + db_conn: + database_engine: + config (synapse.config.homeserver.HomeServerConfig|None): + application config, or None if we are connecting to an existing + database which we expect to be configured already """ try: cur = db_conn.cursor() @@ -64,6 +71,10 @@ def prepare_database(db_conn, database_engine, config): else: _setup_new_database(cur, database_engine) + # check if any of our configured dynamic modules want a database + if config is not None: + _apply_module_schemas(cur, database_engine, config) + cur.close() db_conn.commit() except Exception: @@ -283,6 +294,65 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files, ) +def _apply_module_schemas(txn, database_engine, config): + """Apply the module schemas for the dynamic modules, if any + + Args: + cur: database cursor + database_engine: synapse database engine class + config (synapse.config.homeserver.HomeServerConfig): + application config + """ + for (mod, _config) in config.password_providers: + if not hasattr(mod, 'get_db_schema_files'): + continue + modname = ".".join((mod.__module__, mod.__name__)) + _apply_module_schema_files( + txn, database_engine, modname, mod.get_db_schema_files(), + ) + + +def _apply_module_schema_files(cur, database_engine, modname, names_and_streams): + """Apply the module schemas for a single module + + Args: + cur: database cursor + database_engine: synapse database engine class + modname (str): fully qualified name of the module + names_and_streams (Iterable[(str, file)]): the names and streams of + schemas to be applied + """ + cur.execute( + database_engine.convert_param_style( + "SELECT file FROM applied_module_schemas WHERE module_name = ?" + ), + (modname,) + ) + applied_deltas = set(d for d, in cur) + for (name, stream) in names_and_streams: + if name in applied_deltas: + continue + + root_name, ext = os.path.splitext(name) + if ext != '.sql': + raise PrepareDatabaseException( + "only .sql files are currently supported for module schemas", + ) + + logger.info("applying schema %s for %s", name, modname) + for statement in get_statements(stream): + cur.execute(statement) + + # Mark as done. + cur.execute( + database_engine.convert_param_style( + "INSERT INTO applied_module_schemas (module_name, file)" + " VALUES (?,?)", + ), + (modname, name) + ) + + def get_statements(f): statement_buffer = "" in_comment = False # If we're in a /* ... */ style comment diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 20acd58fcf..9c4f61da76 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -36,12 +36,15 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): columns=["user_id", "device_id"], ) - self.register_background_index_update( - "refresh_tokens_device_index", - index_name="refresh_tokens_device_id", - table="refresh_tokens", - columns=["user_id", "device_id"], - ) + # we no longer use refresh tokens, but it's possible that some people + # might have a background update queued to build this index. Just + # clear the background update. + @defer.inlineCallbacks + def noop_update(progress, batch_size): + yield self._end_background_update("refresh_tokens_device_index") + defer.returnValue(1) + self.register_background_update_handler( + "refresh_tokens_device_index", noop_update) @defer.inlineCallbacks def add_access_token_to_user(self, user_id, token, device_id=None): @@ -177,9 +180,11 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): ) if create_profile_with_localpart: + # set a default displayname serverside to avoid ugly race + # between auto-joins and clients trying to set displaynames txn.execute( - "INSERT INTO profiles(user_id) VALUES (?)", - (create_profile_with_localpart,) + "INSERT INTO profiles(user_id, displayname) VALUES (?,?)", + (create_profile_with_localpart, create_profile_with_localpart) ) self._invalidate_cache_and_stream( @@ -238,10 +243,9 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): @defer.inlineCallbacks def user_delete_access_tokens(self, user_id, except_token_id=None, - device_id=None, - delete_refresh_tokens=False): + device_id=None): """ - Invalidate access/refresh tokens belonging to a user + Invalidate access tokens belonging to a user Args: user_id (str): ID of user the tokens belong to @@ -250,10 +254,9 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): device_id (str|None): ID of device the tokens are associated with. If None, tokens associated with any device (or no device) will be deleted - delete_refresh_tokens (bool): True to delete refresh tokens as - well as access tokens. Returns: - defer.Deferred: + defer.Deferred[list[str, str|None]]: a list of the deleted tokens + and device IDs """ def f(txn): keyvalues = { @@ -262,13 +265,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): if device_id is not None: keyvalues["device_id"] = device_id - if delete_refresh_tokens: - self._simple_delete_txn( - txn, - table="refresh_tokens", - keyvalues=keyvalues, - ) - items = keyvalues.items() where_clause = " AND ".join(k + " = ?" for k, _ in items) values = [v for _, v in items] @@ -277,14 +273,14 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): values.append(except_token_id) txn.execute( - "SELECT token FROM access_tokens WHERE %s" % where_clause, + "SELECT token, device_id FROM access_tokens WHERE %s" % where_clause, values ) - rows = self.cursor_to_dict(txn) + tokens_and_devices = [(r[0], r[1]) for r in txn] - for row in rows: + for token, _ in tokens_and_devices: self._invalidate_cache_and_stream( - txn, self.get_user_by_access_token, (row["token"],) + txn, self.get_user_by_access_token, (token,) ) txn.execute( @@ -292,6 +288,8 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): values ) + return tokens_and_devices + yield self.runInteraction( "user_delete_access_tokens", f, ) diff --git a/synapse/storage/schema/delta/23/refresh_tokens.sql b/synapse/storage/schema/delta/23/refresh_tokens.sql deleted file mode 100644 index 34db0cf12b..0000000000 --- a/synapse/storage/schema/delta/23/refresh_tokens.sql +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE IF NOT EXISTS refresh_tokens( - id INTEGER PRIMARY KEY, - token TEXT NOT NULL, - user_id TEXT NOT NULL, - UNIQUE (token) -); diff --git a/synapse/storage/schema/delta/33/refreshtoken_device_index.sql b/synapse/storage/schema/delta/33/refreshtoken_device_index.sql deleted file mode 100644 index bb225dafbf..0000000000 --- a/synapse/storage/schema/delta/33/refreshtoken_device_index.sql +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -INSERT INTO background_updates (update_name, progress_json) VALUES - ('refresh_tokens_device_index', '{}'); diff --git a/synapse/storage/schema/delta/33/refreshtoken_device.sql b/synapse/storage/schema/delta/46/drop_refresh_tokens.sql index 290bd6da86..68c48a89a9 100644 --- a/synapse/storage/schema/delta/33/refreshtoken_device.sql +++ b/synapse/storage/schema/delta/46/drop_refresh_tokens.sql @@ -1,4 +1,4 @@ -/* Copyright 2016 OpenMarket Ltd +/* Copyright 2017 New Vector Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,4 +13,5 @@ * limitations under the License. */ -ALTER TABLE refresh_tokens ADD COLUMN device_id TEXT; +/* we no longer use (or create) the refresh_tokens table */ +DROP TABLE IF EXISTS refresh_tokens; diff --git a/synapse/storage/schema/delta/46/group_server.sql b/synapse/storage/schema/delta/46/group_server.sql new file mode 100644 index 0000000000..e754b554f8 --- /dev/null +++ b/synapse/storage/schema/delta/46/group_server.sql @@ -0,0 +1,32 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE groups_new ( + group_id TEXT NOT NULL, + name TEXT, -- the display name of the room + avatar_url TEXT, + short_description TEXT, + long_description TEXT, + is_public BOOL NOT NULL -- whether non-members can access group APIs +); + +-- NB: awful hack to get the default to be true on postgres and 1 on sqlite +INSERT INTO groups_new + SELECT group_id, name, avatar_url, short_description, long_description, (1=1) FROM groups; + +DROP TABLE groups; +ALTER TABLE groups_new RENAME TO groups; + +CREATE UNIQUE INDEX groups_idx ON groups(group_id); diff --git a/synapse/storage/schema/schema_version.sql b/synapse/storage/schema/schema_version.sql index a7ade69986..42e5cb6df5 100644 --- a/synapse/storage/schema/schema_version.sql +++ b/synapse/storage/schema/schema_version.sql @@ -25,3 +25,10 @@ CREATE TABLE IF NOT EXISTS applied_schema_deltas( file TEXT NOT NULL, UNIQUE(version, file) ); + +-- a list of schema files we have loaded on behalf of dynamic modules +CREATE TABLE IF NOT EXISTS applied_module_schemas( + module_name TEXT NOT NULL, + file TEXT NOT NULL, + UNIQUE(module_name, file) +); diff --git a/synapse/util/async.py b/synapse/util/async.py index 1a884e96ee..e786fb38a9 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -278,8 +278,13 @@ class Limiter(object): if entry[0] >= self.max_count: new_defer = defer.Deferred() entry[1].append(new_defer) + + logger.info("Waiting to acquire limiter lock for key %r", key) with PreserveLoggingContext(): yield new_defer + logger.info("Acquired limiter lock for key %r", key) + else: + logger.info("Acquired uncontended limiter lock for key %r", key) entry[0] += 1 @@ -288,16 +293,21 @@ class Limiter(object): try: yield finally: + logger.info("Releasing limiter lock for key %r", key) + # We've finished executing so check if there are any things # blocked waiting to execute and start one of them entry[0] -= 1 - try: - entry[1].pop(0).callback(None) - except IndexError: - # If nothing else is executing for this key then remove it - # from the map - if entry[0] == 0: - self.key_to_defer.pop(key, None) + + if entry[1]: + next_def = entry[1].pop(0) + + with PreserveLoggingContext(): + next_def.callback(None) + elif entry[0] == 0: + # We were the last thing for this key: remove it from the + # map. + del self.key_to_defer[key] defer.returnValue(_ctx_manager()) |