diff --git a/changelog.d/12252.feature b/changelog.d/12252.feature
new file mode 100644
index 0000000000..82b9e82f86
--- /dev/null
+++ b/changelog.d/12252.feature
@@ -0,0 +1 @@
+Move `update_client_ip` background job from the main process to the background worker.
\ No newline at end of file
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 12750d9b89..5eb545c86e 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -1268,6 +1268,7 @@ class DatabasePool:
value_names: Collection[str],
value_values: Collection[Collection[Any]],
desc: str,
+ lock: bool = True,
) -> None:
"""
Upsert, many times.
@@ -1279,6 +1280,8 @@ class DatabasePool:
value_names: The value column names
value_values: A list of each row's value column values.
Ignored if value_names is empty.
+ lock: True to lock the table when doing the upsert. Unused if the database engine
+ supports native upserts.
"""
# We can autocommit if we are going to use native upserts
@@ -1286,7 +1289,7 @@ class DatabasePool:
self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables
)
- return await self.runInteraction(
+ await self.runInteraction(
desc,
self.simple_upsert_many_txn,
table,
@@ -1294,6 +1297,7 @@ class DatabasePool:
key_values,
value_names,
value_values,
+ lock=lock,
db_autocommit=autocommit,
)
@@ -1305,6 +1309,7 @@ class DatabasePool:
key_values: Collection[Iterable[Any]],
value_names: Collection[str],
value_values: Iterable[Iterable[Any]],
+ lock: bool = True,
) -> None:
"""
Upsert, many times.
@@ -1316,6 +1321,8 @@ class DatabasePool:
value_names: The value column names
value_values: A list of each row's value column values.
Ignored if value_names is empty.
+ lock: True to lock the table when doing the upsert. Unused if the database engine
+ supports native upserts.
"""
if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
return self.simple_upsert_many_txn_native_upsert(
@@ -1323,7 +1330,7 @@ class DatabasePool:
)
else:
return self.simple_upsert_many_txn_emulated(
- txn, table, key_names, key_values, value_names, value_values
+ txn, table, key_names, key_values, value_names, value_values, lock=lock
)
def simple_upsert_many_txn_emulated(
@@ -1334,6 +1341,7 @@ class DatabasePool:
key_values: Collection[Iterable[Any]],
value_names: Collection[str],
value_values: Iterable[Iterable[Any]],
+ lock: bool = True,
) -> None:
"""
Upsert, many times, but without native UPSERT support or batching.
@@ -1345,17 +1353,24 @@ class DatabasePool:
value_names: The value column names
value_values: A list of each row's value column values.
Ignored if value_names is empty.
+ lock: True to lock the table when doing the upsert.
"""
# No value columns, therefore make a blank list so that the following
# zip() works correctly.
if not value_names:
value_values = [() for x in range(len(key_values))]
+ if lock:
+ # Lock the table just once, to prevent it being done once per row.
+ # Note that, according to Postgres' documentation, once obtained,
+ # the lock is held for the remainder of the current transaction.
+ self.engine.lock_table(txn, "user_ips")
+
for keyv, valv in zip(key_values, value_values):
_keys = {x: y for x, y in zip(key_names, keyv)}
_vals = {x: y for x, y in zip(value_names, valv)}
- self.simple_upsert_txn_emulated(txn, table, _keys, _vals)
+ self.simple_upsert_txn_emulated(txn, table, _keys, _vals, lock=False)
def simple_upsert_many_txn_native_upsert(
self,
@@ -1792,6 +1807,86 @@ class DatabasePool:
return txn.rowcount
+ async def simple_update_many(
+ self,
+ table: str,
+ key_names: Collection[str],
+ key_values: Collection[Iterable[Any]],
+ value_names: Collection[str],
+ value_values: Iterable[Iterable[Any]],
+ desc: str,
+ ) -> None:
+ """
+ Update, many times, using batching where possible.
+ If the keys don't match anything, nothing will be updated.
+
+ Args:
+ table: The table to update
+ key_names: The key column names.
+ key_values: A list of each row's key column values.
+ value_names: The names of value columns to update.
+ value_values: A list of each row's value column values.
+ """
+
+ await self.runInteraction(
+ desc,
+ self.simple_update_many_txn,
+ table,
+ key_names,
+ key_values,
+ value_names,
+ value_values,
+ )
+
+ @staticmethod
+ def simple_update_many_txn(
+ txn: LoggingTransaction,
+ table: str,
+ key_names: Collection[str],
+ key_values: Collection[Iterable[Any]],
+ value_names: Collection[str],
+ value_values: Collection[Iterable[Any]],
+ ) -> None:
+ """
+ Update, many times, using batching where possible.
+ If the keys don't match anything, nothing will be updated.
+
+ Args:
+ table: The table to update
+ key_names: The key column names.
+ key_values: A list of each row's key column values.
+ value_names: The names of value columns to update.
+ value_values: A list of each row's value column values.
+ """
+
+ if len(value_values) != len(key_values):
+ raise ValueError(
+ f"{len(key_values)} key rows and {len(value_values)} value rows: should be the same number."
+ )
+
+ # List of tuples of (value values, then key values)
+ # (This matches the order needed for the query)
+ args = [tuple(x) + tuple(y) for x, y in zip(value_values, key_values)]
+
+ for ks, vs in zip(key_values, value_values):
+ args.append(tuple(vs) + tuple(ks))
+
+ # 'col1 = ?, col2 = ?, ...'
+ set_clause = ", ".join(f"{n} = ?" for n in value_names)
+
+ if key_names:
+ # 'WHERE col3 = ? AND col4 = ? AND col5 = ?'
+ where_clause = "WHERE " + (" AND ".join(f"{n} = ?" for n in key_names))
+ else:
+ where_clause = ""
+
+ # UPDATE mytable SET col1 = ?, col2 = ? WHERE col3 = ? AND col4 = ?
+ sql = f"""
+ UPDATE {table} SET {set_clause} {where_clause}
+ """
+
+ txn.execute_batch(sql, args)
+
async def simple_update_one(
self,
table: str,
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 8480ea4e1c..0df160d2b0 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -616,9 +616,10 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
to_update = self._batch_row_update
self._batch_row_update = {}
- await self.db_pool.runInteraction(
- "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
- )
+ if to_update:
+ await self.db_pool.runInteraction(
+ "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
+ )
def _update_client_ips_batch_txn(
self,
@@ -629,42 +630,43 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
self._update_on_this_worker
), "This worker is not designated to update client IPs"
- if "user_ips" in self.db_pool._unsafe_to_upsert_tables or (
- not self.database_engine.can_native_upsert
- ):
- self.database_engine.lock_table(txn, "user_ips")
+ # Keys and values for the `user_ips` upsert.
+ user_ips_keys = []
+ user_ips_values = []
+
+ # Keys and values for the `devices` update.
+ devices_keys = []
+ devices_values = []
for entry in to_update.items():
(user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
-
- self.db_pool.simple_upsert_txn(
- txn,
- table="user_ips",
- keyvalues={"user_id": user_id, "access_token": access_token, "ip": ip},
- values={
- "user_agent": user_agent,
- "device_id": device_id,
- "last_seen": last_seen,
- },
- lock=False,
- )
+ user_ips_keys.append((user_id, access_token, ip))
+ user_ips_values.append((user_agent, device_id, last_seen))
# Technically an access token might not be associated with
# a device so we need to check.
if device_id:
- # this is always an update rather than an upsert: the row should
- # already exist, and if it doesn't, that may be because it has been
- # deleted, and we don't want to re-create it.
- self.db_pool.simple_update_txn(
- txn,
- table="devices",
- keyvalues={"user_id": user_id, "device_id": device_id},
- updatevalues={
- "user_agent": user_agent,
- "last_seen": last_seen,
- "ip": ip,
- },
- )
+ devices_keys.append((user_id, device_id))
+ devices_values.append((user_agent, last_seen, ip))
+
+ self.db_pool.simple_upsert_many_txn(
+ txn,
+ table="user_ips",
+ key_names=("user_id", "access_token", "ip"),
+ key_values=user_ips_keys,
+ value_names=("user_agent", "device_id", "last_seen"),
+ value_values=user_ips_values,
+ )
+
+ if devices_values:
+ self.db_pool.simple_update_many_txn(
+ txn,
+ table="devices",
+ key_names=("user_id", "device_id"),
+ key_values=devices_keys,
+ value_names=("user_agent", "last_seen", "ip"),
+ value_values=devices_values,
+ )
async def get_last_client_ip_by_device(
self, user_id: str, device_id: Optional[str]
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 366398e39d..09cb06d614 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -14,7 +14,7 @@
# limitations under the License.
import secrets
-from typing import Any, Dict, Generator, List, Tuple
+from typing import Generator, Tuple
from twisted.test.proto_helpers import MemoryReactor
@@ -24,7 +24,7 @@ from synapse.util import Clock
from tests import unittest
-class UpsertManyTests(unittest.HomeserverTestCase):
+class UpdateUpsertManyTests(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.storage = hs.get_datastores().main
@@ -46,9 +46,13 @@ class UpsertManyTests(unittest.HomeserverTestCase):
)
)
- def _dump_to_tuple(
- self, res: List[Dict[str, Any]]
- ) -> Generator[Tuple[int, str, str], None, None]:
+ def _dump_table_to_tuple(self) -> Generator[Tuple[int, str, str], None, None]:
+ res = self.get_success(
+ self.storage.db_pool.simple_select_list(
+ self.table_name, None, ["id, username, value"]
+ )
+ )
+
for i in res:
yield (i["id"], i["username"], i["value"])
@@ -75,13 +79,8 @@ class UpsertManyTests(unittest.HomeserverTestCase):
)
# Check results are what we expect
- res = self.get_success(
- self.storage.db_pool.simple_select_list(
- self.table_name, None, ["id, username, value"]
- )
- )
self.assertEqual(
- set(self._dump_to_tuple(res)),
+ set(self._dump_table_to_tuple()),
{(1, "user1", "hello"), (2, "user2", "there")},
)
@@ -102,12 +101,54 @@ class UpsertManyTests(unittest.HomeserverTestCase):
)
# Check results are what we expect
- res = self.get_success(
- self.storage.db_pool.simple_select_list(
- self.table_name, None, ["id, username, value"]
+ self.assertEqual(
+ set(self._dump_table_to_tuple()),
+ {(1, "user1", "hello"), (2, "user2", "bleb")},
+ )
+
+ def test_simple_update_many(self):
+ """
+ simple_update_many performs many updates at once.
+ """
+ # First add some data.
+ self.get_success(
+ self.storage.db_pool.simple_insert_many(
+ table=self.table_name,
+ keys=("id", "username", "value"),
+ values=[(1, "alice", "A"), (2, "bob", "B"), (3, "charlie", "C")],
+ desc="insert",
)
)
+
+ # Check the data made it to the table
self.assertEqual(
- set(self._dump_to_tuple(res)),
- {(1, "user1", "hello"), (2, "user2", "bleb")},
+ set(self._dump_table_to_tuple()),
+ {(1, "alice", "A"), (2, "bob", "B"), (3, "charlie", "C")},
+ )
+
+ # Now use simple_update_many
+ self.get_success(
+ self.storage.db_pool.simple_update_many(
+ table=self.table_name,
+ key_names=("username",),
+ key_values=(
+ ("alice",),
+ ("bob",),
+ ("stranger",),
+ ),
+ value_names=("value",),
+ value_values=(
+ ("aaa!",),
+ ("bbb!",),
+ ("???",),
+ ),
+ desc="update_many1",
+ )
+ )
+
+ # Check the table is how we expect:
+ # charlie has been left alone
+ self.assertEqual(
+ set(self._dump_table_to_tuple()),
+ {(1, "alice", "aaa!"), (2, "bob", "bbb!"), (3, "charlie", "C")},
)
|