diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 97c82c150e..b2c76440b7 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -53,7 +53,7 @@ import urlparse
logger = logging.getLogger(__name__)
-UPDATE_RECENTLY_ACCESSED_REMOTES_TS = 60 * 1000
+UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000
class MediaRepository(object):
@@ -75,6 +75,7 @@ class MediaRepository(object):
self.remote_media_linearizer = Linearizer(name="media_remote")
self.recently_accessed_remotes = set()
+ self.recently_accessed_locals = set()
# List of StorageProviders where we should search for media and
# potentially upload to.
@@ -99,19 +100,34 @@ class MediaRepository(object):
)
self.clock.looping_call(
- self._update_recently_accessed_remotes,
- UPDATE_RECENTLY_ACCESSED_REMOTES_TS
+ self._update_recently_accessed,
+ UPDATE_RECENTLY_ACCESSED_TS,
)
@defer.inlineCallbacks
- def _update_recently_accessed_remotes(self):
- media = self.recently_accessed_remotes
+ def _update_recently_accessed(self):
+ remote_media = self.recently_accessed_remotes
self.recently_accessed_remotes = set()
+ local_media = self.recently_accessed_locals
+ self.recently_accessed_locals = set()
+
yield self.store.update_cached_last_access_time(
- media, self.clock.time_msec()
+ local_media, remote_media, self.clock.time_msec()
)
+ def mark_recently_accessed(self, server_name, media_id):
+ """Mark the given media as recently accessed.
+
+ Args:
+ server_name (str|None): Origin server of media, or None if local
+ media_id (str): The media ID of the content
+ """
+ if server_name:
+ self.recently_accessed_remotes.add((server_name, media_id))
+ else:
+ self.recently_accessed_locals.add(media_id)
+
@defer.inlineCallbacks
def create_content(self, media_type, upload_name, content, content_length,
auth_user):
@@ -173,6 +189,8 @@ class MediaRepository(object):
respond_404(request)
return
+ self.mark_recently_accessed(None, media_id)
+
media_type = media_info["media_type"]
media_length = media_info["media_length"]
upload_name = name if name else media_info["upload_name"]
@@ -204,7 +222,7 @@ class MediaRepository(object):
Deferred: Resolves once a response has successfully been written
to request
"""
- self.recently_accessed_remotes.add((server_name, media_id))
+ self.mark_recently_accessed(server_name, media_id)
# We linearize here to ensure that we don't try and download remote
# media multiple times concurrently
|