summary refs log tree commit diff
path: root/synapse/storage/events.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/events.py')
-rw-r--r--synapse/storage/events.py402
1 files changed, 293 insertions, 109 deletions
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index db01eb6d14..3f6833fad2 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -34,14 +34,16 @@ from canonicaljson import encode_canonical_json
 from collections import deque, namedtuple, OrderedDict
 from functools import wraps
 
-import synapse
 import synapse.metrics
 
-
 import logging
 import math
 import ujson as json
 
+# these are only included to make the type annotations work
+from synapse.events import EventBase    # noqa: F401
+from synapse.events.snapshot import EventContext   # noqa: F401
+
 logger = logging.getLogger(__name__)
 
 
@@ -82,6 +84,11 @@ class _EventPeristenceQueue(object):
 
     def add_to_queue(self, room_id, events_and_contexts, backfilled):
         """Add events to the queue, with the given persist_event options.
+
+        Args:
+            room_id (str):
+            events_and_contexts (list[(EventBase, EventContext)]):
+            backfilled (bool):
         """
         queue = self._event_persist_queues.setdefault(room_id, deque())
         if queue:
@@ -210,14 +217,14 @@ class EventsStore(SQLBaseStore):
             partitioned.setdefault(event.room_id, []).append((event, ctx))
 
         deferreds = []
-        for room_id, evs_ctxs in partitioned.items():
+        for room_id, evs_ctxs in partitioned.iteritems():
             d = preserve_fn(self._event_persist_queue.add_to_queue)(
                 room_id, evs_ctxs,
                 backfilled=backfilled,
             )
             deferreds.append(d)
 
