summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/deviceinbox.py69
1 files changed, 32 insertions, 37 deletions
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py
index 68116b0394..57202a5bda 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/deviceinbox.py
@@ -37,9 +37,21 @@ class DeviceInboxStore(SQLBaseStore):
             inserted.
         """
 
-        def select_devices_txn(txn, user_id, devices):
-            if not devices:
-                return []
+        with self._device_inbox_id_gen.get_next() as stream_id:
+            yield self.runInteraction(
+                "add_messages_to_device_inbox",
+                self._add_messages_to_device_inbox_txn,
+                stream_id,
+                messages_by_user_then_device,
+            )
+
+        defer.returnValue(self._device_inbox_id_gen.get_current_token())
+
+    def _add_messages_to_device_inbox_txn(self, txn, stream_id,
+                                          messages_by_user_then_device):
+        local_users_and_devices = set()
+        for user_id, messages_by_device in messages_by_user_then_device.items():
+            devices = messages_by_device.keys()
             sql = (
                 "SELECT user_id, device_id FROM devices"
                 " WHERE user_id = ? AND device_id IN ("
@@ -48,41 +60,24 @@ class DeviceInboxStore(SQLBaseStore):
             )
             # TODO: Maybe this needs to be done in batches if there are
             # too many local devices for a given user.
-            args = [user_id] + devices
-            txn.execute(sql, args)
-            return [tuple(row) for row in txn.fetchall()]
-
-        def add_messages_to_device_inbox_txn(txn, stream_id):
-            local_users_and_devices = set()
-            for user_id, messages_by_device in messages_by_user_then_device.items():
-                local_users_and_devices.update(
-                    select_devices_txn(txn, user_id, messages_by_device.keys())
-                )
-
-            sql = (
-                "INSERT INTO device_inbox"
-                " (user_id, device_id, stream_id, message_json)"
-                " VALUES (?,?,?,?)"
-            )
-            rows = []
-            for user_id, messages_by_device in messages_by_user_then_device.items():
-                for device_id, message in messages_by_device.items():
-                    message_json = ujson.dumps(message)
-                    # Only insert into the local inbox if the device exists on
-                    # this server
-                    if (user_id, device_id) in local_users_and_devices:
-                        rows.append((user_id, device_id, stream_id, message_json))
-
-            txn.executemany(sql, rows)
-
-        with self._device_inbox_id_gen.get_next() as stream_id:
-            yield self.runInteraction(
-                "add_messages_to_device_inbox",
-                add_messages_to_device_inbox_txn,
-                stream_id
-            )
+            txn.execute(sql, [user_id] + devices)
+            local_users_and_devices.update(map(tuple, txn.fetchall()))
 
-        defer.returnValue(self._device_inbox_id_gen.get_current_token())
+        sql = (
+            "INSERT INTO device_inbox"
+            " (user_id, device_id, stream_id, message_json)"
+            " VALUES (?,?,?,?)"
+        )
+        rows = []
+        for user_id, messages_by_device in messages_by_user_then_device.items():
+            for device_id, message in messages_by_device.items():
+                message_json = ujson.dumps(message)
+                # Only insert into the local inbox if the device exists on
+                # this server
+                if (user_id, device_id) in local_users_and_devices:
+                    rows.append((user_id, device_id, stream_id, message_json))
+
+        txn.executemany(sql, rows)
 
     def get_new_messages_for_device(
         self, user_id, device_id, last_stream_id, current_stream_id, limit=100