diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index ffbc13bb8d..62c32cae5e 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -169,7 +169,8 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# which sets these values to 10000, but as we're overriding the entire
# rc_login dict here, we need to set this manually as well
"account": {"per_second": 10000, "burst_count": 10000},
- }
+ },
+ "experimental_features": {"msc4041_enabled": True},
}
)
def test_POST_ratelimiting_per_address(self) -> None:
@@ -189,12 +190,15 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
if i == 5:
self.assertEqual(channel.code, 429, msg=channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"])
+ retry_header = channel.headers.getRawHeaders("Retry-After")
else:
self.assertEqual(channel.code, 200, msg=channel.result)
# Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
# than 1min.
- self.assertTrue(retry_after_ms < 6000)
+ self.assertLess(retry_after_ms, 6000)
+ assert retry_header
+ self.assertLessEqual(int(retry_header[0]), 6)
self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
@@ -217,7 +221,8 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# which sets these values to 10000, but as we're overriding the entire
# rc_login dict here, we need to set this manually as well
"address": {"per_second": 10000, "burst_count": 10000},
- }
+ },
+ "experimental_features": {"msc4041_enabled": True},
}
)
def test_POST_ratelimiting_per_account(self) -> None:
@@ -234,12 +239,15 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
if i == 5:
self.assertEqual(channel.code, 429, msg=channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"])
+ retry_header = channel.headers.getRawHeaders("Retry-After")
else:
self.assertEqual(channel.code, 200, msg=channel.result)
# Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
# than 1min.
- self.assertTrue(retry_after_ms < 6000)
+ self.assertLess(retry_after_ms, 6000)
+ assert retry_header
+ self.assertLessEqual(int(retry_header[0]), 6)
self.reactor.advance(retry_after_ms / 1000.0)
@@ -262,7 +270,8 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# rc_login dict here, we need to set this manually as well
"address": {"per_second": 10000, "burst_count": 10000},
"failed_attempts": {"per_second": 0.17, "burst_count": 5},
- }
+ },
+ "experimental_features": {"msc4041_enabled": True},
}
)
def test_POST_ratelimiting_per_account_failed_attempts(self) -> None:
@@ -279,12 +288,15 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
if i == 5:
self.assertEqual(channel.code, 429, msg=channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"])
+ retry_header = channel.headers.getRawHeaders("Retry-After")
else:
self.assertEqual(channel.code, 403, msg=channel.result)
# Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
# than 1min.
- self.assertTrue(retry_after_ms < 6000)
+ self.assertLess(retry_after_ms, 6000)
+ assert retry_header
+ self.assertLessEqual(int(retry_header[0]), 6)
self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
|