summary refs log tree commit diff
path: root/tests/api
diff options
context:
space:
mode:
Diffstat (limited to 'tests/api')
-rw-r--r--tests/api/test_auth.py190
-rw-r--r--tests/api/test_filtering.py69
2 files changed, 210 insertions, 49 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 4f83db5e84..70d928defe 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -19,17 +19,21 @@ from mock import Mock
 
 from synapse.api.auth import Auth
 from synapse.api.errors import AuthError
+from synapse.types import UserID
+from tests.utils import setup_test_homeserver
+
+import pymacaroons
 
 
 class AuthTestCase(unittest.TestCase):
 
+    @defer.inlineCallbacks
     def setUp(self):
         self.state_handler = Mock()
         self.store = Mock()
 
-        self.hs = Mock()
+        self.hs = yield setup_test_homeserver(handlers=None)
         self.hs.get_datastore = Mock(return_value=self.store)
-        self.hs.get_state_handler = Mock(return_value=self.state_handler)
         self.auth = Auth(self.hs)
 
         self.test_user = "@foo:bar"
@@ -40,21 +44,19 @@ class AuthTestCase(unittest.TestCase):
         self.store.get_app_service_by_token = Mock(return_value=None)
         user_info = {
             "name": self.test_user,
-            "device_id": "nothing",
             "token_id": "ditto",
-            "admin": False
         }