-        for room_id in partitioned.keys():
+        for room_id in partitioned:
             self._maybe_start_persisting(room_id)
 
         return preserve_context_over_deferred(
@@ -227,6 +234,17 @@ class EventsStore(SQLBaseStore):
     @defer.inlineCallbacks
     @log_function
     def persist_event(self, event, context, backfilled=False):
+        """
+
+        Args:
+            event (EventBase):
+            context (EventContext):
+            backfilled (bool):
+
+        Returns:
+            Deferred: resolves to (int, int): the stream ordering of ``event``,
+            and the stream ordering of the latest persisted event
+        """
         deferred = self._event_persist_queue.add_to_queue(
             event.room_id, [(event, context)],
             backfilled=backfilled,
@@ -253,6 +271,16 @@ class EventsStore(SQLBaseStore):
     @defer.inlineCallbacks
     def _persist_events(self, events_and_contexts, backfilled=False,
                         delete_existing=False):
+        """Persist events to db
+
+        Args:
+            events_and_contexts (list[(EventBase, EventContext)]):
+            backfilled (bool):
+            delete_existing (bool):
+
+        Returns:
+            Deferred: resolves when the events have been persisted
+        """
         if not events_and_contexts:
             return
 
@@ -295,7 +323,7 @@ class EventsStore(SQLBaseStore):
                                 (event, context)
                             )
 
-                        for room_id, ev_ctx_rm in events_by_room.items():
+                        for room_id, ev_ctx_rm in events_by_room.iteritems():
                             # Work out new extremities by recursively adding and removing
                             # the new events.
                             latest_event_ids = yield self.get_latest_event_ids_in_room(
@@ -400,6 +428,7 @@ class EventsStore(SQLBaseStore):
         # Now we need to work out the different state sets for
         # each state extremities
         state_sets = []
+        state_groups = set()
         missing_event_ids = []
         was_updated = False
         for event_id in new_latest_event_ids:
@@ -409,9 +438,17 @@ class EventsStore(SQLBaseStore):
                 if event_id == ev.event_id:
                     if ctx.current_state_ids is None:
                         raise Exception("Unknown current state")
-                    state_sets.append(ctx.current_state_ids)
-                    if ctx.delta_ids or hasattr(ev, "state_key"):
-                        was_updated = True
+
+                    # If we've already seen the state group don't bother adding
+                    # it to the state sets again
+                    if ctx.state_group not in state_groups:
+                        state_sets.append(ctx.current_state_ids)
+                        if ctx.delta_ids or hasattr(ev, "state_key"):
+                            was_updated = True
+                        if ctx.state_group:
+                            # Add this as a seen state group (if it has a state
+                            # group)
+                            state_groups.add(ctx.state_group)
                     break
             else:
                 # If we couldn't find it, then we'll need to pull
@@ -425,31 +462,57 @@ class EventsStore(SQLBaseStore):
                 missing_event_ids,
             )
 
-            groups = set(event_to_groups.values())
-            group_to_state = yield self._get_state_for_groups(groups)
+            groups = set(event_to_groups.itervalues()) - state_groups
 
-            state_sets.extend(group_to_state.values())
+            if groups:
+                group_to_state = yield self._get_state_for_groups(groups)
+                state_sets.extend(group_to_state.itervalues())
 
         if not new_latest_event_ids:
             current_state = {}
         elif was_updated:
-            current_state = yield resolve_events(
-                state_sets,
-                state_map_factory=lambda ev_ids: self.get_events(
-                    ev_ids, get_prev_content=False, check_redacted=False,
-                ),
-            )
+            if len(state_sets) == 1:
+                # If there is only one state set, then we know what the current
+                # state is.
+                current_state = state_sets[0]
+            else:
+                # We work out the current state by passing the state sets to the
+                # state resolution algorithm. It may ask for some events, including
+                # the events we have yet to persist, so we need a slightly more
+                # complicated event lookup function than simply looking the events
+                # up in the db.
+                events_map = {ev.event_id: ev for ev, _ in events_context}
+
+                @defer.inlineCallbacks
+                def get_events(ev_ids):
+                    # We get the events by first looking at the list of events we
+                    # are trying to persist, and then fetching the rest from the DB.
+                    db = []
+                    to_return = {}
+                    for ev_id in ev_ids:
+                        ev = events_map.get(ev_id, None)
+                        if ev:
+                            to_return[ev_id] = ev
+                        else:
+                            db.append(ev_id)
+
+                    if db:
+                        evs = yield self.get_events(
+                            ev_ids, get_prev_content=False, check_redacted=False,
+                        )
+                        to_return.update(evs)
+                    defer.returnValue(to_return)
+
+                current_state = yield resolve_events(
+                    state_sets,
+                    state_map_factory=get_events,
+                )
         else:
             return
 
-        existing_state_rows = yield self._simple_select_list(
-            table="current_state_events",
-            keyvalues={"room_id": room_id},
-            retcols=["event_id", "type", "state_key"],
-            desc="_calculate_state_delta",
-        )
+        existing_state = yield self.get_current_state_ids(room_id)
 
-        existing_events = set(row["event_id"] for row in existing_state_rows)
+        existing_events = set(existing_state.itervalues())
         new_events = set(ev_id for ev_id in current_state.itervalues())
         changed_events = existing_events ^ new_events
 
@@ -457,9 +520,8 @@ class EventsStore(SQLBaseStore):
             return
 
         to_delete = {
-            (row["type"], row["state_key"]): row["event_id"]
-            for row in existing_state_rows
-            if row["event_id"] in changed_events
+            key: ev_id for key, ev_id in existing_state.iteritems()
+            if ev_id in changed_events
         }
         events_to_insert = (new_events - existing_events)
         to_insert = {
@@ -535,11 +597,91 @@ class EventsStore(SQLBaseStore):
         and the rejections table. Things reading from those table will need to check
         whether the event was rejected.
 
-        If delete_existing is True then existing events will be purged from the
-        database before insertion. This is useful when retrying due to IntegrityError.
+        Args:
+            txn (twisted.enterprise.adbapi.Connection): db connection
+            events_and_contexts (list[(EventBase, EventContext)]):
+                events to persist
+            backfilled (bool): True if the events were backfilled
+            delete_existing (bool): True to purge existing table rows for the
+                events from the database. This is useful when retrying due to
+                IntegrityError.
+            current_state_for_room (dict[str, (list[str], list[str])]):
+                The current-state delta for each room. For each room, a tuple
+                (to_delete, to_insert), being a list of event ids to be removed
+                from the current state, and a list of event ids to be added to
+                the current state.
+            new_forward_extremeties (dict[str, list[str]]):
+                The new forward extremities for each room. For each room, a
+                list of the event ids which are the forward extremities.
+
         """
+        self._update_current_state_txn(txn, current_state_for_room)
+
         max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
-        for room_id, current_state_tuple in current_state_for_room.iteritems():
+        self._update_forward_extremities_txn(
+            txn,
+            new_forward_extremities=new_forward_extremeties,
+            max_stream_order=max_stream_order,
+        )
+
+        # Ensure that we don't have the same event twice.
+        events_and_contexts = self._filter_events_and_contexts_for_duplicates(
+            events_and_contexts,
+        )
+
+        self._update_room_depths_txn(
+            txn,
+            events_and_contexts=events_and_contexts,
+            backfilled=backfilled,
+        )
+
+        # _update_outliers_txn filters out any events which have already been
+        # persisted, and returns the filtered list.
+        events_and_contexts = self._update_outliers_txn(
+            txn,
+            events_and_contexts=events_and_contexts,
+        )
+
+        # From this point onwards the events are only events that we haven't
+        # seen before.
+
+        if delete_existing:
+            # For paranoia reasons, we go and delete all the existing entries
+            # for these events so we can reinsert them.
+            # This gets around any problems with some tables already having
+            # entries.
+            self._delete_existing_rows_txn(
+                txn,
+                events_and_contexts=events_and_contexts,
+            )
+
+        self._store_event_txn(
+            txn,
+            events_and_contexts=events_and_contexts,
+        )
+
+        # Insert into the state_groups, state_groups_state, and
+        # event_to_state_groups tables.
+        self._store_mult_state_groups_txn(txn, events_and_contexts)
+
+        # _store_rejected_events_txn filters out any events which were
+        # rejected, and returns the filtered list.
+        events_and_contexts = self._store_rejected_events_txn(
+            txn,
+            events_and_contexts=events_and_contexts,
+        )
+
+        # From this point onwards the events are only ones that weren't
+        # rejected.
+
+        self._update_metadata_tables_txn(
+            txn,
+            events_and_contexts=events_and_contexts,
+            backfilled=backfilled,
+        )
+
+    def _update_current_state_txn(self, txn, state_delta_by_room):
+        for room_id, current_state_tuple in state_delta_by_room.iteritems():
                 to_delete, to_insert = current_state_tuple
                 txn.executemany(
                     "DELETE FROM current_state_events WHERE event_id = ?",
@@ -585,7 +727,13 @@ class EventsStore(SQLBaseStore):
                     txn, self.get_users_in_room, (room_id,)
                 )
 
-        for room_id, new_extrem in new_forward_extremeties.items():
+                self._invalidate_cache_and_stream(
+                    txn, self.get_current_state_ids, (room_id,)
+                )
+
+    def _update_forward_extremities_txn(self, txn, new_forward_extremities,
+                                        max_stream_order):
+        for room_id, new_extrem in new_forward_extremities.iteritems():
             self._simple_delete_txn(
                 txn,
                 table="event_forward_extremities",
@@ -603,7 +751,7 @@ class EventsStore(SQLBaseStore):
                     "event_id": ev_id,
                     "room_id": room_id,
                 }
-                for room_id, new_extrem in new_forward_extremeties.items()
+                for room_id, new_extrem in new_forward_extremities.iteritems()
                 for ev_id in new_extrem
             ],
         )
@@ -620,13 +768,22 @@ class EventsStore(SQLBaseStore):
                     "event_id": event_id,
                     "stream_ordering": max_stream_order,
                 }
-                for room_id, new_extrem in new_forward_extremeties.items()
+                for room_id, new_extrem in new_forward_extremities.iteritems()
                 for event_id in new_extrem
             ]
         )
 
-        # Ensure that we don't have the same event twice.
-        # Pick the earliest non-outlier if there is one, else the earliest one.
+    @classmethod
+    def _filter_events_and_contexts_for_duplicates(cls, events_and_contexts):
+        """Ensure that we don't have the same event twice.
+
+        Pick the earliest non-outlier if there is one, else the earliest one.
+
+        Args:
+            events_and_contexts (list[(EventBase, EventContext)]):
+        Returns:
+            list[(EventBase, EventContext)]: filtered list
+        """
         new_events_and_contexts = OrderedDict()
         for event, context in events_and_contexts:
             prev_event_context = new_events_and_contexts.get(event.event_id)
@@ -639,9 +796,17 @@ class EventsStore(SQLBaseStore):
                         new_events_and_contexts[event.event_id] = (event, context)
             else:
                 new_events_and_contexts[event.event_id] = (event, context)
+        return new_events_and_contexts.values()
 
-        events_and_contexts = new_events_and_contexts.values()
+    def _update_room_depths_txn(self, txn, events_and_contexts, backfilled):
+        """Update min_depth for each room
 
+        Args:
+            txn (twisted.enterprise.adbapi.Connection): db connection
+            events_and_contexts (list[(EventBase, EventContext)]): events
+                we are persisting
+            backfilled (bool): True if the events were backfilled
+        """
         depth_updates = {}
         for event, context in events_and_contexts:
             # Remove the any existing cache entries for the event_ids
@@ -657,9 +822,24 @@ class EventsStore(SQLBaseStore):
                     event.depth, depth_updates.get(event.room_id, event.depth)
                 )
 
-        for room_id, depth in depth_updates.items():
+        for room_id, depth in depth_updates.iteritems():
             self._update_min_depth_for_room_txn(txn, room_id, depth)
 
+    def _update_outliers_txn(self, txn, events_and_contexts):
+        """Update any outliers with new event info.
+
+        This turns outliers into ex-outliers (unless the new event was
+        rejected).
+
+        Args:
+            txn (twisted.enterprise.adbapi.Connection): db connection
+            events_and_contexts (list[(EventBase, EventContext)]): events
+                we are persisting
+
+        Returns:
+            list[(EventBase, EventContext)] new list, without events which
+            are already in the events table.
+        """
         txn.execute(
             "SELECT event_id, outlier FROM events WHERE event_id in (%s)" % (
                 ",".join(["?"] * len(events_and_contexts)),
@@ -669,24 +849,21 @@ class EventsStore(SQLBaseStore):
 
         have_persisted = {
             event_id: outlier
-            for event_id, outlier in txn.fetchall()
+            for event_id, outlier in txn
         }
 
         to_remove = set()
         for event, context in events_and_contexts:
-            if context.rejected:
-                # If the event is rejected then we don't care if the event
-                # was an outlier or not.
-                if event.event_id in have_persisted:
-                    # If we have already seen the event then ignore it.
-                    to_remove.add(event)
-                continue
-
             if event.event_id not in have_persisted:
                 continue
 
             to_remove.add(event)
 
+            if context.rejected:
+                # If the event is rejected then we don't care if the event
+                # was an outlier or not.
+                continue
+
             outlier_persisted = have_persisted[event.event_id]
             if not event.internal_metadata.is_outlier() and outlier_persisted:
                 # We received a copy of an event that we had already stored as
@@ -741,37 +918,19 @@ class EventsStore(SQLBaseStore):
                 # event isn't an outlier any more.
                 self._update_backward_extremeties(txn, [event])
 
-        events_and_contexts = [
+        return [
             ec for ec in events_and_contexts if ec[0] not in to_remove
         ]
 
+    @classmethod
+    def _delete_existing_rows_txn(cls, txn, events_and_contexts):
         if not events_and_contexts:
-            # Make sure we don't pass an empty list to functions that expect to
-            # be storing at least one element.
+            # nothing to do here
             return
 
-        # From this point onwards the events are only events that we haven't
-        # seen before.
-
-        def event_dict(event):
-            return {
-                k: v
-                for k, v in event.get_dict().items()
-                if k not in [
-                    "redacted",
-                    "redacted_because",
-                ]
-            }
-
-        if delete_existing:
-            # For paranoia reasons, we go and delete all the existing entries
-            # for these events so we can reinsert them.
-            # This gets around any problems with some tables already having
-            # entries.
+        logger.info("Deleting existing")
 
-            logger.info("Deleting existing")
-
-            for table in (
+        for table in (
                 "events",
                 "event_auth",
                 "event_json",
@@ -794,11 +953,30 @@ class EventsStore(SQLBaseStore):
                 "redactions",
                 "room_memberships",
                 "topics"
-            ):
-                txn.executemany(
-                    "DELETE FROM %s WHERE event_id = ?" % (table,),
-                    [(ev.event_id,) for ev, _ in events_and_contexts]
-                )
+        ):
+            txn.executemany(
+                "DELETE FROM %s WHERE event_id = ?" % (table,),
+                [(ev.event_id,) for ev, _ in events_and_contexts]
+            )
+
+    def _store_event_txn(self, txn, events_and_contexts):
+        """Insert new events into the event and event_json tables
+
+        Args:
+            txn (twisted.enterprise.adbapi.Connection): db connection
+            events_and_contexts (list[(EventBase, EventContext)]): events
+                we are persisting
+        """
+
+        if not events_and_contexts:
+            # nothing to do here
+            return
+
+        def event_dict(event):
+            d = event.get_dict()
+            d.pop("redacted", None)
+            d.pop("redacted_because", None)
+            return d
 
         self._simple_insert_many_txn(
             txn,
@@ -842,6 +1020,19 @@ class EventsStore(SQLBaseStore):
             ],
         )
 
+    def _store_rejected_events_txn(self, txn, events_and_contexts):
+        """Add rows to the 'rejections' table for received events which were
+        rejected
+
+        Args:
+            txn (twisted.enterprise.adbapi.Connection): db connection
+            events_and_contexts (list[(EventBase, EventContext)]): events
+                we are persisting
+
+        Returns:
+            list[(EventBase, EventContext)] new list, without the rejected
+                events.
+        """
         # Remove the rejected events from the list now that we've added them
         # to the events table and the events_json table.
         to_remove = set()
@@ -853,17 +1044,24 @@ class EventsStore(SQLBaseStore):
                 )
                 to_remove.add(event)
 
-        events_and_contexts = [
+        return [
             ec for ec in events_and_contexts if ec[0] not in to_remove
         ]
 
+    def _update_metadata_tables_txn(self, txn, events_and_contexts, backfilled):
+        """Update all the miscellaneous tables for new events
+
+        Args:
+            txn (twisted.enterprise.adbapi.Connection): db connection
+            events_and_contexts (list[(EventBase, EventContext)]): events
+                we are persisting
+            backfilled (bool): True if the events were backfilled
+        """
+
         if not events_and_contexts:
-            # Make sure we don't pass an empty list to functions that expect to
-            # be storing at least one element.
+            # nothing to do here
             return
 
-        # From this point onwards the events are only ones that weren't rejected.
-
         for event, context in events_and_contexts:
             # Insert all the push actions into the event_push_actions table.
             if context.push_actions:
@@ -892,10 +1090,6 @@ class EventsStore(SQLBaseStore):
             ],
         )
 
-        # Insert into the state_groups, state_groups_state, and
-        # event_to_state_groups tables.
-        self._store_mult_state_groups_txn(txn, events_and_contexts)
-
         # Update the event_forward_extremities, event_backward_extremities and
         # event_edges tables.
         self._handle_mult_prev_events(
@@ -982,13 +1176,6 @@ class EventsStore(SQLBaseStore):
         # Prefill the event cache
         self._add_to_cache(txn, events_and_contexts)
 
-        if backfilled:
-            # Backfilled events come before the current state so we don't need
-            # to update the current state table
-            return
-
-        return
-
     def _add_to_cache(self, txn, events_and_contexts):
         to_prefill = []
 
@@ -1597,14 +1784,13 @@ class EventsStore(SQLBaseStore):
 
         def get_all_new_events_txn(txn):
             sql = (
-                "SELECT e.stream_ordering, ej.internal_metadata, ej.json, eg.state_group"
-                " FROM events as e"
-                " JOIN event_json as ej"
-                " ON e.event_id = ej.event_id AND e.room_id = ej.room_id"
-                " LEFT JOIN event_to_state_groups as eg"
-                " ON e.event_id = eg.event_id"
-                " WHERE ? < e.stream_ordering AND e.stream_ordering <= ?"
-                " ORDER BY e.stream_ordering ASC"
+                "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
+                " state_key, redacts"
+                " FROM events AS e"
+                " LEFT JOIN redactions USING (event_id)"
+                " LEFT JOIN state_events USING (event_id)"
+                " WHERE ? < stream_ordering AND stream_ordering <= ?"
+                " ORDER BY stream_ordering ASC"
                 " LIMIT ?"
             )
             if have_forward_events:
@@ -1630,15 +1816,13 @@ class EventsStore(SQLBaseStore):
                 forward_ex_outliers = []
 
             sql = (
-                "SELECT -e.stream_ordering, ej.internal_metadata, ej.json,"
-                " eg.state_group"
-                " FROM events as e"
-                " JOIN event_json as ej"
-                " ON e.event_id = ej.event_id AND e.room_id = ej.room_id"
-                " LEFT JOIN event_to_state_groups as eg"
-                " ON e.event_id = eg.event_id"
-                " WHERE ? > e.stream_ordering AND e.stream_ordering >= ?"
-                " ORDER BY e.stream_ordering DESC"
+                "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
+                " state_key, redacts"
+                " FROM events AS e"
+                " LEFT JOIN redactions USING (event_id)"
+                " LEFT JOIN state_events USING (event_id)"
+                " WHERE ? > stream_ordering AND stream_ordering >= ?"
+                " ORDER BY stream_ordering DESC"
                 " LIMIT ?"
             )
             if have_backfill_events:
@@ -1825,7 +2009,7 @@ class EventsStore(SQLBaseStore):
                         "state_key": key[1],
                         "event_id": state_id,
                     }
-                    for key, state_id in curr_state.items()
+                    for key, state_id in curr_state.iteritems()
                 ],
             )