summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/events/snapshot.py45
-rw-r--r--synapse/handlers/message.py12
-rw-r--r--synapse/push/action_generator.py6
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py15
-rw-r--r--synapse/python_dependencies.py2
-rw-r--r--synapse/replication/http/send_event.py15
-rw-r--r--synapse/storage/background_updates.py19
-rw-r--r--synapse/storage/event_push_actions.py99
-rw-r--r--synapse/storage/events.py126
-rw-r--r--synapse/storage/registration.py7
-rw-r--r--synapse/storage/schema/delta/38/postgres_fts_gist.sql6
-rw-r--r--synapse/storage/schema/delta/47/postgres_fts_gin.sql17
-rw-r--r--synapse/storage/schema/delta/47/push_actions_staging.sql28
-rw-r--r--synapse/storage/search.py81
-rw-r--r--synapse/storage/state.py96
-rw-r--r--tests/replication/slave/storage/test_events.py5
-rw-r--r--tests/storage/test_event_push_actions.py10
17 files changed, 436 insertions, 153 deletions
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 7b80444f73..8e684d91b5 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from twisted.internet import defer
 
 from frozendict import frozendict
 
@@ -51,7 +52,6 @@ class EventContext(object):
         "prev_state_ids",
         "state_group",
         "rejected",
-        "push_actions",
         "prev_group",
         "delta_ids",
         "prev_state_events",
@@ -66,7 +66,6 @@ class EventContext(object):
         self.state_group = None
 
         self.rejected = False
-        self.push_actions = []
 
         # A previously persisted state group and a delta between that
         # and this state.
@@ -77,19 +76,32 @@ class EventContext(object):
 
         self.app_service = None
 
-    def serialize(self):
+    def serialize(self, event):
         """Converts self to a type that can be serialized as JSON, and then
         deserialized by `deserialize`
 
+        Args:
+            event (FrozenEvent): The event that this context relates to
+
         Returns:
             dict
         """
+
+        # We don't serialize the full state dicts, instead they get pulled out
+        # of the DB on the other side. However, the other side can't figure out
+        # the prev_state_ids, so if we're a state event we include the event
+        # id that we replaced in the state.
+        if event.is_state():
+            prev_state_id = self.prev_state_ids.get((event.type, event.state_key))
+        else:
+            prev_state_id = None
+
         return {
-            "current_state_ids": _encode_state_dict(self.current_state_ids),
-            "prev_state_ids": _encode_state_dict(self.prev_state_ids),
+            "prev_state_id": prev_state_id,
+            "event_type": event.type,
+            "event_state_key": event.state_key if event.is_state() else None,
             "state_group": self.state_group,
             "rejected": self.rejected,
-            "push_actions": self.push_actions,
             "prev_group": self.prev_group,
             "delta_ids": _encode_state_dict(self.delta_ids),
             "prev_state_events": self.prev_state_events,
@@ -97,6 +109,7 @@ class EventContext(object):
         }
 
     @staticmethod
+    @defer.inlineCallbacks
     def deserialize(store, input):
         """Converts a dict that was produced by `serialize` back into a
         EventContext.
@@ -109,20 +122,32 @@ class EventContext(object):
             EventContext
         """
         context = EventContext()
-        context.current_state_ids = _decode_state_dict(input["current_state_ids"])
-        context.prev_state_ids = _decode_state_dict(input["prev_state_ids"])
         context.state_group = input["state_group"]
         context.rejected = input["rejected"]
-        context.push_actions = input["push_actions"]
         context.prev_group = input["prev_group"]
         context.delta_ids = _decode_state_dict(input["delta_ids"])
         context.prev_state_events = input["prev_state_events"]
 
+        # We use the state_group and prev_state_id stuff to pull the
+        # current_state_ids out of the DB and construct prev_state_ids.
+        prev_state_id = input["prev_state_id"]
+        event_type = input["event_type"]
+        event_state_key = input["event_state_key"]
+
+        context.current_state_ids = yield store.get_state_ids_for_group(
+            context.state_group,
+        )
+        if prev_state_id and event_state_key:
+            context.prev_state_ids = dict(context.current_state_ids)
+            context.prev_state_ids[(event_type, event_state_key)] = prev_state_id
+        else:
+            context.prev_state_ids = context.current_state_ids
+
         app_service_id = input["app_service_id"]
         if app_service_id:
             context.app_service = store.get_app_service_by_id(app_service_id)
 
