diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py
index 539584288d..bb135166ce 100644
--- a/synapse/storage/client_ips.py
+++ b/synapse/storage/client_ips.py
@@ -463,14 +463,46 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
last_device_id = progress.get("last_device_id", "")
def _devices_last_seen_update_txn(txn):
+ # This consists of two queries:
+ #
+ # 1. The sub-query searches for the next N devices and joins
+ # against user_ips to find the max last_seen associated with
+ # that device.
+ # 2. The outer query then joins again against user_ips on
+ # user/device/last_seen. This *should* hopefully only
+ # return one row, but if it does return more than one then
+ # we'll just end up updating the same device row multiple
+ # times, which is fine.
+
+ if self.database_engine.supports_tuple_comparison:
+ where_clause = "(user_id, device_id) > (?, ?)"
+ where_args = [last_user_id, last_device_id]
+ else:
+ # We explicitly do a `user_id >= ? AND (...)` here to ensure
+ # that an index is used, as doing `user_id > ? OR (user_id = ? AND ...)`
+ # makes it hard for query optimiser to tell that it can use the
+ # index on user_id
+ where_clause = "user_id >= ? AND (user_id > ? OR device_id > ?)"
+ where_args = [last_user_id, last_user_id, last_device_id]
+
sql = """
- SELECT u.last_seen, u.ip, u.user_agent, user_id, device_id FROM devices
- INNER JOIN user_ips AS u USING (user_id, device_id)
- WHERE user_id > ? OR (user_id = ? AND device_id > ?)
- ORDER BY user_id ASC, device_id ASC
- LIMIT ?
- """
- txn.execute(sql, (last_user_id, last_user_id, last_device_id, batch_size))
+ SELECT
+ last_seen, ip, user_agent, user_id, device_id
+ FROM (
+ SELECT
+ user_id, device_id, MAX(u.last_seen) AS last_seen
+ FROM devices
+ INNER JOIN user_ips AS u USING (user_id, device_id)
+ WHERE %(where_clause)s
+ GROUP BY user_id, device_id
+ ORDER BY user_id ASC, device_id ASC
+ LIMIT ?
+ ) c
+ INNER JOIN user_ips AS u USING (user_id, device_id, last_seen)
+ """ % {
+ "where_clause": where_clause
+ }
+ txn.execute(sql, where_args + [batch_size])
rows = txn.fetchall()
if not rows:
|