summary refs log tree commit diff
path: root/synapse/storage/deviceinbox.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/deviceinbox.py')
-rw-r--r--synapse/storage/deviceinbox.py52
1 files changed, 43 insertions, 9 deletions
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py
index bde3b5cbbc..2714519d21 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/deviceinbox.py
@@ -20,6 +20,8 @@ from twisted.internet import defer
 
 from .background_updates import BackgroundUpdateStore
 
+from synapse.util.caches.expiringcache import ExpiringCache
+
 
 logger = logging.getLogger(__name__)
 
@@ -42,6 +44,15 @@ class DeviceInboxStore(BackgroundUpdateStore):
             self._background_drop_index_device_inbox,
         )
 
+        # Map of (user_id, device_id) to the last stream_id that has been
+        # deleted up to. This is so that we can no op deletions.
+        self._last_device_delete_cache = ExpiringCache(
+            cache_name="last_device_delete_cache",
+            clock=self._clock,
+            max_len=10000,
+            expiry_ms=30 * 60 * 1000,
+        )
+
     @defer.inlineCallbacks
     def add_messages_to_device_inbox(self, local_messages_by_user_then_device,
                                      remote_messages_by_destination):
@@ -167,7 +178,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
                 )
                 txn.execute(sql, (user_id,))
                 message_json = ujson.dumps(messages_by_device["*"])
-                for row in txn.fetchall():
+                for row in txn:
                     # Add the message for all devices for this user on this
                     # server.
                     device = row[0]
@@ -184,7 +195,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
                 # TODO: Maybe this needs to be done in batches if there are
                 # too many local devices for a given user.
                 txn.execute(sql, [user_id] + devices)
-                for row in txn.fetchall():
+                for row in txn:
                     # Only insert into the local inbox if the device exists on
                     # this server
                     device = row[0]
@@ -240,7 +251,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
                 user_id, device_id, last_stream_id, current_stream_id, limit
             ))
             messages = []
-            for row in txn.fetchall():
+            for row in txn:
                 stream_pos = row[0]
                 messages.append(ujson.loads(row[1]))
             if len(messages) < limit:
@@ -251,6 +262,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
             "get_new_messages_for_device", get_new_messages_for_device_txn,
         )
 
+    @defer.inlineCallbacks
     def delete_messages_for_device(self, user_id, device_id, up_to_stream_id):
         """
         Args:
@@ -260,6 +272,18 @@ class DeviceInboxStore(BackgroundUpdateStore):
         Returns:
             A deferred that resolves to the number of messages deleted.
         """
+        # If we have cached the last stream id we've deleted up to, we can
+        # check if there is likely to be anything that needs deleting
+        last_deleted_stream_id = self._last_device_delete_cache.get(
+            (user_id, device_id), None
+        )
+        if last_deleted_stream_id:
+            has_changed = self._device_inbox_stream_cache.has_entity_changed(
+                user_id, last_deleted_stream_id
+            )
+            if not has_changed:
+                defer.returnValue(0)
+
         def delete_messages_for_device_txn(txn):
             sql = (
                 "DELETE FROM device_inbox"
@@ -269,10 +293,20 @@ class DeviceInboxStore(BackgroundUpdateStore):
             txn.execute(sql, (user_id, device_id, up_to_stream_id))
             return txn.rowcount
 
-        return self.runInteraction(
+        count = yield self.runInteraction(
             "delete_messages_for_device", delete_messages_for_device_txn
         )
 
+        # Update the cache, ensuring that we only ever increase the value
+        last_deleted_stream_id = self._last_device_delete_cache.get(
+            (user_id, device_id), 0
+        )
+        self._last_device_delete_cache[(user_id, device_id)] = max(
+            last_deleted_stream_id, up_to_stream_id
+        )
+
+        defer.returnValue(count)
+
     def get_all_new_device_messages(self, last_pos, current_pos, limit):
         """
         Args:
@@ -306,7 +340,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
                 " ORDER BY stream_id ASC"
             )
             txn.execute(sql, (last_pos, upper_pos))
-            rows.extend(txn.fetchall())
+            rows.extend(txn)
 
             return rows
 
@@ -323,12 +357,12 @@ class DeviceInboxStore(BackgroundUpdateStore):
         """
         Args:
             destination(str): The name of the remote server.
-            last_stream_id(int): The last position of the device message stream
+            last_stream_id(int|long): The last position of the device message stream
                 that the server sent up to.
-            current_stream_id(int): The current position of the device
+            current_stream_id(int|long): The current position of the device
                 message stream.
         Returns:
-            Deferred ([dict], int): List of messages for the device and where
+            Deferred ([dict], int|long): List of messages for the device and where
                 in the stream the messages got to.
         """
 
@@ -350,7 +384,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
                 destination, last_stream_id, current_stream_id, limit
             ))
             messages = []
-            for row in txn.fetchall():
+            for row in txn:
                 stream_pos = row[0]
                 messages.append(ujson.loads(row[1]))
             if len(messages) < limit: