diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index 624bf5ada2..587be7b2e7 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -181,3 +181,43 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
)
self.render(request)
self.assertEqual(channel.code, 403)
+
+ def test_complete_operation_unknown_session(self):
+ """
+ Attempting to mark an invalid session as complete should error.
+ """
+
+ # Make the initial request to register. (Later on a different password
+ # will be used.)
+ request, channel = self.make_request(
+ "POST",
+ "register",
+ {"username": "user", "type": "m.login.password", "password": "bar"},
+ )
+ self.render(request)
+
+ # Returns a 401 as per the spec
+ self.assertEqual(request.code, 401)
+ # Grab the session
+ session = channel.json_body["session"]
+ # Assert our configured public key is being given
+ self.assertEqual(
+ channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake"
+ )
+
+ request, channel = self.make_request(
+ "GET", "auth/m.login.recaptcha/fallback/web?session=" + session
+ )
+ self.render(request)
+ self.assertEqual(request.code, 200)
+
+ # Attempt to complete an unknown session, which should return an error.
+ unknown_session = session + "unknown"
+ request, channel = self.make_request(
+ "POST",
+ "auth/m.login.recaptcha/fallback/web?session="
+ + unknown_session
+ + "&g-recaptcha-response=a",
+ )
+ self.render(request)
+ self.assertEqual(request.code, 400)
diff --git a/tests/utils.py b/tests/utils.py
index 037cb134f0..f9be62b499 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -512,8 +512,8 @@ class MockClock(object):
return t
- def looping_call(self, function, interval):
- self.loopers.append([function, interval / 1000.0, self.now])
+ def looping_call(self, function, interval, *args, **kwargs):
+ self.loopers.append([function, interval / 1000.0, self.now, args, kwargs])
def cancel_call_later(self, timer, ignore_errs=False):
if timer[2]:
@@ -543,9 +543,9 @@ class MockClock(object):
self.timers.append(t)
for looped in self.loopers:
- func, interval, last = looped
+ func, interval, last, args, kwargs = looped
if last + interval < self.now:
- func()
+ func(*args, **kwargs)
looped[2] = self.now
def advance_time_msec(self, ms):
|