summary refs log tree commit diff
path: root/synapse/storage/databases/main/media_repository.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/media_repository.py')
-rw-r--r--synapse/storage/databases/main/media_repository.py90
1 files changed, 88 insertions, 2 deletions
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 3f80a64dc5..149135b8b5 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -49,13 +49,14 @@ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2 = (
 class LocalMedia:
     media_id: str
     media_type: str
-    media_length: int
+    media_length: Optional[int]
     upload_name: str
     created_ts: int
     url_cache: Optional[str]
     last_access_ts: int
     quarantined_by: Optional[str]
     safe_from_quarantine: bool
+    user_id: Optional[str]
 
 
 @attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -149,6 +150,13 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
             self._drop_media_index_without_method,
         )
 
+        if hs.config.media.can_load_media_repo:
+            self.unused_expiration_time: Optional[
+                int
+            ] = hs.config.media.unused_expiration_time
+        else:
+            self.unused_expiration_time = None
+
     async def _drop_media_index_without_method(
         self, progress: JsonDict, batch_size: int
     ) -> int:
@@ -202,6 +210,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                 "url_cache",
                 "last_access_ts",
                 "safe_from_quarantine",
+                "user_id",
             ),
             allow_none=True,
             desc="get_local_media",
@@ -218,6 +227,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             url_cache=row[5],
             last_access_ts=row[6],
             safe_from_quarantine=row[7],
+            user_id=row[8],
         )
 
     async def get_local_media_by_user_paginate(
@@ -272,7 +282,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                     url_cache,
                     last_access_ts,
                     quarantined_by,
-                    safe_from_quarantine
+                    safe_from_quarantine,
+                    user_id
                 FROM local_media_repository
                 WHERE user_id = ?
                 ORDER BY {order_by_column} {order}, media_id ASC
@@ -295,6 +306,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                     last_access_ts=row[6],
                     quarantined_by=row[7],
                     safe_from_quarantine=bool(row[8]),
+                    user_id=row[9],
                 )
                 for row in txn
             ]
@@ -392,6 +404,23 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         )
 
     @trace
+    async def store_local_media_id(
+        self,
+        media_id: str,
+        time_now_ms: int,
+        user_id: UserID,
+    ) -> None:
+        await self.db_pool.simple_insert(
+            "local_media_repository",
+            {
+                "media_id": media_id,
+                "created_ts": time_now_ms,
+                "user_id": user_id.to_string(),
+            },
+            desc="store_local_media_id",
+        )
+
+    @trace
     async def store_local_media(
         self,
         media_id: str,
@@ -416,6 +445,30 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             desc="store_local_media",
         )
 
+    async def update_local_media(
+        self,
+        media_id: str,
+        media_type: str,
+        upload_name: Optional[str],
+        media_length: int,
+        user_id: UserID,
+        url_cache: Optional[str] = None,
+    ) -> None:
+        await self.db_pool.simple_update_one(
+            "local_media_repository",
+            keyvalues={
+                "user_id": user_id.to_string(),
+                "media_id": media_id,
+            },
+            updatevalues={
+                "media_type": media_type,
+                "upload_name": upload_name,
+                "media_length": media_length,
+                "url_cache": url_cache,
+            },
+            desc="update_local_media",
+        )
+
     async def mark_local_media_as_safe(self, media_id: str, safe: bool = True) -> None:
         """Mark a local media as safe or unsafe from quarantining."""
         await self.db_pool.simple_update_one(
@@ -425,6 +478,39 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             desc="mark_local_media_as_safe",
         )
 
+    async def count_pending_media(self, user_id: UserID) -> Tuple[int, int]:
+        """Count the number of pending media for a user.
+
+        Returns:
+            A tuple of two integers: the total pending media requests and the earliest
+            expiration timestamp.
+        """
+
+        def get_pending_media_txn(txn: LoggingTransaction) -> Tuple[int, int]:
+            sql = """
+            SELECT COUNT(*), MIN(created_ts)
+            FROM local_media_repository
+            WHERE user_id = ?
+                AND created_ts > ?
+                AND media_length IS NULL
+            """
+            assert self.unused_expiration_time is not None
+            txn.execute(
+                sql,
+                (
+                    user_id.to_string(),
+                    self._clock.time_msec() - self.unused_expiration_time,
+                ),
+            )
+            row = txn.fetchone()
+            if not row:
+                return 0, 0
+            return row[0], (row[1] + self.unused_expiration_time if row[1] else 0)
+
+        return await self.db_pool.runInteraction(
+            "get_pending_media", get_pending_media_txn
+        )
+
     async def get_url_cache(self, url: str, ts: int) -> Optional[UrlCache]:
         """Get the media_id and ts for a cached URL as of the given timestamp
         Returns: