summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2018-02-23 11:46:24 +0000
committerErik Johnston <erik@matrix.org>2018-02-23 11:46:24 +0000
commitd62ce972f8170ae8d1708fb4b4a47a6a1938f4d9 (patch)
tree8cea5b32d1569652a1baf68347d6f353124dcc21 /synapse
parentUpdate copyright (diff)
parentMerge pull request #2900 from matrix-org/erikj/split_event_push_actions (diff)
downloadsynapse-d62ce972f8170ae8d1708fb4b4a47a6a1938f4d9.tar.xz
Merge branch 'develop' of github.com:matrix-org/synapse into erikj/split_roommember_store
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/federation.py26
-rw-r--r--synapse/replication/slave/storage/account_data.py43
-rw-r--r--synapse/replication/slave/storage/events.py28
-rw-r--r--synapse/replication/slave/storage/push_rule.py24
-rw-r--r--synapse/replication/slave/storage/pushers.py12
-rw-r--r--synapse/storage/__init__.py21
-rw-r--r--synapse/storage/account_data.py76
-rw-r--r--synapse/storage/event_push_actions.py145
-rw-r--r--synapse/storage/events.py372
-rw-r--r--synapse/storage/events_worker.py395
-rw-r--r--synapse/storage/push_rule.py72
-rw-r--r--synapse/storage/pusher.py11
-rw-r--r--synapse/storage/tags.py16
13 files changed, 647 insertions, 594 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 46bcf8b081..8832ba58bc 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1447,16 +1447,24 @@ class FederationHandler(BaseHandler):
             auth_events=auth_events,
         )
 
-        if not event.internal_metadata.is_outlier() and not backfilled:
-            yield self.action_generator.handle_push_actions_for_event(
-                event, context
-            )
+        try:
+            if not event.internal_metadata.is_outlier() and not backfilled:
+                yield self.action_generator.handle_push_actions_for_event(
+                    event, context
+                )
 
-        event_stream_id, max_stream_id = yield self.store.persist_event(
-            event,
-            context=context,
-            backfilled=backfilled,
-        )
+            event_stream_id, max_stream_id = yield self.store.persist_event(
+                event,
+                context=context,
+                backfilled=backfilled,
+            )
+        except:  # noqa: E722, as we reraise the exception this is fine.
+            # Ensure that we actually remove the entries in the push actions
+            # staging area
+            logcontext.preserve_fn(
+                self.store.remove_push_actions_from_staging
+            )(event.event_id)
+            raise
 
         if not backfilled:
             # this intentionally does not yield: we don't care about the result
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index efbd87918e..6c8d2954d7 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,50 +14,20 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import BaseSlavedStore
-from ._slaved_id_tracker import SlavedIdTracker
-from synapse.storage import DataStore
-from synapse.storage.account_data import AccountDataStore
-from synapse.storage.tags import TagsStore
-from synapse.util.caches.stream_change_cache import StreamChangeCache
+from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.storage.account_data import AccountDataWorkerStore
+from synapse.storage.tags import TagsWorkerStore
 
 
-class SlavedAccountDataStore(BaseSlavedStore):
+class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
 
     def __init__(self, db_conn, hs):
-        super(SlavedAccountDataStore, self).__init__(db_conn, hs)
         self._account_data_id_gen = SlavedIdTracker(
             db_conn, "account_data_max_stream_id", "stream_id",
         )
-        self._account_data_stream_cache = StreamChangeCache(
-            "AccountDataAndTagsChangeCache",
-            self._account_data_id_gen.get_current_token(),
-        )
-
-    get_account_data_for_user = (
-        AccountDataStore.__dict__["get_account_data_for_user"]
-    )
-
-    get_global_account_data_by_type_for_users = (
-        AccountDataStore.__dict__["get_global_account_data_by_type_for_users"]
-    )
 
-    get_global_account_data_by_type_for_user = (
-        AccountDataStore.__dict__["get_global_account_data_by_type_for_user"]
-    )
-
-    get_tags_for_user = TagsStore.__dict__["get_tags_for_user"]
-    get_tags_for_room = (
-        DataStore.get_tags_for_room.__func__
-    )
-    get_account_data_for_room = (
-        DataStore.get_account_data_for_room.__func__
-    )
-
-    get_updated_tags = DataStore.get_updated_tags.__func__
-    get_updated_account_data_for_user = (
-        DataStore.get_updated_account_data_for_user.__func__
-    )
+        super(SlavedAccountDataStore, self).__init__(db_conn, hs)
 
     def get_max_account_data_stream_id(self):
         return self._account_data_id_gen.get_current_token()
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index ef7a42d801..de0b26f437 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -18,8 +18,8 @@ import logging
 from synapse.api.constants import EventTypes
 from synapse.storage import DataStore
 from synapse.storage.event_federation import EventFederationStore
-from synapse.storage.event_push_actions import EventPushActionsStore
-from synapse.storage.events import EventsWorkerStore
+from synapse.storage.event_push_actions import EventPushActionsWorkerStore
+from synapse.storage.events_worker import EventsWorkerStore
 from synapse.storage.roommember import RoomMemberWorkerStore
 from synapse.storage.state import StateGroupWorkerStore
 from synapse.storage.stream import StreamStore
@@ -40,8 +40,9 @@ logger = logging.getLogger(__name__)
 # the method descriptor on the DataStore and chuck them into our class.
 
 
-class SlavedEventStore(RoomMemberWorkerStore, EventsWorkerStore,
-                       StateGroupWorkerStore, BaseSlavedStore):
+class SlavedEventStore(RoomMemberWorkerStore, EventPushActionsWorkerStore,
+                       EventsWorkerStore, StateGroupWorkerStore,
+                       BaseSlavedStore):
 
     def __init__(self, db_conn, hs):
         super(SlavedEventStore, self).__init__(db_conn, hs)
