summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/handlers/federation.py1
-rw-r--r--synapse/state.py3
-rw-r--r--synapse/storage/events.py188
-rw-r--r--tests/replication/slave/storage/test_events.py45
4 files changed, 138 insertions, 99 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index d3f5892376..996bfd0e23 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1319,7 +1319,6 @@ class FederationHandler(BaseHandler):
 
         event_stream_id, max_stream_id = yield self.store.persist_event(
             event, new_event_context,
-            current_state=state,
         )
 
         defer.returnValue((event_stream_id, max_stream_id))
diff --git a/synapse/state.py b/synapse/state.py
index 20aaacf40f..383d32b163 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -429,6 +429,9 @@ def resolve_events(state_sets, state_map_factory):
         dict[(str, str), synapse.events.FrozenEvent] is a map from
         (type, state_key) to event.
     """
+    if len(state_sets) == 1:
+        return state_sets[0]
+
     unconflicted_state, conflicted_state = _seperate(
         state_sets,
     )
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index ca501932f3..0d6519f30d 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -12,7 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from ._base import SQLBaseStore, _RollbackButIsFineException
+from ._base import SQLBaseStore
 
 from twisted.internet import defer, reactor
 
@@ -27,6 +27,7 @@ from synapse.util.logutils import log_function
 from synapse.util.metrics import Measure
 from synapse.api.constants import EventTypes
 from synapse.api.errors import SynapseError
+from synapse.state import resolve_events
 
 from canonicaljson import encode_canonical_json
 from collections import deque, namedtuple, OrderedDict
@@ -71,22 +72,19 @@ class _EventPeristenceQueue(object):
     """
 
     _EventPersistQueueItem = namedtuple("_EventPersistQueueItem", (
-        "events_and_contexts", "current_state", "backfilled", "deferred",
+        "events_and_contexts", "backfilled", "deferred",
     ))
 
     def __init__(self):
         self._event_persist_queues = {}
         self._currently_persisting_rooms = set()
 
-    def add_to_queue(self, room_id, events_and_contexts, backfilled, current_state):
+    def add_to_queue(self, room_id, events_and_contexts, backfilled):
         """Add events to the queue, with the given persist_event options.
         """
         queue = self._event_persist_queues.setdefault(room_id, deque())
         if queue:
             end_item = queue[-1]
-            if end_item.current_state or current_state:
-                # We perist events with current_state set to True one at a time
-                pass
             if end_item.backfilled == backfilled:
                 end_item.events_and_contexts.extend(events_and_contexts)
                 return end_item.deferred.observe()
@@ -96,7 +94,6 @@ class _EventPeristenceQueue(object):
         queue.append(self._EventPersistQueueItem(
             events_and_contexts=events_and_contexts,
             backfilled=backfilled,
-            current_state=current_state,
             deferred=deferred,
         ))
 
@@ -216,7 +213,6 @@ class EventsStore(SQLBaseStore):
             d = preserve_fn(self._event_persist_queue.add_to_queue)(
                 room_id, evs_ctxs,
                 backfilled=backfilled,
-                current_state=None,
             )
             deferreds.append(d)
 
@@ -229,11 +225,10 @@ class EventsStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     @log_function
-    def persist_event(self, event, context, current_state=None, backfilled=False):
+    def persist_event(self, event, context, backfilled=False):
         deferred = self._event_persist_queue.add_to_queue(
             event.room_id, [(event, context)],
             backfilled=backfilled,
-            current_state=current_state,
         )
 
         self._maybe_start_persisting(event.room_id)
@@ -246,21 +241,10 @@ class EventsStore(SQLBaseStore):
     def _maybe_start_persisting(self, room_id):
         @defer.inlineCallbacks
         def persisting_queue(item):
-            if item.current_state:
-                for event, context in item.events_and_contexts:
-                    # There should only ever be one item in
-                    # events_and_contexts when current_state is
-                    # not None
-                    yield self._persist_event(
-                        event, context,
-                        current_state=item.current_state,
-                        backfilled=item.backfilled,
-                    )
-            else:
-                yield self._persist_events(
-                    item.events_and_contexts,
-                    backfilled=item.backfilled,
-                )
+            yield self._persist_events(
+                item.events_and_contexts,
+                backfilled=item.backfilled,
+            )
 
         self._event_persist_queue.handle_queue(room_id, persisting_queue)
 
@@ -294,36 +278,89 @@ class EventsStore(SQLBaseStore):
             for chunk in chunks:
                 # We can't easily parallelize these since different chunks
                 # might contain the same event. :(
