summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/rest/client/v1/admin.py22
-rw-r--r--synapse/state.py36
-rw-r--r--synapse/storage/events.py57
-rw-r--r--synapse/storage/room.py153
4 files changed, 186 insertions, 82 deletions
diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py
index 5022808ea9..0615e5d807 100644
--- a/synapse/rest/client/v1/admin.py
+++ b/synapse/rest/client/v1/admin.py
@@ -289,6 +289,27 @@ class QuarantineMediaInRoom(ClientV1RestServlet):
         defer.returnValue((200, {"num_quarantined": num_quarantined}))
 
 
+class ListMediaInRoom(ClientV1RestServlet):
+    """Lists all of the media in a given room.
+    """
+    PATTERNS = client_path_patterns("/admin/room/(?P<room_id>[^/]+)/media")
+
+    def __init__(self, hs):
+        super(ListMediaInRoom, self).__init__(hs)
+        self.store = hs.get_datastore()
+
+    @defer.inlineCallbacks
+    def on_GET(self, request, room_id):
+        requester = yield self.auth.get_user_by_req(request)
+        is_admin = yield self.auth.is_server_admin(requester.user)
+        if not is_admin:
+            raise AuthError(403, "You are not a server admin")
+
+        local_mxcs, remote_mxcs = yield self.store.get_media_mxcs_in_room(room_id)
+
+        defer.returnValue((200, {"local": local_mxcs, "remote": remote_mxcs}))
+
+
 class ResetPasswordRestServlet(ClientV1RestServlet):
     """Post request to allow an administrator reset password for a user.
     This needs user to have administrator access in Synapse.
@@ -487,3 +508,4 @@ def register_servlets(hs, http_server):
     SearchUsersRestServlet(hs).register(http_server)
     ShutdownRoomRestServlet(hs).register(http_server)
     QuarantineMediaInRoom(hs).register(http_server)
+    ListMediaInRoom(hs).register(http_server)
diff --git a/synapse/state.py b/synapse/state.py
index 5004e0f913..273f9911ca 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -58,7 +58,11 @@ class _StateCacheEntry(object):
     __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
 
     def __init__(self, state, state_group, prev_group=None, delta_ids=None):
+        # dict[(str, str), str] map  from (type, state_key) to event_id
         self.state = frozendict(state)
+
+        # the ID of a state group if one and only one is involved.
+        # otherwise, None otherwise?
         self.state_group = state_group
 
         self.prev_group = prev_group
@@ -275,11 +279,12 @@ class StateHandler(object):
         """ Given a list of event_ids this method fetches the state at each
         event, resolves conflicts between them and returns them.
 
+        Args:
+            room_id (str):
+            event_ids (list[str]):
+
         Returns:
-            a Deferred tuple of (`state_group`, `state`, `prev_state`).
-            `state_group` is the name of a state group if one and only one is
-            involved. `state` is a map from (type, state_key) to event, and
-            `prev_state` is a list of event ids.
+            Deferred[_StateCacheEntry]: resolved state
         """
         logger.debug("resolve_state_groups event_ids %s", event_ids)
 
@@ -482,8 +487,8 @@ def resolve_events_with_state_map(state_sets, state_map):
             state_sets.
 
     Returns
-        dict[(str, str), synapse.events.FrozenEvent]:
-            a map from (type, state_key) to event.
+        dict[(str, str), str]:
+            a map from (type, state_key) to event_id.
     """
     if len(state_sets) == 1:
         return state_sets[0]
@@ -505,6 +510,21 @@ def _seperate(state_sets):
     """Takes the state_sets and figures out which keys are conflicted and
     which aren't. i.e., which have multiple different event_ids associated
     with them in different state sets.
+
+    Args:
+        state_sets(list[dict[(str, str), str]]):
+            List of dicts of (type, state_key) -> event_id, which are the
+            different state groups to resolve.
+
+    Returns:
+        (dict[(str, str), str], dict[(str, str), set[str]]):
+            A tuple of (unconflicted_state, conflicted_state), where:
+
+            unconflicted_state is a dict mapping (type, state_key)->event_id
+            for unconflicted state keys.
+
+            conflicted_state is a dict mapping (type, state_key) to a set of
+            event ids for conflicted state keys.
     """
     unconflicted_state = dict(state_sets[0])
     conflicted_state = {}
@@ -545,8 +565,8 @@ def resolve_events_with_factory(state_sets, state_map_factory):
             a Deferred of dict of event_id to event.
 
     Returns
-        Deferred[dict[(str, str), synapse.events.FrozenEvent]]:
-            a map from (type, state_key) to event.
+        Deferred[dict[(str, str), str]]:
+            a map from (type, state_key) to event_id.
     """
     if len(state_sets) == 1:
         defer.returnValue(state_sets[0])
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 33fccfa7a8..dd28c2efe3 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -386,11 +386,18 @@ class EventsStore(SQLBaseStore):
                                 if all_single_prev_not_state:
                                     continue
 
-                            state = yield self._calculate_state_delta(
-                                room_id, ev_ctx_rm, new_latest_event_ids
+                            logger.info(
+                                "Calculating state delta for room %s", room_id,
                             )
-                            if state:
-                                current_state_for_room[room_id] = state
+                            current_state = yield self._get_new_state_after_events(
+                                ev_ctx_rm, new_latest_event_ids,
+                            )
+                            if current_state is not None:
+                                delta = yield self._calculate_state_delta(
+                                    room_id, current_state,
+                                )
+                                if delta is not None:
+                                    current_state_for_room[room_id] = delta
 
                 yield self.runInteraction(
                     "persist_events",
@@ -467,20 +474,22 @@ class EventsStore(SQLBaseStore):
         defer.returnValue(new_latest_event_ids)
 
     @defer.inlineCallbacks
-    def _calculate_state_delta(self, room_id, events_context, new_latest_event_ids):
-        """Calculate the new state deltas for a room.
+    def _get_new_state_after_events(self, events_context, new_latest_event_ids):
+        """Calculate the current state dict after adding some new events to
+        a room
 
-        Assumes that we are only persisting events for one room at a time.
+        Args:
+            events_context (list[(EventBase, EventContext)]):
+                events and contexts which are being added to the room
+
+            new_latest_event_ids (iterable[str]):
+                the new forward extremities for the room.
 
         Returns:
-            3-tuple (to_delete, to_insert, new_state) where both are state dicts,
-            i.e. (type, state_key) -> event_id. `to_delete` are the entries to
-            first be deleted from current_state_events, `to_insert` are entries
-            to insert. `new_state` is the full set of state.
-            May return None if there are no changes to be applied.
+            Deferred[dict[(str,str), str]|None]:
+                None if there are no changes to the room state, or
+                a dict of (type, state_key) -> event_id].
         """
-        # Now we need to work out the different state sets for
-        # each state extremities
         state_sets = []
         state_groups = set()
         missing_event_ids = []
@@ -523,12 +532,12 @@ class EventsStore(SQLBaseStore):
                 state_sets.extend(group_to_state.itervalues())
 
         if not new_latest_event_ids:
-            current_state = {}
+            defer.returnValue({})
         elif was_updated:
             if len(state_sets) == 1:
                 # If there is only one state set, then we know what the current
                 # state is.
-                current_state = state_sets[0]
+                defer.returnValue(state_sets[0])
             else:
                 # We work out the current state by passing the state sets to the
                 # state resolution algorithm. It may ask for some events, including
@@ -537,8 +546,7 @@ class EventsStore(SQLBaseStore):
                 # up in the db.
 
                 logger.info(
-                    "Resolving state for %s with %i state sets",
-                    room_id, len(state_sets),
+                    "Resolving state with %i state sets", len(state_sets),
                 )
 
                 events_map = {ev.event_id: ev for ev, _ in events_context}
@@ -567,9 +575,22 @@ class EventsStore(SQLBaseStore):
                     state_sets,
                     state_map_factory=get_events,
                 )
+                defer.returnValue(current_state)
         else:
             return
 
+    @defer.inlineCallbacks
+    def _calculate_state_delta(self, room_id, current_state):
+        """Calculate the new state deltas for a room.
+
+        Assumes that we are only persisting events for one room at a time.
+
+        Returns:
+            3-tuple (to_delete, to_insert, new_state) where both are state dicts,
+            i.e. (type, state_key) -> event_id. `to_delete` are the entries to
+            first be deleted from current_state_events, `to_insert` are entries
+            to insert. `new_state` is the full set of state.
+        """
         existing_state = yield self.get_current_state_ids(room_id)
 
         existing_events = set(existing_state.itervalues())
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 23688430b7..cf2c4dae39 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -533,73 +533,114 @@ class RoomStore(SQLBaseStore):
         )
         self.is_room_blocked.invalidate((room_id,))
 
+    def get_media_mxcs_in_room(self, room_id):
+        """Retrieves all the local and remote media MXC URIs in a given room
+
+        Args:
+            room_id (str)
+
+        Returns:
+            The local and remote media as a lists of tuples where the key is
+            the hostname and the value is the media ID.
+        """
+        def _get_media_mxcs_in_room_txn(txn):
+            local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
+            local_media_mxcs = []
+            remote_media_mxcs = []
+
+            # Convert the IDs to MXC URIs
+            for media_id in local_mxcs:
+                local_media_mxcs.append("mxc://%s/%s" % (self.hostname, media_id))
+            for hostname, media_id in remote_mxcs:
+                remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id))
+
+            return local_media_mxcs, remote_media_mxcs
+        return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn)
+
     def quarantine_media_ids_in_room(self, room_id, quarantined_by):
         """For a room loops through all events with media and quarantines
         the associated media
         """
-        def _get_media_ids_in_room(txn):
-            mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
+        def _quarantine_media_in_room_txn(txn):
+            local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
+            total_media_quarantined = 0
 
-            next_token = self.get_current_events_token() + 1
+            # Now update all the tables to set the quarantined_by flag
 
-            total_media_quarantined = 0
+            txn.executemany("""
+                UPDATE local_media_repository
+                SET quarantined_by = ?
+                WHERE media_id = ?
+            """, ((quarantined_by, media_id) for media_id in local_mxcs))
 
-            while next_token:
-                sql = """
-                    SELECT stream_ordering, content FROM events
-                    WHERE room_id = ?
-                        AND stream_ordering < ?
-                        AND contains_url = ? AND outlier = ?
-                    ORDER BY stream_ordering DESC
-                    LIMIT ?
+            txn.executemany(
                 """
-                txn.execute(sql, (room_id, next_token, True, False, 100))
-
-                next_token = None
-                local_media_mxcs = []
-                remote_media_mxcs = []
-                for stream_ordering, content_json in txn:
-                    next_token = stream_ordering
-                    content = json.loads(content_json)
-
-                    content_url = content.get("url")
-                    thumbnail_url = content.get("info", {}).get("thumbnail_url")
-
-                    for url in (content_url, thumbnail_url):
-                        if not url:
-                            continue
-                        matches = mxc_re.match(url)
-                        if matches:
-                            hostname = matches.group(1)
-                            media_id = matches.group(2)
-                            if hostname == self.hostname:
-                                local_media_mxcs.append(media_id)
-                            else:
-                                remote_media_mxcs.append((hostname, media_id))
-
-                # Now update all the tables to set the quarantined_by flag
-
-                txn.executemany("""
-                    UPDATE local_media_repository
+                    UPDATE remote_media_cache
                     SET quarantined_by = ?
-                    WHERE media_id = ?
-                """, ((quarantined_by, media_id) for media_id in local_media_mxcs))
-
-                txn.executemany(
-                    """
-                        UPDATE remote_media_cache
-                        SET quarantined_by = ?
-                        WHERE media_origin AND media_id = ?
-                    """,
-                    (
-                        (quarantined_by, origin, media_id)
-                        for origin, media_id in remote_media_mxcs
-                    )
+                    WHERE media_origin = ? AND media_id = ?
+                """,
+                (
+                    (quarantined_by, origin, media_id)
+                    for origin, media_id in remote_mxcs
                 )
+            )
 
-                total_media_quarantined += len(local_media_mxcs)
-                total_media_quarantined += len(remote_media_mxcs)
+            total_media_quarantined += len(local_mxcs)
+            total_media_quarantined += len(remote_mxcs)
 
             return total_media_quarantined
 
-        return self.runInteraction("get_media_ids_in_room", _get_media_ids_in_room)
+        return self.runInteraction(
+            "quarantine_media_in_room",
+            _quarantine_media_in_room_txn,
+        )
+
+    def _get_media_mxcs_in_room_txn(self, txn, room_id):
+        """Retrieves all the local and remote media MXC URIs in a given room
+
+        Args:
+            txn (cursor)
+            room_id (str)
+
+        Returns:
+            The local and remote media as a lists of tuples where the key is
+            the hostname and the value is the media ID.
+        """
+        mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
+
+        next_token = self.get_current_events_token() + 1
+        local_media_mxcs = []
+        remote_media_mxcs = []
+
+        while next_token:
+            sql = """
+                SELECT stream_ordering, content FROM events
+                WHERE room_id = ?
+                    AND stream_ordering < ?
+                    AND contains_url = ? AND outlier = ?
+                ORDER BY stream_ordering DESC
+                LIMIT ?
+            """
+            txn.execute(sql, (room_id, next_token, True, False, 100))
+
+            next_token = None
+            for stream_ordering, content_json in txn:
+                next_token = stream_ordering
+                content = json.loads(content_json)
+
+                content_url = content.get("url")
+                thumbnail_url = content.get("info", {}).get("thumbnail_url")
+
+                for url in (content_url, thumbnail_url):
+                    if not url:
+                        continue
+                    matches = mxc_re.match(url)
+                    if matches:
+                        hostname = matches.group(1)
+                        media_id = matches.group(2)
+                        if hostname == self.hostname:
+                            local_media_mxcs.append(media_id)
+                        else:
+                            remote_media_mxcs.append((hostname, media_id))
+
+        return local_media_mxcs, remote_media_mxcs