diff --git a/changelog.d/47.misc b/changelog.d/47.misc
new file mode 100644
index 0000000000..1d6596d788
--- /dev/null
+++ b/changelog.d/47.misc
@@ -0,0 +1 @@
+Improve performance of `mark_expired_users_as_inactive` background job.
\ No newline at end of file
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index a6c907b9c9..0c2bcda4d0 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -281,23 +281,21 @@ class AccountValidityHandler(object):
if self._show_users_in_user_directory:
# Show the user in the directory again by setting them to active
await self.profile_handler.set_active(
- UserID.from_string(user_id), True, True
+ [UserID.from_string(user_id)], True, True
)
return expiration_ts
@defer.inlineCallbacks
def _mark_expired_users_as_inactive(self):
- """Iterate over expired users. Mark them as inactive in order to hide them from the
- user directory.
+ """Iterate over active, expired users. Mark them as inactive in order to hide them
+ from the user directory.
Returns:
Deferred
"""
- # Get expired users
- expired_user_ids = yield self.store.get_expired_users()
- expired_users = [UserID.from_string(user_id) for user_id in expired_user_ids]
+ # Get active, expired users
+ active_expired_users = yield self.store.get_expired_users()
- # Mark each one as non-active
- for user in expired_users:
- yield self.profile_handler.set_active(user, False, True)
+ # Mark each as non-active
+ yield self.profile_handler.set_active(active_expired_users, False, True)
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index f624c2a3f9..fe62c3f973 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -106,7 +106,7 @@ class DeactivateAccountHandler(BaseHandler):
await self.store.user_set_password_hash(user_id, None)
user = UserID.from_string(user_id)
- await self._profile_handler.set_active(user, False, False)
+ await self._profile_handler.set_active([user], False, False)
# Add the user to a table of users pending deactivation (ie.
# removal from all the rooms they're a member of)
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index bca0d8d380..1880bb2dd9 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -15,6 +15,7 @@
# limitations under the License.
import logging
+from typing import List
from six import raise_from
from six.moves import range
@@ -67,6 +68,7 @@ class BaseProfileHandler(BaseHandler):
self.max_avatar_size = hs.config.max_avatar_size
self.allowed_avatar_mimetypes = hs.config.allowed_avatar_mimetypes
+ self.replicate_user_profiles_to = hs.config.replicate_user_profiles_to
if hs.config.worker_app is None:
self.clock.looping_call(
@@ -293,29 +295,42 @@ class BaseProfileHandler(BaseHandler):
run_in_background(self._replicate_profiles)
@defer.inlineCallbacks
- def set_active(self, target_user, active, hide):
+ def set_active(
+ self, users: List[UserID], active: bool, hide: bool,
+ ):
"""
- Sets the 'active' flag on a user profile. If set to false, the user
- account is considered deactivated or hidden.
+ Sets the 'active' flag on a set of user profiles. If set to false, the
+ accounts are considered deactivated or hidden.
If 'hide' is true, then we interpret active=False as a request to try to
- hide the user rather than deactivating it. This means withholding the
- profile from replication (and mark it as inactive) rather than clearing
- the profile from the HS DB. Note that unlike set_displayname and
- set_avatar_url, this does *not* perform authorization checks! This is
- because the only place it's used currently is in account deactivation
- where we've already done these checks anyway.
+ hide the users rather than deactivating them. This means withholding the
+ profiles from replication (and mark it as inactive) rather than clearing
+ the profile from the HS DB.
+
+ Note that unlike set_displayname and set_avatar_url, this does *not*
+ perform authorization checks! This is because the only place it's used
+ currently is in account deactivation where we've already done these
+ checks anyway.
+
+ Args:
+ users: The users to modify
+ active: Whether to set the user to active or inactive
+ hide: Whether to hide the user (withold from replication). If
+ False and active is False, user will have their profile
+ erased
+
+ Returns:
+ Deferred
"""
- if len(self.hs.config.replicate_user_profiles_to) > 0:
+ if len(self.replicate_user_profiles_to) > 0:
cur_batchnum = (
yield self.store.get_latest_profile_replication_batch_number()
)
new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1
else:
new_batchnum = None
- yield self.store.set_profile_active(
- target_user.localpart, active, hide, new_batchnum
- )
+
+ yield self.store.set_profiles_active(users, active, hide, new_batchnum)
# start a profile replication push
run_in_background(self._replicate_profiles)
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index c7e3861054..8f9841117a 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -281,7 +281,7 @@ class RegistrationHandler(BaseHandler):
yield self.store.add_account_data_for_user(
user_id, "im.vector.hide_profile", {"hide_profile": True}
)
- yield self.profile_handler.set_active(user, False, True)
+ yield self.profile_handler.set_active([user], False, True)
return user_id
diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index ddb011d864..d31ec7c29d 100644
--- a/synapse/rest/client/v2_alpha/account_data.py
+++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -55,7 +55,7 @@ class AccountDataServlet(RestServlet):
if account_data_type == "im.vector.hide_profile":
user = UserID.from_string(user_id)
hide_profile = body.get("hide_profile")
- await self._profile_handler.set_active(user, not hide_profile, True)
+ await self._profile_handler.set_active([user], not hide_profile, True)
max_id = await self.store.add_account_data_for_user(
user_id, account_data_type, body
diff --git a/synapse/storage/data_stores/main/profile.py b/synapse/storage/data_stores/main/profile.py
index 7c69996041..cd2472feb0 100644
--- a/synapse/storage/data_stores/main/profile.py
+++ b/synapse/storage/data_stores/main/profile.py
@@ -14,11 +14,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import List, Tuple
+
from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.roommember import ProfileInfo
+from synapse.types import UserID
from synapse.util.caches.descriptors import cached
BATCH_SIZE = 100
@@ -149,19 +152,43 @@ class ProfileWorkerStore(SQLBaseStore):
lock=False, # we can do this because user_id has a unique index
)
- def set_profile_active(self, user_localpart, active, hide, batchnum):
- values = {"active": int(active), "batch": batchnum}
+ def set_profiles_active(
+ self, users: List[UserID], active: bool, hide: bool, batchnum: int,
+ ):
+ """Given a set of users, set active and hidden flags on them.
+
+ Args:
+ users: A list of UserIDs
+ active: Whether to set the users to active or inactive
+ hide: Whether to hide the users (withold from replication). If
+ False and active is False, users will have their profiles
+ erased
+ batchnum: The batch number, used for profile replication
+
+ Returns:
+ Deferred
+ """
+ # Convert list of localparts to list of tuples containing localparts
+ user_localparts = [(user.localpart,) for user in users]
+
+ # Generate list of value tuples for each user
+ value_names = ("active", "batch")
+ values = [(int(active), batchnum) for _ in user_localparts] # type: List[Tuple]
+
if not active and not hide:
# we are deactivating for real (not in hide mode)
- # so clear the profile.
- values["avatar_url"] = None
- values["displayname"] = None
- return self.db.simple_upsert(
+ # so clear the profile information
+ value_names += ("avatar_url", "displayname")
+ values = [v + (None, None) for v in values]
+
+ return self.db.runInteraction(
+ "set_profiles_active",
+ self.db.simple_upsert_many_txn,
table="profiles",
- keyvalues={"user_id": user_localpart},
- values=values,
- desc="set_profile_active",
- lock=False, # we can do this because user_id has a unique index
+ key_names=("user_id",),
+ key_values=user_localparts,
+ value_names=value_names,
+ value_values=values,
)
diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py
index e91634b322..b07c44d87a 100644
--- a/synapse/storage/data_stores/main/registration.py
+++ b/synapse/storage/data_stores/main/registration.py
@@ -159,25 +159,34 @@ class RegistrationWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_expired_users(self):
- """Get IDs of all expired users
+ """Get UserIDs of all expired users.
+
+ Users who are not active, or do not have profile information, are
+ excluded from the results.
Returns:
- Deferred[list[str]]: List of expired user IDs
+ Deferred[List[UserID]]: List of expired user IDs
"""
def get_expired_users_txn(txn, now_ms):
+ # We need to use pattern matching as profiles.user_id is confusingly just the
+ # user's localpart, whereas account_validity.user_id is a full user ID
sql = """
- SELECT user_id from account_validity
- WHERE expiration_ts_ms <= ?
+ SELECT av.user_id from account_validity AS av
+ LEFT JOIN profiles as p
+ ON av.user_id LIKE '%%' || p.user_id || ':%%'
+ WHERE expiration_ts_ms <= ?
+ AND p.active = 1
"""
txn.execute(sql, (now_ms,))
rows = txn.fetchall()
- return [row[0] for row in rows]
+
+ return [UserID.from_string(row[0]) for row in rows]
res = yield self.db.runInteraction(
"get_expired_users", get_expired_users_txn, self.clock.time_msec()
)
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def set_renewal_token_for_user(self, user_id, renewal_token):
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 2e07cddfce..8e5608b3ba 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -539,14 +539,32 @@ class AccountValidityUserDirectoryTestCase(unittest.HomeserverTestCase):
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
- # Create a user to expire
+ # Mock the homeserver's HTTP client
+ post_json = self.hs.get_simple_http_client().post_json_get_json
+
+ # Create a user
username = "kermit"
user_id = self.register_user(username, "monkey")
self.login(username, "monkey")
+ self.get_success(
+ self.hs.get_datastore().set_profile_displayname(username, "mr.kermit", 1)
+ )
- self.pump(1000)
- self.reactor.advance(1000)
- self.pump()
+ # Check that a full profile for this user is replicated
+ self.assertIsNotNone(post_json.call_args, post_json.call_args)
+ payload = post_json.call_args[0][1]
+ batch = payload.get("batch")
+
+ self.assertIsNotNone(batch, batch)
+ self.assertEquals(len(batch), 1, batch)
+
+ replicated_user_id = list(batch.keys())[0]
+ self.assertEquals(replicated_user_id, user_id, replicated_user_id)
+
+ # There was replicated information about our user
+ # Check that it's not None
+ replicated_content = batch[user_id]
+ self.assertIsNotNone(replicated_content)
# Expire the user
url = "/_matrix/client/unstable/admin/account_validity/validity"
@@ -563,17 +581,16 @@ class AccountValidityUserDirectoryTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
# Wait for the background job to run which hides expired users in the directory
- self.pump(60 * 60 * 1000)
-
- # Mock the homeserver's HTTP client
- post_json = self.hs.get_simple_http_client().post_json_get_json
+ self.reactor.advance(60 * 60 * 1000)
# Check if the homeserver has replicated the user's profile to the identity server
- self.assertNotEquals(post_json.call_args, None, post_json.call_args)
+ self.assertIsNotNone(post_json.call_args, post_json.call_args)
payload = post_json.call_args[0][1]
batch = payload.get("batch")
- self.assertNotEquals(batch, None, batch)
+
+ self.assertIsNotNone(batch, batch)
self.assertEquals(len(batch), 1, batch)
+
replicated_user_id = list(batch.keys())[0]
self.assertEquals(replicated_user_id, user_id, replicated_user_id)
|