summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/api/auth.py5
-rw-r--r--synapse/api/filtering.py261
-rw-r--r--synapse/app/appservice.py8
-rw-r--r--synapse/app/client_reader.py10
-rw-r--r--synapse/app/federation_reader.py8
-rw-r--r--synapse/app/federation_sender.py8
-rwxr-xr-xsynapse/app/homeserver.py9
-rw-r--r--synapse/app/media_repository.py10
-rw-r--r--synapse/app/pusher.py9
-rw-r--r--synapse/app/synchrotron.py21
-rwxr-xr-xsynapse/app/synctl.py47
-rw-r--r--synapse/appservice/__init__.py38
-rw-r--r--synapse/crypto/keyring.py71
-rw-r--r--synapse/events/snapshot.py26
-rw-r--r--synapse/federation/federation_client.py40
-rw-r--r--synapse/federation/send_queue.py7
-rw-r--r--synapse/federation/transaction_queue.py223
-rw-r--r--synapse/federation/transport/client.py7
-rw-r--r--synapse/handlers/directory.py1
-rw-r--r--synapse/handlers/e2e_keys.py34
-rw-r--r--synapse/handlers/federation.py79
-rw-r--r--synapse/handlers/presence.py3
-rw-r--r--synapse/handlers/profile.py6
-rw-r--r--synapse/http/matrixfederationclient.py314
-rw-r--r--synapse/push/push_rule_evaluator.py104
-rw-r--r--synapse/push/push_tools.py7
-rw-r--r--synapse/python_dependencies.py1
-rw-r--r--synapse/replication/resource.py4
-rw-r--r--synapse/replication/slave/storage/events.py45
-rw-r--r--synapse/replication/slave/storage/presence.py1
-rw-r--r--synapse/rest/client/v2_alpha/account.py2
-rw-r--r--synapse/rest/client/v2_alpha/register.py2
-rw-r--r--synapse/state.py17
-rw-r--r--synapse/storage/_base.py25
-rw-r--r--synapse/storage/account_data.py4
-rw-r--r--synapse/storage/background_updates.py15
-rw-r--r--synapse/storage/deviceinbox.py10
-rw-r--r--synapse/storage/devices.py11
-rw-r--r--synapse/storage/end_to_end_keys.py63
-rw-r--r--synapse/storage/event_federation.py35
-rw-r--r--synapse/storage/event_push_actions.py2
-rw-r--r--synapse/storage/events.py407
-rw-r--r--synapse/storage/keys.py5
-rw-r--r--synapse/storage/prepare_database.py2
-rw-r--r--synapse/storage/presence.py4
-rw-r--r--synapse/storage/receipts.py5
-rw-r--r--synapse/storage/registration.py2
-rw-r--r--synapse/storage/room.py4
-rw-r--r--synapse/storage/roommember.py53
-rw-r--r--synapse/storage/signatures.py2
-rw-r--r--synapse/storage/state.py78
-rw-r--r--synapse/storage/tags.py4
-rw-r--r--synapse/types.py4
-rw-r--r--synapse/util/__init__.py10
-rw-r--r--synapse/util/async.py7
-rw-r--r--synapse/util/caches/descriptors.py205
-rw-r--r--synapse/util/logcontext.py77
-rw-r--r--synapse/util/retryutils.py34
-rw-r--r--synapse/visibility.py10
60 files changed, 1543 insertions, 965 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py
index ff251ce597..7628e7c505 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -16,4 +16,4 @@
 """ This is a reference implementation of a Matrix home server.
 """
 
-__version__ = "0.19.2"
+__version__ = "0.19.3"
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 03a215ab1b..9dbc7993df 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -23,7 +23,7 @@ from synapse import event_auth
 from synapse.api.constants import EventTypes, Membership, JoinRules
 from synapse.api.errors import AuthError, Codes
 from synapse.types import UserID
-from synapse.util.logcontext import preserve_context_over_fn
+from synapse.util import logcontext
 from synapse.util.metrics import Measure
 
 logger = logging.getLogger(__name__)
@@ -209,8 +209,7 @@ class Auth(object):
                 default=[""]
             )[0]
             if user and access_token and ip_addr:
-                preserve_context_over_fn(
-                    self.store.insert_client_ip,
+                logcontext.preserve_fn(self.store.insert_client_ip)(
                     user=user,
                     access_token=access_token,
                     ip=ip_addr,
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 47f0cf0fa9..83206348e5 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -15,10 +15,172 @@
 from synapse.api.errors import SynapseError
 from synapse.storage.presence import UserPresenceState
 from synapse.types import UserID, RoomID
-
 from twisted.internet import defer
 
 import ujson as json
+import jsonschema
+from jsonschema import FormatChecker
+
+FILTER_SCHEMA = {
+    "additionalProperties": False,
+    "type": "object",
+    "properties": {
+        "limit": {
+            "type": "number"
+        },
+        "senders": {
+            "$ref": "#/definitions/user_id_array"
+        },
+        "not_senders": {
+            "$ref": "#/definitions/user_id_array"
+        },
+        # TODO: We don't limit event type values but we probably should...
+        # check types are valid event types
+        "types": {
+            "type": "array",
+            "items": {
+                "type": "string"
+            }
+        },
+        "not_types": {
+            "type": "array",
+            "items": {
+                "type": "string"
+            }
+        }
+    }
+}
+
+ROOM_FILTER_SCHEMA = {
+    "additionalProperties": False,
+    "type": "object",
+    "properties": {
+        "not_rooms": {
+            "$ref": "#/definitions/room_id_array"
+        },
+        "rooms": {
+            "$ref": "#/definitions/room_id_array"
+        },
+        "ephemeral": {
+            "$ref": "#/definitions/room_event_filter"
+        },
+        "include_leave": {
+            "type": "boolean"
+        },
+        "state": {
+            "$ref": "#/definitions/room_event_filter"
+        },
+        "timeline": {
+            "$ref": "#/definitions/room_event_filter"
+        },
+        "account_data": {
+            "$ref": "#/definitions/room_event_filter"
+        },
+    }
+}
+
+ROOM_EVENT_FILTER_SCHEMA = {
+    "additionalProperties": False,
+    "type": "object",
+    "properties": {
+        "limit": {
+            "type": "number"
+        },
+        "senders": {
+            "$ref": "#/definitions/user_id_array"
+        },
+        "not_senders": {
+            "$ref": "#/definitions/user_id_array"
+        },
+        "types": {
+            "type": "array",
+            "items": {
+                "type": "string"
+            }
+        },
+        "not_types": {
+            "type": "array",
+            "items": {
+                "type": "string"
+            }
+        },
+        "rooms": {
+            "$ref": "#/definitions/room_id_array"
+        },
+        "not_rooms": {
+            "$ref": "#/definitions/room_id_array"
+        },
+        "contains_url": {
+            "type": "boolean"
+        }
+    }
+}
+
+USER_ID_ARRAY_SCHEMA = {
+    "type": "array",
+    "items": {
+        "type": "string",
+        "format": "matrix_user_id"
+    }
+}
+
+ROOM_ID_ARRAY_SCHEMA = {
+    "type": "array",
+    "items": {
+        "type": "string",
+        "format": "matrix_room_id"
+    }
+}
+
+USER_FILTER_SCHEMA = {
+    "$schema": "http://json-schema.org/draft-04/schema#",
+    "description": "schema for a Sync filter",
+    "type": "object",
+    "definitions": {
+        "room_id_array": ROOM_ID_ARRAY_SCHEMA,
+        "user_id_array": USER_ID_ARRAY_SCHEMA,
+        "filter": FILTER_SCHEMA,
+        "room_filter": ROOM_FILTER_SCHEMA,
+        "room_event_filter": ROOM_EVENT_FILTER_SCHEMA
+    },
+    "properties": {
+        "presence": {
+            "$ref": "#/definitions/filter"
+        },
+        "account_data": {
+            "$ref": "#/definitions/filter"
+        },
+        "room": {
+            "$ref": "#/definitions/room_filter"
+        },
+        "event_format": {
+            "type": "string",
+            "enum": ["client", "federation"]
+        },
+        "event_fields": {
+            "type": "array",
+            "items": {
+                "type": "string",
+                # Don't allow '\\' in event field filters. This makes matching
+                # events a lot easier as we can then use a negative lookbehind
+                # assertion to split '\.' If we allowed \\ then it would
+                # incorrectly split '\\.' See synapse.events.utils.serialize_event
+                "pattern": "^((?!\\\).)*$"
+            }
+        }
+    },
+    "additionalProperties": False
+}
+
+
+@FormatChecker.cls_checks('matrix_room_id')
+def matrix_room_id_validator(room_id_str):
+    return RoomID.from_string(room_id_str)
+
+
+@FormatChecker.cls_checks('matrix_user_id')
+def matrix_user_id_validator(user_id_str):
+    return UserID.from_string(user_id_str)
 
 
 class Filtering(object):
@@ -53,98 +215,11 @@ class Filtering(object):
         # NB: Filters are the complete json blobs. "Definitions" are an
         # individual top-level key e.g. public_user_data. Filters are made of
         # many definitions.
-
-        top_level_definitions = [
-            "presence", "account_data"
-        ]
-
-        room_level_definitions = [
-            "state", "timeline", "ephemeral", "account_data"
-        ]
-
-        for key in top_level_definitions:
-            if key in user_filter_json:
-                self._check_definition(user_filter_json[key])
-
-        if "room" in user_filter_json:
-            self._check_definition_room_lists(user_filter_json["room"])
-            for key in room_level_definitions:
-                if key in user_filter_json["room"]:
-                    self._check_definition(user_filter_json["room"][key])
-
-        if "event_fields" in user_filter_json:
-            if type(user_filter_json["event_fields"]) != list:
-                raise SynapseError(400, "event_fields must be a list of strings")
-            for field in user_filter_json["event_fields"]:
-                if not isinstance(field, basestring):
-                    raise SynapseError(400, "Event field must be a string")
-                # Don't allow '\\' in event field filters. This makes matching
-                # events a lot easier as we can then use a negative lookbehind
-                # assertion to split '\.' If we allowed \\ then it would
-                # incorrectly split '\\.' See synapse.events.utils.serialize_event
-                if r'\\' in field:
-                    raise SynapseError(
-                        400, r'The escape character \ cannot itself be escaped'
-                    )
-
-    def _check_definition_room_lists(self, definition):
-        """Check that "rooms" and "not_rooms" are lists of room ids if they
-        are present
-
-        Args:
-            definition(dict): The filter definition
-        Raises:
-            SynapseError: If there was a problem with this definition.
-        """
-        # check rooms are valid room IDs
-        room_id_keys = ["rooms", "not_rooms"]
-        for key in room_id_keys:
-            if key in definition:
-                if type(definition[key]) != list:
-                    raise SynapseError(400, "Expected %s to be a list." % key)
-                for room_id in definition[key]:
-                    RoomID.from_string(room_id)
-
-    def _check_definition(self, definition):
-        """Check if the provided definition is valid.
-
-        This inspects not only the types but also the values to make sure they
-        make sense.
-
-        Args:
-            definition(dict): The filter definition
-        Raises:
-            SynapseError: If there was a problem with this definition.
-        """
-        # NB: Filters are the complete json blobs. "Definitions" are an
-        # individual top-level key e.g. public_user_data. Filters are made of
-        # many definitions.
-        if type(definition) != dict:
-            raise SynapseError(
-                400, "Expected JSON object, not %s" % (definition,)
-            )
-
-        self._check_definition_room_lists(definition)
-
-        # check senders are valid user IDs
-        user_id_keys = ["senders", "not_senders"]
-        for key in user_id_keys:
-            if key in definition:
-                if type(definition[key]) != list:
-                    raise SynapseError(400, "Expected %s to be a list." % key)
-                for user_id in definition[key]:
-                    UserID.from_string(user_id)
-
-        # TODO: We don't limit event type values but we probably should...
-        # check types are valid event types
-        event_keys = ["types", "not_types"]
-        for key in event_keys:
-            if key in definition:
-                if type(definition[key]) != list:
-                    raise SynapseError(400, "Expected %s to be a list." % key)
-                for event_type in definition[key]:
-                    if not isinstance(event_type, basestring):
-                        raise SynapseError(400, "Event type should be a string")
+        try:
+            jsonschema.validate(user_filter_json, USER_FILTER_SCHEMA,
+                                format_checker=FormatChecker())
+        except jsonschema.ValidationError as e:
+            raise SynapseError(400, e.message)
 
 
 class FilterCollection(object):
diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py
index 83ee3e3ce3..a6f1e7594e 100644
--- a/synapse/app/appservice.py
+++ b/synapse/app/appservice.py
@@ -29,7 +29,7 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto
 from synapse.storage.engines import create_engine
 from synapse.util.async import sleep
 from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext
+from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
 from synapse.util.manhole import manhole
 from synapse.util.rlimit import change_resource_limit
 from synapse.util.versionstring import get_version_string
@@ -187,7 +187,11 @@ def start(config_options):
     ps.start_listening(config.worker_listeners)
 
     def run():
-        with LoggingContext("run"):
+        # make sure that we run the reactor with the sentinel log context,
+        # otherwise other PreserveLoggingContext instances will get confused
+        # and complain when they see the logcontext arbitrarily swapping
+        # between the sentinel and `run` logcontexts.
+        with PreserveLoggingContext():
             logger.info("Running")
             change_resource_limit(config.soft_file_limit)
             if config.gc_thresholds:
diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py
index 7ed0de4117..e4ea3ab933 100644
--- a/synapse/app/client_reader.py
+++ b/synapse/app/client_reader.py
@@ -29,13 +29,14 @@ from synapse.replication.slave.storage.keys import SlavedKeyStore
 from synapse.replication.slave.storage.room import RoomStore
 from synapse.replication.slave.storage.directory import DirectoryStore
 from synapse.replication.slave.storage.registration import SlavedRegistrationStore
+from synapse.replication.slave.storage.transactions import TransactionStore
 from synapse.rest.client.v1.room import PublicRoomListRestServlet
 from synapse.server import HomeServer
 from synapse.storage.client_ips import ClientIpStore
 from synapse.storage.engines import create_engine
 from synapse.util.async import sleep
 from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext
+from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
 from synapse.util.manhole import manhole
 from synapse.util.rlimit import change_resource_limit
 from synapse.util.versionstring import get_version_string
@@ -63,6 +64,7 @@ class ClientReaderSlavedStore(
     DirectoryStore,
     SlavedApplicationServiceStore,
     SlavedRegistrationStore,
+    TransactionStore,
     BaseSlavedStore,
     ClientIpStore,  # After BaseSlavedStore because the constructor is different
 ):
@@ -193,7 +195,11 @@ def start(config_options):
     ss.start_listening(config.worker_listeners)
 
     def run():
-        with LoggingContext("run"):
+        # make sure that we run the reactor with the sentinel log context,
+        # otherwise other PreserveLoggingContext instances will get confused
+        # and complain when they see the logcontext arbitrarily swapping
+        # between the sentinel and `run` logcontexts.
+        with PreserveLoggingContext():
             logger.info("Running")
             change_resource_limit(config.soft_file_limit)
             if config.gc_thresholds:
diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py
index ca742de6b2..e52b0f240d 100644
--- a/synapse/app/federation_reader.py
+++ b/synapse/app/federation_reader.py
@@ -31,7 +31,7 @@ from synapse.server import HomeServer
 from synapse.storage.engines import create_engine
 from synapse.util.async import sleep
 from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext
+from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
 from synapse.util.manhole import manhole
 from synapse.util.rlimit import change_resource_limit
 from synapse.util.versionstring import get_version_string
@@ -184,7 +184,11 @@ def start(config_options):
     ss.start_listening(config.worker_listeners)
 
     def run():
-        with LoggingContext("run"):
+        # make sure that we run the reactor with the sentinel log context,
+        # otherwise other PreserveLoggingContext instances will get confused
+        # and complain when they see the logcontext arbitrarily swapping
+        # between the sentinel and `run` logcontexts.
+        with PreserveLoggingContext():
             logger.info("Running")
             change_resource_limit(config.soft_file_limit)
             if config.gc_thresholds:
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index 0cf5b196e6..76c4cc54d1 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -35,7 +35,7 @@ from synapse.storage.engines import create_engine
 from synapse.storage.presence import UserPresenceState
 from synapse.util.async import sleep
 from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext
+from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
 from synapse.util.manhole import manhole
 from synapse.util.rlimit import change_resource_limit
 from synapse.util.versionstring import get_version_string
@@ -193,7 +193,11 @@ def start(config_options):
     ps.start_listening(config.worker_listeners)
 
     def run():
-        with LoggingContext("run"):
+        # make sure that we run the reactor with the sentinel log context,
+        # otherwise other PreserveLoggingContext instances will get confused
+        # and complain when they see the logcontext arbitrarily swapping
+        # between the sentinel and `run` logcontexts.
+        with PreserveLoggingContext():
             logger.info("Running")
             change_resource_limit(config.soft_file_limit)
             if config.gc_thresholds:
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 0b9d78c13c..2cdd2d39ff 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -52,7 +52,7 @@ from synapse.api.urls import (
 )
 from synapse.config.homeserver import HomeServerConfig
 from synapse.crypto import context_factory
-from synapse.util.logcontext import LoggingContext
+from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
 from synapse.metrics import register_memory_metrics, get_metrics_for
 from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
 from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX
@@ -456,7 +456,12 @@ def run(hs):
     def in_thread():
         # Uncomment to enable tracing of log context changes.
         # sys.settrace(logcontext_tracer)
-        with LoggingContext("run"):
+
+        # make sure that we run the reactor with the sentinel log context,
+        # otherwise other PreserveLoggingContext instances will get confused
+        # and complain when they see the logcontext arbitrarily swapping
+        # between the sentinel and `run` logcontexts.
+        with PreserveLoggingContext():
             change_resource_limit(hs.config.soft_file_limit)
             if hs.config.gc_thresholds:
                 gc.set_threshold(*hs.config.gc_thresholds)
diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py
index c5579d9e38..1444e69a42 100644
--- a/synapse/app/media_repository.py
+++ b/synapse/app/media_repository.py
@@ -24,6 +24,7 @@ from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
 from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
 from synapse.replication.slave.storage.registration import SlavedRegistrationStore
+from synapse.replication.slave.storage.transactions import TransactionStore
 from synapse.rest.media.v0.content_repository import ContentRepoResource
 from synapse.rest.media.v1.media_repository import MediaRepositoryResource
 from synapse.server import HomeServer
@@ -32,7 +33,7 @@ from synapse.storage.engines import create_engine
 from synapse.storage.media_repository import MediaRepositoryStore
 from synapse.util.async import sleep
 from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext
+from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
 from synapse.util.manhole import manhole
 from synapse.util.rlimit import change_resource_limit
 from synapse.util.versionstring import get_version_string
@@ -59,6 +60,7 @@ logger = logging.getLogger("synapse.app.media_repository")
 class MediaRepositorySlavedStore(
     SlavedApplicationServiceStore,
     SlavedRegistrationStore,
+    TransactionStore,
     BaseSlavedStore,
     MediaRepositoryStore,
     ClientIpStore,
@@ -190,7 +192,11 @@ def start(config_options):
     ss.start_listening(config.worker_listeners)
 
     def run():
-        with LoggingContext("run"):
+        # make sure that we run the reactor with the sentinel log context,
+        # otherwise other PreserveLoggingContext instances will get confused
+        # and complain when they see the logcontext arbitrarily swapping
+        # between the sentinel and `run` logcontexts.
+        with PreserveLoggingContext():
             logger.info("Running")
             change_resource_limit(config.soft_file_limit)
             if config.gc_thresholds:
diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py
index b025db54d4..ab682e52ec 100644
--- a/synapse/app/pusher.py
+++ b/synapse/app/pusher.py
@@ -31,7 +31,8 @@ from synapse.storage.engines import create_engine
 from synapse.storage import DataStore
 from synapse.util.async import sleep
 from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext, preserve_fn
+from synapse.util.logcontext import LoggingContext, preserve_fn, \
+    PreserveLoggingContext
 from synapse.util.manhole import manhole
 from synapse.util.rlimit import change_resource_limit
 from synapse.util.versionstring import get_version_string
@@ -275,7 +276,11 @@ def start(config_options):
     ps.start_listening(config.worker_listeners)
 
     def run():
-        with LoggingContext("run"):
+        # make sure that we run the reactor with the sentinel log context,
+        # otherwise other PreserveLoggingContext instances will get confused
+        # and complain when they see the logcontext arbitrarily swapping
+        # between the sentinel and `run` logcontexts.
+        with PreserveLoggingContext():
             logger.info("Running")
             change_resource_limit(config.soft_file_limit)
             if config.gc_thresholds:
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index 449fac771b..34e34e5580 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -20,7 +20,6 @@ from synapse.api.constants import EventTypes, PresenceState
 from synapse.config._base import ConfigError
 from synapse.config.homeserver import HomeServerConfig
 from synapse.config.logger import setup_logging
-from synapse.events import FrozenEvent
 from synapse.handlers.presence import PresenceHandler
 from synapse.http.site import SynapseSite
 from synapse.http.server import JsonResource
@@ -48,7 +47,8 @@ from synapse.storage.presence import PresenceStore, UserPresenceState
 from synapse.storage.roommember import RoomMemberStore
 from synapse.util.async import sleep
 from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext, preserve_fn
+from synapse.util.logcontext import LoggingContext, preserve_fn, \
+    PreserveLoggingContext
 from synapse.util.manhole import manhole
 from synapse.util.rlimit import change_resource_limit
 from synapse.util.stringutils import random_string
@@ -410,11 +410,16 @@ class SynchrotronServer(HomeServer):
             stream = result.get("events")
             if stream:
                 max_position = stream["position"]
+
+                event_map = yield store.get_events([row[1] for row in stream["rows"]])
+
                 for row in stream["rows"]:
                     position = row[0]
-                    internal = json.loads(row[1])
-                    event_json = json.loads(row[2])
-                    event = FrozenEvent(event_json, internal_metadata_dict=internal)
+                    event_id = row[1]
+                    event = event_map.get(event_id, None)
+                    if not event:
+                        continue
+
                     extra_users = ()
                     if event.type == EventTypes.Member:
                         extra_users = (event.state_key,)
@@ -496,7 +501,11 @@ def start(config_options):
     ss.start_listening(config.worker_listeners)
 
     def run():
-        with LoggingContext("run"):
+        # make sure that we run the reactor with the sentinel log context,
+        # otherwise other PreserveLoggingContext instances will get confused
+        # and complain when they see the logcontext arbitrarily swapping
+        # between the sentinel and `run` logcontexts.
+        with PreserveLoggingContext():
             logger.info("Running")
             change_resource_limit(config.soft_file_limit)
             if config.gc_thresholds:
diff --git a/synapse/app/synctl.py b/synapse/app/synctl.py
index c045588866..23eb6a1ec4 100755
--- a/synapse/app/synctl.py
+++ b/synapse/app/synctl.py
@@ -23,14 +23,27 @@ import signal
 import subprocess
 import sys
 import yaml
+import errno
+import time
 
 SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"]
 
 GREEN = "\x1b[1;32m"
+YELLOW = "\x1b[1;33m"
 RED = "\x1b[1;31m"
 NORMAL = "\x1b[m"
 
 
+def pid_running(pid):
+    try:
+        os.kill(pid, 0)
+        return True
+    except OSError, err:
+        if err.errno == errno.EPERM:
+            return True
+        return False
+
+
 def write(message, colour=NORMAL, stream=sys.stdout):
     if colour == NORMAL:
         stream.write(message + "\n")
@@ -38,6 +51,11 @@ def write(message, colour=NORMAL, stream=sys.stdout):
         stream.write(colour + message + NORMAL + "\n")
 
 
+def abort(message, colour=RED, stream=sys.stderr):
+    write(message, colour, stream)
+    sys.exit(1)
+
+
 def start(configfile):
     write("Starting ...")
     args = SYNAPSE
@@ -45,7 +63,8 @@ def start(configfile):
 
     try:
         subprocess.check_call(args)
-        write("started synapse.app.homeserver(%r)" % (configfile,), colour=GREEN)
+        write("started synapse.app.homeserver(%r)" %
+              (configfile,), colour=GREEN)
     except subprocess.CalledProcessError as e:
         write(
             "error starting (exit code: %d); see above for logs" % e.returncode,
@@ -76,8 +95,16 @@ def start_worker(app, configfile, worker_configfile):
 def stop(pidfile, app):
     if os.path.exists(pidfile):
         pid = int(open(pidfile).read())
-        os.kill(pid, signal.SIGTERM)
-        write("stopped %s" % (app,), colour=GREEN)
+        try:
+            os.kill(pid, signal.SIGTERM)
+            write("stopped %s" % (app,), colour=GREEN)
+        except OSError, err:
+            if err.errno == errno.ESRCH:
+                write("%s not running" % (app,), colour=YELLOW)
+            elif err.errno == errno.EPERM:
+                abort("Cannot stop %s: Operation not permitted" % (app,))
+            else:
+                abort("Cannot stop %s: Unknown error" % (app,))
 
 
 Worker = collections.namedtuple("Worker", [
@@ -190,7 +217,19 @@ def main():
         if start_stop_synapse:
             stop(pidfile, "synapse.app.homeserver")
 
-        # TODO: Wait for synapse to actually shutdown before starting it again
+    # Wait for synapse to actually shutdown before starting it again
+    if action == "restart":
+        running_pids = []
+        if start_stop_synapse and os.path.exists(pidfile):
+            running_pids.append(int(open(pidfile).read()))
+        for worker in workers:
+            if os.path.exists(worker.pidfile):
+                running_pids.append(int(open(worker.pidfile).read()))
+        if len(running_pids) > 0:
+            write("Waiting for process to exit before restarting...")
+            for running_pid in running_pids:
+                while pid_running(running_pid):
+                    time.sleep(0.2)
 
     if action == "start" or action == "restart":
         if start_stop_synapse:
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index b0106a3597..7346206bb1 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from synapse.api.constants import EventTypes
+from synapse.util.caches.descriptors import cachedInlineCallbacks
 
 from twisted.internet import defer
 
@@ -124,29 +125,23 @@ class ApplicationService(object):
                     raise ValueError(
                         "Expected bool for 'exclusive' in ns '%s'" % ns
                     )
-                if not isinstance(regex_obj.get("regex"), basestring):
+                regex = regex_obj.get("regex")
+                if isinstance(regex, basestring):
+                    regex_obj["regex"] = re.compile(regex)  # Pre-compile regex
+                else:
                     raise ValueError(
                         "Expected string for 'regex' in ns '%s'" % ns
                     )
         return namespaces
 
-    def _matches_regex(self, test_string, namespace_key, return_obj=False):
-        if not isinstance(test_string, basestring):
-            logger.error(
-                "Expected a string to test regex against, but got %s",
-                test_string
-            )
-            return False
-
+    def _matches_regex(self, test_string, namespace_key):
         for regex_obj in self.namespaces[namespace_key]:
-            if re.match(regex_obj["regex"], test_string):
-                if return_obj:
-                    return regex_obj
-                return True
-        return False
+            if regex_obj["regex"].match(test_string):
+                return regex_obj
+        return None
 
     def _is_exclusive(self, ns_key, test_string):
-        regex_obj = self._matches_regex(test_string, ns_key, return_obj=True)
+        regex_obj = self._matches_regex(test_string, ns_key)
         if regex_obj:
             return regex_obj["exclusive"]
         return False
@@ -166,7 +161,14 @@ class ApplicationService(object):
         if not store:
             defer.returnValue(False)
 
-        member_list = yield store.get_users_in_room(event.room_id)
+        does_match = yield self._matches_user_in_member_list(event.room_id, store)
+        defer.returnValue(does_match)
+
+    @cachedInlineCallbacks(num_args=1, cache_context=True)
+    def _matches_user_in_member_list(self, room_id, store, cache_context):
+        member_list = yield store.get_users_in_room(
+            room_id, on_invalidate=cache_context.invalidate
+        )
 
         # check joined member events
         for user_id in member_list:
@@ -219,10 +221,10 @@ class ApplicationService(object):
         )
 
     def is_interested_in_alias(self, alias):
-        return self._matches_regex(alias, ApplicationService.NS_ALIASES)
+        return bool(self._matches_regex(alias, ApplicationService.NS_ALIASES))
 
     def is_interested_in_room(self, room_id):
-        return self._matches_regex(room_id, ApplicationService.NS_ROOMS)
+        return bool(self._matches_regex(room_id, ApplicationService.NS_ROOMS))
 
     def is_exclusive_user(self, user_id):
         return (
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index d7211ee9b3..1bb27edc0f 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -15,7 +15,6 @@
 
 from synapse.crypto.keyclient import fetch_server_key
 from synapse.api.errors import SynapseError, Codes
-from synapse.util.retryutils import get_retry_limiter
 from synapse.util import unwrapFirstError
 from synapse.util.async import ObservableDeferred
 from synapse.util.logcontext import (
@@ -96,10 +95,11 @@ class Keyring(object):
         verify_requests = []
 
         for server_name, json_object in server_and_json:
-            logger.debug("Verifying for %s", server_name)
 
             key_ids = signature_ids(json_object, server_name)
             if not key_ids:
+                logger.warn("Request from %s: no supported signature keys",
+                            server_name)
                 deferred = defer.fail(SynapseError(
                     400,
                     "Not signed with a supported algorithm",
@@ -108,6 +108,9 @@ class Keyring(object):
             else:
                 deferred = defer.Deferred()
 
+            logger.debug("Verifying for %s with key_ids %s",
+                         server_name, key_ids)
+
             verify_request = VerifyKeyRequest(
                 server_name, key_ids, json_object, deferred
             )
@@ -142,6 +145,9 @@ class Keyring(object):
 
             json_object = verify_request.json_object
 
+            logger.debug("Got key %s %s:%s for server %s, verifying" % (
+                key_id, verify_key.alg, verify_key.version, server_name,
+            ))
             try:
                 verify_signed_json(json_object, server_name, verify_key)
             except:
@@ -231,8 +237,14 @@ class Keyring(object):
             d.addBoth(rm, server_name)
 
     def get_server_verify_keys(self, verify_requests):
-        """Takes a dict of KeyGroups and tries to find at least one key for
-        each group.
+        """Tries to find at least one key for each verify request
+
+        For each verify_request, verify_request.deferred is called back with
+        params (server_name, key_id, VerifyKey) if a key is found, or errbacked
+        with a SynapseError if none of the keys are found.
+
+        Args:
+            verify_requests (list[VerifyKeyRequest]): list of verify requests
         """
 
         # These are functions that produce keys given a list of key ids
@@ -245,8 +257,11 @@ class Keyring(object):
         @defer.inlineCallbacks
         def do_iterations():
             with Measure(self.clock, "get_server_verify_keys"):
+                # dict[str, dict[str, VerifyKey]]: results so far.
+                # map server_name -> key_id -> VerifyKey
                 merged_results = {}
 
+                # dict[str, set(str)]: keys to fetch for each server
                 missing_keys = {}
                 for verify_request in verify_requests:
                     missing_keys.setdefault(verify_request.server_name, set()).update(
@@ -308,6 +323,16 @@ class Keyring(object):
 
     @defer.inlineCallbacks
     def get_keys_from_store(self, server_name_and_key_ids):
+        """
+
+        Args:
+            server_name_and_key_ids (list[(str, iterable[str])]):
+                list of (server_name, iterable[key_id]) tuples to fetch keys for
+
+        Returns:
+            Deferred: resolves to dict[str, dict[str, VerifyKey]]: map from
+                server_name -> key_id -> VerifyKey
+        """
         res = yield preserve_context_over_deferred(defer.gatherResults(
             [
                 preserve_fn(self.store.get_server_verify_keys)(
@@ -356,30 +381,24 @@ class Keyring(object):
     def get_keys_from_server(self, server_name_and_key_ids):
         @defer.inlineCallbacks
         def get_key(server_name, key_ids):
-            limiter = yield get_retry_limiter(
-                server_name,
-                self.clock,
-                self.store,
-            )
-            with limiter:
-                keys = None
-                try:
-                    keys = yield self.get_server_verify_key_v2_direct(
-                        server_name, key_ids
-                    )
-                except Exception as e:
-                    logger.info(
-                        "Unable to get key %r for %r directly: %s %s",
-                        key_ids, server_name,
-                        type(e).__name__, str(e.message),
-                    )
+            keys = None
+            try:
+                keys = yield self.get_server_verify_key_v2_direct(
+                    server_name, key_ids
+                )
+            except Exception as e:
+                logger.info(
+                    "Unable to get key %r for %r directly: %s %s",
+                    key_ids, server_name,
+                    type(e).__name__, str(e.message),
+                )
 
-                if not keys:
-                    keys = yield self.get_server_verify_key_v1_direct(
-                        server_name, key_ids
-                    )
+            if not keys:
+                keys = yield self.get_server_verify_key_v1_direct(
+                    server_name, key_ids
+                )
 
-                    keys = {server_name: keys}
+                keys = {server_name: keys}
 
             defer.returnValue(keys)
 
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 11605b34a3..6be18880b9 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -15,6 +15,32 @@
 
 
 class EventContext(object):
+    """
+    Attributes:
+        current_state_ids (dict[(str, str), str]):
+            The current state map including the current event.
+            (type, state_key) -> event_id
+
+        prev_state_ids (dict[(str, str), str]):
+            The current state map excluding the current event.
+            (type, state_key) -> event_id
+
+        state_group (int): state group id
+        rejected (bool|str): A rejection reason if the event was rejected, else
+            False
+
+        push_actions (list[(str, list[object])]): list of (user_id, actions)
+            tuples
+
+        prev_group (int): Previously persisted state group. ``None`` for an
+            outlier.
+        delta_ids (dict[(str, str), str]): Delta from ``prev_group``.
+            (type, state_key) -> event_id. ``None`` for an outlier.
+
+        prev_state_events (?): XXX: is this ever set to anything other than
+            the empty list?
+    """
+
     __slots__ = [
         "current_state_ids",
         "prev_state_ids",
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 5dcd4eecce..deee0f4904 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -29,7 +29,7 @@ from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
 from synapse.events import FrozenEvent, builder
 import synapse.metrics
 
-from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
+from synapse.util.retryutils import NotRetryingDestination
 
 import copy
 import itertools
@@ -88,7 +88,7 @@ class FederationClient(FederationBase):
 
     @log_function
     def make_query(self, destination, query_type, args,
-                   retry_on_dns_fail=False):
+                   retry_on_dns_fail=False, ignore_backoff=False):
         """Sends a federation Query to a remote homeserver of the given type
         and arguments.
 
@@ -98,6 +98,8 @@ class FederationClient(FederationBase):
                 handler name used in register_query_handler().
             args (dict): Mapping of strings to strings containing the details
                 of the query request.
+            ignore_backoff (bool): true to ignore the historical backoff data
+                and try the request anyway.
 
         Returns:
             a Deferred which will eventually yield a JSON object from the
@@ -106,7 +108,8 @@ class FederationClient(FederationBase):
         sent_queries_counter.inc(query_type)
 
         return self.transport_layer.make_query(
-            destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail
+            destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail,
+            ignore_backoff=ignore_backoff,
         )
 
     @log_function
@@ -234,31 +237,24 @@ class FederationClient(FederationBase):
                 continue
 
             try:
-                limiter = yield get_retry_limiter(
-                    destination,
-                    self._clock,
-                    self.store,
+                transaction_data = yield self.transport_layer.get_event(
+                    destination, event_id, timeout=timeout,
                 )
 
-                with limiter:
-                    transaction_data = yield self.transport_layer.get_event(
-                        destination, event_id, timeout=timeout,
-                    )
-
-                    logger.debug("transaction_data %r", transaction_data)
+                logger.debug("transaction_data %r", transaction_data)
 
-                    pdu_list = [
-                        self.event_from_pdu_json(p, outlier=outlier)
-                        for p in transaction_data["pdus"]
-                    ]
+                pdu_list = [
+                    self.event_from_pdu_json(p, outlier=outlier)
+                    for p in transaction_data["pdus"]
+                ]
 
-                    if pdu_list and pdu_list[0]:
-                        pdu = pdu_list[0]
+                if pdu_list and pdu_list[0]:
+                    pdu = pdu_list[0]
 
-                        # Check signatures are correct.
-                        signed_pdu = yield self._check_sigs_and_hashes([pdu])[0]
+                    # Check signatures are correct.
+                    signed_pdu = yield self._check_sigs_and_hashes([pdu])[0]
 
-                        break
+                    break
 
                 pdu_attempts[destination] = now
 
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 5c9f7a86f0..bbb0195228 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -54,6 +54,7 @@ class FederationRemoteSendQueue(object):
     def __init__(self, hs):
         self.server_name = hs.hostname
         self.clock = hs.get_clock()
+        self.notifier = hs.get_notifier()
 
         self.presence_map = {}
         self.presence_changed = sorteddict()
@@ -186,6 +187,8 @@ class FederationRemoteSendQueue(object):
         else:
             self.edus[pos] = edu
 
+        self.notifier.on_new_replication_data()
+
     def send_presence(self, destination, states):
         """As per TransactionQueue"""
         pos = self._next_pos()
@@ -199,16 +202,20 @@ class FederationRemoteSendQueue(object):
             (destination, state.user_id) for state in states
         ]
 
+        self.notifier.on_new_replication_data()
+
     def send_failure(self, failure, destination):
         """As per TransactionQueue"""
         pos = self._next_pos()
 
         self.failures[pos] = (destination, str(failure))
+        self.notifier.on_new_replication_data()
 
     def send_device_messages(self, destination):
         """As per TransactionQueue"""
         pos = self._next_pos()
         self.device_messages[pos] = destination
+        self.notifier.on_new_replication_data()
 
     def get_current_token(self):
         return self.pos - 1
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index c802dd67a3..c27ce7c5f3 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -12,7 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+import datetime
 
 from twisted.internet import defer
 
@@ -22,9 +22,7 @@ from .units import Transaction, Edu
 from synapse.api.errors import HttpResponseException
 from synapse.util.async import run_on_reactor
 from synapse.util.logcontext import preserve_context_over_fn
-from synapse.util.retryutils import (
-    get_retry_limiter, NotRetryingDestination,
-)
+from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
 from synapse.util.metrics import measure_func
 from synapse.types import get_domain_from_id
 from synapse.handlers.presence import format_user_presence_state
@@ -305,20 +303,20 @@ class TransactionQueue(object):
             )
             return
 
+        pending_pdus = []
         try:
             self.pending_transactions[destination] = 1
 
+            # This will throw if we wouldn't retry. We do this here so we fail
+            # quickly, but we will later check this again in the http client,
+            # hence why we throw the result away.
+            yield get_retry_limiter(destination, self.clock, self.store)
+
             # XXX: what's this for?
             yield run_on_reactor()
 
+            pending_pdus = []
             while True:
-                limiter = yield get_retry_limiter(
-                    destination,
-                    self.clock,
-                    self.store,
-                    backoff_on_404=True,  # If we get a 404 the other side has gone
-                )
-
                 device_message_edus, device_stream_id, dev_list_id = (
                     yield self._get_new_device_messages(destination)
                 )
@@ -374,7 +372,6 @@ class TransactionQueue(object):
 
                 success = yield self._send_new_transaction(
                     destination, pending_pdus, pending_edus, pending_failures,
-                    limiter=limiter,
                 )
                 if success:
                     # Remove the acknowledged device messages from the database
@@ -392,12 +389,24 @@ class TransactionQueue(object):
                     self.last_device_list_stream_id_by_dest[destination] = dev_list_id
                 else:
                     break
-        except NotRetryingDestination:
+        except NotRetryingDestination as e:
             logger.debug(
-                "TX [%s] not ready for retry yet - "
+                "TX [%s] not ready for retry yet (next retry at %s) - "
                 "dropping transaction for now",
                 destination,
+                datetime.datetime.fromtimestamp(
+                    (e.retry_last_ts + e.retry_interval) / 1000.0
+                ),
+            )
+        except Exception as e:
+            logger.warn(
+                "TX [%s] Failed to send transaction: %s",
+                destination,
+                e,
             )
+            for p, _ in pending_pdus:
+                logger.info("Failed to send event %s to %s", p.event_id,
+                            destination)
         finally:
             # We want to be *very* sure we delete this after we stop processing
             self.pending_transactions.pop(destination, None)
@@ -437,7 +446,7 @@ class TransactionQueue(object):
     @measure_func("_send_new_transaction")
     @defer.inlineCallbacks
     def _send_new_transaction(self, destination, pending_pdus, pending_edus,
-                              pending_failures, limiter):
+                              pending_failures):
 
         # Sort based on the order field
         pending_pdus.sort(key=lambda t: t[1])
@@ -447,132 +456,104 @@ class TransactionQueue(object):
 
         success = True
 
-        try:
-            logger.debug("TX [%s] _attempt_new_transaction", destination)
+        logger.debug("TX [%s] _attempt_new_transaction", destination)
 
-            txn_id = str(self._next_txn_id)
+        txn_id = str(self._next_txn_id)
 
-            logger.debug(
-                "TX [%s] {%s} Attempting new transaction"
-                " (pdus: %d, edus: %d, failures: %d)",
-                destination, txn_id,
-                len(pdus),
-                len(edus),
-                len(failures)
-            )
+        logger.debug(
+            "TX [%s] {%s} Attempting new transaction"
+            " (pdus: %d, edus: %d, failures: %d)",
+            destination, txn_id,
+            len(pdus),
+            len(edus),
+            len(failures)
+        )
 
-            logger.debug("TX [%s] Persisting transaction...", destination)
+        logger.debug("TX [%s] Persisting transaction...", destination)
 
-            transaction = Transaction.create_new(
-                origin_server_ts=int(self.clock.time_msec()),
-                transaction_id=txn_id,
-                origin=self.server_name,
-                destination=destination,
-                pdus=pdus,
-                edus=edus,
-                pdu_failures=failures,
-            )
+        transaction = Transaction.create_new(
+            origin_server_ts=int(self.clock.time_msec()),
+            transaction_id=txn_id,
+            origin=self.server_name,
+            destination=destination,
+            pdus=pdus,
+            edus=edus,
+            pdu_failures=failures,
+        )
 
-            self._next_txn_id += 1
+        self._next_txn_id += 1
 
-            yield self.transaction_actions.prepare_to_send(transaction)
+        yield self.transaction_actions.prepare_to_send(transaction)
 
-            logger.debug("TX [%s] Persisted transaction", destination)
-            logger.info(
-                "TX [%s] {%s} Sending transaction [%s],"
-                " (PDUs: %d, EDUs: %d, failures: %d)",
-                destination, txn_id,
-                transaction.transaction_id,
-                len(pdus),
-                len(edus),
-                len(failures),
-            )
+        logger.debug("TX [%s] Persisted transaction", destination)
+        logger.info(
+            "TX [%s] {%s} Sending transaction [%s],"
+            " (PDUs: %d, EDUs: %d, failures: %d)",
+            destination, txn_id,
+            transaction.transaction_id,
+            len(pdus),
+            len(edus),
+            len(failures),
+        )
 
-            with limiter:
-                # Actually send the transaction
-
-                # FIXME (erikj): This is a bit of a hack to make the Pdu age
-                # keys work
-                def json_data_cb():
-                    data = transaction.get_dict()
-                    now = int(self.clock.time_msec())
-                    if "pdus" in data:
-                        for p in data["pdus"]:
-                            if "age_ts" in p:
-                                unsigned = p.setdefault("unsigned", {})
-                                unsigned["age"] = now - int(p["age_ts"])
-                                del p["age_ts"]
-                    return data
-
-                try:
-                    response = yield self.transport_layer.send_transaction(
-                        transaction, json_data_cb
-                    )
-                    code = 200
-
-                    if response:
-                        for e_id, r in response.get("pdus", {}).items():
-                            if "error" in r:
-                                logger.warn(
-                                    "Transaction returned error for %s: %s",
-                                    e_id, r,
-                                )
-                except HttpResponseException as e:
-                    code = e.code
-                    response = e.response
-
-                    if e.code in (401, 404, 429) or 500 <= e.code:
-                        logger.info(
-                            "TX [%s] {%s} got %d response",
-                            destination, txn_id, code
+        # Actually send the transaction
+
+        # FIXME (erikj): This is a bit of a hack to make the Pdu age
+        # keys work
+        def json_data_cb():
+            data = transaction.get_dict()
+            now = int(self.clock.time_msec())
+            if "pdus" in data:
+                for p in data["pdus"]:
+                    if "age_ts" in p:
+                        unsigned = p.setdefault("unsigned", {})
+                        unsigned["age"] = now - int(p["age_ts"])
+                        del p["age_ts"]
+            return data
+
+        try:
+            response = yield self.transport_layer.send_transaction(
+                transaction, json_data_cb
+            )
+            code = 200
+
+            if response:
+                for e_id, r in response.get("pdus", {}).items():
+                    if "error" in r:
+                        logger.warn(
+                            "Transaction returned error for %s: %s",
+                            e_id, r,
                         )
-                        raise e
+        except HttpResponseException as e:
+            code = e.code
+            response = e.response
 
+            if e.code in (401, 404, 429) or 500 <= e.code:
                 logger.info(
                     "TX [%s] {%s} got %d response",
                     destination, txn_id, code
                 )
+                raise e
 
-                logger.debug("TX [%s] Sent transaction", destination)
-                logger.debug("TX [%s] Marking as delivered...", destination)
+        logger.info(
+            "TX [%s] {%s} got %d response",
+            destination, txn_id, code
+        )
 
-            yield self.transaction_actions.delivered(
-                transaction, code, response
-            )
+        logger.debug("TX [%s] Sent transaction", destination)
+        logger.debug("TX [%s] Marking as delivered...", destination)
 
-            logger.debug("TX [%s] Marked as delivered", destination)
+        yield self.transaction_actions.delivered(
+            transaction, code, response
+        )
 
-            if code != 200:
-                for p in pdus:
-                    logger.info(
-                        "Failed to send event %s to %s", p.event_id, destination
-                    )
-                success = False
-        except RuntimeError as e:
-            # We capture this here as there as nothing actually listens
-            # for this finishing functions deferred.
-            logger.warn(
-                "TX [%s] Problem in _attempt_transaction: %s",
-                destination,
-                e,
-            )
-
-            success = False
+        logger.debug("TX [%s] Marked as delivered", destination)
 
+        if code != 200:
             for p in pdus:
-                logger.info("Failed to send event %s to %s", p.event_id, destination)
-        except Exception as e:
-            # We capture this here as there as nothing actually listens
-            # for this finishing functions deferred.
-            logger.warn(
-                "TX [%s] Problem in _attempt_transaction: %s",
-                destination,
-                e,
-            )
-
+                logger.info(
+                    "Failed to send event %s to %s", p.event_id, destination
+                )
             success = False
 
-            for p in pdus:
-                logger.info("Failed to send event %s to %s", p.event_id, destination)
-
         defer.returnValue(success)
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index f49e8a2cc4..15a03378f5 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -163,6 +163,7 @@ class TransportLayerClient(object):
             data=json_data,
             json_data_callback=json_data_callback,
             long_retries=True,
+            backoff_on_404=True,  # If we get a 404 the other side has gone
         )
 
         logger.debug(
@@ -174,7 +175,8 @@ class TransportLayerClient(object):
 
     @defer.inlineCallbacks
     @log_function
-    def make_query(self, destination, query_type, args, retry_on_dns_fail):
+    def make_query(self, destination, query_type, args, retry_on_dns_fail,
+                   ignore_backoff=False):
         path = PREFIX + "/query/%s" % query_type
 
         content = yield self.client.get_json(
@@ -183,6 +185,7 @@ class TransportLayerClient(object):
             args=args,
             retry_on_dns_fail=retry_on_dns_fail,
             timeout=10000,
+            ignore_backoff=ignore_backoff,
         )
 
         defer.returnValue(content)
@@ -242,6 +245,7 @@ class TransportLayerClient(object):
             destination=destination,
             path=path,
             data=content,
+            ignore_backoff=True,
         )
 
         defer.returnValue(response)
@@ -269,6 +273,7 @@ class TransportLayerClient(object):
             destination=remote_server,
             path=path,
             args=args,
+            ignore_backoff=True,
         )
 
         defer.returnValue(response)
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 1b5317edf5..943554ce98 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -175,6 +175,7 @@ class DirectoryHandler(BaseHandler):
                         "room_alias": room_alias.to_string(),
                     },
                     retry_on_dns_fail=False,
+                    ignore_backoff=True,
                 )
             except CodeMessageException as e:
                 logging.warn("Error retrieving alias")
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index e40495d1ab..c2b38d72a9 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -22,7 +22,7 @@ from twisted.internet import defer
 from synapse.api.errors import SynapseError, CodeMessageException
 from synapse.types import get_domain_from_id
 from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
-from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
+from synapse.util.retryutils import NotRetryingDestination
 
 logger = logging.getLogger(__name__)
 
@@ -121,15 +121,11 @@ class E2eKeysHandler(object):
         def do_remote_query(destination):
             destination_query = remote_queries_not_in_cache[destination]
             try:
-                limiter = yield get_retry_limiter(
-                    destination, self.clock, self.store
+                remote_result = yield self.federation.query_client_keys(
+                    destination,
+                    {"device_keys": destination_query},
+                    timeout=timeout
                 )
-                with limiter:
-                    remote_result = yield self.federation.query_client_keys(
-                        destination,
-                        {"device_keys": destination_query},
-                        timeout=timeout
-                    )
 
                 for user_id, keys in remote_result["device_keys"].items():
                     if user_id in destination_query:
@@ -239,18 +235,14 @@ class E2eKeysHandler(object):
         def claim_client_keys(destination):
             device_keys = remote_queries[destination]
             try:
-                limiter = yield get_retry_limiter(
-                    destination, self.clock, self.store
+                remote_result = yield self.federation.claim_client_keys(
+                    destination,
+                    {"one_time_keys": device_keys},
+                    timeout=timeout
                 )
-                with limiter:
-                    remote_result = yield self.federation.claim_client_keys(
-                        destination,
-                        {"one_time_keys": device_keys},
-                        timeout=timeout
-                    )
-                    for user_id, keys in remote_result["one_time_keys"].items():
-                        if user_id in device_keys:
-                            json_result[user_id] = keys
+                for user_id, keys in remote_result["one_time_keys"].items():
+                    if user_id in device_keys:
+                        json_result[user_id] = keys
             except CodeMessageException as e:
                 failures[destination] = {
                     "status": e.code, "message": e.message
@@ -316,7 +308,7 @@ class E2eKeysHandler(object):
         # old access_token without an associated device_id. Either way, we
         # need to double-check the device is registered to avoid ending up with
         # keys without a corresponding device.
-        self.device_handler.check_device_registered(user_id, device_id)
+        yield self.device_handler.check_device_registered(user_id, device_id)
 
         result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
 
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index d0c2b4d6ed..888dd01240 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 """Contains handlers for federation events."""
+import synapse.util.logcontext
 from signedjson.key import decode_verify_key_bytes
 from signedjson.sign import verify_signed_json
 from unpaddedbase64 import decode_base64
@@ -114,6 +115,14 @@ class FederationHandler(BaseHandler):
             logger.debug("Already seen pdu %s", pdu.event_id)
             return
 
+        # If we are currently in the process of joining this room, then we
+        # queue up events for later processing.
+        if pdu.room_id in self.room_queues:
+            logger.info("Ignoring PDU %s for room %s from %s for now; join "
+                        "in progress", pdu.event_id, pdu.room_id, origin)
+            self.room_queues[pdu.room_id].append((pdu, origin))
+            return
+
         state = None
 
         auth_chain = []
@@ -274,26 +283,13 @@ class FederationHandler(BaseHandler):
 
     @log_function
     @defer.inlineCallbacks
-    def _process_received_pdu(self, origin, pdu, state=None, auth_chain=None):
+    def _process_received_pdu(self, origin, pdu, state, auth_chain):
         """ Called when we have a new pdu. We need to do auth checks and put it
         through the StateHandler.
-
-        auth_chain and state are None if we already have the necessary state
-        and prev_events in the db
         """
         event = pdu
 
-        logger.debug("Got event: %s", event.event_id)
-
-        # If we are currently in the process of joining this room, then we
-        # queue up events for later processing.
-        if event.room_id in self.room_queues:
-            self.room_queues[event.room_id].append((pdu, origin))
-            return
-
-        logger.debug("Processing event: %s", event.event_id)
-
-        logger.debug("Event: %s", event)
+        logger.debug("Processing event: %s", event)
 
         # FIXME (erikj): Awful hack to make the case where we are not currently
         # in the room work
@@ -862,8 +858,6 @@ class FederationHandler(BaseHandler):
         """
         logger.debug("Joining %s to %s", joinee, room_id)
 
-        yield self.store.clean_room_for_join(room_id)
-
         origin, event = yield self._make_and_verify_event(
             target_hosts,
             room_id,
@@ -872,7 +866,15 @@ class FederationHandler(BaseHandler):
             content,
         )
 
+        # This shouldn't happen, because the RoomMemberHandler has a
+        # linearizer lock which only allows one operation per user per room
+        # at a time - so this is just paranoia.
+        assert (room_id not in self.room_queues)
+
         self.room_queues[room_id] = []
+
+        yield self.store.clean_room_for_join(room_id)
+
         handled_events = set()
 
         try:
@@ -925,18 +927,37 @@ class FederationHandler(BaseHandler):
             room_queue = self.room_queues[room_id]
             del self.room_queues[room_id]
 
-            for p, origin in room_queue:
-                if p.event_id in handled_events:
-                    continue
+            # we don't need to wait for the queued events to be processed -
+            # it's just a best-effort thing at this point. We do want to do
+            # them roughly in order, though, otherwise we'll end up making
+            # lots of requests for missing prev_events which we do actually
+            # have. Hence we fire off the deferred, but don't wait for it.
 
-                try:
-                    self._process_received_pdu(origin, p)
-                except:
-                    logger.exception("Couldn't handle pdu")
+            synapse.util.logcontext.preserve_fn(self._handle_queued_pdus)(
+                room_queue
+            )
 
         defer.returnValue(True)
 
     @defer.inlineCallbacks
+    def _handle_queued_pdus(self, room_queue):
+        """Process PDUs which got queued up while we were busy send_joining.
+
+        Args:
+            room_queue (list[FrozenEvent, str]): list of PDUs to be processed
+                and the servers that sent them
+        """
+        for p, origin in room_queue:
+            try:
+                logger.info("Processing queued PDU %s which was received "
+                            "while we were joining %s", p.event_id, p.room_id)
+                yield self.on_receive_pdu(origin, p)
+            except Exception as e:
+                logger.warn(
+                    "Error handling queued PDU %s from %s: %s",
+                    p.event_id, origin, e)
+
+    @defer.inlineCallbacks
     @log_function
     def on_make_join_request(self, room_id, user_id):
         """ We've received a /make_join/ request, so we create a partial
@@ -1517,7 +1538,17 @@ class FederationHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def _prep_event(self, origin, event, state=None, auth_events=None):
+        """
 
+        Args:
+            origin:
+            event:
+            state:
+            auth_events:
+
+        Returns:
+            Deferred, which resolves to synapse.events.snapshot.EventContext
+        """
         context = yield self.state_handler.compute_event_context(
             event, old_state=state,
         )
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 059260a8aa..1ede117c79 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -575,8 +575,7 @@ class PresenceHandler(object):
                 if not local_states:
                     continue
 
-                users = yield self.store.get_users_in_room(room_id)
-                hosts = set(get_domain_from_id(u) for u in users)
+                hosts = yield self.store.get_hosts_in_room(room_id)
 
                 for host in hosts:
                     hosts_to_states.setdefault(host, []).extend(local_states)
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index abd1fb28cb..9bf638f818 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -52,7 +52,8 @@ class ProfileHandler(BaseHandler):
                     args={
                         "user_id": target_user.to_string(),
                         "field": "displayname",
-                    }
+                    },
+                    ignore_backoff=True,
                 )
             except CodeMessageException as e:
                 if e.code != 404:
@@ -99,7 +100,8 @@ class ProfileHandler(BaseHandler):
                     args={
                         "user_id": target_user.to_string(),
                         "field": "avatar_url",
-                    }
+                    },
+                    ignore_backoff=True,
                 )
             except CodeMessageException as e:
                 if e.code != 404:
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 82586e3dea..62b4d7e93d 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -12,8 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
-
+import synapse.util.retryutils
 from twisted.internet import defer, reactor, protocol
 from twisted.internet.error import DNSLookupError
 from twisted.web.client import readBody, HTTPConnectionPool, Agent
@@ -22,7 +21,7 @@ from twisted.web._newclient import ResponseDone
 
 from synapse.http.endpoint import matrix_federation_endpoint
 from synapse.util.async import sleep
-from synapse.util.logcontext import preserve_context_over_fn
+from synapse.util import logcontext
 import synapse.metrics
 
 from canonicaljson import encode_canonical_json
@@ -94,6 +93,7 @@ class MatrixFederationHttpClient(object):
             reactor, MatrixFederationEndpointFactory(hs), pool=pool
         )
         self.clock = hs.get_clock()
+        self._store = hs.get_datastore()
         self.version_string = hs.version_string
         self._next_id = 1
 
@@ -103,129 +103,152 @@ class MatrixFederationHttpClient(object):
         )
 
     @defer.inlineCallbacks
-    def _create_request(self, destination, method, path_bytes,
-                        body_callback, headers_dict={}, param_bytes=b"",
-                        query_bytes=b"", retry_on_dns_fail=True,
-                        timeout=None, long_retries=False):
-        """ Creates and sends a request to the given url
+    def _request(self, destination, method, path,
+                 body_callback, headers_dict={}, param_bytes=b"",
+                 query_bytes=b"", retry_on_dns_fail=True,
+                 timeout=None, long_retries=False,
+                 ignore_backoff=False,
+                 backoff_on_404=False):
+        """ Creates and sends a request to the given server
+        Args:
+            destination (str): The remote server to send the HTTP request to.
+            method (str): HTTP method
+            path (str): The HTTP path
+            ignore_backoff (bool): true to ignore the historical backoff data
+                and try the request anyway.
+            backoff_on_404 (bool): Back off if we get a 404
 
         Returns:
             Deferred: resolves with the http response object on success.
 
             Fails with ``HTTPRequestException``: if we get an HTTP response
-            code >= 300.
+                code >= 300.
+            Fails with ``NotRetryingDestination`` if we are not yet ready
+                to retry this server.
         """
-        headers_dict[b"User-Agent"] = [self.version_string]
-        headers_dict[b"Host"] = [destination]
-
-        url_bytes = self._create_url(
-            destination, path_bytes, param_bytes, query_bytes
+        limiter = yield synapse.util.retryutils.get_retry_limiter(
+            destination,
+            self.clock,
+            self._store,
+            backoff_on_404=backoff_on_404,
+            ignore_backoff=ignore_backoff,
         )
 
-        txn_id = "%s-O-%s" % (method, self._next_id)
-        self._next_id = (self._next_id + 1) % (sys.maxint - 1)
+        destination = destination.encode("ascii")
+        path_bytes = path.encode("ascii")
+        with limiter:
+            headers_dict[b"User-Agent"] = [self.version_string]
+            headers_dict[b"Host"] = [destination]
 
-        outbound_logger.info(
-            "{%s} [%s] Sending request: %s %s",
-            txn_id, destination, method, url_bytes
-        )
+            url_bytes = self._create_url(
+                destination, path_bytes, param_bytes, query_bytes
+            )
 
-        # XXX: Would be much nicer to retry only at the transaction-layer
-        # (once we have reliable transactions in place)
-        if long_retries:
-            retries_left = MAX_LONG_RETRIES
-        else:
-            retries_left = MAX_SHORT_RETRIES
+            txn_id = "%s-O-%s" % (method, self._next_id)
+            self._next_id = (self._next_id + 1) % (sys.maxint - 1)
 
-        http_url_bytes = urlparse.urlunparse(
-            ("", "", path_bytes, param_bytes, query_bytes, "")
-        )
+            outbound_logger.info(
+                "{%s} [%s] Sending request: %s %s",
+                txn_id, destination, method, url_bytes
+            )
 
-        log_result = None
-        try:
-            while True:
-                producer = None
-                if body_callback:
-                    producer = body_callback(method, http_url_bytes, headers_dict)
-
-                try:
-                    def send_request():
-                        request_deferred = preserve_context_over_fn(
-                            self.agent.request,
-                            method,
-                            url_bytes,
-                            Headers(headers_dict),
-                            producer
-                        )
+            # XXX: Would be much nicer to retry only at the transaction-layer
+            # (once we have reliable transactions in place)
+            if long_retries:
+                retries_left = MAX_LONG_RETRIES
+            else:
+                retries_left = MAX_SHORT_RETRIES
 
-                        return self.clock.time_bound_deferred(
-                            request_deferred,
-                            time_out=timeout / 1000. if timeout else 60,
-                        )
+            http_url_bytes = urlparse.urlunparse(
+                ("", "", path_bytes, param_bytes, query_bytes, "")
+            )
 
-                    response = yield preserve_context_over_fn(send_request)
+            log_result = None
+            try:
+                while True:
+                    producer = None
+                    if body_callback:
+                        producer = body_callback(method, http_url_bytes, headers_dict)
+
+                    try:
+                        def send_request():
+                            request_deferred = self.agent.request(
+                                method,
+                                url_bytes,
+                                Headers(headers_dict),
+                                producer
+                            )
+
+                            return self.clock.time_bound_deferred(
+                                request_deferred,
+                                time_out=timeout / 1000. if timeout else 60,
+                            )
+
+                        with logcontext.PreserveLoggingContext():
+                            response = yield send_request()
+
+                        log_result = "%d %s" % (response.code, response.phrase,)
+                        break
+                    except Exception as e:
+                        if not retry_on_dns_fail and isinstance(e, DNSLookupError):
+                            logger.warn(
+                                "DNS Lookup failed to %s with %s",
+                                destination,
+                                e
+                            )
+                            log_result = "DNS Lookup failed to %s with %s" % (
+                                destination, e
+                            )
+                            raise
 
-                    log_result = "%d %s" % (response.code, response.phrase,)
-                    break
-                except Exception as e:
-                    if not retry_on_dns_fail and isinstance(e, DNSLookupError):
                         logger.warn(
-                            "DNS Lookup failed to %s with %s",
+                            "{%s} Sending request failed to %s: %s %s: %s - %s",
+                            txn_id,
                             destination,
-                            e
+                            method,
+                            url_bytes,
+                            type(e).__name__,
+                            _flatten_response_never_received(e),
                         )
-                        log_result = "DNS Lookup failed to %s with %s" % (
-                            destination, e
+
+                        log_result = "%s - %s" % (
+                            type(e).__name__, _flatten_response_never_received(e),
                         )
-                        raise
-
-                    logger.warn(
-                        "{%s} Sending request failed to %s: %s %s: %s - %s",
-                        txn_id,
-                        destination,
-                        method,
-                        url_bytes,
-                        type(e).__name__,
-                        _flatten_response_never_received(e),
-                    )
-
-                    log_result = "%s - %s" % (
-                        type(e).__name__, _flatten_response_never_received(e),
-                    )
-
-                    if retries_left and not timeout:
-                        if long_retries:
-                            delay = 4 ** (MAX_LONG_RETRIES + 1 - retries_left)
-                            delay = min(delay, 60)
-                            delay *= random.uniform(0.8, 1.4)
+
+                        if retries_left and not timeout:
+                            if long_retries:
+                                delay = 4 ** (MAX_LONG_RETRIES + 1 - retries_left)
+                                delay = min(delay, 60)
+                                delay *= random.uniform(0.8, 1.4)
+                            else:
+                                delay = 0.5 * 2 ** (MAX_SHORT_RETRIES - retries_left)
+                                delay = min(delay, 2)
+                                delay *= random.uniform(0.8, 1.4)
+
+                            yield sleep(delay)
+                            retries_left -= 1
                         else:
-                            delay = 0.5 * 2 ** (MAX_SHORT_RETRIES - retries_left)
-                            delay = min(delay, 2)
-                            delay *= random.uniform(0.8, 1.4)
-
-                        yield sleep(delay)
-                        retries_left -= 1
-                    else:
-                        raise
-        finally:
-            outbound_logger.info(
-                "{%s} [%s] Result: %s",
-                txn_id,
-                destination,
-                log_result,
-            )
+                            raise
+            finally:
+                outbound_logger.info(
+                    "{%s} [%s] Result: %s",
+                    txn_id,
+                    destination,
+                    log_result,
+                )
 
-        if 200 <= response.code < 300:
-            pass
-        else:
-            # :'(
-            # Update transactions table?
-            body = yield preserve_context_over_fn(readBody, response)
-            raise HttpResponseException(
-                response.code, response.phrase, body
-            )
+            if 200 <= response.code < 300:
+                pass
+            else:
+                # :'(
+                # Update transactions table?
+                with logcontext.PreserveLoggingContext():
+                    body = yield readBody(response)
+                raise HttpResponseException(
+                    response.code, response.phrase, body
+                )
 
-        defer.returnValue(response)
+            defer.returnValue(response)
 
     def sign_request(self, destination, method, url_bytes, headers_dict,
                      content=None):
@@ -254,7 +277,9 @@ class MatrixFederationHttpClient(object):
 
     @defer.inlineCallbacks
     def put_json(self, destination, path, data={}, json_data_callback=None,
-                 long_retries=False, timeout=None):
+                 long_retries=False, timeout=None,
+                 ignore_backoff=False,
+                 backoff_on_404=False):
         """ Sends the specifed json data using PUT
 
         Args:
@@ -269,11 +294,19 @@ class MatrixFederationHttpClient(object):
                 retry for a short or long time.
             timeout(int): How long to try (in ms) the destination for before
                 giving up. None indicates no timeout.
+            ignore_backoff (bool): true to ignore the historical backoff data
+                and try the request anyway.
+            backoff_on_404 (bool): True if we should count a 404 response as
+                a failure of the server (and should therefore back off future
+                requests)
 
         Returns:
             Deferred: Succeeds when we get a 2xx HTTP response. The result
             will be the decoded JSON body. On a 4xx or 5xx error response a
             CodeMessageException is raised.
+
+            Fails with ``NotRetryingDestination`` if we are not yet ready
+            to retry this server.
         """
 
         if not json_data_callback:
@@ -288,26 +321,29 @@ class MatrixFederationHttpClient(object):
             producer = _JsonProducer(json_data)
             return producer
 
-        response = yield self._create_request(
-            destination.encode("ascii"),
+        response = yield self._request(
+            destination,
             "PUT",
-            path.encode("ascii"),
+            path,
             body_callback=body_callback,
             headers_dict={"Content-Type": ["application/json"]},
             long_retries=long_retries,
             timeout=timeout,
+            ignore_backoff=ignore_backoff,
+            backoff_on_404=backoff_on_404,
         )
 
         if 200 <= response.code < 300:
             # We need to update the transactions table to say it was sent?
             check_content_type_is_json(response.headers)
 
-        body = yield preserve_context_over_fn(readBody, response)
+        with logcontext.PreserveLoggingContext():
+            body = yield readBody(response)
         defer.returnValue(json.loads(body))
 
     @defer.inlineCallbacks
     def post_json(self, destination, path, data={}, long_retries=False,
-                  timeout=None):
+                  timeout=None, ignore_backoff=False):
         """ Sends the specifed json data using POST
 
         Args:
@@ -320,11 +356,15 @@ class MatrixFederationHttpClient(object):
                 retry for a short or long time.
             timeout(int): How long to try (in ms) the destination for before
                 giving up. None indicates no timeout.
-
+            ignore_backoff (bool): true to ignore the historical backoff data and
+                try the request anyway.
         Returns:
             Deferred: Succeeds when we get a 2xx HTTP response. The result
             will be the decoded JSON body. On a 4xx or 5xx error response a
             CodeMessageException is raised.
+
+            Fails with ``NotRetryingDestination`` if we are not yet ready
+            to retry this server.
         """
 
         def body_callback(method, url_bytes, headers_dict):
@@ -333,27 +373,29 @@ class MatrixFederationHttpClient(object):
             )
             return _JsonProducer(data)
 
-        response = yield self._create_request(
-            destination.encode("ascii"),
+        response = yield self._request(
+            destination,
             "POST",
-            path.encode("ascii"),
+            path,
             body_callback=body_callback,
             headers_dict={"Content-Type": ["application/json"]},
             long_retries=long_retries,
             timeout=timeout,
+            ignore_backoff=ignore_backoff,
         )
 
         if 200 <= response.code < 300:
             # We need to update the transactions table to say it was sent?
             check_content_type_is_json(response.headers)
 
-        body = yield preserve_context_over_fn(readBody, response)
+        with logcontext.PreserveLoggingContext():
+            body = yield readBody(response)
 
         defer.returnValue(json.loads(body))
 
     @defer.inlineCallbacks
     def get_json(self, destination, path, args={}, retry_on_dns_fail=True,
-                 timeout=None):
+                 timeout=None, ignore_backoff=False):
         """ GETs some json from the given host homeserver and path
 
         Args:
@@ -365,11 +407,16 @@ class MatrixFederationHttpClient(object):
             timeout (int): How long to try (in ms) the destination for before
                 giving up. None indicates no timeout and that the request will
                 be retried.
+            ignore_backoff (bool): true to ignore the historical backoff data
+                and try the request anyway.
         Returns:
             Deferred: Succeeds when we get *any* HTTP response.
 
             The result of the deferred is a tuple of `(code, response)`,
             where `response` is a dict representing the decoded JSON body.
+
+            Fails with ``NotRetryingDestination`` if we are not yet ready
+            to retry this server.
         """
         logger.debug("get_json args: %s", args)
 
@@ -386,39 +433,47 @@ class MatrixFederationHttpClient(object):
             self.sign_request(destination, method, url_bytes, headers_dict)
             return None
 
-        response = yield self._create_request(
-            destination.encode("ascii"),
+        response = yield self._request(
+            destination,
             "GET",
-            path.encode("ascii"),
+            path,
             query_bytes=query_bytes,
             body_callback=body_callback,
             retry_on_dns_fail=retry_on_dns_fail,
             timeout=timeout,
+            ignore_backoff=ignore_backoff,
         )
 
         if 200 <= response.code < 300:
             # We need to update the transactions table to say it was sent?
             check_content_type_is_json(response.headers)
 
-        body = yield preserve_context_over_fn(readBody, response)
+        with logcontext.PreserveLoggingContext():
+            body = yield readBody(response)
 
         defer.returnValue(json.loads(body))
 
     @defer.inlineCallbacks
     def get_file(self, destination, path, output_stream, args={},
-                 retry_on_dns_fail=True, max_size=None):
+                 retry_on_dns_fail=True, max_size=None,
+                 ignore_backoff=False):
         """GETs a file from a given homeserver
         Args:
             destination (str): The remote server to send the HTTP request to.
             path (str): The HTTP path to GET.
             output_stream (file): File to write the response body to.
             args (dict): Optional dictionary used to create the query string.
+            ignore_backoff (bool): true to ignore the historical backoff data
+                and try the request anyway.
         Returns:
             Deferred: resolves with an (int,dict) tuple of the file length and
             a dict of the response headers.
 
             Fails with ``HTTPRequestException`` if we get an HTTP response code
             >= 300
+
+            Fails with ``NotRetryingDestination`` if we are not yet ready
+            to retry this server.
         """
 
         encoded_args = {}
@@ -434,22 +489,23 @@ class MatrixFederationHttpClient(object):
             self.sign_request(destination, method, url_bytes, headers_dict)
             return None
 
-        response = yield self._create_request(
-            destination.encode("ascii"),
+        response = yield self._request(
+            destination,
             "GET",
-            path.encode("ascii"),
+            path,
             query_bytes=query_bytes,
             body_callback=body_callback,
-            retry_on_dns_fail=retry_on_dns_fail
+            retry_on_dns_fail=retry_on_dns_fail,
+            ignore_backoff=ignore_backoff,
         )
 
         headers = dict(response.headers.getAllRawHeaders())
 
         try:
-            length = yield preserve_context_over_fn(
-                _readBodyToFile,
-                response, output_stream, max_size
-            )
+            with logcontext.PreserveLoggingContext():
+                length = yield _readBodyToFile(
+                    response, output_stream, max_size
+                )
         except:
             logger.exception("Failed to download body")
             raise
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 4db76f18bd..4d88046579 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -17,6 +17,7 @@ import logging
 import re
 
 from synapse.types import UserID
+from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
 from synapse.util.caches.lrucache import LruCache
 
 logger = logging.getLogger(__name__)
@@ -125,6 +126,11 @@ class PushRuleEvaluatorForEvent(object):
         return self._value_cache.get(dotted_key, None)
 
 
+# Caches (glob, word_boundary) -> regex for push. See _glob_matches
+regex_cache = LruCache(50000 * CACHE_SIZE_FACTOR)
+register_cache("regex_push_cache", regex_cache)
+
+
 def _glob_matches(glob, value, word_boundary=False):
     """Tests if value matches glob.
 
@@ -137,46 +143,63 @@ def _glob_matches(glob, value, word_boundary=False):
     Returns:
         bool
     """
-    try:
-        if IS_GLOB.search(glob):
-            r = re.escape(glob)
-
-            r = r.replace(r'\*', '.*?')
-            r = r.replace(r'\?', '.')
-
-            # handle [abc], [a-z] and [!a-z] style ranges.
-            r = GLOB_REGEX.sub(
-                lambda x: (
-                    '[%s%s]' % (
-                        x.group(1) and '^' or '',
-                        x.group(2).replace(r'\\\-', '-')
-                    )
-                ),
-                r,
-            )
-            if word_boundary:
-                r = r"\b%s\b" % (r,)
-                r = _compile_regex(r)
-
-                return r.search(value)
-            else:
-                r = r + "$"
-                r = _compile_regex(r)
-
-                return r.match(value)
-        elif word_boundary:
-            r = re.escape(glob)
-            r = r"\b%s\b" % (r,)
-            r = _compile_regex(r)
 
-            return r.search(value)
-        else:
-            return value.lower() == glob.lower()
+    try:
+        r = regex_cache.get((glob, word_boundary), None)
+        if not r:
+            r = _glob_to_re(glob, word_boundary)
+            regex_cache[(glob, word_boundary)] = r
+        return r.search(value)
     except re.error:
         logger.warn("Failed to parse glob to regex: %r", glob)
         return False
 
 
+def _glob_to_re(glob, word_boundary):
+    """Generates regex for a given glob.
+
+    Args:
+        glob (string)
+        word_boundary (bool): Whether to match against word boundaries or entire
+            string. Defaults to False.
+
+    Returns:
+        regex object
+    """
+    if IS_GLOB.search(glob):
+        r = re.escape(glob)
+
+        r = r.replace(r'\*', '.*?')
+        r = r.replace(r'\?', '.')
+
+        # handle [abc], [a-z] and [!a-z] style ranges.
+        r = GLOB_REGEX.sub(
+            lambda x: (
+                '[%s%s]' % (
+                    x.group(1) and '^' or '',
+                    x.group(2).replace(r'\\\-', '-')
+                )
+            ),
+            r,
+        )
+        if word_boundary:
+            r = r"\b%s\b" % (r,)
+
+            return re.compile(r, flags=re.IGNORECASE)
+        else:
+            r = "^" + r + "$"
+
+            return re.compile(r, flags=re.IGNORECASE)
+    elif word_boundary:
+        r = re.escape(glob)
+        r = r"\b%s\b" % (r,)
+
+        return re.compile(r, flags=re.IGNORECASE)
+    else:
+        r = "^" + re.escape(glob) + "$"
+        return re.compile(r, flags=re.IGNORECASE)
+
+
 def _flatten_dict(d, prefix=[], result={}):
     for key, value in d.items():
         if isinstance(value, basestring):
@@ -185,16 +208,3 @@ def _flatten_dict(d, prefix=[], result={}):
             _flatten_dict(value, prefix=(prefix + [key]), result=result)
 
     return result
-
-
-regex_cache = LruCache(5000)
-
-
-def _compile_regex(regex_str):
-    r = regex_cache.get(regex_str, None)
-    if r:
-        return r
-
-    r = re.compile(regex_str, flags=re.IGNORECASE)
-    regex_cache[regex_str] = r
-    return r
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index 287df94b4f..6835f54e97 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -17,15 +17,12 @@ from twisted.internet import defer
 from synapse.push.presentable_names import (
     calculate_room_name, name_from_member_event
 )
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
 
 
 @defer.inlineCallbacks
 def get_badge_count(store, user_id):
-    invites, joins = yield preserve_context_over_deferred(defer.gatherResults([
-        preserve_fn(store.get_invited_rooms_for_user)(user_id),
-        preserve_fn(store.get_rooms_for_user)(user_id),
-    ], consumeErrors=True))
+    invites = yield store.get_invited_rooms_for_user(user_id)
+    joins = yield store.get_rooms_for_user(user_id)
 
     my_receipts_by_room = yield store.get_receipts_for_user(
         user_id, "m.read",
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index c4777b2a2b..ed7f1c89ad 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -19,6 +19,7 @@ from distutils.version import LooseVersion
 logger = logging.getLogger(__name__)
 
 REQUIREMENTS = {
+    "jsonschema>=2.5.1": ["jsonschema>=2.5.1"],
     "frozendict>=0.4": ["frozendict"],
     "unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"],
     "canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"],
diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py
index d8eb14592b..03930fe958 100644
--- a/synapse/replication/resource.py
+++ b/synapse/replication/resource.py
@@ -283,12 +283,12 @@ class ReplicationResource(Resource):
 
             if request_events != upto_events_token:
                 writer.write_header_and_rows("events", res.new_forward_events, (
-                    "position", "internal", "json", "state_group"
+                    "position", "event_id", "room_id", "type", "state_key",
                 ), position=upto_events_token)
 
             if request_backfill != upto_backfill_token:
                 writer.write_header_and_rows("backfill", res.new_backfill_events, (
-                    "position", "internal", "json", "state_group",
+                    "position", "event_id", "room_id", "type", "state_key", "redacts",
                 ), position=upto_backfill_token)
 
             writer.write_header_and_rows(
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 518c9ea2e9..d4db1e452e 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -16,7 +16,6 @@ from ._base import BaseSlavedStore
 from ._slaved_id_tracker import SlavedIdTracker
 
 from synapse.api.constants import EventTypes
-from synapse.events import FrozenEvent
 from synapse.storage import DataStore
 from synapse.storage.roommember import RoomMemberStore
 from synapse.storage.event_federation import EventFederationStore
@@ -25,7 +24,6 @@ from synapse.storage.state import StateStore
 from synapse.storage.stream import StreamStore
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
-import ujson as json
 import logging
 
 
@@ -169,7 +167,6 @@ class SlavedEventStore(BaseSlavedStore):
     _get_rooms_for_user_where_membership_is_txn = (
         DataStore._get_rooms_for_user_where_membership_is_txn.__func__
     )
-    _get_members_rows_txn = DataStore._get_members_rows_txn.__func__
     _get_state_for_groups = DataStore._get_state_for_groups.__func__
     _get_all_state_from_cache = DataStore._get_all_state_from_cache.__func__
     _get_events_around_txn = DataStore._get_events_around_txn.__func__
@@ -242,46 +239,32 @@ class SlavedEventStore(BaseSlavedStore):
         return super(SlavedEventStore, self).process_replication(result)
 
     def _process_replication_row(self, row, backfilled):
-        internal = json.loads(row[1])
-        event_json = json.loads(row[2])
-        event = FrozenEvent(event_json, internal_metadata_dict=internal)
+        stream_ordering = row[0] if not backfilled else -row[0]
         self.invalidate_caches_for_event(
-            event, backfilled,
+            stream_ordering, row[1], row[2], row[3], row[4], row[5],
+            backfilled=backfilled,
         )
 
-    def invalidate_caches_for_event(self, event, backfilled):
-        self._invalidate_get_event_cache(event.event_id)
+    def invalidate_caches_for_event(self, stream_ordering, event_id, room_id,
+                                    etype, state_key, redacts, backfilled):
+        self._invalidate_get_event_cache(event_id)
 
-        self.get_latest_event_ids_in_room.invalidate((event.room_id,))
+        self.get_latest_event_ids_in_room.invalidate((room_id,))
 
         self.get_unread_event_push_actions_by_room_for_user.invalidate_many(
-            (event.room_id,)
+            (room_id,)
         )
 
         if not backfilled:
             self._events_stream_cache.entity_has_changed(
-                event.room_id, event.internal_metadata.stream_ordering
+                room_id, stream_ordering
             )
 
-        # self.get_unread_event_push_actions_by_room_for_user.invalidate_many(
-        #     (event.room_id,)
-        # )
+        if redacts:
+            self._invalidate_get_event_cache(redacts)
 
-        if event.type == EventTypes.Redaction:
-            self._invalidate_get_event_cache(event.redacts)
-
-        if event.type == EventTypes.Member:
+        if etype == EventTypes.Member:
             self._membership_stream_cache.entity_has_changed(
-                event.state_key, event.internal_metadata.stream_ordering
+                state_key, stream_ordering
             )
-            self.get_invited_rooms_for_user.invalidate((event.state_key,))
-
-        if not event.is_state():
-            return
-
-        if backfilled:
-            return
-
-        if (not event.internal_metadata.is_invite_from_remote()
-                and event.internal_metadata.is_outlier()):
-            return
+            self.get_invited_rooms_for_user.invalidate((state_key,))
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index 40f6c9a386..e4a2414d78 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -57,5 +57,6 @@ class SlavedPresenceStore(BaseSlavedStore):
                 self.presence_stream_cache.entity_has_changed(
                     user_id, position
                 )
+                self._get_presence_for_user.invalidate((user_id,))
 
         return super(SlavedPresenceStore, self).process_replication(result)
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index aac76edf1c..4990b22b9f 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -268,7 +268,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
         if existingUid is not None:
             raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
 
-        ret = yield self.identity_handler.requestEmailToken(**body)
+        ret = yield self.identity_handler.requestMsisdnToken(**body)
         defer.returnValue((200, ret))
 
 
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index dcd13b876f..3acf4eacdd 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -537,7 +537,7 @@ class RegisterRestServlet(RestServlet):
         # we have nowhere to store it.
         device_id = synapse.api.auth.GUEST_DEVICE_ID
         initial_display_name = params.get("initial_device_display_name")
-        self.device_handler.check_device_registered(
+        yield self.device_handler.check_device_registered(
             user_id, device_id, initial_display_name
         )
 
diff --git a/synapse/state.py b/synapse/state.py
index 383d32b163..f6b83d888a 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -177,17 +177,12 @@ class StateHandler(object):
 
     @defer.inlineCallbacks
     def compute_event_context(self, event, old_state=None):
-        """ Fills out the context with the `current state` of the graph. The
-        `current state` here is defined to be the state of the event graph
-        just before the event - i.e. it never includes `event`
-
-        If `event` has `auth_events` then this will also fill out the
-        `auth_events` field on `context` from the `current_state`.
+        """Build an EventContext structure for the event.
 
         Args:
-            event (EventBase)
+            event (synapse.events.EventBase):
         Returns:
-            an EventContext
+            synapse.events.snapshot.EventContext:
         """
         context = EventContext()
 
@@ -200,11 +195,11 @@ class StateHandler(object):
                     (s.type, s.state_key): s.event_id for s in old_state
                 }
                 if event.is_state():