+
+                current_state_for_room = {}
+                if not backfilled:
+                    # Work out the new "current state" for each room.
+                    # We do this by working out what the new extremities are and then
+                    # calculating the state from that.
+                    events_by_room = {}
+                    for event, context in chunk:
+                        events_by_room.setdefault(event.room_id, []).append(
+                            (event, context)
+                        )
+
+                    for room_id, ev_ctx_rm in events_by_room.items():
+                        # Work out new extremities by recursively adding and removing
+                        # the new events.
+                        latest_event_ids = yield self.get_latest_event_ids_in_room(
+                            room_id
+                        )
+                        new_latest_event_ids = set(latest_event_ids)
+                        for event, ctx in ev_ctx_rm:
+                            if event.internal_metadata.is_outlier():
+                                continue
+
+                            new_latest_event_ids.difference_update(
+                                e_id for e_id, _ in event.prev_events
+                            )
+                            new_latest_event_ids.add(event.event_id)
+
+                        if new_latest_event_ids == set(latest_event_ids):
+                            # No change in extremities, so no change in state
+                            continue
+
+                        # Now we need to work out the different state sets for
+                        # each state extremities
+                        state_sets = []
+                        missing_event_ids = []
+                        was_updated = False
+                        for event_id in new_latest_event_ids:
+                            # First search in the list of new events we're adding,
+                            # and then use the current state from that
+                            for ev, ctx in ev_ctx_rm:
+                                if event_id == ev.event_id:
+                                    if ctx.current_state_ids is None:
+                                        raise Exception("Unknown current state")
+                                    state_sets.append(ctx.current_state_ids)
+                                    if ctx.delta_ids or hasattr(ev, "state_key"):
+                                        was_updated = True
+                                    break
+                            else:
+                                # If we couldn't find it, then we'll need to pull
+                                # the state from the database
+                                was_updated = True
+                                missing_event_ids.append(event_id)
+
+                        if missing_event_ids:
+                            # Now pull out the state for any missing events from DB
+                            event_to_groups = yield self._get_state_group_for_events(
+                                missing_event_ids,
+                            )
+
+                            groups = set(event_to_groups.values())
+                            group_to_state = yield self._get_state_for_groups(groups)
+
+                            state_sets.extend(group_to_state.values())
+
+                        if not new_latest_event_ids or was_updated:
+                            current_state_for_room[room_id] = yield resolve_events(
+                                state_sets,
+                                state_map_factory=lambda ev_ids: self.get_events(
+                                    ev_ids, get_prev_content=False, check_redacted=False,
+                                ),
+                            )
+
                 yield self.runInteraction(
                     "persist_events",
                     self._persist_events_txn,
                     events_and_contexts=chunk,
                     backfilled=backfilled,
                     delete_existing=delete_existing,
+                    current_state_for_room=current_state_for_room,
                 )
                 persist_event_counter.inc_by(len(chunk))
 
-    @_retry_on_integrity_error
-    @defer.inlineCallbacks
-    @log_function
-    def _persist_event(self, event, context, current_state=None, backfilled=False,
-                       delete_existing=False):
-        try:
-            with self._stream_id_gen.get_next() as stream_ordering:
-                event.internal_metadata.stream_ordering = stream_ordering
-                yield self.runInteraction(
-                    "persist_event",
-                    self._persist_event_txn,
-                    event=event,
-                    context=context,
-                    current_state=current_state,
-                    backfilled=backfilled,
-                    delete_existing=delete_existing,
-                )
-                persist_event_counter.inc()
-        except _RollbackButIsFineException:
-            pass
-
     @defer.inlineCallbacks
     def get_event(self, event_id, check_redacted=True,
                   get_prev_content=False, allow_rejected=False,
@@ -426,7 +463,7 @@ class EventsStore(SQLBaseStore):
 
     @log_function
     def _persist_events_txn(self, txn, events_and_contexts, backfilled,
-                            delete_existing=False):
+                            delete_existing=False, current_state_for_room={}):
         """Insert some number of room events into the necessary database tables.
 
         Rejected events are only inserted into the events table, the events_json table,
@@ -436,6 +473,40 @@ class EventsStore(SQLBaseStore):
         If delete_existing is True then existing events will be purged from the
         database before insertion. This is useful when retrying due to IntegrityError.
         """
