summary refs log tree commit diff
path: root/synapse/storage/databases/main
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r--synapse/storage/databases/main/profile.py151
-rw-r--r--synapse/storage/databases/main/registration.py80
-rw-r--r--synapse/storage/databases/main/room.py23
3 files changed, 246 insertions, 8 deletions
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py

index ba7075caa5..1a34118f7e 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py
@@ -1,4 +1,5 @@ # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,11 +12,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore from synapse.storage.databases.main.roommember import ProfileInfo +from synapse.types import UserID +from synapse.util.caches.descriptors import cached + +BATCH_SIZE = 100 class ProfileWorkerStore(SQLBaseStore): @@ -38,6 +43,7 @@ class ProfileWorkerStore(SQLBaseStore): avatar_url=profile["avatar_url"], display_name=profile["displayname"] ) + @cached(max_entries=5000) async def get_profile_displayname(self, user_localpart: str) -> Optional[str]: return await self.db_pool.simple_select_one_onecol( table="profiles", @@ -46,6 +52,7 @@ class ProfileWorkerStore(SQLBaseStore): desc="get_profile_displayname", ) + @cached(max_entries=5000) async def get_profile_avatar_url(self, user_localpart: str) -> Optional[str]: return await self.db_pool.simple_select_one_onecol( table="profiles", @@ -54,6 +61,58 @@ class ProfileWorkerStore(SQLBaseStore): desc="get_profile_avatar_url", ) + async def get_latest_profile_replication_batch_number(self): + def f(txn): + txn.execute("SELECT MAX(batch) as maxbatch FROM profiles") + rows = self.db_pool.cursor_to_dict(txn) + return rows[0]["maxbatch"] + + return await self.db_pool.runInteraction( + "get_latest_profile_replication_batch_number", f + ) + + async def get_profile_batch(self, batchnum): + return await self.db_pool.simple_select_list( + table="profiles", + keyvalues={"batch": batchnum}, + retcols=("user_id", "displayname", "avatar_url", "active"), + desc="get_profile_batch", + ) + + async def assign_profile_batch(self): + def f(txn): + sql = ( + "UPDATE profiles SET batch = " + "(SELECT COALESCE(MAX(batch), -1) + 1 FROM profiles) " + "WHERE user_id in (" + " SELECT user_id FROM profiles WHERE batch is NULL limit ?" + ")" + ) + txn.execute(sql, (BATCH_SIZE,)) + return txn.rowcount + + return await self.db_pool.runInteraction("assign_profile_batch", f) + + async def get_replication_hosts(self): + def f(txn): + txn.execute( + "SELECT host, last_synced_batch FROM profile_replication_status" + ) + rows = self.db_pool.cursor_to_dict(txn) + return {r["host"]: r["last_synced_batch"] for r in rows} + + return await self.db_pool.runInteraction("get_replication_hosts", f) + + async def update_replication_batch_for_host( + self, host: str, last_synced_batch: int + ): + return await self.db_pool.simple_upsert( + table="profile_replication_status", + keyvalues={"host": host}, + values={"last_synced_batch": last_synced_batch}, + desc="update_replication_batch_for_host", + ) + async def get_from_remote_profile_cache( self, user_id: str ) -> Optional[Dict[str, Any]]: @@ -71,32 +130,99 @@ class ProfileWorkerStore(SQLBaseStore): ) async def set_profile_displayname( - self, user_localpart: str, new_displayname: Optional[str] + self, user_localpart: str, new_displayname: Optional[str], batchnum: int ) -> None: + # Invalidate the read cache for this user + self.get_profile_displayname.invalidate((user_localpart,)) + await self.db_pool.simple_upsert( table="profiles", keyvalues={"user_id": user_localpart}, - values={"displayname": new_displayname}, + values={"displayname": new_displayname, "batch": batchnum}, desc="set_profile_displayname", + lock=False, # we can do this because user_id has a unique index ) async def set_profile_avatar_url( - self, user_localpart: str, new_avatar_url: Optional[str] + self, user_localpart: str, new_avatar_url: Optional[str], batchnum: int ) -> None: + # Invalidate the read cache for this user + self.get_profile_avatar_url.invalidate((user_localpart,)) + await self.db_pool.simple_upsert( table="profiles", keyvalues={"user_id": user_localpart}, - values={"avatar_url": new_avatar_url}, + values={"avatar_url": new_avatar_url, "batch": batchnum}, desc="set_profile_avatar_url", + lock=False, # we can do this because user_id has a unique index + ) + + async def set_profiles_active( + self, + users: List[UserID], + active: bool, + hide: bool, + batchnum: int, + ) -> None: + """Given a set of users, set active and hidden flags on them. + + Args: + users: A list of UserIDs + active: Whether to set the users to active or inactive + hide: Whether to hide the users (withold from replication). If + False and active is False, users will have their profiles + erased + batchnum: The batch number, used for profile replication + """ + # Convert list of localparts to list of tuples containing localparts + user_localparts = [(user.localpart,) for user in users] + + # Generate list of value tuples for each user + value_names = ("active", "batch") + values = [(int(active), batchnum) for _ in user_localparts] # type: List[Tuple] + + if not active and not hide: + # we are deactivating for real (not in hide mode) + # so clear the profile information + value_names += ("avatar_url", "displayname") + values = [v + (None, None) for v in values] + + return await self.db_pool.runInteraction( + "set_profiles_active", + self.db_pool.simple_upsert_many_txn, + table="profiles", + key_names=("user_id",), + key_values=user_localparts, + value_names=value_names, + value_values=values, + ) + + async def add_remote_profile_cache( + self, user_id: str, displayname: str, avatar_url: str + ) -> None: + """Ensure we are caching the remote user's profiles. + + This should only be called when `is_subscribed_remote_profile_for_user` + would return true for the user. + """ + await self.db_pool.simple_upsert( + table="remote_profile_cache", + keyvalues={"user_id": user_id}, + values={ + "displayname": displayname, + "avatar_url": avatar_url, + "last_check": self._clock.time_msec(), + }, + desc="add_remote_profile_cache", ) async def update_remote_profile_cache( self, user_id: str, displayname: str, avatar_url: str ) -> int: - return await self.db_pool.simple_update( + return await self.db_pool.simple_upsert( table="remote_profile_cache", keyvalues={"user_id": user_id}, - updatevalues={ + values={ "displayname": displayname, "avatar_url": avatar_url, "last_check": self._clock.time_msec(), @@ -163,6 +289,17 @@ class ProfileWorkerStore(SQLBaseStore): class ProfileStore(ProfileWorkerStore): + def __init__(self, database, db_conn, hs): + super().__init__(database, db_conn, hs) + + self.db_pool.updates.register_background_index_update( + "profile_replication_status_host_index", + index_name="profile_replication_status_idx", + table="profile_replication_status", + columns=["host"], + unique=True, + ) + async def add_remote_profile_cache( self, user_id: str, displayname: str, avatar_url: str ) -> None: diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index e31c5864ac..b8bdfc721e 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py
@@ -272,6 +272,37 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): "set_account_validity_for_user", set_account_validity_for_user_txn ) + async def get_expired_users(self): + """Get UserIDs of all expired users. + + Users who are not active, or do not have profile information, are + excluded from the results. + + Returns: + Deferred[List[UserID]]: List of expired user IDs + """ + + def get_expired_users_txn(txn, now_ms): + # We need to use pattern matching as profiles.user_id is confusingly just the + # user's localpart, whereas account_validity.user_id is a full user ID + sql = """ + SELECT av.user_id from account_validity AS av + LEFT JOIN profiles as p + ON av.user_id LIKE '%%' || p.user_id || ':%%' + WHERE expiration_ts_ms <= ? + AND p.active = 1 + """ + txn.execute(sql, (now_ms,)) + rows = txn.fetchall() + + return [UserID.from_string(row[0]) for row in rows] + + res = await self.db_pool.runInteraction( + "get_expired_users", get_expired_users_txn, self._clock.time_msec() + ) + + return res + async def set_renewal_token_for_user( self, user_id: str, renewal_token: str ) -> None: @@ -392,6 +423,55 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): desc="delete_account_validity_for_user", ) + async def get_info_for_users( + self, + user_ids: List[str], + ): + """Return the user info for a given set of users + + Args: + user_ids: A list of users to return information about + + Returns: + Deferred[Dict[str, bool]]: A dictionary mapping each user ID to + a dict with the following keys: + * expired - whether this is an expired user + * deactivated - whether this is a deactivated user + """ + # Get information of all our local users + def _get_info_for_users_txn(txn): + rows = [] + + for user_id in user_ids: + sql = """ + SELECT u.name, u.deactivated, av.expiration_ts_ms + FROM users as u + LEFT JOIN account_validity as av + ON av.user_id = u.name + WHERE u.name = ? + """ + + txn.execute(sql, (user_id,)) + row = txn.fetchone() + if row: + rows.append(row) + + return rows + + info_rows = await self.db_pool.runInteraction( + "get_info_for_users", _get_info_for_users_txn + ) + + return { + user_id: { + "expired": ( + expiration is not None and self._clock.time_msec() >= expiration + ), + "deactivated": deactivated == 1, + } + for user_id, deactivated, expiration in info_rows + } + async def is_server_admin(self, user: UserID) -> bool: """Determines if a user is an admin of this homeserver. diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 9f0d64a325..bcf25f298e 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py
@@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import collections import logging from abc import abstractmethod @@ -358,6 +357,23 @@ class RoomWorkerStore(SQLBaseStore): desc="is_room_blocked", ) + async def is_room_published(self, room_id: str) -> bool: + """Check whether a room has been published in the local public room + directory. + + Args: + room_id + Returns: + Whether the room is currently published in the room directory + """ + # Get room information + room_info = await self.get_room(room_id) + if not room_info: + return False + + # Check the is_public value + return room_info.get("is_public", False) + async def get_rooms_paginate( self, start: int, @@ -621,6 +637,11 @@ class RoomWorkerStore(SQLBaseStore): Returns: dict[int, int]: "min_lifetime" and "max_lifetime" for this room. """ + # If the room retention feature is disabled, return a policy with no minimum nor + # maximum, in order not to filter out events we should filter out when sending to + # the client. + if not self.config.retention_enabled: + return {"min_lifetime": None, "max_lifetime": None} def get_retention_policy_for_room_txn(txn): txn.execute(