summary refs log tree commit diff
path: root/tests/push
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2023-02-14 14:03:35 -0500
committerGitHub <noreply@github.com>2023-02-14 14:03:35 -0500
commit42aea0d8af1556473b4f31f78d9facb448230a1f (patch)
treef633442e29a23705f45ca8daa148a26d12772af5 /tests/push
parentImplement MSC3966: Add a push rule condition to search for a value in an arra... (diff)
downloadsynapse-42aea0d8af1556473b4f31f78d9facb448230a1f.tar.xz
Add final type hint to tests.unittest. (#15072)
Adds a return type to HomeServerTestCase.make_homeserver and deal
with any variables which are no longer Any.
Diffstat (limited to 'tests/push')
-rw-r--r--tests/push/test_email.py51
-rw-r--r--tests/push/test_http.py45
2 files changed, 64 insertions, 32 deletions
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index ab8bb417e7..7563f33fdc 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -23,6 +23,7 @@ from twisted.test.proto_helpers import MemoryReactor
 
 import synapse.rest.admin
 from synapse.api.errors import Codes, SynapseError
+from synapse.push.emailpusher import EmailPusher
 from synapse.rest.client import login, room
 from synapse.server import HomeServer
 from synapse.util import Clock
@@ -105,6 +106,7 @@ class EmailPusherTests(HomeserverTestCase):
         user_tuple = self.get_success(
             self.hs.get_datastores().main.get_user_by_access_token(self.access_token)
         )
+        assert user_tuple is not None
         self.token_id = user_tuple.token_id
 
         # We need to add email to account before we can create a pusher.
@@ -114,7 +116,7 @@ class EmailPusherTests(HomeserverTestCase):
             )
         )
 
-        self.pusher = self.get_success(
+        pusher = self.get_success(
             self.hs.get_pusherpool().add_or_update_pusher(
                 user_id=self.user_id,
                 access_token=self.token_id,
@@ -127,6 +129,8 @@ class EmailPusherTests(HomeserverTestCase):
                 data={},
             )
         )
+        assert isinstance(pusher, EmailPusher)
+        self.pusher = pusher
 
         self.auth_handler = hs.get_auth_handler()
         self.store = hs.get_datastores().main
@@ -375,10 +379,13 @@ class EmailPusherTests(HomeserverTestCase):
         )
 
         # check that the pusher for that email address has been deleted
-        pushers = self.get_success(
-            self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id})
+        pushers = list(
+            self.get_success(
+                self.hs.get_datastores().main.get_pushers_by(
+                    {"user_name": self.user_id}
+                )
+            )
         )
-        pushers = list(pushers)
         self.assertEqual(len(pushers), 0)
 
     def test_remove_unlinked_pushers_background_job(self) -> None:
@@ -413,10 +420,13 @@ class EmailPusherTests(HomeserverTestCase):
         self.wait_for_background_updates()
 
         # Check that all pushers with unlinked addresses were deleted
-        pushers = self.get_success(
-            self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id})
+        pushers = list(
+            self.get_success(
+                self.hs.get_datastores().main.get_pushers_by(
+                    {"user_name": self.user_id}
+                )
+            )
         )
-        pushers = list(pushers)
         self.assertEqual(len(pushers), 0)
 
     def _check_for_mail(self) -> Tuple[Sequence, Dict]:
@@ -428,10 +438,13 @@ class EmailPusherTests(HomeserverTestCase):
             that notification.
         """
         # Get the stream ordering before it gets sent
-        pushers = self.get_success(
-            self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id})
+        pushers = list(
+            self.get_success(
+                self.hs.get_datastores().main.get_pushers_by(
+                    {"user_name": self.user_id}
+                )
+            )
         )
-        pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
         last_stream_ordering = pushers[0].last_stream_ordering
 
@@ -439,10 +452,13 @@ class EmailPusherTests(HomeserverTestCase):
         self.pump(10)
 
         # It hasn't succeeded yet, so the stream ordering shouldn't have moved
-        pushers = self.get_success(
-            self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id})
+        pushers = list(
+            self.get_success(
+                self.hs.get_datastores().main.get_pushers_by(
+                    {"user_name": self.user_id}
+                )
+            )
         )
-        pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
         self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering)
 
@@ -458,10 +474,13 @@ class EmailPusherTests(HomeserverTestCase):
         self.assertEqual(len(self.email_attempts), 1)
 
         # The stream ordering has increased
-        pushers = self.get_success(
-            self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id})
+        pushers = list(
+            self.get_success(
+                self.hs.get_datastores().main.get_pushers_by(
+                    {"user_name": self.user_id}
+                )
+            )
         )
-        pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
         self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
 
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index 23447cc310..c280ddcdf6 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -11,7 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import List, Optional, Tuple
+from typing import Any, List, Tuple
 from unittest.mock import Mock
 
 from twisted.internet.defer import Deferred
@@ -22,7 +22,6 @@ from synapse.logging.context import make_deferred_yieldable
 from synapse.push import PusherConfig, PusherConfigException
 from synapse.rest.client import login, push_rule, pusher, receipts, room
 from synapse.server import HomeServer
-from synapse.storage.databases.main.registration import TokenLookupResult
 from synapse.types import JsonDict
 from synapse.util import Clock
 
@@ -67,9 +66,10 @@ class HTTPPusherTests(HomeserverTestCase):
         user_tuple = self.get_success(
             self.hs.get_datastores().main.get_user_by_access_token(access_token)
         )
+        assert user_tuple is not None
         token_id = user_tuple.token_id
 
-        def test_data(data: Optional[JsonDict]) -> None:
+        def test_data(data: Any) -> None:
             self.get_failure(
                 self.hs.get_pusherpool().add_or_update_pusher(
                     user_id=user_id,
@@ -113,6 +113,7 @@ class HTTPPusherTests(HomeserverTestCase):
         user_tuple = self.get_success(
             self.hs.get_datastores().main.get_user_by_access_token(access_token)
         )
+        assert user_tuple is not None
         token_id = user_tuple.token_id
 
         self.get_success(
@@ -140,10 +141,11 @@ class HTTPPusherTests(HomeserverTestCase):
         self.helper.send(room, body="There!", tok=other_access_token)
 
         # Get the stream ordering before it gets sent
-        pushers = self.get_success(
-            self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+        pushers = list(
+            self.get_success(
+                self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+            )
         )
-        pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
         last_stream_ordering = pushers[0].last_stream_ordering
 
@@ -151,10 +153,11 @@ class HTTPPusherTests(HomeserverTestCase):
         self.pump()
 
         # It hasn't succeeded yet, so the stream ordering shouldn't have moved
-        pushers = self.get_success(
-            self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+        pushers = list(
+            self.get_success(
+                self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+            )
         )
-        pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
         self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering)
 
@@ -172,10 +175,11 @@ class HTTPPusherTests(HomeserverTestCase):
         self.pump()
 
         # The stream ordering has increased
-        pushers = self.get_success(
-            self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+        pushers = list(
+            self.get_success(
+                self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+            )
         )
-        pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
         self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
         last_stream_ordering = pushers[0].last_stream_ordering
@@ -194,10 +198,11 @@ class HTTPPusherTests(HomeserverTestCase):
         self.pump()
 
         # The stream ordering has increased, again
-        pushers = self.get_success(
-            self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+        pushers = list(
+            self.get_success(
+                self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+            )
         )
-        pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
         self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
 
@@ -229,6 +234,7 @@ class HTTPPusherTests(HomeserverTestCase):
         user_tuple = self.get_success(
             self.hs.get_datastores().main.get_user_by_access_token(access_token)
         )
+        assert user_tuple is not None
         token_id = user_tuple.token_id
 
         self.get_success(
@@ -349,6 +355,7 @@ class HTTPPusherTests(HomeserverTestCase):
         user_tuple = self.get_success(
             self.hs.get_datastores().main.get_user_by_access_token(access_token)
         )
+        assert user_tuple is not None
         token_id = user_tuple.token_id
 
         self.get_success(
@@ -435,6 +442,7 @@ class HTTPPusherTests(HomeserverTestCase):
         user_tuple = self.get_success(
             self.hs.get_datastores().main.get_user_by_access_token(access_token)
         )
+        assert user_tuple is not None
         token_id = user_tuple.token_id
 
         self.get_success(
@@ -512,6 +520,7 @@ class HTTPPusherTests(HomeserverTestCase):
         user_tuple = self.get_success(
             self.hs.get_datastores().main.get_user_by_access_token(access_token)
         )
+        assert user_tuple is not None
         token_id = user_tuple.token_id
 
         self.get_success(
@@ -618,6 +627,7 @@ class HTTPPusherTests(HomeserverTestCase):
         user_tuple = self.get_success(
             self.hs.get_datastores().main.get_user_by_access_token(access_token)
         )
+        assert user_tuple is not None
         token_id = user_tuple.token_id
 
         self.get_success(
@@ -753,6 +763,7 @@ class HTTPPusherTests(HomeserverTestCase):
         user_tuple = self.get_success(
             self.hs.get_datastores().main.get_user_by_access_token(access_token)
         )
+        assert user_tuple is not None
         token_id = user_tuple.token_id
 
         self.get_success(
@@ -895,6 +906,7 @@ class HTTPPusherTests(HomeserverTestCase):
         user_tuple = self.get_success(
             self.hs.get_datastores().main.get_user_by_access_token(access_token)
         )
+        assert user_tuple is not None
         token_id = user_tuple.token_id
         device_id = user_tuple.device_id
 
@@ -941,9 +953,10 @@ class HTTPPusherTests(HomeserverTestCase):
         )
 
         # Look up the user info for the access token so we can compare the device ID.
-        lookup_result: TokenLookupResult = self.get_success(
+        lookup_result = self.get_success(
             self.hs.get_datastores().main.get_user_by_access_token(access_token)
         )
+        assert lookup_result is not None
 
         # Get the user's devices and check it has the correct device ID.
         channel = self.make_request("GET", "/pushers", access_token=access_token)