summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/api/test_auth.py69
-rw-r--r--tests/api/test_filtering.py36
-rw-r--r--tests/handlers/test_typing.py4
-rw-r--r--tests/rest/admin/test_user.py10
-rw-r--r--tests/rest/client/v1/test_profile.py4
-rw-r--r--tests/rest/client/v1/test_rooms.py6
-rw-r--r--tests/rest/client/v1/test_typing.py6
-rw-r--r--tests/server_notices/test_resource_limits_server_notices.py2
-rw-r--r--tests/unittest.py24
9 files changed, 93 insertions, 68 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 0bfb86bf1f..5d45689c8c 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -62,12 +62,15 @@ class AuthTestCase(unittest.TestCase):
         # this is overridden for the appservice tests
         self.store.get_app_service_by_token = Mock(return_value=None)
 
+        self.store.insert_client_ip = Mock(return_value=defer.succeed(None))
         self.store.is_support_user = Mock(return_value=defer.succeed(False))
 
     @defer.inlineCallbacks
     def test_get_user_by_req_user_valid_token(self):
         user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"}
-        self.store.get_user_by_access_token = Mock(return_value=user_info)
+        self.store.get_user_by_access_token = Mock(
+            return_value=defer.succeed(user_info)
+        )
 
         request = Mock(args={})
         request.args[b"access_token"] = [self.test_token]
@@ -76,23 +79,25 @@ class AuthTestCase(unittest.TestCase):
         self.assertEquals(requester.user.to_string(), self.test_user)
 
     def test_get_user_by_req_user_bad_token(self):
-        self.store.get_user_by_access_token = Mock(return_value=None)
+        self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
 
         request = Mock(args={})
         request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
-        d = self.auth.get_user_by_req(request)
+        d = defer.ensureDeferred(self.auth.get_user_by_req(request))
         f = self.failureResultOf(d, InvalidClientTokenError).value
         self.assertEqual(f.code, 401)
         self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
 
     def test_get_user_by_req_user_missing_token(self):
         user_info = {"name": self.test_user, "token_id": "ditto"}
-        self.store.get_user_by_access_token = Mock(return_value=user_info)
+        self.store.get_user_by_access_token = Mock(
+            return_value=defer.succeed(user_info)
+        )
 
         request = Mock(args={})
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
-        d = self.auth.get_user_by_req(request)
+        d = defer.ensureDeferred(self.auth.get_user_by_req(request))
         f = self.failureResultOf(d, MissingClientTokenError).value
         self.assertEqual(f.code, 401)
         self.assertEqual(f.errcode, "M_MISSING_TOKEN")
@@ -103,7 +108,7 @@ class AuthTestCase(unittest.TestCase):
             token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
         )
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_access_token = Mock(return_value=None)
+        self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
 
         request = Mock(args={})
         request.getClientIP.return_value = "127.0.0.1"
@@ -123,7 +128,7 @@ class AuthTestCase(unittest.TestCase):
             ip_range_whitelist=IPSet(["192.168/16"]),
         )
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_access_token = Mock(return_value=None)
+        self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
 
         request = Mock(args={})
         request.getClientIP.return_value = "192.168.10.10"
@@ -142,25 +147,25 @@ class AuthTestCase(unittest.TestCase):
             ip_range_whitelist=IPSet(["192.168/16"]),
         )
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_access_token = Mock(return_value=None)
+        self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
 
         request = Mock(args={})
         request.getClientIP.return_value = "131.111.8.42"
         request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
-        d = self.auth.get_user_by_req(request)
+        d = defer.ensureDeferred(self.auth.get_user_by_req(request))
         f = self.failureResultOf(d, InvalidClientTokenError).value
         self.assertEqual(f.code, 401)
         self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
 
     def test_get_user_by_req_appservice_bad_token(self):
         self.store.get_app_service_by_token = Mock(return_value=None)
-        self.store.get_user_by_access_token = Mock(return_value=None)
+        self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
 
         request = Mock(args={})
         request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
