summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/3350.misc1
-rw-r--r--changelog.d/3586.misc1
-rw-r--r--changelog.d/3587.misc1
-rw-r--r--changelog.d/3595.misc1
-rw-r--r--changelog.d/3597.feature1
-rw-r--r--changelog.d/3601.bugfix1
-rw-r--r--changelog.d/3604.feature1
-rw-r--r--changelog.d/3605.bugfix1
-rw-r--r--changelog.d/3606.misc1
-rw-r--r--changelog.d/3607.bugfix1
-rw-r--r--synapse/app/client_reader.py2
-rwxr-xr-xsynapse/app/homeserver.py13
-rw-r--r--synapse/app/synchrotron.py5
-rw-r--r--synapse/events/snapshot.py3
-rw-r--r--synapse/federation/federation_server.py11
-rw-r--r--synapse/groups/attestations.py6
-rw-r--r--synapse/handlers/federation.py13
-rw-r--r--synapse/handlers/profile.py9
-rw-r--r--synapse/handlers/room.py12
-rw-r--r--synapse/replication/tcp/client.py2
-rw-r--r--synapse/replication/tcp/resource.py14
-rw-r--r--synapse/rest/media/v1/media_repository.py8
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py8
-rw-r--r--synapse/state.py143
-rw-r--r--synapse/storage/_base.py6
-rw-r--r--synapse/storage/devices.py8
-rw-r--r--synapse/storage/event_federation.py9
-rw-r--r--synapse/storage/event_push_actions.py13
-rw-r--r--synapse/storage/events.py196
-rw-r--r--synapse/storage/push_rule.py13
-rw-r--r--synapse/storage/roommember.py61
-rw-r--r--synapse/storage/transactions.py6
-rw-r--r--synapse/visibility.py19
-rw-r--r--tests/replication/slave/storage/_base.py37
34 files changed, 389 insertions, 238 deletions
diff --git a/changelog.d/3350.misc b/changelog.d/3350.misc
new file mode 100644
index 0000000000..3713cd6d63
--- /dev/null
+++ b/changelog.d/3350.misc
@@ -0,0 +1 @@
+Remove redundant checks on who_forgot_in_room
\ No newline at end of file
diff --git a/changelog.d/3586.misc b/changelog.d/3586.misc
new file mode 100644
index 0000000000..e853e2481b
--- /dev/null
+++ b/changelog.d/3586.misc
@@ -0,0 +1 @@
+Fixes and optimisations for resolve_state_groups
diff --git a/changelog.d/3587.misc b/changelog.d/3587.misc
new file mode 100644
index 0000000000..75a3479910
--- /dev/null
+++ b/changelog.d/3587.misc
@@ -0,0 +1 @@
+Improve logging for exceptions when handling PDUs
\ No newline at end of file
diff --git a/changelog.d/3595.misc b/changelog.d/3595.misc
new file mode 100644
index 0000000000..85903504cc
--- /dev/null
+++ b/changelog.d/3595.misc
@@ -0,0 +1 @@
+Attempt to reduce amount of state pulled out of DB during persist_events
diff --git a/changelog.d/3597.feature b/changelog.d/3597.feature
new file mode 100644
index 0000000000..ea4a85e0ae
--- /dev/null
+++ b/changelog.d/3597.feature
@@ -0,0 +1 @@
+Add support for client_reader to handle more APIs
diff --git a/changelog.d/3601.bugfix b/changelog.d/3601.bugfix
new file mode 100644
index 0000000000..1678b261d0
--- /dev/null
+++ b/changelog.d/3601.bugfix
@@ -0,0 +1 @@
+Fix failure to persist events over federation under load
diff --git a/changelog.d/3604.feature b/changelog.d/3604.feature
new file mode 100644
index 0000000000..77a294cb9f
--- /dev/null
+++ b/changelog.d/3604.feature
@@ -0,0 +1 @@
+Add metrics to track resource usage by background processes
diff --git a/changelog.d/3605.bugfix b/changelog.d/3605.bugfix
new file mode 100644
index 0000000000..786da546eb
--- /dev/null
+++ b/changelog.d/3605.bugfix
@@ -0,0 +1 @@
+Fix updating of cached remote profiles
diff --git a/changelog.d/3606.misc b/changelog.d/3606.misc
new file mode 100644
index 0000000000..f0137766a0
--- /dev/null
+++ b/changelog.d/3606.misc
@@ -0,0 +1 @@
+Fix some random logcontext leaks.
\ No newline at end of file
diff --git a/changelog.d/3607.bugfix b/changelog.d/3607.bugfix
new file mode 100644
index 0000000000..7ad64593b8
--- /dev/null
+++ b/changelog.d/3607.bugfix
@@ -0,0 +1 @@
+Fix 'tuple index out of range' error
\ No newline at end of file
diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py
index 398bb36602..e2c91123db 100644
--- a/synapse/app/client_reader.py
+++ b/synapse/app/client_reader.py
@@ -31,6 +31,7 @@ from synapse.http.site import SynapseSite
 from synapse.metrics import RegistryProxy
 from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
 from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
 from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
 from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
 from synapse.replication.slave.storage.directory import DirectoryStore
@@ -58,6 +59,7 @@ logger = logging.getLogger("synapse.app.client_reader")
 
 
 class ClientReaderSlavedStore(
+    SlavedAccountDataStore,
     SlavedEventStore,
     SlavedKeyStore,
     RoomStore,
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 2ad1beb8d8..b7e7718290 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -49,6 +49,7 @@ from synapse.http.additional_resource import AdditionalResource
 from synapse.http.server import RootRedirect
 from synapse.http.site import SynapseSite
 from synapse.metrics import RegistryProxy
+from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
 from synapse.module_api import ModuleApi
 from synapse.python_dependencies import CONDITIONAL_REQUIREMENTS, check_requirements
@@ -427,6 +428,9 @@ def run(hs):
     # currently either 0 or 1
     stats_process = []
 
+    def start_phone_stats_home():
+        run_as_background_process("phone_stats_home", phone_stats_home)
+
     @defer.inlineCallbacks
     def phone_stats_home():
         logger.info("Gathering stats for reporting")
@@ -498,7 +502,10 @@ def run(hs):
             )
 
     def generate_user_daily_visit_stats():
