summary refs log tree commit diff
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2022-04-01 16:10:31 +0100
committerGitHub <noreply@github.com>2022-04-01 16:10:31 +0100
commit33ebee47e4e96a2b6fdf72091769e59034dc550f (patch)
treee77cad2918fd13cd14100e8e7f2e0856900a69ca
parentDefault to `private` room visibility rather than `public` when a client does ... (diff)
downloadsynapse-33ebee47e4e96a2b6fdf72091769e59034dc550f.tar.xz
Remove redundant `get_success` calls in test code (#12346)
There are a bunch of places we call get_success on an immediate value, which is unnecessary. Let's rip them out, and remove the redundant functionality in get_success and friends.
Diffstat (limited to '')
-rw-r--r--changelog.d/12346.misc1
-rw-r--r--tests/handlers/test_deactivate_account.py25
-rw-r--r--tests/handlers/test_profile.py2
-rw-r--r--tests/handlers/test_sync.py4
-rw-r--r--tests/module_api/test_api.py20
-rw-r--r--tests/replication/slave/storage/test_events.py4
-rw-r--r--tests/rest/admin/test_server_notice.py12
-rw-r--r--tests/rest/client/test_rooms.py2
-rw-r--r--tests/storage/test_id_generators.py12
-rw-r--r--tests/storage/test_redaction.py62
-rw-r--r--tests/storage/test_stream.py4
-rw-r--r--tests/test_visibility.py26
-rw-r--r--tests/unittest.py21
13 files changed, 74 insertions, 121 deletions
diff --git a/changelog.d/12346.misc b/changelog.d/12346.misc
new file mode 100644
index 0000000000..6561b3be82
--- /dev/null
+++ b/changelog.d/12346.misc
@@ -0,0 +1 @@
+Remove redundant `get_success` calls in test code.
diff --git a/tests/handlers/test_deactivate_account.py b/tests/handlers/test_deactivate_account.py
index 3a10791226..7586e472b5 100644
--- a/tests/handlers/test_deactivate_account.py
+++ b/tests/handlers/test_deactivate_account.py
@@ -44,21 +44,20 @@ class DeactivateAccountTestCase(HomeserverTestCase):
         Deactivates the account `self.user` using `self.token` and asserts
         that it returns a 200 success code.
         """
-        req = self.get_success(
-            self.make_request(
-                "POST",
-                "account/deactivate",
-                {
-                    "auth": {
-                        "type": "m.login.password",
-                        "user": self.user,
-                        "password": "pass",
-                    },
-                    "erase": True,
+        req = self.make_request(
+            "POST",
+            "account/deactivate",
+            {
+                "auth": {
+                    "type": "m.login.password",
+                    "user": self.user,
+                    "password": "pass",
                 },
-                access_token=self.token,
-            )
+                "erase": True,
+            },
+            access_token=self.token,
         )
+
         self.assertEqual(req.code, HTTPStatus.OK, req)
 
     def test_global_account_data_deleted_upon_deactivation(self) -> None:
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 1ec105c373..f88c725a42 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -59,7 +59,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         self.bob = UserID.from_string("@4567:test")
         self.alice = UserID.from_string("@alice:remote")
 
-        self.get_success(self.register_user(self.frank.localpart, "frankpassword"))
+        self.register_user(self.frank.localpart, "frankpassword")
 
         self.handler = hs.get_profile_handler()
 
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 3aedc0767b..865b8b7e47 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -158,9 +158,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
             )
 
         # Blow away caches (supported room versions can only change due to a restart).
-        self.get_success(
-            self.store.get_rooms_for_user_with_stream_ordering.invalidate_all()
-        )
+        self.store.get_rooms_for_user_with_stream_ordering.invalidate_all()
         self.store._get_event_cache.clear()
 
         # The rooms should be excluded from the sync response.
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index dee248801d..9fd5d59c55 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -87,24 +87,22 @@ class ModuleApiTestCase(HomeserverTestCase):
         self.assertEqual(displayname, "Bobberino")
 
     def test_can_register_admin_user(self):
-        user_id = self.get_success(
-            self.register_user(
-                "bob_module_admin", "1234", displayname="Bobberino Admin", admin=True
-            )
+        user_id = self.register_user(
+            "bob_module_admin", "1234", displayname="Bobberino Admin", admin=True
         )
+
         found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
         self.assertEqual(found_user.user_id.to_string(), user_id)
         self.assertIdentical(found_user.is_admin, True)
 
     def test_can_set_admin(self):
-        user_id = self.get_success(
-            self.register_user(
-                "alice_wants_admin",
-                "1234",
-                displayname="Alice Powerhungry",
-                admin=False,
-            )
+        user_id = self.register_user(
+            "alice_wants_admin",
+            "1234",
+            displayname="Alice Powerhungry",
+            admin=False,
         )
+
         self.get_success(self.module_api.set_user_admin(user_id, True))
         found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
         self.assertEqual(found_user.user_id.to_string(), user_id)
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 17dc42fd37..297a9e77f8 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -268,7 +268,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
 
         event_source = RoomEventSource(self.hs)
         event_source.store = self.slaved_store
-        current_token = self.get_success(event_source.get_current_key())
+        current_token = event_source.get_current_key()
 
         # gradually stream out the replication
         while repl_transport.buffer:
@@ -277,7 +277,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
             self.pump(0)
 
             prev_token = current_token
-            current_token = self.get_success(event_source.get_current_key())
+            current_token = event_source.get_current_key()
 
             # attempt to replicate the behaviour of the sync handler.
             #
diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py
index 2c855bff99..a53463c9ba 100644
--- a/tests/rest/admin/test_server_notice.py
+++ b/tests/rest/admin/test_server_notice.py
@@ -214,9 +214,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
         self.assertEqual(messages[0]["sender"], "@notices:test")
 
         # invalidate cache of server notices room_ids
-        self.get_success(
-            self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all()
-        )
+        self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all()
 
         # send second message
         channel = self.make_request(
@@ -291,9 +289,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
         # invalidate cache of server notices room_ids
         # if server tries to send to a cached room_id the user gets the message
         # in old room
-        self.get_success(
-            self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all()
-        )
+        self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all()
 
         # send second message
         channel = self.make_request(
@@ -380,9 +376,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
 
         # invalidate cache of server notices room_ids
         # if server tries to send to a cached room_id it gives an error
-        self.get_success(
-            self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all()
-        )
+        self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all()
 
         # send second message
         channel = self.make_request(
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 3a9617d6da..6ff79b9e2e 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -982,7 +982,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
         super().prepare(reactor, clock, hs)
         # profile changes expect that the user is actually registered
         user = UserID.from_string(self.user_id)
-        self.get_success(self.register_user(user.localpart, "supersecretpassword"))
+        self.register_user(user.localpart, "supersecretpassword")
 
     @unittest.override_config(
         {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 395396340b..2d8d1f860f 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -157,10 +157,10 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self.assertEqual(id_gen.get_positions(), {"master": 7})
         self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
 
-        ctx1 = self.get_success(id_gen.get_next())
-        ctx2 = self.get_success(id_gen.get_next())
-        ctx3 = self.get_success(id_gen.get_next())
-        ctx4 = self.get_success(id_gen.get_next())
+        ctx1 = id_gen.get_next()
+        ctx2 = id_gen.get_next()
+        ctx3 = id_gen.get_next()
+        ctx4 = id_gen.get_next()
 
         s1 = self.get_success(ctx1.__aenter__())
         s2 = self.get_success(ctx2.__aenter__())
@@ -362,8 +362,8 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
 
         # Persist two rows at once
-        ctx1 = self.get_success(id_gen.get_next())
-        ctx2 = self.get_success(id_gen.get_next())
+        ctx1 = id_gen.get_next()
+        ctx2 = id_gen.get_next()
 
         s1 = self.get_success(ctx1.__aenter__())
         s2 = self.get_success(ctx2.__aenter__())
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 03e9cc7d4a..d8d17ef379 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -119,11 +119,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
         return event
 
     def test_redact(self):
-        self.get_success(
-            self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
-        )
+        self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
 
-        msg_event = self.get_success(self.inject_message(self.room1, self.u_alice, "t"))
+        msg_event = self.inject_message(self.room1, self.u_alice, "t")
 
         # Check event has not been redacted:
         event = self.get_success(self.store.get_event(msg_event.event_id))
@@ -141,9 +139,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
 
         # Redact event
         reason = "Because I said so"
-        self.get_success(
-            self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason)
-        )
+        self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason)
 
         event = self.get_success(self.store.get_event(msg_event.event_id))
 
@@ -170,14 +166,10 @@ class RedactionTestCase(unittest.HomeserverTestCase):
         )
 
     def test_redact_join(self):
-        self.get_success(
-            self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
-        )
+        self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
 
-        msg_event = self.get_success(
-            self.inject_room_member(
-                self.room1, self.u_bob, Membership.JOIN, extra_content={"blue": "red"}
-            )
+        msg_event = self.inject_room_member(
+            self.room1, self.u_bob, Membership.JOIN, extra_content={"blue": "red"}
         )
 
         event = self.get_success(self.store.get_event(msg_event.event_id))
@@ -195,9 +187,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
 
         # Redact event
         reason = "Because I said so"
-        self.get_success(
-            self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason)
-        )
+        self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason)
 
         # Check redaction
 
@@ -311,11 +301,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
     def test_redact_censor(self):
         """Test that a redacted event gets censored in the DB after a month"""
 
-        self.get_success(
-            self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
-        )
+        self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
 
-        msg_event = self.get_success(self.inject_message(self.room1, self.u_alice, "t"))
+        msg_event = self.inject_message(self.room1, self.u_alice, "t")
 
         # Check event has not been redacted:
         event = self.get_success(self.store.get_event(msg_event.event_id))
@@ -333,9 +321,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
 
         # Redact event
         reason = "Because I said so"
-        self.get_success(
-            self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason)
-        )
+        self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason)
 
         event = self.get_success(self.store.get_event(msg_event.event_id))
 
@@ -381,25 +367,19 @@ class RedactionTestCase(unittest.HomeserverTestCase):
     def test_redact_redaction(self):
         """Tests that we can redact a redaction and can fetch it again."""
 
-        self.get_success(
-            self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
-        )
+        self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
 
-        msg_event = self.get_success(self.inject_message(self.room1, self.u_alice, "t"))
+        msg_event = self.inject_message(self.room1, self.u_alice, "t")
 
-        first_redact_event = self.get_success(
-            self.inject_redaction(
-                self.room1, msg_event.event_id, self.u_alice, "Redacting message"
-            )
+        first_redact_event = self.inject_redaction(
+            self.room1, msg_event.event_id, self.u_alice, "Redacting message"
         )
 
-        self.get_success(
-            self.inject_redaction(
-                self.room1,
-                first_redact_event.event_id,
-                self.u_alice,
-                "Redacting redaction",
-            )
+        self.inject_redaction(
+            self.room1,
+            first_redact_event.event_id,
+            self.u_alice,
+            "Redacting redaction",
         )
 
         # Now lets jump to the future where we have censored the redaction event
@@ -414,9 +394,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
     def test_store_redacted_redaction(self):
         """Tests that we can store a redacted redaction."""
 
-        self.get_success(
-            self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
-        )
+        self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
 
         builder = self.event_builder_factory.for_room_version(
             RoomVersions.V1,
diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py
index eaa0d7d749..52e41cdab4 100644
--- a/tests/storage/test_stream.py
+++ b/tests/storage/test_stream.py
@@ -110,9 +110,7 @@ class PaginationTestCase(HomeserverTestCase):
     def _filter_messages(self, filter: JsonDict) -> List[EventBase]:
         """Make a request to /messages with a filter, returns the chunk of events."""
 
-        from_token = self.get_success(
-            self.hs.get_event_sources().get_current_token_for_pagination()
-        )
+        from_token = self.hs.get_event_sources().get_current_token_for_pagination()
 
         events, next_key = self.get_success(
             self.hs.get_datastores().main.paginate_room_events(
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index a02fd4f79a..d0230f9ebb 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -48,17 +48,15 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
         #
 
         # before we do that, we persist some other events to act as state.
-        self.get_success(self._inject_visibility("@admin:hs", "joined"))
+        self._inject_visibility("@admin:hs", "joined")
         for i in range(0, 10):
-            self.get_success(self._inject_room_member("@resident%i:hs" % i))
+            self._inject_room_member("@resident%i:hs" % i)
 
         events_to_filter = []
 
         for i in range(0, 10):
             user = "@user%i:%s" % (i, "test_server" if i == 5 else "other_server")
-            evt = self.get_success(
-                self._inject_room_member(user, extra_content={"a": "b"})
-            )
+            evt = self._inject_room_member(user, extra_content={"a": "b"})
             events_to_filter.append(evt)
 
         filtered = self.get_success(
@@ -76,10 +74,10 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
 
     def test_filter_outlier(self) -> None:
         # outlier events must be returned, for the good of the collective federation
-        self.get_success(self._inject_room_member("@resident:remote_hs"))
-        self.get_success(self._inject_visibility("@resident:remote_hs", "joined"))
+        self._inject_room_member("@resident:remote_hs")
+        self._inject_visibility("@resident:remote_hs", "joined")
 
-        outlier = self.get_success(self._inject_outlier())
+        outlier = self._inject_outlier()
         self.assertEqual(
             self.get_success(
                 filter_events_for_server(self.storage, "remote_hs", [outlier])
@@ -88,7 +86,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
         )
 
         # it should also work when there are other events in the list
-        evt = self.get_success(self._inject_message("@unerased:local_hs"))
+        evt = self._inject_message("@unerased:local_hs")
 
         filtered = self.get_success(
             filter_events_for_server(self.storage, "remote_hs", [outlier, evt])
@@ -112,19 +110,19 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
         # change in the middle of them.
         events_to_filter = []
 
-        evt = self.get_success(self._inject_message("@unerased:local_hs"))
+        evt = self._inject_message("@unerased:local_hs")
         events_to_filter.append(evt)
 
-        evt = self.get_success(self._inject_message("@erased:local_hs"))
+        evt = self._inject_message("@erased:local_hs")
         events_to_filter.append(evt)
 
-        evt = self.get_success(self._inject_room_member("@joiner:remote_hs"))
+        evt = self._inject_room_member("@joiner:remote_hs")
         events_to_filter.append(evt)
 
-        evt = self.get_success(self._inject_message("@unerased:local_hs"))
+        evt = self._inject_message("@unerased:local_hs")
         events_to_filter.append(evt)
 
-        evt = self.get_success(self._inject_message("@erased:local_hs"))
+        evt = self._inject_message("@erased:local_hs")
         events_to_filter.append(evt)
 
         # the erasey user gets erased
diff --git a/tests/unittest.py b/tests/unittest.py
index cbe215ee83..5b19065c71 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -16,7 +16,6 @@
 import gc
 import hashlib
 import hmac
-import inspect
 import json
 import logging
 import secrets
@@ -519,33 +518,23 @@ class HomeserverTestCase(TestCase):
         self.reactor.pump([by] * 100)
 
     def get_success(self, d, by=0.0):
-        if inspect.isawaitable(d):
-            d = ensureDeferred(d)
-        if not isinstance(d, Deferred):
-            return d
+        deferred: Deferred[TV] = ensureDeferred(d)
         self.pump(by=by)
-        return self.successResultOf(d)
+        return self.successResultOf(deferred)
 
     def get_failure(self, d, exc):
         """
         Run a Deferred and get a Failure from it. The failure must be of the type `exc`.
         """
-        if inspect.isawaitable(d):
-            d = ensureDeferred(d)
-        if not isinstance(d, Deferred):
-            return d
+        deferred: Deferred[Any] = ensureDeferred(d)
         self.pump()
-        return self.failureResultOf(d, exc)
+        return self.failureResultOf(deferred, exc)
 
     def get_success_or_raise(self, d, by=0.0):
         """Drive deferred to completion and return result or raise exception
         on failure.
         """
-
-        if inspect.isawaitable(d):
-            deferred = ensureDeferred(d)
-        if not isinstance(deferred, Deferred):
-            return d
+        deferred: Deferred[TV] = ensureDeferred(d)
 
         results: list = []
         deferred.addBoth(results.append)