-        d = self.auth.get_user_by_req(request)
+        d = defer.ensureDeferred(self.auth.get_user_by_req(request))
         f = self.failureResultOf(d, InvalidClientTokenError).value
         self.assertEqual(f.code, 401)
         self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
@@ -168,11 +173,11 @@ class AuthTestCase(unittest.TestCase):
     def test_get_user_by_req_appservice_missing_token(self):
         app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_access_token = Mock(return_value=None)
+        self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
 
         request = Mock(args={})
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
-        d = self.auth.get_user_by_req(request)
+        d = defer.ensureDeferred(self.auth.get_user_by_req(request))
         f = self.failureResultOf(d, MissingClientTokenError).value
         self.assertEqual(f.code, 401)
         self.assertEqual(f.errcode, "M_MISSING_TOKEN")
@@ -185,7 +190,11 @@ class AuthTestCase(unittest.TestCase):
         )
         app_service.is_interested_in_user = Mock(return_value=True)
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_access_token = Mock(return_value=None)
+        # This just needs to return a truth-y value.
+        self.store.get_user_by_id = Mock(
+            return_value=defer.succeed({"is_guest": False})
+        )
+        self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
 
         request = Mock(args={})
         request.getClientIP.return_value = "127.0.0.1"
@@ -204,20 +213,22 @@ class AuthTestCase(unittest.TestCase):
         )
         app_service.is_interested_in_user = Mock(return_value=False)
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_access_token = Mock(return_value=None)
+        self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
 
         request = Mock(args={})
         request.getClientIP.return_value = "127.0.0.1"
         request.args[b"access_token"] = [self.test_token]
         request.args[b"user_id"] = [masquerading_user_id]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
-        d = self.auth.get_user_by_req(request)
+        d = defer.ensureDeferred(self.auth.get_user_by_req(request))
         self.failureResultOf(d, AuthError)
 
     @defer.inlineCallbacks
     def test_get_user_from_macaroon(self):
         self.store.get_user_by_access_token = Mock(
-            return_value={"name": "@baldrick:matrix.org", "device_id": "device"}
+            return_value=defer.succeed(
+                {"name": "@baldrick:matrix.org", "device_id": "device"}
+            )
         )
 
         user_id = "@baldrick:matrix.org"