-                    context.current_state_events = dict(context.prev_state_ids)
+                    context.current_state_ids = dict(context.prev_state_ids)
                     key = (event.type, event.state_key)
-                    context.current_state_events[key] = event.event_id
+                    context.current_state_ids[key] = event.event_id
                 else:
-                    context.current_state_events = context.prev_state_ids
+                    context.current_state_ids = context.prev_state_ids
             else:
                 context.current_state_ids = {}
                 context.prev_state_ids = {}
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 13b106bba1..c659004e8d 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -73,6 +73,9 @@ class LoggingTransaction(object):
     def __setattr__(self, name, value):
         setattr(self.txn, name, value)
 
+    def __iter__(self):
+        return self.txn.__iter__()
+
     def execute(self, sql, *args):
         self._do_execute(self.txn.execute, sql, *args)
 
@@ -132,7 +135,7 @@ class PerformanceCounters(object):
 
     def interval(self, interval_duration, limit=3):
         counters = []
-        for name, (count, cum_time) in self.current_counters.items():
+        for name, (count, cum_time) in self.current_counters.iteritems():
             prev_count, prev_time = self.previous_counters.get(name, (0, 0))
             counters.append((
                 (cum_time - prev_time) / interval_duration,
@@ -357,7 +360,7 @@ class SQLBaseStore(object):
         """
         col_headers = list(intern(column[0]) for column in cursor.description)
         results = list(
-            dict(zip(col_headers, row)) for row in cursor.fetchall()
+            dict(zip(col_headers, row)) for row in cursor
         )
         return results
 
@@ -565,7 +568,7 @@ class SQLBaseStore(object):
     @staticmethod
     def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
         if keyvalues:
-            where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
+            where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
         else:
             where = ""
 
@@ -579,7 +582,7 @@ class SQLBaseStore(object):
 
         txn.execute(sql, keyvalues.values())
 
-        return [r[0] for r in txn.fetchall()]
+        return [r[0] for r in txn]
 
     def _simple_select_onecol(self, table, keyvalues, retcol,
                               desc="_simple_select_onecol"):
@@ -712,7 +715,7 @@ class SQLBaseStore(object):
         )
         values.extend(iterable)
 
-        for key, value in keyvalues.items():
+        for key, value in keyvalues.iteritems():
             clauses.append("%s = ?" % (key,))
             values.append(value)
 
@@ -753,7 +756,7 @@ class SQLBaseStore(object):
     @staticmethod
     def _simple_update_one_txn(txn, table, keyvalues, updatevalues):
         if keyvalues:
-            where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
+            where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
         else:
             where = ""
 
@@ -870,7 +873,7 @@ class SQLBaseStore(object):
         )
         values.extend(iterable)
 
-        for key, value in keyvalues.items():
+        for key, value in keyvalues.iteritems():
             clauses.append("%s = ?" % (key,))
             values.append(value)
 
@@ -901,16 +904,16 @@ class SQLBaseStore(object):
 
         txn = db_conn.cursor()
         txn.execute(sql, (int(max_value),))
-        rows = txn.fetchall()
-        txn.close()
 
         cache = {
             row[0]: int(row[1])
-            for row in rows
+            for row in txn
         }
 
+        txn.close()
+
         if cache:
-            min_val = min(cache.values())
+            min_val = min(cache.itervalues())
         else:
             min_val = max_value
 
diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py
index 3fa226e92d..aa84ffc2b0 100644
--- a/synapse/storage/account_data.py
+++ b/synapse/storage/account_data.py
@@ -182,7 +182,7 @@ class AccountDataStore(SQLBaseStore):
             txn.execute(sql, (user_id, stream_id))
 
             global_account_data = {
-                row[0]: json.loads(row[1]) for row in txn.fetchall()
+                row[0]: json.loads(row[1]) for row in txn
             }
 
             sql = (
@@ -193,7 +193,7 @@ class AccountDataStore(SQLBaseStore):
             txn.execute(sql, (user_id, stream_id))
 
             account_data_by_room = {}
-            for row in txn.fetchall():
+            for row in txn:
                 room_account_data = account_data_by_room.setdefault(row[0], {})
                 room_account_data[row[1]] = json.loads(row[2])
 
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 94b2bcc54a..813ad59e56 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -12,6 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import synapse.util.async
 
 from ._base import SQLBaseStore
 from . import engines
@@ -84,24 +85,14 @@ class BackgroundUpdateStore(SQLBaseStore):
         self._background_update_performance = {}
         self._background_update_queue = []
         self._background_update_handlers = {}
-        self._background_update_timer = None
 
     @defer.inlineCallbacks
     def start_doing_background_updates(self):
-        assert self._background_update_timer is None, \
-            "background updates already running"
-
         logger.info("Starting background schema updates")
 
         while True:
-            sleep = defer.Deferred()
-            self._background_update_timer = self._clock.call_later(
-                self.BACKGROUND_UPDATE_INTERVAL_MS / 1000., sleep.callback, None
-            )
-            try:
-                yield sleep
-            finally:
-                self._background_update_timer = None
+            yield synapse.util.async.sleep(
+                self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.)
 
             try:
                 result = yield self.do_next_background_update(
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py
index 7925cb5f1b..2714519d21 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/deviceinbox.py
@@ -178,7 +178,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
                 )
                 txn.execute(sql, (user_id,))
                 message_json = ujson.dumps(messages_by_device["*"])
-                for row in txn.fetchall():
+                for row in txn:
                     # Add the message for all devices for this user on this
                     # server.
                     device = row[0]
@@ -195,7 +195,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
                 # TODO: Maybe this needs to be done in batches if there are
                 # too many local devices for a given user.
                 txn.execute(sql, [user_id] + devices)
-                for row in txn.fetchall():
+                for row in txn:
                     # Only insert into the local inbox if the device exists on
                     # this server
                     device = row[0]
@@ -251,7 +251,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
                 user_id, device_id, last_stream_id, current_stream_id, limit
             ))
             messages = []
-            for row in txn.fetchall():
+            for row in txn:
                 stream_pos = row[0]
                 messages.append(ujson.loads(row[1]))
             if len(messages) < limit:
@@ -340,7 +340,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
                 " ORDER BY stream_id ASC"
             )
             txn.execute(sql, (last_pos, upper_pos))
-            rows.extend(txn.fetchall())
+            rows.extend(txn)
 
             return rows
 
@@ -384,7 +384,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
                 destination, last_stream_id, current_stream_id, limit
             ))
             messages = []
-            for row in txn.fetchall():
+            for row in txn:
                 stream_pos = row[0]
                 messages.append(ujson.loads(row[1]))
             if len(messages) < limit:
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index e545b62e39..53e36791d5 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -329,17 +329,20 @@ class DeviceStore(SQLBaseStore):
             SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
             WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
             GROUP BY user_id, device_id
+            LIMIT 20
         """
         txn.execute(
             sql, (destination, from_stream_id, now_stream_id, False)
         )
-        rows = txn.fetchall()
 
-        if not rows:
+        # maps (user_id, device_id) -> stream_id
+        query_map = {(r[0], r[1]): r[2] for r in txn}
+        if not query_map:
             return (now_stream_id, [])
 
-        # maps (user_id, device_id) -> stream_id
-        query_map = {(r[0], r[1]): r[2] for r in rows}
+        if len(query_map) >= 20:
+            now_stream_id = max(stream_id for stream_id in query_map.itervalues())
+
         devices = self._get_e2e_device_keys_txn(
             txn, query_map.keys(), include_all_devices=True
         )
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index b9f1365f92..7cbc1470fd 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -14,6 +14,8 @@
 # limitations under the License.
 from twisted.internet import defer
 
+from synapse.api.errors import SynapseError
+
 from canonicaljson import encode_canonical_json
 import ujson as json
 
@@ -120,24 +122,63 @@ class EndToEndKeyStore(SQLBaseStore):
 
         return result
 
+    @defer.inlineCallbacks
     def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list):
+        """Insert some new one time keys for a device.
+
+        Checks if any of the keys are already inserted, if they are then check
+        if they match. If they don't then we raise an error.
+        """
+
+        # First we check if we have already persisted any of the keys.
+        rows = yield self._simple_select_many_batch(
+            table="e2e_one_time_keys_json",
+            column="key_id",
+            iterable=[key_id for _, key_id, _ in key_list],
+            retcols=("algorithm", "key_id", "key_json",),
+            keyvalues={
+                "user_id": user_id,
+                "device_id": device_id,
+            },
+            desc="add_e2e_one_time_keys_check",
+        )
+
+        existing_key_map = {
+            (row["algorithm"], row["key_id"]): row["key_json"] for row in rows
+        }
+
+        new_keys = []  # Keys that we need to insert
+        for algorithm, key_id, json_bytes in key_list:
+            ex_bytes = existing_key_map.get((algorithm, key_id), None)
+            if ex_bytes:
+                if json_bytes != ex_bytes:
+                    raise SynapseError(
+                        400, "One time key with key_id %r already exists" % (key_id,)
+                    )
+            else:
+                new_keys.append((algorithm, key_id, json_bytes))
+
         def _add_e2e_one_time_keys(txn):
-            for (algorithm, key_id, json_bytes) in key_list:
-                self._simple_upsert_txn(
-                    txn, table="e2e_one_time_keys_json",
-                    keyvalues={
+            # We are protected from race between lookup and insertion due to
+            # a unique constraint. If there is a race of two calls to
+            # `add_e2e_one_time_keys` then they'll conflict and we will only
+            # insert one set.
+            self._simple_insert_many_txn(
+                txn, table="e2e_one_time_keys_json",
+                values=[
+                    {
                         "user_id": user_id,
                         "device_id": device_id,
                         "algorithm": algorithm,
                         "key_id": key_id,
-                    },
-                    values={
                         "ts_added_ms": time_now,
                         "key_json": json_bytes,
                     }
-                )
-        return self.runInteraction(
-            "add_e2e_one_time_keys", _add_e2e_one_time_keys
+                    for algorithm, key_id, json_bytes in new_keys
+                ],
+            )
+        yield self.runInteraction(
+            "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
         )
 
     def count_e2e_one_time_keys(self, user_id, device_id):
@@ -153,7 +194,7 @@ class EndToEndKeyStore(SQLBaseStore):
             )
             txn.execute(sql, (user_id, device_id))
             result = {}
-            for algorithm, key_count in txn.fetchall():
+            for algorithm, key_count in txn:
                 result[algorithm] = key_count
             return result
         return self.runInteraction(
@@ -174,7 +215,7 @@ class EndToEndKeyStore(SQLBaseStore):
                 user_result = result.setdefault(user_id, {})
                 device_result = user_result.setdefault(device_id, {})
                 txn.execute(sql, (user_id, device_id, algorithm))
-                for key_id, key_json in txn.fetchall():
+                for key_id, key_json in txn:
                     device_result[algorithm + ":" + key_id] = key_json
                     delete.append((user_id, device_id, algorithm, key_id))
             sql = (
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 256e50dc20..519059c306 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -74,7 +74,7 @@ class EventFederationStore(SQLBaseStore):
                     base_sql % (",".join(["?"] * len(chunk)),),
                     chunk
                 )
-                new_front.update([r[0] for r in txn.fetchall()])
+                new_front.update([r[0] for r in txn])
 
             new_front -= results
 
@@ -110,7 +110,7 @@ class EventFederationStore(SQLBaseStore):
 
         txn.execute(sql, (room_id, False,))
 
-        return dict(txn.fetchall())
+        return dict(txn)
 
     def _get_oldest_events_in_room_txn(self, txn, room_id):
         return self._simple_select_onecol_txn(
@@ -201,19 +201,19 @@ class EventFederationStore(SQLBaseStore):
     def _update_min_depth_for_room_txn(self, txn, room_id, depth):
         min_depth = self._get_min_depth_interaction(txn, room_id)
 
-        do_insert = depth < min_depth if min_depth else True
+        if min_depth and depth >= min_depth:
+            return
 
-        if do_insert:
-            self._simple_upsert_txn(
-                txn,
-                table="room_depth",
-                keyvalues={
-                    "room_id": room_id,
-                },
-                values={
-                    "min_depth": depth,
-                },
-            )
+        self._simple_upsert_txn(
+            txn,
+            table="room_depth",
+            keyvalues={
+                "room_id": room_id,
+            },
+            values={
+                "min_depth": depth,
+            },
+        )
 
     def _handle_mult_prev_events(self, txn, events):
         """
@@ -334,8 +334,7 @@ class EventFederationStore(SQLBaseStore):
 
         def get_forward_extremeties_for_room_txn(txn):
             txn.execute(sql, (stream_ordering, room_id))
-            rows = txn.fetchall()
-            return [event_id for event_id, in rows]
+            return [event_id for event_id, in txn]
 
         return self.runInteraction(
             "get_forward_extremeties_for_room",
@@ -436,7 +435,7 @@ class EventFederationStore(SQLBaseStore):
                 (room_id, event_id, False, limit - len(event_results))
             )
 
-            for row in txn.fetchall():
+            for row in txn:
                 if row[1] not in event_results:
                     queue.put((-row[0], row[1]))
 
@@ -482,7 +481,7 @@ class EventFederationStore(SQLBaseStore):
                     (room_id, event_id, False, limit - len(event_results))
                 )
 
-                for e_id, in txn.fetchall():
+                for e_id, in txn:
                     new_front.add(e_id)
 
             new_front -= earliest_events
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index 14543b4269..d6d8723b4a 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -206,7 +206,7 @@ class EventPushActionsStore(SQLBaseStore):
                 " stream_ordering >= ? AND stream_ordering <= ?"
             )
             txn.execute(sql, (min_stream_ordering, max_stream_ordering))
-            return [r[0] for r in txn.fetchall()]
+            return [r[0] for r in txn]
         ret = yield self.runInteraction("get_push_action_users_in_range", f)
         defer.returnValue(ret)
 
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 72319c35ae..3f6833fad2 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -34,14 +34,16 @@ from canonicaljson import encode_canonical_json
 from collections import deque, namedtuple, OrderedDict
 from functools import wraps
 
-import synapse
 import synapse.metrics
 
-
 import logging
 import math
 import ujson as json
 
+# these are only included to make the type annotations work
+from synapse.events import EventBase    # noqa: F401
+from synapse.events.snapshot import EventContext   # noqa: F401
+
 logger = logging.getLogger(__name__)
 
 
@@ -82,6 +84,11 @@ class _EventPeristenceQueue(object):
 
     def add_to_queue(self, room_id, events_and_contexts, backfilled):
         """Add events to the queue, with the given persist_event options.
+
+        Args:
+            room_id (str):
+            events_and_contexts (list[(EventBase, EventContext)]):
+            backfilled (bool):
         """
         queue = self._event_persist_queues.setdefault(room_id, deque())
         if queue:
@@ -210,14 +217,14 @@ class EventsStore(SQLBaseStore):
             partitioned.setdefault(event.room_id, []).append((event, ctx))
 
         deferreds = []
-        for room_id, evs_ctxs in partitioned.items():
+        for room_id, evs_ctxs in partitioned.iteritems():
             d = preserve_fn(self._event_persist_queue.add_to_queue)(
                 room_id, evs_ctxs,
                 backfilled=backfilled,
             )
             deferreds.append(d)
 
-        for room_id in partitioned.keys():
+        for room_id in partitioned:
             self._maybe_start_persisting(room_id)
 
         return preserve_context_over_deferred(
@@ -227,6 +234,17 @@ class EventsStore(SQLBaseStore):
     @defer.inlineCallbacks
     @log_function
     def persist_event(self, event, context, backfilled=False):
+        """
+
+        Args:
+            event (EventBase):
+            context (EventContext):
+            backfilled (bool):
+
+        Returns:
+            Deferred: resolves to (int, int): the stream ordering of ``event``,
+            and the stream ordering of the latest persisted event
+        """
         deferred = self._event_persist_queue.add_to_queue(
             event.room_id, [(event, context)],
             backfilled=backfilled,
@@ -253,6 +271,16 @@ class EventsStore(SQLBaseStore):
     @defer.inlineCallbacks
     def _persist_events(self, events_and_contexts, backfilled=False,
                         delete_existing=False):
+        """Persist events to db
+
+        Args:
+            events_and_contexts (list[(EventBase, EventContext)]):
+            backfilled (bool):
+            delete_existing (bool):
+
+        Returns:
+            Deferred: resolves when the events have been persisted
+        """
         if not events_and_contexts:
             return
 
@@ -295,7 +323,7 @@ class EventsStore(SQLBaseStore):
                                 (event, context)
                             )
 
-                        for room_id, ev_ctx_rm in events_by_room.items():
+                        for room_id, ev_ctx_rm in events_by_room.iteritems():
                             # Work out new extremities by recursively adding and removing
                             # the new events.
                             latest_event_ids = yield self.get_latest_event_ids_in_room(
@@ -400,6 +428,7 @@ class EventsStore(SQLBaseStore):
         # Now we need to work out the different state sets for
         # each state extremities
         state_sets = []
+        state_groups = set()
         missing_event_ids = []
         was_updated = False
         for event_id in new_latest_event_ids:
@@ -409,9 +438,17 @@ class EventsStore(SQLBaseStore):
                 if event_id == ev.event_id:
                     if ctx.current_state_ids is None:
                         raise Exception("Unknown current state")
-                    state_sets.append(ctx.current_state_ids)
-                    if ctx.delta_ids or hasattr(ev, "state_key"):
-                        was_updated = True
+
+                    # If we've already seen the state group don't bother adding
+                    # it to the state sets again
+                    if ctx.state_group not in state_groups:
+                        state_sets.append(ctx.current_state_ids)
+                        if ctx.delta_ids or hasattr(ev, "state_key"):
+                            was_updated = True
+                        if ctx.state_group:
+                            # Add this as a seen state group (if it has a state
+                            # group)
+                            state_groups.add(ctx.state_group)
                     break
             else:
                 # If we couldn't find it, then we'll need to pull
@@ -425,45 +462,51 @@ class EventsStore(SQLBaseStore):
                 missing_event_ids,
             )
 
-            groups = set(event_to_groups.values())
-            group_to_state = yield self._get_state_for_groups(groups)
+            groups = set(event_to_groups.itervalues()) - state_groups
 
-            state_sets.extend(group_to_state.values())
+            if groups:
+                group_to_state = yield self._get_state_for_groups(groups)
+                state_sets.extend(group_to_state.itervalues())
 
         if not new_latest_event_ids:
             current_state = {}
         elif was_updated:
-            # We work out the current state by passing the state sets to the
-            # state resolution algorithm. It may ask for some events, including
-            # the events we have yet to persist, so we need a slightly more
-            # complicated event lookup function than simply looking the events
-            # up in the db.
-            events_map = {ev.event_id: ev for ev, _ in events_context}
-
-            @defer.inlineCallbacks
-            def get_events(ev_ids):
-                # We get the events by first looking at the list of events we
-                # are trying to persist, and then fetching the rest from the DB.
-                db = []
-                to_return = {}
-                for ev_id in ev_ids:
-                    ev = events_map.get(ev_id, None)
-                    if ev:
-                        to_return[ev_id] = ev
-                    else:
-                        db.append(ev_id)
-
-                if db:
-                    evs = yield self.get_events(
-                        ev_ids, get_prev_content=False, check_redacted=False,
-                    )
-                    to_return.update(evs)
-                defer.returnValue(to_return)
+            if len(state_sets) == 1:
+                # If there is only one state set, then we know what the current
+                # state is.
+                current_state = state_sets[0]
+            else:
+                # We work out the current state by passing the state sets to the
+                # state resolution algorithm. It may ask for some events, including
+                # the events we have yet to persist, so we need a slightly more
+                # complicated event lookup function than simply looking the events
+                # up in the db.
+                events_map = {ev.event_id: ev for ev, _ in events_context}
+
+                @defer.inlineCallbacks
+                def get_events(ev_ids):
+                    # We get the events by first looking at the list of events we
+                    # are trying to persist, and then fetching the rest from the DB.
+                    db = []
+                    to_return = {}
+                    for ev_id in ev_ids:
+                        ev = events_map.get(ev_id, None)
+                        if ev:
+                            to_return[ev_id] = ev
+                        else:
+                            db.append(ev_id)
 
-            current_state = yield resolve_events(
-                state_sets,
-                state_map_factory=get_events,
-            )
+                    if db:
+                        evs = yield self.get_events(
+                            ev_ids, get_prev_content=False, check_redacted=False,
+                        )
+                        to_return.update(evs)
+                    defer.returnValue(to_return)
+
+                current_state = yield resolve_events(
+                    state_sets,
+                    state_map_factory=get_events,
+                )
         else:
             return
 
@@ -554,11 +597,91 @@ class EventsStore(SQLBaseStore):
         and the rejections table. Things reading from those table will need to check
         whether the event was rejected.
 
-        If delete_existing is True then existing events will be purged from the
-        database before insertion. This is useful when retrying due to IntegrityError.
+        Args:
+            txn (twisted.enterprise.adbapi.Connection): db connection
+            events_and_contexts (list[(EventBase, EventContext)]):
+                events to persist
+            backfilled (bool): True if the events were backfilled
+            delete_existing (bool): True to purge existing table rows for the
+                events from the database. This is useful when retrying due to
+                IntegrityError.
+            current_state_for_room (dict[str, (list[str], list[str])]):
+                The current-state delta for each room. For each room, a tuple
+                (to_delete, to_insert), being a list of event ids to be removed
+                from the current state, and a list of event ids to be added to
+                the current state.
+            new_forward_extremeties (dict[str, list[str]]):
+                The new forward extremities for each room. For each room, a
+                list of the event ids which are the forward extremities.
+
         """
+        self._update_current_state_txn(txn, current_state_for_room)
+
         max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
-        for room_id, current_state_tuple in current_state_for_room.iteritems():
+        self._update_forward_extremities_txn(
+            txn,
+            new_forward_extremities=new_forward_extremeties,
+            max_stream_order=max_stream_order,
+        )
+
+        # Ensure that we don't have the same event twice.
+        events_and_contexts = self._filter_events_and_contexts_for_duplicates(
+            events_and_contexts,
+        )
+
+        self._update_room_depths_txn(
+            txn,
+            events_and_contexts=events_and_contexts,
+            backfilled=backfilled,
+        )
+
+        # _update_outliers_txn filters out any events which have already been
+        # persisted, and returns the filtered list.
+        events_and_contexts = self._update_outliers_txn(
+            txn,
+            events_and_contexts=events_and_contexts,
+        )
+
+        # From this point onwards the events are only events that we haven't
+        # seen before.
+
+        if delete_existing:
+            # For paranoia reasons, we go and delete all the existing entries
+            # for these events so we can reinsert them.
+            # This gets around any problems with some tables already having
+            # entries.
+            self._delete_existing_rows_txn(
+                txn,
+                events_and_contexts=events_and_contexts,
+            )
+
+        self._store_event_txn(
+            txn,
+            events_and_contexts=events_and_contexts,
+        )
+
+        # Insert into the state_groups, state_groups_state, and
+        # event_to_state_groups tables.
+        self._store_mult_state_groups_txn(txn, events_and_contexts)
+
+        # _store_rejected_events_txn filters out any events which were
+        # rejected, and returns the filtered list.
+        events_and_contexts = self._store_rejected_events_txn(
+            txn,
+            events_and_contexts=events_and_contexts,
+        )
+
+        # From this point onwards the events are only ones that weren't
+        # rejected.
+
+        self._update_metadata_tables_txn(
+            txn,
+            events_and_contexts=events_and_contexts,
+            backfilled=backfilled,
+        )
+
+    def _update_current_state_txn(self, txn, state_delta_by_room):
+        for room_id, current_state_tuple in state_delta_by_room.iteritems():
                 to_delete, to_insert = current_state_tuple
                 txn.executemany(
                     "DELETE FROM current_state_events WHERE event_id = ?",
@@ -608,7 +731,9 @@ class EventsStore(SQLBaseStore):
                     txn, self.get_current_state_ids, (room_id,)
                 )
 
-        for room_id, new_extrem in new_forward_extremeties.items():
+    def _update_forward_extremities_txn(self, txn, new_forward_extremities,
+                                        max_stream_order):
+        for room_id, new_extrem in new_forward_extremities.iteritems():
             self._simple_delete_txn(
                 txn,
                 table="event_forward_extremities",
@@ -626,7 +751,7 @@ class EventsStore(SQLBaseStore):
                     "event_id": ev_id,
                     "room_id": room_id,
                 }
-                for room_id, new_extrem in new_forward_extremeties.items()
+                for room_id, new_extrem in new_forward_extremities.iteritems()
                 for ev_id in new_extrem
             ],
         )
@@ -643,13 +768,22 @@ class EventsStore(SQLBaseStore):
                     "event_id": event_id,
                     "stream_ordering": max_stream_order,
                 }
-                for room_id, new_extrem in new_forward_extremeties.items()
+                for room_id, new_extrem in new_forward_extremities.iteritems()
                 for event_id in new_extrem
             ]
         )
 
