summary refs log tree commit diff
path: root/synapse/storage/databases/main
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r--synapse/storage/databases/main/profile.py82
-rw-r--r--synapse/storage/databases/main/registration.py52
-rw-r--r--synapse/storage/databases/main/room.py168
3 files changed, 150 insertions, 152 deletions
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index d2e0685e9e..1681caa1f0 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -91,27 +91,6 @@ class ProfileWorkerStore(SQLBaseStore):
             desc="set_profile_avatar_url",
         )
 
-
-class ProfileStore(ProfileWorkerStore):
-    async def add_remote_profile_cache(
-        self, user_id: str, displayname: str, avatar_url: str
-    ) -> None:
-        """Ensure we are caching the remote user's profiles.
-
-        This should only be called when `is_subscribed_remote_profile_for_user`
-        would return true for the user.
-        """
-        await self.db_pool.simple_upsert(
-            table="remote_profile_cache",
-            keyvalues={"user_id": user_id},
-            values={
-                "displayname": displayname,
-                "avatar_url": avatar_url,
-                "last_check": self._clock.time_msec(),
-            },
-            desc="add_remote_profile_cache",
-        )
-
     async def update_remote_profile_cache(
         self, user_id: str, displayname: str, avatar_url: str
     ) -> int:
@@ -138,6 +117,31 @@ class ProfileStore(ProfileWorkerStore):
                 desc="delete_remote_profile_cache",
             )
 
+    async def is_subscribed_remote_profile_for_user(self, user_id):
+        """Check whether we are interested in a remote user's profile.
+        """
+        res = await self.db_pool.simple_select_one_onecol(
+            table="group_users",
+            keyvalues={"user_id": user_id},
+            retcol="user_id",
+            allow_none=True,
+            desc="should_update_remote_profile_cache_for_user",
+        )
+
+        if res:
+            return True
+
+        res = await self.db_pool.simple_select_one_onecol(
+            table="group_invites",
+            keyvalues={"user_id": user_id},
+            retcol="user_id",
+            allow_none=True,
+            desc="should_update_remote_profile_cache_for_user",
+        )
+
+        if res:
+            return True
+
     async def get_remote_profile_cache_entries_that_expire(
         self, last_checked: int
     ) -> Dict[str, str]:
@@ -160,27 +164,23 @@ class ProfileStore(ProfileWorkerStore):
             _get_remote_profile_cache_entries_that_expire_txn,
         )
 
-    async def is_subscribed_remote_profile_for_user(self, user_id):
-        """Check whether we are interested in a remote user's profile.
-        """
-        res = await self.db_pool.simple_select_one_onecol(
-            table="group_users",
-            keyvalues={"user_id": user_id},
-            retcol="user_id",
-            allow_none=True,
-            desc="should_update_remote_profile_cache_for_user",
-        )
 
-        if res:
-            return True
+class ProfileStore(ProfileWorkerStore):
+    async def add_remote_profile_cache(
+        self, user_id: str, displayname: str, avatar_url: str
+    ) -> None:
+        """Ensure we are caching the remote user's profiles.
 
-        res = await self.db_pool.simple_select_one_onecol(
-            table="group_invites",
+        This should only be called when `is_subscribed_remote_profile_for_user`
+        would return true for the user.
+        """
+        await self.db_pool.simple_upsert(
+            table="remote_profile_cache",
             keyvalues={"user_id": user_id},
-            retcol="user_id",
-            allow_none=True,
-            desc="should_update_remote_profile_cache_for_user",
+            values={
+                "displayname": displayname,
+                "avatar_url": avatar_url,
+                "last_check": self._clock.time_msec(),
+            },
+            desc="add_remote_profile_cache",
         )
-
-        if res:
-            return True
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 9a003e30f9..4c843b7679 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -862,6 +862,32 @@ class RegistrationWorkerStore(SQLBaseStore):
             values={"expiration_ts_ms": expiration_ts, "email_sent": False},
         )
 
+    async def get_user_pending_deactivation(self) -> Optional[str]:
+        """
+        Gets one user from the table of users waiting to be parted from all the rooms
+        they're in.
+        """
+        return await self.db_pool.simple_select_one_onecol(
+            "users_pending_deactivation",
+            keyvalues={},
+            retcol="user_id",
+            allow_none=True,
+            desc="get_users_pending_deactivation",
+        )
+
+    async def del_user_pending_deactivation(self, user_id: str) -> None:
+        """
+        Removes the given user to the table of users who need to be parted from all the
+        rooms they're in, effectively marking that user as fully deactivated.
+        """
+        # XXX: This should be simple_delete_one but we failed to put a unique index on
+        # the table, so somehow duplicate entries have ended up in it.
+        await self.db_pool.simple_delete(
+            "users_pending_deactivation",
+            keyvalues={"user_id": user_id},
+            desc="del_user_pending_deactivation",
+        )
+
 
 class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
@@ -1371,32 +1397,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             desc="add_user_pending_deactivation",
         )
 
-    async def del_user_pending_deactivation(self, user_id: str) -> None:
-        """
-        Removes the given user to the table of users who need to be parted from all the
-        rooms they're in, effectively marking that user as fully deactivated.
-        """
-        # XXX: This should be simple_delete_one but we failed to put a unique index on
-        # the table, so somehow duplicate entries have ended up in it.
-        await self.db_pool.simple_delete(
-            "users_pending_deactivation",
-            keyvalues={"user_id": user_id},
-            desc="del_user_pending_deactivation",
-        )
-
-    async def get_user_pending_deactivation(self) -> Optional[str]:
-        """
-        Gets one user from the table of users waiting to be parted from all the rooms
-        they're in.
-        """
-        return await self.db_pool.simple_select_one_onecol(
-            "users_pending_deactivation",
-            keyvalues={},
-            retcol="user_id",
-            allow_none=True,
-            desc="get_users_pending_deactivation",
-        )
-
     async def validate_threepid_session(
         self, session_id: str, client_secret: str, token: str, current_ts: int
     ) -> Optional[str]:
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index c0f2af0785..e83d961c20 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -869,6 +869,89 @@ class RoomWorkerStore(SQLBaseStore):
             "get_all_new_public_rooms", get_all_new_public_rooms
         )
 
+    async def get_rooms_for_retention_period_in_range(
+        self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False
+    ) -> Dict[str, dict]:
+        """Retrieves all of the rooms within the given retention range.
+
+        Optionally includes the rooms which don't have a retention policy.
+
+        Args:
+            min_ms: Duration in milliseconds that define the lower limit of
+                the range to handle (exclusive). If None, doesn't set a lower limit.
+            max_ms: Duration in milliseconds that define the upper limit of
+                the range to handle (inclusive). If None, doesn't set an upper limit.
+            include_null: Whether to include rooms which retention policy is NULL
+                in the returned set.
+
+        Returns:
+            The rooms within this range, along with their retention
+            policy. The key is "room_id", and maps to a dict describing the retention
+            policy associated with this room ID. The keys for this nested dict are
+            "min_lifetime" (int|None), and "max_lifetime" (int|None).
+        """
+
+        def get_rooms_for_retention_period_in_range_txn(txn):
+            range_conditions = []
+            args = []
+
+            if min_ms is not None:
+                range_conditions.append("max_lifetime > ?")
+                args.append(min_ms)
+
+            if max_ms is not None:
+                range_conditions.append("max_lifetime <= ?")
+                args.append(max_ms)
+
+            # Do a first query which will retrieve the rooms that have a retention policy
+            # in their current state.
+            sql = """
+                SELECT room_id, min_lifetime, max_lifetime FROM room_retention
+                INNER JOIN current_state_events USING (event_id, room_id)
+                """
+
+            if len(range_conditions):
+                sql += " WHERE (" + " AND ".join(range_conditions) + ")"
+
+                if include_null:
+                    sql += " OR max_lifetime IS NULL"
+
+            txn.execute(sql, args)
+
+            rows = self.db_pool.cursor_to_dict(txn)
+            rooms_dict = {}
+
+            for row in rows:
+                rooms_dict[row["room_id"]] = {
+                    "min_lifetime": row["min_lifetime"],
+                    "max_lifetime": row["max_lifetime"],
+                }
+
+            if include_null:
+                # If required, do a second query that retrieves all of the rooms we know
+                # of so we can handle rooms with no retention policy.
+                sql = "SELECT DISTINCT room_id FROM current_state_events"
+
+                txn.execute(sql)
+
+                rows = self.db_pool.cursor_to_dict(txn)
+
+                # If a room isn't already in the dict (i.e. it doesn't have a retention
+                # policy in its state), add it with a null policy.
+                for row in rows:
+                    if row["room_id"] not in rooms_dict:
+                        rooms_dict[row["room_id"]] = {
+                            "min_lifetime": None,
+                            "max_lifetime": None,
+                        }
+
+            return rooms_dict
+
+        return await self.db_pool.runInteraction(
+            "get_rooms_for_retention_period_in_range",
+            get_rooms_for_retention_period_in_range_txn,
+        )
+
 
 class RoomBackgroundUpdateStore(SQLBaseStore):
     REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
