diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 6160949f32..9f57760ab0 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -476,37 +476,63 @@ class EventsStore(SQLBaseStore):
"""
max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
for room_id, current_state in current_state_for_room.iteritems():
- txn.call_after(self._get_current_state_for_key.invalidate_all)
- txn.call_after(self.get_rooms_for_user.invalidate_all)
- txn.call_after(self.get_users_in_room.invalidate, (room_id,))
-
- # Add an entry to the current_state_resets table to record the point
- # where we clobbered the current state
- self._simple_insert_txn(
- txn,
- table="current_state_resets",
- values={"event_stream_ordering": max_stream_order}
- )
-
- self._simple_delete_txn(
+ existing_state_rows = self._simple_select_list_txn(
txn,
table="current_state_events",
keyvalues={"room_id": room_id},
+ retcols=["event_id", "type", "state_key"],
)
- self._simple_insert_many_txn(
- txn,
- table="current_state_events",
- values=[
- {
- "event_id": ev_id,
- "room_id": room_id,
- "type": key[0],
- "state_key": key[1],
- }
- for key, ev_id in current_state.iteritems()
- ],
- )
+ existing_events = set(row["event_id"] for row in existing_state_rows)
+ new_events = set(ev_id for ev_id in current_state.itervalues())
+ changed_events = existing_events ^ new_events
+ if changed_events:
+ txn.executemany(
+ "DELETE FROM current_state_events WHERE event_id = ?",
+ [(ev_id,) for ev_id in changed_events],
+ )
+
+ # Add an entry to the current_state_resets table to record the point
+ # where we clobbered the current state
+ self._simple_insert_txn(
+ txn,
+ table="current_state_resets",
+ values={"event_stream_ordering": max_stream_order}
+ )
+
+ events_to_insert = (new_events - existing_events)
+ to_insert = [
+ (key, ev_id) for key, ev_id in current_state.iteritems()
+ if ev_id in events_to_insert
+ ]
+ self._simple_insert_many_txn(
+ txn,
+ table="current_state_events",
+ values=[
+ {
+ "event_id": ev_id,
+ "room_id": room_id,
+ "type": key[0],
+ "state_key": key[1],
+ }
+ for key, ev_id in to_insert
+ ],
+ )
+
+ members_changed = set(
+ row["state_key"] for row in existing_state_rows
+ if row["event_id"] in changed_events
+ and row["type"] == EventTypes.Member
+ )
+ members_changed.update(
+ key[1] for key, event_id in to_insert
+ if key[0] == EventTypes.Member
+ )
+
+ for member in members_changed:
+ txn.call_after(self.get_rooms_for_user.invalidate, (member,))
+
+ txn.call_after(self.get_users_in_room.invalidate, (room_id,))
for room_id, new_extrem in new_forward_extremeties.items():
self._simple_delete_txn(
|