-        return context
+        defer.returnValue(context)
 
 
 def _encode_state_dict(state_dict):
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 1c3ac03f20..d99d8049b3 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -683,9 +683,15 @@ class EventCreationHandler(object):
             event, context
         )
 
-        (event_stream_id, max_stream_id) = yield self.store.persist_event(
-            event, context=context
-        )
+        try:
+            (event_stream_id, max_stream_id) = yield self.store.persist_event(
+                event, context=context
+            )
+        except:  # noqa: E722, as we reraise the exception this is fine.
+            # Ensure that we actually remove the entries in the push actions
+            # staging area
+            preserve_fn(self.store.remove_push_actions_from_staging)(event.event_id)
+            raise
 
         # this intentionally does not yield: we don't care about the result
         # and don't need to wait for it.
diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py
index fe09d50d55..8f619a7a1b 100644
--- a/synapse/push/action_generator.py
+++ b/synapse/push/action_generator.py
@@ -40,10 +40,6 @@ class ActionGenerator(object):
     @defer.inlineCallbacks
     def handle_push_actions_for_event(self, event, context):
         with Measure(self.clock, "action_for_event_by_user"):
-            actions_by_user = yield self.bulk_evaluator.action_for_event_by_user(
+            yield self.bulk_evaluator.action_for_event_by_user(
                 event, context
             )
-
-        context.push_actions = [
-            (uid, actions) for uid, actions in actions_by_user.iteritems()
-        ]
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 425a017bdf..bf4f1c5836 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -137,14 +137,13 @@ class BulkPushRuleEvaluator(object):
 
     @defer.inlineCallbacks
     def action_for_event_by_user(self, event, context):
-        """Given an event and context, evaluate the push rules and return
-        the results
+        """Given an event and context, evaluate the push rules and insert the
+        results into the event_push_actions_staging table.
 
         Returns:
-            dict of user_id -> action
+            Deferred
         """
         rules_by_user = yield self._get_rules_for_event(event, context)
-        actions_by_user = {}
 
         room_members = yield self.store.get_joined_users_from_context(
             event, context
@@ -190,9 +189,13 @@ class BulkPushRuleEvaluator(object):
                 if matches:
                     actions = [x for x in rule['actions'] if x != 'dont_notify']
                     if actions and 'notify' in actions:
-                        actions_by_user[uid] = actions
+                        # Push rules say we should notify the user of this event,
+                        # so we mark it in the DB in the staging area. (This
+                        # will then get handled when we persist the event)
+                        yield self.store.add_push_actions_to_staging(
+                            event.event_id, uid, actions,
+                        )
                     break
-        defer.returnValue(actions_by_user)
 
 
 def _condition_checker(evaluator, conditions, uid, display_name, cache):
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 97b631e60d..5d65b5fd6e 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -24,7 +24,7 @@ REQUIREMENTS = {
     "unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"],
     "canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"],
     "signedjson>=1.0.0": ["signedjson>=1.0.0"],
-    "pynacl==0.3.0": ["nacl==0.3.0", "nacl.bindings"],
+    "pynacl>=1.2.1": ["nacl>=1.2.1", "nacl.bindings"],
     "service_identity>=1.0.0": ["service_identity>=1.0.0"],
     "Twisted>=16.0.0": ["twisted>=16.0.0"],
     "pyopenssl>=0.14": ["OpenSSL>=0.14"],
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index ff9b9d2f10..468f4b68f4 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -15,6 +15,7 @@
 
 from twisted.internet import defer
 
+from synapse.api.errors import SynapseError, MatrixCodeMessageException
 from synapse.events import FrozenEvent
 from synapse.events.snapshot import EventContext
 from synapse.http.servlet import RestServlet, parse_json_object_from_request
@@ -27,6 +28,7 @@ import re
 logger = logging.getLogger(__name__)
 
 
+@defer.inlineCallbacks
 def send_event_to_master(client, host, port, requester, event, context):
     """Send event to be handled on the master
 
@@ -44,11 +46,18 @@ def send_event_to_master(client, host, port, requester, event, context):
         "event": event.get_pdu_json(),
         "internal_metadata": event.internal_metadata.get_dict(),
         "rejected_reason": event.rejected_reason,
-        "context": context.serialize(),
+        "context": context.serialize(event),
         "requester": requester.serialize(),
     }
 
-    return client.post_json_get_json(uri, payload)
+    try:
+        result = yield client.post_json_get_json(uri, payload)
+    except MatrixCodeMessageException as e:
+        # We convert to SynapseError as we know that it was a SynapseError
+        # on the master process that we should send to the client. (And
+        # importantly, not stack traces everywhere)
+        raise SynapseError(e.code, e.msg, e.errcode)
+    defer.returnValue(result)
 
 
 class ReplicationSendEventRestServlet(RestServlet):
@@ -87,7 +96,7 @@ class ReplicationSendEventRestServlet(RestServlet):
             event = FrozenEvent(event_dict, internal_metadata, rejected_reason)
 
             requester = Requester.deserialize(self.store, content["requester"])
-            context = EventContext.deserialize(self.store, content["context"])
+            context = yield EventContext.deserialize(self.store, content["context"])
 
         if requester.user:
             request.authenticated_entity = requester.user.to_string()
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 11a1b942f1..c88759bf2c 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -242,6 +242,25 @@ class BackgroundUpdateStore(SQLBaseStore):
         """
         self._background_update_handlers[update_name] = update_handler
 
+    def register_noop_background_update(self, update_name):
+        """Register a noop handler for a background update.
+
+        This is useful when we previously did a background update, but no
+        longer wish to do the update. In this case the background update should
+        be removed from the schema delta files, but there may still be some
+        users who have the background update queued, so this method should
+        also be called to clear the update.
+
+        Args:
+            update_name (str): Name of update
+        """
+        @defer.inlineCallbacks
+        def noop_update(progress, batch_size):
+            yield self._end_background_update(update_name)
+            defer.returnValue(1)
+
+        self.register_background_update_handler(update_name, noop_update)
+
     def register_background_index_update(self, update_name, index_name,
                                          table, columns, where_clause=None,
                                          unique=False,
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index 8efe2fd4bb..f787431b7a 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -88,33 +88,50 @@ class EventPushActionsStore(SQLBaseStore):
             self._rotate_notifs, 30 * 60 * 1000
         )
 
-    def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples):
+    def _set_push_actions_for_event_and_users_txn(self, txn, event):
         """
         Args:
             event: the event set actions for
             tuples: list of tuples of (user_id, actions)
         """
-        values = []
-        for uid, actions in tuples:
-            is_highlight = 1 if _action_has_highlight(actions) else 0
-
-            values.append({
-                'room_id': event.room_id,
-                'event_id': event.event_id,
-                'user_id': uid,
-                'actions': _serialize_action(actions, is_highlight),
-                'stream_ordering': event.internal_metadata.stream_ordering,
-                'topological_ordering': event.depth,
-                'notif': 1,
-                'highlight': is_highlight,
-            })
-
-        for uid, __ in tuples:
+
+        sql = """
+            INSERT INTO event_push_actions (
+                room_id, event_id, user_id, actions, stream_ordering,
+                topological_ordering, notif, highlight
+            )
+            SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight
+            FROM event_push_actions_staging
+            WHERE event_id = ?
+        """
+
+        txn.execute(sql, (
+            event.room_id, event.internal_metadata.stream_ordering,
+            event.depth, event.event_id,
+        ))
+
+        user_ids = self._simple_select_onecol_txn(
+            txn,
+            table="event_push_actions_staging",
+            keyvalues={
+                "event_id": event.event_id,
+            },
+            retcol="user_id",
+        )
+
+        self._simple_delete_txn(
+            txn,
+            table="event_push_actions_staging",
+            keyvalues={
+                "event_id": event.event_id,
+            },
+        )
+
+        for uid in user_ids:
             txn.call_after(
                 self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
-                (event.room_id, uid)
+                (event.room_id, uid,)
             )
-        self._simple_insert_many_txn(txn, "event_push_actions", values)
 
     @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
     def get_unread_event_push_actions_by_room_for_user(
@@ -738,6 +755,50 @@ class EventPushActionsStore(SQLBaseStore):
             (rotate_to_stream_ordering,)
         )
 
+    def add_push_actions_to_staging(self, event_id, user_id, actions):
+        """Add the push actions for the user and event to the push
+        action staging area.
+
+        Args:
+            event_id (str)
+            user_id (str)
+            actions (list[dict|str]): An action can either be a string or
+                dict.
+
+        Returns:
+            Deferred
+        """
+
+        is_highlight = 1 if _action_has_highlight(actions) else 0
+
+        return self._simple_insert(
+            table="event_push_actions_staging",
+            values={
+                "event_id": event_id,
+                "user_id": user_id,
+                "actions": _serialize_action(actions, is_highlight),
+                "notif": 1,
+                "highlight": is_highlight,
+            },
+            desc="add_push_actions_to_staging",
+        )
+
+    def remove_push_actions_from_staging(self, event_id):
+        """Called if we failed to persist the event to ensure that stale push
+        actions don't build up in the DB
+
+        Args:
+            event_id (str)
+        """
+
+        return self._simple_delete(
+            table="event_push_actions_staging",
+            keyvalues={
+                "event_id": event_id,
+            },
+            desc="remove_push_actions_from_staging",
+        )
+
 
 def _action_has_highlight(actions):
     for action in actions:
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index bbb6aa992c..73177e0bc2 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -1168,10 +1168,9 @@ class EventsStore(SQLBaseStore):
 
         for event, context in events_and_contexts:
             # Insert all the push actions into the event_push_actions table.
-            if context.push_actions:
-                self._set_push_actions_for_event_and_users_txn(
-                    txn, event, context.push_actions
-                )
+            self._set_push_actions_for_event_and_users_txn(
+                txn, event,
+            )
 
             if event.type == EventTypes.Redaction and event.redacts is not None:
                 # Remove the entries in the event_push_actions table for the
@@ -2093,6 +2092,30 @@ class EventsStore(SQLBaseStore):
         #     state_groups
         #     state_groups_state
 
+        # we will build a temporary table listing the events so that we don't
+        # have to keep shovelling the list back and forth across the
+        # connection. Annoyingly the python sqlite driver commits the
+        # transaction on CREATE, so let's do this first.
+        #
+        # furthermore, we might already have the table from a previous (failed)
+        # purge attempt, so let's drop the table first.
+
+        txn.execute("DROP TABLE IF EXISTS events_to_purge")
+
+        txn.execute(
+            "CREATE TEMPORARY TABLE events_to_purge ("
+            "    event_id TEXT NOT NULL,"
+            "    should_delete BOOLEAN NOT NULL"
+            ")"
+        )
+
+        # create an index on should_delete because later we'll be looking for
+        # the should_delete / shouldn't_delete subsets
+        txn.execute(
+            "CREATE INDEX events_to_purge_should_delete"
+            " ON events_to_purge(should_delete)",
+        )
+
         # First ensure that we're not about to delete all the forward extremeties
         txn.execute(
             "SELECT e.event_id, e.depth FROM events as e "
@@ -2115,23 +2138,30 @@ class EventsStore(SQLBaseStore):
 
         logger.info("[purge] looking for events to delete")
 
+        should_delete_expr = "state_key IS NULL"
+        should_delete_params = ()
+        if not delete_local_events:
+            should_delete_expr += " AND event_id NOT LIKE ?"
+            should_delete_params += ("%:" + self.hs.hostname, )
+
+        should_delete_params += (room_id, topological_ordering)
+
+        txn.execute(
+            "INSERT INTO events_to_purge"
+            " SELECT event_id, %s"
+            " FROM events AS e LEFT JOIN state_events USING (event_id)"
+            " WHERE e.room_id = ? AND topological_ordering < ?" % (
+                should_delete_expr,
+            ),
+            should_delete_params,
+        )
         txn.execute(
-            "SELECT event_id, state_key FROM events"
-            " LEFT JOIN state_events USING (room_id, event_id)"
-            " WHERE room_id = ? AND topological_ordering < ?",
-            (room_id, topological_ordering,)
+            "SELECT event_id, should_delete FROM events_to_purge"
         )
         event_rows = txn.fetchall()
-
-        to_delete = [
-            (event_id,) for event_id, state_key in event_rows
-            if state_key is None and (
-                delete_local_events or not self.hs.is_mine_id(event_id)
-            )
-        ]
         logger.info(
             "[purge] found %i events before cutoff, of which %i can be deleted",
-            len(event_rows), len(to_delete),
+            len(event_rows), sum(1 for e in event_rows if e[1]),
         )
 
         logger.info("[purge] Finding new backward extremities")
@@ -2139,12 +2169,11 @@ class EventsStore(SQLBaseStore):
         # We calculate the new entries for the backward extremeties by finding
         # all events that point to events that are to be purged
         txn.execute(
-            "SELECT DISTINCT e.event_id FROM events as e"
-            " INNER JOIN event_edges as ed ON e.event_id = ed.prev_event_id"
-            " INNER JOIN events as e2 ON e2.event_id = ed.event_id"
-            " WHERE e.room_id = ? AND e.topological_ordering < ?"
-            " AND e2.topological_ordering >= ?",
-            (room_id, topological_ordering, topological_ordering)
+            "SELECT DISTINCT e.event_id FROM events_to_purge AS e"
+            " INNER JOIN event_edges AS ed ON e.event_id = ed.prev_event_id"
+            " INNER JOIN events AS e2 ON e2.event_id = ed.event_id"
+            " WHERE e2.topological_ordering >= ?",
+            (topological_ordering, )
         )
         new_backwards_extrems = txn.fetchall()
 
@@ -2172,12 +2201,11 @@ class EventsStore(SQLBaseStore):
             "SELECT state_group FROM event_to_state_groups"
             " INNER JOIN events USING (event_id)"
             " WHERE state_group IN ("
-            "   SELECT DISTINCT state_group FROM events"
+            "   SELECT DISTINCT state_group FROM events_to_purge"
             "   INNER JOIN event_to_state_groups USING (event_id)"
-            "   WHERE room_id = ? AND topological_ordering < ?"
             " )"
             " GROUP BY state_group HAVING MAX(topological_ordering) < ?",
-            (room_id, topological_ordering, topological_ordering)
+            (topological_ordering, )
         )
 
         state_rows = txn.fetchall()
@@ -2262,9 +2290,9 @@ class EventsStore(SQLBaseStore):
         )
 
         logger.info("[purge] removing events from event_to_state_groups")
-        txn.executemany(
-            "DELETE FROM event_to_state_groups WHERE event_id = ?",
-            [(event_id,) for event_id, _ in event_rows]
+        txn.execute(
+            "DELETE FROM event_to_state_groups "
+            "WHERE event_id IN (SELECT event_id from events_to_purge)"
         )
         for event_id, _ in event_rows:
             txn.call_after(self._get_state_group_for_event.invalidate, (
@@ -2281,7 +2309,6 @@ class EventsStore(SQLBaseStore):
             "event_edge_hashes",
             "event_edges",
             "event_forward_extremities",
-            "event_push_actions",
             "event_reference_hashes",
             "event_search",
             "event_signatures",
@@ -2289,22 +2316,35 @@ class EventsStore(SQLBaseStore):
         ):
             logger.info("[purge] removing events from %s", table)
 
-            txn.executemany(
-                "DELETE FROM %s WHERE event_id = ?" % (table,),
-                to_delete
+            txn.execute(
+                "DELETE FROM %s WHERE event_id IN ("
+                "    SELECT event_id FROM events_to_purge WHERE should_delete"
+                ")" % (table,),
+            )
+
+        # event_push_actions lacks an index on event_id, and has one on
+        # (room_id, event_id) instead.
+        for table in (
+            "event_push_actions",
+        ):
+            logger.info("[purge] removing events from %s", table)
+
+            txn.execute(
+                "DELETE FROM %s WHERE room_id = ? AND event_id IN ("
+                "    SELECT event_id FROM events_to_purge WHERE should_delete"
+                ")" % (table,),
+                (room_id, )
             )
 
         # Mark all state and own events as outliers
         logger.info("[purge] marking remaining events as outliers")
-        txn.executemany(
+        txn.execute(
             "UPDATE events SET outlier = ?"
-            " WHERE event_id = ?",
-            [
-                (True, event_id,) for event_id, state_key in event_rows
-                if state_key is not None or (
-                    not delete_local_events and self.hs.is_mine_id(event_id)
-                )
-            ]
+            " WHERE event_id IN ("
+            "    SELECT event_id FROM events_to_purge "
+            "    WHERE NOT should_delete"
+            ")",
+            (True,),
         )
 
         # synapse tries to take out an exclusive lock on room_depth whenever it
@@ -2319,6 +2359,12 @@ class EventsStore(SQLBaseStore):
             (topological_ordering, room_id,)
         )
 
+        # finally, drop the temp table. this will commit the txn in sqlite,
+        # so make sure to keep this actually last.
+        txn.execute(
+            "DROP TABLE events_to_purge"
+        )
+
         logger.info("[purge] done")
 
     @defer.inlineCallbacks
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 3aa810981f..95f75d6df1 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -39,12 +39,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
         # we no longer use refresh tokens, but it's possible that some people
         # might have a background update queued to build this index. Just
         # clear the background update.
-        @defer.inlineCallbacks
-        def noop_update(progress, batch_size):
-            yield self._end_background_update("refresh_tokens_device_index")
-            defer.returnValue(1)
-        self.register_background_update_handler(
-            "refresh_tokens_device_index", noop_update)
+        self.register_noop_background_update("refresh_tokens_device_index")
 
     @defer.inlineCallbacks
     def add_access_token_to_user(self, user_id, token, device_id=None):
diff --git a/synapse/storage/schema/delta/38/postgres_fts_gist.sql b/synapse/storage/schema/delta/38/postgres_fts_gist.sql
index f090a7b75a..515e6b8e84 100644
--- a/synapse/storage/schema/delta/38/postgres_fts_gist.sql
+++ b/synapse/storage/schema/delta/38/postgres_fts_gist.sql
@@ -13,5 +13,7 @@
  * limitations under the License.
  */
 
- INSERT into background_updates (update_name, progress_json)
-     VALUES ('event_search_postgres_gist', '{}');
+-- We no longer do this given we back it out again in schema 47
+
+-- INSERT into background_updates (update_name, progress_json)
+--     VALUES ('event_search_postgres_gist', '{}');
diff --git a/synapse/storage/schema/delta/47/postgres_fts_gin.sql b/synapse/storage/schema/delta/47/postgres_fts_gin.sql
new file mode 100644
index 0000000000..31d7a817eb
--- /dev/null
+++ b/synapse/storage/schema/delta/47/postgres_fts_gin.sql
@@ -0,0 +1,17 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT into background_updates (update_name, progress_json)
+    VALUES ('event_search_postgres_gin', '{}');
diff --git a/synapse/storage/schema/delta/47/push_actions_staging.sql b/synapse/storage/schema/delta/47/push_actions_staging.sql
new file mode 100644
index 0000000000..edccf4a96f
--- /dev/null
+++ b/synapse/storage/schema/delta/47/push_actions_staging.sql
@@ -0,0 +1,28 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Temporary staging area for push actions that have been calculated for an
+-- event, but the event hasn't yet been persisted.
+-- When the event is persisted the rows are moved over to the
+-- event_push_actions table.
+CREATE TABLE event_push_actions_staging (
+    event_id TEXT NOT NULL,
+    user_id TEXT NOT NULL,
+    actions TEXT NOT NULL,
+    notif SMALLINT NOT NULL,
+    highlight SMALLINT NOT NULL
+);
+
+CREATE INDEX event_push_actions_staging_id ON event_push_actions_staging(event_id);
diff --git a/synapse/storage/search.py b/synapse/storage/search.py
index f1ac9ba0fd..2755acff40 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/search.py
@@ -38,6 +38,7 @@ class SearchStore(BackgroundUpdateStore):
     EVENT_SEARCH_UPDATE_NAME = "event_search"
     EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
     EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
+    EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
 
     def __init__(self, db_conn, hs):
         super(SearchStore, self).__init__(db_conn, hs)
@@ -48,9 +49,19 @@ class SearchStore(BackgroundUpdateStore):
             self.EVENT_SEARCH_ORDER_UPDATE_NAME,
             self._background_reindex_search_order
         )
-        self.register_background_update_handler(
+
+        # we used to have a background update to turn the GIN index into a
+        # GIST one; we no longer do that (obviously) because we actually want
+        # a GIN index. However, it's possible that some people might still have
+        # the background update queued, so we register a handler to clear the
+        # background update.
+        self.register_noop_background_update(
             self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME,
-            self._background_reindex_gist_search
+        )
+
+        self.register_background_update_handler(
+            self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME,
+            self._background_reindex_gin_search
         )
 
     @defer.inlineCallbacks
@@ -151,25 +162,48 @@ class SearchStore(BackgroundUpdateStore):
         defer.returnValue(result)
 
     @defer.inlineCallbacks
-    def _background_reindex_gist_search(self, progress, batch_size):
+    def _background_reindex_gin_search(self, progress, batch_size):
+        """This handles old synapses which used GIST indexes, if any;
+        converting them back to be GIN as per the actual schema.
+        """
+
         def create_index(conn):
             conn.rollback()
-            conn.set_session(autocommit=True)
-            c = conn.cursor()
 
-            c.execute(
-                "CREATE INDEX CONCURRENTLY event_search_fts_idx_gist"
-                " ON event_search USING GIST (vector)"
-            )
+            # we have to set autocommit, because postgres refuses to
+            # CREATE INDEX CONCURRENTLY without it.
+            conn.set_session(autocommit=True)
 
-            c.execute("DROP INDEX event_search_fts_idx")
+            try:
+                c = conn.cursor()
 
-            conn.set_session(autocommit=False)
+                # if we skipped the conversion to GIST, we may already/still
+                # have an event_search_fts_idx; unfortunately postgres 9.4
+                # doesn't support CREATE INDEX IF EXISTS so we just catch the
+                # exception and ignore it.
+                import psycopg2
+                try:
+                    c.execute(
+                        "CREATE INDEX CONCURRENTLY event_search_fts_idx"
+                        " ON event_search USING GIN (vector)"
+                    )
+                except psycopg2.ProgrammingError as e:
+                    logger.warn(
+                        "Ignoring error %r when trying to switch from GIST to GIN",
+                        e
+                    )
+
+                # we should now be able to delete the GIST index.
+                c.execute(
+                    "DROP INDEX IF EXISTS event_search_fts_idx_gist"
+                )
+            finally:
+                conn.set_session(autocommit=False)
 
         if isinstance(self.database_engine, PostgresEngine):
             yield self.runWithConnection(create_index)
 
-        yield self._end_background_update(self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME)
+        yield self._end_background_update(self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME)
         defer.returnValue(1)
 
     @defer.inlineCallbacks
@@ -289,7 +323,30 @@ class SearchStore(BackgroundUpdateStore):
                 entry.stream_ordering, entry.origin_server_ts,
             ) for entry in entries)
 
+            # inserts to a GIN index are normally batched up into a pending
+            # list, and then all committed together once the list gets to a
+            # certain size. The trouble with that is that postgres (pre-9.5)
+            # uses work_mem to determine the length of the list, and work_mem
+            # is typically very large.
+            #
+            # We therefore reduce work_mem while we do the insert.
+            #
+            # (postgres 9.5 uses the separate gin_pending_list_limit setting,
+            # so doesn't suffer the same problem, but changing work_mem will
+            # be harmless)
+            #
+            # Note that we don't need to worry about restoring it on
+            # exception, because exceptions will cause the transaction to be
+            # rolled back, including the effects of the SET command.
+            #
+            # Also: we use SET rather than SET LOCAL because there's lots of
+            # other stuff going on in this transaction, which want to have the
+            # normal work_mem setting.
+
+            txn.execute("SET work_mem='256kB'")
             txn.executemany(sql, args)
+            txn.execute("RESET work_mem")
+
         elif isinstance(self.database_engine, Sqlite3Engine):
             sql = (
                 "INSERT INTO event_search (event_id, room_id, key, value)"
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index adb48df73e..2b325e1c1f 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -140,6 +140,20 @@ class StateGroupWorkerStore(SQLBaseStore):
         defer.returnValue(group_to_state)
 
     @defer.inlineCallbacks
+    def get_state_ids_for_group(self, state_group):
+        """Get the state IDs for the given state group
+
+        Args:
+            state_group (int)
+
+        Returns:
+            Deferred[dict]: Resolves to a map of (type, state_key) -> event_id
+        """
+        group_to_state = yield self._get_state_for_groups((state_group,))
+
+        defer.returnValue(group_to_state[state_group])
+
+    @defer.inlineCallbacks
     def get_state_groups(self, room_id, event_ids):
         """ Get the state groups for the given list of event_ids
 
