summary refs log tree commit diff
path: root/synapse/storage/room.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/room.py')
-rw-r--r--synapse/storage/room.py70
1 files changed, 70 insertions, 0 deletions
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 07366f66b6..e9c1549c00 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -24,6 +24,7 @@ from .engines import PostgresEngine, Sqlite3Engine
 import collections
 import logging
 import ujson as json
+import re
 
 logger = logging.getLogger(__name__)
 
@@ -531,3 +532,72 @@ class RoomStore(SQLBaseStore):
             desc="block_room",
         )
         self.is_room_blocked.invalidate((room_id,))
+
+    def quarantine_media_ids_in_room(self, room_id, quarantined_by):
+        """For a room loops through all events with media and quarantines
+        the associated media
+        """
+        def _get_media_ids_in_room(txn):
+            mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
+
+            next_token = self.get_current_events_token() + 1
+
+            total_media_quarantined = 0
+
+            while next_token:
+                sql = """
+                    SELECT stream_ordering, content FROM events
+                    WHERE room_id = ?
+                        AND stream_ordering < ?
+                        AND contains_url = ? AND outlier = ?
+                    ORDER BY stream_ordering DESC
+                    LIMIT ?
+                """
+                txn.execute(sql, (room_id, next_token, True, False, 100))
+
+                next_token = None
+                local_media_mxcs = []
+                remote_media_mxcs = []
+                for stream_ordering, content_json in txn:
+                    next_token = stream_ordering
+                    content = json.loads(content_json)
+
+                    url = content.get("url")
+                    if not url:
+                        continue
+
+                    matches = mxc_re.match(url)
+                    if matches:
+                        hostname = matches.group(1)
+                        media_id = matches.group(2)
+                        if hostname == self.hostname:
+                            local_media_mxcs.append(media_id)
+                        else:
+                            remote_media_mxcs.append((hostname, media_id))
+
+                # Now update all the tables to set the quarantined_by flag
+
+                txn.executemany("""
+                    UPDATE local_media_repository
+                    SET quarantined_by = ?
+                    WHERE media_id = ?
+                """, ((quarantined_by, media_id) for media_id in local_media_mxcs))
+
+                txn.executemany(
+                    """
+                        UPDATE remote_media_cache
+                        SET quarantined_by = ?
+                        WHERE media_origin AND media_id = ?
+                    """,
+                    (
+                        (quarantined_by, origin, media_id)
+                        for origin, media_id in remote_media_mxcs
+                    )
+                )
+
+                total_media_quarantined += len(local_media_mxcs)
+                total_media_quarantined += len(remote_media_mxcs)
+
+            return total_media_quarantined
+
+        return self.runInteraction("get_media_ids_in_room", _get_media_ids_in_room)