-        self.store.get_user_by_token = Mock(return_value=user_info)
+        self.store.get_user_by_access_token = Mock(return_value=user_info)
 
         request = Mock(args={})
         request.args["access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = Mock(return_value=[""])
-        (user, info) = yield self.auth.get_user_by_req(request)
+        (user, _, _) = yield self.auth.get_user_by_req(request)
         self.assertEquals(user.to_string(), self.test_user)
 
     def test_get_user_by_req_user_bad_token(self):
         self.store.get_app_service_by_token = Mock(return_value=None)
-        self.store.get_user_by_token = Mock(return_value=None)
+        self.store.get_user_by_access_token = Mock(return_value=None)
 
         request = Mock(args={})
         request.args["access_token"] = [self.test_token]
@@ -66,11 +68,9 @@ class AuthTestCase(unittest.TestCase):
         self.store.get_app_service_by_token = Mock(return_value=None)
         user_info = {
             "name": self.test_user,
-            "device_id": "nothing",
             "token_id": "ditto",
-            "admin": False
         }
-        self.store.get_user_by_token = Mock(return_value=user_info)
+        self.store.get_user_by_access_token = Mock(return_value=user_info)
 
         request = Mock(args={})
         request.requestHeaders.getRawHeaders = Mock(return_value=[""])
@@ -81,17 +81,17 @@ class AuthTestCase(unittest.TestCase):
     def test_get_user_by_req_appservice_valid_token(self):
         app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_token = Mock(return_value=None)
+        self.store.get_user_by_access_token = Mock(return_value=None)
 
         request = Mock(args={})
         request.args["access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = Mock(return_value=[""])
-        (user, info) = yield self.auth.get_user_by_req(request)
+        (user, _, _) = yield self.auth.get_user_by_req(request)
         self.assertEquals(user.to_string(), self.test_user)
 
     def test_get_user_by_req_appservice_bad_token(self):
         self.store.get_app_service_by_token = Mock(return_value=None)
-        self.store.get_user_by_token = Mock(return_value=None)
+        self.store.get_user_by_access_token = Mock(return_value=None)
 
         request = Mock(args={})
         request.args["access_token"] = [self.test_token]
@@ -102,7 +102,7 @@ class AuthTestCase(unittest.TestCase):
     def test_get_user_by_req_appservice_missing_token(self):
         app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_token = Mock(return_value=None)
+        self.store.get_user_by_access_token = Mock(return_value=None)
 
         request = Mock(args={})
         request.requestHeaders.getRawHeaders = Mock(return_value=[""])
@@ -115,13 +115,13 @@ class AuthTestCase(unittest.TestCase):
         app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
         app_service.is_interested_in_user = Mock(return_value=True)
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_token = Mock(return_value=None)
+        self.store.get_user_by_access_token = Mock(return_value=None)
 
         request = Mock(args={})
         request.args["access_token"] = [self.test_token]
         request.args["user_id"] = [masquerading_user_id]
         request.requestHeaders.getRawHeaders = Mock(return_value=[""])
-        (user, info) = yield self.auth.get_user_by_req(request)
+        (user, _, _) = yield self.auth.get_user_by_req(request)
         self.assertEquals(user.to_string(), masquerading_user_id)
 
     def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
@@ -129,7 +129,7 @@ class AuthTestCase(unittest.TestCase):
         app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
         app_service.is_interested_in_user = Mock(return_value=False)
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_token = Mock(return_value=None)
+        self.store.get_user_by_access_token = Mock(return_value=None)
 
         request = Mock(args={})
         request.args["access_token"] = [self.test_token]
@@ -137,3 +137,159 @@ class AuthTestCase(unittest.TestCase):
         request.requestHeaders.getRawHeaders = Mock(return_value=[""])
         d = self.auth.get_user_by_req(request)
         self.failureResultOf(d, AuthError)
+
+    @defer.inlineCallbacks
+    def test_get_user_from_macaroon(self):
+        # TODO(danielwh): Remove this mock when we remove the
+        # get_user_by_access_token fallback.
+        self.store.get_user_by_access_token = Mock(
+            return_value={"name": "@baldrick:matrix.org"}
+        )
+
+        user_id = "@baldrick:matrix.org"
+        macaroon = pymacaroons.Macaroon(
+            location=self.hs.config.server_name,
+            identifier="key",
+            key=self.hs.config.macaroon_secret_key)
+        macaroon.add_first_party_caveat("gen = 1")
+        macaroon.add_first_party_caveat("type = access")
+        macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
+        user_info = yield self.auth._get_user_from_macaroon(macaroon.serialize())
+        user = user_info["user"]
+        self.assertEqual(UserID.from_string(user_id), user)
+
+    @defer.inlineCallbacks
+    def test_get_guest_user_from_macaroon(self):
+        user_id = "@baldrick:matrix.org"
+        macaroon = pymacaroons.Macaroon(
+            location=self.hs.config.server_name,
+            identifier="key",
+            key=self.hs.config.macaroon_secret_key)
+        macaroon.add_first_party_caveat("gen = 1")
+        macaroon.add_first_party_caveat("type = access")
+        macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
+        macaroon.add_first_party_caveat("guest = true")
+        serialized = macaroon.serialize()
+
+        user_info = yield self.auth._get_user_from_macaroon(serialized)
+        user = user_info["user"]
+        is_guest = user_info["is_guest"]
+        self.assertEqual(UserID.from_string(user_id), user)
+        self.assertTrue(is_guest)
+
+    @defer.inlineCallbacks
+    def test_get_user_from_macaroon_user_db_mismatch(self):
+        self.store.get_user_by_access_token = Mock(
+            return_value={"name": "@percy:matrix.org"}
+        )
+
+        user = "@baldrick:matrix.org"
+        macaroon = pymacaroons.Macaroon(
+            location=self.hs.config.server_name,
+            identifier="key",
+            key=self.hs.config.macaroon_secret_key)
+        macaroon.add_first_party_caveat("gen = 1")
+        macaroon.add_first_party_caveat("type = access")
+        macaroon.add_first_party_caveat("user_id = %s" % (user,))
+        with self.assertRaises(AuthError) as cm:
+            yield self.auth._get_user_from_macaroon(macaroon.serialize())
+        self.assertEqual(401, cm.exception.code)
+        self.assertIn("User mismatch", cm.exception.msg)
+
+    @defer.inlineCallbacks
+    def test_get_user_from_macaroon_missing_caveat(self):
+        # TODO(danielwh): Remove this mock when we remove the
+        # get_user_by_access_token fallback.
+        self.store.get_user_by_access_token = Mock(
+            return_value={"name": "@baldrick:matrix.org"}
+        )
+
+        macaroon = pymacaroons.Macaroon(
+            location=self.hs.config.server_name,
+            identifier="key",
+            key=self.hs.config.macaroon_secret_key)
+        macaroon.add_first_party_caveat("gen = 1")
+        macaroon.add_first_party_caveat("type = access")
+
+        with self.assertRaises(AuthError) as cm:
+            yield self.auth._get_user_from_macaroon(macaroon.serialize())
+        self.assertEqual(401, cm.exception.code)
+        self.assertIn("No user caveat", cm.exception.msg)
+
+    @defer.inlineCallbacks
+    def test_get_user_from_macaroon_wrong_key(self):
+        # TODO(danielwh): Remove this mock when we remove the
+        # get_user_by_access_token fallback.
+        self.store.get_user_by_access_token = Mock(
+            return_value={"name": "@baldrick:matrix.org"}
+        )
+
+        user = "@baldrick:matrix.org"
+        macaroon = pymacaroons.Macaroon(
+            location=self.hs.config.server_name,
+            identifier="key",
+            key=self.hs.config.macaroon_secret_key + "wrong")
+        macaroon.add_first_party_caveat("gen = 1")
+        macaroon.add_first_party_caveat("type = access")
+        macaroon.add_first_party_caveat("user_id = %s" % (user,))
+
+        with self.assertRaises(AuthError) as cm:
+            yield self.auth._get_user_from_macaroon(macaroon.serialize())
+        self.assertEqual(401, cm.exception.code)
+        self.assertIn("Invalid macaroon", cm.exception.msg)
+
+    @defer.inlineCallbacks
+    def test_get_user_from_macaroon_unknown_caveat(self):
+        # TODO(danielwh): Remove this mock when we remove the
+        # get_user_by_access_token fallback.
+        self.store.get_user_by_access_token = Mock(
+            return_value={"name": "@baldrick:matrix.org"}
+        )
+
+        user = "@baldrick:matrix.org"
+        macaroon = pymacaroons.Macaroon(
+            location=self.hs.config.server_name,
+            identifier="key",
+            key=self.hs.config.macaroon_secret_key)
+        macaroon.add_first_party_caveat("gen = 1")
+        macaroon.add_first_party_caveat("type = access")
+        macaroon.add_first_party_caveat("user_id = %s" % (user,))
+        macaroon.add_first_party_caveat("cunning > fox")
+
+        with self.assertRaises(AuthError) as cm:
+            yield self.auth._get_user_from_macaroon(macaroon.serialize())
+        self.assertEqual(401, cm.exception.code)
+        self.assertIn("Invalid macaroon", cm.exception.msg)
+
+    @defer.inlineCallbacks
+    def test_get_user_from_macaroon_expired(self):
+        # TODO(danielwh): Remove this mock when we remove the
+        # get_user_by_access_token fallback.
+        self.store.get_user_by_access_token = Mock(
+            return_value={"name": "@baldrick:matrix.org"}
+        )
+
+        self.store.get_user_by_access_token = Mock(
+            return_value={"name": "@baldrick:matrix.org"}
+        )
+
+        user = "@baldrick:matrix.org"
+        macaroon = pymacaroons.Macaroon(
+            location=self.hs.config.server_name,
+            identifier="key",
+            key=self.hs.config.macaroon_secret_key)
+        macaroon.add_first_party_caveat("gen = 1")
+        macaroon.add_first_party_caveat("type = access")
+        macaroon.add_first_party_caveat("user_id = %s" % (user,))
+        macaroon.add_first_party_caveat("time < 1") # ms
+
+        self.hs.clock.now = 5000 # seconds
+
+        yield self.auth._get_user_from_macaroon(macaroon.serialize())
+        # TODO(daniel): Turn on the check that we validate expiration, when we
+        # validate expiration (and remove the above line, which will start
+        # throwing).
+        # with self.assertRaises(AuthError) as cm:
+        #     yield self.auth._get_user_from_macaroon(macaroon.serialize())
+        # self.assertEqual(401, cm.exception.code)
+        # self.assertIn("Invalid macaroon", cm.exception.msg)
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 65b2f590c8..9f9af2d783 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -23,10 +23,17 @@ from tests.utils import (
 )
 
 from synapse.types import UserID
-from synapse.api.filtering import Filter
+from synapse.api.filtering import FilterCollection, Filter
 
 user_localpart = "test_user"
-MockEvent = namedtuple("MockEvent", "sender type room_id")
+# MockEvent = namedtuple("MockEvent", "sender type room_id")
+
+
+def MockEvent(**kwargs):
+    ev = NonCallableMock(spec_set=kwargs.keys())
+    ev.configure_mock(**kwargs)
+    return ev
+
 
 class FilteringTestCase(unittest.TestCase):
 
@@ -44,7 +51,6 @@ class FilteringTestCase(unittest.TestCase):
         )
 
         self.filtering = hs.get_filtering()
-        self.filter = Filter({})
 
         self.datastore = hs.get_datastore()
 
@@ -57,8 +63,9 @@ class FilteringTestCase(unittest.TestCase):
             type="m.room.message",
             room_id="!foo:bar"
         )
