summary refs log tree commit diff
path: root/synapse/storage/deviceinbox.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/deviceinbox.py')
-rw-r--r--synapse/storage/deviceinbox.py29
1 files changed, 16 insertions, 13 deletions
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py
index f640e73714..2821eb89c9 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/deviceinbox.py
@@ -242,7 +242,7 @@ class DeviceInboxStore(SQLBaseStore):
             device_id(str): The recipient device_id.
             up_to_stream_id(int): Where to delete messages up to.
         Returns:
-            A deferred that resolves when the messages have been deleted.
+            A deferred that resolves to the number of messages deleted.
         """
         def delete_messages_for_device_txn(txn):
             sql = (
@@ -251,6 +251,7 @@ class DeviceInboxStore(SQLBaseStore):
                 " AND stream_id <= ?"
             )
             txn.execute(sql, (user_id, device_id, up_to_stream_id))
+            return txn.rowcount
 
         return self.runInteraction(
             "delete_messages_for_device", delete_messages_for_device_txn
@@ -269,27 +270,29 @@ class DeviceInboxStore(SQLBaseStore):
             return defer.succeed([])
 
         def get_all_new_device_messages_txn(txn):
+            # We limit like this as we might have multiple rows per stream_id, and
+            # we want to make sure we always get all entries for any stream_id
+            # we return.
+            upper_pos = min(current_pos, last_pos + limit)
             sql = (
-                "SELECT stream_id FROM device_inbox"
+                "SELECT stream_id, user_id"
+                " FROM device_inbox"
                 " WHERE ? < stream_id AND stream_id <= ?"
-                " GROUP BY stream_id"
                 " ORDER BY stream_id ASC"
-                " LIMIT ?"
             )
-            txn.execute(sql, (last_pos, current_pos, limit))
-            stream_ids = txn.fetchall()
-            if not stream_ids:
-                return []
-            max_stream_id_in_limit = stream_ids[-1]
+            txn.execute(sql, (last_pos, upper_pos))
+            rows = txn.fetchall()
 
             sql = (
-                "SELECT stream_id, user_id, device_id, message_json"
-                " FROM device_inbox"
+                "SELECT stream_id, destination"
+                " FROM device_federation_outbox"
                 " WHERE ? < stream_id AND stream_id <= ?"
                 " ORDER BY stream_id ASC"
             )
-            txn.execute(sql, (last_pos, max_stream_id_in_limit))
-            return txn.fetchall()
+            txn.execute(sql, (last_pos, upper_pos))
+            rows.extend(txn.fetchall())
+
+            return rows
 
         return self.runInteraction(
             "get_all_new_device_messages", get_all_new_device_messages_txn