diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 23688430b7..cd6899a4b5 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -533,73 +533,92 @@ class RoomStore(SQLBaseStore):
)
self.is_room_blocked.invalidate((room_id,))
+ def get_media_mxcs_in_room(self, room_id):
+ def _get_media_ids_in_room(txn):
+ local_media_ids, remote_media_ids = self._get_media_ids_in_room(txn, room_id)
+ local_media_mxcs = []
+ remote_media_mxcs = []
+
+ # Convert the IDs to MXC URIs
+ for media_id in local_media_ids:
+ local_media_mxcs.append("mxc://%s/%s" % (self.hostname, media_id))
+ for hostname, media_id in remote_media_ids:
+ remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id))
+
+ return local_media_mxcs, remote_media_mxcs
+ return self.runInteraction("get_media_ids_in_room", _get_media_ids_in_room)
+
def quarantine_media_ids_in_room(self, room_id, quarantined_by):
"""For a room loops through all events with media and quarantines
the associated media
"""
- def _get_media_ids_in_room(txn):
- mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
+ def _quarantine_media_in_room(txn):
+ local_media_mxcs, remote_media_mxcs = self._get_media_ids_in_room(txn, room_id)
+ total_media_quarantined = 0
- next_token = self.get_current_events_token() + 1
+ # Now update all the tables to set the quarantined_by flag
- total_media_quarantined = 0
+ txn.executemany("""
+ UPDATE local_media_repository
+ SET quarantined_by = ?
+ WHERE media_id = ?
+ """, ((quarantined_by, media_id) for media_id in local_media_mxcs))
- while next_token:
- sql = """
- SELECT stream_ordering, content FROM events
- WHERE room_id = ?
- AND stream_ordering < ?
- AND contains_url = ? AND outlier = ?
- ORDER BY stream_ordering DESC
- LIMIT ?
+ txn.executemany(
"""
- txn.execute(sql, (room_id, next_token, True, False, 100))
-
- next_token = None
- local_media_mxcs = []
- remote_media_mxcs = []
- for stream_ordering, content_json in txn:
- next_token = stream_ordering
- content = json.loads(content_json)
-
- content_url = content.get("url")
- thumbnail_url = content.get("info", {}).get("thumbnail_url")
-
- for url in (content_url, thumbnail_url):
- if not url:
- continue
- matches = mxc_re.match(url)
- if matches:
- hostname = matches.group(1)
- media_id = matches.group(2)
- if hostname == self.hostname:
- local_media_mxcs.append(media_id)
- else:
- remote_media_mxcs.append((hostname, media_id))
-
- # Now update all the tables to set the quarantined_by flag
-
- txn.executemany("""
- UPDATE local_media_repository
+ UPDATE remote_media_cache
SET quarantined_by = ?
- WHERE media_id = ?
- """, ((quarantined_by, media_id) for media_id in local_media_mxcs))
-
- txn.executemany(
- """
- UPDATE remote_media_cache
- SET quarantined_by = ?
- WHERE media_origin AND media_id = ?
- """,
- (
- (quarantined_by, origin, media_id)
- for origin, media_id in remote_media_mxcs
- )
+ WHERE media_origin AND media_id = ?
+ """,
+ (
+ (quarantined_by, origin, media_id)
+ for origin, media_id in remote_media_mxcs
)
+ )
- total_media_quarantined += len(local_media_mxcs)
- total_media_quarantined += len(remote_media_mxcs)
+ total_media_quarantined += len(local_media_mxcs)
+ total_media_quarantined += len(remote_media_mxcs)
return total_media_quarantined
- return self.runInteraction("get_media_ids_in_room", _get_media_ids_in_room)
+ return self.runInteraction("quarantine_media_in_room", _quarantine_media_in_room)
+
+ def _get_media_ids_in_room(self, txn, room_id):
+ mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
+
+ next_token = self.get_current_events_token() + 1
+ local_media_mxcs = []
+ remote_media_mxcs = []
+
+ while next_token:
+ sql = """
+ SELECT stream_ordering, content FROM events
+ WHERE room_id = ?
+ AND stream_ordering < ?
+ AND contains_url = ? AND outlier = ?
+ ORDER BY stream_ordering DESC
+ LIMIT ?
+ """
+ txn.execute(sql, (room_id, next_token, True, False, 100))
+
+ next_token = None
+ for stream_ordering, content_json in txn:
+ next_token = stream_ordering
+ content = json.loads(content_json)
+
+ content_url = content.get("url")
+ thumbnail_url = content.get("info", {}).get("thumbnail_url")
+
+ for url in (content_url, thumbnail_url):
+ if not url:
+ continue
+ matches = mxc_re.match(url)
+ if matches:
+ hostname = matches.group(1)
+ media_id = matches.group(2)
+ if hostname == self.hostname:
+ local_media_mxcs.append(media_id)
+ else:
+ remote_media_mxcs.append((hostname, media_id))
+
+ return local_media_mxcs, remote_media_mxcs
|