diff --git a/changelog.d/13.feature b/changelog.d/13.feature
new file mode 100644
index 0000000000..c2d2e93abf
--- /dev/null
+++ b/changelog.d/13.feature
@@ -0,0 +1 @@
+Hide expired users from the user directory, and optionally re-add them on renewal.
\ No newline at end of file
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index 396f0059f7..947237d7da 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -42,6 +42,8 @@ class AccountValidityHandler(object):
self.clock = self.hs.get_clock()
self._account_validity = self.hs.config.account_validity
+ self._show_users_in_user_directory = self.hs.config.show_users_in_user_directory
+ self.profile_handler = self.hs.get_profile_handler()
if self._account_validity.renew_by_email_enabled and load_jinja2_templates:
# Don't do email-specific configuration if renewal by email is disabled.
@@ -74,6 +76,12 @@ class AccountValidityHandler(object):
30 * 60 * 1000,
)
+ # Check every hour to remove expired users from the user directory
+ self.clock.looping_call(
+ self._mark_expired_users_as_inactive,
+ 60 * 60 * 1000,
+ )
+
@defer.inlineCallbacks
def send_renewal_emails(self):
"""Gets the list of users whose account is expiring in the amount of time
@@ -261,4 +269,28 @@ class AccountValidityHandler(object):
email_sent=email_sent,
)
+ # Check if renewed users should be reintroduced to the user directory
+ if self._show_users_in_user_directory:
+ # Show the user in the directory again by setting them to active
+ yield self.profile_handler.set_active(UserID.from_string(user_id), True, True)
+
defer.returnValue(expiration_ts)
+
+ @defer.inlineCallbacks
+ def _mark_expired_users_as_inactive(self):
+ """Iterate over expired users. Mark them as inactive in order to hide them from the
+ user directory.
+
+ Returns:
+ Deferred
+ """
+ # Get expired users
+ expired_user_ids = yield self.store.get_expired_users()
+ expired_users = [
+ UserID.from_string(user_id)
+ for user_id in expired_user_ids
+ ]
+
+ # Mark each one as non-active
+ for user in expired_users:
+ yield self.profile_handler.set_active(user, False, True)
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 0b3c656e90..028848cf89 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -152,6 +152,29 @@ class RegistrationWorkerStore(SQLBaseStore):
)
@defer.inlineCallbacks
+ def get_expired_users(self):
+ """Get IDs of all expired users
+
+ Returns:
+ Deferred[list[str]]: List of expired user IDs
+ """
+ def get_expired_users_txn(txn, now_ms):
+ sql = """
+ SELECT user_id from account_validity
+ WHERE expiration_ts_ms <= ?
+ """
+ txn.execute(sql, (now_ms,))
+ rows = txn.fetchall()
+ return [row[0] for row in rows]
+
+ res = yield self.runInteraction(
+ "get_expired_users",
+ get_expired_users_txn,
+ self.clock.time_msec(),
+ )
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
def set_renewal_token_for_user(self, user_id, renewal_token):
"""Defines a renewal token for a given user.
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index af1e600591..a5c7aaa9c0 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -352,6 +352,141 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
)
+class AccountValidityUserDirectoryTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.client.v1.profile.register_servlets,
+ synapse.rest.client.v1.room.register_servlets,
+ synapse.rest.client.v2_alpha.user_directory.register_servlets,
+ login.register_servlets,
+ register.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ account_validity.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+
+ # Set accounts to expire after a week
+ config["enable_registration"] = True
+ config["account_validity"] = {
+ "enabled": True,
+ "period": 604800000, # Time in ms for 1 week
+ }
+ config["replicate_user_profiles_to"] = "test.is"
+
+ # Mock homeserver requests to an identity server
+ mock_http_client = Mock(spec=[
+ "post_json_get_json",
+ ])
+ mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}"))
+
+ self.hs = self.setup_test_homeserver(
+ config=config,
+ simple_http_client=mock_http_client,
+ )
+
+ return self.hs
+
+ def test_expired_user_in_directory(self):
+ """Test that an expired user is hidden in the user directory"""
+ # Create an admin user to search the user directory
+ admin_id = self.register_user("admin", "adminpassword", admin=True)
+ admin_tok = self.login("admin", "adminpassword")
+
+ # Ensure the admin never expires
+ url = "/_matrix/client/unstable/admin/account_validity/validity"
+ params = {
+ "user_id": admin_id,
+ "expiration_ts": 999999999999,
+ "enable_renewal_emails": False,
+ }
+ request_data = json.dumps(params)
+ request, channel = self.make_request(
+ b"POST", url, request_data, access_token=admin_tok
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ # Create a user to expire
+ username = "kermit"
+ user_id = self.register_user(username, "monkey")
+ self.login(username, "monkey")
+
+ self.pump(1000)
+ self.reactor.advance(1000)
+ self.pump()
+
+ # Expire the user
+ url = "/_matrix/client/unstable/admin/account_validity/validity"
+ params = {
+ "user_id": user_id,
+ "expiration_ts": 0,
+ "enable_renewal_emails": False,
+ }
+ request_data = json.dumps(params)
+ request, channel = self.make_request(
+ b"POST", url, request_data, access_token=admin_tok
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ # Wait for the background job to run which hides expired users in the directory
+ self.pump(60 * 60 * 1000)
+
+ # Mock the homeserver's HTTP client
+ post_json = self.hs.get_simple_http_client().post_json_get_json
+
+ # Check if the homeserver has replicated the user's profile to the identity server
+ self.assertNotEquals(post_json.call_args, None, post_json.call_args)
+ payload = post_json.call_args[0][1]
+ batch = payload.get("batch")
+ self.assertNotEquals(batch, None, batch)
+ self.assertEquals(len(batch), 1, batch)
+ replicated_user_id = list(batch.keys())[0]
+ self.assertEquals(replicated_user_id, user_id, replicated_user_id)
+
+ # There was replicated information about our user
+ # Check that it's None, signifying that the user should be removed from the user
+ # directory because they were expired
+ replicated_content = batch[user_id]
+ self.assertIsNone(replicated_content)
+
+ # Now renew the user, and check they get replicated again to the identity server
+ url = "/_matrix/client/unstable/admin/account_validity/validity"
+ params = {
+ "user_id": user_id,
+ "expiration_ts": 99999999999,
+ "enable_renewal_emails": False,
+ }
+ request_data = json.dumps(params)
+ request, channel = self.make_request(
+ b"POST", url, request_data, access_token=admin_tok
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ self.pump(10)
+ self.reactor.advance(10)
+ self.pump()
+
+ # Check if the homeserver has replicated the user's profile to the identity server
+ post_json = self.hs.get_simple_http_client().post_json_get_json
+ self.assertNotEquals(post_json.call_args, None, post_json.call_args)
+ payload = post_json.call_args[0][1]
+ batch = payload.get("batch")
+ self.assertNotEquals(batch, None, batch)
+ self.assertEquals(len(batch), 1, batch)
+ replicated_user_id = list(batch.keys())[0]
+ self.assertEquals(replicated_user_id, user_id, replicated_user_id)
+
+ # There was replicated information about our user
+ # Check that it's not None, signifying that the user is back in the user
+ # directory
+ replicated_content = batch[user_id]
+ self.assertIsNotNone(replicated_content)
+
+
class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
servlets = [
|