summary refs log tree commit diff
path: root/synapse/storage/events.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/events.py')
-rw-r--r--synapse/storage/events.py307
1 files changed, 215 insertions, 92 deletions
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 9fc65229fd..b96104ccae 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -14,52 +14,57 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.events_worker import EventsWorkerStore
+from collections import OrderedDict, deque, namedtuple
+from functools import wraps
+import itertools
+import logging
 
+import simplejson as json
 from twisted.internet import defer
 
-from synapse.events import USE_FROZEN_DICTS
-
+from synapse.storage.events_worker import EventsWorkerStore
 from synapse.util.async import ObservableDeferred
+from synapse.util.frozenutils import frozendict_json_encoder
 from synapse.util.logcontext import (
-    PreserveLoggingContext, make_deferred_yieldable
+    PreserveLoggingContext, make_deferred_yieldable,
 )
 from synapse.util.logutils import log_function
 from synapse.util.metrics import Measure
 from synapse.api.constants import EventTypes
 from synapse.api.errors import SynapseError
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
-from synapse.types import get_domain_from_id
-
-from canonicaljson import encode_canonical_json
-from collections import deque, namedtuple, OrderedDict
-from functools import wraps
-
+from synapse.types import get_domain_from_id, RoomStreamToken
 import synapse.metrics
 
-import logging
-import simplejson as json
-
 # these are only included to make the type annotations work
 from synapse.events import EventBase    # noqa: F401
 from synapse.events.snapshot import EventContext   # noqa: F401
 
+from prometheus_client import Counter
+
 logger = logging.getLogger(__name__)
 
+persist_event_counter = Counter("synapse_storage_events_persisted_events", "")
+event_counter = Counter("synapse_storage_events_persisted_events_sep", "",
+                        ["type", "origin_type", "origin_entity"])
 
-metrics = synapse.metrics.get_metrics_for(__name__)
-persist_event_counter = metrics.register_counter("persisted_events")
-event_counter = metrics.register_counter(
-    "persisted_events_sep", labels=["type", "origin_type", "origin_entity"]
-)
+# The number of times we are recalculating the current state
+state_delta_counter = Counter("synapse_storage_events_state_delta", "")
+
+# The number of times we are recalculating state when there is only a
+# single forward extremity
+state_delta_single_event_counter = Counter(
+    "synapse_storage_events_state_delta_single_event", "")
+
+# The number of times we are reculating state when we could have resonably
+# calculated the delta when we calculated the state for an event we were
+# persisting.
+state_delta_reuse_delta_counter = Counter(
+    "synapse_storage_events_state_delta_reuse_delta", "")
 
 
 def encode_json(json_object):
-    if USE_FROZEN_DICTS:
-        # ujson doesn't like frozen_dicts
-        return encode_canonical_json(json_object)
-    else:
-        return json.dumps(json_object, ensure_ascii=False)
+    return frozendict_json_encoder.encode(json_object)
 
 
 class _EventPeristenceQueue(object):
@@ -369,7 +374,8 @@ class EventsStore(EventsWorkerStore):
                                 room_id, ev_ctx_rm, latest_event_ids
                             )
 
-                            if new_latest_event_ids == set(latest_event_ids):
+                            latest_event_ids = set(latest_event_ids)
+                            if new_latest_event_ids == latest_event_ids:
                                 # No change in extremities, so no change in state
                                 continue
 
@@ -390,12 +396,34 @@ class EventsStore(EventsWorkerStore):
                                 if all_single_prev_not_state:
                                     continue
 
+                            state_delta_counter.inc()
+                            if len(new_latest_event_ids) == 1:
+                                state_delta_single_event_counter.inc()
+
+                                # This is a fairly handwavey check to see if we could
+                                # have guessed what the delta would have been when
+                                # processing one of these events.
+                                # What we're interested in is if the latest extremities
+                                # were the same when we created the event as they are
+                                # now. When this server creates a new event (as opposed
+                                # to receiving it over federation) it will use the
+                                # forward extremities as the prev_events, so we can
+                                # guess this by looking at the prev_events and checking
+                                # if they match the current forward extremities.
+                                for ev, _ in ev_ctx_rm:
+                                    prev_event_ids = set(e for e, _ in ev.prev_events)
+                                    if latest_event_ids == prev_event_ids:
+                                        state_delta_reuse_delta_counter.inc()
+                                        break
+
                             logger.info(
                                 "Calculating state delta for room %s", room_id,
                             )
                             current_state = yield self._get_new_state_after_events(
                                 room_id,
-                                ev_ctx_rm, new_latest_event_ids,
+                                ev_ctx_rm,
+                                latest_event_ids,
+                                new_latest_event_ids,
                             )
                             if current_state is not None:
                                 current_state_for_room[room_id] = current_state
@@ -414,7 +442,10 @@ class EventsStore(EventsWorkerStore):
                     state_delta_for_room=state_delta_for_room,
                     new_forward_extremeties=new_forward_extremeties,
                 )
-                persist_event_counter.inc_by(len(chunk))
+                persist_event_counter.inc(len(chunk))
+                synapse.metrics.event_persisted_position.set(
+                    chunk[-1][0].internal_metadata.stream_ordering,
+                )
                 for event, context in chunk:
                     if context.app_service:
                         origin_type = "local"
@@ -426,7 +457,7 @@ class EventsStore(EventsWorkerStore):
                         origin_type = "remote"
                         origin_entity = get_domain_from_id(event.sender)
 
-                    event_counter.inc(event.type, origin_type, origin_entity)
+                    event_counter.labels(event.type, origin_type, origin_entity).inc()
 
                 for room_id, new_state in current_state_for_room.iteritems():
                     self.get_current_state_ids.prefill(
@@ -480,7 +511,8 @@ class EventsStore(EventsWorkerStore):
         defer.returnValue(new_latest_event_ids)
 
     @defer.inlineCallbacks
-    def _get_new_state_after_events(self, room_id, events_context, new_latest_event_ids):
+    def _get_new_state_after_events(self, room_id, events_context, old_latest_event_ids,
+                                    new_latest_event_ids):
         """Calculate the current state dict after adding some new events to
         a room
 
@@ -491,6 +523,9 @@ class EventsStore(EventsWorkerStore):
             events_context (list[(EventBase, EventContext)]):
                 events and contexts which are being added to the room
 
+            old_latest_event_ids (iterable[str]):
+                the old forward extremities for the room.
+
             new_latest_event_ids (iterable[str]):
                 the new forward extremities for the room.
 
@@ -501,64 +536,89 @@ class EventsStore(EventsWorkerStore):
         """
 
         if not new_latest_event_ids:
-            defer.returnValue({})
+            return
 
         # map from state_group to ((type, key) -> event_id) state map
-        state_groups = {}
-        missing_event_ids = []
-        was_updated = False
+        state_groups_map = {}
+        for ev, ctx in events_context:
+            if ctx.state_group is None:
+                # I don't think this can happen, but let's double-check
+                raise Exception(
+                    "Context for new extremity event %s has no state "
+                    "group" % (ev.event_id, ),
+                )
+
+            if ctx.state_group in state_groups_map:
+                continue
+
+            state_groups_map[ctx.state_group] = ctx.current_state_ids
+
+        # We need to map the event_ids to their state groups. First, let's
+        # check if the event is one we're persisting, in which case we can
+        # pull the state group from its context.
+        # Otherwise we need to pull the state group from the database.
+
+        # Set of events we need to fetch groups for. (We know none of the old
+        # extremities are going to be in events_context).
+        missing_event_ids = set(old_latest_event_ids)
+
+        event_id_to_state_group = {}
         for event_id in new_latest_event_ids:
-            # First search in the list of new events we're adding,
-            # and then use the current state from that
+            # First search in the list of new events we're adding.
             for ev, ctx in events_context:
                 if event_id == ev.event_id:
-                    if ctx.current_state_ids is None:
-                        raise Exception("Unknown current state")
-
-                    if ctx.state_group is None:
-                        # I don't think this can happen, but let's double-check
-                        raise Exception(
-                            "Context for new extremity event %s has no state "
-                            "group" % (event_id, ),
-                        )
-
-                    # If we've already seen the state group don't bother adding
-                    # it to the state sets again
-                    if ctx.state_group not in state_groups:
-                        state_groups[ctx.state_group] = ctx.current_state_ids
-                        if ctx.delta_ids or hasattr(ev, "state_key"):
-                            was_updated = True
+                    event_id_to_state_group[event_id] = ctx.state_group
                     break
             else:
                 # If we couldn't find it, then we'll need to pull
                 # the state from the database
-                was_updated = True
-                missing_event_ids.append(event_id)
-
-        if not was_updated:
-            return
+                missing_event_ids.add(event_id)
 
         if missing_event_ids:
-            # Now pull out the state for any missing events from DB
+            # Now pull out the state groups for any missing events from DB
             event_to_groups = yield self._get_state_group_for_events(
                 missing_event_ids,
             )
+            event_id_to_state_group.update(event_to_groups)
 
-            groups = set(event_to_groups.itervalues()) - set(state_groups.iterkeys())
+        # State groups of old_latest_event_ids
+        old_state_groups = set(
+            event_id_to_state_group[evid] for evid in old_latest_event_ids
+        )
 
-            if groups:
-                group_to_state = yield self._get_state_for_groups(groups)
-                state_groups.update(group_to_state)
+        # State groups of new_latest_event_ids
+        new_state_groups = set(
+            event_id_to_state_group[evid] for evid in new_latest_event_ids
+        )
 
-        if len(state_groups) == 1:
+        # If they old and new groups are the same then we don't need to do
+        # anything.
+        if old_state_groups == new_state_groups:
+            return
+
+        # Now that we have calculated new_state_groups we need to get
+        # their state IDs so we can resolve to a single state set.
+        missing_state = new_state_groups - set(state_groups_map)
+        if missing_state:
+            group_to_state = yield self._get_state_for_groups(missing_state)
+            state_groups_map.update(group_to_state)
+
+        if len(new_state_groups) == 1:
             # If there is only one state group, then we know what the current
             # state is.
-            defer.returnValue(state_groups.values()[0])
+            defer.returnValue(state_groups_map[new_state_groups.pop()])
+
+        # Ok, we need to defer to the state handler to resolve our state sets.
 
         def get_events(ev_ids):
             return self.get_events(
                 ev_ids, get_prev_content=False, check_redacted=False,
             )
+
+        state_groups = {
+            sg: state_groups_map[sg] for sg in new_state_groups
+        }
+
         events_map = {ev.event_id: ev for ev, _ in events_context}
         logger.debug("calling resolve_state_groups from preserve_events")
         res = yield self._state_resolution_handler.resolve_state_groups(
@@ -1288,13 +1348,49 @@ class EventsStore(EventsWorkerStore):
 
         defer.returnValue(set(r["event_id"] for r in rows))
 
-    def have_events(self, event_ids):
+    @defer.inlineCallbacks
+    def have_seen_events(self, event_ids):
         """Given a list of event ids, check if we have already processed them.
 
+        Args:
+            event_ids (iterable[str]):
+
+        Returns:
+            Deferred[set[str]]: The events we have already seen.
+        """
+        results = set()
+
+        def have_seen_events_txn(txn, chunk):
+            sql = (
+                "SELECT event_id FROM events as e WHERE e.event_id IN (%s)"
+                % (",".join("?" * len(chunk)), )
+            )
+            txn.execute(sql, chunk)
+            for (event_id, ) in txn:
+                results.add(event_id)
+
+        # break the input up into chunks of 100
+        input_iterator = iter(event_ids)
+        for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)),
+                          []):
+            yield self.runInteraction(
+                "have_seen_events",
+                have_seen_events_txn,
+                chunk,
+            )
+        defer.returnValue(results)
+
+    def get_seen_events_with_rejections(self, event_ids):
+        """Given a list of event ids, check if we rejected them.
+
+        Args:
+            event_ids (list[str])
+
         Returns:
-            dict: Has an entry for each event id we already have seen. Maps to
-            the rejected reason string if we rejected the event, else maps to
-            None.
+            Deferred[dict[str, str|None):
+                Has an entry for each event id we already have seen. Maps to
+                the rejected reason string if we rejected the event, else maps
+                to None.
         """
         if not event_ids:
             return defer.succeed({})
