diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index c96273480d..70d928defe 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -51,7 +51,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={})
request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
- (user, _) = yield self.auth.get_user_by_req(request)
+ (user, _, _) = yield self.auth.get_user_by_req(request)
self.assertEquals(user.to_string(), self.test_user)
def test_get_user_by_req_user_bad_token(self):
@@ -86,7 +86,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={})
request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
- (user, _) = yield self.auth.get_user_by_req(request)
+ (user, _, _) = yield self.auth.get_user_by_req(request)
self.assertEquals(user.to_string(), self.test_user)
def test_get_user_by_req_appservice_bad_token(self):
@@ -121,7 +121,7 @@ class AuthTestCase(unittest.TestCase):
request.args["access_token"] = [self.test_token]
request.args["user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
- (user, _) = yield self.auth.get_user_by_req(request)
+ (user, _, _) = yield self.auth.get_user_by_req(request)
self.assertEquals(user.to_string(), masquerading_user_id)
def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
@@ -159,6 +159,25 @@ class AuthTestCase(unittest.TestCase):
self.assertEqual(UserID.from_string(user_id), user)
@defer.inlineCallbacks
+ def test_get_guest_user_from_macaroon(self):
+ user_id = "@baldrick:matrix.org"
+ macaroon = pymacaroons.Macaroon(
+ location=self.hs.config.server_name,
+ identifier="key",
+ key=self.hs.config.macaroon_secret_key)
+ macaroon.add_first_party_caveat("gen = 1")
+ macaroon.add_first_party_caveat("type = access")
+ macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
+ macaroon.add_first_party_caveat("guest = true")
+ serialized = macaroon.serialize()
+
+ user_info = yield self.auth._get_user_from_macaroon(serialized)
+ user = user_info["user"]
+ is_guest = user_info["is_guest"]
+ self.assertEqual(UserID.from_string(user_id), user)
+ self.assertTrue(is_guest)
+
+ @defer.inlineCallbacks
def test_get_user_from_macaroon_user_db_mismatch(self):
self.store.get_user_by_access_token = Mock(
return_value={"name": "@percy:matrix.org"}
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
index 0e3b922246..3e0f294630 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/v1/test_presence.py
@@ -86,10 +86,11 @@ class PresenceStateTestCase(unittest.TestCase):
return defer.succeed([])
self.datastore.get_presence_list = get_presence_list
- def _get_user_by_access_token(token=None):
+ def _get_user_by_access_token(token=None, allow_guest=False):
return {
"user": UserID.from_string(myid),
"token_id": 1,
+ "is_guest": False,
}
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
@@ -173,10 +174,11 @@ class PresenceListTestCase(unittest.TestCase):
)
self.datastore.has_presence_state = has_presence_state
- def _get_user_by_access_token(token=None):
+ def _get_user_by_access_token(token=None, allow_guest=False):
return {
"user": UserID.from_string(myid),
"token_id": 1,
+ "is_guest": False,
}
hs.handlers.room_member_handler = Mock(
@@ -291,8 +293,8 @@ class PresenceEventStreamTestCase(unittest.TestCase):
hs.get_clock().time_msec.return_value = 1000000
- def _get_user_by_req(req=None):
- return (UserID.from_string(myid), "")
+ def _get_user_by_req(req=None, allow_guest=False):
+ return (UserID.from_string(myid), "", False)
hs.get_v1auth().get_user_by_req = _get_user_by_req
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index 929e5e5dd4..adcc1d1969 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -52,8 +52,8 @@ class ProfileTestCase(unittest.TestCase):
replication_layer=Mock(),
)
- def _get_user_by_req(request=None):
- return (UserID.from_string(myid), "")
+ def _get_user_by_req(request=None, allow_guest=False):
+ return (UserID.from_string(myid), "", False)
hs.get_v1auth().get_user_by_req = _get_user_by_req
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 93896dd076..b43563fa4b 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -54,10 +54,11 @@ class RoomPermissionsTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock()
- def _get_user_by_access_token(token=None):
+ def _get_user_by_access_token(token=None, allow_guest=False):
return {
"user": UserID.from_string(self.auth_user_id),
"token_id": 1,
+ "is_guest": False,
}
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
@@ -439,10 +440,11 @@ class RoomsMemberListTestCase(RestTestCase):
self.auth_user_id = self.user_id
- def _get_user_by_access_token(token=None):
+ def _get_user_by_access_token(token=None, allow_guest=False):
return {
"user": UserID.from_string(self.auth_user_id),
"token_id": 1,
+ "is_guest": False,
}
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
@@ -517,10 +519,11 @@ class RoomsCreateTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock()
- def _get_user_by_access_token(token=None):
+ def _get_user_by_access_token(token=None, allow_guest=False):
return {
"user": UserID.from_string(self.auth_user_id),
"token_id": 1,
+ "is_guest": False,
}
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
@@ -608,10 +611,11 @@ class RoomTopicTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock()
- def _get_user_by_access_token(token=None):
+ def _get_user_by_access_token(token=None, allow_guest=False):
return {
"user": UserID.from_string(self.auth_user_id),
"token_id": 1,
+ "is_guest": False,
}
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
@@ -713,10 +717,11 @@ class RoomMemberStateTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock()
- def _get_user_by_access_token(token=None):
+ def _get_user_by_access_token(token=None, allow_guest=False):
return {
"user": UserID.from_string(self.auth_user_id),
"token_id": 1,
+ "is_guest": False,
}
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
@@ -838,10 +843,11 @@ class RoomMessagesTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock()
- def _get_user_by_access_token(token=None):
+ def _get_user_by_access_token(token=None, allow_guest=False):
return {
"user": UserID.from_string(self.auth_user_id),
"token_id": 1,
+ "is_guest": False,
}
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
@@ -933,10 +939,11 @@ class RoomInitialSyncTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock()
- def _get_user_by_access_token(token=None):
+ def _get_user_by_access_token(token=None, allow_guest=False):
return {
"user": UserID.from_string(self.auth_user_id),
"token_id": 1,
+ "is_guest": False,
}
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 6395ce79db..8433585616 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -61,10 +61,11 @@ class RoomTypingTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock()
- def _get_user_by_access_token(token=None):
+ def _get_user_by_access_token(token=None, allow_guest=False):
return {
"user": UserID.from_string(self.auth_user_id),
"token_id": 1,
+ "is_guest": False,
}
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py
index f45570a1c0..fa9e17ec4f 100644
--- a/tests/rest/client/v2_alpha/__init__.py
+++ b/tests/rest/client/v2_alpha/__init__.py
@@ -43,10 +43,11 @@ class V2AlphaRestTestCase(unittest.TestCase):
resource_for_federation=self.mock_resource,
)
- def _get_user_by_access_token(token=None):
+ def _get_user_by_access_token(token=None, allow_guest=False):
return {
"user": UserID.from_string(self.USER_ID),
"token_id": 1,
+ "is_guest": False,
}
hs.get_auth()._get_user_by_access_token = _get_user_by_access_token
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index b57006fcb4..dbf9700e6a 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -120,7 +120,6 @@ class RedactionTestCase(unittest.TestCase):
self.u_alice.to_string(),
start,
end,
- None, # Is currently ignored
)
self.assertEqual(1, len(results))
@@ -149,7 +148,6 @@ class RedactionTestCase(unittest.TestCase):
self.u_alice.to_string(),
start,
end,
- None, # Is currently ignored
)
self.assertEqual(1, len(results))
@@ -199,7 +197,6 @@ class RedactionTestCase(unittest.TestCase):
self.u_alice.to_string(),
start,
end,
- None, # Is currently ignored
)
self.assertEqual(1, len(results))
@@ -228,7 +225,6 @@ class RedactionTestCase(unittest.TestCase):
self.u_alice.to_string(),
start,
end,
- None, # Is currently ignored
)
self.assertEqual(1, len(results))
diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py
index a658a789aa..e5c2c5cc8e 100644
--- a/tests/storage/test_stream.py
+++ b/tests/storage/test_stream.py
@@ -68,7 +68,6 @@ class StreamStoreTestCase(unittest.TestCase):
self.u_bob.to_string(),
start,
end,
- None, # Is currently ignored
)
self.assertEqual(1, len(results))
@@ -105,7 +104,6 @@ class StreamStoreTestCase(unittest.TestCase):
self.u_alice.to_string(),
start,
end,
- None, # Is currently ignored
)
self.assertEqual(1, len(results))
@@ -147,7 +145,6 @@ class StreamStoreTestCase(unittest.TestCase):
self.u_bob.to_string(),
start,
end,
- None, # Is currently ignored
)
# We should not get the message, as it happened *after* bob left.
@@ -175,7 +172,6 @@ class StreamStoreTestCase(unittest.TestCase):
self.u_bob.to_string(),
start,
end,
- None, # Is currently ignored
)
# We should not get the message, as it happened *after* bob left.
diff --git a/tests/utils.py b/tests/utils.py
index 4da51291a4..ca2c33cf8e 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -335,7 +335,7 @@ class MemoryDataStore(object):
]
def get_room_events_stream(self, user_id=None, from_key=None, to_key=None,
- room_id=None, limit=0, with_feedback=False):
+ limit=0, with_feedback=False):
return ([], from_key) # TODO
def get_joined_hosts_for_room(self, room_id):
|