@@ -241,8 +252,8 @@ class AuthTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_get_guest_user_from_macaroon(self):
-        self.store.get_user_by_id = Mock(return_value={"is_guest": True})
-        self.store.get_user_by_access_token = Mock(return_value=None)
+        self.store.get_user_by_id = Mock(return_value=defer.succeed({"is_guest": True}))
+        self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
 
         user_id = "@baldrick:matrix.org"
         macaroon = pymacaroons.Macaroon(
@@ -282,16 +293,20 @@ class AuthTestCase(unittest.TestCase):
 
         def get_user(tok):
             if token != tok:
-                return None
-            return {
-                "name": USER_ID,
-                "is_guest": False,
-                "token_id": 1234,
-                "device_id": "DEVICE",
-            }
+                return defer.succeed(None)
+            return defer.succeed(
+                {
+                    "name": USER_ID,
+                    "is_guest": False,
+                    "token_id": 1234,
+                    "device_id": "DEVICE",
+                }
+            )
 
         self.store.get_user_by_access_token = get_user
-        self.store.get_user_by_id = Mock(return_value={"is_guest": False})
+        self.store.get_user_by_id = Mock(
+            return_value=defer.succeed({"is_guest": False})
+        )
 
         # check the token works
         request = Mock(args={})
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 4e67503cf0..1fab1d6b69 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -375,8 +375,10 @@ class FilteringTestCase(unittest.TestCase):
         event = MockEvent(sender="@foo:bar", type="m.profile")
         events = [event]
 
-        user_filter = yield self.filtering.get_user_filter(
-            user_localpart=user_localpart, filter_id=filter_id
+        user_filter = yield defer.ensureDeferred(
+            self.filtering.get_user_filter(
+                user_localpart=user_localpart, filter_id=filter_id
+            )
         )
 
         results = user_filter.filter_presence(events=events)
@@ -396,8 +398,10 @@ class FilteringTestCase(unittest.TestCase):
         )
         events = [event]
 
-        user_filter = yield self.filtering.get_user_filter(
-            user_localpart=user_localpart + "2", filter_id=filter_id
+        user_filter = yield defer.ensureDeferred(
+            self.filtering.get_user_filter(
+                user_localpart=user_localpart + "2", filter_id=filter_id
+            )
         )
 
         results = user_filter.filter_presence(events=events)
@@ -412,8 +416,10 @@ class FilteringTestCase(unittest.TestCase):
         event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
         events = [event]
 
-        user_filter = yield self.filtering.get_user_filter(
-            user_localpart=user_localpart, filter_id=filter_id
+        user_filter = yield defer.ensureDeferred(
+            self.filtering.get_user_filter(
+                user_localpart=user_localpart, filter_id=filter_id
+            )
         )
 
         results = user_filter.filter_room_state(events=events)
@@ -430,8 +436,10 @@ class FilteringTestCase(unittest.TestCase):
         )
         events = [event]
 
-        user_filter = yield self.filtering.get_user_filter(
-            user_localpart=user_localpart, filter_id=filter_id
+        user_filter = yield defer.ensureDeferred(
+            self.filtering.get_user_filter(
+                user_localpart=user_localpart, filter_id=filter_id
+            )
         )
 
         results = user_filter.filter_room_state(events)
@@ -465,8 +473,10 @@ class FilteringTestCase(unittest.TestCase):
         self.assertEquals(
             user_filter_json,
             (
-                yield self.datastore.get_user_filter(
-                    user_localpart=user_localpart, filter_id=0
+                yield defer.ensureDeferred(
+                    self.datastore.get_user_filter(
+                        user_localpart=user_localpart, filter_id=0
+                    )
                 )
             ),
         )
@@ -479,8 +489,10 @@ class FilteringTestCase(unittest.TestCase):
             user_localpart=user_localpart, user_filter=user_filter_json
         )
 
-        filter = yield self.filtering.get_user_filter(
-            user_localpart=user_localpart, filter_id=filter_id
+        filter = yield defer.ensureDeferred(
+            self.filtering.get_user_filter(
+                user_localpart=user_localpart, filter_id=filter_id
+            )
         )
 
         self.assertEquals(filter.get_filter_json(), user_filter_json)
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 5878f74175..b7d0adb10e 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -126,10 +126,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.room_members = []
 
-        def check_user_in_room(room_id, user_id):
+        async def check_user_in_room(room_id, user_id):
             if user_id not in [u.to_string() for u in self.room_members]:
                 raise AuthError(401, "User is not in the room")
-            return defer.succeed(None)
+            return None
 
         hs.get_auth().check_user_in_room = check_user_in_room
 
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index f16eef15f7..17d0aae2e9 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -20,6 +20,8 @@ import urllib.parse
 
 from mock import Mock
 
+from twisted.internet import defer
+
 import synapse.rest.admin
 from synapse.api.constants import UserTypes
 from synapse.api.errors import HttpResponseException, ResourceLimitError
@@ -335,7 +337,9 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         store = self.hs.get_datastore()
 
         # Set monthly active users to the limit
-        store.get_monthly_active_count = Mock(return_value=self.hs.config.max_mau_value)
+        store.get_monthly_active_count = Mock(
+            return_value=defer.succeed(self.hs.config.max_mau_value)
+        )
         # Check that the blocking of monthly active users is working as expected
         # The registration of a new user fails due to the limit
         self.get_failure(
@@ -588,7 +592,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
 
         # Set monthly active users to the limit
         self.store.get_monthly_active_count = Mock(
-            return_value=self.hs.config.max_mau_value
+            return_value=defer.succeed(self.hs.config.max_mau_value)
         )
         # Check that the blocking of monthly active users is working as expected
         # The registration of a new user fails due to the limit
@@ -628,7 +632,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
 
         # Set monthly active users to the limit
         self.store.get_monthly_active_count = Mock(
-            return_value=self.hs.config.max_mau_value
+            return_value=defer.succeed(self.hs.config.max_mau_value)
         )
         # Check that the blocking of monthly active users is working as expected
         # The registration of a new user fails due to the limit
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index 8df58b4a63..ace0a3c08d 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -70,8 +70,8 @@ class MockHandlerProfileTestCase(unittest.TestCase):
             profile_handler=self.mock_handler,
         )
 
-        def _get_user_by_req(request=None, allow_guest=False):
-            return defer.succeed(synapse.types.create_requester(myid))
+        async def _get_user_by_req(request=None, allow_guest=False):
+            return synapse.types.create_requester(myid)
 
         hs.get_auth().get_user_by_req = _get_user_by_req
 
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 5ccda8b2bd..ef6b775ed2 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -23,8 +23,6 @@ from urllib import parse as urlparse
 
 from mock import Mock
 
-from twisted.internet import defer
-
 import synapse.rest.admin
 from synapse.api.constants import EventContentFields, EventTypes, Membership
 from synapse.handlers.pagination import PurgeStatus
@@ -51,8 +49,8 @@ class RoomBase(unittest.HomeserverTestCase):
 
         self.hs.get_federation_handler = Mock(return_value=Mock())
 
-        def _insert_client_ip(*args, **kwargs):
-            return defer.succeed(None)
+        async def _insert_client_ip(*args, **kwargs):
+            return None
 
         self.hs.get_datastore().insert_client_ip = _insert_client_ip
 
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 18260bb90e..94d2bf2eb1 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -46,7 +46,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
 
         hs.get_handlers().federation_handler = Mock()
 
-        def get_user_by_access_token(token=None, allow_guest=False):
+        async def get_user_by_access_token(token=None, allow_guest=False):
             return {
                 "user": UserID.from_string(self.auth_user_id),
                 "token_id": 1,
@@ -55,8 +55,8 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
 
         hs.get_auth().get_user_by_access_token = get_user_by_access_token
 
-        def _insert_client_ip(*args, **kwargs):
-            return defer.succeed(None)
+        async def _insert_client_ip(*args, **kwargs):
+            return None
 
         hs.get_datastore().insert_client_ip = _insert_client_ip
 
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 7f70353b0d..3f88abe3d2 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -258,7 +258,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
         self.user_id = "@user_id:test"
 
     def test_server_notice_only_sent_once(self):
-        self.store.get_monthly_active_count = Mock(return_value=1000)
+        self.store.get_monthly_active_count = Mock(return_value=defer.succeed(1000))
 
         self.store.user_last_seen_monthly_active = Mock(
             return_value=defer.succeed(1000)
diff --git a/tests/unittest.py b/tests/unittest.py
index 2152c693f2..d0bba3ddef 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -241,20 +241,16 @@ class HomeserverTestCase(TestCase):
         if hasattr(self, "user_id"):
             if self.hijack_auth:
 
-                def get_user_by_access_token(token=None, allow_guest=False):
-                    return succeed(
-                        {
-                            "user": UserID.from_string(self.helper.auth_user_id),
-                            "token_id": 1,
-                            "is_guest": False,
-                        }
-                    )
-
-                def get_user_by_req(request, allow_guest=False, rights="access"):
-                    return succeed(
-                        create_requester(
-                            UserID.from_string(self.helper.auth_user_id), 1, False, None
-                        )
+                async def get_user_by_access_token(token=None, allow_guest=False):
+                    return {
+                        "user": UserID.from_string(self.helper.auth_user_id),
+                        "token_id": 1,
+                        "is_guest": False,
+                    }
+
+                async def get_user_by_req(request, allow_guest=False, rights="access"):
+                    return create_requester(
+                        UserID.from_string(self.helper.auth_user_id), 1, False, None
                     )
 
                 self.hs.get_auth().get_user_by_req = get_user_by_req