-        # Ensure that we don't have the same event twice.
-        # Pick the earliest non-outlier if there is one, else the earliest one.
+    @classmethod
+    def _filter_events_and_contexts_for_duplicates(cls, events_and_contexts):
+        """Ensure that we don't have the same event twice.
+
+        Pick the earliest non-outlier if there is one, else the earliest one.
+
+        Args:
+            events_and_contexts (list[(EventBase, EventContext)]):
+        Returns:
+            list[(EventBase, EventContext)]: filtered list
+        """
         new_events_and_contexts = OrderedDict()
         for event, context in events_and_contexts:
             prev_event_context = new_events_and_contexts.get(event.event_id)
@@ -662,9 +796,17 @@ class EventsStore(SQLBaseStore):
                         new_events_and_contexts[event.event_id] = (event, context)
             else:
                 new_events_and_contexts[event.event_id] = (event, context)
+        return new_events_and_contexts.values()
 
-        events_and_contexts = new_events_and_contexts.values()
+    def _update_room_depths_txn(self, txn, events_and_contexts, backfilled):
+        """Update min_depth for each room
 
+        Args:
+            txn (twisted.enterprise.adbapi.Connection): db connection
+            events_and_contexts (list[(EventBase, EventContext)]): events
+                we are persisting
+            backfilled (bool): True if the events were backfilled
+        """
         depth_updates = {}
         for event, context in events_and_contexts:
             # Remove the any existing cache entries for the event_ids
