diff options
-rw-r--r-- | changelog.d/6664.bugfix | 1 | ||||
-rw-r--r-- | synapse/app/media_repository.py | 2 | ||||
-rw-r--r-- | synapse/storage/data_stores/main/room.py | 248 |
3 files changed, 131 insertions, 120 deletions
diff --git a/changelog.d/6664.bugfix b/changelog.d/6664.bugfix new file mode 100644 index 0000000000..8c6a6fa1c8 --- /dev/null +++ b/changelog.d/6664.bugfix @@ -0,0 +1 @@ +Fix media repo admin APIs when using a media worker. diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index a63c53dc44..5b5832214a 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -34,6 +34,7 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.client_ips import SlavedClientIpStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore +from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.transactions import SlavedTransactionStore from synapse.replication.tcp.client import ReplicationClientHandler from synapse.rest.admin import register_servlets_for_media_repo @@ -47,6 +48,7 @@ logger = logging.getLogger("synapse.app.media_repository") class MediaRepositorySlavedStore( + RoomStore, SlavedApplicationServiceStore, SlavedRegistrationStore, SlavedClientIpStore, diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index 79cfd39194..8636d75030 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -366,6 +366,134 @@ class RoomWorkerStore(SQLBaseStore): defer.returnValue(row) + def get_media_mxcs_in_room(self, room_id): + """Retrieves all the local and remote media MXC URIs in a given room + + Args: + room_id (str) + + Returns: + The local and remote media as a lists of tuples where the key is + the hostname and the value is the media ID. + """ + + def _get_media_mxcs_in_room_txn(txn): + local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) + local_media_mxcs = [] + remote_media_mxcs = [] + + # Convert the IDs to MXC URIs + for media_id in local_mxcs: + local_media_mxcs.append("mxc://%s/%s" % (self.hs.hostname, media_id)) + for hostname, media_id in remote_mxcs: + remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id)) + + return local_media_mxcs, remote_media_mxcs + + return self.db.runInteraction( + "get_media_ids_in_room", _get_media_mxcs_in_room_txn + ) + + def quarantine_media_ids_in_room(self, room_id, quarantined_by): + """For a room loops through all events with media and quarantines + the associated media + """ + + def _quarantine_media_in_room_txn(txn): + local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) + total_media_quarantined = 0 + + # Now update all the tables to set the quarantined_by flag + + txn.executemany( + """ + UPDATE local_media_repository + SET quarantined_by = ? + WHERE media_id = ? + """, + ((quarantined_by, media_id) for media_id in local_mxcs), + ) + + txn.executemany( + """ + UPDATE remote_media_cache + SET quarantined_by = ? + WHERE media_origin = ? AND media_id = ? + """, + ( + (quarantined_by, origin, media_id) + for origin, media_id in remote_mxcs + ), + ) + + total_media_quarantined += len(local_mxcs) + total_media_quarantined += len(remote_mxcs) + + return total_media_quarantined + + return self.db.runInteraction( + "quarantine_media_in_room", _quarantine_media_in_room_txn + ) + + def _get_media_mxcs_in_room_txn(self, txn, room_id): + """Retrieves all the local and remote media MXC URIs in a given room + + Args: + txn (cursor) + room_id (str) + + Returns: + The local and remote media as a lists of tuples where the key is + the hostname and the value is the media ID. + """ + mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)") + + sql = """ + SELECT stream_ordering, json FROM events + JOIN event_json USING (room_id, event_id) + WHERE room_id = ? + %(where_clause)s + AND contains_url = ? AND outlier = ? + ORDER BY stream_ordering DESC + LIMIT ? + """ + txn.execute(sql % {"where_clause": ""}, (room_id, True, False, 100)) + + local_media_mxcs = [] + remote_media_mxcs = [] + + while True: + next_token = None + for stream_ordering, content_json in txn: + next_token = stream_ordering + event_json = json.loads(content_json) + content = event_json["content"] + content_url = content.get("url") + thumbnail_url = content.get("info", {}).get("thumbnail_url") + + for url in (content_url, thumbnail_url): + if not url: + continue + matches = mxc_re.match(url) + if matches: + hostname = matches.group(1) + media_id = matches.group(2) + if hostname == self.hs.hostname: + local_media_mxcs.append(media_id) + else: + remote_media_mxcs.append((hostname, media_id)) + + if next_token is None: + # We've gone through the whole room, so we're finished. + break + + txn.execute( + sql % {"where_clause": "AND stream_ordering < ?"}, + (room_id, next_token, True, False, 100), + ) + + return local_media_mxcs, remote_media_mxcs + class RoomBackgroundUpdateStore(SQLBaseStore): REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory" @@ -810,126 +938,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): (room_id,), ) - def get_media_mxcs_in_room(self, room_id): - """Retrieves all the local and remote media MXC URIs in a given room - - Args: - room_id (str) - - Returns: - The local and remote media as a lists of tuples where the key is - the hostname and the value is the media ID. - """ - - def _get_media_mxcs_in_room_txn(txn): - local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) - local_media_mxcs = [] - remote_media_mxcs = [] - - # Convert the IDs to MXC URIs - for media_id in local_mxcs: - local_media_mxcs.append("mxc://%s/%s" % (self.hs.hostname, media_id)) - for hostname, media_id in remote_mxcs: - remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id)) - - return local_media_mxcs, remote_media_mxcs - - return self.db.runInteraction( - "get_media_ids_in_room", _get_media_mxcs_in_room_txn - ) - - def quarantine_media_ids_in_room(self, room_id, quarantined_by): - """For a room loops through all events with media and quarantines - the associated media - """ - - def _quarantine_media_in_room_txn(txn): - local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) - total_media_quarantined = 0 - - # Now update all the tables to set the quarantined_by flag - - txn.executemany( - """ - UPDATE local_media_repository - SET quarantined_by = ? - WHERE media_id = ? - """, - ((quarantined_by, media_id) for media_id in local_mxcs), - ) - - txn.executemany( - """ - UPDATE remote_media_cache - SET quarantined_by = ? - WHERE media_origin = ? AND media_id = ? - """, - ( - (quarantined_by, origin, media_id) - for origin, media_id in remote_mxcs - ), - ) - - total_media_quarantined += len(local_mxcs) - total_media_quarantined += len(remote_mxcs) - - return total_media_quarantined - - return self.db.runInteraction( - "quarantine_media_in_room", _quarantine_media_in_room_txn - ) - - def _get_media_mxcs_in_room_txn(self, txn, room_id): - """Retrieves all the local and remote media MXC URIs in a given room - - Args: - txn (cursor) - room_id (str) - - Returns: - The local and remote media as a lists of tuples where the key is - the hostname and the value is the media ID. - """ - mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)") - - next_token = self.get_current_events_token() + 1 - local_media_mxcs = [] - remote_media_mxcs = [] - - while next_token: - sql = """ - SELECT stream_ordering, json FROM events - JOIN event_json USING (room_id, event_id) - WHERE room_id = ? - AND stream_ordering < ? - AND contains_url = ? AND outlier = ? - ORDER BY stream_ordering DESC - LIMIT ? - """ - txn.execute(sql, (room_id, next_token, True, False, 100)) - - next_token = None - for stream_ordering, content_json in txn: - next_token = stream_ordering - event_json = json.loads(content_json) - content = event_json["content"] - content_url = content.get("url") - thumbnail_url = content.get("info", {}).get("thumbnail_url") - - for url in (content_url, thumbnail_url): - if not url: - continue - matches = mxc_re.match(url) - if matches: - hostname = matches.group(1) - media_id = matches.group(2) - if hostname == self.hs.hostname: - local_media_mxcs.append(media_id) - else: - remote_media_mxcs.append((hostname, media_id)) - - return local_media_mxcs, remote_media_mxcs - @defer.inlineCallbacks def get_rooms_for_retention_period_in_range( self, min_ms, max_ms, include_null=False |