+
         self.assertTrue(
-            self.filter._passes_definition(definition, event)
+            Filter(definition).check(event)
         )
 
     def test_definition_types_works_with_wildcards(self):
@@ -71,7 +78,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertTrue(
-            self.filter._passes_definition(definition, event)
+            Filter(definition).check(event)
         )
 
     def test_definition_types_works_with_unknowns(self):
@@ -84,7 +91,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertFalse(
-            self.filter._passes_definition(definition, event)
+            Filter(definition).check(event)
         )
 
     def test_definition_not_types_works_with_literals(self):
@@ -97,7 +104,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertFalse(
-            self.filter._passes_definition(definition, event)
+            Filter(definition).check(event)
         )
 
     def test_definition_not_types_works_with_wildcards(self):
@@ -110,7 +117,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertFalse(
-            self.filter._passes_definition(definition, event)
+            Filter(definition).check(event)
         )
 
     def test_definition_not_types_works_with_unknowns(self):
@@ -123,7 +130,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertTrue(
-            self.filter._passes_definition(definition, event)
+            Filter(definition).check(event)
         )
 
     def test_definition_not_types_takes_priority_over_types(self):
@@ -137,7 +144,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertFalse(
-            self.filter._passes_definition(definition, event)
+            Filter(definition).check(event)
         )
 
     def test_definition_senders_works_with_literals(self):