@@ -74,29 +75,12 @@ class SlavedEventStore(RoomMemberWorkerStore, EventsWorkerStore,
     get_latest_event_ids_in_room = EventFederationStore.__dict__[
         "get_latest_event_ids_in_room"
     ]
-    get_unread_event_push_actions_by_room_for_user = (
-        EventPushActionsStore.__dict__["get_unread_event_push_actions_by_room_for_user"]
-    )
-    _get_unread_counts_by_receipt_txn = (
-        DataStore._get_unread_counts_by_receipt_txn.__func__
-    )
-    _get_unread_counts_by_pos_txn = (
-        DataStore._get_unread_counts_by_pos_txn.__func__
-    )
+
     get_recent_event_ids_for_room = (
         StreamStore.__dict__["get_recent_event_ids_for_room"]
     )
     has_room_changed_since = DataStore.has_room_changed_since.__func__
 
-    get_unread_push_actions_for_user_in_range_for_http = (
-        DataStore.get_unread_push_actions_for_user_in_range_for_http.__func__
-    )
-    get_unread_push_actions_for_user_in_range_for_email = (
-        DataStore.get_unread_push_actions_for_user_in_range_for_email.__func__
-    )
-    get_push_action_users_in_range = (
-        DataStore.get_push_action_users_in_range.__func__
-    )
     get_membership_changes_for_user = (
         DataStore.get_membership_changes_for_user.__func__
     )
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 83e880fdd2..bb2c40b6e3 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -15,29 +16,15 @@
 
 from .events import SlavedEventStore
 from ._slaved_id_tracker import SlavedIdTracker
-from synapse.storage import DataStore
-from synapse.storage.push_rule import PushRuleStore
-from synapse.util.caches.stream_change_cache import StreamChangeCache
+from synapse.storage.push_rule import PushRulesWorkerStore
 
 
-class SlavedPushRuleStore(SlavedEventStore):
+class SlavedPushRuleStore(PushRulesWorkerStore, SlavedEventStore):
     def __init__(self, db_conn, hs):
-        super(SlavedPushRuleStore, self).__init__(db_conn, hs)
         self._push_rules_stream_id_gen = SlavedIdTracker(
             db_conn, "push_rules_stream", "stream_id",
         )
-        self.push_rules_stream_cache = StreamChangeCache(
-            "PushRulesStreamChangeCache",
-            self._push_rules_stream_id_gen.get_current_token(),
-        )
-
-    get_push_rules_for_user = PushRuleStore.__dict__["get_push_rules_for_user"]
-    get_push_rules_enabled_for_user = (
-        PushRuleStore.__dict__["get_push_rules_enabled_for_user"]
-    )
-    have_push_rules_changed_for_user = (
-        DataStore.have_push_rules_changed_for_user.__func__
-    )
+        super(SlavedPushRuleStore, self).__init__(db_conn, hs)
 
     def get_push_rules_stream_token(self):
         return (
@@ -45,6 +32,9 @@ class SlavedPushRuleStore(SlavedEventStore):
             self._stream_id_gen.get_current_token(),
         )
 
+    def get_max_push_rules_stream_id(self):
+        return self._push_rules_stream_id_gen.get_current_token()
+
     def stream_positions(self):
         result = super(SlavedPushRuleStore, self).stream_positions()
         result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index 4e8d68ece9..a7cd5a7291 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -16,10 +17,10 @@
 from ._base import BaseSlavedStore
 from ._slaved_id_tracker import SlavedIdTracker
 
-from synapse.storage import DataStore
+from synapse.storage.pusher import PusherWorkerStore
 
 
-class SlavedPusherStore(BaseSlavedStore):
+class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
 
     def __init__(self, db_conn, hs):
         super(SlavedPusherStore, self).__init__(db_conn, hs)
@@ -28,13 +29,6 @@ class SlavedPusherStore(BaseSlavedStore):
             extra_tables=[("deleted_pushers", "stream_id")],
         )
 
-    get_all_pushers = DataStore.get_all_pushers.__func__
-    get_pushers_by = DataStore.get_pushers_by.__func__
-    get_pushers_by_app_id_and_pushkey = (
-        DataStore.get_pushers_by_app_id_and_pushkey.__func__
-    )
-    _decode_pushers_rows = DataStore._decode_pushers_rows.__func__
-
     def stream_positions(self):
         result = super(SlavedPusherStore, self).stream_positions()
         result["pushers"] = self._pushers_id_gen.get_current_token()
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index e1c4fe086e..0f136f8a06 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -104,9 +105,6 @@ class DataStore(RoomMemberStore, RoomStore,
             db_conn, "events", "stream_ordering", step=-1,
             extra_tables=[("ex_outlier_stream", "event_stream_ordering")]
         )
-        self._account_data_id_gen = StreamIdGenerator(
-            db_conn, "account_data_max_stream_id", "stream_id"
-        )
         self._presence_id_gen = StreamIdGenerator(
             db_conn, "presence_stream", "stream_id"
         )
@@ -159,11 +157,6 @@ class DataStore(RoomMemberStore, RoomStore,
             "MembershipStreamChangeCache", events_max,
         )
 
-        account_max = self._account_data_id_gen.get_current_token()
-        self._account_data_stream_cache = StreamChangeCache(
-            "AccountDataAndTagsChangeCache", account_max,
-        )
-
         self._presence_on_startup = self._get_active_presence(db_conn)
 
         presence_cache_prefill, min_presence_val = self._get_cache_dict(
@@ -177,18 +170,6 @@ class DataStore(RoomMemberStore, RoomStore,
             prefilled_cache=presence_cache_prefill
         )
 
-        push_rules_prefill, push_rules_id = self._get_cache_dict(
-            db_conn, "push_rules_stream",
-            entity_column="user_id",
-            stream_column="stream_id",
-            max_value=self._push_rules_stream_id_gen.get_current_token()[0],
-        )
-
-        self.push_rules_stream_cache = StreamChangeCache(
-            "PushRulesStreamChangeCache", push_rules_id,
-            prefilled_cache=push_rules_prefill,
-        )
-
         max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
         device_inbox_prefill, min_device_inbox_id = self._get_cache_dict(
             db_conn, "device_inbox",
diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py
index 56a0bde549..466194e96f 100644
--- a/synapse/storage/account_data.py
+++ b/synapse/storage/account_data.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,18 +14,46 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import SQLBaseStore
 from twisted.internet import defer
 
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.util.id_generators import StreamIdGenerator
+
+from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
 
+import abc
 import ujson as json
 import logging
 
 logger = logging.getLogger(__name__)
 
 
-class AccountDataStore(SQLBaseStore):
+class AccountDataWorkerStore(SQLBaseStore):
+    """This is an abstract base class where subclasses must implement
+    `get_max_account_data_stream_id` which can be called in the initializer.
+    """
+
+    # This ABCMeta metaclass ensures that we cannot be instantiated without
+    # the abstract methods being implemented.
+    __metaclass__ = abc.ABCMeta
+
+    def __init__(self, db_conn, hs):
+        account_max = self.get_max_account_data_stream_id()
+        self._account_data_stream_cache = StreamChangeCache(
+            "AccountDataAndTagsChangeCache", account_max,
+        )
+
+        super(AccountDataWorkerStore, self).__init__(db_conn, hs)
+
+    @abc.abstractmethod
+    def get_max_account_data_stream_id(self):
+        """Get the current max stream ID for account data stream
+
+        Returns:
+            int
+        """
+        raise NotImplementedError()
 
     @cached()
     def get_account_data_for_user(self, user_id):
@@ -209,6 +238,36 @@ class AccountDataStore(SQLBaseStore):
             "get_updated_account_data_for_user", get_updated_account_data_for_user_txn
         )
 
+    @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
+    def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):
+        ignored_account_data = yield self.get_global_account_data_by_type_for_user(
+            "m.ignored_user_list", ignorer_user_id,
+            on_invalidate=cache_context.invalidate,
+        )
+        if not ignored_account_data:
+            defer.returnValue(False)
+
+        defer.returnValue(
+            ignored_user_id in ignored_account_data.get("ignored_users", {})
+        )
+
+
+class AccountDataStore(AccountDataWorkerStore):
+    def __init__(self, db_conn, hs):
+        self._account_data_id_gen = StreamIdGenerator(
+            db_conn, "account_data_max_stream_id", "stream_id"
+        )
+
+        super(AccountDataStore, self).__init__(db_conn, hs)
+
+    def get_max_account_data_stream_id(self):
+        """Get the current max stream id for the private user data stream
+
+        Returns:
+            A deferred int.
+        """
+        return self._account_data_id_gen.get_current_token()
+
     @defer.inlineCallbacks
     def add_account_data_to_room(self, user_id, room_id, account_data_type, content):
         """Add some account_data to a room for a user.
@@ -321,16 +380,3 @@ class AccountDataStore(SQLBaseStore):
             "update_account_data_max_stream_id",
             _update,
         )
-
-    @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
-    def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):
-        ignored_account_data = yield self.get_global_account_data_by_type_for_user(
-            "m.ignored_user_list", ignorer_user_id,
-            on_invalidate=cache_context.invalidate,
-        )
-        if not ignored_account_data:
-            defer.returnValue(False)
-
-        defer.returnValue(
-            ignored_user_id in ignored_account_data.get("ignored_users", {})
-        )
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index f787431b7a..4cabf70ad0 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2015 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -62,77 +63,7 @@ def _deserialize_action(actions, is_highlight):
         return DEFAULT_NOTIF_ACTION
 
 
-class EventPushActionsStore(SQLBaseStore):
-    EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
-
-    def __init__(self, db_conn, hs):
-        super(EventPushActionsStore, self).__init__(db_conn, hs)
-
-        self.register_background_index_update(
-            self.EPA_HIGHLIGHT_INDEX,
-            index_name="event_push_actions_u_highlight",
-            table="event_push_actions",
-            columns=["user_id", "stream_ordering"],
-        )
-
-        self.register_background_index_update(
-            "event_push_actions_highlights_index",
-            index_name="event_push_actions_highlights_index",
-            table="event_push_actions",
-            columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
-            where_clause="highlight=1"
-        )
-
-        self._doing_notif_rotation = False
-        self._rotate_notif_loop = self._clock.looping_call(
-            self._rotate_notifs, 30 * 60 * 1000
-        )
-
-    def _set_push_actions_for_event_and_users_txn(self, txn, event):
-        """
-        Args:
-            event: the event set actions for
-            tuples: list of tuples of (user_id, actions)
-        """
-
-        sql = """
-            INSERT INTO event_push_actions (
-                room_id, event_id, user_id, actions, stream_ordering,
-                topological_ordering, notif, highlight
-            )
-            SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight
-            FROM event_push_actions_staging
-            WHERE event_id = ?
-        """
-
-        txn.execute(sql, (
-            event.room_id, event.internal_metadata.stream_ordering,
-            event.depth, event.event_id,
-        ))
-
-        user_ids = self._simple_select_onecol_txn(
-            txn,
-            table="event_push_actions_staging",
-            keyvalues={
-                "event_id": event.event_id,
-            },
-            retcol="user_id",
-        )
-
-        self._simple_delete_txn(
-            txn,
-            table="event_push_actions_staging",
-            keyvalues={
-                "event_id": event.event_id,
-            },
-        )
-
-        for uid in user_ids:
-            txn.call_after(
-                self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
-                (event.room_id, uid,)
-            )
-
+class EventPushActionsWorkerStore(SQLBaseStore):
     @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
     def get_unread_event_push_actions_by_room_for_user(
             self, room_id, user_id, last_read_event_id
@@ -449,6 +380,78 @@ class EventPushActionsStore(SQLBaseStore):
         # Now return the first `limit`
         defer.returnValue(notifs[:limit])
 
+
+class EventPushActionsStore(EventPushActionsWorkerStore):
+    EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
+
+    def __init__(self, db_conn, hs):
+        super(EventPushActionsStore, self).__init__(db_conn, hs)
+
+        self.register_background_index_update(
+            self.EPA_HIGHLIGHT_INDEX,
+            index_name="event_push_actions_u_highlight",
+            table="event_push_actions",
+            columns=["user_id", "stream_ordering"],
+        )
+
+        self.register_background_index_update(
+            "event_push_actions_highlights_index",
+            index_name="event_push_actions_highlights_index",
+            table="event_push_actions",
+            columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
+            where_clause="highlight=1"
+        )
+
+        self._doing_notif_rotation = False
+        self._rotate_notif_loop = self._clock.looping_call(
+            self._rotate_notifs, 30 * 60 * 1000
+        )
+
+    def _set_push_actions_for_event_and_users_txn(self, txn, event):
+        """
+        Args:
+            event: the event set actions for
+            tuples: list of tuples of (user_id, actions)
+        """
+
+        sql = """
+            INSERT INTO event_push_actions (
+                room_id, event_id, user_id, actions, stream_ordering,
+                topological_ordering, notif, highlight
+            )
+            SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight
+            FROM event_push_actions_staging
+            WHERE event_id = ?
+        """
+
+        txn.execute(sql, (
+            event.room_id, event.internal_metadata.stream_ordering,
+            event.depth, event.event_id,
+        ))
+
+        user_ids = self._simple_select_onecol_txn(
+            txn,
+            table="event_push_actions_staging",
+            keyvalues={
+                "event_id": event.event_id,
+            },
+            retcol="user_id",
+        )
+
+        self._simple_delete_txn(
+            txn,
+            table="event_push_actions_staging",
+            keyvalues={
+                "event_id": event.event_id,
+            },
+        )
+
+        for uid in user_ids:
+            txn.call_after(
+                self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
+                (event.room_id, uid,)
+            )
+
     @defer.inlineCallbacks
     def get_push_actions_for_user(self, user_id, before=None, limit=50,
                                   only_highlight=False):
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 3e3229b408..99d6cca585 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -13,16 +13,16 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from ._base import SQLBaseStore
 
-from twisted.internet import defer, reactor
+from synapse.storage.events_worker import EventsWorkerStore
 
-from synapse.events import FrozenEvent, USE_FROZEN_DICTS
-from synapse.events.utils import prune_event
+from twisted.internet import defer
+
+from synapse.events import USE_FROZEN_DICTS
 
 from synapse.util.async import ObservableDeferred
 from synapse.util.logcontext import (
-    preserve_fn, PreserveLoggingContext, make_deferred_yieldable
+    PreserveLoggingContext, make_deferred_yieldable
 )
 from synapse.util.logutils import log_function
 from synapse.util.metrics import Measure
@@ -62,16 +62,6 @@ def encode_json(json_object):
         return json.dumps(json_object, ensure_ascii=False)
 
 
-# These values are used in the `enqueus_event` and `_do_fetch` methods to
-# control how we batch/bulk fetch events from the database.
-# The values are plucked out of thing air to make initial sync run faster
-# on jki.re
-# TODO: Make these configurable.
-EVENT_QUEUE_THREADS = 3  # Max number of threads that will fetch events
-EVENT_QUEUE_ITERATIONS = 3  # No. times we block waiting for requests for events
-EVENT_QUEUE_TIMEOUT_S = 0.1  # Timeout when waiting for requests for events
-
-
 class _EventPeristenceQueue(object):
     """Queues up events so that they can be persisted in bulk with only one
     concurrent transaction per room.
@@ -200,362 +190,12 @@ def _retry_on_integrity_error(func):
     return f
 
 
-class EventsWorkerStore(SQLBaseStore):
-    def __init__(self, db_conn, hs):
-        super(EventsWorkerStore, self).__init__(db_conn, hs)
-
-        self._event_persist_queue = _EventPeristenceQueue()
-
-    @defer.inlineCallbacks
-    def get_event(self, event_id, check_redacted=True,
-                  get_prev_content=False, allow_rejected=False,
-                  allow_none=False):
-        """Get an event from the database by event_id.
-
-        Args:
-            event_id (str): The event_id of the event to fetch
-            check_redacted (bool): If True, check if event has been redacted
-                and redact it.
-            get_prev_content (bool): If True and event is a state event,
-                include the previous states content in the unsigned field.
-            allow_rejected (bool): If True return rejected events.
-            allow_none (bool): If True, return None if no event found, if
-                False throw an exception.
-
-        Returns:
-            Deferred : A FrozenEvent.
-        """
-        events = yield self._get_events(
-            [event_id],
-            check_redacted=check_redacted,
-            get_prev_content=get_prev_content,
-            allow_rejected=allow_rejected,
-        )
-
-        if not events and not allow_none:
-            raise SynapseError(404, "Could not find event %s" % (event_id,))
-
-        defer.returnValue(events[0] if events else None)
-
-    @defer.inlineCallbacks
-    def get_events(self, event_ids, check_redacted=True,
-                   get_prev_content=False, allow_rejected=False):
-        """Get events from the database
-
-        Args:
-            event_ids (list): The event_ids of the events to fetch
-            check_redacted (bool): If True, check if event has been redacted
-                and redact it.
-            get_prev_content (bool): If True and event is a state event,
-                include the previous states content in the unsigned field.
-            allow_rejected (bool): If True return rejected events.
-
-        Returns:
-            Deferred : Dict from event_id to event.
-        """
-        events = yield self._get_events(
-            event_ids,
-            check_redacted=check_redacted,
-            get_prev_content=get_prev_content,
-            allow_rejected=allow_rejected,
-        )
-
-        defer.returnValue({e.event_id: e for e in events})
-
-    @defer.inlineCallbacks
-    def _get_events(self, event_ids, check_redacted=True,
-                    get_prev_content=False, allow_rejected=False):
-        if not event_ids:
-            defer.returnValue([])
-
-        event_id_list = event_ids
-        event_ids = set(event_ids)
-
-        event_entry_map = self._get_events_from_cache(
-            event_ids,
-            allow_rejected=allow_rejected,
-        )
-
-        missing_events_ids = [e for e in event_ids if e not in event_entry_map]
-
-        if missing_events_ids:
-            missing_events = yield self._enqueue_events(
-                missing_events_ids,
-                check_redacted=check_redacted,
-                allow_rejected=allow_rejected,
-            )
-
-            event_entry_map.update(missing_events)
-
-        events = []
-        for event_id in event_id_list:
-            entry = event_entry_map.get(event_id, None)
-            if not entry:
-                continue
-
-            if allow_rejected or not entry.event.rejected_reason:
-                if check_redacted and entry.redacted_event:
-                    event = entry.redacted_event
-                else:
-                    event = entry.event
-
-                events.append(event)
-
-                if get_prev_content:
-                    if "replaces_state" in event.unsigned:
-                        prev = yield self.get_event(
-                            event.unsigned["replaces_state"],
-                            get_prev_content=False,
-                            allow_none=True,
-                        )
-                        if prev:
-                            event.unsigned = dict(event.unsigned)
-                            event.unsigned["prev_content"] = prev.content
-                            event.unsigned["prev_sender"] = prev.sender
-
-        defer.returnValue(events)
-
-    def _invalidate_get_event_cache(self, event_id):
-            self._get_event_cache.invalidate((event_id,))
-
-    def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
-        """Fetch events from the caches
-
-        Args:
-            events (list(str)): list of event_ids to fetch
-            allow_rejected (bool): Whether to teturn events that were rejected
-            update_metrics (bool): Whether to update the cache hit ratio metrics
-
-        Returns:
-            dict of event_id -> _EventCacheEntry for each event_id in cache. If
-            allow_rejected is `False` then there will still be an entry but it
-            will be `None`
-        """
-        event_map = {}
-
-        for event_id in events:
-            ret = self._get_event_cache.get(
-                (event_id,), None,
-                update_metrics=update_metrics,
-            )
-            if not ret:
-                continue
-
-            if allow_rejected or not ret.event.rejected_reason:
-                event_map[event_id] = ret
-            else:
-                event_map[event_id] = None
-
-        return event_map
-
-    def _do_fetch(self, conn):
-        """Takes a database connection and waits for requests for events from
-        the _event_fetch_list queue.
-        """
-        event_list = []
-        i = 0
-        while True:
-            try:
-                with self._event_fetch_lock:
-                    event_list = self._event_fetch_list
-                    self._event_fetch_list = []
-
-                    if not event_list:
-                        single_threaded = self.database_engine.single_threaded
-                        if single_threaded or i > EVENT_QUEUE_ITERATIONS:
-                            self._event_fetch_ongoing -= 1
-                            return
-                        else:
-                            self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
-                            i += 1
-                            continue
-                    i = 0
-
-                event_id_lists = zip(*event_list)[0]
-                event_ids = [
-                    item for sublist in event_id_lists for item in sublist
-                ]
-
-                rows = self._new_transaction(
-                    conn, "do_fetch", [], [], None, self._fetch_event_rows, event_ids
-                )
-
-                row_dict = {
-                    r["event_id"]: r
-                    for r in rows
-                }
-
-                # We only want to resolve deferreds from the main thread
-                def fire(lst, res):
-                    for ids, d in lst:
-                        if not d.called:
-                            try:
-                                with PreserveLoggingContext():
-                                    d.callback([
-                                        res[i]
-                                        for i in ids
-                                        if i in res
-                                    ])
-                            except Exception:
-                                logger.exception("Failed to callback")
-                with PreserveLoggingContext():
-                    reactor.callFromThread(fire, event_list, row_dict)
-            except Exception as e:
-                logger.exception("do_fetch")
-
-                # We only want to resolve deferreds from the main thread
-                def fire(evs):
-                    for _, d in evs:
-                        if not d.called:
-                            with PreserveLoggingContext():
-                                d.errback(e)
-
-                if event_list:
-                    with PreserveLoggingContext():
-                        reactor.callFromThread(fire, event_list)
-
-    @defer.inlineCallbacks
-    def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
-        """Fetches events from the database using the _event_fetch_list. This
-        allows batch and bulk fetching of events - it allows us to fetch events
-        without having to create a new transaction for each request for events.
-        """
-        if not events:
-            defer.returnValue({})
-
-        events_d = defer.Deferred()
-        with self._event_fetch_lock:
-            self._event_fetch_list.append(
-                (events, events_d)
-            )
-
-            self._event_fetch_lock.notify()
-
-            if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
-                self._event_fetch_ongoing += 1
-                should_start = True
-            else:
-                should_start = False
-
-        if should_start:
-            with PreserveLoggingContext():
-                self.runWithConnection(
-                    self._do_fetch
-                )
-
-        logger.debug("Loading %d events", len(events))
-        with PreserveLoggingContext():
-            rows = yield events_d
-        logger.debug("Loaded %d events (%d rows)", len(events), len(rows))
-
-        if not allow_rejected:
-            rows[:] = [r for r in rows if not r["rejects"]]
-
-        res = yield make_deferred_yieldable(defer.gatherResults(
-            [
-                preserve_fn(self._get_event_from_row)(
-                    row["internal_metadata"], row["json"], row["redacts"],
-                    rejected_reason=row["rejects"],
-                )
-                for row in rows
-            ],
-            consumeErrors=True
-        ))
-
-        defer.returnValue({
-            e.event.event_id: e
-            for e in res if e
-        })
-
-    def _fetch_event_rows(self, txn, events):
-        rows = []
-        N = 200
-        for i in range(1 + len(events) / N):
-            evs = events[i * N:(i + 1) * N]
-            if not evs:
-                break
-
-            sql = (
-                "SELECT "
-                " e.event_id as event_id, "
-                " e.internal_metadata,"
-                " e.json,"
-                " r.redacts as redacts,"
-                " rej.event_id as rejects "
-                " FROM event_json as e"
-                " LEFT JOIN rejections as rej USING (event_id)"
-                " LEFT JOIN redactions as r ON e.event_id = r.redacts"
-                " WHERE e.event_id IN (%s)"
-            ) % (",".join(["?"] * len(evs)),)
-
-            txn.execute(sql, evs)
-            rows.extend(self.cursor_to_dict(txn))
-
-        return rows
-
-    @defer.inlineCallbacks
-    def _get_event_from_row(self, internal_metadata, js, redacted,
-                            rejected_reason=None):
-        with Measure(self._clock, "_get_event_from_row"):
-            d = json.loads(js)
-            internal_metadata = json.loads(internal_metadata)
-
-            if rejected_reason:
-                rejected_reason = yield self._simple_select_one_onecol(
-                    table="rejections",
-                    keyvalues={"event_id": rejected_reason},
-                    retcol="reason",
-                    desc="_get_event_from_row_rejected_reason",
-                )
-
-            original_ev = FrozenEvent(
-                d,
-                internal_metadata_dict=internal_metadata,
-                rejected_reason=rejected_reason,
-            )
-
-            redacted_event = None
-            if redacted:
-                redacted_event = prune_event(original_ev)
-
-                redaction_id = yield self._simple_select_one_onecol(
-                    table="redactions",
-                    keyvalues={"redacts": redacted_event.event_id},
-                    retcol="event_id",
-                    desc="_get_event_from_row_redactions",
-                )
-
-                redacted_event.unsigned["redacted_by"] = redaction_id
-                # Get the redaction event.
-
-                because = yield self.get_event(
-                    redaction_id,
-                    check_redacted=False,
-                    allow_none=True,
-                )
-
-                if because:
-                    # It's fine to do add the event directly, since get_pdu_json
-                    # will serialise this field correctly
-                    redacted_event.unsigned["redacted_because"] = because
-
-            cache_entry = _EventCacheEntry(
-                event=original_ev,
-                redacted_event=redacted_event,
-            )
-
-            self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
-
-        defer.returnValue(cache_entry)
-
-
 class EventsStore(EventsWorkerStore):
     EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
     EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
 
     def __init__(self, db_conn, hs):
         super(EventsStore, self).__init__(db_conn, hs)
-        self._clock = hs.get_clock()
         self.register_background_update_handler(
             self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
         )
@@ -584,6 +224,8 @@ class EventsStore(EventsWorkerStore):
             psql_only=True,
         )
 
+        self._event_persist_queue = _EventPeristenceQueue()
+
         self._state_resolution_handler = hs.get_state_resolution_handler()
 
     def persist_events(self, events_and_contexts, backfilled=False):
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
new file mode 100644
index 0000000000..86c3b48ad4
--- /dev/null
+++ b/synapse/storage/events_worker.py
@@ -0,0 +1,395 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ._base import SQLBaseStore
+
+from twisted.internet import defer, reactor
+
+from synapse.events import FrozenEvent
+from synapse.events.utils import prune_event
+
+from synapse.util.logcontext import (
+    preserve_fn, PreserveLoggingContext, make_deferred_yieldable
+)
+from synapse.util.metrics import Measure
+from synapse.api.errors import SynapseError
+
+from collections import namedtuple
+
+import logging
+import ujson as json
+
+# these are only included to make the type annotations work
+from synapse.events import EventBase    # noqa: F401
+from synapse.events.snapshot import EventContext   # noqa: F401
+
+logger = logging.getLogger(__name__)
+
+
+# These values are used in the `enqueus_event` and `_do_fetch` methods to
+# control how we batch/bulk fetch events from the database.
+# The values are plucked out of thing air to make initial sync run faster
+# on jki.re
+# TODO: Make these configurable.
+EVENT_QUEUE_THREADS = 3  # Max number of threads that will fetch events
+EVENT_QUEUE_ITERATIONS = 3  # No. times we block waiting for requests for events
+EVENT_QUEUE_TIMEOUT_S = 0.1  # Timeout when waiting for requests for events
+
+
+_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
+
+
+class EventsWorkerStore(SQLBaseStore):
+
+    @defer.inlineCallbacks
+    def get_event(self, event_id, check_redacted=True,
+                  get_prev_content=False, allow_rejected=False,
+                  allow_none=False):
+        """Get an event from the database by event_id.
+
+        Args:
+            event_id (str): The event_id of the event to fetch
+            check_redacted (bool): If True, check if event has been redacted
+                and redact it.
+            get_prev_content (bool): If True and event is a state event,
+                include the previous states content in the unsigned field.
+            allow_rejected (bool): If True return rejected events.
+            allow_none (bool): If True, return None if no event found, if
+                False throw an exception.
+
+        Returns:
+            Deferred : A FrozenEvent.
+        """
+        events = yield self._get_events(
+            [event_id],
+            check_redacted=check_redacted,
+            get_prev_content=get_prev_content,
+            allow_rejected=allow_rejected,
+        )
+
+        if not events and not allow_none:
+            raise SynapseError(404, "Could not find event %s" % (event_id,))
+
+        defer.returnValue(events[0] if events else None)
+
+    @defer.inlineCallbacks
+    def get_events(self, event_ids, check_redacted=True,
+                   get_prev_content=False, allow_rejected=False):
+        """Get events from the database
+
+        Args:
+            event_ids (list): The event_ids of the events to fetch
+            check_redacted (bool): If True, check if event has been redacted
+                and redact it.
+            get_prev_content (bool): If True and event is a state event,
+                include the previous states content in the unsigned field.
+            allow_rejected (bool): If True return rejected events.
+
+        Returns:
+            Deferred : Dict from event_id to event.
+        """
+        events = yield self._get_events(
+            event_ids,
+            check_redacted=check_redacted,
+            get_prev_content=get_prev_content,
+            allow_rejected=allow_rejected,
+        )
+
+        defer.returnValue({e.event_id: e for e in events})
+
+    @defer.inlineCallbacks
+    def _get_events(self, event_ids, check_redacted=True,
+                    get_prev_content=False, allow_rejected=False):
+        if not event_ids:
+            defer.returnValue([])
+
+        event_id_list = event_ids
+        event_ids = set(event_ids)
+
+        event_entry_map = self._get_events_from_cache(
+            event_ids,
+            allow_rejected=allow_rejected,
+        )
+
+        missing_events_ids = [e for e in event_ids if e not in event_entry_map]
+
+        if missing_events_ids:
+            missing_events = yield self._enqueue_events(
+                missing_events_ids,
+                check_redacted=check_redacted,
+                allow_rejected=allow_rejected,
+            )
+
+            event_entry_map.update(missing_events)
+
+        events = []
+        for event_id in event_id_list:
+            entry = event_entry_map.get(event_id, None)
+            if not entry:
+                continue
+
+            if allow_rejected or not entry.event.rejected_reason:
+                if check_redacted and entry.redacted_event:
+                    event = entry.redacted_event
+                else:
+                    event = entry.event
+
+                events.append(event)
+
+                if get_prev_content:
+                    if "replaces_state" in event.unsigned:
+                        prev = yield self.get_event(
+                            event.unsigned["replaces_state"],
+                            get_prev_content=False,
+                            allow_none=True,
+                        )
+                        if prev:
+                            event.unsigned = dict(event.unsigned)
+                            event.unsigned["prev_content"] = prev.content
+                            event.unsigned["prev_sender"] = prev.sender
+
+        defer.returnValue(events)
+
+    def _invalidate_get_event_cache(self, event_id):
+            self._get_event_cache.invalidate((event_id,))
+
+    def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
+        """Fetch events from the caches
+
+        Args:
+            events (list(str)): list of event_ids to fetch
+            allow_rejected (bool): Whether to teturn events that were rejected
+            update_metrics (bool): Whether to update the cache hit ratio metrics
+
+        Returns:
+            dict of event_id -> _EventCacheEntry for each event_id in cache. If
+            allow_rejected is `False` then there will still be an entry but it
+            will be `None`
+        """
+        event_map = {}
+
+        for event_id in events:
+            ret = self._get_event_cache.get(
+                (event_id,), None,
+                update_metrics=update_metrics,
+            )
+            if not ret:
+                continue
+
+            if allow_rejected or not ret.event.rejected_reason:
+                event_map[event_id] = ret
+            else:
+                event_map[event_id] = None
+
+        return event_map
+
+    def _do_fetch(self, conn):
+        """Takes a database connection and waits for requests for events from
+        the _event_fetch_list queue.
+        """
+        event_list = []
+        i = 0
+        while True:
+            try:
+                with self._event_fetch_lock:
+                    event_list = self._event_fetch_list
+                    self._event_fetch_list = []
+
+                    if not event_list:
+                        single_threaded = self.database_engine.single_threaded
+                        if single_threaded or i > EVENT_QUEUE_ITERATIONS:
+                            self._event_fetch_ongoing -= 1
+                            return
+                        else:
+                            self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
+                            i += 1
+                            continue
+                    i = 0
+
+                event_id_lists = zip(*event_list)[0]
+                event_ids = [
+                    item for sublist in event_id_lists for item in sublist
+                ]
+
+                rows = self._new_transaction(
+                    conn, "do_fetch", [], [], None, self._fetch_event_rows, event_ids
+                )
+
+                row_dict = {
+                    r["event_id"]: r
+                    for r in rows
+                }
+
+                # We only want to resolve deferreds from the main thread
+                def fire(lst, res):
+                    for ids, d in lst:
+                        if not d.called:
+                            try:
+                                with PreserveLoggingContext():
+                                    d.callback([
+                                        res[i]
+                                        for i in ids
+                                        if i in res
+                                    ])
+                            except Exception:
+                                logger.exception("Failed to callback")
+                with PreserveLoggingContext():
+                    reactor.callFromThread(fire, event_list, row_dict)
+            except Exception as e:
+                logger.exception("do_fetch")
+
+                # We only want to resolve deferreds from the main thread
+                def fire(evs):
+                    for _, d in evs:
+                        if not d.called:
+                            with PreserveLoggingContext():
+                                d.errback(e)
+
+                if event_list:
+                    with PreserveLoggingContext():
+                        reactor.callFromThread(fire, event_list)
+
+    @defer.inlineCallbacks
+    def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
+        """Fetches events from the database using the _event_fetch_list. This
+        allows batch and bulk fetching of events - it allows us to fetch events
+        without having to create a new transaction for each request for events.
+        """
+        if not events:
+            defer.returnValue({})
+
+        events_d = defer.Deferred()
+        with self._event_fetch_lock:
+            self._event_fetch_list.append(
+                (events, events_d)
+            )
+
+            self._event_fetch_lock.notify()
+
+            if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
+                self._event_fetch_ongoing += 1
+                should_start = True
+            else:
+                should_start = False
+
+        if should_start:
+            with PreserveLoggingContext():
+                self.runWithConnection(
+                    self._do_fetch
+                )
+
+        logger.debug("Loading %d events", len(events))
+        with PreserveLoggingContext():
+            rows = yield events_d
+        logger.debug("Loaded %d events (%d rows)", len(events), len(rows))
+
+        if not allow_rejected:
+            rows[:] = [r for r in rows if not r["rejects"]]
+
+        res = yield make_deferred_yieldable(defer.gatherResults(
+            [
+                preserve_fn(self._get_event_from_row)(
+                    row["internal_metadata"], row["json"], row["redacts"],
+                    rejected_reason=row["rejects"],
+                )
+                for row in rows
+            ],
+            consumeErrors=True
+        ))
+
+        defer.returnValue({
+            e.event.event_id: e
+            for e in res if e
+        })
+
+    def _fetch_event_rows(self, txn, events):
+        rows = []
+        N = 200
+        for i in range(1 + len(events) / N):
+            evs = events[i * N:(i + 1) * N]
+            if not evs:
+                break
+
+            sql = (
+                "SELECT "
+                " e.event_id as event_id, "
+                " e.internal_metadata,"
+                " e.json,"
+                " r.redacts as redacts,"
+                " rej.event_id as rejects "
+                " FROM event_json as e"
+                " LEFT JOIN rejections as rej USING (event_id)"
+                " LEFT JOIN redactions as r ON e.event_id = r.redacts"
+                " WHERE e.event_id IN (%s)"
+            ) % (",".join(["?"] * len(evs)),)
+
+            txn.execute(sql, evs)
+            rows.extend(self.cursor_to_dict(txn))
+
+        return rows
+
+    @defer.inlineCallbacks
+    def _get_event_from_row(self, internal_metadata, js, redacted,
+                            rejected_reason=None):
+        with Measure(self._clock, "_get_event_from_row"):
+            d = json.loads(js)
+            internal_metadata = json.loads(internal_metadata)
+
+            if rejected_reason:
+                rejected_reason = yield self._simple_select_one_onecol(
+                    table="rejections",
+                    keyvalues={"event_id": rejected_reason},
+                    retcol="reason",
+                    desc="_get_event_from_row_rejected_reason",
+                )
+
+            original_ev = FrozenEvent(
+                d,
+                internal_metadata_dict=internal_metadata,
+                rejected_reason=rejected_reason,
+            )
+
+            redacted_event = None
+            if redacted:
+                redacted_event = prune_event(original_ev)
+
+                redaction_id = yield self._simple_select_one_onecol(
+                    table="redactions",
+                    keyvalues={"redacts": redacted_event.event_id},
+                    retcol="event_id",
+                    desc="_get_event_from_row_redactions",
+                )
+
+                redacted_event.unsigned["redacted_by"] = redaction_id
+                # Get the redaction event.
+
+                because = yield self.get_event(
+                    redaction_id,
+                    check_redacted=False,
+                    allow_none=True,
+                )
+
+                if because:
+                    # It's fine to do add the event directly, since get_pdu_json
+                    # will serialise this field correctly
+                    redacted_event.unsigned["redacted_because"] = because
+
+            cache_entry = _EventCacheEntry(
+                event=original_ev,
+                redacted_event=redacted_event,
+            )
+
+            self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
+
+        defer.returnValue(cache_entry)
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 8758b1c0c7..583efb7bdf 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -15,10 +16,12 @@
 
 from ._base import SQLBaseStore
 from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
