diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 90f96209f8..e83adc8339 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -88,9 +88,13 @@ class BaseHandler(object):
current_state = yield self.store.get_events(
context.current_state_ids.values()
)
- current_state = current_state.values()
else:
- current_state = yield self.store.get_current_state(event.room_id)
+ current_state = yield self.state_handler.get_current_state(
+ event.room_id
+ )
+
+ current_state = current_state.values()
+
logger.info("maybe_kick_guest_users %r", current_state)
yield self.kick_guest_users(current_state)
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 64f18bbb3e..b3f3bf7488 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -76,9 +76,6 @@ class SlavedEventStore(BaseSlavedStore):
get_latest_event_ids_in_room = EventFederationStore.__dict__[
"get_latest_event_ids_in_room"
]
- _get_current_state_for_key = StateStore.__dict__[
- "_get_current_state_for_key"
- ]
get_invited_rooms_for_user = RoomMemberStore.__dict__[
"get_invited_rooms_for_user"
]
@@ -115,8 +112,6 @@ class SlavedEventStore(BaseSlavedStore):
)
get_event = DataStore.get_event.__func__
get_events = DataStore.get_events.__func__
- get_current_state = DataStore.get_current_state.__func__
- get_current_state_for_key = DataStore.get_current_state_for_key.__func__
get_rooms_for_user_where_membership_is = (
DataStore.get_rooms_for_user_where_membership_is.__func__
)
@@ -248,7 +243,6 @@ class SlavedEventStore(BaseSlavedStore):
def invalidate_caches_for_event(self, event, backfilled, reset_state):
if reset_state:
- self._get_current_state_for_key.invalidate_all()
self.get_rooms_for_user.invalidate_all()
self.get_users_in_room.invalidate((event.room_id,))
@@ -289,7 +283,3 @@ class SlavedEventStore(BaseSlavedStore):
if (not event.internal_metadata.is_invite_from_remote()
and event.internal_metadata.is_outlier()):
return
-
- self._get_current_state_for_key.invalidate((
- event.room_id, event.type, event.state_key
- ))
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 6160949f32..9f57760ab0 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -476,37 +476,63 @@ class EventsStore(SQLBaseStore):
"""
max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
for room_id, current_state in current_state_for_room.iteritems():
- txn.call_after(self._get_current_state_for_key.invalidate_all)
- txn.call_after(self.get_rooms_for_user.invalidate_all)
- txn.call_after(self.get_users_in_room.invalidate, (room_id,))
-
- # Add an entry to the current_state_resets table to record the point
- # where we clobbered the current state
- self._simple_insert_txn(
- txn,
- table="current_state_resets",
- values={"event_stream_ordering": max_stream_order}
- )
-
- self._simple_delete_txn(
+ existing_state_rows = self._simple_select_list_txn(
txn,
table="current_state_events",
keyvalues={"room_id": room_id},
+ retcols=["event_id", "type", "state_key"],
)
- self._simple_insert_many_txn(
- txn,
- table="current_state_events",
- values=[
- {
- "event_id": ev_id,
- "room_id": room_id,
- "type": key[0],
- "state_key": key[1],
- }
- for key, ev_id in current_state.iteritems()
- ],
- )
+ existing_events = set(row["event_id"] for row in existing_state_rows)
+ new_events = set(ev_id for ev_id in current_state.itervalues())
+ changed_events = existing_events ^ new_events
+ if changed_events:
+ txn.executemany(
+ "DELETE FROM current_state_events WHERE event_id = ?",
+ [(ev_id,) for ev_id in changed_events],
+ )
+
+ # Add an entry to the current_state_resets table to record the point
+ # where we clobbered the current state
+ self._simple_insert_txn(
+ txn,
+ table="current_state_resets",
+ values={"event_stream_ordering": max_stream_order}
+ )
+
+ events_to_insert = (new_events - existing_events)
+ to_insert = [
+ (key, ev_id) for key, ev_id in current_state.iteritems()
+ if ev_id in events_to_insert
+ ]
+ self._simple_insert_many_txn(
+ txn,
+ table="current_state_events",
+ values=[
+ {
+ "event_id": ev_id,
+ "room_id": room_id,
+ "type": key[0],
+ "state_key": key[1],
+ }
+ for key, ev_id in to_insert
+ ],
+ )
+
+ members_changed = set(
+ row["state_key"] for row in existing_state_rows
+ if row["event_id"] in changed_events
+ and row["type"] == EventTypes.Member
+ )
+ members_changed.update(
+ key[1] for key, event_id in to_insert
+ if key[0] == EventTypes.Member
+ )
+
+ for member in members_changed:
+ txn.call_after(self.get_rooms_for_user.invalidate, (member,))
+
+ txn.call_after(self.get_users_in_room.invalidate, (room_id,))
for room_id, new_extrem in new_forward_extremeties.items():
self._simple_delete_txn(
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 7d34dd03bf..d1d653327c 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -232,58 +232,6 @@ class StateStore(SQLBaseStore):
return count
- @defer.inlineCallbacks
- def get_current_state(self, room_id, event_type=None, state_key=""):
- if event_type and state_key is not None:
- result = yield self.get_current_state_for_key(
- room_id, event_type, state_key
- )
- defer.returnValue(result)
-
- def f(txn):
- sql = (
- "SELECT event_id FROM current_state_events"
- " WHERE room_id = ? "
- )
-
- if event_type and state_key is not None:
- sql += " AND type = ? AND state_key = ? "
- args = (room_id, event_type, state_key)
- elif event_type:
- sql += " AND type = ?"
- args = (room_id, event_type)
- else:
- args = (room_id, )
-
- txn.execute(sql, args)
- results = txn.fetchall()
-
- return [r[0] for r in results]
-
- event_ids = yield self.runInteraction("get_current_state", f)
- events = yield self._get_events(event_ids, get_prev_content=False)
- defer.returnValue(events)
-
- @defer.inlineCallbacks
- def get_current_state_for_key(self, room_id, event_type, state_key):
- event_ids = yield self._get_current_state_for_key(room_id, event_type, state_key)
- events = yield self._get_events(event_ids, get_prev_content=False)
- defer.returnValue(events)
-
- @cached(num_args=3)
- def _get_current_state_for_key(self, room_id, event_type, state_key):
- def f(txn):
- sql = (
- "SELECT event_id FROM current_state_events"
- " WHERE room_id = ? AND type = ? AND state_key = ?"
- )
-
- args = (room_id, event_type, state_key)
- txn.execute(sql, args)
- results = txn.fetchall()
- return [r[0] for r in results]
- return self.runInteraction("get_current_state_for_key", f)
-
@cached(num_args=2, max_entries=100000, iterable=True)
def _get_state_group_from_group(self, group, types):
raise NotImplementedError()
|