diff options
26 files changed, 446 insertions, 70 deletions
diff --git a/CHANGES.rst b/CHANGES.rst index 2ec10516fd..f1d2c7a765 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,3 +1,17 @@ +Changes in synapse v0.10.0-r2 (2015-09-16) +========================================== + +* Fix bug where we always fetched remote server signing keys instead of using + ones in our cache. +* Fix adding threepids to an existing account. +* Fix bug with invinting over federation where remote server was already in + the room. (PR #281, SYN-392) + +Changes in synapse v0.10.0-r1 (2015-09-08) +========================================== + +* Fix bug with python packaging + Changes in synapse v0.10.0 (2015-09-03) ======================================= diff --git a/scripts-dev/check_auth.py b/scripts-dev/check_auth.py index b889ac7fa7..4fa8792a5f 100644 --- a/scripts-dev/check_auth.py +++ b/scripts-dev/check_auth.py @@ -56,10 +56,9 @@ if __name__ == '__main__': js = json.load(args.json) - auth = Auth(Mock()) check_auth( auth, [FrozenEvent(d) for d in js["auth_chain"]], - [FrozenEvent(d) for d in js["pdus"]], + [FrozenEvent(d) for d in js.get("pdus", [])], ) diff --git a/synapse/__init__.py b/synapse/__init__.py index d85bb3dce0..d62294e6bb 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -16,4 +16,4 @@ """ This is a reference implementation of a Matrix home server. """ -__version__ = "0.10.0" +__version__ = "0.10.0-r2" diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 49a068afb1..847ff60671 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -23,6 +23,7 @@ from synapse.util.logutils import log_function from synapse.types import UserID, EventID import logging +import pymacaroons logger = logging.getLogger(__name__) @@ -40,6 +41,12 @@ class Auth(object): self.store = hs.get_datastore() self.state = hs.get_state_handler() self.TOKEN_NOT_FOUND_HTTP_STATUS = 401 + self._KNOWN_CAVEAT_PREFIXES = set([ + "gen = ", + "type = ", + "time < ", + "user_id = ", + ]) def check(self, event, auth_events): """ Checks if this event is correctly authed. @@ -65,6 +72,14 @@ class Auth(object): # FIXME return True + creation_event = auth_events.get((EventTypes.Create, ""), None) + + if not creation_event: + raise SynapseError( + 403, + "Room %r does not exist" % (event.room_id,) + ) + # FIXME: Temp hack if event.type == EventTypes.Aliases: return True @@ -410,7 +425,7 @@ class Auth(object): except KeyError: pass # normal users won't have the user_id query parameter set. - user_info = yield self.get_user_by_access_token(access_token) + user_info = yield self._get_user_by_access_token(access_token) user = user_info["user"] token_id = user_info["token_id"] @@ -437,7 +452,7 @@ class Auth(object): ) @defer.inlineCallbacks - def get_user_by_access_token(self, token): + def _get_user_by_access_token(self, token): """ Get a registered user's ID. Args: @@ -447,6 +462,86 @@ class Auth(object): Raises: AuthError if no user by that token exists or the token is invalid. """ + try: + ret = yield self._get_user_from_macaroon(token) + except AuthError: + # TODO(daniel): Remove this fallback when all existing access tokens + # have been re-issued as macaroons. + ret = yield self._look_up_user_by_access_token(token) + defer.returnValue(ret) + + @defer.inlineCallbacks + def _get_user_from_macaroon(self, macaroon_str): + try: + macaroon = pymacaroons.Macaroon.deserialize(macaroon_str) + self._validate_macaroon(macaroon) + + user_prefix = "user_id = " + for caveat in macaroon.caveats: + if caveat.caveat_id.startswith(user_prefix): + user = UserID.from_string(caveat.caveat_id[len(user_prefix):]) + # This codepath exists so that we can actually return a + # token ID, because we use token IDs in place of device + # identifiers throughout the codebase. + # TODO(daniel): Remove this fallback when device IDs are + # properly implemented. + ret = yield self._look_up_user_by_access_token(macaroon_str) + if ret["user"] != user: + logger.error( + "Macaroon user (%s) != DB user (%s)", + user, + ret["user"] + ) + raise AuthError( + self.TOKEN_NOT_FOUND_HTTP_STATUS, + "User mismatch in macaroon", + errcode=Codes.UNKNOWN_TOKEN + ) + defer.returnValue(ret) + raise AuthError( + self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon", + errcode=Codes.UNKNOWN_TOKEN + ) + except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError): + raise AuthError( + self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.", + errcode=Codes.UNKNOWN_TOKEN + ) + + def _validate_macaroon(self, macaroon): + v = pymacaroons.Verifier() + v.satisfy_exact("gen = 1") + v.satisfy_exact("type = access") + v.satisfy_general(lambda c: c.startswith("user_id = ")) + v.satisfy_general(self._verify_expiry) + v.verify(macaroon, self.hs.config.macaroon_secret_key) + + v = pymacaroons.Verifier() + v.satisfy_general(self._verify_recognizes_caveats) + v.verify(macaroon, self.hs.config.macaroon_secret_key) + + def _verify_expiry(self, caveat): + prefix = "time < " + if not caveat.startswith(prefix): + return False + # TODO(daniel): Enable expiry check when clients actually know how to + # refresh tokens. (And remember to enable the tests) + return True + expiry = int(caveat[len(prefix):]) + now = self.hs.get_clock().time_msec() + return now < expiry + + def _verify_recognizes_caveats(self, caveat): + first_space = caveat.find(" ") + if first_space < 0: + return False + second_space = caveat.find(" ", first_space + 1) + if second_space < 0: + return False + return caveat[:second_space + 1] in self._KNOWN_CAVEAT_PREFIXES + + @defer.inlineCallbacks + def _look_up_user_by_access_token(self, token): ret = yield self.store.get_user_by_access_token(token) if not ret: raise AuthError( @@ -457,7 +552,6 @@ class Auth(object): "user": UserID.from_string(ret.get("name")), "token_id": ret.get("token_id", None), } - defer.returnValue(user_info) @defer.inlineCallbacks diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index c23f853230..15c0a4a003 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -16,10 +16,23 @@ import sys sys.dont_write_bytecode = True -from synapse.python_dependencies import check_requirements, DEPENDENCY_LINKS +from synapse.python_dependencies import ( + check_requirements, DEPENDENCY_LINKS, MissingRequirementError +) if __name__ == '__main__': - check_requirements() + try: + check_requirements() + except MissingRequirementError as e: + message = "\n".join([ + "Missing Requirement: %s" % (e.message,), + "To install run:", + " pip install --upgrade --force \"%s\"" % (e.dependency,), + "", + ]) + sys.stderr.writelines(message) + sys.exit(1) + from synapse.storage.engines import create_engine, IncorrectDatabaseSetup from synapse.storage import ( @@ -221,7 +234,7 @@ class SynapseHomeServer(HomeServer): listener_config, root_resource, ), - self.tls_context_factory, + self.tls_server_context_factory, interface=bind_address ) else: @@ -365,7 +378,6 @@ def setup(config_options): Args: config_options_options: The options passed to Synapse. Usually `sys.argv[1:]`. - should_run (bool): Whether to start the reactor. Returns: HomeServer @@ -388,7 +400,7 @@ def setup(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - tls_context_factory = context_factory.ServerContextFactory(config) + tls_server_context_factory = context_factory.ServerContextFactory(config) database_engine = create_engine(config.database_config["name"]) config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection @@ -396,7 +408,7 @@ def setup(config_options): hs = SynapseHomeServer( config.server_name, db_config=config.database_config, - tls_context_factory=tls_context_factory, + tls_server_context_factory=tls_server_context_factory, config=config, content_addr=config.content_addr, version_string=version_string, diff --git a/synapse/config/logger.py b/synapse/config/logger.py index fa542623b7..daca698d0c 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -21,6 +21,7 @@ import logging.config import yaml from string import Template import os +import signal DEFAULT_LOG_CONFIG = Template(""" @@ -142,6 +143,19 @@ class LoggingConfig(Config): handler = logging.handlers.RotatingFileHandler( self.log_file, maxBytes=(1000 * 1000 * 100), backupCount=3 ) + + def sighup(signum, stack): + logger.info("Closing log file due to SIGHUP") + handler.doRollover() + logger.info("Opened new log file due to SIGHUP") + + # TODO(paul): obviously this is a terrible mechanism for + # stealing SIGHUP, because it means no other part of synapse + # can use it instead. If we want to catch SIGHUP anywhere + # else as well, I'd suggest we find a nicer way to broadcast + # it around. + if getattr(signal, "SIGHUP"): + signal.signal(signal.SIGHUP, sighup) else: handler = logging.StreamHandler() handler.setFormatter(formatter) diff --git a/synapse/config/tls.py b/synapse/config/tls.py index 4751d39bc9..e6023a718d 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -42,6 +42,14 @@ class TlsConfig(Config): config.get("tls_dh_params_path"), "tls_dh_params" ) + # This config option applies to non-federation HTTP clients + # (e.g. for talking to recaptcha, identity servers, and such) + # It should never be used in production, and is intended for + # use only when running tests. + self.use_insecure_ssl_client_just_for_testing_do_not_use = config.get( + "use_insecure_ssl_client_just_for_testing_do_not_use" + ) + def default_config(self, config_dir_path, server_name): base_key_name = os.path.join(config_dir_path, server_name) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index a692cdbe55..1b1b31c5c0 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -162,7 +162,9 @@ class Keyring(object): def remove_deferreds(res, server_name, group_id): server_to_gids[server_name].discard(group_id) if not server_to_gids[server_name]: - server_to_deferred.pop(server_name).callback(None) + d = server_to_deferred.pop(server_name, None) + if d: + d.callback(None) return res for g_id, deferred in deferreds.items(): @@ -200,8 +202,15 @@ class Keyring(object): else: break - for server_name, deferred in server_to_deferred: - self.key_downloads[server_name] = ObservableDeferred(deferred) + for server_name, deferred in server_to_deferred.items(): + d = ObservableDeferred(deferred) + self.key_downloads[server_name] = d + + def rm(r, server_name): + self.key_downloads.pop(server_name, None) + return r + + d.addBoth(rm, server_name) def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred): """Takes a dict of KeyGroups and tries to find at least one key for @@ -220,9 +229,8 @@ class Keyring(object): merged_results = {} missing_keys = { - group.server_name: key_id + group.server_name: set(group.key_ids) for group in group_id_to_group.values() - for key_id in group.key_ids } for fn in key_fetch_fns: @@ -279,16 +287,15 @@ class Keyring(object): def get_keys_from_store(self, server_name_and_key_ids): res = yield defer.gatherResults( [ - self.store.get_server_verify_keys(server_name, key_ids) + self.store.get_server_verify_keys( + server_name, key_ids + ).addCallback(lambda ks, server: (server, ks), server_name) for server_name, key_ids in server_name_and_key_ids ], consumeErrors=True, ).addErrback(unwrapFirstError) - defer.returnValue(dict(zip( - [server_name for server_name, _ in server_name_and_key_ids], - res - ))) + defer.returnValue(dict(res)) @defer.inlineCallbacks def get_keys_from_perspectives(self, server_name_and_key_ids): @@ -463,7 +470,7 @@ class Keyring(object): continue (response, tls_certificate) = yield fetch_server_key( - server_name, self.hs.tls_context_factory, + server_name, self.hs.tls_server_context_factory, path=(b"/_matrix/key/v2/server/%s" % ( urllib.quote(requested_key_id), )).encode("ascii"), @@ -597,7 +604,7 @@ class Keyring(object): # Try to fetch the key from the remote server. (response, tls_certificate) = yield fetch_server_key( - server_name, self.hs.tls_context_factory + server_name, self.hs.tls_server_context_factory ) # Check the response. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 59f687e0f1..793b3fcd8b 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -19,7 +19,6 @@ from ._base import BaseHandler from synapse.api.constants import LoginType from synapse.types import UserID from synapse.api.errors import LoginError, Codes -from synapse.http.client import SimpleHttpClient from synapse.util.async import run_on_reactor from twisted.web.client import PartialDownloadError @@ -187,7 +186,7 @@ class AuthHandler(BaseHandler): # TODO: get this from the homeserver rather than creating a new one for # each request try: - client = SimpleHttpClient(self.hs) + client = self.hs.get_simple_http_client() resp_body = yield client.post_urlencoded_get_json( self.hs.config.recaptcha_siteverify_api, args={ diff --git a/synapse/http/client.py b/synapse/http/client.py index 4b8fd3d3a3..0933388c04 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from OpenSSL import SSL +from OpenSSL.SSL import VERIFY_NONE from synapse.api.errors import CodeMessageException from synapse.util.logcontext import preserve_context_over_fn @@ -19,7 +21,7 @@ import synapse.metrics from canonicaljson import encode_canonical_json -from twisted.internet import defer, reactor +from twisted.internet import defer, reactor, ssl from twisted.web.client import ( Agent, readBody, FileBodyProducer, PartialDownloadError, HTTPConnectionPool, @@ -59,7 +61,12 @@ class SimpleHttpClient(object): # 'like a browser' pool = HTTPConnectionPool(reactor) pool.maxPersistentPerHost = 10 - self.agent = Agent(reactor, pool=pool) + self.agent = Agent( + reactor, + pool=pool, + connectTimeout=15, + contextFactory=hs.get_http_client_context_factory() + ) self.version_string = hs.version_string def request(self, method, uri, *args, **kwargs): @@ -252,3 +259,18 @@ def _print_ex(e): _print_ex(ex) else: logger.exception(e) + + +class InsecureInterceptableContextFactory(ssl.ContextFactory): + """ + Factory for PyOpenSSL SSL contexts which accepts any certificate for any domain. + + Do not use this since it allows an attacker to intercept your communications. + """ + + def __init__(self): + self._context = SSL.Context(SSL.SSLv23_METHOD) + self._context.set_verify(VERIFY_NONE, lambda *_: None) + + def getContext(self, hostname, port): + return self._context diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 1c9e552788..b50a0c445c 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -57,14 +57,14 @@ incoming_responses_counter = metrics.register_counter( class MatrixFederationEndpointFactory(object): def __init__(self, hs): - self.tls_context_factory = hs.tls_context_factory + self.tls_server_context_factory = hs.tls_server_context_factory def endpointForURI(self, uri): destination = uri.netloc return matrix_federation_endpoint( reactor, destination, timeout=10, - ssl_context_factory=self.tls_context_factory + ssl_context_factory=self.tls_server_context_factory ) diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index d7bcad8a8a..943d637459 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -17,7 +17,7 @@ from __future__ import absolute_import import logging -from resource import getrusage, getpagesize, RUSAGE_SELF +from resource import getrusage, RUSAGE_SELF import functools import os import stat @@ -100,7 +100,6 @@ def render_all(): # process resource usage rusage = None -PAGE_SIZE = getpagesize() def update_resource_metrics(): @@ -113,8 +112,8 @@ resource_metrics = get_metrics_for("process.resource") resource_metrics.register_callback("utime", lambda: rusage.ru_utime * 1000) resource_metrics.register_callback("stime", lambda: rusage.ru_stime * 1000) -# pages -resource_metrics.register_callback("maxrss", lambda: rusage.ru_maxrss * PAGE_SIZE) +# kilobytes +resource_metrics.register_callback("maxrss", lambda: rusage.ru_maxrss * 1024) TYPES = { stat.S_IFSOCK: "SOCK", @@ -131,6 +130,10 @@ def _process_fds(): counts = {(k,): 0 for k in TYPES.values()} counts[("other",)] = 0 + # Not every OS will have a /proc/self/fd directory + if not os.path.exists("/proc/self/fd"): + return counts + for fd in os.listdir("/proc/self/fd"): try: s = os.stat("/proc/self/fd/%s" % (fd)) diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 795ef27182..e95316720e 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -18,18 +18,18 @@ from distutils.version import LooseVersion logger = logging.getLogger(__name__) REQUIREMENTS = { + "frozendict>=0.4": ["frozendict"], "unpaddedbase64>=1.0.1": ["unpaddedbase64>=1.0.1"], "canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"], "signedjson>=1.0.0": ["signedjson>=1.0.0"], - "Twisted>=15.1.0": ["twisted>=15.1.0"], + "pynacl>=0.3.0": ["nacl>=0.3.0", "nacl.bindings"], "service_identity>=1.0.0": ["service_identity>=1.0.0"], + "Twisted>=15.1.0": ["twisted>=15.1.0"], "pyopenssl>=0.14": ["OpenSSL>=0.14"], "pyyaml": ["yaml"], "pyasn1": ["pyasn1"], - "pynacl>=0.3.0": ["nacl>=0.3.0"], "daemonize": ["daemonize"], "py-bcrypt": ["bcrypt"], - "frozendict>=0.4": ["frozendict"], "pillow": ["PIL"], "pydenticon": ["pydenticon"], "ujson": ["ujson"], @@ -60,7 +60,10 @@ DEPENDENCY_LINKS = { class MissingRequirementError(Exception): - pass + def __init__(self, message, module_name, dependency): + super(MissingRequirementError, self).__init__(message) + self.module_name = module_name + self.dependency = dependency def check_requirements(config=None): @@ -88,7 +91,7 @@ def check_requirements(config=None): ) raise MissingRequirementError( "Can't import %r which is part of %r" - % (module_name, dependency) + % (module_name, dependency), module_name, dependency ) version = getattr(module, "__version__", None) file_path = getattr(module, "__file__", None) @@ -101,23 +104,25 @@ def check_requirements(config=None): if version is None: raise MissingRequirementError( "Version of %r isn't set as __version__ of module %r" - % (dependency, module_name) + % (dependency, module_name), module_name, dependency ) if LooseVersion(version) < LooseVersion(required_version): raise MissingRequirementError( "Version of %r in %r is too old. %r < %r" - % (dependency, file_path, version, required_version) + % (dependency, file_path, version, required_version), + module_name, dependency ) elif version_test == "==": if version is None: raise MissingRequirementError( "Version of %r isn't set as __version__ of module %r" - % (dependency, module_name) + % (dependency, module_name), module_name, dependency ) if LooseVersion(version) != LooseVersion(required_version): raise MissingRequirementError( "Unexpected version of %r in %r. %r != %r" - % (dependency, file_path, version, required_version) + % (dependency, file_path, version, required_version), + module_name, dependency ) diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index b5edffdb60..4692ba413c 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -96,6 +96,7 @@ class ThreepidRestServlet(RestServlet): self.hs = hs self.identity_handler = hs.get_handlers().identity_handler self.auth = hs.get_auth() + self.auth_handler = hs.get_handlers().auth_handler @defer.inlineCallbacks def on_GET(self, request): diff --git a/synapse/server.py b/synapse/server.py index 4d1fb1cbf6..8424798b1b 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -19,7 +19,9 @@ # partial one for unit test mocking. # Imports required for the default HomeServer() implementation +from twisted.web.client import BrowserLikePolicyForHTTPS from synapse.federation import initialize_http_replication +from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory from synapse.notifier import Notifier from synapse.api.auth import Auth from synapse.handlers import Handlers @@ -87,6 +89,8 @@ class BaseHomeServer(object): 'pusherpool', 'event_builder_factory', 'filtering', + 'http_client_context_factory', + 'simple_http_client', ] def __init__(self, hostname, **kwargs): @@ -174,6 +178,17 @@ class HomeServer(BaseHomeServer): def build_auth(self): return Auth(self) + def build_http_client_context_factory(self): + config = self.get_config() + return ( + InsecureInterceptableContextFactory() + if config.use_insecure_ssl_client_just_for_testing_do_not_use + else BrowserLikePolicyForHTTPS() + ) + + def build_simple_http_client(self): + return SimpleHttpClient(self) + def build_v1auth(self): orf = Auth(self) # Matrix spec makes no reference to what HTTP status code is returned, diff --git a/synapse/state.py b/synapse/state.py index 1fe4d066bd..ed36f844cb 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -17,7 +17,6 @@ from twisted.internet import defer from synapse.util.logutils import log_function -from synapse.util.async import run_on_reactor from synapse.util.caches.expiringcache import ExpiringCache from synapse.api.constants import EventTypes from synapse.api.errors import AuthError @@ -119,8 +118,6 @@ class StateHandler(object): Returns: an EventContext """ - yield run_on_reactor() - context = EventContext() if outlier: diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 989ad340b0..c1cabbaa60 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -303,6 +303,15 @@ class EventFederationStore(SQLBaseStore): ], ) + self._update_extremeties(txn, events) + + def _update_extremeties(self, txn, events): + """Updates the event_*_extremities tables based on the new/updated + events being persisted. + + This is called for new events *and* for events that were outliers, but + are are now being persisted as non-outliers. + """ events_by_room = {} for ev in events: events_by_room.setdefault(ev.room_id, []).append(ev) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index fba837f461..0a477e3122 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -281,6 +281,8 @@ class EventsStore(SQLBaseStore): (False, event.event_id,) ) + self._update_extremeties(txn, [event]) + events_and_contexts = filter( lambda ec: ec[0] not in to_remove, events_and_contexts diff --git a/synapse/storage/schema/delta/23/drop_state_index.sql b/synapse/storage/schema/delta/23/drop_state_index.sql new file mode 100644 index 0000000000..07d0ea5cb2 --- /dev/null +++ b/synapse/storage/schema/delta/23/drop_state_index.sql @@ -0,0 +1,16 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +DROP INDEX IF EXISTS state_groups_state_tuple; diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 22fc804331..c96273480d 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -19,17 +19,21 @@ from mock import Mock from synapse.api.auth import Auth from synapse.api.errors import AuthError +from synapse.types import UserID +from tests.utils import setup_test_homeserver + +import pymacaroons class AuthTestCase(unittest.TestCase): + @defer.inlineCallbacks def setUp(self): self.state_handler = Mock() self.store = Mock() - self.hs = Mock() + self.hs = yield setup_test_homeserver(handlers=None) self.hs.get_datastore = Mock(return_value=self.store) - self.hs.get_state_handler = Mock(return_value=self.state_handler) self.auth = Auth(self.hs) self.test_user = "@foo:bar" @@ -133,3 +137,140 @@ class AuthTestCase(unittest.TestCase): request.requestHeaders.getRawHeaders = Mock(return_value=[""]) d = self.auth.get_user_by_req(request) self.failureResultOf(d, AuthError) + + @defer.inlineCallbacks + def test_get_user_from_macaroon(self): + # TODO(danielwh): Remove this mock when we remove the + # get_user_by_access_token fallback. + self.store.get_user_by_access_token = Mock( + return_value={"name": "@baldrick:matrix.org"} + ) + + user_id = "@baldrick:matrix.org" + macaroon = pymacaroons.Macaroon( + location=self.hs.config.server_name, + identifier="key", + key=self.hs.config.macaroon_secret_key) + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("type = access") + macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) + user_info = yield self.auth._get_user_from_macaroon(macaroon.serialize()) + user = user_info["user"] + self.assertEqual(UserID.from_string(user_id), user) + + @defer.inlineCallbacks + def test_get_user_from_macaroon_user_db_mismatch(self): + self.store.get_user_by_access_token = Mock( + return_value={"name": "@percy:matrix.org"} + ) + + user = "@baldrick:matrix.org" + macaroon = pymacaroons.Macaroon( + location=self.hs.config.server_name, + identifier="key", + key=self.hs.config.macaroon_secret_key) + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("type = access") + macaroon.add_first_party_caveat("user_id = %s" % (user,)) + with self.assertRaises(AuthError) as cm: + yield self.auth._get_user_from_macaroon(macaroon.serialize()) + self.assertEqual(401, cm.exception.code) + self.assertIn("User mismatch", cm.exception.msg) + + @defer.inlineCallbacks + def test_get_user_from_macaroon_missing_caveat(self): + # TODO(danielwh): Remove this mock when we remove the + # get_user_by_access_token fallback. + self.store.get_user_by_access_token = Mock( + return_value={"name": "@baldrick:matrix.org"} + ) + + macaroon = pymacaroons.Macaroon( + location=self.hs.config.server_name, + identifier="key", + key=self.hs.config.macaroon_secret_key) + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("type = access") + + with self.assertRaises(AuthError) as cm: + yield self.auth._get_user_from_macaroon(macaroon.serialize()) + self.assertEqual(401, cm.exception.code) + self.assertIn("No user caveat", cm.exception.msg) + + @defer.inlineCallbacks + def test_get_user_from_macaroon_wrong_key(self): + # TODO(danielwh): Remove this mock when we remove the + # get_user_by_access_token fallback. + self.store.get_user_by_access_token = Mock( + return_value={"name": "@baldrick:matrix.org"} + ) + + user = "@baldrick:matrix.org" + macaroon = pymacaroons.Macaroon( + location=self.hs.config.server_name, + identifier="key", + key=self.hs.config.macaroon_secret_key + "wrong") + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("type = access") + macaroon.add_first_party_caveat("user_id = %s" % (user,)) + + with self.assertRaises(AuthError) as cm: + yield self.auth._get_user_from_macaroon(macaroon.serialize()) + self.assertEqual(401, cm.exception.code) + self.assertIn("Invalid macaroon", cm.exception.msg) + + @defer.inlineCallbacks + def test_get_user_from_macaroon_unknown_caveat(self): + # TODO(danielwh): Remove this mock when we remove the + # get_user_by_access_token fallback. + self.store.get_user_by_access_token = Mock( + return_value={"name": "@baldrick:matrix.org"} + ) + + user = "@baldrick:matrix.org" + macaroon = pymacaroons.Macaroon( + location=self.hs.config.server_name, + identifier="key", + key=self.hs.config.macaroon_secret_key) + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("type = access") + macaroon.add_first_party_caveat("user_id = %s" % (user,)) + macaroon.add_first_party_caveat("cunning > fox") + + with self.assertRaises(AuthError) as cm: + yield self.auth._get_user_from_macaroon(macaroon.serialize()) + self.assertEqual(401, cm.exception.code) + self.assertIn("Invalid macaroon", cm.exception.msg) + + @defer.inlineCallbacks + def test_get_user_from_macaroon_expired(self): + # TODO(danielwh): Remove this mock when we remove the + # get_user_by_access_token fallback. + self.store.get_user_by_access_token = Mock( + return_value={"name": "@baldrick:matrix.org"} + ) + + self.store.get_user_by_access_token = Mock( + return_value={"name": "@baldrick:matrix.org"} + ) + + user = "@baldrick:matrix.org" + macaroon = pymacaroons.Macaroon( + location=self.hs.config.server_name, + identifier="key", + key=self.hs.config.macaroon_secret_key) + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("type = access") + macaroon.add_first_party_caveat("user_id = %s" % (user,)) + macaroon.add_first_party_caveat("time < 1") # ms + + self.hs.clock.now = 5000 # seconds + + yield self.auth._get_user_from_macaroon(macaroon.serialize()) + # TODO(daniel): Turn on the check that we validate expiration, when we + # validate expiration (and remove the above line, which will start + # throwing). + # with self.assertRaises(AuthError) as cm: + # yield self.auth._get_user_from_macaroon(macaroon.serialize()) + # self.assertEqual(401, cm.exception.code) + # self.assertIn("Invalid macaroon", cm.exception.msg) diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 91547bdd06..2ee3da0b34 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -76,7 +76,7 @@ class PresenceStateTestCase(unittest.TestCase): "token_id": 1, } - hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token + hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token room_member_handler = hs.handlers.room_member_handler = Mock( spec=[ @@ -169,7 +169,7 @@ class PresenceListTestCase(unittest.TestCase): ] ) - hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token + hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token presence.register_servlets(hs, self.mock_resource) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index ed0ac8d5c8..a2123be81b 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -59,7 +59,7 @@ class RoomPermissionsTestCase(RestTestCase): "user": UserID.from_string(self.auth_user_id), "token_id": 1, } - hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token + hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -444,7 +444,7 @@ class RoomsMemberListTestCase(RestTestCase): "user": UserID.from_string(self.auth_user_id), "token_id": 1, } - hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token + hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -522,7 +522,7 @@ class RoomsCreateTestCase(RestTestCase): "user": UserID.from_string(self.auth_user_id), "token_id": 1, } - hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token + hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -614,7 +614,7 @@ class RoomTopicTestCase(RestTestCase): "token_id": 1, } - hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token + hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -718,7 +718,7 @@ class RoomMemberStateTestCase(RestTestCase): "user": UserID.from_string(self.auth_user_id), "token_id": 1, } - hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token + hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -843,7 +843,7 @@ class RoomMessagesTestCase(RestTestCase): "user": UserID.from_string(self.auth_user_id), "token_id": 1, } - hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token + hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -938,7 +938,7 @@ class RoomInitialSyncTestCase(RestTestCase): "user": UserID.from_string(self.auth_user_id), "token_id": 1, } - hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token + hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 1c4519406d..6395ce79db 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -67,7 +67,7 @@ class RoomTypingTestCase(RestTestCase): "token_id": 1, } - hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token + hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index c472d53043..85096a0326 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -37,9 +37,6 @@ class RestTestCase(unittest.TestCase): self.mock_resource = None self.auth_user_id = None - def mock_get_user_by_access_token(self, token=None): - return self.auth_user_id - @defer.inlineCallbacks def create_room_as(self, room_creator, is_public=True, tok=None): temp_id = self.auth_user_id diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py index ef972a53aa..f45570a1c0 100644 --- a/tests/rest/client/v2_alpha/__init__.py +++ b/tests/rest/client/v2_alpha/__init__.py @@ -48,7 +48,7 @@ class V2AlphaRestTestCase(unittest.TestCase): "user": UserID.from_string(self.USER_ID), "token_id": 1, } - hs.get_auth().get_user_by_access_token = _get_user_by_access_token + hs.get_auth()._get_user_by_access_token = _get_user_by_access_token for r in self.TO_REGISTER: r.register_servlets(hs, self.mock_resource) diff --git a/tests/test_state.py b/tests/test_state.py index 5845358754..55f37c521f 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -204,8 +204,8 @@ class StateTestCase(unittest.TestCase): nodes={ "START": DictObj( type=EventTypes.Create, - state_key="creator", - content={"membership": "@user_id:example.com"}, + state_key="", + content={"creator": "@user_id:example.com"}, depth=1, ), "A": DictObj( @@ -259,8 +259,8 @@ class StateTestCase(unittest.TestCase): nodes={ "START": DictObj( type=EventTypes.Create, - state_key="creator", - content={"membership": "@user_id:example.com"}, + state_key="", + content={"creator": "@user_id:example.com"}, depth=1, ), "A": DictObj( @@ -432,13 +432,19 @@ class StateTestCase(unittest.TestCase): def test_resolve_message_conflict(self): event = create_event(type="test_message", name="event") + creation = create_event( + type=EventTypes.Create, state_key="" + ) + old_state_1 = [ + creation, create_event(type="test1", state_key="1"), create_event(type="test1", state_key="2"), create_event(type="test2", state_key=""), ] old_state_2 = [ + creation, create_event(type="test1", state_key="1"), create_event(type="test3", state_key="2"), create_event(type="test4", state_key=""), @@ -446,7 +452,7 @@ class StateTestCase(unittest.TestCase): context = yield self._get_context(event, old_state_1, old_state_2) - self.assertEqual(len(context.current_state), 5) + self.assertEqual(len(context.current_state), 6) self.assertIsNone(context.state_group) @@ -454,13 +460,19 @@ class StateTestCase(unittest.TestCase): def test_resolve_state_conflict(self): event = create_event(type="test4", state_key="", name="event") + creation = create_event( + type=EventTypes.Create, state_key="" + ) + old_state_1 = [ + creation, create_event(type="test1", state_key="1"), create_event(type="test1", state_key="2"), create_event(type="test2", state_key=""), ] old_state_2 = [ + creation, create_event(type="test1", state_key="1"), create_event(type="test3", state_key="2"), create_event(type="test4", state_key=""), @@ -468,7 +480,7 @@ class StateTestCase(unittest.TestCase): context = yield self._get_context(event, old_state_1, old_state_2) - self.assertEqual(len(context.current_state), 5) + self.assertEqual(len(context.current_state), 6) self.assertIsNone(context.state_group) @@ -484,36 +496,45 @@ class StateTestCase(unittest.TestCase): } ) + creation = create_event( + type=EventTypes.Create, state_key="", + content={"creator": "@foo:bar"} + ) + old_state_1 = [ + creation, member_event, create_event(type="test1", state_key="1", depth=1), ] old_state_2 = [ + creation, member_event, create_event(type="test1", state_key="1", depth=2), ] context = yield self._get_context(event, old_state_1, old_state_2) - self.assertEqual(old_state_2[1], context.current_state[("test1", "1")]) + self.assertEqual(old_state_2[2], context.current_state[("test1", "1")]) # Reverse the depth to make sure we are actually using the depths # during state resolution. old_state_1 = [ + creation, member_event, create_event(type="test1", state_key="1", depth=2), ] old_state_2 = [ + creation, member_event, create_event(type="test1", state_key="1", depth=1), ] context = yield self._get_context(event, old_state_1, old_state_2) - self.assertEqual(old_state_1[1], context.current_state[("test1", "1")]) + self.assertEqual(old_state_1[2], context.current_state[("test1", "1")]) def _get_context(self, event, old_state_1, old_state_2): group_name_1 = "group_name_1" |