diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 2e07cddfce..8e5608b3ba 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -539,14 +539,32 @@ class AccountValidityUserDirectoryTestCase(unittest.HomeserverTestCase):
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
- # Create a user to expire
+ # Mock the homeserver's HTTP client
+ post_json = self.hs.get_simple_http_client().post_json_get_json
+
+ # Create a user
username = "kermit"
user_id = self.register_user(username, "monkey")
self.login(username, "monkey")
+ self.get_success(
+ self.hs.get_datastore().set_profile_displayname(username, "mr.kermit", 1)
+ )
- self.pump(1000)
- self.reactor.advance(1000)
- self.pump()
+ # Check that a full profile for this user is replicated
+ self.assertIsNotNone(post_json.call_args, post_json.call_args)
+ payload = post_json.call_args[0][1]
+ batch = payload.get("batch")
+
+ self.assertIsNotNone(batch, batch)
+ self.assertEquals(len(batch), 1, batch)
+
+ replicated_user_id = list(batch.keys())[0]
+ self.assertEquals(replicated_user_id, user_id, replicated_user_id)
+
+ # There was replicated information about our user
+ # Check that it's not None
+ replicated_content = batch[user_id]
+ self.assertIsNotNone(replicated_content)
# Expire the user
url = "/_matrix/client/unstable/admin/account_validity/validity"
@@ -563,17 +581,16 @@ class AccountValidityUserDirectoryTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
# Wait for the background job to run which hides expired users in the directory
- self.pump(60 * 60 * 1000)
-
- # Mock the homeserver's HTTP client
- post_json = self.hs.get_simple_http_client().post_json_get_json
+ self.reactor.advance(60 * 60 * 1000)
# Check if the homeserver has replicated the user's profile to the identity server
- self.assertNotEquals(post_json.call_args, None, post_json.call_args)
+ self.assertIsNotNone(post_json.call_args, post_json.call_args)
payload = post_json.call_args[0][1]
batch = payload.get("batch")
- self.assertNotEquals(batch, None, batch)
+
+ self.assertIsNotNone(batch, batch)
self.assertEquals(len(batch), 1, batch)
+
replicated_user_id = list(batch.keys())[0]
self.assertEquals(replicated_user_id, user_id, replicated_user_id)
|