diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py
index bde3b5cbbc..2714519d21 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/deviceinbox.py
@@ -20,6 +20,8 @@ from twisted.internet import defer
from .background_updates import BackgroundUpdateStore
+from synapse.util.caches.expiringcache import ExpiringCache
+
logger = logging.getLogger(__name__)
@@ -42,6 +44,15 @@ class DeviceInboxStore(BackgroundUpdateStore):
self._background_drop_index_device_inbox,
)
+ # Map of (user_id, device_id) to the last stream_id that has been
+ # deleted up to. This is so that we can no op deletions.
+ self._last_device_delete_cache = ExpiringCache(
+ cache_name="last_device_delete_cache",
+ clock=self._clock,
+ max_len=10000,
+ expiry_ms=30 * 60 * 1000,
+ )
+
@defer.inlineCallbacks
def add_messages_to_device_inbox(self, local_messages_by_user_then_device,
remote_messages_by_destination):
@@ -167,7 +178,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
)
txn.execute(sql, (user_id,))
message_json = ujson.dumps(messages_by_device["*"])
- for row in txn.fetchall():
+ for row in txn:
# Add the message for all devices for this user on this
# server.
device = row[0]
@@ -184,7 +195,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
# TODO: Maybe this needs to be done in batches if there are
# too many local devices for a given user.
txn.execute(sql, [user_id] + devices)
- for row in txn.fetchall():
+ for row in txn:
# Only insert into the local inbox if the device exists on
# this server
device = row[0]
@@ -240,7 +251,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
user_id, device_id, last_stream_id, current_stream_id, limit
))
messages = []
- for row in txn.fetchall():
+ for row in txn:
stream_pos = row[0]
messages.append(ujson.loads(row[1]))
if len(messages) < limit:
@@ -251,6 +262,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
"get_new_messages_for_device", get_new_messages_for_device_txn,
)
+ @defer.inlineCallbacks
def delete_messages_for_device(self, user_id, device_id, up_to_stream_id):
"""
Args:
@@ -260,6 +272,18 @@ class DeviceInboxStore(BackgroundUpdateStore):
Returns:
A deferred that resolves to the number of messages deleted.
"""
+ # If we have cached the last stream id we've deleted up to, we can
+ # check if there is likely to be anything that needs deleting
+ last_deleted_stream_id = self._last_device_delete_cache.get(
+ (user_id, device_id), None
+ )
+ if last_deleted_stream_id:
+ has_changed = self._device_inbox_stream_cache.has_entity_changed(
+ user_id, last_deleted_stream_id
+ )
+ if not has_changed:
+ defer.returnValue(0)
+
def delete_messages_for_device_txn(txn):
sql = (
"DELETE FROM device_inbox"
@@ -269,10 +293,20 @@ class DeviceInboxStore(BackgroundUpdateStore):
txn.execute(sql, (user_id, device_id, up_to_stream_id))
return txn.rowcount
- return self.runInteraction(
+ count = yield self.runInteraction(
"delete_messages_for_device", delete_messages_for_device_txn
)
+ # Update the cache, ensuring that we only ever increase the value
+ last_deleted_stream_id = self._last_device_delete_cache.get(
+ (user_id, device_id), 0
+ )
+ self._last_device_delete_cache[(user_id, device_id)] = max(
+ last_deleted_stream_id, up_to_stream_id
+ )
+
+ defer.returnValue(count)
+
def get_all_new_device_messages(self, last_pos, current_pos, limit):
"""
Args:
@@ -306,7 +340,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
" ORDER BY stream_id ASC"
)
txn.execute(sql, (last_pos, upper_pos))
- rows.extend(txn.fetchall())
+ rows.extend(txn)
return rows
@@ -323,12 +357,12 @@ class DeviceInboxStore(BackgroundUpdateStore):
"""
Args:
destination(str): The name of the remote server.
- last_stream_id(int): The last position of the device message stream
+ last_stream_id(int|long): The last position of the device message stream
that the server sent up to.
- current_stream_id(int): The current position of the device
+ current_stream_id(int|long): The current position of the device
message stream.
Returns:
- Deferred ([dict], int): List of messages for the device and where
+ Deferred ([dict], int|long): List of messages for the device and where
in the stream the messages got to.
"""
@@ -350,7 +384,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
destination, last_stream_id, current_stream_id, limit
))
messages = []
- for row in txn.fetchall():
+ for row in txn:
stream_pos = row[0]
messages.append(ujson.loads(row[1]))
if len(messages) < limit:
|