-        hs.get_datastore().generate_user_daily_visits()
+        run_as_background_process(
+            "generate_user_daily_visits",
+            hs.get_datastore().generate_user_daily_visits,
+        )
 
     # Rather than update on per session basis, batch up the requests.
     # If you increase the loop period, the accuracy of user_daily_visits
@@ -507,7 +514,7 @@ def run(hs):
 
     if hs.config.report_stats:
         logger.info("Scheduling stats reporting for 3 hour intervals")
-        clock.looping_call(phone_stats_home, 3 * 60 * 60 * 1000)
+        clock.looping_call(start_phone_stats_home, 3 * 60 * 60 * 1000)
 
         # We need to defer this init for the cases that we daemonize
         # otherwise the process ID we get is that of the non-daemon process
@@ -515,7 +522,7 @@ def run(hs):
 
         # We wait 5 minutes to send the first set of stats as the server can
         # be quite busy the first few minutes
-        clock.call_later(5 * 60, phone_stats_home)
+        clock.call_later(5 * 60, start_phone_stats_home)
 
     if hs.config.daemonize and hs.config.print_pidfile:
         print (hs.config.pid_file)
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index 26b9ec85f2..e201f18efd 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -55,7 +55,6 @@ from synapse.rest.client.v2_alpha import sync
 from synapse.server import HomeServer
 from synapse.storage.engines import create_engine
 from synapse.storage.presence import UserPresenceState
-from synapse.storage.roommember import RoomMemberStore
 from synapse.util.httpresourcetree import create_resource_tree
 from synapse.util.logcontext import LoggingContext, run_in_background
 from synapse.util.manhole import manhole
@@ -81,9 +80,7 @@ class SynchrotronSlavedStore(
     RoomStore,
     BaseSlavedStore,
 ):
-    did_forget = (
-        RoomMemberStore.__dict__["did_forget"]
-    )
+    pass
 
 
 UPDATE_SYNCING_USERS_MS = 10 * 1000
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 189212b0fa..368b5f6ae4 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -249,7 +249,7 @@ class EventContext(object):
 
     @defer.inlineCallbacks
     def update_state(self, state_group, prev_state_ids, current_state_ids,
-                     delta_ids):
+                     prev_group, delta_ids):
         """Replace the state in the context
         """
 
@@ -260,6 +260,7 @@ class EventContext(object):
 
         self.state_group = state_group
         self._prev_state_ids = prev_state_ids
+        self.prev_group = prev_group
         self._current_state_ids = current_state_ids
         self.delta_ids = delta_ids
 
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 48f26db67c..e501251b6e 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -24,6 +24,7 @@ from prometheus_client import Counter
 
 from twisted.internet import defer
 from twisted.internet.abstract import isIPAddress
+from twisted.python import failure
 
 from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError, FederationError, NotFoundError, SynapseError
@@ -186,8 +187,12 @@ class FederationServer(FederationBase):
                     logger.warn("Error handling PDU %s: %s", event_id, e)
                     pdu_results[event_id] = {"error": str(e)}
                 except Exception as e:
+                    f = failure.Failure()
                     pdu_results[event_id] = {"error": str(e)}
-                    logger.exception("Failed to handle PDU %s", event_id)
+                    logger.error(
+                        "Failed to handle PDU %s: %s",
+                        event_id, f.getTraceback().rstrip(),
+                    )
 
         yield async.concurrently_execute(
             process_pdus_for_room, pdus_by_room.keys(),
@@ -203,8 +208,8 @@ class FederationServer(FederationBase):
                 )
 
         pdu_failures = getattr(transaction, "pdu_failures", [])
-        for failure in pdu_failures:
-            logger.info("Got failure %r", failure)
+        for fail in pdu_failures:
+            logger.info("Got failure %r", fail)
 
         response = {
             "pdus": pdu_results,
diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py
index 47452700a8..4216af0a27 100644
--- a/synapse/groups/attestations.py
+++ b/synapse/groups/attestations.py
@@ -43,6 +43,7 @@ from signedjson.sign import sign_json
 from twisted.internet import defer
 
 from synapse.api.errors import SynapseError
+from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.types import get_domain_from_id
 from synapse.util.logcontext import run_in_background
 
@@ -129,7 +130,7 @@ class GroupAttestionRenewer(object):
         self.attestations = hs.get_groups_attestation_signing()
 
         self._renew_attestations_loop = self.clock.looping_call(
-            self._renew_attestations, 30 * 60 * 1000,
+            self._start_renew_attestations, 30 * 60 * 1000,
         )
 
     @defer.inlineCallbacks
@@ -151,6 +152,9 @@ class GroupAttestionRenewer(object):
 
         defer.returnValue({})
 
+    def _start_renew_attestations(self):
+        run_as_background_process("renew_attestations", self._renew_attestations)
+
     @defer.inlineCallbacks
     def _renew_attestations(self):
         """Called periodically to check if we need to update any of our attestations
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 14654d59f1..145c1a21d4 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1980,10 +1980,6 @@ class FederationHandler(BaseHandler):
 
         current_state_ids.update(state_updates)
 
-        if context.delta_ids is not None:
-            delta_ids = dict(context.delta_ids)
-            delta_ids.update(state_updates)
-
         prev_state_ids = yield context.get_prev_state_ids(self.store)
         prev_state_ids = dict(prev_state_ids)
 
@@ -1991,11 +1987,13 @@ class FederationHandler(BaseHandler):
             k: a.event_id for k, a in iteritems(auth_events)
         })
 
+        # create a new state group as a delta from the existing one.
+        prev_group = context.state_group
         state_group = yield self.store.store_state_group(
             event.event_id,
             event.room_id,
-            prev_group=context.prev_group,
-            delta_ids=delta_ids,
+            prev_group=prev_group,
+            delta_ids=state_updates,
             current_state_ids=current_state_ids,
         )
 
@@ -2003,7 +2001,8 @@ class FederationHandler(BaseHandler):
             state_group=state_group,
             current_state_ids=current_state_ids,
             prev_state_ids=prev_state_ids,
-            delta_ids=delta_ids,
+            prev_group=prev_group,
+            delta_ids=state_updates,
         )
 
     @defer.inlineCallbacks
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 859f6d2b2e..43692b83a8 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -18,6 +18,7 @@ import logging
 from twisted.internet import defer
 
 from synapse.api.errors import AuthError, CodeMessageException, SynapseError
+from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.types import UserID, get_domain_from_id
 
 from ._base import BaseHandler
@@ -41,7 +42,7 @@ class ProfileHandler(BaseHandler):
 
         if hs.config.worker_app is None:
             self.clock.looping_call(
-                self._update_remote_profile_cache, self.PROFILE_UPDATE_MS,
+                self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS,
             )
 
     @defer.inlineCallbacks
@@ -254,6 +255,12 @@ class ProfileHandler(BaseHandler):
                     room_id, str(e.message)
                 )
 
+    def _start_update_remote_profile_cache(self):
+        run_as_background_process(
+            "Update remote profile", self._update_remote_profile_cache,
+        )
+
+    @defer.inlineCallbacks
     def _update_remote_profile_cache(self):
         """Called periodically to check profiles of remote users we haven't
         checked in a while.
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 6150b7e226..003b848c00 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -24,7 +24,7 @@ from twisted.internet import defer
 
 from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset
 from synapse.api.errors import AuthError, Codes, StoreError, SynapseError
-from synapse.types import RoomAlias, RoomID, RoomStreamToken, UserID
+from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID
 from synapse.util import stringutils
 from synapse.visibility import filter_events_for_client
 
@@ -418,8 +418,6 @@ class RoomContextHandler(object):
         before_limit = math.floor(limit / 2.)
         after_limit = limit - before_limit
 
-        now_token = yield self.hs.get_event_sources().get_current_token()
-
         users = yield self.store.get_users_in_room(room_id)
         is_peeking = user.to_string() not in users
 
@@ -462,11 +460,15 @@ class RoomContextHandler(object):
         )
         results["state"] = list(state[last_event_id].values())
 
