diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 90f96209f8..e83adc8339 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -88,9 +88,13 @@ class BaseHandler(object):
current_state = yield self.store.get_events(
context.current_state_ids.values()
)
- current_state = current_state.values()
else:
- current_state = yield self.store.get_current_state(event.room_id)
+ current_state = yield self.state_handler.get_current_state(
+ event.room_id
+ )
+
+ current_state = current_state.values()
+
logger.info("maybe_kick_guest_users %r", current_state)
yield self.kick_guest_users(current_state)
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 64f18bbb3e..b3f3bf7488 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -76,9 +76,6 @@ class SlavedEventStore(BaseSlavedStore):
get_latest_event_ids_in_room = EventFederationStore.__dict__[
"get_latest_event_ids_in_room"
]
- _get_current_state_for_key = StateStore.__dict__[
- "_get_current_state_for_key"
- ]
get_invited_rooms_for_user = RoomMemberStore.__dict__[
"get_invited_rooms_for_user"
]
@@ -115,8 +112,6 @@ class SlavedEventStore(BaseSlavedStore):
)
get_event = DataStore.get_event.__func__
get_events = DataStore.get_events.__func__
- get_current_state = DataStore.get_current_state.__func__
- get_current_state_for_key = DataStore.get_current_state_for_key.__func__
get_rooms_for_user_where_membership_is = (
DataStore.get_rooms_for_user_where_membership_is.__func__
)
@@ -248,7 +243,6 @@ class SlavedEventStore(BaseSlavedStore):
def invalidate_caches_for_event(self, event, backfilled, reset_state):
if reset_state:
- self._get_current_state_for_key.invalidate_all()
self.get_rooms_for_user.invalidate_all()
self.get_users_in_room.invalidate((event.room_id,))
@@ -289,7 +283,3 @@ class SlavedEventStore(BaseSlavedStore):
if (not event.internal_metadata.is_invite_from_remote()
and event.internal_metadata.is_outlier()):
return
-
- self._get_current_state_for_key.invalidate((
- event.room_id, event.type, event.state_key
- ))
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 6160949f32..599db4c9f0 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -284,71 +284,37 @@ class EventsStore(SQLBaseStore):
new_forward_extremeties = {}
current_state_for_room = {}
if not backfilled:
- # Work out the new "current state" for each room.
- # We do this by working out what the new extremities are and then
- # calculating the state from that.
- events_by_room = {}
- for event, context in chunk:
- events_by_room.setdefault(event.room_id, []).append(
- (event, context)
- )
-
- for room_id, ev_ctx_rm in events_by_room.items():
- # Work out new extremities by recursively adding and removing
- # the new events.
- latest_event_ids = yield self.get_latest_event_ids_in_room(
- room_id
- )
- new_latest_event_ids = yield self._calculate_new_extremeties(
- room_id, [ev for ev, _ in ev_ctx_rm]
- )
-
- if new_latest_event_ids == set(latest_event_ids):
- # No change in extremities, so no change in state
- continue
+ with Measure(self._clock, "_calculate_state_and_extrem"):
+ # Work out the new "current state" for each room.
+ # We do this by working out what the new extremities are and then
+ # calculating the state from that.
+ events_by_room = {}
+ for event, context in chunk:
+ events_by_room.setdefault(event.room_id, []).append(
+ (event, context)
+ )
- new_forward_extremeties[room_id] = new_latest_event_ids
-
- # Now we need to work out the different state sets for
- # each state extremities
- state_sets = []
- missing_event_ids = []
- was_updated = False
- for event_id in new_latest_event_ids:
- # First search in the list of new events we're adding,
- # and then use the current state from that
- for ev, ctx in ev_ctx_rm:
- if event_id == ev.event_id:
- if ctx.current_state_ids is None:
- raise Exception("Unknown current state")
- state_sets.append(ctx.current_state_ids)
- if ctx.delta_ids or hasattr(ev, "state_key"):
- was_updated = True
- break
- else:
- # If we couldn't find it, then we'll need to pull
- # the state from the database
- was_updated = True
- missing_event_ids.append(event_id)
-
- if missing_event_ids:
- # Now pull out the state for any missing events from DB
- event_to_groups = yield self._get_state_group_for_events(
- missing_event_ids,
+ for room_id, ev_ctx_rm in events_by_room.items():
+ # Work out new extremities by recursively adding and removing
+ # the new events.
+ latest_event_ids = yield self.get_latest_event_ids_in_room(
+ room_id
+ )
+ new_latest_event_ids = yield self._calculate_new_extremeties(
+ room_id, [ev for ev, _ in ev_ctx_rm]
)
- groups = set(event_to_groups.values())
- group_to_state = yield self._get_state_for_groups(groups)
+ if new_latest_event_ids == set(latest_event_ids):
+ # No change in extremities, so no change in state
+ continue
- state_sets.extend(group_to_state.values())
+ new_forward_extremeties[room_id] = new_latest_event_ids
- if not new_latest_event_ids or was_updated:
- current_state_for_room[room_id] = yield resolve_events(
- state_sets,
- state_map_factory=lambda ev_ids: self.get_events(
- ev_ids, get_prev_content=False, check_redacted=False,
- ),
+ state = yield self._calculate_state_delta(
+ room_id, ev_ctx_rm, new_latest_event_ids
)
+ if state:
+ current_state_for_room[room_id] = state
yield self.runInteraction(
"persist_events",
@@ -406,6 +372,91 @@ class EventsStore(SQLBaseStore):
defer.returnValue(new_latest_event_ids)
@defer.inlineCallbacks
+ def _calculate_state_delta(self, room_id, events_context, new_latest_event_ids):
+ """Calculate the new state deltas for a room.
+
+ Assumes that we are only persisting events for one room at a time.
+
+ Returns:
+ 2-tuple (to_delete, to_insert) where both are state dicts, i.e.
+ (type, state_key) -> event_id. `to_delete` are the entreis to
+ first be deleted from current_state_events, `to_insert` are entries
+ to insert.
+ May return None if there are no changes to be applied.
+ """
+ # Now we need to work out the different state sets for
+ # each state extremities
+ state_sets = []
+ missing_event_ids = []
+ was_updated = False
+ for event_id in new_latest_event_ids:
+ # First search in the list of new events we're adding,
+ # and then use the current state from that
+ for ev, ctx in events_context:
+ if event_id == ev.event_id:
+ if ctx.current_state_ids is None:
+ raise Exception("Unknown current state")
+ state_sets.append(ctx.current_state_ids)
+ if ctx.delta_ids or hasattr(ev, "state_key"):
+ was_updated = True
+ break
+ else:
+ # If we couldn't find it, then we'll need to pull
+ # the state from the database
+ was_updated = True
+ missing_event_ids.append(event_id)
+
+ if missing_event_ids:
+ # Now pull out the state for any missing events from DB
+ event_to_groups = yield self._get_state_group_for_events(
+ missing_event_ids,
+ )
+
+ groups = set(event_to_groups.values())
+ group_to_state = yield self._get_state_for_groups(groups)
+
+ state_sets.extend(group_to_state.values())
+
+ if not new_latest_event_ids:
+ current_state = {}
+ elif was_updated:
+ current_state = yield resolve_events(
+ state_sets,
+ state_map_factory=lambda ev_ids: self.get_events(
+ ev_ids, get_prev_content=False, check_redacted=False,
+ ),
+ )
+ else:
+ return
+
+ existing_state_rows = yield self._simple_select_list(
+ table="current_state_events",
+ keyvalues={"room_id": room_id},
+ retcols=["event_id", "type", "state_key"],
+ desc="_calculate_state_delta",
+ )
+
+ existing_events = set(row["event_id"] for row in existing_state_rows)
+ new_events = set(ev_id for ev_id in current_state.itervalues())
+ changed_events = existing_events ^ new_events
+
+ if not changed_events:
+ return
+
+ to_delete = {
+ (row["type"], row["state_key"]): row["event_id"]
+ for row in existing_state_rows
+ if row["event_id"] in changed_events
+ }
+ events_to_insert = (new_events - existing_events)
+ to_insert = {
+ key: ev_id for key, ev_id in current_state.iteritems()
+ if ev_id in events_to_insert
+ }
+
+ defer.returnValue((to_delete, to_insert))
+
+ @defer.inlineCallbacks
def get_event(self, event_id, check_redacted=True,
get_prev_content=False, allow_rejected=False,
allow_none=False):
@@ -475,38 +526,55 @@ class EventsStore(SQLBaseStore):
database before insertion. This is useful when retrying due to IntegrityError.
"""
max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
- for room_id, current_state in current_state_for_room.iteritems():
- txn.call_after(self._get_current_state_for_key.invalidate_all)
- txn.call_after(self.get_rooms_for_user.invalidate_all)
- txn.call_after(self.get_users_in_room.invalidate, (room_id,))
-
- # Add an entry to the current_state_resets table to record the point
- # where we clobbered the current state
- self._simple_insert_txn(
- txn,
- table="current_state_resets",
- values={"event_stream_ordering": max_stream_order}
- )
+ for room_id, current_state_tuple in current_state_for_room.iteritems():
+ to_delete, to_insert = current_state_tuple
+ txn.executemany(
+ "DELETE FROM current_state_events WHERE event_id = ?",
+ [(ev_id,) for ev_id in to_delete.itervalues()],
+ )
- self._simple_delete_txn(
- txn,
- table="current_state_events",
- keyvalues={"room_id": room_id},
- )
+ self._simple_insert_many_txn(
+ txn,
+ table="current_state_events",
+ values=[
+ {
+ "event_id": ev_id,
+ "room_id": room_id,
+ "type": key[0],
+ "state_key": key[1],
+ }
+ for key, ev_id in to_insert.iteritems()
+ ],
+ )
- self._simple_insert_many_txn(
- txn,
- table="current_state_events",
- values=[
- {
- "event_id": ev_id,
- "room_id": room_id,
- "type": key[0],
- "state_key": key[1],
- }
- for key, ev_id in current_state.iteritems()
- ],
- )
+ # Invalidate the various caches
+
+ # Figure out the changes of membership to invalidate the
+ # `get_rooms_for_user` cache.
+ # We find out which membership events we may have deleted
+ # and which we have added, then we invlidate the caches for all
+ # those users.
+ members_changed = set(
+ state_key for ev_type, state_key in to_delete.iterkeys()
+ if ev_type == EventTypes.Member
+ )
+ members_changed.update(
+ state_key for ev_type, state_key in to_insert.iterkeys()
+ if ev_type == EventTypes.Member
+ )
+
+ for member in members_changed:
+ txn.call_after(self.get_rooms_for_user.invalidate, (member,))
+
+ txn.call_after(self.get_users_in_room.invalidate, (room_id,))
+
+ # Add an entry to the current_state_resets table to record the point
+ # where we clobbered the current state
+ self._simple_insert_txn(
+ txn,
+ table="current_state_resets",
+ values={"event_stream_ordering": max_stream_order}
+ )
for room_id, new_extrem in new_forward_extremeties.items():
self._simple_delete_txn(
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 7d34dd03bf..d1d653327c 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -232,58 +232,6 @@ class StateStore(SQLBaseStore):
return count
- @defer.inlineCallbacks
- def get_current_state(self, room_id, event_type=None, state_key=""):
- if event_type and state_key is not None:
- result = yield self.get_current_state_for_key(
- room_id, event_type, state_key
- )
- defer.returnValue(result)
-
- def f(txn):
- sql = (
- "SELECT event_id FROM current_state_events"
- " WHERE room_id = ? "
- )
-
- if event_type and state_key is not None:
- sql += " AND type = ? AND state_key = ? "
- args = (room_id, event_type, state_key)
- elif event_type:
- sql += " AND type = ?"
- args = (room_id, event_type)
- else:
- args = (room_id, )
-
- txn.execute(sql, args)
- results = txn.fetchall()
-
- return [r[0] for r in results]
-
- event_ids = yield self.runInteraction("get_current_state", f)
- events = yield self._get_events(event_ids, get_prev_content=False)
- defer.returnValue(events)
-
- @defer.inlineCallbacks
- def get_current_state_for_key(self, room_id, event_type, state_key):
- event_ids = yield self._get_current_state_for_key(room_id, event_type, state_key)
- events = yield self._get_events(event_ids, get_prev_content=False)
- defer.returnValue(events)
-
- @cached(num_args=3)
- def _get_current_state_for_key(self, room_id, event_type, state_key):
- def f(txn):
- sql = (
- "SELECT event_id FROM current_state_events"
- " WHERE room_id = ? AND type = ? AND state_key = ?"
- )
-
- args = (room_id, event_type, state_key)
- txn.execute(sql, args)
- results = txn.fetchall()
- return [r[0] for r in results]
- return self.runInteraction("get_current_state_for_key", f)
-
@cached(num_args=2, max_entries=100000, iterable=True)
def _get_state_group_from_group(self, group, types):
raise NotImplementedError()
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 38fedfe690..6acb8ab758 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -119,35 +119,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
)
@defer.inlineCallbacks
- def test_get_current_state(self):
- # Create the room.
- yield self.persist(type="m.room.create", key="", creator=USER_ID)
- yield self.replicate()
- yield self.check(
- "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID), []
- )
-
- # Join the room.
- join1 = yield self.persist(
- type="m.room.member", key=USER_ID, membership="join",
- )
- yield self.replicate()
- yield self.check(
- "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID),
- [join1]
- )
-
- # Add some other user to the room.
- join2 = yield self.persist(
- type="m.room.member", key=USER_ID_2, membership="join",
- )
- yield self.replicate()
- yield self.check(
- "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID_2),
- [join2]
- )
-
- @defer.inlineCallbacks
def test_redactions(self):
yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.persist(type="m.room.member", key=USER_ID, membership="join")
|