@@ -655,6 +669,47 @@ class StateGroupWorkerStore(SQLBaseStore):
 
         return self.runInteraction("store_state_group", _store_state_group_txn)
 
+    def _count_state_group_hops_txn(self, txn, state_group):
+        """Given a state group, count how many hops there are in the tree.
+
+        This is used to ensure the delta chains don't get too long.
+        """
+        if isinstance(self.database_engine, PostgresEngine):
+            sql = ("""
+                WITH RECURSIVE state(state_group) AS (
+                    VALUES(?::bigint)
+                    UNION ALL
+                    SELECT prev_state_group FROM state_group_edges e, state s
+                    WHERE s.state_group = e.state_group
+                )
+                SELECT count(*) FROM state;
+            """)
+
+            txn.execute(sql, (state_group,))
+            row = txn.fetchone()
+            if row and row[0]:
+                return row[0]
+            else:
+                return 0
+        else:
+            # We don't use WITH RECURSIVE on sqlite3 as there are distributions
+            # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
+            next_group = state_group
+            count = 0
+
+            while next_group:
+                next_group = self._simple_select_one_onecol_txn(
+                    txn,
+                    table="state_group_edges",
+                    keyvalues={"state_group": next_group},
+                    retcol="prev_state_group",
+                    allow_none=True,
+                )
+                if next_group:
+                    count += 1
+
+            return count
+
 
 class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
     """ Keeps track of the state at a given event.
@@ -729,47 +784,6 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
                 (event_id,), state_group_id
             )
 
-    def _count_state_group_hops_txn(self, txn, state_group):
-        """Given a state group, count how many hops there are in the tree.
-
-        This is used to ensure the delta chains don't get too long.
-        """
-        if isinstance(self.database_engine, PostgresEngine):
-            sql = ("""
-                WITH RECURSIVE state(state_group) AS (
-                    VALUES(?::bigint)
-                    UNION ALL
-                    SELECT prev_state_group FROM state_group_edges e, state s
-                    WHERE s.state_group = e.state_group
-                )
-                SELECT count(*) FROM state;
-            """)
-
-            txn.execute(sql, (state_group,))
-            row = txn.fetchone()
-            if row and row[0]:
-                return row[0]
-            else:
-                return 0
-        else:
-            # We don't use WITH RECURSIVE on sqlite3 as there are distributions
-            # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
-            next_group = state_group
-            count = 0
-
-            while next_group:
-                next_group = self._simple_select_one_onecol_txn(
-                    txn,
-                    table="state_group_edges",
-                    keyvalues={"state_group": next_group},
-                    retcol="prev_state_group",
-                    allow_none=True,
-                )
-                if next_group:
-                    count += 1
-
-            return count
-
     @defer.inlineCallbacks
     def _background_deduplicate_state(self, progress, batch_size):
         """This background update will slowly deduplicate state by reencoding
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index f430cce931..4780f2ab72 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -230,7 +230,10 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
             state_handler = self.hs.get_state_handler()
             context = yield state_handler.compute_event_context(event)
 
-        context.push_actions = push_actions
+        for user_id, actions in push_actions:
+            yield self.master_store.add_push_actions_to_staging(
+                event.event_id, user_id, actions,
+            )
 
         ordering = None
         if backfill:
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 3135488353..d483e7cf9e 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -62,6 +62,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
                 {"notify_count": noitf_count, "highlight_count": highlight_count}
             )
 
+        @defer.inlineCallbacks
         def _inject_actions(stream, action):
             event = Mock()
             event.room_id = room_id
@@ -69,11 +70,12 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
             event.internal_metadata.stream_ordering = stream
             event.depth = stream
 
-            tuples = [(user_id, action)]
-
-            return self.store.runInteraction(
+            yield self.store.add_push_actions_to_staging(
+                event.event_id, user_id, action,
+            )
+            yield self.store.runInteraction(
                 "", self.store._set_push_actions_for_event_and_users_txn,
-                event, tuples
+                event,
             )
 
         def _rotate(stream):