summary refs log tree commit diff
path: root/tests/api
diff options
context:
space:
mode:
Diffstat (limited to 'tests/api')
-rw-r--r--tests/api/test_auth.py30
-rw-r--r--tests/api/test_filtering.py21
2 files changed, 25 insertions, 26 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 70d928defe..474c5c418f 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2015 OpenMarket Ltd
+# Copyright 2015 - 2016 OpenMarket Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -51,8 +51,8 @@ class AuthTestCase(unittest.TestCase):
         request = Mock(args={})
         request.args["access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = Mock(return_value=[""])
-        (user, _, _) = yield self.auth.get_user_by_req(request)
-        self.assertEquals(user.to_string(), self.test_user)
+        requester = yield self.auth.get_user_by_req(request)
+        self.assertEquals(requester.user.to_string(), self.test_user)
 
     def test_get_user_by_req_user_bad_token(self):
         self.store.get_app_service_by_token = Mock(return_value=None)
@@ -86,8 +86,8 @@ class AuthTestCase(unittest.TestCase):
         request = Mock(args={})
         request.args["access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = Mock(return_value=[""])
-        (user, _, _) = yield self.auth.get_user_by_req(request)
-        self.assertEquals(user.to_string(), self.test_user)
+        requester = yield self.auth.get_user_by_req(request)
+        self.assertEquals(requester.user.to_string(), self.test_user)
 
     def test_get_user_by_req_appservice_bad_token(self):
         self.store.get_app_service_by_token = Mock(return_value=None)
@@ -121,8 +121,8 @@ class AuthTestCase(unittest.TestCase):
         request.args["access_token"] = [self.test_token]
         request.args["user_id"] = [masquerading_user_id]
         request.requestHeaders.getRawHeaders = Mock(return_value=[""])
-        (user, _, _) = yield self.auth.get_user_by_req(request)
-        self.assertEquals(user.to_string(), masquerading_user_id)
+        requester = yield self.auth.get_user_by_req(request)
+        self.assertEquals(requester.user.to_string(), masquerading_user_id)
 
     def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
         masquerading_user_id = "@doppelganger:matrix.org"
@@ -154,7 +154,7 @@ class AuthTestCase(unittest.TestCase):
         macaroon.add_first_party_caveat("gen = 1")
         macaroon.add_first_party_caveat("type = access")
         macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
-        user_info = yield self.auth._get_user_from_macaroon(macaroon.serialize())
+        user_info = yield self.auth.get_user_from_macaroon(macaroon.serialize())
         user = user_info["user"]
         self.assertEqual(UserID.from_string(user_id), user)
 
@@ -171,7 +171,7 @@ class AuthTestCase(unittest.TestCase):
         macaroon.add_first_party_caveat("guest = true")
         serialized = macaroon.serialize()
 
-        user_info = yield self.auth._get_user_from_macaroon(serialized)
+        user_info = yield self.auth.get_user_from_macaroon(serialized)
         user = user_info["user"]
         is_guest = user_info["is_guest"]
         self.assertEqual(UserID.from_string(user_id), user)
@@ -192,7 +192,7 @@ class AuthTestCase(unittest.TestCase):
         macaroon.add_first_party_caveat("type = access")
         macaroon.add_first_party_caveat("user_id = %s" % (user,))
         with self.assertRaises(AuthError) as cm:
-            yield self.auth._get_user_from_macaroon(macaroon.serialize())
+            yield self.auth.get_user_from_macaroon(macaroon.serialize())
         self.assertEqual(401, cm.exception.code)
         self.assertIn("User mismatch", cm.exception.msg)
 
@@ -212,7 +212,7 @@ class AuthTestCase(unittest.TestCase):
         macaroon.add_first_party_caveat("type = access")
 
         with self.assertRaises(AuthError) as cm:
-            yield self.auth._get_user_from_macaroon(macaroon.serialize())
+            yield self.auth.get_user_from_macaroon(macaroon.serialize())
         self.assertEqual(401, cm.exception.code)
         self.assertIn("No user caveat", cm.exception.msg)
 
@@ -234,7 +234,7 @@ class AuthTestCase(unittest.TestCase):
         macaroon.add_first_party_caveat("user_id = %s" % (user,))
 
         with self.assertRaises(AuthError) as cm:
-            yield self.auth._get_user_from_macaroon(macaroon.serialize())
+            yield self.auth.get_user_from_macaroon(macaroon.serialize())
         self.assertEqual(401, cm.exception.code)
         self.assertIn("Invalid macaroon", cm.exception.msg)
 
@@ -257,7 +257,7 @@ class AuthTestCase(unittest.TestCase):
         macaroon.add_first_party_caveat("cunning > fox")
 
         with self.assertRaises(AuthError) as cm:
-            yield self.auth._get_user_from_macaroon(macaroon.serialize())
+            yield self.auth.get_user_from_macaroon(macaroon.serialize())
         self.assertEqual(401, cm.exception.code)
         self.assertIn("Invalid macaroon", cm.exception.msg)
 
@@ -285,11 +285,11 @@ class AuthTestCase(unittest.TestCase):
 
         self.hs.clock.now = 5000 # seconds
 
-        yield self.auth._get_user_from_macaroon(macaroon.serialize())
+        yield self.auth.get_user_from_macaroon(macaroon.serialize())
         # TODO(daniel): Turn on the check that we validate expiration, when we
         # validate expiration (and remove the above line, which will start
         # throwing).
         # with self.assertRaises(AuthError) as cm:
-        #     yield self.auth._get_user_from_macaroon(macaroon.serialize())
+        #     yield self.auth.get_user_from_macaroon(macaroon.serialize())
         # self.assertEqual(401, cm.exception.code)
         # self.assertIn("Invalid macaroon", cm.exception.msg)
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 9f9af2d783..ceb0089268 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2015 OpenMarket Ltd
+# Copyright 2015, 2016 OpenMarket Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,26 +13,24 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from collections import namedtuple
 from tests import unittest
 from twisted.internet import defer
 
-from mock import Mock, NonCallableMock
+from mock import Mock
 from tests.utils import (
     MockHttpResource, DeferredMockCallable, setup_test_homeserver
 )
 
 from synapse.types import UserID
-from synapse.api.filtering import FilterCollection, Filter
+from synapse.api.filtering import Filter
+from synapse.events import FrozenEvent
 
 user_localpart = "test_user"
 # MockEvent = namedtuple("MockEvent", "sender type room_id")
 
 
 def MockEvent(**kwargs):
-    ev = NonCallableMock(spec_set=kwargs.keys())
-    ev.configure_mock(**kwargs)
-    return ev
+    return FrozenEvent(kwargs)
 
 
 class FilteringTestCase(unittest.TestCase):
@@ -384,19 +382,20 @@ class FilteringTestCase(unittest.TestCase):
                 "types": ["m.*"]
             }
         }
-        user = UserID.from_string("@" + user_localpart + ":test")
+
         filter_id = yield self.datastore.add_user_filter(
-            user_localpart=user_localpart,
+            user_localpart=user_localpart + "2",
             user_filter=user_filter_json,
         )
         event = MockEvent(
+            event_id="$asdasd:localhost",
             sender="@foo:bar",
             type="custom.avatar.3d.crazy",
         )
         events = [event]
 
         user_filter = yield self.filtering.get_user_filter(
-            user_localpart=user_localpart,
+            user_localpart=user_localpart + "2",
             filter_id=filter_id,
         )
 
@@ -504,4 +503,4 @@ class FilteringTestCase(unittest.TestCase):
             filter_id=filter_id,
         )
 
-        self.assertEquals(filter.filter_json, user_filter_json)
+        self.assertEquals(filter.get_filter_json(), user_filter_json)