@@ -150,7 +157,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertTrue(
-            self.filter._passes_definition(definition, event)
+            Filter(definition).check(event)
         )
 
     def test_definition_senders_works_with_unknowns(self):
@@ -163,7 +170,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertFalse(
-            self.filter._passes_definition(definition, event)
+            Filter(definition).check(event)
         )
 
     def test_definition_not_senders_works_with_literals(self):
@@ -176,7 +183,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertFalse(
-            self.filter._passes_definition(definition, event)
+            Filter(definition).check(event)
         )
 
     def test_definition_not_senders_works_with_unknowns(self):
@@ -189,7 +196,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertTrue(
-            self.filter._passes_definition(definition, event)
+            Filter(definition).check(event)
         )
 
     def test_definition_not_senders_takes_priority_over_senders(self):
@@ -203,7 +210,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertFalse(
-            self.filter._passes_definition(definition, event)
+            Filter(definition).check(event)
         )
 
     def test_definition_rooms_works_with_literals(self):
@@ -216,7 +223,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!secretbase:unknown"
         )
         self.assertTrue(
-            self.filter._passes_definition(definition, event)
+            Filter(definition).check(event)
         )
 
     def test_definition_rooms_works_with_unknowns(self):
@@ -229,7 +236,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!anothersecretbase:unknown"
         )
         self.assertFalse(
-            self.filter._passes_definition(definition, event)
+            Filter(definition).check(event)
         )
 
     def test_definition_not_rooms_works_with_literals(self):