+from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.push.baserules import list_with_base_rules
 from synapse.api.constants import EventTypes
 from twisted.internet import defer
 
+import abc
 import logging
 import simplejson as json
 
@@ -48,7 +51,39 @@ def _load_rules(rawrules, enabled_map):
     return rules
 
 
-class PushRuleStore(SQLBaseStore):
+class PushRulesWorkerStore(SQLBaseStore):
+    """This is an abstract base class where subclasses must implement
+    `get_max_push_rules_stream_id` which can be called in the initializer.
+    """
+
+    # This ABCMeta metaclass ensures that we cannot be instantiated without
+    # the abstract methods being implemented.
+    __metaclass__ = abc.ABCMeta
+
+    def __init__(self, db_conn, hs):
+        super(PushRulesWorkerStore, self).__init__(db_conn, hs)
+
+        push_rules_prefill, push_rules_id = self._get_cache_dict(
+            db_conn, "push_rules_stream",
+            entity_column="user_id",
+            stream_column="stream_id",
+            max_value=self.get_max_push_rules_stream_id(),
+        )
+
+        self.push_rules_stream_cache = StreamChangeCache(
+            "PushRulesStreamChangeCache", push_rules_id,
+            prefilled_cache=push_rules_prefill,
+        )
+
+    @abc.abstractmethod
+    def get_max_push_rules_stream_id(self):
+        """Get the position of the push rules stream.
+
+        Returns:
+            int
+        """
+        raise NotImplementedError()
+
     @cachedInlineCallbacks(max_entries=5000)
     def get_push_rules_for_user(self, user_id):
         rows = yield self._simple_select_list(
@@ -89,6 +124,24 @@ class PushRuleStore(SQLBaseStore):
             r['rule_id']: False if r['enabled'] == 0 else True for r in results
         })
 
+    def have_push_rules_changed_for_user(self, user_id, last_id):
+        if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
+            return defer.succeed(False)
+        else:
+            def have_push_rules_changed_txn(txn):
+                sql = (
+                    "SELECT COUNT(stream_id) FROM push_rules_stream"
+                    " WHERE user_id = ? AND ? < stream_id"
+                )
+                txn.execute(sql, (user_id, last_id))
+                count, = txn.fetchone()
+                return bool(count)
+            return self.runInteraction(
+                "have_push_rules_changed", have_push_rules_changed_txn
+            )
+
+
+class PushRuleStore(PushRulesWorkerStore):
     @cachedList(cached_method_name="get_push_rules_for_user",
                 list_name="user_ids", num_args=1, inlineCallbacks=True)
     def bulk_get_push_rules(self, user_ids):
@@ -526,21 +579,8 @@ class PushRuleStore(SQLBaseStore):
         room stream ordering it corresponds to."""
         return self._push_rules_stream_id_gen.get_current_token()
 
-    def have_push_rules_changed_for_user(self, user_id, last_id):
-        if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
-            return defer.succeed(False)
-        else:
-            def have_push_rules_changed_txn(txn):
-                sql = (
-                    "SELECT COUNT(stream_id) FROM push_rules_stream"
-                    " WHERE user_id = ? AND ? < stream_id"
-                )
-                txn.execute(sql, (user_id, last_id))
-                count, = txn.fetchone()
-                return bool(count)
-            return self.runInteraction(
-                "have_push_rules_changed", have_push_rules_changed_txn
-            )
+    def get_max_push_rules_stream_id(self):
+        return self.get_push_rules_stream_token()[0]
 
 
 class RuleNotFoundException(Exception):
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index 3d8b4d5d5b..f4af3e4caa 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -27,7 +28,7 @@ import types
 logger = logging.getLogger(__name__)
 
 
-class PusherStore(SQLBaseStore):
+class PusherWorkerStore(SQLBaseStore):
     def _decode_pushers_rows(self, rows):
         for r in rows:
             dataJson = r['data']
@@ -102,9 +103,6 @@ class PusherStore(SQLBaseStore):
         rows = yield self.runInteraction("get_all_pushers", get_pushers)
         defer.returnValue(rows)
 
-    def get_pushers_stream_token(self):
-        return self._pushers_id_gen.get_current_token()
-
     def get_all_updated_pushers(self, last_id, current_id, limit):
         if last_id == current_id:
             return defer.succeed(([], []))
@@ -177,6 +175,11 @@ class PusherStore(SQLBaseStore):
             "get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn
         )
 
+
+class PusherStore(PusherWorkerStore):
+    def get_pushers_stream_token(self):
+        return self._pushers_id_gen.get_current_token()
+
     @cachedInlineCallbacks(num_args=1, max_entries=15000)
     def get_if_user_has_pusher(self, user_id):
         # This only exists for the cachedList decorator
diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py
index bff73f3f04..fc46bf7bb3 100644
--- a/synapse/storage/tags.py
+++ b/synapse/storage/tags.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,7 +14,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import SQLBaseStore
+from synapse.storage.account_data import AccountDataWorkerStore
+
 from synapse.util.caches.descriptors import cached
 from twisted.internet import defer
 
@@ -23,15 +25,7 @@ import logging
 logger = logging.getLogger(__name__)
 
 
-class TagsStore(SQLBaseStore):
-    def get_max_account_data_stream_id(self):
-        """Get the current max stream id for the private user data stream
-
-        Returns:
-            A deferred int.
-        """
-        return self._account_data_id_gen.get_current_token()
-
+class TagsWorkerStore(AccountDataWorkerStore):
     @cached()
     def get_tags_for_user(self, user_id):
         """Get all the tags for a user.
@@ -170,6 +164,8 @@ class TagsStore(SQLBaseStore):
             row["tag"]: json.loads(row["content"]) for row in rows
         })
 
+
+class TagsStore(TagsWorkerStore):
     @defer.inlineCallbacks
     def add_tag_to_room(self, user_id, room_id, tag, content):
         """Add a tag to a room for a user.