+        for room_id, current_state in current_state_for_room.iteritems():
+            txn.call_after(self._get_current_state_for_key.invalidate_all)
+            txn.call_after(self.get_rooms_for_user.invalidate_all)
+            txn.call_after(self.get_users_in_room.invalidate, (room_id,))
+
+            # Add an entry to the current_state_resets table to record the point
+            # where we clobbered the current state
+            stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
+            self._simple_insert_txn(
+                txn,
+                table="current_state_resets",
+                values={"event_stream_ordering": stream_order}
+            )
+
+            self._simple_delete_txn(
+                txn,
+                table="current_state_events",
+                keyvalues={"room_id": room_id},
+            )
+
+            self._simple_insert_many_txn(
+                txn,
+                table="current_state_events",
+                values=[
+                    {
+                        "event_id": ev_id,
+                        "room_id": room_id,
+                        "type": key[0],
+                        "state_key": key[1],
+                    }
+                    for key, ev_id in current_state.iteritems()
+                ],
+            )
+
         # Ensure that we don't have the same event twice.
         # Pick the earliest non-outlier if there is one, else the earliest one.
         new_events_and_contexts = OrderedDict()
@@ -798,29 +869,6 @@ class EventsStore(SQLBaseStore):
             # to update the current state table
             return
 
-        for event, _ in state_events_and_contexts:
-            if event.internal_metadata.is_outlier():
-                # Outlier events shouldn't clobber the current state.
-                continue
-
-            txn.call_after(
-                self._get_current_state_for_key.invalidate,
-                (event.room_id, event.type, event.state_key,)
-            )
-
-            self._simple_upsert_txn(
-                txn,
-                "current_state_events",
-                keyvalues={
-                    "room_id": event.room_id,
-                    "type": event.type,
-                    "state_key": event.state_key,
-                },
-                values={
-                    "event_id": event.event_id,
-                }
-            )
-
         return
 
     def _add_to_cache(self, txn, events_and_contexts):
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 44e859b5d1..38fedfe690 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -60,7 +60,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
 
     @defer.inlineCallbacks
     def test_room_members(self):
-        create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
+        yield self.persist(type="m.room.create", key="", creator=USER_ID)
         yield self.replicate()
         yield self.check("get_rooms_for_user", (USER_ID,), [])
         yield self.check("get_users_in_room", (ROOM_ID,), [])
@@ -95,15 +95,11 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
         )])
         yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID_2])
 
-        # Join the room clobbering the state.
-        # This should remove any evidence of the other user being in the room.
         yield self.persist(
             type="m.room.member", key=USER_ID, membership="join",
-            reset_state=[create]
         )
         yield self.replicate()
-        yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID])
-        yield self.check("get_rooms_for_user", (USER_ID_2,), [])
+        yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID_2, USER_ID])
 
     @defer.inlineCallbacks
     def test_get_latest_event_ids_in_room(self):
@@ -125,7 +121,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
     @defer.inlineCallbacks
     def test_get_current_state(self):
         # Create the room.
-        create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
+        yield self.persist(type="m.room.create", key="", creator=USER_ID)
         yield self.replicate()
         yield self.check(
             "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID), []
@@ -151,22 +147,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
             [join2]
         )
 
-        # Leave the room, then rejoin the room clobbering state.
-        yield self.persist(type="m.room.member", key=USER_ID, membership="leave")
-        join3 = yield self.persist(
-            type="m.room.member", key=USER_ID, membership="join",
-            reset_state=[create]
-        )
-        yield self.replicate()
-        yield self.check(
-            "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID_2),
-            []
-        )
-        yield self.check(
-            "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID),
-            [join3]
-        )
-
     @defer.inlineCallbacks
     def test_redactions(self):
         yield self.persist(type="m.room.create", key="", creator=USER_ID)
@@ -283,6 +263,12 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
         if depth is None:
             depth = self.event_id
 
+        if not prev_events:
+            latest_event_ids = yield self.master_store.get_latest_event_ids_in_room(
+                room_id
+            )
+            prev_events = [(ev_id, {}) for ev_id in latest_event_ids]
+
         event_dict = {
             "sender": sender,
             "type": type,
@@ -309,12 +295,15 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
             state_ids = {
                 key: e.event_id for key, e in state.items()
             }
+            context = EventContext()
+            context.current_state_ids = state_ids
+            context.prev_state_ids = state_ids
+        elif not backfill:
+            state_handler = self.hs.get_state_handler()
+            context = yield state_handler.compute_event_context(event)
         else:
-            state_ids = None
+            context = EventContext()
 
-        context = EventContext()
-        context.current_state_ids = state_ids
-        context.prev_state_ids = state_ids
         context.push_actions = push_actions
 
         ordering = None
@@ -324,7 +313,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
             )
         else:
             ordering, _ = yield self.master_store.persist_event(
-                event, context, current_state=reset_state
+                event, context,
             )
 
         if ordering: