summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/storage/database.py101
-rw-r--r--synapse/storage/databases/main/client_ips.py66
2 files changed, 132 insertions, 35 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 12750d9b89..5eb545c86e 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -1268,6 +1268,7 @@ class DatabasePool:
         value_names: Collection[str],
         value_values: Collection[Collection[Any]],
         desc: str,
+        lock: bool = True,
     ) -> None:
         """
         Upsert, many times.
@@ -1279,6 +1280,8 @@ class DatabasePool:
             value_names: The value column names
             value_values: A list of each row's value column values.
                 Ignored if value_names is empty.
+            lock: True to lock the table when doing the upsert. Unused if the database engine
+                supports native upserts.
         """
 
         # We can autocommit if we are going to use native upserts
@@ -1286,7 +1289,7 @@ class DatabasePool:
             self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables
         )
 
-        return await self.runInteraction(
+        await self.runInteraction(
             desc,
             self.simple_upsert_many_txn,
             table,
@@ -1294,6 +1297,7 @@ class DatabasePool:
             key_values,
             value_names,
             value_values,
+            lock=lock,
             db_autocommit=autocommit,
         )
 
@@ -1305,6 +1309,7 @@ class DatabasePool:
         key_values: Collection[Iterable[Any]],
         value_names: Collection[str],
         value_values: Iterable[Iterable[Any]],
+        lock: bool = True,
     ) -> None:
         """
         Upsert, many times.
@@ -1316,6 +1321,8 @@ class DatabasePool:
             value_names: The value column names
             value_values: A list of each row's value column values.
                 Ignored if value_names is empty.
+            lock: True to lock the table when doing the upsert. Unused if the database engine
+                supports native upserts.
         """
         if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
             return self.simple_upsert_many_txn_native_upsert(
@@ -1323,7 +1330,7 @@ class DatabasePool:
             )
         else:
             return self.simple_upsert_many_txn_emulated(
-                txn, table, key_names, key_values, value_names, value_values
+                txn, table, key_names, key_values, value_names, value_values, lock=lock
             )
 
     def simple_upsert_many_txn_emulated(
@@ -1334,6 +1341,7 @@ class DatabasePool:
         key_values: Collection[Iterable[Any]],
         value_names: Collection[str],
         value_values: Iterable[Iterable[Any]],
+        lock: bool = True,
     ) -> None:
         """
         Upsert, many times, but without native UPSERT support or batching.
@@ -1345,17 +1353,24 @@ class DatabasePool:
             value_names: The value column names
             value_values: A list of each row's value column values.
                 Ignored if value_names is empty.
+            lock: True to lock the table when doing the upsert.
         """
         # No value columns, therefore make a blank list so that the following
         # zip() works correctly.
         if not value_names:
             value_values = [() for x in range(len(key_values))]
 
+        if lock:
+            # Lock the table just once, to prevent it being done once per row.
+            # Note that, according to Postgres' documentation, once obtained,
+            # the lock is held for the remainder of the current transaction.
+            self.engine.lock_table(txn, "user_ips")
+
         for keyv, valv in zip(key_values, value_values):
             _keys = {x: y for x, y in zip(key_names, keyv)}
             _vals = {x: y for x, y in zip(value_names, valv)}
 
-            self.simple_upsert_txn_emulated(txn, table, _keys, _vals)
+            self.simple_upsert_txn_emulated(txn, table, _keys, _vals, lock=False)
 
     def simple_upsert_many_txn_native_upsert(
         self,
@@ -1792,6 +1807,86 @@ class DatabasePool:
 
         return txn.rowcount
 
+    async def simple_update_many(
+        self,
+        table: str,
+        key_names: Collection[str],
+        key_values: Collection[Iterable[Any]],
+        value_names: Collection[str],
+        value_values: Iterable[Iterable[Any]],
+        desc: str,
+    ) -> None:
+        """
+        Update, many times, using batching where possible.
+        If the keys don't match anything, nothing will be updated.
+
+        Args:
+            table: The table to update
+            key_names: The key column names.
+            key_values: A list of each row's key column values.
+            value_names: The names of value columns to update.
+            value_values: A list of each row's value column values.
+        """
+
+        await self.runInteraction(
+            desc,
+            self.simple_update_many_txn,
+            table,
+            key_names,
+            key_values,
+            value_names,
+            value_values,
+        )
+
+    @staticmethod
+    def simple_update_many_txn(
+        txn: LoggingTransaction,
+        table: str,
+        key_names: Collection[str],
+        key_values: Collection[Iterable[Any]],
+        value_names: Collection[str],
+        value_values: Collection[Iterable[Any]],
+    ) -> None:
+        """
+        Update, many times, using batching where possible.
+        If the keys don't match anything, nothing will be updated.
+
+        Args:
+            table: The table to update
+            key_names: The key column names.
+            key_values: A list of each row's key column values.
+            value_names: The names of value columns to update.
+            value_values: A list of each row's value column values.
+        """
+
+        if len(value_values) != len(key_values):
+            raise ValueError(
+                f"{len(key_values)} key rows and {len(value_values)} value rows: should be the same number."
+            )
+
+        # List of tuples of (value values, then key values)
+        # (This matches the order needed for the query)
+        args = [tuple(x) + tuple(y) for x, y in zip(value_values, key_values)]
+
+        for ks, vs in zip(key_values, value_values):
+            args.append(tuple(vs) + tuple(ks))
+
+        # 'col1 = ?, col2 = ?, ...'
+        set_clause = ", ".join(f"{n} = ?" for n in value_names)
+
+        if key_names:
+            # 'WHERE col3 = ? AND col4 = ? AND col5 = ?'
+            where_clause = "WHERE " + (" AND ".join(f"{n} = ?" for n in key_names))
+        else:
+            where_clause = ""
+
+        # UPDATE mytable SET col1 = ?, col2 = ? WHERE col3 = ? AND col4 = ?
+        sql = f"""
+            UPDATE {table} SET {set_clause} {where_clause}
+        """
+
+        txn.execute_batch(sql, args)
+
     async def simple_update_one(
         self,
         table: str,
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 8480ea4e1c..0df160d2b0 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -616,9 +616,10 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
         to_update = self._batch_row_update
         self._batch_row_update = {}
 
-        await self.db_pool.runInteraction(
-            "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
-        )
+        if to_update:
+            await self.db_pool.runInteraction(
+                "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
+            )
 
     def _update_client_ips_batch_txn(
         self,
@@ -629,42 +630,43 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
             self._update_on_this_worker
         ), "This worker is not designated to update client IPs"
 
-        if "user_ips" in self.db_pool._unsafe_to_upsert_tables or (
-            not self.database_engine.can_native_upsert
-        ):
-            self.database_engine.lock_table(txn, "user_ips")
+        # Keys and values for the `user_ips` upsert.
+        user_ips_keys = []
+        user_ips_values = []
+
+        # Keys and values for the `devices` update.
+        devices_keys = []
+        devices_values = []
 
         for entry in to_update.items():
             (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
-
-            self.db_pool.simple_upsert_txn(
-                txn,
-                table="user_ips",
-                keyvalues={"user_id": user_id, "access_token": access_token, "ip": ip},
-                values={
-                    "user_agent": user_agent,
-                    "device_id": device_id,
-                    "last_seen": last_seen,
-                },
-                lock=False,
-            )
+            user_ips_keys.append((user_id, access_token, ip))
+            user_ips_values.append((user_agent, device_id, last_seen))
 
             # Technically an access token might not be associated with
             # a device so we need to check.
             if device_id:
-                # this is always an update rather than an upsert: the row should
-                # already exist, and if it doesn't, that may be because it has been
-                # deleted, and we don't want to re-create it.
-                self.db_pool.simple_update_txn(
-                    txn,
-                    table="devices",
-                    keyvalues={"user_id": user_id, "device_id": device_id},
-                    updatevalues={
-                        "user_agent": user_agent,
-                        "last_seen": last_seen,
-                        "ip": ip,
-                    },
-                )
+                devices_keys.append((user_id, device_id))
+                devices_values.append((user_agent, last_seen, ip))
+
+        self.db_pool.simple_upsert_many_txn(
+            txn,
+            table="user_ips",
+            key_names=("user_id", "access_token", "ip"),
+            key_values=user_ips_keys,
+            value_names=("user_agent", "device_id", "last_seen"),
+            value_values=user_ips_values,
+        )
+
+        if devices_values:
+            self.db_pool.simple_update_many_txn(
+                txn,
+                table="devices",
+                key_names=("user_id", "device_id"),
+                key_values=devices_keys,
+                value_names=("user_agent", "last_seen", "ip"),
+                value_values=devices_values,
+            )
 
     async def get_last_client_ip_by_device(
         self, user_id: str, device_id: Optional[str]