summary refs log tree commit diff
path: root/synapse/storage/databases/main
diff options
context:
space:
mode:
authorShay <hillerys@element.io>2023-04-26 16:03:26 -0700
committerGitHub <noreply@github.com>2023-04-26 16:03:26 -0700
commit301b4156d5574521e4fa3df8fed2f8a1c8617745 (patch)
tree2f9f837444c8802451f51b1a165f6bcdf181fcfd /synapse/storage/databases/main
parentAdd a module API to send an HTTP push notification (#15387) (diff)
downloadsynapse-301b4156d5574521e4fa3df8fed2f8a1c8617745.tar.xz
Add column `full_user_id` to tables `profiles` and `user_filters`. (#15458)
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r--synapse/storage/databases/main/filtering.py47
-rw-r--r--synapse/storage/databases/main/profile.py42
-rw-r--r--synapse/storage/databases/main/registration.py4
3 files changed, 76 insertions, 17 deletions
diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py
index 8e57c8e5a0..50516402f9 100644
--- a/synapse/storage/databases/main/filtering.py
+++ b/synapse/storage/databases/main/filtering.py
@@ -16,15 +16,38 @@
 from typing import Optional, Tuple, Union, cast
 
 from canonicaljson import encode_canonical_json
+from typing_extensions import TYPE_CHECKING
 
 from synapse.api.errors import Codes, StoreError, SynapseError
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import LoggingTransaction
-from synapse.types import JsonDict
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
+from synapse.types import JsonDict, UserID
 from synapse.util.caches.descriptors import cached
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 
 class FilteringWorkerStore(SQLBaseStore):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
+        super().__init__(database, db_conn, hs)
+        self.db_pool.updates.register_background_index_update(
+            "full_users_filters_unique_idx",
+            index_name="full_users_unique_idx",
+            table="user_filters",
+            columns=["full_user_id, filter_id"],
+            unique=True,
+        )
+
     @cached(num_args=2)
     async def get_user_filter(
         self, user_localpart: str, filter_id: Union[int, str]
@@ -46,7 +69,7 @@ class FilteringWorkerStore(SQLBaseStore):
 
         return db_to_json(def_json)
 
-    async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> int:
+    async def add_user_filter(self, user_id: UserID, user_filter: JsonDict) -> int:
         def_json = encode_canonical_json(user_filter)
 
         # Need an atomic transaction to SELECT the maximal ID so far then
@@ -56,13 +79,13 @@ class FilteringWorkerStore(SQLBaseStore):
                 "SELECT filter_id FROM user_filters "
                 "WHERE user_id = ? AND filter_json = ?"
             )
-            txn.execute(sql, (user_localpart, bytearray(def_json)))
+            txn.execute(sql, (user_id.localpart, bytearray(def_json)))
             filter_id_response = txn.fetchone()
             if filter_id_response is not None:
                 return filter_id_response[0]
 
             sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?"
-            txn.execute(sql, (user_localpart,))
+            txn.execute(sql, (user_id.localpart,))
             max_id = cast(Tuple[Optional[int]], txn.fetchone())[0]
             if max_id is None:
                 filter_id = 0
@@ -70,10 +93,18 @@ class FilteringWorkerStore(SQLBaseStore):
                 filter_id = max_id + 1
 
             sql = (
-                "INSERT INTO user_filters (user_id, filter_id, filter_json)"
-                "VALUES(?, ?, ?)"
+                "INSERT INTO user_filters (full_user_id, user_id, filter_id, filter_json)"
+                "VALUES(?, ?, ?, ?)"
+            )
+            txn.execute(
+                sql,
+                (
+                    user_id.to_string(),
+                    user_id.localpart,
+                    filter_id,
+                    bytearray(def_json),
+                ),
             )
-            txn.execute(sql, (user_localpart, filter_id, bytearray(def_json)))
 
             return filter_id
 
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index a1747f04ce..b109f8c07f 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -11,14 +11,34 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Optional
+from typing import TYPE_CHECKING, Optional
 
 from synapse.api.errors import StoreError
 from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.roommember import ProfileInfo
+from synapse.types import UserID
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
 
 
 class ProfileWorkerStore(SQLBaseStore):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
+        super().__init__(database, db_conn, hs)
+        self.db_pool.updates.register_background_index_update(
+            "profiles_full_user_id_key_idx",
+            index_name="profiles_full_user_id_key",
+            table="profiles",
+            columns=["full_user_id"],
+            unique=True,
+        )
+
     async def get_profileinfo(self, user_localpart: str) -> ProfileInfo:
         try:
             profile = await self.db_pool.simple_select_one(
@@ -54,28 +74,36 @@ class ProfileWorkerStore(SQLBaseStore):
             desc="get_profile_avatar_url",
         )
 
-    async def create_profile(self, user_localpart: str) -> None:
+    async def create_profile(self, user_id: UserID) -> None:
+        user_localpart = user_id.localpart
         await self.db_pool.simple_insert(
-            table="profiles", values={"user_id": user_localpart}, desc="create_profile"
+            table="profiles",
+            values={"user_id": user_localpart, "full_user_id": user_id.to_string()},
+            desc="create_profile",
         )
 
     async def set_profile_displayname(
-        self, user_localpart: str, new_displayname: Optional[str]
+        self, user_id: UserID, new_displayname: Optional[str]
     ) -> None:
+        user_localpart = user_id.localpart
         await self.db_pool.simple_upsert(
             table="profiles",
             keyvalues={"user_id": user_localpart},
-            values={"displayname": new_displayname},
+            values={
+                "displayname": new_displayname,
+                "full_user_id": user_id.to_string(),
+            },
             desc="set_profile_displayname",
         )
 
     async def set_profile_avatar_url(
-        self, user_localpart: str, new_avatar_url: Optional[str]
+        self, user_id: UserID, new_avatar_url: Optional[str]
     ) -> None:
+        user_localpart = user_id.localpart
         await self.db_pool.simple_upsert(
             table="profiles",
             keyvalues={"user_id": user_localpart},
-            values={"avatar_url": new_avatar_url},
+            values={"avatar_url": new_avatar_url, "full_user_id": user_id.to_string()},
             desc="set_profile_avatar_url",
         )
 
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 717237e024..676d03bb7e 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -2414,8 +2414,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
             # *obviously* the 'profiles' table uses localpart for user_id
             # while everything else uses the full mxid.
             txn.execute(
-                "INSERT INTO profiles(user_id, displayname) VALUES (?,?)",
-                (user_id_obj.localpart, create_profile_with_displayname),
+                "INSERT INTO profiles(full_user_id, user_id, displayname) VALUES (?,?,?)",
+                (user_id, user_id_obj.localpart, create_profile_with_displayname),
             )
 
         if self.hs.config.stats.stats_enabled: