summary refs log tree commit diff
path: root/synapse/storage/databases/main/devices.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/devices.py')
-rw-r--r--synapse/storage/databases/main/devices.py55
1 files changed, 36 insertions, 19 deletions
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 49edbb9e06..b0811a4cf1 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -1620,7 +1620,6 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
         #
         # For each duplicate, we delete all the existing rows and put one back.
 
-        KEY_COLS = ["stream_id", "destination", "user_id", "device_id"]
         last_row = progress.get(
             "last_row",
             {"stream_id": 0, "destination": "", "user_id": "", "device_id": ""},
@@ -1628,44 +1627,62 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
 
         def _txn(txn: LoggingTransaction) -> int:
             clause, args = make_tuple_comparison_clause(
-                [(x, last_row[x]) for x in KEY_COLS]
+                [
+                    ("stream_id", last_row["stream_id"]),
+                    ("destination", last_row["destination"]),
+                    ("user_id", last_row["user_id"]),
+                    ("device_id", last_row["device_id"]),
+                ]
             )
-            sql = """
+            sql = f"""
                 SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts
                 FROM device_lists_outbound_pokes
-                WHERE %s
-                GROUP BY %s
+                WHERE {clause}
+                GROUP BY stream_id, destination, user_id, device_id
                 HAVING count(*) > 1
-                ORDER BY %s
+                ORDER BY stream_id, destination, user_id, device_id
                 LIMIT ?
-                """ % (
-                clause,  # WHERE
-                ",".join(KEY_COLS),  # GROUP BY
-                ",".join(KEY_COLS),  # ORDER BY
-            )
+                """
             txn.execute(sql, args + [batch_size])
-            rows = self.db_pool.cursor_to_dict(txn)
+            rows = txn.fetchall()
 
-            row = None
-            for row in rows:
+            stream_id, destination, user_id, device_id = None, None, None, None
+            for stream_id, destination, user_id, device_id, _ in rows:
                 self.db_pool.simple_delete_txn(
                     txn,
                     "device_lists_outbound_pokes",
-                    {x: row[x] for x in KEY_COLS},
+                    {
+                        "stream_id": stream_id,
+                        "destination": destination,
+                        "user_id": user_id,
+                        "device_id": device_id,
+                    },
                 )
 
-                row["sent"] = False
                 self.db_pool.simple_insert_txn(
                     txn,
                     "device_lists_outbound_pokes",
-                    row,
+                    {
+                        "stream_id": stream_id,
+                        "destination": destination,
+                        "user_id": user_id,
+                        "device_id": device_id,
+                        "sent": False,
+                    },
                 )
 
-            if row:
+            if rows:
                 self.db_pool.updates._background_update_progress_txn(
                     txn,
                     BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES,
-                    {"last_row": row},
+                    {
+                        "last_row": {
+                            "stream_id": stream_id,
+                            "destination": destination,
+                            "user_id": user_id,
+                            "device_id": device_id,
+                        }
+                    },
                 )
 
             return len(rows)