summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/handlers/_base.py8
-rw-r--r--synapse/replication/slave/storage/events.py10
-rw-r--r--synapse/storage/events.py246
-rw-r--r--synapse/storage/state.py52
-rw-r--r--tests/replication/slave/storage/test_events.py29
5 files changed, 163 insertions, 182 deletions
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 90f96209f8..e83adc8339 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -88,9 +88,13 @@ class BaseHandler(object):
                     current_state = yield self.store.get_events(
                         context.current_state_ids.values()
                     )
-                    current_state = current_state.values()
                 else:
-                    current_state = yield self.store.get_current_state(event.room_id)
+                    current_state = yield self.state_handler.get_current_state(
+                        event.room_id
+                    )
+
+                current_state = current_state.values()
+
                 logger.info("maybe_kick_guest_users %r", current_state)
                 yield self.kick_guest_users(current_state)
 
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 64f18bbb3e..b3f3bf7488 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -76,9 +76,6 @@ class SlavedEventStore(BaseSlavedStore):
     get_latest_event_ids_in_room = EventFederationStore.__dict__[
         "get_latest_event_ids_in_room"
     ]
-    _get_current_state_for_key = StateStore.__dict__[
-        "_get_current_state_for_key"
-    ]
     get_invited_rooms_for_user = RoomMemberStore.__dict__[
         "get_invited_rooms_for_user"
     ]
@@ -115,8 +112,6 @@ class SlavedEventStore(BaseSlavedStore):
     )
     get_event = DataStore.get_event.__func__
     get_events = DataStore.get_events.__func__
-    get_current_state = DataStore.get_current_state.__func__
-    get_current_state_for_key = DataStore.get_current_state_for_key.__func__
     get_rooms_for_user_where_membership_is = (
         DataStore.get_rooms_for_user_where_membership_is.__func__
     )
@@ -248,7 +243,6 @@ class SlavedEventStore(BaseSlavedStore):
 
     def invalidate_caches_for_event(self, event, backfilled, reset_state):
         if reset_state:
-            self._get_current_state_for_key.invalidate_all()
             self.get_rooms_for_user.invalidate_all()
             self.get_users_in_room.invalidate((event.room_id,))
 
@@ -289,7 +283,3 @@ class SlavedEventStore(BaseSlavedStore):
         if (not event.internal_metadata.is_invite_from_remote()
                 and event.internal_metadata.is_outlier()):
             return
-
-        self._get_current_state_for_key.invalidate((
-            event.room_id, event.type, event.state_key
-        ))
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 6160949f32..599db4c9f0 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -284,71 +284,37 @@ class EventsStore(SQLBaseStore):
                 new_forward_extremeties = {}
                 current_state_for_room = {}
                 if not backfilled:
-                    # Work out the new "current state" for each room.
-                    # We do this by working out what the new extremities are and then
-                    # calculating the state from that.
-                    events_by_room = {}
-                    for event, context in chunk:
-                        events_by_room.setdefault(event.room_id, []).append(
-                            (event, context)
-                        )
-
-                    for room_id, ev_ctx_rm in events_by_room.items():
-                        # Work out new extremities by recursively adding and removing
-                        # the new events.
-                        latest_event_ids = yield self.get_latest_event_ids_in_room(
-                            room_id
-                        )
-                        new_latest_event_ids = yield self._calculate_new_extremeties(
-                            room_id, [ev for ev, _ in ev_ctx_rm]
-                        )
-
-                        if new_latest_event_ids == set(latest_event_ids):
-                            # No change in extremities, so no change in state
-                            continue
+                    with Measure(self._clock, "_calculate_state_and_extrem"):
+                        # Work out the new "current state" for each room.
+                        # We do this by working out what the new extremities are and then
+                        # calculating the state from that.
+                        events_by_room = {}
+                        for event, context in chunk:
+                            events_by_room.setdefault(event.room_id, []).append(
+                                (event, context)
+                            )
 
