summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/app/media_repository.py2
-rw-r--r--synapse/storage/data_stores/main/room.py248
2 files changed, 130 insertions, 120 deletions
diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py
index a63c53dc44..5b5832214a 100644
--- a/synapse/app/media_repository.py
+++ b/synapse/app/media_repository.py
@@ -34,6 +34,7 @@ from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
 from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
 from synapse.replication.slave.storage.registration import SlavedRegistrationStore
+from synapse.replication.slave.storage.room import RoomStore
 from synapse.replication.slave.storage.transactions import SlavedTransactionStore
 from synapse.replication.tcp.client import ReplicationClientHandler
 from synapse.rest.admin import register_servlets_for_media_repo
@@ -47,6 +48,7 @@ logger = logging.getLogger("synapse.app.media_repository")
 
 
 class MediaRepositorySlavedStore(
+    RoomStore,
     SlavedApplicationServiceStore,
     SlavedRegistrationStore,
     SlavedClientIpStore,
diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py
index 79cfd39194..8636d75030 100644
--- a/synapse/storage/data_stores/main/room.py
+++ b/synapse/storage/data_stores/main/room.py
@@ -366,6 +366,134 @@ class RoomWorkerStore(SQLBaseStore):
 
         defer.returnValue(row)
 
+    def get_media_mxcs_in_room(self, room_id):
+        """Retrieves all the local and remote media MXC URIs in a given room
+
+        Args:
+            room_id (str)
+
+        Returns:
+            The local and remote media as a lists of tuples where the key is
+            the hostname and the value is the media ID.
+        """
+
+        def _get_media_mxcs_in_room_txn(txn):
+            local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
+            local_media_mxcs = []
+            remote_media_mxcs = []
+
+            # Convert the IDs to MXC URIs
+            for media_id in local_mxcs:
+                local_media_mxcs.append("mxc://%s/%s" % (self.hs.hostname, media_id))
+            for hostname, media_id in remote_mxcs:
+                remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id))
+
+            return local_media_mxcs, remote_media_mxcs
+
+        return self.db.runInteraction(
+            "get_media_ids_in_room", _get_media_mxcs_in_room_txn
+        )
+
+    def quarantine_media_ids_in_room(self, room_id, quarantined_by):
+        """For a room loops through all events with media and quarantines
+        the associated media
+        """
+
+        def _quarantine_media_in_room_txn(txn):
+            local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
+            total_media_quarantined = 0
+
+            # Now update all the tables to set the quarantined_by flag
+
+            txn.executemany(
+                """
+                UPDATE local_media_repository
+                SET quarantined_by = ?
+                WHERE media_id = ?
+            """,
+                ((quarantined_by, media_id) for media_id in local_mxcs),
+            )
+
+            txn.executemany(
+                """
+                    UPDATE remote_media_cache
+                    SET quarantined_by = ?
+                    WHERE media_origin = ? AND media_id = ?
+                """,
+                (
+                    (quarantined_by, origin, media_id)
+                    for origin, media_id in remote_mxcs
+                ),
+            )
+
+            total_media_quarantined += len(local_mxcs)
+            total_media_quarantined += len(remote_mxcs)
+
+            return total_media_quarantined
+
+        return self.db.runInteraction(
+            "quarantine_media_in_room", _quarantine_media_in_room_txn
+        )
+
+    def _get_media_mxcs_in_room_txn(self, txn, room_id):
+        """Retrieves all the local and remote media MXC URIs in a given room
+
+        Args:
+            txn (cursor)
+            room_id (str)
+
+        Returns:
+            The local and remote media as a lists of tuples where the key is
+            the hostname and the value is the media ID.
+        """
+        mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
+
+        sql = """
+            SELECT stream_ordering, json FROM events
+            JOIN event_json USING (room_id, event_id)
+            WHERE room_id = ?
+                %(where_clause)s
+                AND contains_url = ? AND outlier = ?
+            ORDER BY stream_ordering DESC
+            LIMIT ?
+        """
+        txn.execute(sql % {"where_clause": ""}, (room_id, True, False, 100))
+
+        local_media_mxcs = []
+        remote_media_mxcs = []
+
+        while True:
+            next_token = None
+            for stream_ordering, content_json in txn:
+                next_token = stream_ordering
+                event_json = json.loads(content_json)
+                content = event_json["content"]
+                content_url = content.get("url")
+                thumbnail_url = content.get("info", {}).get("thumbnail_url")
+
+                for url in (content_url, thumbnail_url):
+                    if not url:
+                        continue
+                    matches = mxc_re.match(url)
+                    if matches:
+                        hostname = matches.group(1)
+                        media_id = matches.group(2)
+                        if hostname == self.hs.hostname:
+                            local_media_mxcs.append(media_id)
+                        else:
+                            remote_media_mxcs.append((hostname, media_id))
+
+            if next_token is None:
+                # We've gone through the whole room, so we're finished.
+                break
+
+            txn.execute(
+                sql % {"where_clause": "AND stream_ordering < ?"},
+                (room_id, next_token, True, False, 100),
+            )
+
+        return local_media_mxcs, remote_media_mxcs
+
 
 class RoomBackgroundUpdateStore(SQLBaseStore):
     REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