@@ -680,9 +822,24 @@ class EventsStore(SQLBaseStore):
                     event.depth, depth_updates.get(event.room_id, event.depth)
                 )
 
-        for room_id, depth in depth_updates.items():
+        for room_id, depth in depth_updates.iteritems():
             self._update_min_depth_for_room_txn(txn, room_id, depth)
 
+    def _update_outliers_txn(self, txn, events_and_contexts):
+        """Update any outliers with new event info.
+
+        This turns outliers into ex-outliers (unless the new event was
+        rejected).
+
+        Args:
+            txn (twisted.enterprise.adbapi.Connection): db connection
+            events_and_contexts (list[(EventBase, EventContext)]): events
+                we are persisting
+
+        Returns:
+            list[(EventBase, EventContext)] new list, without events which
+            are already in the events table.
+        """
         txn.execute(
             "SELECT event_id, outlier FROM events WHERE event_id in (%s)" % (
                 ",".join(["?"] * len(events_and_contexts)),
@@ -692,24 +849,21 @@ class EventsStore(SQLBaseStore):
 
         have_persisted = {
             event_id: outlier
-            for event_id, outlier in txn.fetchall()
+            for event_id, outlier in txn
         }
 
         to_remove = set()
         for event, context in events_and_contexts:
-            if context.rejected:
-                # If the event is rejected then we don't care if the event
-                # was an outlier or not.
-                if event.event_id in have_persisted:
-                    # If we have already seen the event then ignore it.
-                    to_remove.add(event)
-                continue
-
             if event.event_id not in have_persisted:
                 continue
 
             to_remove.add(event)
 
+            if context.rejected:
+                # If the event is rejected then we don't care if the event
+                # was an outlier or not.
+                continue
+
             outlier_persisted = have_persisted[event.event_id]
             if not event.internal_metadata.is_outlier() and outlier_persisted:
                 # We received a copy of an event that we had already stored as
@@ -764,37 +918,19 @@ class EventsStore(SQLBaseStore):
                 # event isn't an outlier any more.
                 self._update_backward_extremeties(txn, [event])
 
-        events_and_contexts = [
+        return [
             ec for ec in events_and_contexts if ec[0] not in to_remove
         ]
 
+    @classmethod
+    def _delete_existing_rows_txn(cls, txn, events_and_contexts):
         if not events_and_contexts:
-            # Make sure we don't pass an empty list to functions that expect to
-            # be storing at least one element.
+            # nothing to do here
             return
 
-        # From this point onwards the events are only events that we haven't
-        # seen before.
-
-        def event_dict(event):
-            return {
-                k: v
-                for k, v in event.get_dict().items()
-                if k not in [
-                    "redacted",
-                    "redacted_because",
-                ]
-            }
-
-        if delete_existing:
-            # For paranoia reasons, we go and delete all the existing entries
-            # for these events so we can reinsert them.
-            # This gets around any problems with some tables already having
-            # entries.
-
-            logger.info("Deleting existing")
+        logger.info("Deleting existing")
 
-            for table in (
+        for table in (
                 "events",
                 "event_auth",
                 "event_json",
@@ -817,11 +953,30 @@ class EventsStore(SQLBaseStore):
                 "redactions",
                 "room_memberships",
                 "topics"
-            ):
-                txn.executemany(
-                    "DELETE FROM %s WHERE event_id = ?" % (table,),
-                    [(ev.event_id,) for ev, _ in events_and_contexts]
-                )
+        ):
+            txn.executemany(
+                "DELETE FROM %s WHERE event_id = ?" % (table,),
+                [(ev.event_id,) for ev, _ in events_and_contexts]
+            )
+
+    def _store_event_txn(self, txn, events_and_contexts):
+        """Insert new events into the event and event_json tables
+
+        Args:
+            txn (twisted.enterprise.adbapi.Connection): db connection
+            events_and_contexts (list[(EventBase, EventContext)]): events
+                we are persisting
+        """
+
+        if not events_and_contexts:
+            # nothing to do here
+            return
+
+        def event_dict(event):
+            d = event.get_dict()
+            d.pop("redacted", None)
+            d.pop("redacted_because", None)
+            return d
 
         self._simple_insert_many_txn(
             txn,
@@ -865,6 +1020,19 @@ class EventsStore(SQLBaseStore):
             ],
         )
 
+    def _store_rejected_events_txn(self, txn, events_and_contexts):
+        """Add rows to the 'rejections' table for received events which were
+        rejected
+
+        Args:
+            txn (twisted.enterprise.adbapi.Connection): db connection
+            events_and_contexts (list[(EventBase, EventContext)]): events
+                we are persisting
+
+        Returns:
+            list[(EventBase, EventContext)] new list, without the rejected
+                events.
+        """
         # Remove the rejected events from the list now that we've added them
         # to the events table and the events_json table.
         to_remove = set()
@@ -876,17 +1044,24 @@ class EventsStore(SQLBaseStore):
                 )
                 to_remove.add(event)
 
-        events_and_contexts = [
+        return [
             ec for ec in events_and_contexts if ec[0] not in to_remove
         ]
 
+    def _update_metadata_tables_txn(self, txn, events_and_contexts, backfilled):
+        """Update all the miscellaneous tables for new events
+
+        Args:
+            txn (twisted.enterprise.adbapi.Connection): db connection
+            events_and_contexts (list[(EventBase, EventContext)]): events
+                we are persisting
+            backfilled (bool): True if the events were backfilled
+        """
+
         if not events_and_contexts:
-            # Make sure we don't pass an empty list to functions that expect to
-            # be storing at least one element.
+            # nothing to do here
             return
 
-        # From this point onwards the events are only ones that weren't rejected.
-
         for event, context in events_and_contexts:
             # Insert all the push actions into the event_push_actions table.
             if context.push_actions:
@@ -915,10 +1090,6 @@ class EventsStore(SQLBaseStore):
             ],
         )
 
-        # Insert into the state_groups, state_groups_state, and
-        # event_to_state_groups tables.
-        self._store_mult_state_groups_txn(txn, events_and_contexts)
-
         # Update the event_forward_extremities, event_backward_extremities and
         # event_edges tables.
         self._handle_mult_prev_events(
@@ -1005,13 +1176,6 @@ class EventsStore(SQLBaseStore):
         # Prefill the event cache
         self._add_to_cache(txn, events_and_contexts)
 
-        if backfilled:
-            # Backfilled events come before the current state so we don't need
-            # to update the current state table
-            return
-
-        return
-
     def _add_to_cache(self, txn, events_and_contexts):
         to_prefill = []
 
@@ -1620,14 +1784,13 @@ class EventsStore(SQLBaseStore):
 
         def get_all_new_events_txn(txn):
             sql = (
-                "SELECT e.stream_ordering, ej.internal_metadata, ej.json, eg.state_group"
-                " FROM events as e"
-                " JOIN event_json as ej"
-                " ON e.event_id = ej.event_id AND e.room_id = ej.room_id"
-                " LEFT JOIN event_to_state_groups as eg"
-                " ON e.event_id = eg.event_id"
-                " WHERE ? < e.stream_ordering AND e.stream_ordering <= ?"
-                " ORDER BY e.stream_ordering ASC"
+                "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
+                " state_key, redacts"
+                " FROM events AS e"
+                " LEFT JOIN redactions USING (event_id)"
+                " LEFT JOIN state_events USING (event_id)"
+                " WHERE ? < stream_ordering AND stream_ordering <= ?"
+                " ORDER BY stream_ordering ASC"
                 " LIMIT ?"
             )
             if have_forward_events:
@@ -1653,15 +1816,13 @@ class EventsStore(SQLBaseStore):
                 forward_ex_outliers = []
 
             sql = (
-                "SELECT -e.stream_ordering, ej.internal_metadata, ej.json,"
-                " eg.state_group"
-                " FROM events as e"
-                " JOIN event_json as ej"
-                " ON e.event_id = ej.event_id AND e.room_id = ej.room_id"
-                " LEFT JOIN event_to_state_groups as eg"
-                " ON e.event_id = eg.event_id"
-                " WHERE ? > e.stream_ordering AND e.stream_ordering >= ?"
-                " ORDER BY e.stream_ordering DESC"
+                "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
+                " state_key, redacts"
+                " FROM events AS e"
+                " LEFT JOIN redactions USING (event_id)"
+                " LEFT JOIN state_events USING (event_id)"
+                " WHERE ? > stream_ordering AND stream_ordering >= ?"
+                " ORDER BY stream_ordering DESC"
                 " LIMIT ?"
             )
             if have_backfill_events:
@@ -1848,7 +2009,7 @@ class EventsStore(SQLBaseStore):
                         "state_key": key[1],
                         "event_id": state_id,
                     }
