diff options
-rw-r--r-- | synapse/api/auth.py | 29 | ||||
-rw-r--r-- | synapse/metrics/metric.py | 3 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/register.py | 54 | ||||
-rw-r--r-- | synapse/storage/client_ips.py | 3 | ||||
-rw-r--r-- | tests/api/test_auth.py | 10 | ||||
-rw-r--r-- | tests/rest/client/v2_alpha/test_register.py | 13 |
6 files changed, 86 insertions, 26 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py index ff7d816cfc..eca8513905 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -586,6 +586,10 @@ class Auth(object): token_id = user_info["token_id"] is_guest = user_info["is_guest"] + # device_id may not be present if get_user_by_access_token has been + # stubbed out. + device_id = user_info.get("device_id") + ip_addr = self.hs.get_ip_from_request(request) user_agent = request.requestHeaders.getRawHeaders( "User-Agent", @@ -597,7 +601,8 @@ class Auth(object): user=user, access_token=access_token, ip=ip_addr, - user_agent=user_agent + user_agent=user_agent, + device_id=device_id, ) if is_guest and not allow_guest: @@ -695,6 +700,7 @@ class Auth(object): "user": user, "is_guest": True, "token_id": None, + "device_id": None, } elif rights == "delete_pusher": # We don't store these tokens in the database @@ -702,13 +708,20 @@ class Auth(object): "user": user, "is_guest": False, "token_id": None, + "device_id": None, } else: - # This codepath exists so that we can actually return a - # token ID, because we use token IDs in place of device - # identifiers throughout the codebase. - # TODO(daniel): Remove this fallback when device IDs are - # properly implemented. + # This codepath exists for several reasons: + # * so that we can actually return a token ID, which is used + # in some parts of the schema (where we probably ought to + # use device IDs instead) + # * the only way we currently have to invalidate an + # access_token is by removing it from the database, so we + # have to check here that it is still in the db + # * some attributes (notably device_id) aren't stored in the + # macaroon. They probably should be. + # TODO: build the dictionary from the macaroon once the + # above are fixed ret = yield self._look_up_user_by_access_token(macaroon_str) if ret["user"] != user: logger.error( @@ -782,10 +795,14 @@ class Auth(object): self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.", errcode=Codes.UNKNOWN_TOKEN ) + # we use ret.get() below because *lots* of unit tests stub out + # get_user_by_access_token in a way where it only returns a couple of + # the fields. user_info = { "user": UserID.from_string(ret.get("name")), "token_id": ret.get("token_id", None), "is_guest": False, + "device_id": ret.get("device_id"), } defer.returnValue(user_info) diff --git a/synapse/metrics/metric.py b/synapse/metrics/metric.py index d100841a7f..7becbe0491 100644 --- a/synapse/metrics/metric.py +++ b/synapse/metrics/metric.py @@ -180,6 +180,9 @@ class MemoryUsageMetric(object): self.memory_snapshots[:] = self.memory_snapshots[-max_size:] def render(self): + if not self.memory_snapshots: + return [] + max_rss = max(self.memory_snapshots) min_rss = min(self.memory_snapshots) sum_rss = sum(self.memory_snapshots) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index b7e03ea9d1..d401722224 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -93,6 +93,7 @@ class RegisterRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_handlers().registration_handler self.identity_handler = hs.get_handlers().identity_handler + self.device_handler = hs.get_device_handler() @defer.inlineCallbacks def on_POST(self, request): @@ -145,7 +146,7 @@ class RegisterRestServlet(RestServlet): if isinstance(desired_username, basestring): result = yield self._do_appservice_registration( - desired_username, request.args["access_token"][0] + desired_username, request.args["access_token"][0], body ) defer.returnValue((200, result)) # we throw for non 200 responses return @@ -155,7 +156,7 @@ class RegisterRestServlet(RestServlet): # FIXME: Should we really be determining if this is shared secret # auth based purely on the 'mac' key? result = yield self._do_shared_secret_registration( - desired_username, desired_password, body["mac"] + desired_username, desired_password, body ) defer.returnValue((200, result)) # we throw for non 200 responses return @@ -236,7 +237,7 @@ class RegisterRestServlet(RestServlet): add_email = True result = yield self._create_registration_details( - registered_user_id + registered_user_id, body ) if add_email and result and LoginType.EMAIL_IDENTITY in result: @@ -252,14 +253,14 @@ class RegisterRestServlet(RestServlet): return 200, {} @defer.inlineCallbacks - def _do_appservice_registration(self, username, as_token): + def _do_appservice_registration(self, username, as_token, body): user_id = yield self.registration_handler.appservice_register( username, as_token ) - defer.returnValue((yield self._create_registration_details(user_id))) + defer.returnValue((yield self._create_registration_details(user_id, body))) @defer.inlineCallbacks - def _do_shared_secret_registration(self, username, password, mac): + def _do_shared_secret_registration(self, username, password, body): if not self.hs.config.registration_shared_secret: raise SynapseError(400, "Shared secret registration is not enabled") @@ -267,7 +268,7 @@ class RegisterRestServlet(RestServlet): # str() because otherwise hmac complains that 'unicode' does not # have the buffer interface - got_mac = str(mac) + got_mac = str(body["mac"]) want_mac = hmac.new( key=self.hs.config.registration_shared_secret, @@ -284,7 +285,7 @@ class RegisterRestServlet(RestServlet): localpart=username, password=password, generate_token=False, ) - result = yield self._create_registration_details(user_id) + result = yield self._create_registration_details(user_id, body) defer.returnValue(result) @defer.inlineCallbacks @@ -358,35 +359,58 @@ class RegisterRestServlet(RestServlet): defer.returnValue() @defer.inlineCallbacks - def _create_registration_details(self, user_id): + def _create_registration_details(self, user_id, body): """Complete registration of newly-registered user - Issues access_token and refresh_token, and builds the success response - body. + Allocates device_id if one was not given; also creates access_token + and refresh_token. Args: (str) user_id: full canonical @user:id - + (object) body: dictionary supplied to /register call, from + which we pull device_id and initial_device_name Returns: defer.Deferred: (object) dictionary for response from /register """ + device_id = yield self._register_device(user_id, body) access_token = yield self.auth_handler.issue_access_token( - user_id + user_id, device_id=device_id ) refresh_token = yield self.auth_handler.issue_refresh_token( - user_id + user_id, device_id=device_id ) - defer.returnValue({ "user_id": user_id, "access_token": access_token, "home_server": self.hs.hostname, "refresh_token": refresh_token, + "device_id": device_id, }) + def _register_device(self, user_id, body): + """Register a device for a user. + + This is called after the user's credentials have been validated, but + before the access token has been issued. + + Args: + (str) user_id: full canonical @user:id + (object) body: dictionary supplied to /register call, from + which we pull device_id and initial_device_name + Returns: + defer.Deferred: (str) device_id + """ + # register the user's device + device_id = body.get("device_id") + initial_display_name = body.get("initial_device_display_name") + device_id = self.device_handler.check_device_registered( + user_id, device_id, initial_display_name + ) + return device_id + @defer.inlineCallbacks def _do_guest_registration(self): if not self.hs.config.allow_guest_access: diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py index 07161496ca..365f08650d 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py @@ -38,7 +38,7 @@ class ClientIpStore(SQLBaseStore): super(ClientIpStore, self).__init__(hs) @defer.inlineCallbacks - def insert_client_ip(self, user, access_token, ip, user_agent): + def insert_client_ip(self, user, access_token, ip, user_agent, device_id): now = int(self._clock.time_msec()) key = (user.to_string(), access_token, ip) @@ -62,6 +62,7 @@ class ClientIpStore(SQLBaseStore): "access_token": access_token, "ip": ip, "user_agent": user_agent, + "device_id": device_id, }, values={ "last_seen": now, diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 960c23d631..e91723ca3d 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -45,6 +45,7 @@ class AuthTestCase(unittest.TestCase): user_info = { "name": self.test_user, "token_id": "ditto", + "device_id": "device", } self.store.get_user_by_access_token = Mock(return_value=user_info) @@ -143,7 +144,10 @@ class AuthTestCase(unittest.TestCase): # TODO(danielwh): Remove this mock when we remove the # get_user_by_access_token fallback. self.store.get_user_by_access_token = Mock( - return_value={"name": "@baldrick:matrix.org"} + return_value={ + "name": "@baldrick:matrix.org", + "device_id": "device", + } ) user_id = "@baldrick:matrix.org" @@ -158,6 +162,10 @@ class AuthTestCase(unittest.TestCase): user = user_info["user"] self.assertEqual(UserID.from_string(user_id), user) + # TODO: device_id should come from the macaroon, but currently comes + # from the db. + self.assertEqual(user_info["device_id"], "device") + @defer.inlineCallbacks def test_get_guest_user_from_macaroon(self): user_id = "@baldrick:matrix.org" diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index ccbb8776d3..3bd7065e32 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -30,6 +30,7 @@ class RegisterRestServletTestCase(unittest.TestCase): self.registration_handler = Mock() self.identity_handler = Mock() self.login_handler = Mock() + self.device_handler = Mock() # do the dance to hook it up to the hs global self.handlers = Mock( @@ -42,6 +43,7 @@ class RegisterRestServletTestCase(unittest.TestCase): self.hs.get_auth = Mock(return_value=self.auth) self.hs.get_handlers = Mock(return_value=self.handlers) self.hs.get_auth_handler = Mock(return_value=self.auth_handler) + self.hs.get_device_handler = Mock(return_value=self.device_handler) self.hs.config.enable_registration = True # init the thing we're testing @@ -107,9 +109,11 @@ class RegisterRestServletTestCase(unittest.TestCase): def test_POST_user_valid(self): user_id = "@kermit:muppet" token = "kermits_access_token" + device_id = "frogfone" self.request_data = json.dumps({ "username": "kermit", - "password": "monkey" + "password": "monkey", + "device_id": device_id, }) self.registration_handler.check_username = Mock(return_value=True) self.auth_result = (True, None, { @@ -118,18 +122,21 @@ class RegisterRestServletTestCase(unittest.TestCase): }, None) self.registration_handler.register = Mock(return_value=(user_id, None)) self.auth_handler.issue_access_token = Mock(return_value=token) + self.device_handler.check_device_registered = \ + Mock(return_value=device_id) (code, result) = yield self.servlet.on_POST(self.request) self.assertEquals(code, 200) det_data = { "user_id": user_id, "access_token": token, - "home_server": self.hs.hostname + "home_server": self.hs.hostname, + "device_id": device_id, } self.assertDictContainsSubset(det_data, result) self.assertIn("refresh_token", result) self.auth_handler.issue_access_token.assert_called_once_with( - user_id) + user_id, device_id=device_id) def test_POST_disabled_registration(self): self.hs.config.enable_registration = False |