-        results["start"] = now_token.copy_and_replace(
+        # We use a dummy token here as we only care about the room portion of
+        # the token, which we replace.
+        token = StreamToken.START
+
+        results["start"] = token.copy_and_replace(
             "room_key", results["start"]
         ).to_string()
 
-        results["end"] = now_token.copy_and_replace(
+        results["end"] = token.copy_and_replace(
             "room_key", results["end"]
         ).to_string()
 
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index e592ab57bf..970e94313e 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -192,7 +192,7 @@ class ReplicationClientHandler(object):
         """Returns a deferred that is resolved when we receive a SYNC command
         with given data.
 
-        Used by tests.
+        [Not currently] used by tests.
         """
         return self.awaiting_syncs.setdefault(data, defer.Deferred())
 
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 611fb66e1d..fd59f1595f 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -25,6 +25,7 @@ from twisted.internet import defer
 from twisted.internet.protocol import Factory
 
 from synapse.metrics import LaterGauge
+from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.util.metrics import Measure, measure_func
 
 from .protocol import ServerReplicationStreamProtocol
@@ -117,7 +118,6 @@ class ReplicationStreamer(object):
         for conn in self.connections:
             conn.send_error("server shutting down")
 
-    @defer.inlineCallbacks
     def on_notifier_poke(self):
         """Checks if there is actually any new data and sends it to the
         connections if there are.
@@ -132,14 +132,16 @@ class ReplicationStreamer(object):
                 stream.discard_updates_and_advance()
             return
 
-        # If we're in the process of checking for new updates, mark that fact
-        # and return
+        self.pending_updates = True
+
         if self.is_looping:
-            logger.debug("Noitifier poke loop already running")
-            self.pending_updates = True
+            logger.debug("Notifier poke loop already running")
             return
 
-        self.pending_updates = True
+        run_as_background_process("replication_notifier", self._run_notifier_loop)
+
+    @defer.inlineCallbacks
+    def _run_notifier_loop(self):
         self.is_looping = True
 
         try:
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 30242c525a..5b13378caa 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -35,6 +35,7 @@ from synapse.api.errors import (
     SynapseError,
 )
 from synapse.http.matrixfederationclient import MatrixFederationHttpClient
+from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.util.async import Linearizer
 from synapse.util.logcontext import make_deferred_yieldable
 from synapse.util.retryutils import NotRetryingDestination
@@ -100,10 +101,15 @@ class MediaRepository(object):
         )
 
         self.clock.looping_call(
-            self._update_recently_accessed,
+            self._start_update_recently_accessed,
             UPDATE_RECENTLY_ACCESSED_TS,
         )
 
+    def _start_update_recently_accessed(self):
+        run_as_background_process(
+            "update_recently_accessed_media", self._update_recently_accessed,
+        )
+
     @defer.inlineCallbacks
     def _update_recently_accessed(self):
         remote_media = self.recently_accessed_remotes
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index b70b15c4c2..4efd5339a4 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -41,6 +41,7 @@ from synapse.http.server import (
     wrap_json_request_handler,
 )
 from synapse.http.servlet import parse_integer, parse_string
+from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.util.async import ObservableDeferred
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.logcontext import make_deferred_yieldable, run_in_background
@@ -81,7 +82,7 @@ class PreviewUrlResource(Resource):
         self._cache.start()
 
         self._cleaner_loop = self.clock.looping_call(
-            self._expire_url_cache_data, 10 * 1000
+            self._start_expire_url_cache_data, 10 * 1000,
         )
 
     def render_OPTIONS(self, request):
@@ -371,6 +372,11 @@ class PreviewUrlResource(Resource):
             "etag": headers["ETag"][0] if "ETag" in headers else None,
         })
 
+    def _start_expire_url_cache_data(self):
+        run_as_background_process(
+            "expire_url_cache_data", self._expire_url_cache_data,
+        )
+
     @defer.inlineCallbacks
     def _expire_url_cache_data(self):
         """Clean up expired url cache content, media and thumbnails.
diff --git a/synapse/state.py b/synapse/state.py
index 32125c95df..033f55d967 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -471,69 +471,39 @@ class StateResolutionHandler(object):
                 "Resolving state for %s with %d groups", room_id, len(state_groups_ids)
             )
 
-            # build a map from state key to the event_ids which set that state.
-            # dict[(str, str), set[str])
-            state = {}
+            # start by assuming we won't have any conflicted state, and build up the new
+            # state map by iterating through the state groups. If we discover a conflict,
+            # we give up and instead use `resolve_events_with_factory`.
+            #
+            # XXX: is this actually worthwhile, or should we just let
+            # resolve_events_with_factory do it?
+            new_state = {}
+            conflicted_state = False
             for st in itervalues(state_groups_ids):
                 for key, e_id in iteritems(st):
-                    state.setdefault(key, set()).add(e_id)
-
-            # build a map from state key to the event_ids which set that state,
-            # including only those where there are state keys in conflict.
-            conflicted_state = {
-                k: list(v)
-                for k, v in iteritems(state)
-                if len(v) > 1
-            }
+                    if key in new_state:
+                        conflicted_state = True
+                        break
+                    new_state[key] = e_id
+                if conflicted_state:
+                    break
 
             if conflicted_state:
                 logger.info("Resolving conflicted state for %r", room_id)
                 with Measure(self.clock, "state._resolve_events"):
                     new_state = yield resolve_events_with_factory(
-                        list(state_groups_ids.values()),
+                        list(itervalues(state_groups_ids)),
                         event_map=event_map,
                         state_map_factory=state_map_factory,
                     )
-            else:
-                new_state = {
-                    key: e_ids.pop() for key, e_ids in iteritems(state)
-                }
 
-            with Measure(self.clock, "state.create_group_ids"):
-                # if the new state matches any of the input state groups, we can
-                # use that state group again. Otherwise we will generate a state_id
-                # which will be used as a cache key for future resolutions, but
-                # not get persisted.
-                state_group = None
-                new_state_event_ids = frozenset(itervalues(new_state))
-                for sg, events in iteritems(state_groups_ids):
-                    if new_state_event_ids == frozenset(e_id for e_id in events):
-                        state_group = sg
-                        break
+            # if the new state matches any of the input state groups, we can
+            # use that state group again. Otherwise we will generate a state_id
+            # which will be used as a cache key for future resolutions, but
+            # not get persisted.
 
-                # TODO: We want to create a state group for this set of events, to
-                # increase cache hits, but we need to make sure that it doesn't
-                # end up as a prev_group without being added to the database
-
-                prev_group = None
-                delta_ids = None
-                for old_group, old_ids in iteritems(state_groups_ids):
-                    if not set(new_state) - set(old_ids):
-                        n_delta_ids = {
-                            k: v
-                            for k, v in iteritems(new_state)
-                            if old_ids.get(k) != v
-                        }
-                        if not delta_ids or len(n_delta_ids) < len(delta_ids):
-                            prev_group = old_group
-                            delta_ids = n_delta_ids
-
-            cache = _StateCacheEntry(
-                state=new_state,
-                state_group=state_group,
-                prev_group=prev_group,
-                delta_ids=delta_ids,
-            )
+            with Measure(self.clock, "state.create_group_ids"):
+                cache = _make_state_cache_entry(new_state, state_groups_ids)
 
             if self._state_cache is not None:
                 self._state_cache[group_names] = cache
@@ -541,6 +511,70 @@ class StateResolutionHandler(object):
             defer.returnValue(cache)
 
 
+def _make_state_cache_entry(
+    new_state,
+    state_groups_ids,
+):
+    """Given a resolved state, and a set of input state groups, pick one to base
+    a new state group on (if any), and return an appropriately-constructed
+    _StateCacheEntry.
+
+    Args:
+        new_state (dict[(str, str), str]): resolved state map (mapping from
+           (type, state_key) to event_id)
+
+        state_groups_ids (dict[int, dict[(str, str), str]]):
+                 map from state group id to the state in that state group
+                (where 'state' is a map from state key to event id)
+
+    Returns:
+        _StateCacheEntry
+    """
+    # if the new state matches any of the input state groups, we can
+    # use that state group again. Otherwise we will generate a state_id
+    # which will be used as a cache key for future resolutions, but
+    # not get persisted.
+
+    # first look for exact matches
+    new_state_event_ids = set(itervalues(new_state))
+    for sg, state in iteritems(state_groups_ids):
+        if len(new_state_event_ids) != len(state):
+            continue
+
+        old_state_event_ids = set(itervalues(state))
+        if new_state_event_ids == old_state_event_ids:
+            # got an exact match.
+            return _StateCacheEntry(
+                state=new_state,
+                state_group=sg,
+            )
+
+    # TODO: We want to create a state group for this set of events, to
+    # increase cache hits, but we need to make sure that it doesn't
+    # end up as a prev_group without being added to the database
+
+    # failing that, look for the closest match.
+    prev_group = None
+    delta_ids = None
+
+    for old_group, old_state in iteritems(state_groups_ids):
+        n_delta_ids = {
+            k: v
+            for k, v in iteritems(new_state)
+            if old_state.get(k) != v
+        }
+        if not delta_ids or len(n_delta_ids) < len(delta_ids):
+            prev_group = old_group
+            delta_ids = n_delta_ids
+
+    return _StateCacheEntry(
+        state=new_state,
+        state_group=None,
+        prev_group=prev_group,
+        delta_ids=delta_ids,
+    )
+
+
 def _ordered_events(events):
     def key_func(e):
         return -int(e.depth), hashlib.sha1(e.event_id.encode()).hexdigest()
@@ -582,7 +616,7 @@ def _seperate(state_sets):
     with them in different state sets.
 
     Args:
-        state_sets(list[dict[(str, str), str]]):
+        state_sets(iterable[dict[(str, str), str]]):
             List of dicts of (type, state_key) -> event_id, which are the
             different state groups to resolve.
 
@@ -596,10 +630,11 @@ def _seperate(state_sets):
             conflicted_state is a dict mapping (type, state_key) to a set of
             event ids for conflicted state keys.
     """
-    unconflicted_state = dict(state_sets[0])
+    state_set_iterator = iter(state_sets)
+    unconflicted_state = dict(next(state_set_iterator))
     conflicted_state = {}
 
-    for state_set in state_sets[1:]:
+    for state_set in state_set_iterator:
         for key, value in iteritems(state_set):
             # Check if there is an unconflicted entry for the state key.
             unconflicted_value = unconflicted_state.get(key)
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 1d41d8d445..44f37b4c1e 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -311,6 +311,12 @@ class SQLBaseStore(object):
         after_callbacks = []
         exception_callbacks = []
 
+        if LoggingContext.current_context() == LoggingContext.sentinel:
+            logger.warn(
+                "Starting db txn '%s' from sentinel context",
+                desc,
+            )
+
         try:
             result = yield self.runWithConnection(
                 self._new_transaction,
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index cc3cdf2ebc..52dccb1507 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -21,6 +21,7 @@ from canonicaljson import json
 from twisted.internet import defer
 
 from synapse.api.errors import StoreError
+from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
 
 from ._base import Cache, SQLBaseStore
@@ -711,6 +712,9 @@ class DeviceStore(SQLBaseStore):
 
             logger.info("Pruned %d device list outbound pokes", txn.rowcount)
 
-        return self.runInteraction(
-            "_prune_old_outbound_device_pokes", _prune_txn
+        run_as_background_process(
+            "prune_old_outbound_device_pokes",
+            self.runInteraction,
+            "_prune_old_outbound_device_pokes",
+            _prune_txn,
         )
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 8d366d1b91..65f2d19e20 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -23,6 +23,7 @@ from unpaddedbase64 import encode_base64
 from twisted.internet import defer
 
 from synapse.api.errors import StoreError
+from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.events import EventsWorkerStore
 from synapse.storage.signatures import SignatureWorkerStore
@@ -446,7 +447,7 @@ class EventFederationStore(EventFederationWorkerStore):
         )
 
         hs.get_clock().looping_call(
-            self._delete_old_forward_extrem_cache, 60 * 60 * 1000
+            self._delete_old_forward_extrem_cache, 60 * 60 * 1000,
         )
 
     def _update_min_depth_for_room_txn(self, txn, room_id, depth):
@@ -548,9 +549,11 @@ class EventFederationStore(EventFederationWorkerStore):
                 sql,
                 (self.stream_ordering_month_ago, self.stream_ordering_month_ago,)
             )
-        return self.runInteraction(
+        run_as_background_process(
+            "delete_old_forward_extrem_cache",
+            self.runInteraction,
             "_delete_old_forward_extrem_cache",
-            _delete_old_forward_extrem_cache_txn
+            _delete_old_forward_extrem_cache_txn,
         )
 
     def clean_room_for_join(self, room_id):
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index 29b511ae5e..4f44b0ad47 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -22,6 +22,7 @@ from canonicaljson import json
 
 from twisted.internet import defer
 
+from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import LoggingTransaction, SQLBaseStore
 from synapse.util.caches.descriptors import cachedInlineCallbacks
 
@@ -458,11 +459,12 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 "Error removing push actions after event persistence failure",
             )
 
-    @defer.inlineCallbacks
     def _find_stream_orderings_for_times(self):
-        yield self.runInteraction(
+        run_as_background_process(
+            "event_push_action_stream_orderings",
+            self.runInteraction,
             "_find_stream_orderings_for_times",
-            self._find_stream_orderings_for_times_txn
+            self._find_stream_orderings_for_times_txn,
         )
 
     def _find_stream_orderings_for_times_txn(self, txn):
@@ -604,7 +606,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
 
         self._doing_notif_rotation = False
         self._rotate_notif_loop = self._clock.looping_call(
-            self._rotate_notifs, 30 * 60 * 1000
+            self._start_rotate_notifs, 30 * 60 * 1000,
         )
 
     def _set_push_actions_for_event_and_users_txn(self, txn, events_and_contexts,
@@ -787,6 +789,9 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
             WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
         """, (room_id, user_id, stream_ordering))
 
+    def _start_rotate_notifs(self):
+        run_as_background_process("rotate_notifs", self._rotate_notifs)
+
     @defer.inlineCallbacks
     def _rotate_notifs(self):
         if self._doing_notif_rotation or self.stream_ordering_day_ago is None:
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 19c05dc8d6..200f5ec95f 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -19,7 +19,7 @@ import logging
 from collections import OrderedDict, deque, namedtuple
 from functools import wraps
 
-from six import iteritems, itervalues
+from six import iteritems
 from six.moves import range
 
 from canonicaljson import json
@@ -142,15 +142,14 @@ class _EventPeristenceQueue(object):
             try:
                 queue = self._get_drainining_queue(room_id)
                 for item in queue:
-                    # handle_queue_loop runs in the sentinel logcontext, so
-                    # there is no need to preserve_fn when running the
-                    # callbacks on the deferred.
                     try:
                         ret = yield per_item_callback(item)
+                    except Exception:
+                        with PreserveLoggingContext():
+                            item.deferred.errback()
+                    else:
                         with PreserveLoggingContext():
                             item.deferred.callback(ret)
-                    except Exception:
-                        item.deferred.errback()
             finally:
                 queue = self._event_persist_queues.pop(room_id, None)
                 if queue:
@@ -344,11 +343,14 @@ class EventsStore(EventsWorkerStore):
                 new_forward_extremeties = {}
 
                 # map room_id->(type,state_key)->event_id tracking the full
-                # state in each room after adding these events
+                # state in each room after adding these events.
+                # This is simply used to prefill the get_current_state_ids
+                # cache
                 current_state_for_room = {}
 
-                # map room_id->(to_delete, to_insert) where each entry is
-                # a map (type,key)->event_id giving the state delta in each
+                # map room_id->(to_delete, to_insert) where to_delete is a list
+                # of type/state keys to remove from current state, and to_insert
+                # is a map (type,key)->event_id giving the state delta in each
                 # room
                 state_delta_for_room = {}
 
@@ -418,28 +420,40 @@ class EventsStore(EventsWorkerStore):
                             logger.info(
                                 "Calculating state delta for room %s", room_id,
                             )
-
                             with Measure(
-                                    self._clock,
-                                    "persist_events.get_new_state_after_events",
+                                self._clock,
+                                "persist_events.get_new_state_after_events",
                             ):
-                                current_state = yield self._get_new_state_after_events(
+                                res = yield self._get_new_state_after_events(
                                     room_id,
                                     ev_ctx_rm,
                                     latest_event_ids,
                                     new_latest_event_ids,
                                 )
-
-                            if current_state is not None:
-                                current_state_for_room[room_id] = current_state
+                                current_state, delta_ids = res
+
+                            # If either are not None then there has been a change,
+                            # and we need to work out the delta (or use that
+                            # given)
+                            if delta_ids is not None:
+                                # If there is a delta we know that we've
+                                # only added or replaced state, never
+                                # removed keys entirely.
+                                state_delta_for_room[room_id] = ([], delta_ids)
+                            elif current_state is not None:
                                 with Measure(
-                                        self._clock,
-                                        "persist_events.calculate_state_delta",
+                                    self._clock,
+                                    "persist_events.calculate_state_delta",
                                 ):
                                     delta = yield self._calculate_state_delta(
                                         room_id, current_state,
                                     )
-                                    state_delta_for_room[room_id] = delta
+                                state_delta_for_room[room_id] = delta
+
+                            # If we have the current_state then lets prefill
+                            # the cache with it.
+                            if current_state is not None:
+                                current_state_for_room[room_id] = current_state
 
                 yield self.runInteraction(
                     "persist_events",
@@ -538,9 +552,15 @@ class EventsStore(EventsWorkerStore):
                 the new forward extremities for the room.
 
         Returns:
-            Deferred[dict[(str,str), str]|None]:
-                None if there are no changes to the room state, or
-                a dict of (type, state_key) -> event_id].
+            Deferred[tuple[dict[(str,str), str]|None, dict[(str,str), str]|None]]:
+            Returns a tuple of two state maps, the first being the full new current
+            state and the second being the delta to the existing current state.
+            If both are None then there has been no change.
+
+            If there has been a change then we only return the delta if its
+            already been calculated. Conversely if we do know the delta then
+            the new current state is only returned if we've already calculated
+            it.
         """
 
         if not new_latest_event_ids:
@@ -548,13 +568,19 @@ class EventsStore(EventsWorkerStore):
 
         # map from state_group to ((type, key) -> event_id) state map
         state_groups_map = {}
+
+        # Map from (prev state group, new state group) -> delta state dict
+        state_group_deltas = {}
+
         for ev, ctx in events_context:
             if ctx.state_group is None:
-                # I don't think this can happen, but let's double-check
-                raise Exception(
-                    "Context for new extremity event %s has no state "
-                    "group" % (ev.event_id, ),
-                )
+                # This should only happen for outlier events.
+                if not ev.internal_metadata.is_outlier():
+                    raise Exception(
+                        "Context for new event %s has no state "
+                        "group" % (ev.event_id, ),
+                    )
+                continue
 
             if ctx.state_group in state_groups_map:
                 continue
@@ -566,6 +592,9 @@ class EventsStore(EventsWorkerStore):
             if current_state_ids is not None:
                 state_groups_map[ctx.state_group] = current_state_ids
 
+            if ctx.prev_group:
+                state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids
+
         # We need to map the event_ids to their state groups. First, let's
         # check if the event is one we're persisting, in which case we can
         # pull the state group from its context.
@@ -579,7 +608,7 @@ class EventsStore(EventsWorkerStore):
         for event_id in new_latest_event_ids:
             # First search in the list of new events we're adding.
             for ev, ctx in events_context:
-                if event_id == ev.event_id:
+                if event_id == ev.event_id and ctx.state_group is not None:
                     event_id_to_state_group[event_id] = ctx.state_group
                     break
             else:
@@ -607,7 +636,26 @@ class EventsStore(EventsWorkerStore):
         # If they old and new groups are the same then we don't need to do
         # anything.
         if old_state_groups == new_state_groups:
-            return
+            defer.returnValue((None, None))
+
+        if len(new_state_groups) == 1 and len(old_state_groups) == 1:
+            # If we're going from one state group to another, lets check if
+            # we have a delta for that transition. If we do then we can just
+            # return that.
+
+            new_state_group = next(iter(new_state_groups))
+            old_state_group = next(iter(old_state_groups))
+
+            delta_ids = state_group_deltas.get(
+                (old_state_group, new_state_group,), None
+            )
+            if delta_ids is not None:
+                # We have a delta from the existing to new current state,
+                # so lets just return that. If we happen to already have
+                # the current state in memory then lets also return that,
+                # but it doesn't matter if we don't.
+                new_state = state_groups_map.get(new_state_group)
+                defer.returnValue((new_state, delta_ids))
 
         # Now that we have calculated new_state_groups we need to get
         # their state IDs so we can resolve to a single state set.
@@ -619,7 +667,7 @@ class EventsStore(EventsWorkerStore):
         if len(new_state_groups) == 1:
             # If there is only one state group, then we know what the current
             # state is.
-            defer.returnValue(state_groups_map[new_state_groups.pop()])
+            defer.returnValue((state_groups_map[new_state_groups.pop()], None))
 
         # Ok, we need to defer to the state handler to resolve our state sets.
 
@@ -638,7 +686,7 @@ class EventsStore(EventsWorkerStore):
             room_id, state_groups, events_map, get_events
         )
 
-        defer.returnValue(res.state)
+        defer.returnValue((res.state, None))
 
     @defer.inlineCallbacks
     def _calculate_state_delta(self, room_id, current_state):
@@ -647,17 +695,16 @@ class EventsStore(EventsWorkerStore):
         Assumes that we are only persisting events for one room at a time.
 
         Returns:
-            2-tuple (to_delete, to_insert) where both are state dicts,
-            i.e. (type, state_key) -> event_id. `to_delete` are the entries to
-            first be deleted from current_state_events, `to_insert` are entries
-            to insert.
+            tuple[list, dict] (to_delete, to_insert): where to_delete are the
+            type/state_keys to remove from current_state_events and `to_insert`
+            are the updates to current_state_events.
         """
         existing_state = yield self.get_current_state_ids(room_id)
 
-        to_delete = {
-            key: ev_id for key, ev_id in iteritems(existing_state)
-            if ev_id != current_state.get(key)
-        }
+        to_delete = [
+            key for key in existing_state
+            if key not in current_state
+        ]
 
         to_insert = {
             key: ev_id for key, ev_id in iteritems(current_state)
@@ -684,10 +731,10 @@ class EventsStore(EventsWorkerStore):
             delete_existing (bool): True to purge existing table rows for the
                 events from the database. This is useful when retrying due to
                 IntegrityError.
-            state_delta_for_room (dict[str, (list[str], list[str])]):
+            state_delta_for_room (dict[str, (list, dict)]):
                 The current-state delta for each room. For each room, a tuple
-                (to_delete, to_insert), being a list of event ids to be removed
-                from the current state, and a list of event ids to be added to
+                (to_delete, to_insert), being a list of type/state keys to be
+                removed from the current state, and a state set to be added to
                 the current state.
             new_forward_extremeties (dict[str, list[str]]):
                 The new forward extremities for each room. For each room, a
@@ -765,9 +812,46 @@ class EventsStore(EventsWorkerStore):
     def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order):
         for room_id, current_state_tuple in iteritems(state_delta_by_room):
                 to_delete, to_insert = current_state_tuple
+
+                # First we add entries to the current_state_delta_stream. We
+                # do this before updating the current_state_events table so
+                # that we can use it to calculate the `prev_event_id`. (This
+                # allows us to not have to pull out the existing state
+                # unnecessarily).
+                sql = """
+                    INSERT INTO current_state_delta_stream
+                    (stream_id, room_id, type, state_key, event_id, prev_event_id)
+                    SELECT ?, ?, ?, ?, ?, (
+                        SELECT event_id FROM current_state_events
+                        WHERE room_id = ? AND type = ? AND state_key = ?
+                    )
+                """
+                txn.executemany(sql, (
+                    (
+                        max_stream_order, room_id, etype, state_key, None,
+                        room_id, etype, state_key,
+                    )
+                    for etype, state_key in to_delete
+                    # We sanity check that we're deleting rather than updating
+                    if (etype, state_key) not in to_insert
+                ))
+                txn.executemany(sql, (
+                    (
+                        max_stream_order, room_id, etype, state_key, ev_id,
+                        room_id, etype, state_key,
+                    )
+                    for (etype, state_key), ev_id in iteritems(to_insert)
+                ))
+
+                # Now we actually update the current_state_events table
+
                 txn.executemany(
-                    "DELETE FROM current_state_events WHERE event_id = ?",
-                    [(ev_id,) for ev_id in itervalues(to_delete)],
+                    "DELETE FROM current_state_events"
+                    " WHERE room_id = ? AND type = ? AND state_key = ?",
+                    (
+                        (room_id, etype, state_key)
+                        for etype, state_key in itertools.chain(to_delete, to_insert)
+                    ),
                 )
 
                 self._simple_insert_many_txn(
@@ -784,25 +868,6 @@ class EventsStore(EventsWorkerStore):
                     ],
                 )
 
-                state_deltas = {key: None for key in to_delete}
-                state_deltas.update(to_insert)
-
-                self._simple_insert_many_txn(
-                    txn,
-                    table="current_state_delta_stream",
-                    values=[
-                        {
-                            "stream_id": max_stream_order,
-                            "room_id": room_id,
-                            "type": key[0],
-                            "state_key": key[1],
-                            "event_id": ev_id,
-                            "prev_event_id": to_delete.get(key, None),
-                        }
-                        for key, ev_id in iteritems(state_deltas)
-                    ]
-                )
-
                 txn.call_after(
                     self._curr_state_delta_stream_cache.entity_has_changed,
                     room_id, max_stream_order,
@@ -816,7 +881,8 @@ class EventsStore(EventsWorkerStore):
                 # and which we have added, then we invlidate the caches for all
                 # those users.
                 members_changed = set(
-                    state_key for ev_type, state_key in state_deltas
+                    state_key
+                    for ev_type, state_key in itertools.chain(to_delete, to_insert)
                     if ev_type == EventTypes.Member
                 )
 
@@ -1072,7 +1138,7 @@ class EventsStore(EventsWorkerStore):
         ):
             txn.executemany(
                 "DELETE FROM %s WHERE room_id = ? AND event_id = ?" % (table,),
-                [(ev.event_id,) for ev, _ in events_and_contexts]
+                [(ev.room_id, ev.event_id) for ev, _ in events_and_contexts]
             )
 
     def _store_event_txn(self, txn, events_and_contexts):
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index af564b1b4e..6a5028961d 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -21,7 +21,6 @@ from canonicaljson import json
 
 from twisted.internet import defer
 