-                        new_forward_extremeties[room_id] = new_latest_event_ids
-
-                        # Now we need to work out the different state sets for
-                        # each state extremities
-                        state_sets = []
-                        missing_event_ids = []
-                        was_updated = False
-                        for event_id in new_latest_event_ids:
-                            # First search in the list of new events we're adding,
-                            # and then use the current state from that
-                            for ev, ctx in ev_ctx_rm:
-                                if event_id == ev.event_id:
-                                    if ctx.current_state_ids is None:
-                                        raise Exception("Unknown current state")
-                                    state_sets.append(ctx.current_state_ids)
-                                    if ctx.delta_ids or hasattr(ev, "state_key"):
-                                        was_updated = True
-                                    break
-                            else:
-                                # If we couldn't find it, then we'll need to pull
-                                # the state from the database
-                                was_updated = True
-                                missing_event_ids.append(event_id)
-
-                        if missing_event_ids:
-                            # Now pull out the state for any missing events from DB
-                            event_to_groups = yield self._get_state_group_for_events(
-                                missing_event_ids,
+                        for room_id, ev_ctx_rm in events_by_room.items():
+                            # Work out new extremities by recursively adding and removing
+                            # the new events.
+                            latest_event_ids = yield self.get_latest_event_ids_in_room(
+                                room_id
+                            )
+                            new_latest_event_ids = yield self._calculate_new_extremeties(
+                                room_id, [ev for ev, _ in ev_ctx_rm]
                             )
 
-                            groups = set(event_to_groups.values())
-                            group_to_state = yield self._get_state_for_groups(groups)
+                            if new_latest_event_ids == set(latest_event_ids):
+                                # No change in extremities, so no change in state
+                                continue
 
-                            state_sets.extend(group_to_state.values())
+                            new_forward_extremeties[room_id] = new_latest_event_ids
 
-                        if not new_latest_event_ids or was_updated:
-                            current_state_for_room[room_id] = yield resolve_events(
-                                state_sets,
-                                state_map_factory=lambda ev_ids: self.get_events(
-                                    ev_ids, get_prev_content=False, check_redacted=False,
-                                ),
+                            state = yield self._calculate_state_delta(
+                                room_id, ev_ctx_rm, new_latest_event_ids
                             )
+                            if state:
+                                current_state_for_room[room_id] = state
 
                 yield self.runInteraction(
                     "persist_events",
@@ -406,6 +372,91 @@ class EventsStore(SQLBaseStore):
         defer.returnValue(new_latest_event_ids)
 
     @defer.inlineCallbacks
+    def _calculate_state_delta(self, room_id, events_context, new_latest_event_ids):
+        """Calculate the new state deltas for a room.
+
+        Assumes that we are only persisting events for one room at a time.
+
+        Returns:
+            2-tuple (to_delete, to_insert) where both are state dicts, i.e.
+            (type, state_key) -> event_id. `to_delete` are the entreis to
+            first be deleted from current_state_events, `to_insert` are entries
+            to insert.
+            May return None if there are no changes to be applied.
+        """
+        # Now we need to work out the different state sets for
+        # each state extremities
+        state_sets = []
+        missing_event_ids = []
+        was_updated = False
+        for event_id in new_latest_event_ids:
+            # First search in the list of new events we're adding,
+            # and then use the current state from that
+            for ev, ctx in events_context:
+                if event_id == ev.event_id:
+                    if ctx.current_state_ids is None:
+                        raise Exception("Unknown current state")
+                    state_sets.append(ctx.current_state_ids)
+                    if ctx.delta_ids or hasattr(ev, "state_key"):
+                        was_updated = True
+                    break
+            else:
+                # If we couldn't find it, then we'll need to pull
+                # the state from the database
+                was_updated = True
+                missing_event_ids.append(event_id)
+
+        if missing_event_ids:
+            # Now pull out the state for any missing events from DB
+            event_to_groups = yield self._get_state_group_for_events(
+                missing_event_ids,
+            )
+
+            groups = set(event_to_groups.values())
+            group_to_state = yield self._get_state_for_groups(groups)
+
+            state_sets.extend(group_to_state.values())
+
+        if not new_latest_event_ids:
+            current_state = {}
+        elif was_updated:
+            current_state = yield resolve_events(
+                state_sets,
+                state_map_factory=lambda ev_ids: self.get_events(
+                    ev_ids, get_prev_content=False, check_redacted=False,
+                ),
+            )
+        else:
+            return
+
+        existing_state_rows = yield self._simple_select_list(
+            table="current_state_events",
+            keyvalues={"room_id": room_id},
+            retcols=["event_id", "type", "state_key"],
+            desc="_calculate_state_delta",
+        )
+
+        existing_events = set(row["event_id"] for row in existing_state_rows)
+        new_events = set(ev_id for ev_id in current_state.itervalues())
+        changed_events = existing_events ^ new_events
+
+        if not changed_events:
+            return
+
+        to_delete = {
+            (row["type"], row["state_key"]): row["event_id"]
+            for row in existing_state_rows
+            if row["event_id"] in changed_events
+        }
+        events_to_insert = (new_events - existing_events)
+        to_insert = {
+            key: ev_id for key, ev_id in current_state.iteritems()
+            if ev_id in events_to_insert
+        }
+
+        defer.returnValue((to_delete, to_insert))
+
+    @defer.inlineCallbacks
     def get_event(self, event_id, check_redacted=True,
                   get_prev_content=False, allow_rejected=False,
                   allow_none=False):
@@ -475,38 +526,55 @@ class EventsStore(SQLBaseStore):
         database before insertion. This is useful when retrying due to IntegrityError.
         """
         max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
-        for room_id, current_state in current_state_for_room.iteritems():
-            txn.call_after(self._get_current_state_for_key.invalidate_all)
-            txn.call_after(self.get_rooms_for_user.invalidate_all)
-            txn.call_after(self.get_users_in_room.invalidate, (room_id,))
-
-            # Add an entry to the current_state_resets table to record the point
-            # where we clobbered the current state
-            self._simple_insert_txn(
-                txn,
-                table="current_state_resets",
-                values={"event_stream_ordering": max_stream_order}
-            )
+        for room_id, current_state_tuple in current_state_for_room.iteritems():
+                to_delete, to_insert = current_state_tuple
+                txn.executemany(
+                    "DELETE FROM current_state_events WHERE event_id = ?",
+                    [(ev_id,) for ev_id in to_delete.itervalues()],
+                )
 
-            self._simple_delete_txn(
-                txn,
-                table="current_state_events",
-                keyvalues={"room_id": room_id},
-            )
+                self._simple_insert_many_txn(
+                    txn,
+                    table="current_state_events",
+                    values=[
+                        {
+                            "event_id": ev_id,
+                            "room_id": room_id,
+                            "type": key[0],
+                            "state_key": key[1],
+                        }
+                        for key, ev_id in to_insert.iteritems()
+                    ],
+                )
 
-            self._simple_insert_many_txn(
-                txn,
-                table="current_state_events",
-                values=[
-                    {
-                        "event_id": ev_id,
-                        "room_id": room_id,
-                        "type": key[0],
-                        "state_key": key[1],
-                    }
-                    for key, ev_id in current_state.iteritems()
-                ],
-            )
+                # Invalidate the various caches
+
+                # Figure out the changes of membership to invalidate the
+                # `get_rooms_for_user` cache.
+                # We find out which membership events we may have deleted
+                # and which we have added, then we invlidate the caches for all
+                # those users.
+                members_changed = set(
+                    state_key for ev_type, state_key in to_delete.iterkeys()
+                    if ev_type == EventTypes.Member
+                )
+                members_changed.update(
+                    state_key for ev_type, state_key in to_insert.iterkeys()
+                    if ev_type == EventTypes.Member
+                )
+
+                for member in members_changed:
+                    txn.call_after(self.get_rooms_for_user.invalidate, (member,))
+
+                txn.call_after(self.get_users_in_room.invalidate, (room_id,))
+
+                # Add an entry to the current_state_resets table to record the point
+                # where we clobbered the current state
+                self._simple_insert_txn(
+                    txn,
+                    table="current_state_resets",
+                    values={"event_stream_ordering": max_stream_order}
+                )
 
         for room_id, new_extrem in new_forward_extremeties.items():
             self._simple_delete_txn(
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 7d34dd03bf..d1d653327c 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -232,58 +232,6 @@ class StateStore(SQLBaseStore):
 
             return count
 
-    @defer.inlineCallbacks
-    def get_current_state(self, room_id, event_type=None, state_key=""):
-        if event_type and state_key is not None:
-            result = yield self.get_current_state_for_key(
-                room_id, event_type, state_key
-            )
-            defer.returnValue(result)
-
-        def f(txn):
-            sql = (
-                "SELECT event_id FROM current_state_events"
-                " WHERE room_id = ? "
-            )
-
-            if event_type and state_key is not None:
-                sql += " AND type = ? AND state_key = ? "
-                args = (room_id, event_type, state_key)
-            elif event_type:
-                sql += " AND type = ?"
-                args = (room_id, event_type)
-            else:
-                args = (room_id, )
-
-            txn.execute(sql, args)
-            results = txn.fetchall()
-
-            return [r[0] for r in results]
-
-        event_ids = yield self.runInteraction("get_current_state", f)
-        events = yield self._get_events(event_ids, get_prev_content=False)
-        defer.returnValue(events)
-
-    @defer.inlineCallbacks
-    def get_current_state_for_key(self, room_id, event_type, state_key):
-        event_ids = yield self._get_current_state_for_key(room_id, event_type, state_key)
-        events = yield self._get_events(event_ids, get_prev_content=False)
-        defer.returnValue(events)
-
-    @cached(num_args=3)
-    def _get_current_state_for_key(self, room_id, event_type, state_key):
-        def f(txn):
-            sql = (
-                "SELECT event_id FROM current_state_events"
-                " WHERE room_id = ? AND type = ? AND state_key = ?"
-            )
-
-            args = (room_id, event_type, state_key)
-            txn.execute(sql, args)
-            results = txn.fetchall()
-            return [r[0] for r in results]
-        return self.runInteraction("get_current_state_for_key", f)
-
     @cached(num_args=2, max_entries=100000, iterable=True)
     def _get_state_group_from_group(self, group, types):
         raise NotImplementedError()
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 38fedfe690..6acb8ab758 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -119,35 +119,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
         )
 
     @defer.inlineCallbacks
-    def test_get_current_state(self):
-        # Create the room.
-        yield self.persist(type="m.room.create", key="", creator=USER_ID)
-        yield self.replicate()
-        yield self.check(
-            "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID), []
-        )
-
-        # Join the room.
-        join1 = yield self.persist(
-            type="m.room.member", key=USER_ID, membership="join",
-        )
-        yield self.replicate()
-        yield self.check(
-            "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID),
-            [join1]
-        )
-
-        # Add some other user to the room.
-        join2 = yield self.persist(
-            type="m.room.member", key=USER_ID_2, membership="join",
-        )
-        yield self.replicate()
-        yield self.check(
-            "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID_2),
-            [join2]
-        )
-
-    @defer.inlineCallbacks
     def test_redactions(self):
         yield self.persist(type="m.room.create", key="", creator=USER_ID)
         yield self.persist(type="m.room.member", key=USER_ID, membership="join")