@@ -1316,9 +1412,7 @@ class EventsStore(EventsWorkerStore):
 
             return res
 
-        return self.runInteraction(
-            "have_events", f,
-        )
+        return self.runInteraction("get_rejection_reasons", f)
 
     @defer.inlineCallbacks
     def count_daily_messages(self):
@@ -1706,15 +1800,14 @@ class EventsStore(EventsWorkerStore):
         return self.runInteraction("get_all_new_events", get_all_new_events_txn)
 
     def purge_history(
-        self, room_id, topological_ordering, delete_local_events,
+        self, room_id, token, delete_local_events,
     ):
         """Deletes room history before a certain point
 
         Args:
             room_id (str):
 
-            topological_ordering (int):
-                minimum topo ordering to preserve
+            token (str): A topological token to delete events before
 
             delete_local_events (bool):
                 if True, we will delete local events as well as remote ones
@@ -1724,13 +1817,15 @@ class EventsStore(EventsWorkerStore):
 
         return self.runInteraction(
             "purge_history",
-            self._purge_history_txn, room_id, topological_ordering,
+            self._purge_history_txn, room_id, token,
             delete_local_events,
         )
 
     def _purge_history_txn(
-        self, txn, room_id, topological_ordering, delete_local_events,
+        self, txn, room_id, token_str, delete_local_events,
     ):
+        token = RoomStreamToken.parse(token_str)
+
         # Tables that should be pruned:
         #     event_auth
         #     event_backward_extremities
@@ -1775,6 +1870,13 @@ class EventsStore(EventsWorkerStore):
             " ON events_to_purge(should_delete)",
         )
 
+        # We do joins against events_to_purge for e.g. calculating state
+        # groups to purge, etc., so lets make an index.
+        txn.execute(
+            "CREATE INDEX events_to_purge_id"
+            " ON events_to_purge(event_id)",
+        )
+
         # First ensure that we're not about to delete all the forward extremeties
         txn.execute(
             "SELECT e.event_id, e.depth FROM events as e "
@@ -1787,7 +1889,7 @@ class EventsStore(EventsWorkerStore):
         rows = txn.fetchall()
         max_depth = max(row[0] for row in rows)
 
-        if max_depth <= topological_ordering:
+        if max_depth <= token.topological:
             # We need to ensure we don't delete all the events from the datanase
             # otherwise we wouldn't be able to send any events (due to not
             # having any backwards extremeties)
@@ -1803,7 +1905,7 @@ class EventsStore(EventsWorkerStore):
             should_delete_expr += " AND event_id NOT LIKE ?"
             should_delete_params += ("%:" + self.hs.hostname, )
 
-        should_delete_params += (room_id, topological_ordering)
+        should_delete_params += (room_id, token.topological)
 
         txn.execute(
             "INSERT INTO events_to_purge"
@@ -1826,13 +1928,13 @@ class EventsStore(EventsWorkerStore):
         logger.info("[purge] Finding new backward extremities")
 
         # We calculate the new entries for the backward extremeties by finding
-        # all events that point to events that are to be purged
+        # events to be purged that are pointed to by events we're not going to
+        # purge.
         txn.execute(
             "SELECT DISTINCT e.event_id FROM events_to_purge AS e"
             " INNER JOIN event_edges AS ed ON e.event_id = ed.prev_event_id"
-            " INNER JOIN events AS e2 ON e2.event_id = ed.event_id"
-            " WHERE e2.topological_ordering >= ?",
-            (topological_ordering, )
+            " LEFT JOIN events_to_purge AS ep2 ON ed.event_id = ep2.event_id"
+            " WHERE ep2.event_id IS NULL",
         )
         new_backwards_extrems = txn.fetchall()
 
@@ -1856,16 +1958,22 @@ class EventsStore(EventsWorkerStore):
 
         # Get all state groups that are only referenced by events that are
         # to be deleted.
-        txn.execute(
-            "SELECT state_group FROM event_to_state_groups"
-            " INNER JOIN events USING (event_id)"
-            " WHERE state_group IN ("
-            "   SELECT DISTINCT state_group FROM events_to_purge"
-            "   INNER JOIN event_to_state_groups USING (event_id)"
-            " )"
-            " GROUP BY state_group HAVING MAX(topological_ordering) < ?",
-            (topological_ordering, )
-        )
+        # This works by first getting state groups that we may want to delete,
+        # joining against event_to_state_groups to get events that use that
+        # state group, then left joining against events_to_purge again. Any
+        # state group where the left join produce *no nulls* are referenced
+        # only by events that are going to be purged.
+        txn.execute("""
+            SELECT state_group FROM
+            (
+                SELECT DISTINCT state_group FROM events_to_purge
+                INNER JOIN event_to_state_groups USING (event_id)
+            ) AS sp
+            INNER JOIN event_to_state_groups USING (state_group)
+            LEFT JOIN events_to_purge AS ep USING (event_id)
+            GROUP BY state_group
+            HAVING SUM(CASE WHEN ep.event_id IS NULL THEN 1 ELSE 0 END) = 0
+        """)
 
         state_rows = txn.fetchall()
         logger.info("[purge] found %i redundant state groups", len(state_rows))
@@ -2012,10 +2120,25 @@ class EventsStore(EventsWorkerStore):
         #
         # So, let's stick it at the end so that we don't block event
         # persistence.
-        logger.info("[purge] updating room_depth")
+        #
+        # We do this by calculating the minimum depth of the backwards
+        # extremities. However, the events in event_backward_extremities
+        # are ones we don't have yet so we need to look at the events that
+        # point to it via event_edges table.
+        txn.execute("""
+            SELECT COALESCE(MIN(depth), 0)
+            FROM event_backward_extremities AS eb
+            INNER JOIN event_edges AS eg ON eg.prev_event_id = eb.event_id
+            INNER JOIN events AS e ON e.event_id = eg.event_id
+            WHERE eb.room_id = ?
+        """, (room_id,))
+        min_depth, = txn.fetchone()
+
+        logger.info("[purge] updating room_depth to %d", min_depth)
+
         txn.execute(
             "UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
-            (topological_ordering, room_id,)
+            (min_depth, room_id,)
         )
 
         # finally, drop the temp table. this will commit the txn in sqlite,