-                    for key, state_id in curr_state.items()
+                    for key, state_id in curr_state.iteritems()
                 ],
             )
 
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 86b37b9ddd..3b5e0a4fb9 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -101,9 +101,10 @@ class KeyStore(SQLBaseStore):
         key_ids
         Args:
             server_name (str): The name of the server.
-            key_ids (list of str): List of key_ids to try and look up.
+            key_ids (iterable[str]): key_ids to try and look up.
         Returns:
-            (list of VerifyKey): The verification keys.
+            Deferred: resolves to dict[str, VerifyKey]: map from
+               key_id to verification key.
         """
         keys = {}
         for key_id in key_ids:
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index ed84db6b4b..6e623843d5 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -356,7 +356,7 @@ def _get_or_create_schema_state(txn, database_engine):
             ),
             (current_version,)
         )
-        applied_deltas = [d for d, in txn.fetchall()]
+        applied_deltas = [d for d, in txn]
         return current_version, applied_deltas, upgraded
 
     return None
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index 4d1590d2b4..9e9d3c2591 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -85,8 +85,8 @@ class PresenceStore(SQLBaseStore):
                 self.presence_stream_cache.entity_has_changed,
                 state.user_id, stream_id,
             )
-            self._invalidate_cache_and_stream(
-                txn, self._get_presence_for_user, (state.user_id,)
+            txn.call_after(
+                self._get_presence_for_user.invalidate, (state.user_id,)
             )
 
         # Actually insert new rows
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index 5cf41501ea..6b0f8c2787 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -313,10 +313,9 @@ class ReceiptsStore(SQLBaseStore):
         )
 
         txn.execute(sql, (room_id, receipt_type, user_id))
-        results = txn.fetchall()
 
-        if results and topological_ordering:
-            for to, so, _ in results:
+        if topological_ordering:
+            for to, so, _ in txn:
                 if int(to) > topological_ordering:
                     return False
                 elif int(to) == topological_ordering and int(so) >= stream_ordering:
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 26be6060c3..ec2c52ab93 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -209,7 +209,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
                 " WHERE lower(name) = lower(?)"
             )
             txn.execute(sql, (user_id,))
-            return dict(txn.fetchall())
+            return dict(txn)
 
         return self.runInteraction("get_users_by_id_case_insensitive", f)
 
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 8a2fe2fdf5..e4c56cc175 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -396,7 +396,7 @@ class RoomStore(SQLBaseStore):
                     sql % ("AND appservice_id IS NULL",),
                     (stream_id,)
                 )
-            return dict(txn.fetchall())
+            return dict(txn)
         else:
             # We want to get from all lists, so we need to aggregate the results
 
@@ -422,7 +422,7 @@ class RoomStore(SQLBaseStore):
 
             results = {}
             # A room is visible if its visible on any list.
-            for room_id, visibility in txn.fetchall():
+            for room_id, visibility in txn:
                 results[room_id] = bool(visibility) or results.get(room_id, False)
 
             return results
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index e38d8927bf..367dbbbcf6 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -129,17 +129,30 @@ class RoomMemberStore(SQLBaseStore):
         with self._stream_id_gen.get_next() as stream_ordering:
             yield self.runInteraction("locally_reject_invite", f, stream_ordering)
 
+    @cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True)
+    def get_hosts_in_room(self, room_id, cache_context):
+        """Returns the set of all hosts currently in the room
+        """
+        user_ids = yield self.get_users_in_room(
+            room_id, on_invalidate=cache_context.invalidate,
+        )
+        hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids)
+        defer.returnValue(hosts)
+
     @cached(max_entries=500000, iterable=True)
     def get_users_in_room(self, room_id):
         def f(txn):
-
-            rows = self._get_members_rows_txn(
-                txn,
-                room_id=room_id,
-                membership=Membership.JOIN,
+            sql = (
+                "SELECT m.user_id FROM room_memberships as m"
+                " INNER JOIN current_state_events as c"
+                " ON m.event_id = c.event_id "
+                " AND m.room_id = c.room_id "
+                " AND m.user_id = c.state_key"
+                " WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?"
             )
 
-            return [r["user_id"] for r in rows]
+            txn.execute(sql, (room_id, Membership.JOIN,))
+            return [r[0] for r in txn]
         return self.runInteraction("get_users_in_room", f)
 
     @cached()
@@ -246,34 +259,6 @@ class RoomMemberStore(SQLBaseStore):
 
         return results
 
-    def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None):
-        where_clause = "c.room_id = ?"
-        where_values = [room_id]
-
-        if membership:
-            where_clause += " AND m.membership = ?"
-            where_values.append(membership)
-
-        if user_id:
-            where_clause += " AND m.user_id = ?"
-            where_values.append(user_id)
-
-        sql = (
-            "SELECT m.* FROM room_memberships as m"
-            " INNER JOIN current_state_events as c"
-            " ON m.event_id = c.event_id "
-            " AND m.room_id = c.room_id "
-            " AND m.user_id = c.state_key"
-            " WHERE c.type = 'm.room.member' AND %(where)s"
-        ) % {
-            "where": where_clause,
-        }
-
-        txn.execute(sql, where_values)
-        rows = self.cursor_to_dict(txn)
-
-        return rows
-
     @cachedInlineCallbacks(max_entries=500000, iterable=True)
     def get_rooms_for_user(self, user_id):
         """Returns a set of room_ids the user is currently joined to
diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py
index e1dca927d7..67d5d9969a 100644
--- a/synapse/storage/signatures.py
+++ b/synapse/storage/signatures.py
@@ -72,7 +72,7 @@ class SignatureStore(SQLBaseStore):
             " WHERE event_id = ?"
         )
         txn.execute(query, (event_id, ))
-        return {k: v for k, v in txn.fetchall()}
+        return {k: v for k, v in txn}
 
     def _store_event_reference_hashes_txn(self, txn, events):
         """Store a hash for a PDU
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 27f1ec89ec..fb23f6f462 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -90,7 +90,7 @@ class StateStore(SQLBaseStore):
             event_ids,
         )
 
-        groups = set(event_to_groups.values())
+        groups = set(event_to_groups.itervalues())
         group_to_state = yield self._get_state_for_groups(groups)
 
         defer.returnValue(group_to_state)
@@ -108,17 +108,18 @@ class StateStore(SQLBaseStore):
 
         state_event_map = yield self.get_events(
             [
-                ev_id for group_ids in group_to_ids.values()
-                for ev_id in group_ids.values()
+                ev_id for group_ids in group_to_ids.itervalues()
+                for ev_id in group_ids.itervalues()
             ],
             get_prev_content=False
         )
 
         defer.returnValue({
             group: [
-                state_event_map[v] for v in event_id_map.values() if v in state_event_map
+                state_event_map[v] for v in event_id_map.itervalues()
+                if v in state_event_map
             ]
-            for group, event_id_map in group_to_ids.items()
+            for group, event_id_map in group_to_ids.iteritems()
         })
 
     def _have_persisted_state_group_txn(self, txn, state_group):
@@ -136,6 +137,16 @@ class StateStore(SQLBaseStore):
                 continue
 
             if context.current_state_ids is None:
+                # AFAIK, this can never happen
+                logger.error(
+                    "Non-outlier event %s had current_state_ids==None",
+                    event.event_id)
+                continue
+
+            # if the event was rejected, just give it the same state as its
+            # predecessor.
+            if context.rejected:
+                state_groups[event.event_id] = context.prev_group
                 continue
 
             state_groups[event.event_id] = context.state_group
@@ -180,7 +191,7 @@ class StateStore(SQLBaseStore):
                             "state_key": key[1],
                             "event_id": state_id,
                         }
-                        for key, state_id in context.delta_ids.items()
+                        for key, state_id in context.delta_ids.iteritems()
                     ],
                 )
             else:
@@ -195,7 +206,7 @@ class StateStore(SQLBaseStore):
                             "state_key": key[1],
                             "event_id": state_id,
                         }
-                        for key, state_id in context.current_state_ids.items()
+                        for key, state_id in context.current_state_ids.iteritems()
                     ],
                 )
 
@@ -207,7 +218,7 @@ class StateStore(SQLBaseStore):
                     "state_group": state_group_id,
                     "event_id": event_id,
                 }
-                for event_id, state_group_id in state_groups.items()
+                for event_id, state_group_id in state_groups.iteritems()
             ],
         )
 
@@ -331,10 +342,10 @@ class StateStore(SQLBaseStore):
                     args.extend(where_args)
 
                     txn.execute(sql % (where_clause,), args)
-                    rows = self.cursor_to_dict(txn)
-                    for row in rows:
-                        key = (row["type"], row["state_key"])
-                        results[group][key] = row["event_id"]
+                    for row in txn:
+                        typ, state_key, event_id = row
+                        key = (typ, state_key)
+                        results[group][key] = event_id
         else:
             if types is not None:
                 where_clause = "AND (%s)" % (
@@ -363,12 +374,11 @@ class StateStore(SQLBaseStore):
                         " WHERE state_group = ? %s" % (where_clause,),
                         args
                     )
-                    rows = txn.fetchall()
-                    results[group].update({
-                        (typ, state_key): event_id
-                        for typ, state_key, event_id in rows
+                    results[group].update(
+                        ((typ, state_key), event_id)
+                        for typ, state_key, event_id in txn
                         if (typ, state_key) not in results[group]
-                    })
+                    )
 
                     # If the lengths match then we must have all the types,
                     # so no need to go walk further down the tree.
@@ -405,21 +415,21 @@ class StateStore(SQLBaseStore):
             event_ids,
         )
 
-        groups = set(event_to_groups.values())
+        groups = set(event_to_groups.itervalues())
         group_to_state = yield self._get_state_for_groups(groups, types)
 
         state_event_map = yield self.get_events(
-            [ev_id for sd in group_to_state.values() for ev_id in sd.values()],
+            [ev_id for sd in group_to_state.itervalues() for ev_id in sd.itervalues()],
             get_prev_content=False
         )
 
         event_to_state = {
             event_id: {
                 k: state_event_map[v]
-                for k, v in group_to_state[group].items()
+                for k, v in group_to_state[group].iteritems()
                 if v in state_event_map
             }
-            for event_id, group in event_to_groups.items()
+            for event_id, group in event_to_groups.iteritems()
         }
 
         defer.returnValue({event: event_to_state[event] for event in event_ids})
@@ -442,12 +452,12 @@ class StateStore(SQLBaseStore):
             event_ids,
         )
 
-        groups = set(event_to_groups.values())
+        groups = set(event_to_groups.itervalues())
         group_to_state = yield self._get_state_for_groups(groups, types)
 
         event_to_state = {
             event_id: group_to_state[group]
-            for event_id, group in event_to_groups.items()
+            for event_id, group in event_to_groups.iteritems()
         }
 
         defer.returnValue({event: event_to_state[event] for event in event_ids})
@@ -486,7 +496,7 @@ class StateStore(SQLBaseStore):
         state_map = yield self.get_state_ids_for_events([event_id], types)
         defer.returnValue(state_map[event_id])
 
-    @cached(num_args=2, max_entries=10000)
+    @cached(num_args=2, max_entries=100000)
     def _get_state_group_for_event(self, room_id, event_id):
         return self._simple_select_one_onecol(
             table="event_to_state_groups",
@@ -559,7 +569,7 @@ class StateStore(SQLBaseStore):
         got_all = not (missing_types or types is None)
 
         return {
-            k: v for k, v in state_dict_ids.items()
+            k: v for k, v in state_dict_ids.iteritems()
             if include(k[0], k[1])
         }, missing_types, got_all
 
@@ -618,7 +628,7 @@ class StateStore(SQLBaseStore):
 
             # Now we want to update the cache with all the things we fetched
             # from the database.
-            for group, group_state_dict in group_to_state_dict.items():
+            for group, group_state_dict in group_to_state_dict.iteritems():
                 if types:
                     # We delibrately put key -> None mappings into the cache to
                     # cache absence of the key, on the assumption that if we've
@@ -633,10 +643,10 @@ class StateStore(SQLBaseStore):
                 else:
                     state_dict = results[group]
 
-                state_dict.update({
-                    (intern_string(k[0]), intern_string(k[1])): v
-                    for k, v in group_state_dict.items()
-                })
+                state_dict.update(
+                    ((intern_string(k[0]), intern_string(k[1])), v)
+                    for k, v in group_state_dict.iteritems()
+                )
 
                 self._state_group_cache.update(
                     cache_seq_num,
@@ -647,10 +657,10 @@ class StateStore(SQLBaseStore):
 
         # Remove all the entries with None values. The None values were just
         # used for bookkeeping in the cache.
-        for group, state_dict in results.items():
+        for group, state_dict in results.iteritems():
             results[group] = {
                 key: event_id
-                for key, event_id in state_dict.items()
+                for key, event_id in state_dict.iteritems()
                 if event_id
             }
 
@@ -739,7 +749,7 @@ class StateStore(SQLBaseStore):
                         # of keys
 
                         delta_state = {
-                            key: value for key, value in curr_state.items()
+                            key: value for key, value in curr_state.iteritems()
                             if prev_state.get(key, None) != value
                         }
 
@@ -779,7 +789,7 @@ class StateStore(SQLBaseStore):
                                     "state_key": key[1],
                                     "event_id": state_id,
                                 }
-                                for key, state_id in delta_state.items()
+                                for key, state_id in delta_state.iteritems()
                             ],
                         )
 
diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py
index 5a2c1aa59b..bff73f3f04 100644
--- a/synapse/storage/tags.py
+++ b/synapse/storage/tags.py
@@ -95,7 +95,7 @@ class TagsStore(SQLBaseStore):
             for stream_id, user_id, room_id in tag_ids:
                 txn.execute(sql, (user_id, room_id))
                 tags = []
-                for tag, content in txn.fetchall():
+                for tag, content in txn:
                     tags.append(json.dumps(tag) + ":" + content)
                 tag_json = "{" + ",".join(tags) + "}"
                 results.append((stream_id, user_id, room_id, tag_json))
@@ -132,7 +132,7 @@ class TagsStore(SQLBaseStore):
                 " WHERE user_id = ? AND stream_id > ?"
             )
             txn.execute(sql, (user_id, stream_id))
-            room_ids = [row[0] for row in txn.fetchall()]
+            room_ids = [row[0] for row in txn]
             return room_ids
 
         changed = self._account_data_stream_cache.has_entity_changed(
diff --git a/synapse/types.py b/synapse/types.py
index 9666f9d73f..c87ed813b9 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -216,9 +216,7 @@ class StreamToken(
             return self
 
     def copy_and_replace(self, key, new_value):
-        d = self._asdict()
-        d[key] = new_value
-        return StreamToken(**d)
+        return self._replace(**{key: new_value})
 
 
 StreamToken.START = StreamToken(
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 30fc480108..98a5a26ac5 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
 
 class DeferredTimedOutError(SynapseError):
     def __init__(self):
-        super(SynapseError).__init__(504, "Timed out")
+        super(SynapseError, self).__init__(504, "Timed out")
 
 
 def unwrapFirstError(failure):
@@ -93,8 +93,10 @@ class Clock(object):
         ret_deferred = defer.Deferred()
 
         def timed_out_fn():
+            e = DeferredTimedOutError()
+
             try:
-                ret_deferred.errback(DeferredTimedOutError())
+                ret_deferred.errback(e)
             except:
                 pass
 
@@ -114,7 +116,7 @@ class Clock(object):
 
         ret_deferred.addBoth(cancel)
 
-        def sucess(res):
+        def success(res):
             try:
                 ret_deferred.callback(res)
             except:
@@ -128,7 +130,7 @@ class Clock(object):
             except:
                 pass
 
-        given_deferred.addCallbacks(callback=sucess, errback=err)
+        given_deferred.addCallbacks(callback=success, errback=err)
 
         timer = self.call_later(time_out, timed_out_fn)
 
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 35380bf8ed..1453faf0ef 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -89,6 +89,11 @@ class ObservableDeferred(object):
         deferred.addCallbacks(callback, errback)
 
     def observe(self):
+        """Observe the underlying deferred.
+
+        Can return either a deferred if the underlying deferred is still pending
+        (or has failed), or the actual value. Callers may need to use maybeDeferred.
+        """
         if not self._result:
             d = defer.Deferred()
 
@@ -101,7 +106,7 @@ class ObservableDeferred(object):
             return d
         else:
             success, res = self._result
-            return defer.succeed(res) if success else defer.fail(res)
+            return res if success else defer.fail(res)
 
     def observers(self):
         return self._observers
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 998de70d29..9d0d0be1f9 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -15,12 +15,9 @@
 import logging
 
 from synapse.util.async import ObservableDeferred
-from synapse.util import unwrapFirstError
+from synapse.util import unwrapFirstError, logcontext
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
-from synapse.util.logcontext import (
-    PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
-)
 
 from . import DEBUG_CACHES, register_cache
 
@@ -189,7 +186,67 @@ class Cache(object):
         self.cache.clear()
 
 
-class CacheDescriptor(object):
+class _CacheDescriptorBase(object):
+    def __init__(self, orig, num_args, inlineCallbacks, cache_context=False):
+        self.orig = orig
+
+        if inlineCallbacks:
+            self.function_to_call = defer.inlineCallbacks(orig)
+        else:
+            self.function_to_call = orig
+
+        arg_spec = inspect.getargspec(orig)
+        all_args = arg_spec.args
+
+        if "cache_context" in all_args:
+            if not cache_context:
+                raise ValueError(
+                    "Cannot have a 'cache_context' arg without setting"
+                    " cache_context=True"
+                )
+        elif cache_context:
+            raise ValueError(
+                "Cannot have cache_context=True without having an arg"
+                " named `cache_context`"
+            )
+
+        if num_args is None:
+            num_args = len(all_args) - 1
+            if cache_context:
+                num_args -= 1
+
+        if len(all_args) < num_args + 1:
+            raise Exception(
+                "Not enough explicit positional arguments to key off for %r: "
+                "got %i args, but wanted %i. (@cached cannot key off *args or "
+                "**kwargs)"
+                % (orig.__name__, len(all_args), num_args)
+            )
+
+        self.num_args = num_args
+
+        # list of the names of the args used as the cache key
+        self.arg_names = all_args[1:num_args + 1]
+
+        # self.arg_defaults is a map of arg name to its default value for each
+        # argument that has a default value
+        if arg_spec.defaults:
+            self.arg_defaults = dict(zip(
+                all_args[-len(arg_spec.defaults):],
+                arg_spec.defaults
+            ))
+        else:
+            self.arg_defaults = {}
+
+        if "cache_context" in self.arg_names:
+            raise Exception(
+                "cache_context arg cannot be included among the cache keys"
+            )
+
+        self.add_cache_context = cache_context
+
+
+class CacheDescriptor(_CacheDescriptorBase):
     """ A method decorator that applies a memoizing cache around the function.
 
     This caches deferreds, rather than the results themselves. Deferreds that
