summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/rest/client/v2_alpha/test_auth.py40
-rw-r--r--tests/utils.py8
2 files changed, 44 insertions, 4 deletions
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index 624bf5ada2..587be7b2e7 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -181,3 +181,43 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
         )
         self.render(request)
         self.assertEqual(channel.code, 403)
+
+    def test_complete_operation_unknown_session(self):
+        """
+        Attempting to mark an invalid session as complete should error.
+        """
+
+        # Make the initial request to register. (Later on a different password
+        # will be used.)
+        request, channel = self.make_request(
+            "POST",
+            "register",
+            {"username": "user", "type": "m.login.password", "password": "bar"},
+        )
+        self.render(request)
+
+        # Returns a 401 as per the spec
+        self.assertEqual(request.code, 401)
+        # Grab the session
+        session = channel.json_body["session"]
+        # Assert our configured public key is being given
+        self.assertEqual(
+            channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake"
+        )
+
+        request, channel = self.make_request(
+            "GET", "auth/m.login.recaptcha/fallback/web?session=" + session
+        )
+        self.render(request)
+        self.assertEqual(request.code, 200)
+
+        # Attempt to complete an unknown session, which should return an error.
+        unknown_session = session + "unknown"
+        request, channel = self.make_request(
+            "POST",
+            "auth/m.login.recaptcha/fallback/web?session="
+            + unknown_session
+            + "&g-recaptcha-response=a",
+        )
+        self.render(request)
+        self.assertEqual(request.code, 400)
diff --git a/tests/utils.py b/tests/utils.py
index 037cb134f0..f9be62b499 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -512,8 +512,8 @@ class MockClock(object):
 
         return t
 
-    def looping_call(self, function, interval):
-        self.loopers.append([function, interval / 1000.0, self.now])
+    def looping_call(self, function, interval, *args, **kwargs):
+        self.loopers.append([function, interval / 1000.0, self.now, args, kwargs])
 
     def cancel_call_later(self, timer, ignore_errs=False):
         if timer[2]:
@@ -543,9 +543,9 @@ class MockClock(object):
                 self.timers.append(t)
 
         for looped in self.loopers:
-            func, interval, last = looped
+            func, interval, last, args, kwargs = looped
             if last + interval < self.now:
-                func()
+                func(*args, **kwargs)
                 looped[2] = self.now
 
     def advance_time_msec(self, ms):