-from synapse.api.constants import EventTypes
 from synapse.push.baserules import list_with_base_rules
 from synapse.storage.appservice import ApplicationServiceWorkerStore
 from synapse.storage.pusher import PusherWorkerStore
@@ -250,18 +249,6 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
             if uid in local_users_in_room:
                 user_ids.add(uid)
 
-        forgotten = yield self.who_forgot_in_room(
-            event.room_id, on_invalidate=cache_context.invalidate,
-        )
-
-        for row in forgotten:
-            user_id = row["user_id"]
-            event_id = row["event_id"]
-
-            mem_id = current_state_ids.get((EventTypes.Member, user_id), None)
-            if event_id == mem_id:
-                user_ids.discard(user_id)
-
         rules_by_user = yield self.bulk_get_push_rules(
             user_ids, on_invalidate=cache_context.invalidate,
         )
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index a27702a7a0..027bf8c85e 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -461,17 +461,29 @@ class RoomMemberWorkerStore(EventsWorkerStore):
     def _get_joined_hosts_cache(self, room_id):
         return _JoinedHostsCache(self, room_id)
 
-    @cached()
-    def who_forgot_in_room(self, room_id):
-        return self._simple_select_list(
-            table="room_memberships",
-            retcols=("user_id", "event_id"),
-            keyvalues={
-                "room_id": room_id,
-                "forgotten": 1,
-            },
-            desc="who_forgot"
-        )
+    @cachedInlineCallbacks(num_args=2)
+    def did_forget(self, user_id, room_id):
+        """Returns whether user_id has elected to discard history for room_id.
+
+        Returns False if they have since re-joined."""
+        def f(txn):
+            sql = (
+                "SELECT"
+                "  COUNT(*)"
+                " FROM"
+                "  room_memberships"
+                " WHERE"
+                "  user_id = ?"
+                " AND"
+                "  room_id = ?"
+                " AND"
+                "  forgotten = 0"
+            )
+            txn.execute(sql, (user_id, room_id))
+            rows = txn.fetchall()
+            return rows[0][0]
+        count = yield self.runInteraction("did_forget_membership", f)
+        defer.returnValue(count == 0)
 
 
 class RoomMemberStore(RoomMemberWorkerStore):
