summary refs log tree commit diff
path: root/tests/rest/media/test_media_retention.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/media/test_media_retention.py')
-rw-r--r--tests/rest/media/test_media_retention.py102
1 files changed, 38 insertions, 64 deletions
diff --git a/tests/rest/media/test_media_retention.py b/tests/rest/media/test_media_retention.py
index 14af07c5af..23f227aed6 100644
--- a/tests/rest/media/test_media_retention.py
+++ b/tests/rest/media/test_media_retention.py
@@ -13,7 +13,9 @@
 # limitations under the License.
 
 import io
-from typing import Iterable, Optional, Tuple
+from typing import Iterable, Optional
+
+from matrix_common.types.mxc_uri import MXCUri
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -63,9 +65,9 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
             last_accessed_ms: Optional[int],
             is_quarantined: Optional[bool] = False,
             is_protected: Optional[bool] = False,
-        ) -> str:
+        ) -> MXCUri:
             # "Upload" some media to the local media store
-            mxc_uri = self.get_success(
+            mxc_uri: MXCUri = self.get_success(
                 media_repository.create_content(
                     media_type="text/plain",
                     upload_name=None,
@@ -75,13 +77,11 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
                 )
             )
 
-            media_id = mxc_uri.split("/")[-1]
-
             # Set the last recently accessed time for this media
             if last_accessed_ms is not None:
                 self.get_success(
                     self.store.update_cached_last_access_time(
-                        local_media=(media_id,),
+                        local_media=(mxc_uri.media_id,),
                         remote_media=(),
                         time_ms=last_accessed_ms,
                     )
@@ -92,7 +92,7 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
                 self.get_success(
                     self.store.quarantine_media_by_id(
                         server_name=self.hs.config.server.server_name,
-                        media_id=media_id,
+                        media_id=mxc_uri.media_id,
                         quarantined_by="@theadmin:test",
                     )
                 )
@@ -101,18 +101,18 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
                 # Mark this media as protected from quarantine
                 self.get_success(
                     self.store.mark_local_media_as_safe(
-                        media_id=media_id,
+                        media_id=mxc_uri.media_id,
                         safe=True,
                     )
                 )
 
-            return media_id
+            return mxc_uri
 
         def _cache_remote_media_and_set_attributes(
             media_id: str,
             last_accessed_ms: Optional[int],
             is_quarantined: Optional[bool] = False,
-        ) -> str:
+        ) -> MXCUri:
             # Pretend to cache some remote media
             self.get_success(
                 self.store.store_cached_remote_media(
@@ -146,7 +146,7 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
                     )
                 )
 
-            return media_id
+            return MXCUri(self.remote_server_name, media_id)
 
         # Start with the local media store
         self.local_recently_accessed_media = _create_media_and_set_attributes(
@@ -214,28 +214,16 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
         # Remote media should be unaffected.
         self._assert_if_mxc_uris_purged(
             purged=[
-                (
-                    self.hs.config.server.server_name,
-                    self.local_not_recently_accessed_media,
-                ),
-                (self.hs.config.server.server_name, self.local_never_accessed_media),
+                self.local_not_recently_accessed_media,
+                self.local_never_accessed_media,
             ],
             not_purged=[
-                (self.hs.config.server.server_name, self.local_recently_accessed_media),
-                (
-                    self.hs.config.server.server_name,
-                    self.local_not_recently_accessed_quarantined_media,
-                ),
-                (
-                    self.hs.config.server.server_name,
-                    self.local_not_recently_accessed_protected_media,
-                ),
-                (self.remote_server_name, self.remote_recently_accessed_media),
-                (self.remote_server_name, self.remote_not_recently_accessed_media),
-                (
-                    self.remote_server_name,
-                    self.remote_not_recently_accessed_quarantined_media,
-                ),
+                self.local_recently_accessed_media,
+                self.local_not_recently_accessed_quarantined_media,
+                self.local_not_recently_accessed_protected_media,
+                self.remote_recently_accessed_media,
+                self.remote_not_recently_accessed_media,
+                self.remote_not_recently_accessed_quarantined_media,
             ],
         )
 
@@ -261,49 +249,35 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
         # Remote media accessed <30 days ago should still exist.
         self._assert_if_mxc_uris_purged(
             purged=[
-                (self.remote_server_name, self.remote_not_recently_accessed_media),
+                self.remote_not_recently_accessed_media,
             ],
             not_purged=[
-                (self.remote_server_name, self.remote_recently_accessed_media),
-                (self.hs.config.server.server_name, self.local_recently_accessed_media),
-                (
-                    self.hs.config.server.server_name,
-                    self.local_not_recently_accessed_media,
-                ),
-                (
-                    self.hs.config.server.server_name,
-                    self.local_not_recently_accessed_quarantined_media,
-                ),
-                (
-                    self.hs.config.server.server_name,
-                    self.local_not_recently_accessed_protected_media,
-                ),
-                (
-                    self.remote_server_name,
-                    self.remote_not_recently_accessed_quarantined_media,
-                ),
-                (self.hs.config.server.server_name, self.local_never_accessed_media),
+                self.remote_recently_accessed_media,
+                self.local_recently_accessed_media,
+                self.local_not_recently_accessed_media,
+                self.local_not_recently_accessed_quarantined_media,
+                self.local_not_recently_accessed_protected_media,
+                self.remote_not_recently_accessed_quarantined_media,
+                self.local_never_accessed_media,
             ],
         )
 
     def _assert_if_mxc_uris_purged(
-        self, purged: Iterable[Tuple[str, str]], not_purged: Iterable[Tuple[str, str]]
+        self, purged: Iterable[MXCUri], not_purged: Iterable[MXCUri]
     ) -> None:
-        def _assert_mxc_uri_purge_state(
-            server_name: str, media_id: str, expect_purged: bool
-        ) -> None:
+        def _assert_mxc_uri_purge_state(mxc_uri: MXCUri, expect_purged: bool) -> None:
             """Given an MXC URI, assert whether it has been purged or not."""
-            if server_name == self.hs.config.server.server_name:
+            if mxc_uri.server_name == self.hs.config.server.server_name:
                 found_media_dict = self.get_success(
-                    self.store.get_local_media(media_id)
+                    self.store.get_local_media(mxc_uri.media_id)
                 )
             else:
                 found_media_dict = self.get_success(
-                    self.store.get_cached_remote_media(server_name, media_id)
+                    self.store.get_cached_remote_media(
+                        mxc_uri.server_name, mxc_uri.media_id
+                    )
                 )
 
-            mxc_uri = f"mxc://{server_name}/{media_id}"
-
             if expect_purged:
                 self.assertIsNone(
                     found_media_dict, msg=f"{mxc_uri} unexpectedly not purged"
@@ -315,7 +289,7 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
                 )
 
         # Assert that the given MXC URIs have either been correctly purged or not.
-        for server_name, media_id in purged:
-            _assert_mxc_uri_purge_state(server_name, media_id, expect_purged=True)
-        for server_name, media_id in not_purged:
-            _assert_mxc_uri_purge_state(server_name, media_id, expect_purged=False)
+        for mxc_uri in purged:
+            _assert_mxc_uri_purge_state(mxc_uri, expect_purged=True)
+        for mxc_uri in not_purged:
+            _assert_mxc_uri_purge_state(mxc_uri, expect_purged=False)