diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 4f83db5e84..70d928defe 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -19,17 +19,21 @@ from mock import Mock
from synapse.api.auth import Auth
from synapse.api.errors import AuthError
+from synapse.types import UserID
+from tests.utils import setup_test_homeserver
+
+import pymacaroons
class AuthTestCase(unittest.TestCase):
+ @defer.inlineCallbacks
def setUp(self):
self.state_handler = Mock()
self.store = Mock()
- self.hs = Mock()
+ self.hs = yield setup_test_homeserver(handlers=None)
self.hs.get_datastore = Mock(return_value=self.store)
- self.hs.get_state_handler = Mock(return_value=self.state_handler)
self.auth = Auth(self.hs)
self.test_user = "@foo:bar"
@@ -40,21 +44,19 @@ class AuthTestCase(unittest.TestCase):
self.store.get_app_service_by_token = Mock(return_value=None)
user_info = {
"name": self.test_user,
- "device_id": "nothing",
"token_id": "ditto",
- "admin": False
}
- self.store.get_user_by_token = Mock(return_value=user_info)
+ self.store.get_user_by_access_token = Mock(return_value=user_info)
request = Mock(args={})
request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
- (user, info) = yield self.auth.get_user_by_req(request)
+ (user, _, _) = yield self.auth.get_user_by_req(request)
self.assertEquals(user.to_string(), self.test_user)
def test_get_user_by_req_user_bad_token(self):
self.store.get_app_service_by_token = Mock(return_value=None)
- self.store.get_user_by_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=None)
request = Mock(args={})
request.args["access_token"] = [self.test_token]
@@ -66,11 +68,9 @@ class AuthTestCase(unittest.TestCase):
self.store.get_app_service_by_token = Mock(return_value=None)
user_info = {
"name": self.test_user,
- "device_id": "nothing",
"token_id": "ditto",
- "admin": False
}
- self.store.get_user_by_token = Mock(return_value=user_info)
+ self.store.get_user_by_access_token = Mock(return_value=user_info)
request = Mock(args={})
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
@@ -81,17 +81,17 @@ class AuthTestCase(unittest.TestCase):
def test_get_user_by_req_appservice_valid_token(self):
app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=None)
request = Mock(args={})
request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
- (user, info) = yield self.auth.get_user_by_req(request)
+ (user, _, _) = yield self.auth.get_user_by_req(request)
self.assertEquals(user.to_string(), self.test_user)
def test_get_user_by_req_appservice_bad_token(self):
self.store.get_app_service_by_token = Mock(return_value=None)
- self.store.get_user_by_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=None)
request = Mock(args={})
request.args["access_token"] = [self.test_token]
@@ -102,7 +102,7 @@ class AuthTestCase(unittest.TestCase):
def test_get_user_by_req_appservice_missing_token(self):
app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=None)
request = Mock(args={})
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
@@ -115,13 +115,13 @@ class AuthTestCase(unittest.TestCase):
app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
app_service.is_interested_in_user = Mock(return_value=True)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=None)
request = Mock(args={})
request.args["access_token"] = [self.test_token]
request.args["user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
- (user, info) = yield self.auth.get_user_by_req(request)
+ (user, _, _) = yield self.auth.get_user_by_req(request)
self.assertEquals(user.to_string(), masquerading_user_id)
def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
@@ -129,7 +129,7 @@ class AuthTestCase(unittest.TestCase):
app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
app_service.is_interested_in_user = Mock(return_value=False)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=None)
request = Mock(args={})
request.args["access_token"] = [self.test_token]
@@ -137,3 +137,159 @@ class AuthTestCase(unittest.TestCase):
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError)
+
+ @defer.inlineCallbacks
+ def test_get_user_from_macaroon(self):
+ # TODO(danielwh): Remove this mock when we remove the
+ # get_user_by_access_token fallback.
+ self.store.get_user_by_access_token = Mock(
+ return_value={"name": "@baldrick:matrix.org"}
+ )
+
+ user_id = "@baldrick:matrix.org"
+ macaroon = pymacaroons.Macaroon(
+ location=self.hs.config.server_name,
+ identifier="key",
+ key=self.hs.config.macaroon_secret_key)
+ macaroon.add_first_party_caveat("gen = 1")
+ macaroon.add_first_party_caveat("type = access")
+ macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
+ user_info = yield self.auth._get_user_from_macaroon(macaroon.serialize())
+ user = user_info["user"]
+ self.assertEqual(UserID.from_string(user_id), user)
+
+ @defer.inlineCallbacks
+ def test_get_guest_user_from_macaroon(self):
+ user_id = "@baldrick:matrix.org"
+ macaroon = pymacaroons.Macaroon(
+ location=self.hs.config.server_name,
+ identifier="key",
+ key=self.hs.config.macaroon_secret_key)
+ macaroon.add_first_party_caveat("gen = 1")
+ macaroon.add_first_party_caveat("type = access")
+ macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
+ macaroon.add_first_party_caveat("guest = true")
+ serialized = macaroon.serialize()
+
+ user_info = yield self.auth._get_user_from_macaroon(serialized)
+ user = user_info["user"]
+ is_guest = user_info["is_guest"]
+ self.assertEqual(UserID.from_string(user_id), user)
+ self.assertTrue(is_guest)
+
+ @defer.inlineCallbacks
+ def test_get_user_from_macaroon_user_db_mismatch(self):
+ self.store.get_user_by_access_token = Mock(
+ return_value={"name": "@percy:matrix.org"}
+ )
+
+ user = "@baldrick:matrix.org"
+ macaroon = pymacaroons.Macaroon(
+ location=self.hs.config.server_name,
+ identifier="key",
+ key=self.hs.config.macaroon_secret_key)
+ macaroon.add_first_party_caveat("gen = 1")
+ macaroon.add_first_party_caveat("type = access")
+ macaroon.add_first_party_caveat("user_id = %s" % (user,))
+ with self.assertRaises(AuthError) as cm:
+ yield self.auth._get_user_from_macaroon(macaroon.serialize())
+ self.assertEqual(401, cm.exception.code)
+ self.assertIn("User mismatch", cm.exception.msg)
+
+ @defer.inlineCallbacks
+ def test_get_user_from_macaroon_missing_caveat(self):
+ # TODO(danielwh): Remove this mock when we remove the
+ # get_user_by_access_token fallback.
+ self.store.get_user_by_access_token = Mock(
+ return_value={"name": "@baldrick:matrix.org"}
+ )
+
+ macaroon = pymacaroons.Macaroon(
+ location=self.hs.config.server_name,
+ identifier="key",
+ key=self.hs.config.macaroon_secret_key)
+ macaroon.add_first_party_caveat("gen = 1")
+ macaroon.add_first_party_caveat("type = access")
+
+ with self.assertRaises(AuthError) as cm:
+ yield self.auth._get_user_from_macaroon(macaroon.serialize())
+ self.assertEqual(401, cm.exception.code)
+ self.assertIn("No user caveat", cm.exception.msg)
+
+ @defer.inlineCallbacks
+ def test_get_user_from_macaroon_wrong_key(self):
+ # TODO(danielwh): Remove this mock when we remove the
+ # get_user_by_access_token fallback.
+ self.store.get_user_by_access_token = Mock(
+ return_value={"name": "@baldrick:matrix.org"}
+ )
+
+ user = "@baldrick:matrix.org"
+ macaroon = pymacaroons.Macaroon(
+ location=self.hs.config.server_name,
+ identifier="key",
+ key=self.hs.config.macaroon_secret_key + "wrong")
+ macaroon.add_first_party_caveat("gen = 1")
+ macaroon.add_first_party_caveat("type = access")
+ macaroon.add_first_party_caveat("user_id = %s" % (user,))
+
+ with self.assertRaises(AuthError) as cm:
+ yield self.auth._get_user_from_macaroon(macaroon.serialize())
+ self.assertEqual(401, cm.exception.code)
+ self.assertIn("Invalid macaroon", cm.exception.msg)
+
+ @defer.inlineCallbacks
+ def test_get_user_from_macaroon_unknown_caveat(self):
+ # TODO(danielwh): Remove this mock when we remove the
+ # get_user_by_access_token fallback.
+ self.store.get_user_by_access_token = Mock(
+ return_value={"name": "@baldrick:matrix.org"}
+ )
+
+ user = "@baldrick:matrix.org"
+ macaroon = pymacaroons.Macaroon(
+ location=self.hs.config.server_name,
+ identifier="key",
+ key=self.hs.config.macaroon_secret_key)
+ macaroon.add_first_party_caveat("gen = 1")
+ macaroon.add_first_party_caveat("type = access")
+ macaroon.add_first_party_caveat("user_id = %s" % (user,))
+ macaroon.add_first_party_caveat("cunning > fox")
+
+ with self.assertRaises(AuthError) as cm:
+ yield self.auth._get_user_from_macaroon(macaroon.serialize())
+ self.assertEqual(401, cm.exception.code)
+ self.assertIn("Invalid macaroon", cm.exception.msg)
+
+ @defer.inlineCallbacks
+ def test_get_user_from_macaroon_expired(self):
+ # TODO(danielwh): Remove this mock when we remove the
+ # get_user_by_access_token fallback.
+ self.store.get_user_by_access_token = Mock(
+ return_value={"name": "@baldrick:matrix.org"}
+ )
+
+ self.store.get_user_by_access_token = Mock(
+ return_value={"name": "@baldrick:matrix.org"}
+ )
+
+ user = "@baldrick:matrix.org"
+ macaroon = pymacaroons.Macaroon(
+ location=self.hs.config.server_name,
+ identifier="key",
+ key=self.hs.config.macaroon_secret_key)
+ macaroon.add_first_party_caveat("gen = 1")
+ macaroon.add_first_party_caveat("type = access")
+ macaroon.add_first_party_caveat("user_id = %s" % (user,))
+ macaroon.add_first_party_caveat("time < 1") # ms
+
+ self.hs.clock.now = 5000 # seconds
+
+ yield self.auth._get_user_from_macaroon(macaroon.serialize())
+ # TODO(daniel): Turn on the check that we validate expiration, when we
+ # validate expiration (and remove the above line, which will start
+ # throwing).
+ # with self.assertRaises(AuthError) as cm:
+ # yield self.auth._get_user_from_macaroon(macaroon.serialize())
+ # self.assertEqual(401, cm.exception.code)
+ # self.assertIn("Invalid macaroon", cm.exception.msg)
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 65b2f590c8..9f9af2d783 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -23,10 +23,17 @@ from tests.utils import (
)
from synapse.types import UserID
-from synapse.api.filtering import Filter
+from synapse.api.filtering import FilterCollection, Filter
user_localpart = "test_user"
-MockEvent = namedtuple("MockEvent", "sender type room_id")
+# MockEvent = namedtuple("MockEvent", "sender type room_id")
+
+
+def MockEvent(**kwargs):
+ ev = NonCallableMock(spec_set=kwargs.keys())
+ ev.configure_mock(**kwargs)
+ return ev
+
class FilteringTestCase(unittest.TestCase):
@@ -44,7 +51,6 @@ class FilteringTestCase(unittest.TestCase):
)
self.filtering = hs.get_filtering()
- self.filter = Filter({})
self.datastore = hs.get_datastore()
@@ -57,8 +63,9 @@ class FilteringTestCase(unittest.TestCase):
type="m.room.message",
room_id="!foo:bar"
)
+
self.assertTrue(
- self.filter._passes_definition(definition, event)
+ Filter(definition).check(event)
)
def test_definition_types_works_with_wildcards(self):
@@ -71,7 +78,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar"
)
self.assertTrue(
- self.filter._passes_definition(definition, event)
+ Filter(definition).check(event)
)
def test_definition_types_works_with_unknowns(self):
@@ -84,7 +91,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar"
)
self.assertFalse(
- self.filter._passes_definition(definition, event)
+ Filter(definition).check(event)
)
def test_definition_not_types_works_with_literals(self):
@@ -97,7 +104,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar"
)
self.assertFalse(
- self.filter._passes_definition(definition, event)
+ Filter(definition).check(event)
)
def test_definition_not_types_works_with_wildcards(self):
@@ -110,7 +117,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar"
)
self.assertFalse(
- self.filter._passes_definition(definition, event)
+ Filter(definition).check(event)
)
def test_definition_not_types_works_with_unknowns(self):
@@ -123,7 +130,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar"
)
self.assertTrue(
- self.filter._passes_definition(definition, event)
+ Filter(definition).check(event)
)
def test_definition_not_types_takes_priority_over_types(self):
@@ -137,7 +144,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar"
)
self.assertFalse(
- self.filter._passes_definition(definition, event)
+ Filter(definition).check(event)
)
def test_definition_senders_works_with_literals(self):
@@ -150,7 +157,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar"
)
self.assertTrue(
- self.filter._passes_definition(definition, event)
+ Filter(definition).check(event)
)
def test_definition_senders_works_with_unknowns(self):
@@ -163,7 +170,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar"
)
self.assertFalse(
- self.filter._passes_definition(definition, event)
+ Filter(definition).check(event)
)
def test_definition_not_senders_works_with_literals(self):
@@ -176,7 +183,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar"
)
self.assertFalse(
- self.filter._passes_definition(definition, event)
+ Filter(definition).check(event)
)
def test_definition_not_senders_works_with_unknowns(self):
@@ -189,7 +196,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar"
)
self.assertTrue(
- self.filter._passes_definition(definition, event)
+ Filter(definition).check(event)
)
def test_definition_not_senders_takes_priority_over_senders(self):
@@ -203,7 +210,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar"
)
self.assertFalse(
- self.filter._passes_definition(definition, event)
+ Filter(definition).check(event)
)
def test_definition_rooms_works_with_literals(self):
@@ -216,7 +223,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!secretbase:unknown"
)
self.assertTrue(
- self.filter._passes_definition(definition, event)
+ Filter(definition).check(event)
)
def test_definition_rooms_works_with_unknowns(self):
@@ -229,7 +236,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!anothersecretbase:unknown"
)
self.assertFalse(
- self.filter._passes_definition(definition, event)
+ Filter(definition).check(event)
)
def test_definition_not_rooms_works_with_literals(self):
@@ -242,7 +249,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!anothersecretbase:unknown"
)
self.assertFalse(
- self.filter._passes_definition(definition, event)
+ Filter(definition).check(event)
)
def test_definition_not_rooms_works_with_unknowns(self):
@@ -255,7 +262,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!anothersecretbase:unknown"
)
self.assertTrue(
- self.filter._passes_definition(definition, event)
+ Filter(definition).check(event)
)
def test_definition_not_rooms_takes_priority_over_rooms(self):
@@ -269,7 +276,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!secretbase:unknown"
)
self.assertFalse(
- self.filter._passes_definition(definition, event)
+ Filter(definition).check(event)
)
def test_definition_combined_event(self):
@@ -287,7 +294,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!stage:unknown" # yup
)
self.assertTrue(
- self.filter._passes_definition(definition, event)
+ Filter(definition).check(event)
)
def test_definition_combined_event_bad_sender(self):
@@ -305,7 +312,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!stage:unknown" # yup
)
self.assertFalse(
- self.filter._passes_definition(definition, event)
+ Filter(definition).check(event)
)
def test_definition_combined_event_bad_room(self):
@@ -323,7 +330,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!piggyshouse:muppets" # nope
)
self.assertFalse(
- self.filter._passes_definition(definition, event)
+ Filter(definition).check(event)
)
def test_definition_combined_event_bad_type(self):
@@ -341,13 +348,13 @@ class FilteringTestCase(unittest.TestCase):
room_id="!stage:unknown" # yup
)
self.assertFalse(
- self.filter._passes_definition(definition, event)
+ Filter(definition).check(event)
)
@defer.inlineCallbacks
- def test_filter_public_user_data_match(self):
+ def test_filter_presence_match(self):
user_filter_json = {
- "public_user_data": {
+ "presence": {
"types": ["m.*"]
}
}
@@ -359,7 +366,6 @@ class FilteringTestCase(unittest.TestCase):
event = MockEvent(
sender="@foo:bar",
type="m.profile",
- room_id="!foo:bar"
)
events = [event]
@@ -368,13 +374,13 @@ class FilteringTestCase(unittest.TestCase):
filter_id=filter_id,
)
- results = user_filter.filter_public_user_data(events=events)
+ results = user_filter.filter_presence(events=events)
self.assertEquals(events, results)
@defer.inlineCallbacks
- def test_filter_public_user_data_no_match(self):
+ def test_filter_presence_no_match(self):
user_filter_json = {
- "public_user_data": {
+ "presence": {
"types": ["m.*"]
}
}
@@ -386,7 +392,6 @@ class FilteringTestCase(unittest.TestCase):
event = MockEvent(
sender="@foo:bar",
type="custom.avatar.3d.crazy",
- room_id="!foo:bar"
)
events = [event]
@@ -395,7 +400,7 @@ class FilteringTestCase(unittest.TestCase):
filter_id=filter_id,
)
- results = user_filter.filter_public_user_data(events=events)
+ results = user_filter.filter_presence(events=events)
self.assertEquals([], results)
@defer.inlineCallbacks
|