@@ -580,36 +592,11 @@ class RoomMemberStore(RoomMemberWorkerStore):
             )
             txn.execute(sql, (user_id, room_id))
 
-            txn.call_after(self.did_forget.invalidate, (user_id, room_id))
             self._invalidate_cache_and_stream(
-                txn, self.who_forgot_in_room, (room_id,)
+                txn, self.did_forget, (user_id, room_id,),
             )
         return self.runInteraction("forget_membership", f)
 
-    @cachedInlineCallbacks(num_args=2)
-    def did_forget(self, user_id, room_id):
-        """Returns whether user_id has elected to discard history for room_id.
-
-        Returns False if they have since re-joined."""
-        def f(txn):
-            sql = (
-                "SELECT"
-                "  COUNT(*)"
-                " FROM"
-                "  room_memberships"
-                " WHERE"
-                "  user_id = ?"
-                " AND"
-                "  room_id = ?"
-                " AND"
-                "  forgotten = 0"
-            )
-            txn.execute(sql, (user_id, room_id))
-            rows = txn.fetchall()
-            return rows[0][0]
-        count = yield self.runInteraction("did_forget_membership", f)
-        defer.returnValue(count == 0)
-
     @defer.inlineCallbacks
     def _background_add_membership_profile(self, progress, batch_size):
         target_min_stream_id = progress.get(
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index c3bc94f56d..b4b479d94c 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -22,6 +22,7 @@ from canonicaljson import encode_canonical_json, json
 
 from twisted.internet import defer
 
+from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.util.caches.descriptors import cached
 
 from ._base import SQLBaseStore
@@ -57,7 +58,7 @@ class TransactionStore(SQLBaseStore):
     def __init__(self, db_conn, hs):
         super(TransactionStore, self).__init__(db_conn, hs)
 
-        self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000)
+        self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000)
 
     def get_received_txn_response(self, transaction_id, origin):
         """For an incoming transaction from a given origin, check if we have
@@ -271,6 +272,9 @@ class TransactionStore(SQLBaseStore):
         txn.execute(query, (self._clock.time_msec(),))
         return self.cursor_to_dict(txn)
 
+    def _start_cleanup_transactions(self):
+        run_as_background_process("cleanup_transactions", self._cleanup_transactions)
+
     def _cleanup_transactions(self):
         now = self._clock.time_msec()
         month_ago = now - 30 * 24 * 60 * 60 * 1000
diff --git a/synapse/visibility.py b/synapse/visibility.py
index ba0499a022..d4680863d3 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -24,7 +24,6 @@ from twisted.internet import defer
 from synapse.api.constants import EventTypes, Membership
 from synapse.events.utils import prune_event
 from synapse.types import get_domain_from_id
-from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
 
 logger = logging.getLogger(__name__)
 
@@ -76,19 +75,6 @@ def filter_events_for_client(store, user_id, events, is_peeking=False,
         types=types,
     )
 
-    forgotten = yield make_deferred_yieldable(defer.gatherResults([
-        defer.maybeDeferred(
-            preserve_fn(store.who_forgot_in_room),
-            room_id,
-        )
-        for room_id in frozenset(e.room_id for e in events)
-    ], consumeErrors=True))
-
-    # Set of membership event_ids that have been forgotten
-    event_id_forgotten = frozenset(
-        row["event_id"] for rows in forgotten for row in rows
-    )
-
     ignore_dict_content = yield store.get_global_account_data_by_type_for_user(
         "m.ignored_user_list", user_id,
     )
@@ -177,10 +163,7 @@ def filter_events_for_client(store, user_id, events, is_peeking=False,
         if membership is None:
             membership_event = state.get((EventTypes.Member, user_id), None)
             if membership_event:
-                # XXX why do we do this?
-                # https://github.com/matrix-org/synapse/issues/3350
-                if membership_event.event_id not in event_id_forgotten:
-                    membership = membership_event.membership
+                membership = membership_event.membership
 
         # if the user was a member of the room at the time of the event,
         # they can see it.
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 8708c8a196..a103e7be80 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -11,23 +11,44 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import tempfile
 
 from mock import Mock, NonCallableMock
 
 from twisted.internet import defer, reactor
+from twisted.internet.defer import Deferred
 
 from synapse.replication.tcp.client import (
     ReplicationClientFactory,
     ReplicationClientHandler,
 )
 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable
 
 from tests import unittest
 from tests.utils import setup_test_homeserver
 
 
+class TestReplicationClientHandler(ReplicationClientHandler):
+    """Overrides on_rdata so that we can wait for it to happen"""
+    def __init__(self, store):
+        super(TestReplicationClientHandler, self).__init__(store)
+        self._rdata_awaiters = []
+
+    def await_replication(self):
+        d = Deferred()
+        self._rdata_awaiters.append(d)
+        return make_deferred_yieldable(d)
+
+    def on_rdata(self, stream_name, token, rows):
+        awaiters = self._rdata_awaiters
+        self._rdata_awaiters = []
+        super(TestReplicationClientHandler, self).on_rdata(stream_name, token, rows)
+        with PreserveLoggingContext():
+            for a in awaiters:
+                a.callback(None)
+
+
 class BaseSlavedStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def setUp(self):
@@ -52,7 +73,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
         self.addCleanup(listener.stopListening)
         self.streamer = server_factory.streamer
 
-        self.replication_handler = ReplicationClientHandler(self.slaved_store)
+        self.replication_handler = TestReplicationClientHandler(self.slaved_store)
         client_factory = ReplicationClientFactory(
             self.hs, "client_name", self.replication_handler
         )
@@ -60,12 +81,14 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
         self.addCleanup(client_factory.stopTrying)
         self.addCleanup(client_connector.disconnect)
 
-    @defer.inlineCallbacks
     def replicate(self):
-        yield self.streamer.on_notifier_poke()
-        d = self.replication_handler.await_sync("replication_test")
-        self.streamer.send_sync_to_all_connections("replication_test")
-        yield d
+        """Tell the master side of replication that something has happened, and then
+        wait for the replication to occur.
+        """
+        # xxx: should we be more specific in what we wait for?
+        d = self.replication_handler.await_replication()
+        self.streamer.on_notifier_poke()
+        return d
 
     @defer.inlineCallbacks
     def check(self, method, args, expected_result=None):