diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 65ee1452ce..f8ea1e2c69 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -23,6 +23,7 @@ from synapse.util.logutils import log_function
from synapse.types import UserID
import logging
+import pymacaroons
logger = logging.getLogger(__name__)
@@ -40,6 +41,12 @@ class Auth(object):
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
+ self._KNOWN_CAVEAT_PREFIXES = set([
+ "gen = ",
+ "type = ",
+ "time < ",
+ "user_id = ",
+ ])
def check(self, event, auth_events):
""" Checks if this event is correctly authed.
@@ -359,8 +366,8 @@ class Auth(object):
except KeyError:
pass # normal users won't have the user_id query parameter set.
- user_info = yield self.get_user_by_access_token(access_token)
- user = user_info["user"]
+ user_info = yield self._get_user_by_access_token(access_token)
+ user_id = user_info["user_id"]
token_id = user_info["token_id"]
ip_addr = self.hs.get_ip_from_request(request)
@@ -368,17 +375,17 @@ class Auth(object):
"User-Agent",
default=[""]
)[0]
- if user and access_token and ip_addr:
+ if user_id and access_token and ip_addr:
self.store.insert_client_ip(
- user=user,
+ user=user_id,
access_token=access_token,
ip=ip_addr,
user_agent=user_agent
)
- request.authenticated_entity = user.to_string()
+ request.authenticated_entity = user_id.to_string()
- defer.returnValue((user, token_id,))
+ defer.returnValue((user_id, token_id,))
except KeyError:
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
@@ -386,7 +393,7 @@ class Auth(object):
)
@defer.inlineCallbacks
- def get_user_by_access_token(self, token):
+ def _get_user_by_access_token(self, token):
""" Get a registered user's ID.
Args:
@@ -396,6 +403,86 @@ class Auth(object):
Raises:
AuthError if no user by that token exists or the token is invalid.
"""
+ try:
+ ret = yield self._get_user_from_macaroon(token)
+ except AuthError:
+ # TODO(daniel): Remove this fallback when all existing access tokens
+ # have been re-issued as macaroons.
+ ret = yield self._look_up_user_by_access_token(token)
+ defer.returnValue(ret)
+
+ @defer.inlineCallbacks
+ def _get_user_from_macaroon(self, macaroon_str):
+ try:
+ macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
+ self._validate_macaroon(macaroon)
+
+ user_prefix = "user_id = "
+ for caveat in macaroon.caveats:
+ if caveat.caveat_id.startswith(user_prefix):
+ user_id = UserID.from_string(caveat.caveat_id[len(user_prefix):])
+ # This codepath exists so that we can actually return a
+ # token ID, because we use token IDs in place of device
+ # identifiers throughout the codebase.
+ # TODO(daniel): Remove this fallback when device IDs are
+ # properly implemented.
+ ret = yield self._look_up_user_by_access_token(macaroon_str)
+ if ret["user_id"] != user_id:
+ logger.error(
+ "Macaroon user (%s) != DB user (%s)",
+ user_id,
+ ret["user_id"]
+ )
+ raise AuthError(
+ self.TOKEN_NOT_FOUND_HTTP_STATUS,
+ "User mismatch in macaroon",
+ errcode=Codes.UNKNOWN_TOKEN
+ )
+ defer.returnValue(ret)
+ raise AuthError(
+ self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
+ errcode=Codes.UNKNOWN_TOKEN
+ )
+ except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
+ raise AuthError(
+ self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.",
+ errcode=Codes.UNKNOWN_TOKEN
+ )
+
+ def _validate_macaroon(self, macaroon):
+ v = pymacaroons.Verifier()
+ v.satisfy_exact("gen = 1")
+ v.satisfy_exact("type = access")
+ v.satisfy_general(lambda c: c.startswith("user_id = "))
+ v.satisfy_general(self._verify_expiry)
+ v.verify(macaroon, self.hs.config.macaroon_secret_key)
+
+ v = pymacaroons.Verifier()
+ v.satisfy_general(self._verify_recognizes_caveats)
+ v.verify(macaroon, self.hs.config.macaroon_secret_key)
+
+ def _verify_expiry(self, caveat):
+ prefix = "time < "
+ if not caveat.startswith(prefix):
+ return False
+ # TODO(daniel): Enable expiry check when clients actually know how to
+ # refresh tokens. (And remember to enable the tests)
+ return True
+ expiry = int(caveat[len(prefix):])
+ now = self.hs.get_clock().time_msec()
+ return now < expiry
+
+ def _verify_recognizes_caveats(self, caveat):
+ first_space = caveat.find(" ")
+ if first_space < 0:
+ return False
+ second_space = caveat.find(" ", first_space + 1)
+ if second_space < 0:
+ return False
+ return caveat[:second_space + 1] in self._KNOWN_CAVEAT_PREFIXES
+
+ @defer.inlineCallbacks
+ def _look_up_user_by_access_token(self, token):
ret = yield self.store.get_user_by_access_token(token)
if not ret:
raise AuthError(
@@ -403,10 +490,9 @@ class Auth(object):
errcode=Codes.UNKNOWN_TOKEN
)
user_info = {
- "user": UserID.from_string(ret.get("name")),
+ "user_id": UserID.from_string(ret.get("name")),
"token_id": ret.get("token_id", None),
}
-
defer.returnValue(user_info)
@defer.inlineCallbacks
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 22fc804331..1ba85d6f83 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -14,22 +14,27 @@
# limitations under the License.
from tests import unittest
from twisted.internet import defer
+from twisted.trial.unittest import FailTest
from mock import Mock
from synapse.api.auth import Auth
from synapse.api.errors import AuthError
+from synapse.types import UserID
+from tests.utils import setup_test_homeserver
+
+import pymacaroons
class AuthTestCase(unittest.TestCase):
+ @defer.inlineCallbacks
def setUp(self):
self.state_handler = Mock()
self.store = Mock()
- self.hs = Mock()
+ self.hs = yield setup_test_homeserver(handlers=None)
self.hs.get_datastore = Mock(return_value=self.store)
- self.hs.get_state_handler = Mock(return_value=self.state_handler)
self.auth = Auth(self.hs)
self.test_user = "@foo:bar"
@@ -133,3 +138,136 @@ class AuthTestCase(unittest.TestCase):
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError)
+
+ @defer.inlineCallbacks
+ def test_get_user_from_macaroon(self):
+ # TODO(danielwh): Remove this mock when we remove the
+ # get_user_by_access_token fallback.
+ self.store.get_user_by_access_token = Mock(
+ return_value={"name": "@baldrick:matrix.org"}
+ )
+
+ user = "@baldrick:matrix.org"
+ macaroon = pymacaroons.Macaroon(
+ location=self.hs.config.server_name,
+ identifier="key",
+ key=self.hs.config.macaroon_secret_key)
+ macaroon.add_first_party_caveat("gen = 1")
+ macaroon.add_first_party_caveat("type = access")
+ macaroon.add_first_party_caveat("user_id = %s" % (user,))
+ user_info = yield self.auth._get_user_from_macaroon(macaroon.serialize())
+ user_id = user_info["user_id"]
+ self.assertEqual(UserID.from_string(user), user_id)
+
+ @defer.inlineCallbacks
+ def test_get_user_from_macaroon_user_db_mismatch(self):
+ self.store.get_user_by_access_token = Mock(
+ return_value={"name": "@percy:matrix.org"}
+ )
+
+ user = "@baldrick:matrix.org"
+ macaroon = pymacaroons.Macaroon(
+ location=self.hs.config.server_name,
+ identifier="key",
+ key=self.hs.config.macaroon_secret_key)
+ macaroon.add_first_party_caveat("gen = 1")
+ macaroon.add_first_party_caveat("type = access")
+ macaroon.add_first_party_caveat("user_id = %s" % (user,))
+ with self.assertRaises(AuthError) as cm:
+ yield self.auth._get_user_from_macaroon(macaroon.serialize())
+ self.assertEqual(401, cm.exception.code)
+ self.assertIn("User mismatch", cm.exception.msg)
+
+ @defer.inlineCallbacks
+ def test_get_user_from_macaroon_missing_caveat(self):
+ # TODO(danielwh): Remove this mock when we remove the
+ # get_user_by_access_token fallback.
+ self.store.get_user_by_access_token = Mock(
+ return_value={"name": "@baldrick:matrix.org"}
+ )
+
+ macaroon = pymacaroons.Macaroon(
+ location=self.hs.config.server_name,
+ identifier="key",
+ key=self.hs.config.macaroon_secret_key)
+ macaroon.add_first_party_caveat("gen = 1")
+ macaroon.add_first_party_caveat("type = access")
+
+ with self.assertRaises(AuthError) as cm:
+ yield self.auth._get_user_from_macaroon(macaroon.serialize())
+ self.assertEqual(401, cm.exception.code)
+ self.assertIn("No user caveat", cm.exception.msg)
+
+ @defer.inlineCallbacks
+ def test_get_user_from_macaroon_wrong_key(self):
+ # TODO(danielwh): Remove this mock when we remove the
+ # get_user_by_access_token fallback.
+ self.store.get_user_by_access_token = Mock(
+ return_value={"name": "@baldrick:matrix.org"}
+ )
+
+ user = "@baldrick:matrix.org"
+ macaroon = pymacaroons.Macaroon(
+ location=self.hs.config.server_name,
+ identifier="key",
+ key=self.hs.config.macaroon_secret_key + "wrong")
+ macaroon.add_first_party_caveat("gen = 1")
+ macaroon.add_first_party_caveat("type = access")
+ macaroon.add_first_party_caveat("user_id = %s" % (user,))
+
+ with self.assertRaises(AuthError) as cm:
+ yield self.auth._get_user_from_macaroon(macaroon.serialize())
+ self.assertEqual(401, cm.exception.code)
+ self.assertIn("Invalid macaroon", cm.exception.msg)
+
+ @defer.inlineCallbacks
+ def test_get_user_from_macaroon_unknown_caveat(self):
+ # TODO(danielwh): Remove this mock when we remove the
+ # get_user_by_access_token fallback.
+ self.store.get_user_by_access_token = Mock(
+ return_value={"name": "@baldrick:matrix.org"}
+ )
+
+ user = "@baldrick:matrix.org"
+ macaroon = pymacaroons.Macaroon(
+ location=self.hs.config.server_name,
+ identifier="key",
+ key=self.hs.config.macaroon_secret_key)
+ macaroon.add_first_party_caveat("gen = 1")
+ macaroon.add_first_party_caveat("type = access")
+ macaroon.add_first_party_caveat("user_id = %s" % (user,))
+ macaroon.add_first_party_caveat("cunning > fox")
+
+ with self.assertRaises(AuthError) as cm:
+ yield self.auth._get_user_from_macaroon(macaroon.serialize())
+ self.assertEqual(401, cm.exception.code)
+ self.assertIn("Invalid macaroon", cm.exception.msg)
+
+ @defer.inlineCallbacks
+ def test_get_user_from_macaroon_expired(self):
+ # TODO(danielwh): Remove this mock when we remove the
+ # get_user_by_access_token fallback.
+ self.store.get_user_by_access_token = Mock(
+ return_value={"name": "@baldrick:matrix.org"}
+ )
+
+ self.todo = (FailTest, "Token expiry isn't currently enabled",)
+ self.store.get_user_by_access_token = Mock(
+ return_value={"name": "@baldrick:matrix.org"}
+ )
+
+ user = "@baldrick:matrix.org"
+ macaroon = pymacaroons.Macaroon(
+ location=self.hs.config.server_name,
+ identifier="key",
+ key=self.hs.config.macaroon_secret_key)
+ macaroon.add_first_party_caveat("gen = 1")
+ macaroon.add_first_party_caveat("type = access")
+ macaroon.add_first_party_caveat("user_id = %s" % (user,))
+ macaroon.add_first_party_caveat("time < 1") # ms
+
+ self.hs.clock.now = 5000 # seconds
+ with self.assertRaises(AuthError) as cm:
+ yield self.auth._get_user_from_macaroon(macaroon.serialize())
+ self.assertEqual(401, cm.exception.code)
+ self.assertIn("Invalid macaroon", cm.exception.msg)
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
index 91547bdd06..d8d1416f59 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/v1/test_presence.py
@@ -72,11 +72,11 @@ class PresenceStateTestCase(unittest.TestCase):
def _get_user_by_access_token(token=None):
return {
- "user": UserID.from_string(myid),
+ "user_id": UserID.from_string(myid),
"token_id": 1,
}
- hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
+ hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
room_member_handler = hs.handlers.room_member_handler = Mock(
spec=[
@@ -159,7 +159,7 @@ class PresenceListTestCase(unittest.TestCase):
def _get_user_by_access_token(token=None):
return {
- "user": UserID.from_string(myid),
+ "user_id": UserID.from_string(myid),
"token_id": 1,
}
@@ -169,7 +169,7 @@ class PresenceListTestCase(unittest.TestCase):
]
)
- hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
+ hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
presence.register_servlets(hs, self.mock_resource)
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 34ab47d02e..be1d52f720 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -56,10 +56,10 @@ class RoomPermissionsTestCase(RestTestCase):
def _get_user_by_access_token(token=None):
return {
- "user": UserID.from_string(self.auth_user_id),
+ "user_id": UserID.from_string(self.auth_user_id),
"token_id": 1,
}
- hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
+ hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
@@ -441,10 +441,10 @@ class RoomsMemberListTestCase(RestTestCase):
def _get_user_by_access_token(token=None):
return {
- "user": UserID.from_string(self.auth_user_id),
+ "user_id": UserID.from_string(self.auth_user_id),
"token_id": 1,
}
- hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
+ hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
@@ -519,10 +519,10 @@ class RoomsCreateTestCase(RestTestCase):
def _get_user_by_access_token(token=None):
return {
- "user": UserID.from_string(self.auth_user_id),
+ "user_id": UserID.from_string(self.auth_user_id),
"token_id": 1,
}
- hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
+ hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
@@ -610,11 +610,11 @@ class RoomTopicTestCase(RestTestCase):
def _get_user_by_access_token(token=None):
return {
- "user": UserID.from_string(self.auth_user_id),
+ "user_id": UserID.from_string(self.auth_user_id),
"token_id": 1,
}
- hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
+ hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
@@ -715,10 +715,10 @@ class RoomMemberStateTestCase(RestTestCase):
def _get_user_by_access_token(token=None):
return {
- "user": UserID.from_string(self.auth_user_id),
+ "user_id": UserID.from_string(self.auth_user_id),
"token_id": 1,
}
- hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
+ hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
@@ -840,10 +840,10 @@ class RoomMessagesTestCase(RestTestCase):
def _get_user_by_access_token(token=None):
return {
- "user": UserID.from_string(self.auth_user_id),
+ "user_id": UserID.from_string(self.auth_user_id),
"token_id": 1,
}
- hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
+ hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
@@ -935,10 +935,10 @@ class RoomInitialSyncTestCase(RestTestCase):
def _get_user_by_access_token(token=None):
return {
- "user": UserID.from_string(self.auth_user_id),
+ "user_id": UserID.from_string(self.auth_user_id),
"token_id": 1,
}
- hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
+ hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 1c4519406d..da6fc975f7 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -63,11 +63,11 @@ class RoomTypingTestCase(RestTestCase):
def _get_user_by_access_token(token=None):
return {
- "user": UserID.from_string(self.auth_user_id),
+ "user_id": UserID.from_string(self.auth_user_id),
"token_id": 1,
}
- hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
+ hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index c472d53043..85096a0326 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -37,9 +37,6 @@ class RestTestCase(unittest.TestCase):
self.mock_resource = None
self.auth_user_id = None
- def mock_get_user_by_access_token(self, token=None):
- return self.auth_user_id
-
@defer.inlineCallbacks
def create_room_as(self, room_creator, is_public=True, tok=None):
temp_id = self.auth_user_id
diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py
index ef972a53aa..7d0f77a3ee 100644
--- a/tests/rest/client/v2_alpha/__init__.py
+++ b/tests/rest/client/v2_alpha/__init__.py
@@ -45,10 +45,10 @@ class V2AlphaRestTestCase(unittest.TestCase):
def _get_user_by_access_token(token=None):
return {
- "user": UserID.from_string(self.USER_ID),
+ "user_id": UserID.from_string(self.USER_ID),
"token_id": 1,
}
- hs.get_auth().get_user_by_access_token = _get_user_by_access_token
+ hs.get_auth()._get_user_by_access_token = _get_user_by_access_token
for r in self.TO_REGISTER:
r.register_servlets(hs, self.mock_resource)
|