@@ -242,7 +249,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!anothersecretbase:unknown"
         )
         self.assertFalse(
-            self.filter._passes_definition(definition, event)
+            Filter(definition).check(event)
         )
 
     def test_definition_not_rooms_works_with_unknowns(self):
@@ -255,7 +262,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!anothersecretbase:unknown"
         )
         self.assertTrue(
-            self.filter._passes_definition(definition, event)
+            Filter(definition).check(event)
         )
 
     def test_definition_not_rooms_takes_priority_over_rooms(self):
@@ -269,7 +276,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!secretbase:unknown"
         )
         self.assertFalse(
-            self.filter._passes_definition(definition, event)
+            Filter(definition).check(event)
         )
 
     def test_definition_combined_event(self):
@@ -287,7 +294,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!stage:unknown"  # yup
         )
         self.assertTrue(
-            self.filter._passes_definition(definition, event)
+            Filter(definition).check(event)
         )
 
     def test_definition_combined_event_bad_sender(self):
@@ -305,7 +312,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!stage:unknown"  # yup
         )
         self.assertFalse(
-            self.filter._passes_definition(definition, event)
+            Filter(definition).check(event)
         )
 
     def test_definition_combined_event_bad_room(self):
@@ -323,7 +330,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!piggyshouse:muppets"  # nope
         )
         self.assertFalse(
-            self.filter._passes_definition(definition, event)
+            Filter(definition).check(event)
         )
 
     def test_definition_combined_event_bad_type(self):
@@ -341,13 +348,13 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!stage:unknown"  # yup
         )
         self.assertFalse(
-            self.filter._passes_definition(definition, event)
+            Filter(definition).check(event)
         )
 
     @defer.inlineCallbacks
-    def test_filter_public_user_data_match(self):
+    def test_filter_presence_match(self):
         user_filter_json = {
-            "public_user_data": {
+            "presence": {
                 "types": ["m.*"]
             }
         }
@@ -359,7 +366,6 @@ class FilteringTestCase(unittest.TestCase):
         event = MockEvent(
             sender="@foo:bar",
             type="m.profile",
-            room_id="!foo:bar"
         )
         events = [event]
 
@@ -368,13 +374,13 @@ class FilteringTestCase(unittest.TestCase):
             filter_id=filter_id,
         )
 
-        results = user_filter.filter_public_user_data(events=events)
+        results = user_filter.filter_presence(events=events)
         self.assertEquals(events, results)
 
     @defer.inlineCallbacks
-    def test_filter_public_user_data_no_match(self):
+    def test_filter_presence_no_match(self):
         user_filter_json = {
-            "public_user_data": {
+            "presence": {
                 "types": ["m.*"]
             }
         }
@@ -386,7 +392,6 @@ class FilteringTestCase(unittest.TestCase):
         event = MockEvent(
             sender="@foo:bar",
             type="custom.avatar.3d.crazy",
-            room_id="!foo:bar"
         )
         events = [event]
 
@@ -395,7 +400,7 @@ class FilteringTestCase(unittest.TestCase):
             filter_id=filter_id,
         )
 
-        results = user_filter.filter_public_user_data(events=events)
+        results = user_filter.filter_presence(events=events)
         self.assertEquals([], results)
 
     @defer.inlineCallbacks