@@ -810,126 +938,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
             (room_id,),
         )
 
-    def get_media_mxcs_in_room(self, room_id):
-        """Retrieves all the local and remote media MXC URIs in a given room
-
-        Args:
-            room_id (str)
-
-        Returns:
-            The local and remote media as a lists of tuples where the key is
-            the hostname and the value is the media ID.
-        """
-
-        def _get_media_mxcs_in_room_txn(txn):
-            local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
-            local_media_mxcs = []
-            remote_media_mxcs = []
-
-            # Convert the IDs to MXC URIs
-            for media_id in local_mxcs:
-                local_media_mxcs.append("mxc://%s/%s" % (self.hs.hostname, media_id))
-            for hostname, media_id in remote_mxcs:
-                remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id))
-
-            return local_media_mxcs, remote_media_mxcs
-
-        return self.db.runInteraction(
-            "get_media_ids_in_room", _get_media_mxcs_in_room_txn
-        )
-
-    def quarantine_media_ids_in_room(self, room_id, quarantined_by):
-        """For a room loops through all events with media and quarantines
-        the associated media
-        """
-
-        def _quarantine_media_in_room_txn(txn):
-            local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
-            total_media_quarantined = 0
-
-            # Now update all the tables to set the quarantined_by flag
-
-            txn.executemany(
-                """
-                UPDATE local_media_repository
-                SET quarantined_by = ?
-                WHERE media_id = ?
-            """,
-                ((quarantined_by, media_id) for media_id in local_mxcs),
-            )
-
-            txn.executemany(
-                """
-                    UPDATE remote_media_cache
-                    SET quarantined_by = ?
-                    WHERE media_origin = ? AND media_id = ?
-                """,
-                (
-                    (quarantined_by, origin, media_id)
-                    for origin, media_id in remote_mxcs
-                ),
-            )
-
-            total_media_quarantined += len(local_mxcs)
-            total_media_quarantined += len(remote_mxcs)
-
-            return total_media_quarantined
-
-        return self.db.runInteraction(
-            "quarantine_media_in_room", _quarantine_media_in_room_txn
-        )
-
-    def _get_media_mxcs_in_room_txn(self, txn, room_id):
-        """Retrieves all the local and remote media MXC URIs in a given room
-
-        Args:
-            txn (cursor)
-            room_id (str)
-
-        Returns:
-            The local and remote media as a lists of tuples where the key is
-            the hostname and the value is the media ID.
-        """
-        mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
-
-        next_token = self.get_current_events_token() + 1
-        local_media_mxcs = []
-        remote_media_mxcs = []
-
-        while next_token:
-            sql = """
-                SELECT stream_ordering, json FROM events
-                JOIN event_json USING (room_id, event_id)
-                WHERE room_id = ?
-                    AND stream_ordering < ?
-                    AND contains_url = ? AND outlier = ?
-                ORDER BY stream_ordering DESC
-                LIMIT ?
-            """
-            txn.execute(sql, (room_id, next_token, True, False, 100))
-
-            next_token = None
-            for stream_ordering, content_json in txn:
-                next_token = stream_ordering
-                event_json = json.loads(content_json)
-                content = event_json["content"]
-                content_url = content.get("url")
-                thumbnail_url = content.get("info", {}).get("thumbnail_url")
-
-                for url in (content_url, thumbnail_url):
-                    if not url:
-                        continue
-                    matches = mxc_re.match(url)
-                    if matches:
-                        hostname = matches.group(1)
-                        media_id = matches.group(2)
-                        if hostname == self.hs.hostname:
-                            local_media_mxcs.append(media_id)
-                        else:
-                            remote_media_mxcs.append((hostname, media_id))
-
-        return local_media_mxcs, remote_media_mxcs
-
     @defer.inlineCallbacks
     def get_rooms_for_retention_period_in_range(
         self, min_ms, max_ms, include_null=False