diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index d2e0685e9e..1681caa1f0 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -91,27 +91,6 @@ class ProfileWorkerStore(SQLBaseStore):
desc="set_profile_avatar_url",
)
-
-class ProfileStore(ProfileWorkerStore):
- async def add_remote_profile_cache(
- self, user_id: str, displayname: str, avatar_url: str
- ) -> None:
- """Ensure we are caching the remote user's profiles.
-
- This should only be called when `is_subscribed_remote_profile_for_user`
- would return true for the user.
- """
- await self.db_pool.simple_upsert(
- table="remote_profile_cache",
- keyvalues={"user_id": user_id},
- values={
- "displayname": displayname,
- "avatar_url": avatar_url,
- "last_check": self._clock.time_msec(),
- },
- desc="add_remote_profile_cache",
- )
-
async def update_remote_profile_cache(
self, user_id: str, displayname: str, avatar_url: str
) -> int:
@@ -138,6 +117,31 @@ class ProfileStore(ProfileWorkerStore):
desc="delete_remote_profile_cache",
)
+ async def is_subscribed_remote_profile_for_user(self, user_id):
+ """Check whether we are interested in a remote user's profile.
+ """
+ res = await self.db_pool.simple_select_one_onecol(
+ table="group_users",
+ keyvalues={"user_id": user_id},
+ retcol="user_id",
+ allow_none=True,
+ desc="should_update_remote_profile_cache_for_user",
+ )
+
+ if res:
+ return True
+
+ res = await self.db_pool.simple_select_one_onecol(
+ table="group_invites",
+ keyvalues={"user_id": user_id},
+ retcol="user_id",
+ allow_none=True,
+ desc="should_update_remote_profile_cache_for_user",
+ )
+
+ if res:
+ return True
+
async def get_remote_profile_cache_entries_that_expire(
self, last_checked: int
) -> Dict[str, str]:
@@ -160,27 +164,23 @@ class ProfileStore(ProfileWorkerStore):
_get_remote_profile_cache_entries_that_expire_txn,
)
- async def is_subscribed_remote_profile_for_user(self, user_id):
- """Check whether we are interested in a remote user's profile.
- """
- res = await self.db_pool.simple_select_one_onecol(
- table="group_users",
- keyvalues={"user_id": user_id},
- retcol="user_id",
- allow_none=True,
- desc="should_update_remote_profile_cache_for_user",
- )
- if res:
- return True
+class ProfileStore(ProfileWorkerStore):
+ async def add_remote_profile_cache(
+ self, user_id: str, displayname: str, avatar_url: str
+ ) -> None:
+ """Ensure we are caching the remote user's profiles.
- res = await self.db_pool.simple_select_one_onecol(
- table="group_invites",
+ This should only be called when `is_subscribed_remote_profile_for_user`
+ would return true for the user.
+ """
+ await self.db_pool.simple_upsert(
+ table="remote_profile_cache",
keyvalues={"user_id": user_id},
- retcol="user_id",
- allow_none=True,
- desc="should_update_remote_profile_cache_for_user",
+ values={
+ "displayname": displayname,
+ "avatar_url": avatar_url,
+ "last_check": self._clock.time_msec(),
+ },
+ desc="add_remote_profile_cache",
)
-
- if res:
- return True
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 9a003e30f9..4c843b7679 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -862,6 +862,32 @@ class RegistrationWorkerStore(SQLBaseStore):
values={"expiration_ts_ms": expiration_ts, "email_sent": False},
)
+ async def get_user_pending_deactivation(self) -> Optional[str]:
+ """
+ Gets one user from the table of users waiting to be parted from all the rooms
+ they're in.
+ """
+ return await self.db_pool.simple_select_one_onecol(
+ "users_pending_deactivation",
+ keyvalues={},
+ retcol="user_id",
+ allow_none=True,
+ desc="get_users_pending_deactivation",
+ )
+
+ async def del_user_pending_deactivation(self, user_id: str) -> None:
+ """
+ Removes the given user to the table of users who need to be parted from all the
+ rooms they're in, effectively marking that user as fully deactivated.
+ """
+ # XXX: This should be simple_delete_one but we failed to put a unique index on
+ # the table, so somehow duplicate entries have ended up in it.
+ await self.db_pool.simple_delete(
+ "users_pending_deactivation",
+ keyvalues={"user_id": user_id},
+ desc="del_user_pending_deactivation",
+ )
+
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
@@ -1371,32 +1397,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="add_user_pending_deactivation",
)
- async def del_user_pending_deactivation(self, user_id: str) -> None:
- """
- Removes the given user to the table of users who need to be parted from all the
- rooms they're in, effectively marking that user as fully deactivated.
- """
- # XXX: This should be simple_delete_one but we failed to put a unique index on
- # the table, so somehow duplicate entries have ended up in it.
- await self.db_pool.simple_delete(
- "users_pending_deactivation",
- keyvalues={"user_id": user_id},
- desc="del_user_pending_deactivation",
- )
-
- async def get_user_pending_deactivation(self) -> Optional[str]:
- """
- Gets one user from the table of users waiting to be parted from all the rooms
- they're in.
- """
- return await self.db_pool.simple_select_one_onecol(
- "users_pending_deactivation",
- keyvalues={},
- retcol="user_id",
- allow_none=True,
- desc="get_users_pending_deactivation",
- )
-
async def validate_threepid_session(
self, session_id: str, client_secret: str, token: str, current_ts: int
) -> Optional[str]:
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index c0f2af0785..e83d961c20 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -869,6 +869,89 @@ class RoomWorkerStore(SQLBaseStore):
"get_all_new_public_rooms", get_all_new_public_rooms
)
+ async def get_rooms_for_retention_period_in_range(
+ self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False
+ ) -> Dict[str, dict]:
+ """Retrieves all of the rooms within the given retention range.
+
+ Optionally includes the rooms which don't have a retention policy.
+
+ Args:
+ min_ms: Duration in milliseconds that define the lower limit of
+ the range to handle (exclusive). If None, doesn't set a lower limit.
+ max_ms: Duration in milliseconds that define the upper limit of
+ the range to handle (inclusive). If None, doesn't set an upper limit.
+ include_null: Whether to include rooms which retention policy is NULL
+ in the returned set.
+
+ Returns:
+ The rooms within this range, along with their retention
+ policy. The key is "room_id", and maps to a dict describing the retention
+ policy associated with this room ID. The keys for this nested dict are
+ "min_lifetime" (int|None), and "max_lifetime" (int|None).
+ """
+
+ def get_rooms_for_retention_period_in_range_txn(txn):
+ range_conditions = []
+ args = []
+
+ if min_ms is not None:
+ range_conditions.append("max_lifetime > ?")
+ args.append(min_ms)
+
+ if max_ms is not None:
+ range_conditions.append("max_lifetime <= ?")
+ args.append(max_ms)
+
+ # Do a first query which will retrieve the rooms that have a retention policy
+ # in their current state.
+ sql = """
+ SELECT room_id, min_lifetime, max_lifetime FROM room_retention
+ INNER JOIN current_state_events USING (event_id, room_id)
+ """
+
+ if len(range_conditions):
+ sql += " WHERE (" + " AND ".join(range_conditions) + ")"
+
+ if include_null:
+ sql += " OR max_lifetime IS NULL"
+
+ txn.execute(sql, args)
+
+ rows = self.db_pool.cursor_to_dict(txn)
+ rooms_dict = {}
+
+ for row in rows:
+ rooms_dict[row["room_id"]] = {
+ "min_lifetime": row["min_lifetime"],
+ "max_lifetime": row["max_lifetime"],
+ }
+
+ if include_null:
+ # If required, do a second query that retrieves all of the rooms we know
+ # of so we can handle rooms with no retention policy.
+ sql = "SELECT DISTINCT room_id FROM current_state_events"
+
+ txn.execute(sql)
+
+ rows = self.db_pool.cursor_to_dict(txn)
+
+ # If a room isn't already in the dict (i.e. it doesn't have a retention
+ # policy in its state), add it with a null policy.
+ for row in rows:
+ if row["room_id"] not in rooms_dict:
+ rooms_dict[row["room_id"]] = {
+ "min_lifetime": None,
+ "max_lifetime": None,
+ }
+
+ return rooms_dict
+
+ return await self.db_pool.runInteraction(
+ "get_rooms_for_retention_period_in_range",
+ get_rooms_for_retention_period_in_range_txn,
+ )
+
class RoomBackgroundUpdateStore(SQLBaseStore):
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
@@ -1446,88 +1529,3 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
self.is_room_blocked,
(room_id,),
)
-
- async def get_rooms_for_retention_period_in_range(
- self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False
- ) -> Dict[str, dict]:
- """Retrieves all of the rooms within the given retention range.
-
- Optionally includes the rooms which don't have a retention policy.
-
- Args:
- min_ms: Duration in milliseconds that define the lower limit of
- the range to handle (exclusive). If None, doesn't set a lower limit.
- max_ms: Duration in milliseconds that define the upper limit of
- the range to handle (inclusive). If None, doesn't set an upper limit.
- include_null: Whether to include rooms which retention policy is NULL
- in the returned set.
-
- Returns:
- The rooms within this range, along with their retention
- policy. The key is "room_id", and maps to a dict describing the retention
- policy associated with this room ID. The keys for this nested dict are
- "min_lifetime" (int|None), and "max_lifetime" (int|None).
- """
-
- def get_rooms_for_retention_period_in_range_txn(txn):
- range_conditions = []
- args = []
-
- if min_ms is not None:
- range_conditions.append("max_lifetime > ?")
- args.append(min_ms)
-
- if max_ms is not None:
- range_conditions.append("max_lifetime <= ?")
- args.append(max_ms)
-
- # Do a first query which will retrieve the rooms that have a retention policy
- # in their current state.
- sql = """
- SELECT room_id, min_lifetime, max_lifetime FROM room_retention
- INNER JOIN current_state_events USING (event_id, room_id)
- """
-
- if len(range_conditions):
- sql += " WHERE (" + " AND ".join(range_conditions) + ")"
-
- if include_null:
- sql += " OR max_lifetime IS NULL"
-
- txn.execute(sql, args)
-
- rows = self.db_pool.cursor_to_dict(txn)
- rooms_dict = {}
-
- for row in rows:
- rooms_dict[row["room_id"]] = {
- "min_lifetime": row["min_lifetime"],
- "max_lifetime": row["max_lifetime"],
- }
-
- if include_null:
- # If required, do a second query that retrieves all of the rooms we know
- # of so we can handle rooms with no retention policy.
- sql = "SELECT DISTINCT room_id FROM current_state_events"
-
- txn.execute(sql)
-
- rows = self.db_pool.cursor_to_dict(txn)
-
- # If a room isn't already in the dict (i.e. it doesn't have a retention
- # policy in its state), add it with a null policy.
- for row in rows:
- if row["room_id"] not in rooms_dict:
- rooms_dict[row["room_id"]] = {
- "min_lifetime": None,
- "max_lifetime": None,
- }
-
- return rooms_dict
-
- rooms = await self.db_pool.runInteraction(
- "get_rooms_for_retention_period_in_range",
- get_rooms_for_retention_period_in_range_txn,
- )
-
- return rooms
|