diff --git a/docs/admin_api/media_admin_api.md b/docs/admin_api/media_admin_api.md
new file mode 100644
index 0000000000..abdbc1ea86
--- /dev/null
+++ b/docs/admin_api/media_admin_api.md
@@ -0,0 +1,23 @@
+# List all media in a room
+
+This API gets a list of known media in a room.
+
+The API is:
+```
+GET /_matrix/client/r0/admin/room/<room_id>/media
+```
+including an `access_token` of a server admin.
+
+It returns a JSON body like the following:
+```
+{
+ "local": [
+ "mxc://localhost/xwvutsrqponmlkjihgfedcba",
+ "mxc://localhost/abcdefghijklmnopqrstuvwx"
+ ],
+ "remote": [
+ "mxc://matrix.org/xwvutsrqponmlkjihgfedcba",
+ "mxc://matrix.org/abcdefghijklmnopqrstuvwx"
+ ]
+}
+```
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 677532c87b..8ee9434c9b 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -808,13 +808,12 @@ class FederationHandler(BaseHandler):
event_ids = list(extremities.keys())
logger.debug("calling resolve_state_groups in _maybe_backfill")
+ resolve = logcontext.preserve_fn(
+ self.state_handler.resolve_state_groups_for_events
+ )
states = yield logcontext.make_deferred_yieldable(defer.gatherResults(
- [
- logcontext.preserve_fn(self.state_handler.resolve_state_groups)(
- room_id, [e]
- )
- for e in event_ids
- ], consumeErrors=True,
+ [resolve(room_id, [e]) for e in event_ids],
+ consumeErrors=True,
))
states = dict(zip(event_ids, [s.state for s in states]))
diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py
index 5022808ea9..0615e5d807 100644
--- a/synapse/rest/client/v1/admin.py
+++ b/synapse/rest/client/v1/admin.py
@@ -289,6 +289,27 @@ class QuarantineMediaInRoom(ClientV1RestServlet):
defer.returnValue((200, {"num_quarantined": num_quarantined}))
+class ListMediaInRoom(ClientV1RestServlet):
+ """Lists all of the media in a given room.
+ """
+ PATTERNS = client_path_patterns("/admin/room/(?P<room_id>[^/]+)/media")
+
+ def __init__(self, hs):
+ super(ListMediaInRoom, self).__init__(hs)
+ self.store = hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, room_id):
+ requester = yield self.auth.get_user_by_req(request)
+ is_admin = yield self.auth.is_server_admin(requester.user)
+ if not is_admin:
+ raise AuthError(403, "You are not a server admin")
+
+ local_mxcs, remote_mxcs = yield self.store.get_media_mxcs_in_room(room_id)
+
+ defer.returnValue((200, {"local": local_mxcs, "remote": remote_mxcs}))
+
+
class ResetPasswordRestServlet(ClientV1RestServlet):
"""Post request to allow an administrator reset password for a user.
This needs user to have administrator access in Synapse.
@@ -487,3 +508,4 @@ def register_servlets(hs, http_server):
SearchUsersRestServlet(hs).register(http_server)
ShutdownRoomRestServlet(hs).register(http_server)
QuarantineMediaInRoom(hs).register(http_server)
+ ListMediaInRoom(hs).register(http_server)
diff --git a/synapse/server.py b/synapse/server.py
index ff8a8fbc46..3173aed1d0 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -66,7 +66,7 @@ from synapse.rest.media.v1.media_repository import (
MediaRepository,
MediaRepositoryResource,
)
-from synapse.state import StateHandler
+from synapse.state import StateHandler, StateResolutionHandler
from synapse.storage import DataStore
from synapse.streams.events import EventSources
from synapse.util import Clock
@@ -102,6 +102,7 @@ class HomeServer(object):
'v1auth',
'auth',
'state_handler',
+ 'state_resolution_handler',
'presence_handler',
'sync_handler',
'typing_handler',
@@ -224,6 +225,9 @@ class HomeServer(object):
def build_state_handler(self):
return StateHandler(self)
+ def build_state_resolution_handler(self):
+ return StateResolutionHandler(self)
+
def build_presence_handler(self):
return PresenceHandler(self)
diff --git a/synapse/server.pyi b/synapse/server.pyi
index 41416ef252..c3a9a3847b 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -34,6 +34,9 @@ class HomeServer(object):
def get_state_handler(self) -> synapse.state.StateHandler:
pass
+ def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler:
+ pass
+
def get_deactivate_account_handler(self) -> synapse.handlers.deactivate_account.DeactivateAccountHandler:
pass
diff --git a/synapse/state.py b/synapse/state.py
index 4c8247e7c2..273f9911ca 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -58,7 +58,11 @@ class _StateCacheEntry(object):
__slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
def __init__(self, state, state_group, prev_group=None, delta_ids=None):
+ # dict[(str, str), str] map from (type, state_key) to event_id
self.state = frozendict(state)
+
+ # the ID of a state group if one and only one is involved.
+ # otherwise, None otherwise?
self.state_group = state_group
self.prev_group = prev_group
@@ -81,31 +85,19 @@ class _StateCacheEntry(object):
class StateHandler(object):
- """ Responsible for doing state conflict resolution.
+ """Fetches bits of state from the stores, and does state resolution
+ where necessary
"""
def __init__(self, hs):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.hs = hs
-
- # dict of set of event_ids -> _StateCacheEntry.
- self._state_cache = None
- self.resolve_linearizer = Linearizer(name="state_resolve_lock")
+ self._state_resolution_handler = hs.get_state_resolution_handler()
def start_caching(self):
- logger.debug("start_caching")
-
- self._state_cache = ExpiringCache(
- cache_name="state_cache",
- clock=self.clock,
- max_len=SIZE_OF_CACHE,
- expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
- iterable=True,
- reset_expiry_on_get=True,
- )
-
- self._state_cache.start()
+ # TODO: remove this shim
+ self._state_resolution_handler.start_caching()
@defer.inlineCallbacks
def get_current_state(self, room_id, event_type=None, state_key="",
@@ -127,7 +119,7 @@ class StateHandler(object):
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_state")
- ret = yield self.resolve_state_groups(room_id, latest_event_ids)
+ ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = ret.state
if event_type:
@@ -164,7 +156,7 @@ class StateHandler(object):
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_state_ids")
- ret = yield self.resolve_state_groups(room_id, latest_event_ids)
+ ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = ret.state
defer.returnValue(state)
@@ -174,7 +166,7 @@ class StateHandler(object):
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_user_in_room")
- entry = yield self.resolve_state_groups(room_id, latest_event_ids)
+ entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
joined_users = yield self.store.get_joined_users_from_state(room_id, entry)
defer.returnValue(joined_users)
@@ -183,7 +175,7 @@ class StateHandler(object):
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_hosts_in_room")
- entry = yield self.resolve_state_groups(room_id, latest_event_ids)
+ entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
joined_hosts = yield self.store.get_joined_hosts(room_id, entry)
defer.returnValue(joined_hosts)
@@ -241,7 +233,7 @@ class StateHandler(object):
defer.returnValue(context)
logger.debug("calling resolve_state_groups from compute_event_context")
- entry = yield self.resolve_state_groups(
+ entry = yield self.resolve_state_groups_for_events(
event.room_id, [e for e, _ in event.prev_events],
)
@@ -283,16 +275,16 @@ class StateHandler(object):
defer.returnValue(context)
@defer.inlineCallbacks
- @log_function
- def resolve_state_groups(self, room_id, event_ids):
+ def resolve_state_groups_for_events(self, room_id, event_ids):
""" Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them.
+ Args:
+ room_id (str):
+ event_ids (list[str]):
+
Returns:
- a Deferred tuple of (`state_group`, `state`, `prev_state`).
- `state_group` is the name of a state group if one and only one is
- involved. `state` is a map from (type, state_key) to event, and
- `prev_state` is a list of event ids.
+ Deferred[_StateCacheEntry]: resolved state
"""
logger.debug("resolve_state_groups event_ids %s", event_ids)
@@ -303,13 +295,7 @@ class StateHandler(object):
room_id, event_ids
)
- logger.debug(
- "resolve_state_groups state_groups %s",
- state_groups_ids.keys()
- )
-
- group_names = frozenset(state_groups_ids.keys())
- if len(group_names) == 1:
+ if len(state_groups_ids) == 1:
name, state_list = state_groups_ids.items().pop()
prev_group, delta_ids = yield self.store.get_state_group_delta(name)
@@ -321,6 +307,92 @@ class StateHandler(object):
delta_ids=delta_ids,
))
+ result = yield self._state_resolution_handler.resolve_state_groups(
+ room_id, state_groups_ids, self._state_map_factory,
+ )
+ defer.returnValue(result)
+
+ def _state_map_factory(self, ev_ids):
+ return self.store.get_events(
+ ev_ids, get_prev_content=False, check_redacted=False,
+ )
+
+ def resolve_events(self, state_sets, event):
+ logger.info(
+ "Resolving state for %s with %d groups", event.room_id, len(state_sets)
+ )
+ state_set_ids = [{
+ (ev.type, ev.state_key): ev.event_id
+ for ev in st
+ } for st in state_sets]
+
+ state_map = {
+ ev.event_id: ev
+ for st in state_sets
+ for ev in st
+ }
+
+ with Measure(self.clock, "state._resolve_events"):
+ new_state = resolve_events_with_state_map(state_set_ids, state_map)
+
+ new_state = {
+ key: state_map[ev_id] for key, ev_id in new_state.items()
+ }
+
+ return new_state
+
+
+class StateResolutionHandler(object):
+ """Responsible for doing state conflict resolution.
+
+ Note that the storage layer depends on this handler, so all functions must
+ be storage-independent.
+ """
+ def __init__(self, hs):
+ self.clock = hs.get_clock()
+
+ # dict of set of event_ids -> _StateCacheEntry.
+ self._state_cache = None
+ self.resolve_linearizer = Linearizer(name="state_resolve_lock")
+
+ def start_caching(self):
+ logger.debug("start_caching")
+
+ self._state_cache = ExpiringCache(
+ cache_name="state_cache",
+ clock=self.clock,
+ max_len=SIZE_OF_CACHE,
+ expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
+ iterable=True,
+ reset_expiry_on_get=True,
+ )
+
+ self._state_cache.start()
+
+ @defer.inlineCallbacks
+ @log_function
+ def resolve_state_groups(self, room_id, state_groups_ids, state_map_factory):
+ """Resolves conflicts between a set of state groups
+
+ Always generates a new state group (unless we hit the cache), so should
+ not be called for a single state group
+
+ Args:
+ room_id (str): room we are resolving for (used for logging)
+ state_groups_ids (dict[int, dict[(str, str), str]]):
+ map from state group id to the state in that state group
+ (where 'state' is a map from state key to event id)
+
+ Returns:
+ Deferred[_StateCacheEntry]: resolved state
+ """
+ logger.debug(
+ "resolve_state_groups state_groups %s",
+ state_groups_ids.keys()
+ )
+
+ group_names = frozenset(state_groups_ids.keys())
+
with (yield self.resolve_linearizer.queue(group_names)):
if self._state_cache is not None:
cache = self._state_cache.get(group_names, None)
@@ -351,15 +423,17 @@ class StateHandler(object):
with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events_with_factory(
state_groups_ids.values(),
- state_map_factory=lambda ev_ids: self.store.get_events(
- ev_ids, get_prev_content=False, check_redacted=False,
- ),
+ state_map_factory=state_map_factory,
)
else:
new_state = {
key: e_ids.pop() for key, e_ids in state.items()
}
+ # if the new state matches any of the input state groups, we can
+ # use that state group again. Otherwise we will generate a state_id
+ # which will be used as a cache key for future resolutions, but
+ # not get persisted.
state_group = None
new_state_event_ids = frozenset(new_state.values())
for sg, events in state_groups_ids.items():
@@ -396,30 +470,6 @@ class StateHandler(object):
defer.returnValue(cache)
- def resolve_events(self, state_sets, event):
- logger.info(
- "Resolving state for %s with %d groups", event.room_id, len(state_sets)
- )
- state_set_ids = [{
- (ev.type, ev.state_key): ev.event_id
- for ev in st
- } for st in state_sets]
-
- state_map = {
- ev.event_id: ev
- for st in state_sets
- for ev in st
- }
-
- with Measure(self.clock, "state._resolve_events"):
- new_state = resolve_events_with_state_map(state_set_ids, state_map)
-
- new_state = {
- key: state_map[ev_id] for key, ev_id in new_state.items()
- }
-
- return new_state
-
def _ordered_events(events):
def key_func(e):
@@ -437,8 +487,8 @@ def resolve_events_with_state_map(state_sets, state_map):
state_sets.
Returns
- dict[(str, str), synapse.events.FrozenEvent]:
- a map from (type, state_key) to event.
+ dict[(str, str), str]:
+ a map from (type, state_key) to event_id.
"""
if len(state_sets) == 1:
return state_sets[0]
@@ -460,6 +510,21 @@ def _seperate(state_sets):
"""Takes the state_sets and figures out which keys are conflicted and
which aren't. i.e., which have multiple different event_ids associated
with them in different state sets.
+
+ Args:
+ state_sets(list[dict[(str, str), str]]):
+ List of dicts of (type, state_key) -> event_id, which are the
+ different state groups to resolve.
+
+ Returns:
+ (dict[(str, str), str], dict[(str, str), set[str]]):
+ A tuple of (unconflicted_state, conflicted_state), where:
+
+ unconflicted_state is a dict mapping (type, state_key)->event_id
+ for unconflicted state keys.
+
+ conflicted_state is a dict mapping (type, state_key) to a set of
+ event ids for conflicted state keys.
"""
unconflicted_state = dict(state_sets[0])
conflicted_state = {}
@@ -500,8 +565,8 @@ def resolve_events_with_factory(state_sets, state_map_factory):
a Deferred of dict of event_id to event.
Returns
- Deferred[dict[(str, str), synapse.events.FrozenEvent]]:
- a map from (type, state_key) to event.
+ Deferred[dict[(str, str), str]]:
+ a map from (type, state_key) to event_id.
"""
if len(state_sets) == 1:
defer.returnValue(state_sets[0])
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 33fccfa7a8..dd28c2efe3 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -386,11 +386,18 @@ class EventsStore(SQLBaseStore):
if all_single_prev_not_state:
continue
- state = yield self._calculate_state_delta(
- room_id, ev_ctx_rm, new_latest_event_ids
+ logger.info(
+ "Calculating state delta for room %s", room_id,
)
- if state:
- current_state_for_room[room_id] = state
+ current_state = yield self._get_new_state_after_events(
+ ev_ctx_rm, new_latest_event_ids,
+ )
+ if current_state is not None:
+ delta = yield self._calculate_state_delta(
+ room_id, current_state,
+ )
+ if delta is not None:
+ current_state_for_room[room_id] = delta
yield self.runInteraction(
"persist_events",
@@ -467,20 +474,22 @@ class EventsStore(SQLBaseStore):
defer.returnValue(new_latest_event_ids)
@defer.inlineCallbacks
- def _calculate_state_delta(self, room_id, events_context, new_latest_event_ids):
- """Calculate the new state deltas for a room.
+ def _get_new_state_after_events(self, events_context, new_latest_event_ids):
+ """Calculate the current state dict after adding some new events to
+ a room
- Assumes that we are only persisting events for one room at a time.
+ Args:
+ events_context (list[(EventBase, EventContext)]):
+ events and contexts which are being added to the room
+
+ new_latest_event_ids (iterable[str]):
+ the new forward extremities for the room.
Returns:
- 3-tuple (to_delete, to_insert, new_state) where both are state dicts,
- i.e. (type, state_key) -> event_id. `to_delete` are the entries to
- first be deleted from current_state_events, `to_insert` are entries
- to insert. `new_state` is the full set of state.
- May return None if there are no changes to be applied.
+ Deferred[dict[(str,str), str]|None]:
+ None if there are no changes to the room state, or
+ a dict of (type, state_key) -> event_id].
"""
- # Now we need to work out the different state sets for
- # each state extremities
state_sets = []
state_groups = set()
missing_event_ids = []
@@ -523,12 +532,12 @@ class EventsStore(SQLBaseStore):
state_sets.extend(group_to_state.itervalues())
if not new_latest_event_ids:
- current_state = {}
+ defer.returnValue({})
elif was_updated:
if len(state_sets) == 1:
# If there is only one state set, then we know what the current
# state is.
- current_state = state_sets[0]
+ defer.returnValue(state_sets[0])
else:
# We work out the current state by passing the state sets to the
# state resolution algorithm. It may ask for some events, including
@@ -537,8 +546,7 @@ class EventsStore(SQLBaseStore):
# up in the db.
logger.info(
- "Resolving state for %s with %i state sets",
- room_id, len(state_sets),
+ "Resolving state with %i state sets", len(state_sets),
)
events_map = {ev.event_id: ev for ev, _ in events_context}
@@ -567,9 +575,22 @@ class EventsStore(SQLBaseStore):
state_sets,
state_map_factory=get_events,
)
+ defer.returnValue(current_state)
else:
return
+ @defer.inlineCallbacks
+ def _calculate_state_delta(self, room_id, current_state):
+ """Calculate the new state deltas for a room.
+
+ Assumes that we are only persisting events for one room at a time.
+
+ Returns:
+ 3-tuple (to_delete, to_insert, new_state) where both are state dicts,
+ i.e. (type, state_key) -> event_id. `to_delete` are the entries to
+ first be deleted from current_state_events, `to_insert` are entries
+ to insert. `new_state` is the full set of state.
+ """
existing_state = yield self.get_current_state_ids(room_id)
existing_events = set(existing_state.itervalues())
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index d91c853070..cf2c4dae39 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -533,73 +533,114 @@ class RoomStore(SQLBaseStore):
)
self.is_room_blocked.invalidate((room_id,))
+ def get_media_mxcs_in_room(self, room_id):
+ """Retrieves all the local and remote media MXC URIs in a given room
+
+ Args:
+ room_id (str)
+
+ Returns:
+ The local and remote media as a lists of tuples where the key is
+ the hostname and the value is the media ID.
+ """
+ def _get_media_mxcs_in_room_txn(txn):
+ local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
+ local_media_mxcs = []
+ remote_media_mxcs = []
+
+ # Convert the IDs to MXC URIs
+ for media_id in local_mxcs:
+ local_media_mxcs.append("mxc://%s/%s" % (self.hostname, media_id))
+ for hostname, media_id in remote_mxcs:
+ remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id))
+
+ return local_media_mxcs, remote_media_mxcs
+ return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn)
+
def quarantine_media_ids_in_room(self, room_id, quarantined_by):
"""For a room loops through all events with media and quarantines
the associated media
"""
- def _get_media_ids_in_room(txn):
- mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
+ def _quarantine_media_in_room_txn(txn):
+ local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
+ total_media_quarantined = 0
- next_token = self.get_current_events_token() + 1
+ # Now update all the tables to set the quarantined_by flag
- total_media_quarantined = 0
+ txn.executemany("""
+ UPDATE local_media_repository
+ SET quarantined_by = ?
+ WHERE media_id = ?
+ """, ((quarantined_by, media_id) for media_id in local_mxcs))
- while next_token:
- sql = """
- SELECT stream_ordering, content FROM events
- WHERE room_id = ?
- AND stream_ordering < ?
- AND contains_url = ? AND outlier = ?
- ORDER BY stream_ordering DESC
- LIMIT ?
+ txn.executemany(
"""
- txn.execute(sql, (room_id, next_token, True, False, 100))
-
- next_token = None
- local_media_mxcs = []
- remote_media_mxcs = []
- for stream_ordering, content_json in txn:
- next_token = stream_ordering
- content = json.loads(content_json)
-
- content_url = content.get("url")
- thumbnail_url = content.get("info", {}).get("thumbnail_url")
-
- for url in (content_url, thumbnail_url):
- if not url:
- continue
- matches = mxc_re.match(url)
- if matches:
- hostname = matches.group(1)
- media_id = matches.group(2)
- if hostname == self.hostname:
- local_media_mxcs.append(media_id)
- else:
- remote_media_mxcs.append((hostname, media_id))
-
- # Now update all the tables to set the quarantined_by flag
-
- txn.executemany("""
- UPDATE local_media_repository
+ UPDATE remote_media_cache
SET quarantined_by = ?
- WHERE media_id = ?
- """, ((quarantined_by, media_id) for media_id in local_media_mxcs))
-
- txn.executemany(
- """
- UPDATE remote_media_cache
- SET quarantined_by = ?
- WHERE media_origin = ? AND media_id = ?
- """,
- (
- (quarantined_by, origin, media_id)
- for origin, media_id in remote_media_mxcs
- )
+ WHERE media_origin = ? AND media_id = ?
+ """,
+ (
+ (quarantined_by, origin, media_id)
+ for origin, media_id in remote_mxcs
)
+ )
- total_media_quarantined += len(local_media_mxcs)
- total_media_quarantined += len(remote_media_mxcs)
+ total_media_quarantined += len(local_mxcs)
+ total_media_quarantined += len(remote_mxcs)
return total_media_quarantined
- return self.runInteraction("get_media_ids_in_room", _get_media_ids_in_room)
+ return self.runInteraction(
+ "quarantine_media_in_room",
+ _quarantine_media_in_room_txn,
+ )
+
+ def _get_media_mxcs_in_room_txn(self, txn, room_id):
+ """Retrieves all the local and remote media MXC URIs in a given room
+
+ Args:
+ txn (cursor)
+ room_id (str)
+
+ Returns:
+ The local and remote media as a lists of tuples where the key is
+ the hostname and the value is the media ID.
+ """
+ mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
+
+ next_token = self.get_current_events_token() + 1
+ local_media_mxcs = []
+ remote_media_mxcs = []
+
+ while next_token:
+ sql = """
+ SELECT stream_ordering, content FROM events
+ WHERE room_id = ?
+ AND stream_ordering < ?
+ AND contains_url = ? AND outlier = ?
+ ORDER BY stream_ordering DESC
+ LIMIT ?
+ """
+ txn.execute(sql, (room_id, next_token, True, False, 100))
+
+ next_token = None
+ for stream_ordering, content_json in txn:
+ next_token = stream_ordering
+ content = json.loads(content_json)
+
+ content_url = content.get("url")
+ thumbnail_url = content.get("info", {}).get("thumbnail_url")
+
+ for url in (content_url, thumbnail_url):
+ if not url:
+ continue
+ matches = mxc_re.match(url)
+ if matches:
+ hostname = matches.group(1)
+ media_id = matches.group(2)
+ if hostname == self.hostname:
+ local_media_mxcs.append(media_id)
+ else:
+ remote_media_mxcs.append((hostname, media_id))
+
+ return local_media_mxcs, remote_media_mxcs
diff --git a/tests/test_state.py b/tests/test_state.py
index feb84f3d48..d16e1b3b8b 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -19,7 +19,7 @@ from twisted.internet import defer
from synapse.events import FrozenEvent
from synapse.api.auth import Auth
from synapse.api.constants import EventTypes, Membership
-from synapse.state import StateHandler
+from synapse.state import StateHandler, StateResolutionHandler
from .utils import MockClock
@@ -148,11 +148,13 @@ class StateTestCase(unittest.TestCase):
)
hs = Mock(spec_set=[
"get_datastore", "get_auth", "get_state_handler", "get_clock",
+ "get_state_resolution_handler",
])
hs.get_datastore.return_value = self.store
hs.get_state_handler.return_value = None
hs.get_clock.return_value = MockClock()
hs.get_auth.return_value = Auth(hs)
+ hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
self.store.get_next_state_group.side_effect = Mock
self.store.get_state_group_delta.return_value = (None, None)
|