diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 33fccfa7a8..dd28c2efe3 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -386,11 +386,18 @@ class EventsStore(SQLBaseStore):
if all_single_prev_not_state:
continue
- state = yield self._calculate_state_delta(
- room_id, ev_ctx_rm, new_latest_event_ids
+ logger.info(
+ "Calculating state delta for room %s", room_id,
)
- if state:
- current_state_for_room[room_id] = state
+ current_state = yield self._get_new_state_after_events(
+ ev_ctx_rm, new_latest_event_ids,
+ )
+ if current_state is not None:
+ delta = yield self._calculate_state_delta(
+ room_id, current_state,
+ )
+ if delta is not None:
+ current_state_for_room[room_id] = delta
yield self.runInteraction(
"persist_events",
@@ -467,20 +474,22 @@ class EventsStore(SQLBaseStore):
defer.returnValue(new_latest_event_ids)
@defer.inlineCallbacks
- def _calculate_state_delta(self, room_id, events_context, new_latest_event_ids):
- """Calculate the new state deltas for a room.
+ def _get_new_state_after_events(self, events_context, new_latest_event_ids):
+ """Calculate the current state dict after adding some new events to
+ a room
- Assumes that we are only persisting events for one room at a time.
+ Args:
+ events_context (list[(EventBase, EventContext)]):
+ events and contexts which are being added to the room
+
+ new_latest_event_ids (iterable[str]):
+ the new forward extremities for the room.
Returns:
- 3-tuple (to_delete, to_insert, new_state) where both are state dicts,
- i.e. (type, state_key) -> event_id. `to_delete` are the entries to
- first be deleted from current_state_events, `to_insert` are entries
- to insert. `new_state` is the full set of state.
- May return None if there are no changes to be applied.
+ Deferred[dict[(str,str), str]|None]:
+ None if there are no changes to the room state, or
+ a dict of (type, state_key) -> event_id].
"""
- # Now we need to work out the different state sets for
- # each state extremities
state_sets = []
state_groups = set()
missing_event_ids = []
@@ -523,12 +532,12 @@ class EventsStore(SQLBaseStore):
state_sets.extend(group_to_state.itervalues())
if not new_latest_event_ids:
- current_state = {}
+ defer.returnValue({})
elif was_updated:
if len(state_sets) == 1:
# If there is only one state set, then we know what the current
# state is.
- current_state = state_sets[0]
+ defer.returnValue(state_sets[0])
else:
# We work out the current state by passing the state sets to the
# state resolution algorithm. It may ask for some events, including
@@ -537,8 +546,7 @@ class EventsStore(SQLBaseStore):
# up in the db.
logger.info(
- "Resolving state for %s with %i state sets",
- room_id, len(state_sets),
+ "Resolving state with %i state sets", len(state_sets),
)
events_map = {ev.event_id: ev for ev, _ in events_context}
@@ -567,9 +575,22 @@ class EventsStore(SQLBaseStore):
state_sets,
state_map_factory=get_events,
)
+ defer.returnValue(current_state)
else:
return
+ @defer.inlineCallbacks
+ def _calculate_state_delta(self, room_id, current_state):
+ """Calculate the new state deltas for a room.
+
+ Assumes that we are only persisting events for one room at a time.
+
+ Returns:
+ 3-tuple (to_delete, to_insert, new_state) where both are state dicts,
+ i.e. (type, state_key) -> event_id. `to_delete` are the entries to
+ first be deleted from current_state_events, `to_insert` are entries
+ to insert. `new_state` is the full set of state.
+ """
existing_state = yield self.get_current_state_ids(room_id)
existing_events = set(existing_state.itervalues())
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index d91c853070..cf2c4dae39 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -533,73 +533,114 @@ class RoomStore(SQLBaseStore):
)
self.is_room_blocked.invalidate((room_id,))
+ def get_media_mxcs_in_room(self, room_id):
+ """Retrieves all the local and remote media MXC URIs in a given room
+
+ Args:
+ room_id (str)
+
+ Returns:
+ The local and remote media as a lists of tuples where the key is
+ the hostname and the value is the media ID.
+ """
+ def _get_media_mxcs_in_room_txn(txn):
+ local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
+ local_media_mxcs = []
+ remote_media_mxcs = []
+
+ # Convert the IDs to MXC URIs
+ for media_id in local_mxcs:
+ local_media_mxcs.append("mxc://%s/%s" % (self.hostname, media_id))
+ for hostname, media_id in remote_mxcs:
+ remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id))
+
+ return local_media_mxcs, remote_media_mxcs
+ return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn)
+
def quarantine_media_ids_in_room(self, room_id, quarantined_by):
"""For a room loops through all events with media and quarantines
the associated media
"""
- def _get_media_ids_in_room(txn):
- mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
+ def _quarantine_media_in_room_txn(txn):
+ local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
+ total_media_quarantined = 0
- next_token = self.get_current_events_token() + 1
+ # Now update all the tables to set the quarantined_by flag
- total_media_quarantined = 0
+ txn.executemany("""
+ UPDATE local_media_repository
+ SET quarantined_by = ?
+ WHERE media_id = ?
+ """, ((quarantined_by, media_id) for media_id in local_mxcs))
- while next_token:
- sql = """
- SELECT stream_ordering, content FROM events
- WHERE room_id = ?
- AND stream_ordering < ?
- AND contains_url = ? AND outlier = ?
- ORDER BY stream_ordering DESC
- LIMIT ?
+ txn.executemany(
"""
- txn.execute(sql, (room_id, next_token, True, False, 100))
-
- next_token = None
- local_media_mxcs = []
- remote_media_mxcs = []
- for stream_ordering, content_json in txn:
- next_token = stream_ordering
- content = json.loads(content_json)
-
- content_url = content.get("url")
- thumbnail_url = content.get("info", {}).get("thumbnail_url")
-
- for url in (content_url, thumbnail_url):
- if not url:
- continue
- matches = mxc_re.match(url)
- if matches:
- hostname = matches.group(1)
- media_id = matches.group(2)
- if hostname == self.hostname:
- local_media_mxcs.append(media_id)
- else:
- remote_media_mxcs.append((hostname, media_id))
-
- # Now update all the tables to set the quarantined_by flag
-
- txn.executemany("""
- UPDATE local_media_repository
+ UPDATE remote_media_cache
SET quarantined_by = ?
- WHERE media_id = ?
- """, ((quarantined_by, media_id) for media_id in local_media_mxcs))
-
- txn.executemany(
- """
- UPDATE remote_media_cache
- SET quarantined_by = ?
- WHERE media_origin = ? AND media_id = ?
- """,
- (
- (quarantined_by, origin, media_id)
- for origin, media_id in remote_media_mxcs
- )
+ WHERE media_origin = ? AND media_id = ?
+ """,
+ (
+ (quarantined_by, origin, media_id)
+ for origin, media_id in remote_mxcs
)
+ )
- total_media_quarantined += len(local_media_mxcs)
- total_media_quarantined += len(remote_media_mxcs)
+ total_media_quarantined += len(local_mxcs)
+ total_media_quarantined += len(remote_mxcs)
return total_media_quarantined
- return self.runInteraction("get_media_ids_in_room", _get_media_ids_in_room)
+ return self.runInteraction(
+ "quarantine_media_in_room",
+ _quarantine_media_in_room_txn,
+ )
+
+ def _get_media_mxcs_in_room_txn(self, txn, room_id):
+ """Retrieves all the local and remote media MXC URIs in a given room
+
+ Args:
+ txn (cursor)
+ room_id (str)
+
+ Returns:
+ The local and remote media as a lists of tuples where the key is
+ the hostname and the value is the media ID.
+ """
+ mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
+
+ next_token = self.get_current_events_token() + 1
+ local_media_mxcs = []
+ remote_media_mxcs = []
+
+ while next_token:
+ sql = """
+ SELECT stream_ordering, content FROM events
+ WHERE room_id = ?
+ AND stream_ordering < ?
+ AND contains_url = ? AND outlier = ?
+ ORDER BY stream_ordering DESC
+ LIMIT ?
+ """
+ txn.execute(sql, (room_id, next_token, True, False, 100))
+
+ next_token = None
+ for stream_ordering, content_json in txn:
+ next_token = stream_ordering
+ content = json.loads(content_json)
+
+ content_url = content.get("url")
+ thumbnail_url = content.get("info", {}).get("thumbnail_url")
+
+ for url in (content_url, thumbnail_url):
+ if not url:
+ continue
+ matches = mxc_re.match(url)
+ if matches:
+ hostname = matches.group(1)
+ media_id = matches.group(2)
+ if hostname == self.hostname:
+ local_media_mxcs.append(media_id)
+ else:
+ remote_media_mxcs.append((hostname, media_id))
+
+ return local_media_mxcs, remote_media_mxcs
|