summary refs log tree commit diff
path: root/synapse/storage/events.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/events.py')
-rw-r--r--synapse/storage/events.py330
1 files changed, 172 insertions, 158 deletions
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index b486ca50eb..ac876287fc 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -223,7 +223,7 @@ def _retry_on_integrity_error(func):
         except self.database_engine.module.IntegrityError:
             logger.exception("IntegrityError, retrying.")
             res = yield func(self, *args, delete_existing=True, **kwargs)
-        defer.returnValue(res)
+        return res
 
     return f
 
@@ -309,7 +309,7 @@ class EventsStore(
 
         max_persisted_id = yield self._stream_id_gen.get_current_token()
 
-        defer.returnValue(max_persisted_id)
+        return max_persisted_id
 
     @defer.inlineCallbacks
     @log_function
@@ -334,7 +334,7 @@ class EventsStore(
         yield make_deferred_yieldable(deferred)
 
         max_persisted_id = yield self._stream_id_gen.get_current_token()
-        defer.returnValue((event.internal_metadata.stream_ordering, max_persisted_id))
+        return (event.internal_metadata.stream_ordering, max_persisted_id)
 
     def _maybe_start_persisting(self, room_id):
         @defer.inlineCallbacks
@@ -364,147 +364,161 @@ class EventsStore(
         if not events_and_contexts:
             return
 
-        if backfilled:
-            stream_ordering_manager = self._backfill_id_gen.get_next_mult(
-                len(events_and_contexts)
-            )
-        else:
-            stream_ordering_manager = self._stream_id_gen.get_next_mult(
-                len(events_and_contexts)
-            )
-
-        with stream_ordering_manager as stream_orderings:
-            for (event, context), stream in zip(events_and_contexts, stream_orderings):
-                event.internal_metadata.stream_ordering = stream
-
-            chunks = [
-                events_and_contexts[x : x + 100]
-                for x in range(0, len(events_and_contexts), 100)
-            ]
-
-            for chunk in chunks:
-                # We can't easily parallelize these since different chunks
-                # might contain the same event. :(
+        chunks = [
+            events_and_contexts[x : x + 100]
+            for x in range(0, len(events_and_contexts), 100)
+        ]
 
-                # NB: Assumes that we are only persisting events for one room
-                # at a time.
+        for chunk in chunks:
+            # We can't easily parallelize these since different chunks
+            # might contain the same event. :(
 
-                # map room_id->list[event_ids] giving the new forward
-                # extremities in each room
-                new_forward_extremeties = {}
+            # NB: Assumes that we are only persisting events for one room
+            # at a time.
 
-                # map room_id->(type,state_key)->event_id tracking the full
-                # state in each room after adding these events.
-                # This is simply used to prefill the get_current_state_ids
-                # cache
-                current_state_for_room = {}
+            # map room_id->list[event_ids] giving the new forward
+            # extremities in each room
+            new_forward_extremeties = {}
 
-                # map room_id->(to_delete, to_insert) where to_delete is a list
-                # of type/state keys to remove from current state, and to_insert
-                # is a map (type,key)->event_id giving the state delta in each
-                # room
-                state_delta_for_room = {}
+            # map room_id->(type,state_key)->event_id tracking the full
+            # state in each room after adding these events.
+            # This is simply used to prefill the get_current_state_ids
+            # cache
+            current_state_for_room = {}
 
-                if not backfilled:
-                    with Measure(self._clock, "_calculate_state_and_extrem"):
-                        # Work out the new "current state" for each room.
-                        # We do this by working out what the new extremities are and then
-                        # calculating the state from that.
-                        events_by_room = {}
-                        for event, context in chunk:
-                            events_by_room.setdefault(event.room_id, []).append(
-                                (event, context)
-                            )
+            # map room_id->(to_delete, to_insert) where to_delete is a list
+            # of type/state keys to remove from current state, and to_insert
+            # is a map (type,key)->event_id giving the state delta in each
+            # room
+            state_delta_for_room = {}
 
-                        for room_id, ev_ctx_rm in iteritems(events_by_room):
-                            latest_event_ids = yield self.get_latest_event_ids_in_room(
-                                room_id
-                            )
-                            new_latest_event_ids = yield self._calculate_new_extremities(
-                                room_id, ev_ctx_rm, latest_event_ids
+            if not backfilled:
+                with Measure(self._clock, "_calculate_state_and_extrem"):
+                    # Work out the new "current state" for each room.
+                    # We do this by working out what the new extremities are and then
+                    # calculating the state from that.
+                    events_by_room = {}
+                    for event, context in chunk:
+                        events_by_room.setdefault(event.room_id, []).append(
+                            (event, context)
+                        )
+
+                    for room_id, ev_ctx_rm in iteritems(events_by_room):
+                        latest_event_ids = yield self.get_latest_event_ids_in_room(
+                            room_id
+                        )
+                        new_latest_event_ids = yield self._calculate_new_extremities(
+                            room_id, ev_ctx_rm, latest_event_ids
+                        )
+
+                        latest_event_ids = set(latest_event_ids)
+                        if new_latest_event_ids == latest_event_ids:
+                            # No change in extremities, so no change in state
+                            continue
+
+                        # there should always be at least one forward extremity.
+                        # (except during the initial persistence of the send_join
+                        # results, in which case there will be no existing
+                        # extremities, so we'll `continue` above and skip this bit.)
+                        assert new_latest_event_ids, "No forward extremities left!"
+
+                        new_forward_extremeties[room_id] = new_latest_event_ids
+
+                        len_1 = (
+                            len(latest_event_ids) == 1
+                            and len(new_latest_event_ids) == 1
+                        )
+                        if len_1:
+                            all_single_prev_not_state = all(
+                                len(event.prev_event_ids()) == 1
+                                and not event.is_state()
+                                for event, ctx in ev_ctx_rm
                             )
-
-                            latest_event_ids = set(latest_event_ids)
-                            if new_latest_event_ids == latest_event_ids:
-                                # No change in extremities, so no change in state
+                            # Don't bother calculating state if they're just
+                            # a long chain of single ancestor non-state events.
+                            if all_single_prev_not_state:
                                 continue
 
-                            # there should always be at least one forward extremity.
-                            # (except during the initial persistence of the send_join
-                            # results, in which case there will be no existing
-                            # extremities, so we'll `continue` above and skip this bit.)
-                            assert new_latest_event_ids, "No forward extremities left!"
-
-                            new_forward_extremeties[room_id] = new_latest_event_ids
-
-                            len_1 = (
-                                len(latest_event_ids) == 1
-                                and len(new_latest_event_ids) == 1
+                        state_delta_counter.inc()
+                        if len(new_latest_event_ids) == 1:
+                            state_delta_single_event_counter.inc()
+
+                            # This is a fairly handwavey check to see if we could
+                            # have guessed what the delta would have been when
+                            # processing one of these events.
+                            # What we're interested in is if the latest extremities
+                            # were the same when we created the event as they are
+                            # now. When this server creates a new event (as opposed
+                            # to receiving it over federation) it will use the
+                            # forward extremities as the prev_events, so we can
+                            # guess this by looking at the prev_events and checking
+                            # if they match the current forward extremities.
+                            for ev, _ in ev_ctx_rm:
+                                prev_event_ids = set(ev.prev_event_ids())
+                                if latest_event_ids == prev_event_ids:
+                                    state_delta_reuse_delta_counter.inc()
+                                    break
+
+                        logger.info("Calculating state delta for room %s", room_id)
+                        with Measure(
+                            self._clock, "persist_events.get_new_state_after_events"
+                        ):
+                            res = yield self._get_new_state_after_events(
+                                room_id,
+                                ev_ctx_rm,
+                                latest_event_ids,
+                                new_latest_event_ids,
                             )
-                            if len_1:
-                                all_single_prev_not_state = all(
-                                    len(event.prev_event_ids()) == 1
-                                    and not event.is_state()
-                                    for event, ctx in ev_ctx_rm
-                                )
-                                # Don't bother calculating state if they're just
-                                # a long chain of single ancestor non-state events.
-                                if all_single_prev_not_state:
-                                    continue
-
-                            state_delta_counter.inc()
-                            if len(new_latest_event_ids) == 1:
-                                state_delta_single_event_counter.inc()
-
-                                # This is a fairly handwavey check to see if we could
-                                # have guessed what the delta would have been when
-                                # processing one of these events.
-                                # What we're interested in is if the latest extremities
-                                # were the same when we created the event as they are
-                                # now. When this server creates a new event (as opposed
-                                # to receiving it over federation) it will use the
-                                # forward extremities as the prev_events, so we can
-                                # guess this by looking at the prev_events and checking
-                                # if they match the current forward extremities.
-                                for ev, _ in ev_ctx_rm:
-                                    prev_event_ids = set(ev.prev_event_ids())
-                                    if latest_event_ids == prev_event_ids:
-                                        state_delta_reuse_delta_counter.inc()
-                                        break
-
-                            logger.info("Calculating state delta for room %s", room_id)
+                            current_state, delta_ids = res
+
+                        # If either are not None then there has been a change,
+                        # and we need to work out the delta (or use that
+                        # given)
+                        if delta_ids is not None:
+                            # If there is a delta we know that we've
+                            # only added or replaced state, never
+                            # removed keys entirely.
+                            state_delta_for_room[room_id] = ([], delta_ids)
+                        elif current_state is not None:
                             with Measure(
-                                self._clock, "persist_events.get_new_state_after_events"
+                                self._clock, "persist_events.calculate_state_delta"
                             ):
-                                res = yield self._get_new_state_after_events(
-                                    room_id,
-                                    ev_ctx_rm,
-                                    latest_event_ids,
-                                    new_latest_event_ids,
+                                delta = yield self._calculate_state_delta(
+                                    room_id, current_state
                                 )
-                                current_state, delta_ids = res
-
-                            # If either are not None then there has been a change,
-                            # and we need to work out the delta (or use that
-                            # given)
-                            if delta_ids is not None:
-                                # If there is a delta we know that we've
-                                # only added or replaced state, never
-                                # removed keys entirely.
-                                state_delta_for_room[room_id] = ([], delta_ids)
-                            elif current_state is not None:
-                                with Measure(
-                                    self._clock, "persist_events.calculate_state_delta"
-                                ):
-                                    delta = yield self._calculate_state_delta(
-                                        room_id, current_state
-                                    )
-                                state_delta_for_room[room_id] = delta
-
-                            # If we have the current_state then lets prefill
-                            # the cache with it.
-                            if current_state is not None:
-                                current_state_for_room[room_id] = current_state
+                            state_delta_for_room[room_id] = delta
+
+                        # If we have the current_state then lets prefill
+                        # the cache with it.
+                        if current_state is not None:
+                            current_state_for_room[room_id] = current_state
+
+            # We want to calculate the stream orderings as late as possible, as
+            # we only notify after all events with a lesser stream ordering have
+            # been persisted. I.e. if we spend 10s inside the with block then
+            # that will delay all subsequent events from being notified about.
+            # Hence why we do it down here rather than wrapping the entire
+            # function.
+            #
+            # Its safe to do this after calculating the state deltas etc as we
+            # only need to protect the *persistence* of the events. This is to
+            # ensure that queries of the form "fetch events since X" don't
+            # return events and stream positions after events that are still in
+            # flight, as otherwise subsequent requests "fetch event since Y"
+            # will not return those events.
+            #
+            # Note: Multiple instances of this function cannot be in flight at
+            # the same time for the same room.
+            if backfilled:
+                stream_ordering_manager = self._backfill_id_gen.get_next_mult(
+                    len(chunk)
+                )
+            else:
+                stream_ordering_manager = self._stream_id_gen.get_next_mult(len(chunk))
+
+            with stream_ordering_manager as stream_orderings:
+                for (event, context), stream in zip(chunk, stream_orderings):
+                    event.internal_metadata.stream_ordering = stream
 
                 yield self.runInteraction(
                     "persist_events",
@@ -595,7 +609,7 @@ class EventsStore(
             stale = latest_event_ids & result
             stale_forward_extremities_counter.observe(len(stale))
 
-        defer.returnValue(result)
+        return result
 
     @defer.inlineCallbacks
     def _get_events_which_are_prevs(self, event_ids):
@@ -633,7 +647,7 @@ class EventsStore(
                 "_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk
             )
 
-        defer.returnValue(results)
+        return results
 
     @defer.inlineCallbacks
     def _get_prevs_before_rejected(self, event_ids):
@@ -695,7 +709,7 @@ class EventsStore(
                 "_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk
             )
 
-        defer.returnValue(existing_prevs)
+        return existing_prevs
 
     @defer.inlineCallbacks
     def _get_new_state_after_events(
@@ -796,7 +810,7 @@ class EventsStore(
         # If they old and new groups are the same then we don't need to do
         # anything.
         if old_state_groups == new_state_groups:
-            defer.returnValue((None, None))
+            return (None, None)
 
         if len(new_state_groups) == 1 and len(old_state_groups) == 1:
             # If we're going from one state group to another, lets check if
@@ -813,7 +827,7 @@ class EventsStore(
                 # the current state in memory then lets also return that,
                 # but it doesn't matter if we don't.
                 new_state = state_groups_map.get(new_state_group)
-                defer.returnValue((new_state, delta_ids))
+                return (new_state, delta_ids)
 
         # Now that we have calculated new_state_groups we need to get
         # their state IDs so we can resolve to a single state set.
@@ -825,7 +839,7 @@ class EventsStore(
         if len(new_state_groups) == 1:
             # If there is only one state group, then we know what the current
             # state is.
-            defer.returnValue((state_groups_map[new_state_groups.pop()], None))
+            return (state_groups_map[new_state_groups.pop()], None)
 
         # Ok, we need to defer to the state handler to resolve our state sets.
 
@@ -854,7 +868,7 @@ class EventsStore(
             state_res_store=StateResolutionStore(self),
         )
 
-        defer.returnValue((res.state, None))
+        return (res.state, None)
 
     @defer.inlineCallbacks
     def _calculate_state_delta(self, room_id, current_state):
@@ -877,7 +891,7 @@ class EventsStore(
             if ev_id != existing_state.get(key)
         }
 
-        defer.returnValue((to_delete, to_insert))
+        return (to_delete, to_insert)
 
     @log_function
     def _persist_events_txn(
@@ -918,8 +932,6 @@ class EventsStore(
         min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering
         max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
 
-        self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
-
         self._update_forward_extremities_txn(
             txn,
             new_forward_extremities=new_forward_extremeties,
@@ -993,6 +1005,10 @@ class EventsStore(
             backfilled=backfilled,
         )
 
+        # We call this last as it assumes we've inserted the events into
+        # room_memberships, where applicable.
+        self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
+
     def _update_current_state_txn(self, txn, state_delta_by_room, stream_id):
         for room_id, current_state_tuple in iteritems(state_delta_by_room):
             to_delete, to_insert = current_state_tuple
@@ -1062,16 +1078,16 @@ class EventsStore(
                 ),
             )
 
-            self._simple_insert_many_txn(
-                txn,
-                table="current_state_events",
-                values=[
-                    {
-                        "event_id": ev_id,
-                        "room_id": room_id,
-                        "type": key[0],
-                        "state_key": key[1],
-                    }
+            # We include the membership in the current state table, hence we do
+            # a lookup when we insert. This assumes that all events have already
+            # been inserted into room_memberships.
+            txn.executemany(
+                """INSERT INTO current_state_events
+                    (room_id, type, state_key, event_id, membership)
+                VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
+                """,
+                [
+                    (room_id, key[0], key[1], ev_id, ev_id)
                     for key, ev_id in iteritems(to_insert)
                 ],
             )
@@ -1562,7 +1578,7 @@ class EventsStore(
             return count
 
         ret = yield self.runInteraction("count_messages", _count_messages)
-        defer.returnValue(ret)
+        return ret
 
     @defer.inlineCallbacks
     def count_daily_sent_messages(self):
@@ -1583,7 +1599,7 @@ class EventsStore(
             return count
 
         ret = yield self.runInteraction("count_daily_sent_messages", _count_messages)
-        defer.returnValue(ret)
+        return ret
 
     @defer.inlineCallbacks
     def count_daily_active_rooms(self):
@@ -1598,7 +1614,7 @@ class EventsStore(
             return count
 
         ret = yield self.runInteraction("count_daily_active_rooms", _count)
-        defer.returnValue(ret)
+        return ret
 
     def get_current_backfill_token(self):
         """The current minimum token that backfilled events have reached"""
@@ -2181,7 +2197,7 @@ class EventsStore(
         """
         to_1, so_1 = yield self._get_event_ordering(event_id1)
         to_2, so_2 = yield self._get_event_ordering(event_id2)
-        defer.returnValue((to_1, so_1) > (to_2, so_2))
+        return (to_1, so_1) > (to_2, so_2)
 
     @cachedInlineCallbacks(max_entries=5000)
     def _get_event_ordering(self, event_id):
@@ -2195,9 +2211,7 @@ class EventsStore(
         if not res:
             raise SynapseError(404, "Could not find event %s" % (event_id,))
 
-        defer.returnValue(
-            (int(res["topological_ordering"]), int(res["stream_ordering"]))
-        )
+        return (int(res["topological_ordering"]), int(res["stream_ordering"]))
 
     def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
         def get_all_updated_current_state_deltas_txn(txn):