diff options
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/rest/client/v1/admin.py | 22 | ||||
-rw-r--r-- | synapse/state.py | 36 | ||||
-rw-r--r-- | synapse/storage/events.py | 57 | ||||
-rw-r--r-- | synapse/storage/room.py | 153 |
4 files changed, 186 insertions, 82 deletions
diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index 5022808ea9..0615e5d807 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -289,6 +289,27 @@ class QuarantineMediaInRoom(ClientV1RestServlet): defer.returnValue((200, {"num_quarantined": num_quarantined})) +class ListMediaInRoom(ClientV1RestServlet): + """Lists all of the media in a given room. + """ + PATTERNS = client_path_patterns("/admin/room/(?P<room_id>[^/]+)/media") + + def __init__(self, hs): + super(ListMediaInRoom, self).__init__(hs) + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def on_GET(self, request, room_id): + requester = yield self.auth.get_user_by_req(request) + is_admin = yield self.auth.is_server_admin(requester.user) + if not is_admin: + raise AuthError(403, "You are not a server admin") + + local_mxcs, remote_mxcs = yield self.store.get_media_mxcs_in_room(room_id) + + defer.returnValue((200, {"local": local_mxcs, "remote": remote_mxcs})) + + class ResetPasswordRestServlet(ClientV1RestServlet): """Post request to allow an administrator reset password for a user. This needs user to have administrator access in Synapse. @@ -487,3 +508,4 @@ def register_servlets(hs, http_server): SearchUsersRestServlet(hs).register(http_server) ShutdownRoomRestServlet(hs).register(http_server) QuarantineMediaInRoom(hs).register(http_server) + ListMediaInRoom(hs).register(http_server) diff --git a/synapse/state.py b/synapse/state.py index 5004e0f913..273f9911ca 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -58,7 +58,11 @@ class _StateCacheEntry(object): __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"] def __init__(self, state, state_group, prev_group=None, delta_ids=None): + # dict[(str, str), str] map from (type, state_key) to event_id self.state = frozendict(state) + + # the ID of a state group if one and only one is involved. + # otherwise, None otherwise? self.state_group = state_group self.prev_group = prev_group @@ -275,11 +279,12 @@ class StateHandler(object): """ Given a list of event_ids this method fetches the state at each event, resolves conflicts between them and returns them. + Args: + room_id (str): + event_ids (list[str]): + Returns: - a Deferred tuple of (`state_group`, `state`, `prev_state`). - `state_group` is the name of a state group if one and only one is - involved. `state` is a map from (type, state_key) to event, and - `prev_state` is a list of event ids. + Deferred[_StateCacheEntry]: resolved state """ logger.debug("resolve_state_groups event_ids %s", event_ids) @@ -482,8 +487,8 @@ def resolve_events_with_state_map(state_sets, state_map): state_sets. Returns - dict[(str, str), synapse.events.FrozenEvent]: - a map from (type, state_key) to event. + dict[(str, str), str]: + a map from (type, state_key) to event_id. """ if len(state_sets) == 1: return state_sets[0] @@ -505,6 +510,21 @@ def _seperate(state_sets): """Takes the state_sets and figures out which keys are conflicted and which aren't. i.e., which have multiple different event_ids associated with them in different state sets. + + Args: + state_sets(list[dict[(str, str), str]]): + List of dicts of (type, state_key) -> event_id, which are the + different state groups to resolve. + + Returns: + (dict[(str, str), str], dict[(str, str), set[str]]): + A tuple of (unconflicted_state, conflicted_state), where: + + unconflicted_state is a dict mapping (type, state_key)->event_id + for unconflicted state keys. + + conflicted_state is a dict mapping (type, state_key) to a set of + event ids for conflicted state keys. """ unconflicted_state = dict(state_sets[0]) conflicted_state = {} @@ -545,8 +565,8 @@ def resolve_events_with_factory(state_sets, state_map_factory): a Deferred of dict of event_id to event. Returns - Deferred[dict[(str, str), synapse.events.FrozenEvent]]: - a map from (type, state_key) to event. + Deferred[dict[(str, str), str]]: + a map from (type, state_key) to event_id. """ if len(state_sets) == 1: defer.returnValue(state_sets[0]) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 33fccfa7a8..dd28c2efe3 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -386,11 +386,18 @@ class EventsStore(SQLBaseStore): if all_single_prev_not_state: continue - state = yield self._calculate_state_delta( - room_id, ev_ctx_rm, new_latest_event_ids + logger.info( + "Calculating state delta for room %s", room_id, ) - if state: - current_state_for_room[room_id] = state + current_state = yield self._get_new_state_after_events( + ev_ctx_rm, new_latest_event_ids, + ) + if current_state is not None: + delta = yield self._calculate_state_delta( + room_id, current_state, + ) + if delta is not None: + current_state_for_room[room_id] = delta yield self.runInteraction( "persist_events", @@ -467,20 +474,22 @@ class EventsStore(SQLBaseStore): defer.returnValue(new_latest_event_ids) @defer.inlineCallbacks - def _calculate_state_delta(self, room_id, events_context, new_latest_event_ids): - """Calculate the new state deltas for a room. + def _get_new_state_after_events(self, events_context, new_latest_event_ids): + """Calculate the current state dict after adding some new events to + a room - Assumes that we are only persisting events for one room at a time. + Args: + events_context (list[(EventBase, EventContext)]): + events and contexts which are being added to the room + + new_latest_event_ids (iterable[str]): + the new forward extremities for the room. Returns: - 3-tuple (to_delete, to_insert, new_state) where both are state dicts, - i.e. (type, state_key) -> event_id. `to_delete` are the entries to - first be deleted from current_state_events, `to_insert` are entries - to insert. `new_state` is the full set of state. - May return None if there are no changes to be applied. + Deferred[dict[(str,str), str]|None]: + None if there are no changes to the room state, or + a dict of (type, state_key) -> event_id]. """ - # Now we need to work out the different state sets for - # each state extremities state_sets = [] state_groups = set() missing_event_ids = [] @@ -523,12 +532,12 @@ class EventsStore(SQLBaseStore): state_sets.extend(group_to_state.itervalues()) if not new_latest_event_ids: - current_state = {} + defer.returnValue({}) elif was_updated: if len(state_sets) == 1: # If there is only one state set, then we know what the current # state is. - current_state = state_sets[0] + defer.returnValue(state_sets[0]) else: # We work out the current state by passing the state sets to the # state resolution algorithm. It may ask for some events, including @@ -537,8 +546,7 @@ class EventsStore(SQLBaseStore): # up in the db. logger.info( - "Resolving state for %s with %i state sets", - room_id, len(state_sets), + "Resolving state with %i state sets", len(state_sets), ) events_map = {ev.event_id: ev for ev, _ in events_context} @@ -567,9 +575,22 @@ class EventsStore(SQLBaseStore): state_sets, state_map_factory=get_events, ) + defer.returnValue(current_state) else: return + @defer.inlineCallbacks + def _calculate_state_delta(self, room_id, current_state): + """Calculate the new state deltas for a room. + + Assumes that we are only persisting events for one room at a time. + + Returns: + 3-tuple (to_delete, to_insert, new_state) where both are state dicts, + i.e. (type, state_key) -> event_id. `to_delete` are the entries to + first be deleted from current_state_events, `to_insert` are entries + to insert. `new_state` is the full set of state. + """ existing_state = yield self.get_current_state_ids(room_id) existing_events = set(existing_state.itervalues()) diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 23688430b7..cf2c4dae39 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -533,73 +533,114 @@ class RoomStore(SQLBaseStore): ) self.is_room_blocked.invalidate((room_id,)) + def get_media_mxcs_in_room(self, room_id): + """Retrieves all the local and remote media MXC URIs in a given room + + Args: + room_id (str) + + Returns: + The local and remote media as a lists of tuples where the key is + the hostname and the value is the media ID. + """ + def _get_media_mxcs_in_room_txn(txn): + local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) + local_media_mxcs = [] + remote_media_mxcs = [] + + # Convert the IDs to MXC URIs + for media_id in local_mxcs: + local_media_mxcs.append("mxc://%s/%s" % (self.hostname, media_id)) + for hostname, media_id in remote_mxcs: + remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id)) + + return local_media_mxcs, remote_media_mxcs + return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn) + def quarantine_media_ids_in_room(self, room_id, quarantined_by): """For a room loops through all events with media and quarantines the associated media """ - def _get_media_ids_in_room(txn): - mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)") + def _quarantine_media_in_room_txn(txn): + local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) + total_media_quarantined = 0 - next_token = self.get_current_events_token() + 1 + # Now update all the tables to set the quarantined_by flag - total_media_quarantined = 0 + txn.executemany(""" + UPDATE local_media_repository + SET quarantined_by = ? + WHERE media_id = ? + """, ((quarantined_by, media_id) for media_id in local_mxcs)) - while next_token: - sql = """ - SELECT stream_ordering, content FROM events - WHERE room_id = ? - AND stream_ordering < ? - AND contains_url = ? AND outlier = ? - ORDER BY stream_ordering DESC - LIMIT ? + txn.executemany( """ - txn.execute(sql, (room_id, next_token, True, False, 100)) - - next_token = None - local_media_mxcs = [] - remote_media_mxcs = [] - for stream_ordering, content_json in txn: - next_token = stream_ordering - content = json.loads(content_json) - - content_url = content.get("url") - thumbnail_url = content.get("info", {}).get("thumbnail_url") - - for url in (content_url, thumbnail_url): - if not url: - continue - matches = mxc_re.match(url) - if matches: - hostname = matches.group(1) - media_id = matches.group(2) - if hostname == self.hostname: - local_media_mxcs.append(media_id) - else: - remote_media_mxcs.append((hostname, media_id)) - - # Now update all the tables to set the quarantined_by flag - - txn.executemany(""" - UPDATE local_media_repository + UPDATE remote_media_cache SET quarantined_by = ? - WHERE media_id = ? - """, ((quarantined_by, media_id) for media_id in local_media_mxcs)) - - txn.executemany( - """ - UPDATE remote_media_cache - SET quarantined_by = ? - WHERE media_origin AND media_id = ? - """, - ( - (quarantined_by, origin, media_id) - for origin, media_id in remote_media_mxcs - ) + WHERE media_origin = ? AND media_id = ? + """, + ( + (quarantined_by, origin, media_id) + for origin, media_id in remote_mxcs ) + ) - total_media_quarantined += len(local_media_mxcs) - total_media_quarantined += len(remote_media_mxcs) + total_media_quarantined += len(local_mxcs) + total_media_quarantined += len(remote_mxcs) return total_media_quarantined - return self.runInteraction("get_media_ids_in_room", _get_media_ids_in_room) + return self.runInteraction( + "quarantine_media_in_room", + _quarantine_media_in_room_txn, + ) + + def _get_media_mxcs_in_room_txn(self, txn, room_id): + """Retrieves all the local and remote media MXC URIs in a given room + + Args: + txn (cursor) + room_id (str) + + Returns: + The local and remote media as a lists of tuples where the key is + the hostname and the value is the media ID. + """ + mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)") + + next_token = self.get_current_events_token() + 1 + local_media_mxcs = [] + remote_media_mxcs = [] + + while next_token: + sql = """ + SELECT stream_ordering, content FROM events + WHERE room_id = ? + AND stream_ordering < ? + AND contains_url = ? AND outlier = ? + ORDER BY stream_ordering DESC + LIMIT ? + """ + txn.execute(sql, (room_id, next_token, True, False, 100)) + + next_token = None + for stream_ordering, content_json in txn: + next_token = stream_ordering + content = json.loads(content_json) + + content_url = content.get("url") + thumbnail_url = content.get("info", {}).get("thumbnail_url") + + for url in (content_url, thumbnail_url): + if not url: + continue + matches = mxc_re.match(url) + if matches: + hostname = matches.group(1) + media_id = matches.group(2) + if hostname == self.hostname: + local_media_mxcs.append(media_id) + else: + remote_media_mxcs.append((hostname, media_id)) + + return local_media_mxcs, remote_media_mxcs |