@@ -1446,88 +1529,3 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
             self.is_room_blocked,
             (room_id,),
         )
-
-    async def get_rooms_for_retention_period_in_range(
-        self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False
-    ) -> Dict[str, dict]:
-        """Retrieves all of the rooms within the given retention range.
-
-        Optionally includes the rooms which don't have a retention policy.
-
-        Args:
-            min_ms: Duration in milliseconds that define the lower limit of
-                the range to handle (exclusive). If None, doesn't set a lower limit.
-            max_ms: Duration in milliseconds that define the upper limit of
-                the range to handle (inclusive). If None, doesn't set an upper limit.
-            include_null: Whether to include rooms which retention policy is NULL
-                in the returned set.
-
-        Returns:
-            The rooms within this range, along with their retention
-            policy. The key is "room_id", and maps to a dict describing the retention
-            policy associated with this room ID. The keys for this nested dict are
-            "min_lifetime" (int|None), and "max_lifetime" (int|None).
-        """
-
-        def get_rooms_for_retention_period_in_range_txn(txn):
-            range_conditions = []
-            args = []
-
-            if min_ms is not None:
-                range_conditions.append("max_lifetime > ?")
-                args.append(min_ms)
-
-            if max_ms is not None:
-                range_conditions.append("max_lifetime <= ?")
-                args.append(max_ms)
-
-            # Do a first query which will retrieve the rooms that have a retention policy
-            # in their current state.
-            sql = """
-                SELECT room_id, min_lifetime, max_lifetime FROM room_retention
-                INNER JOIN current_state_events USING (event_id, room_id)
-                """
-
-            if len(range_conditions):
-                sql += " WHERE (" + " AND ".join(range_conditions) + ")"
-
-                if include_null:
-                    sql += " OR max_lifetime IS NULL"
-
-            txn.execute(sql, args)
-
-            rows = self.db_pool.cursor_to_dict(txn)
-            rooms_dict = {}
-
-            for row in rows:
-                rooms_dict[row["room_id"]] = {
-                    "min_lifetime": row["min_lifetime"],
-                    "max_lifetime": row["max_lifetime"],
-                }
-
-            if include_null:
-                # If required, do a second query that retrieves all of the rooms we know
-                # of so we can handle rooms with no retention policy.
-                sql = "SELECT DISTINCT room_id FROM current_state_events"
-
-                txn.execute(sql)
-
-                rows = self.db_pool.cursor_to_dict(txn)
-
-                # If a room isn't already in the dict (i.e. it doesn't have a retention
-                # policy in its state), add it with a null policy.
-                for row in rows:
-                    if row["room_id"] not in rooms_dict:
-                        rooms_dict[row["room_id"]] = {
-                            "min_lifetime": None,
-                            "max_lifetime": None,
-                        }
-
-            return rooms_dict
-
-        rooms = await self.db_pool.runInteraction(
-            "get_rooms_for_retention_period_in_range",
-            get_rooms_for_retention_period_in_range_txn,
-        )
-
-        return rooms