@@ -217,52 +274,24 @@ class CacheDescriptor(object):
             r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
             defer.returnValue(r1 + r2)
 
+    Args:
+        num_args (int): number of positional arguments (excluding ``self`` and
+            ``cache_context``) to use as cache keys. Defaults to all named
+            args of the function.
     """
-    def __init__(self, orig, max_entries=1000, num_args=1, tree=False,
+    def __init__(self, orig, max_entries=1000, num_args=None, tree=False,
                  inlineCallbacks=False, cache_context=False, iterable=False):
-        max_entries = int(max_entries * CACHE_SIZE_FACTOR)
 
-        self.orig = orig
+        super(CacheDescriptor, self).__init__(
+            orig, num_args=num_args, inlineCallbacks=inlineCallbacks,
+            cache_context=cache_context)
 
-        if inlineCallbacks:
-            self.function_to_call = defer.inlineCallbacks(orig)
-        else:
-            self.function_to_call = orig
+        max_entries = int(max_entries * CACHE_SIZE_FACTOR)
 
         self.max_entries = max_entries
-        self.num_args = num_args
         self.tree = tree
-
         self.iterable = iterable
 
-        all_args = inspect.getargspec(orig)
-        self.arg_names = all_args.args[1:num_args + 1]
-
-        if "cache_context" in all_args.args:
-            if not cache_context:
-                raise ValueError(
-                    "Cannot have a 'cache_context' arg without setting"
-                    " cache_context=True"
-                )
-            try:
-                self.arg_names.remove("cache_context")
-            except ValueError:
-                pass
-        elif cache_context:
-            raise ValueError(
-                "Cannot have cache_context=True without having an arg"
-                " named `cache_context`"
-            )
-
-        self.add_cache_context = cache_context
-
-        if len(self.arg_names) < self.num_args:
-            raise Exception(
-                "Not enough explicit positional arguments to key off of for %r."
-                " (@cached cannot key off of *args or **kwargs)"
-                % (orig.__name__,)
-            )
-
     def __get__(self, obj, objtype=None):
         cache = Cache(
             name=self.orig.__name__,
@@ -272,18 +301,31 @@ class CacheDescriptor(object):
             iterable=self.iterable,
         )
 
+        def get_cache_key(args, kwargs):
+            """Given some args/kwargs return a generator that resolves into
+            the cache_key.
+
+            We loop through each arg name, looking up if its in the `kwargs`,
+            otherwise using the next argument in `args`. If there are no more
+            args then we try looking the arg name up in the defaults
+            """
+            pos = 0
+            for nm in self.arg_names:
+                if nm in kwargs:
+                    yield kwargs[nm]
+                elif pos < len(args):
+                    yield args[pos]
+                    pos += 1
+                else:
+                    yield self.arg_defaults[nm]
+
         @functools.wraps(self.orig)
         def wrapped(*args, **kwargs):
             # If we're passed a cache_context then we'll want to call its invalidate()
             # whenever we are invalidated
             invalidate_callback = kwargs.pop("on_invalidate", None)
 
-            # Add temp cache_context so inspect.getcallargs doesn't explode
-            if self.add_cache_context:
-                kwargs["cache_context"] = None
-
-            arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
-            cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
+            cache_key = tuple(get_cache_key(args, kwargs))
 
             # Add our own `cache_context` to argument list if the wrapped function
             # has asked for one
@@ -308,11 +350,9 @@ class CacheDescriptor(object):
                         defer.returnValue(cached_result)
                     observer.addCallback(check_result)
 
-                return preserve_context_over_deferred(observer)
             except KeyError:
                 ret = defer.maybeDeferred(
-                    preserve_context_over_fn,
-                    self.function_to_call,
+                    logcontext.preserve_fn(self.function_to_call),
                     obj, *args, **kwargs
                 )
 
@@ -322,10 +362,14 @@ class CacheDescriptor(object):
 
                 ret.addErrback(onErr)
 
-                ret = ObservableDeferred(ret, consumeErrors=True)
-                cache.set(cache_key, ret, callback=invalidate_callback)
+                result_d = ObservableDeferred(ret, consumeErrors=True)
+                cache.set(cache_key, result_d, callback=invalidate_callback)
+                observer = result_d.observe()
 
-                return preserve_context_over_deferred(ret.observe())
+            if isinstance(observer, defer.Deferred):
+                return logcontext.make_deferred_yieldable(observer)
+            else:
+                return observer
 
         wrapped.invalidate = cache.invalidate
         wrapped.invalidate_all = cache.invalidate_all
@@ -338,48 +382,40 @@ class CacheDescriptor(object):
         return wrapped
 
 
-class CacheListDescriptor(object):
+class CacheListDescriptor(_CacheDescriptorBase):
     """Wraps an existing cache to support bulk fetching of keys.
 
     Given a list of keys it looks in the cache to find any hits, then passes
-    the list of missing keys to the wrapped fucntion.
+    the list of missing keys to the wrapped function.
+
+    Once wrapped, the function returns either a Deferred which resolves to
+    the list of results, or (if all results were cached), just the list of
+    results.
     """
 
-    def __init__(self, orig, cached_method_name, list_name, num_args=1,
+    def __init__(self, orig, cached_method_name, list_name, num_args=None,
                  inlineCallbacks=False):
         """
         Args:
             orig (function)
-            method_name (str); The name of the chached method.
+            cached_method_name (str): The name of the chached method.
             list_name (str): Name of the argument which is the bulk lookup list
-            num_args (int)
+            num_args (int): number of positional arguments (excluding ``self``,
+                but including list_name) to use as cache keys. Defaults to all
+                named args of the function.
             inlineCallbacks (bool): Whether orig is a generator that should
                 be wrapped by defer.inlineCallbacks
         """
-        self.orig = orig
+        super(CacheListDescriptor, self).__init__(
+            orig, num_args=num_args, inlineCallbacks=inlineCallbacks)
 
-        if inlineCallbacks:
-            self.function_to_call = defer.inlineCallbacks(orig)
-        else:
-            self.function_to_call = orig
-
-        self.num_args = num_args
         self.list_name = list_name
 
-        self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
         self.list_pos = self.arg_names.index(self.list_name)
-
         self.cached_method_name = cached_method_name
 
         self.sentinel = object()
 
-        if len(self.arg_names) < self.num_args:
-            raise Exception(
-                "Not enough explicit positional arguments to key off of for %r."
-                " (@cached cannot key off of *args or **kwars)"
-                % (orig.__name__,)
-            )
-
         if self.list_name not in self.arg_names:
             raise Exception(
                 "Couldn't see arguments %r for %r."
@@ -425,8 +461,7 @@ class CacheListDescriptor(object):
                 args_to_call[self.list_name] = missing
 
                 ret_d = defer.maybeDeferred(
-                    preserve_context_over_fn,
-                    self.function_to_call,
+                    logcontext.preserve_fn(self.function_to_call),
                     **args_to_call
                 )
 
@@ -435,8 +470,7 @@ class CacheListDescriptor(object):
                 # We need to create deferreds for each arg in the list so that
                 # we can insert the new deferred into the cache.
                 for arg in missing:
-                    with PreserveLoggingContext():
-                        observer = ret_d.observe()
+                    observer = ret_d.observe()
                     observer.addCallback(lambda r, arg: r.get(arg, None), arg)
 
                     observer = ObservableDeferred(observer)
@@ -463,7 +497,7 @@ class CacheListDescriptor(object):
                     results.update(res)
                     return results
 
-                return preserve_context_over_deferred(defer.gatherResults(
+                return logcontext.make_deferred_yieldable(defer.gatherResults(
                     cached_defers.values(),
                     consumeErrors=True,
                 ).addCallback(update_results_dict).addErrback(
@@ -487,7 +521,7 @@ class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
         self.cache.invalidate(self.key)
 
 
-def cached(max_entries=1000, num_args=1, tree=False, cache_context=False,
+def cached(max_entries=1000, num_args=None, tree=False, cache_context=False,
            iterable=False):
     return lambda orig: CacheDescriptor(
         orig,
@@ -499,8 +533,8 @@ def cached(max_entries=1000, num_args=1, tree=False, cache_context=False,
     )
 
 
-def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False,
-                          iterable=False):
+def cachedInlineCallbacks(max_entries=1000, num_args=None, tree=False,
+                          cache_context=False, iterable=False):
     return lambda orig: CacheDescriptor(
         orig,
         max_entries=max_entries,
@@ -512,7 +546,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_contex
     )
 
 
-def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False):
+def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=False):
     """Creates a descriptor that wraps a function in a `CacheListDescriptor`.
 
     Used to do batch lookups for an already created cache. A single argument
@@ -525,7 +559,8 @@ def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False)
         cache (Cache): The underlying cache to use.
         list_name (str): The name of the argument that is the list to use to
             do batch lookups in the cache.
-        num_args (int): Number of arguments to use as the key in the cache.
+        num_args (int): Number of arguments to use as the key in the cache
+            (including list_name). Defaults to all named parameters.
         inlineCallbacks (bool): Should the function be wrapped in an
             `defer.inlineCallbacks`?
 
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index 6c83eb213d..857afee7cb 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -12,6 +12,16 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+""" Thread-local-alike tracking of log contexts within synapse
+
+This module provides objects and utilities for tracking contexts through
+synapse code, so that log lines can include a request identifier, and so that
+CPU and database activity can be accounted for against the request that caused
+them.
+
+See doc/log_contexts.rst for details on how this works.
+"""
+
 from twisted.internet import defer
 
 import threading
@@ -300,6 +310,10 @@ def preserve_context_over_fn(fn, *args, **kwargs):
 def preserve_context_over_deferred(deferred, context=None):
     """Given a deferred wrap it such that any callbacks added later to it will
     be invoked with the current context.
+
+    Deprecated: this almost certainly doesn't do want you want, ie make
+    the deferred follow the synapse logcontext rules: try
+    ``make_deferred_yieldable`` instead.
     """
     if context is None:
         context = LoggingContext.current_context()
@@ -309,24 +323,65 @@ def preserve_context_over_deferred(deferred, context=None):
 
 
 def preserve_fn(f):
-    """Ensures that function is called with correct context and that context is
-    restored after return. Useful for wrapping functions that return a deferred
-    which you don't yield on.
+    """Wraps a function, to ensure that the current context is restored after
+    return from the function, and that the sentinel context is set once the
+    deferred returned by the funtion completes.
+
+    Useful for wrapping functions that return a deferred which you don't yield
+    on.
     """
+    def reset_context(result):
+        LoggingContext.set_current_context(LoggingContext.sentinel)
+        return result
+
+    # XXX: why is this here rather than inside g? surely we want to preserve
+    # the context from the time the function was called, not when it was
+    # wrapped?
     current = LoggingContext.current_context()
 
     def g(*args, **kwargs):
-        with PreserveLoggingContext(current):
-            res = f(*args, **kwargs)
-            if isinstance(res, defer.Deferred):
-                return preserve_context_over_deferred(
-                    res, context=LoggingContext.sentinel
-                )
-            else:
-                return res
+        res = f(*args, **kwargs)
+        if isinstance(res, defer.Deferred) and not res.called:
+            # The function will have reset the context before returning, so
+            # we need to restore it now.
+            LoggingContext.set_current_context(current)
+
+            # The original context will be restored when the deferred
+            # completes, but there is nothing waiting for it, so it will
+            # get leaked into the reactor or some other function which
+            # wasn't expecting it. We therefore need to reset the context
+            # here.
+            #
+            # (If this feels asymmetric, consider it this way: we are
+            # effectively forking a new thread of execution. We are
+            # probably currently within a ``with LoggingContext()`` block,
+            # which is supposed to have a single entry and exit point. But
+            # by spawning off another deferred, we are effectively
+            # adding a new exit point.)
+            res.addBoth(reset_context)
+        return res
     return g
 
 
+@defer.inlineCallbacks
+def make_deferred_yieldable(deferred):
+    """Given a deferred, make it follow the Synapse logcontext rules:
+
+    If the deferred has completed (or is not actually a Deferred), essentially
+    does nothing (just returns another completed deferred with the
+    result/failure).
+
+    If the deferred has not yet completed, resets the logcontext before
+    returning a deferred. Then, when the deferred completes, restores the
+    current logcontext before running callbacks/errbacks.
+
+    (This is more-or-less the opposite operation to preserve_fn.)
+    """
+    with PreserveLoggingContext():
+        r = yield deferred
+    defer.returnValue(r)
+
+
 # modules to ignore in `logcontext_tracer`
 _to_ignore = [
     "synapse.util.logcontext",
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index 153ef001ad..4fa9d1a03c 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -12,7 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+import synapse.util.logcontext
 from twisted.internet import defer
 
 from synapse.api.errors import CodeMessageException
@@ -35,7 +35,8 @@ class NotRetryingDestination(Exception):
 
 
 @defer.inlineCallbacks
-def get_retry_limiter(destination, clock, store, **kwargs):
+def get_retry_limiter(destination, clock, store, ignore_backoff=False,
+                      **kwargs):
     """For a given destination check if we have previously failed to
     send a request there and are waiting before retrying the destination.
     If we are not ready to retry the destination, this will raise a
@@ -43,6 +44,14 @@ def get_retry_limiter(destination, clock, store, **kwargs):
     that will mark the destination as down if an exception is thrown (excluding
     CodeMessageException with code < 500)
 
+    Args:
+        destination (str): name of homeserver
+        clock (synapse.util.clock): timing source
+        store (synapse.storage.transactions.TransactionStore): datastore
+        ignore_backoff (bool): true to ignore the historical backoff data and
+            try the request anyway. We will still update the next
+            retry_interval on success/failure.
+
     Example usage:
 
         try:
@@ -66,7 +75,7 @@ def get_retry_limiter(destination, clock, store, **kwargs):
 
         now = int(clock.time_msec())
 
-        if retry_last_ts + retry_interval > now:
+        if not ignore_backoff and retry_last_ts + retry_interval > now:
             raise NotRetryingDestination(
                 retry_last_ts=retry_last_ts,
                 retry_interval=retry_interval,
@@ -124,7 +133,13 @@ class RetryDestinationLimiter(object):
 
     def __exit__(self, exc_type, exc_val, exc_tb):
         valid_err_code = False
-        if exc_type is not None and issubclass(exc_type, CodeMessageException):
+        if exc_type is None:
+            valid_err_code = True
+        elif not issubclass(exc_type, Exception):
+            # avoid treating exceptions which don't derive from Exception as
+            # failures; this is mostly so as not to catch defer._DefGen.
+            valid_err_code = True
+        elif issubclass(exc_type, CodeMessageException):
             # Some error codes are perfectly fine for some APIs, whereas other
             # APIs may expect to never received e.g. a 404. It's important to
             # handle 404 as some remote servers will return a 404 when the HS
@@ -142,11 +157,13 @@ class RetryDestinationLimiter(object):
             else:
                 valid_err_code = False
 
-        if exc_type is None or valid_err_code:
+        if valid_err_code:
             # We connected successfully.
             if not self.retry_interval:
                 return
 
+            logger.debug("Connection to %s was successful; clearing backoff",
+                         self.destination)
             retry_last_ts = 0
             self.retry_interval = 0
         else:
@@ -160,6 +177,10 @@ class RetryDestinationLimiter(object):
             else:
                 self.retry_interval = self.min_retry_interval
 
+            logger.debug(
+                "Connection to %s was unsuccessful (%s(%s)); backoff now %i",
+                self.destination, exc_type, exc_val, self.retry_interval
+            )
             retry_last_ts = int(self.clock.time_msec())
 
         @defer.inlineCallbacks
@@ -173,4 +194,5 @@ class RetryDestinationLimiter(object):
                     "Failed to store set_destination_retry_timings",
                 )
 
-        store_retry_timings()
+        # we deliberately do this in the background.
+        synapse.util.logcontext.preserve_fn(store_retry_timings)()
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 199b16d827..c4dd9ae2c7 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -56,7 +56,8 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
         events ([synapse.events.EventBase]): list of events to filter
     """
     forgotten = yield preserve_context_over_deferred(defer.gatherResults([
-        preserve_fn(store.who_forgot_in_room)(
+        defer.maybeDeferred(
+            preserve_fn(store.who_forgot_in_room),
             room_id,
         )
         for room_id in frozenset(e.room_id for e in events)
@@ -134,6 +135,13 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
             if prev_membership not in MEMBERSHIP_PRIORITY:
                 prev_membership = "leave"
 
+            # Always allow the user to see their own leave events, otherwise
+            # they won't see the room disappear if they reject the invite
+            if membership == "leave" and (
+                prev_membership == "join" or prev_membership == "invite"
+            ):
+                return True
+
             new_priority = MEMBERSHIP_PRIORITY.index(membership)
             old_priority = MEMBERSHIP_PRIORITY.index(prev_membership)
             if old_priority < new_priority: