summary refs log tree commit diff
path: root/tests/api
diff options
context:
space:
mode:
Diffstat (limited to 'tests/api')
-rw-r--r--tests/api/test_auth.py69
-rw-r--r--tests/api/test_filtering.py36
2 files changed, 66 insertions, 39 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 0bfb86bf1f..5d45689c8c 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -62,12 +62,15 @@ class AuthTestCase(unittest.TestCase):
         # this is overridden for the appservice tests
         self.store.get_app_service_by_token = Mock(return_value=None)
 
+        self.store.insert_client_ip = Mock(return_value=defer.succeed(None))
         self.store.is_support_user = Mock(return_value=defer.succeed(False))
 
     @defer.inlineCallbacks
     def test_get_user_by_req_user_valid_token(self):
         user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"}
-        self.store.get_user_by_access_token = Mock(return_value=user_info)
+        self.store.get_user_by_access_token = Mock(
+            return_value=defer.succeed(user_info)
+        )
 
         request = Mock(args={})
         request.args[b"access_token"] = [self.test_token]
@@ -76,23 +79,25 @@ class AuthTestCase(unittest.TestCase):
         self.assertEquals(requester.user.to_string(), self.test_user)
 
     def test_get_user_by_req_user_bad_token(self):
-        self.store.get_user_by_access_token = Mock(return_value=None)
+        self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
 
         request = Mock(args={})
         request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
-        d = self.auth.get_user_by_req(request)
+        d = defer.ensureDeferred(self.auth.get_user_by_req(request))
         f = self.failureResultOf(d, InvalidClientTokenError).value
         self.assertEqual(f.code, 401)
         self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
 
     def test_get_user_by_req_user_missing_token(self):
         user_info = {"name": self.test_user, "token_id": "ditto"}
-        self.store.get_user_by_access_token = Mock(return_value=user_info)
+        self.store.get_user_by_access_token = Mock(
+            return_value=defer.succeed(user_info)
+        )
 
         request = Mock(args={})
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
-        d = self.auth.get_user_by_req(request)
+        d = defer.ensureDeferred(self.auth.get_user_by_req(request))
         f = self.failureResultOf(d, MissingClientTokenError).value
         self.assertEqual(f.code, 401)
         self.assertEqual(f.errcode, "M_MISSING_TOKEN")
@@ -103,7 +108,7 @@ class AuthTestCase(unittest.TestCase):
             token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
         )
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_access_token = Mock(return_value=None)
+        self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
 
         request = Mock(args={})
         request.getClientIP.return_value = "127.0.0.1"
@@ -123,7 +128,7 @@ class AuthTestCase(unittest.TestCase):
             ip_range_whitelist=IPSet(["192.168/16"]),
         )
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_access_token = Mock(return_value=None)
+        self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
 
         request = Mock(args={})
         request.getClientIP.return_value = "192.168.10.10"
@@ -142,25 +147,25 @@ class AuthTestCase(unittest.TestCase):
             ip_range_whitelist=IPSet(["192.168/16"]),
         )
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_access_token = Mock(return_value=None)
+        self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
 
         request = Mock(args={})
         request.getClientIP.return_value = "131.111.8.42"
         request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
-        d = self.auth.get_user_by_req(request)
+        d = defer.ensureDeferred(self.auth.get_user_by_req(request))
         f = self.failureResultOf(d, InvalidClientTokenError).value
         self.assertEqual(f.code, 401)
         self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
 
     def test_get_user_by_req_appservice_bad_token(self):
         self.store.get_app_service_by_token = Mock(return_value=None)
-        self.store.get_user_by_access_token = Mock(return_value=None)
+        self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
 
         request = Mock(args={})
         request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
-        d = self.auth.get_user_by_req(request)
+        d = defer.ensureDeferred(self.auth.get_user_by_req(request))
         f = self.failureResultOf(d, InvalidClientTokenError).value
         self.assertEqual(f.code, 401)
         self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
@@ -168,11 +173,11 @@ class AuthTestCase(unittest.TestCase):
     def test_get_user_by_req_appservice_missing_token(self):
         app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_access_token = Mock(return_value=None)
+        self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
 
         request = Mock(args={})
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
-        d = self.auth.get_user_by_req(request)
+        d = defer.ensureDeferred(self.auth.get_user_by_req(request))
         f = self.failureResultOf(d, MissingClientTokenError).value
         self.assertEqual(f.code, 401)
         self.assertEqual(f.errcode, "M_MISSING_TOKEN")
@@ -185,7 +190,11 @@ class AuthTestCase(unittest.TestCase):
         )
         app_service.is_interested_in_user = Mock(return_value=True)
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_access_token = Mock(return_value=None)
+        # This just needs to return a truth-y value.
+        self.store.get_user_by_id = Mock(
+            return_value=defer.succeed({"is_guest": False})
+        )
+        self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
 
         request = Mock(args={})
         request.getClientIP.return_value = "127.0.0.1"
@@ -204,20 +213,22 @@ class AuthTestCase(unittest.TestCase):
         )
         app_service.is_interested_in_user = Mock(return_value=False)
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_access_token = Mock(return_value=None)
+        self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
 
         request = Mock(args={})
         request.getClientIP.return_value = "127.0.0.1"
         request.args[b"access_token"] = [self.test_token]
         request.args[b"user_id"] = [masquerading_user_id]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
-        d = self.auth.get_user_by_req(request)
+        d = defer.ensureDeferred(self.auth.get_user_by_req(request))
         self.failureResultOf(d, AuthError)
 
     @defer.inlineCallbacks
     def test_get_user_from_macaroon(self):
         self.store.get_user_by_access_token = Mock(
-            return_value={"name": "@baldrick:matrix.org", "device_id": "device"}
+            return_value=defer.succeed(
+                {"name": "@baldrick:matrix.org", "device_id": "device"}
+            )
         )
 
         user_id = "@baldrick:matrix.org"
@@ -241,8 +252,8 @@ class AuthTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_get_guest_user_from_macaroon(self):
-        self.store.get_user_by_id = Mock(return_value={"is_guest": True})
-        self.store.get_user_by_access_token = Mock(return_value=None)
+        self.store.get_user_by_id = Mock(return_value=defer.succeed({"is_guest": True}))
+        self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
 
         user_id = "@baldrick:matrix.org"
         macaroon = pymacaroons.Macaroon(
@@ -282,16 +293,20 @@ class AuthTestCase(unittest.TestCase):
 
         def get_user(tok):
             if token != tok:
-                return None
-            return {
-                "name": USER_ID,
-                "is_guest": False,
-                "token_id": 1234,
-                "device_id": "DEVICE",
-            }
+                return defer.succeed(None)
+            return defer.succeed(
+                {
+                    "name": USER_ID,
+                    "is_guest": False,
+                    "token_id": 1234,
+                    "device_id": "DEVICE",
+                }
+            )
 
         self.store.get_user_by_access_token = get_user
-        self.store.get_user_by_id = Mock(return_value={"is_guest": False})
+        self.store.get_user_by_id = Mock(
+            return_value=defer.succeed({"is_guest": False})
+        )
 
         # check the token works
         request = Mock(args={})
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 4e67503cf0..1fab1d6b69 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -375,8 +375,10 @@ class FilteringTestCase(unittest.TestCase):
         event = MockEvent(sender="@foo:bar", type="m.profile")
         events = [event]
 
-        user_filter = yield self.filtering.get_user_filter(
-            user_localpart=user_localpart, filter_id=filter_id
+        user_filter = yield defer.ensureDeferred(
+            self.filtering.get_user_filter(
+                user_localpart=user_localpart, filter_id=filter_id
+            )
         )
 
         results = user_filter.filter_presence(events=events)
@@ -396,8 +398,10 @@ class FilteringTestCase(unittest.TestCase):
         )
         events = [event]
 
-        user_filter = yield self.filtering.get_user_filter(
-            user_localpart=user_localpart + "2", filter_id=filter_id
+        user_filter = yield defer.ensureDeferred(
+            self.filtering.get_user_filter(
+                user_localpart=user_localpart + "2", filter_id=filter_id
+            )
         )
 
         results = user_filter.filter_presence(events=events)
@@ -412,8 +416,10 @@ class FilteringTestCase(unittest.TestCase):
         event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
         events = [event]
 
-        user_filter = yield self.filtering.get_user_filter(
-            user_localpart=user_localpart, filter_id=filter_id
+        user_filter = yield defer.ensureDeferred(
+            self.filtering.get_user_filter(
+                user_localpart=user_localpart, filter_id=filter_id
+            )
         )
 
         results = user_filter.filter_room_state(events=events)
@@ -430,8 +436,10 @@ class FilteringTestCase(unittest.TestCase):
         )
         events = [event]
 
-        user_filter = yield self.filtering.get_user_filter(
-            user_localpart=user_localpart, filter_id=filter_id
+        user_filter = yield defer.ensureDeferred(
+            self.filtering.get_user_filter(
+                user_localpart=user_localpart, filter_id=filter_id
+            )
         )
 
         results = user_filter.filter_room_state(events)
@@ -465,8 +473,10 @@ class FilteringTestCase(unittest.TestCase):
         self.assertEquals(
             user_filter_json,
             (
-                yield self.datastore.get_user_filter(
-                    user_localpart=user_localpart, filter_id=0
+                yield defer.ensureDeferred(
+                    self.datastore.get_user_filter(
+                        user_localpart=user_localpart, filter_id=0
+                    )
                 )
             ),
         )
@@ -479,8 +489,10 @@ class FilteringTestCase(unittest.TestCase):
             user_localpart=user_localpart, user_filter=user_filter_json
         )
 
-        filter = yield self.filtering.get_user_filter(
-            user_localpart=user_localpart, filter_id=filter_id
+        filter = yield defer.ensureDeferred(
+            self.filtering.get_user_filter(
+                user_localpart=user_localpart, filter_id=filter_id
+            )
         )
 
         self.assertEquals(filter.get_filter_json(), user_filter_json)