diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 22fc804331..c96273480d 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -19,17 +19,21 @@ from mock import Mock
from synapse.api.auth import Auth
from synapse.api.errors import AuthError
+from synapse.types import UserID
+from tests.utils import setup_test_homeserver
+
+import pymacaroons
class AuthTestCase(unittest.TestCase):
+ @defer.inlineCallbacks
def setUp(self):
self.state_handler = Mock()
self.store = Mock()
- self.hs = Mock()
+ self.hs = yield setup_test_homeserver(handlers=None)
self.hs.get_datastore = Mock(return_value=self.store)
- self.hs.get_state_handler = Mock(return_value=self.state_handler)
self.auth = Auth(self.hs)
self.test_user = "@foo:bar"
@@ -133,3 +137,140 @@ class AuthTestCase(unittest.TestCase):
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError)
+
+ @defer.inlineCallbacks
+ def test_get_user_from_macaroon(self):
+ # TODO(danielwh): Remove this mock when we remove the
+ # get_user_by_access_token fallback.
+ self.store.get_user_by_access_token = Mock(
+ return_value={"name": "@baldrick:matrix.org"}
+ )
+
+ user_id = "@baldrick:matrix.org"
+ macaroon = pymacaroons.Macaroon(
+ location=self.hs.config.server_name,
+ identifier="key",
+ key=self.hs.config.macaroon_secret_key)
+ macaroon.add_first_party_caveat("gen = 1")
+ macaroon.add_first_party_caveat("type = access")
+ macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
+ user_info = yield self.auth._get_user_from_macaroon(macaroon.serialize())
+ user = user_info["user"]
+ self.assertEqual(UserID.from_string(user_id), user)
+
+ @defer.inlineCallbacks
+ def test_get_user_from_macaroon_user_db_mismatch(self):
+ self.store.get_user_by_access_token = Mock(
+ return_value={"name": "@percy:matrix.org"}
+ )
+
+ user = "@baldrick:matrix.org"
+ macaroon = pymacaroons.Macaroon(
+ location=self.hs.config.server_name,
+ identifier="key",
+ key=self.hs.config.macaroon_secret_key)
+ macaroon.add_first_party_caveat("gen = 1")
+ macaroon.add_first_party_caveat("type = access")
+ macaroon.add_first_party_caveat("user_id = %s" % (user,))
+ with self.assertRaises(AuthError) as cm:
+ yield self.auth._get_user_from_macaroon(macaroon.serialize())
+ self.assertEqual(401, cm.exception.code)
+ self.assertIn("User mismatch", cm.exception.msg)
+
+ @defer.inlineCallbacks
+ def test_get_user_from_macaroon_missing_caveat(self):
+ # TODO(danielwh): Remove this mock when we remove the
+ # get_user_by_access_token fallback.
+ self.store.get_user_by_access_token = Mock(
+ return_value={"name": "@baldrick:matrix.org"}
+ )
+
+ macaroon = pymacaroons.Macaroon(
+ location=self.hs.config.server_name,
+ identifier="key",
+ key=self.hs.config.macaroon_secret_key)
+ macaroon.add_first_party_caveat("gen = 1")
+ macaroon.add_first_party_caveat("type = access")
+
+ with self.assertRaises(AuthError) as cm:
+ yield self.auth._get_user_from_macaroon(macaroon.serialize())
+ self.assertEqual(401, cm.exception.code)
+ self.assertIn("No user caveat", cm.exception.msg)
+
+ @defer.inlineCallbacks
+ def test_get_user_from_macaroon_wrong_key(self):
+ # TODO(danielwh): Remove this mock when we remove the
+ # get_user_by_access_token fallback.
+ self.store.get_user_by_access_token = Mock(
+ return_value={"name": "@baldrick:matrix.org"}
+ )
+
+ user = "@baldrick:matrix.org"
+ macaroon = pymacaroons.Macaroon(
+ location=self.hs.config.server_name,
+ identifier="key",
+ key=self.hs.config.macaroon_secret_key + "wrong")
+ macaroon.add_first_party_caveat("gen = 1")
+ macaroon.add_first_party_caveat("type = access")
+ macaroon.add_first_party_caveat("user_id = %s" % (user,))
+
+ with self.assertRaises(AuthError) as cm:
+ yield self.auth._get_user_from_macaroon(macaroon.serialize())
+ self.assertEqual(401, cm.exception.code)
+ self.assertIn("Invalid macaroon", cm.exception.msg)
+
+ @defer.inlineCallbacks
+ def test_get_user_from_macaroon_unknown_caveat(self):
+ # TODO(danielwh): Remove this mock when we remove the
+ # get_user_by_access_token fallback.
+ self.store.get_user_by_access_token = Mock(
+ return_value={"name": "@baldrick:matrix.org"}
+ )
+
+ user = "@baldrick:matrix.org"
+ macaroon = pymacaroons.Macaroon(
+ location=self.hs.config.server_name,
+ identifier="key",
+ key=self.hs.config.macaroon_secret_key)
+ macaroon.add_first_party_caveat("gen = 1")
+ macaroon.add_first_party_caveat("type = access")
+ macaroon.add_first_party_caveat("user_id = %s" % (user,))
+ macaroon.add_first_party_caveat("cunning > fox")
+
+ with self.assertRaises(AuthError) as cm:
+ yield self.auth._get_user_from_macaroon(macaroon.serialize())
+ self.assertEqual(401, cm.exception.code)
+ self.assertIn("Invalid macaroon", cm.exception.msg)
+
+ @defer.inlineCallbacks
+ def test_get_user_from_macaroon_expired(self):
+ # TODO(danielwh): Remove this mock when we remove the
+ # get_user_by_access_token fallback.
+ self.store.get_user_by_access_token = Mock(
+ return_value={"name": "@baldrick:matrix.org"}
+ )
+
+ self.store.get_user_by_access_token = Mock(
+ return_value={"name": "@baldrick:matrix.org"}
+ )
+
+ user = "@baldrick:matrix.org"
+ macaroon = pymacaroons.Macaroon(
+ location=self.hs.config.server_name,
+ identifier="key",
+ key=self.hs.config.macaroon_secret_key)
+ macaroon.add_first_party_caveat("gen = 1")
+ macaroon.add_first_party_caveat("type = access")
+ macaroon.add_first_party_caveat("user_id = %s" % (user,))
+ macaroon.add_first_party_caveat("time < 1") # ms
+
+ self.hs.clock.now = 5000 # seconds
+
+ yield self.auth._get_user_from_macaroon(macaroon.serialize())
+ # TODO(daniel): Turn on the check that we validate expiration, when we
+ # validate expiration (and remove the above line, which will start
+ # throwing).
+ # with self.assertRaises(AuthError) as cm:
+ # yield self.auth._get_user_from_macaroon(macaroon.serialize())
+ # self.assertEqual(401, cm.exception.code)
+ # self.assertIn("Invalid macaroon", cm.exception.msg)
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
index 91547bdd06..2ee3da0b34 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/v1/test_presence.py
@@ -76,7 +76,7 @@ class PresenceStateTestCase(unittest.TestCase):
"token_id": 1,
}
- hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
+ hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
room_member_handler = hs.handlers.room_member_handler = Mock(
spec=[
@@ -169,7 +169,7 @@ class PresenceListTestCase(unittest.TestCase):
]
)
- hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
+ hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
presence.register_servlets(hs, self.mock_resource)
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 34ab47d02e..9fb2bfb315 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -59,7 +59,7 @@ class RoomPermissionsTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id),
"token_id": 1,
}
- hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
+ hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
@@ -444,7 +444,7 @@ class RoomsMemberListTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id),
"token_id": 1,
}
- hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
+ hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
@@ -522,7 +522,7 @@ class RoomsCreateTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id),
"token_id": 1,
}
- hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
+ hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
@@ -614,7 +614,7 @@ class RoomTopicTestCase(RestTestCase):
"token_id": 1,
}
- hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
+ hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
@@ -718,7 +718,7 @@ class RoomMemberStateTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id),
"token_id": 1,
}
- hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
+ hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
@@ -843,7 +843,7 @@ class RoomMessagesTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id),
"token_id": 1,
}
- hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
+ hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
@@ -938,7 +938,7 @@ class RoomInitialSyncTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id),
"token_id": 1,
}
- hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
+ hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 1c4519406d..6395ce79db 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -67,7 +67,7 @@ class RoomTypingTestCase(RestTestCase):
"token_id": 1,
}
- hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
+ hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index c472d53043..85096a0326 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -37,9 +37,6 @@ class RestTestCase(unittest.TestCase):
self.mock_resource = None
self.auth_user_id = None
- def mock_get_user_by_access_token(self, token=None):
- return self.auth_user_id
-
@defer.inlineCallbacks
def create_room_as(self, room_creator, is_public=True, tok=None):
temp_id = self.auth_user_id
diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py
index ef972a53aa..f45570a1c0 100644
--- a/tests/rest/client/v2_alpha/__init__.py
+++ b/tests/rest/client/v2_alpha/__init__.py
@@ -48,7 +48,7 @@ class V2AlphaRestTestCase(unittest.TestCase):
"user": UserID.from_string(self.USER_ID),
"token_id": 1,
}
- hs.get_auth().get_user_by_access_token = _get_user_by_access_token
+ hs.get_auth()._get_user_by_access_token = _get_user_by_access_token
for r in self.TO_REGISTER:
r.register_servlets(hs, self.mock_resource)
|