summary refs log tree commit diff
diff options
context:
space:
mode:
-rwxr-xr-xdemo/start.sh1
-rw-r--r--docs/application_services.rst2
-rw-r--r--scripts-dev/convert_server_keys.py113
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/crypto/keyclient.py17
-rw-r--r--synapse/federation/federation_base.py4
-rw-r--r--synapse/federation/federation_server.py46
-rw-r--r--synapse/handlers/_base.py11
-rw-r--r--synapse/handlers/events.py8
-rw-r--r--synapse/handlers/federation.py173
-rw-r--r--synapse/handlers/message.py19
-rw-r--r--synapse/handlers/presence.py70
-rw-r--r--synapse/handlers/profile.py17
-rw-r--r--synapse/handlers/room.py8
-rw-r--r--synapse/handlers/typing.py4
-rw-r--r--synapse/http/client.py6
-rw-r--r--synapse/http/matrixfederationclient.py32
-rw-r--r--synapse/http/server.py6
-rw-r--r--synapse/notifier.py19
-rw-r--r--synapse/storage/_base.py12
-rw-r--r--synapse/storage/event_federation.py66
-rw-r--r--synapse/storage/events.py53
-rw-r--r--synapse/storage/push_rule.py112
-rw-r--r--synapse/storage/stream.py138
-rw-r--r--synapse/storage/util/id_generators.py12
-rw-r--r--synapse/streams/events.py6
-rw-r--r--synapse/types.py52
-rw-r--r--synapse/util/__init__.py17
-rw-r--r--synapse/util/async.py6
-rw-r--r--synapse/util/distributor.py53
-rw-r--r--synapse/util/logcontext.py52
31 files changed, 766 insertions, 371 deletions
diff --git a/demo/start.sh b/demo/start.sh
index 5b3daef57f..b9cc14b9d2 100755
--- a/demo/start.sh
+++ b/demo/start.sh
@@ -31,6 +31,7 @@ for port in 8080 8081 8082; do
     #rm $DIR/etc/$port.config
     python -m synapse.app.homeserver \
         --generate-config \
+        --enable_registration \
         -H "localhost:$https_port" \
         --config-path "$DIR/etc/$port.config" \
 
diff --git a/docs/application_services.rst b/docs/application_services.rst
index a57bae6194..7e87ac9ad6 100644
--- a/docs/application_services.rst
+++ b/docs/application_services.rst
@@ -20,7 +20,7 @@ The format of the AS configuration file is as follows:
 
     url: <base url of AS>
     as_token: <token AS will add to requests to HS>
-    hs_token: <token HS will ad to requests to AS>
+    hs_token: <token HS will add to requests to AS>
     sender_localpart: <localpart of AS user>
     namespaces:
       users:  # List of users we're interested in
diff --git a/scripts-dev/convert_server_keys.py b/scripts-dev/convert_server_keys.py
new file mode 100644
index 0000000000..024ddcdbd0
--- /dev/null
+++ b/scripts-dev/convert_server_keys.py
@@ -0,0 +1,113 @@
+import psycopg2
+import yaml
+import sys
+import json
+import time
+import hashlib
+from syutil.base64util import encode_base64
+from syutil.crypto.signing_key import read_signing_keys
+from syutil.crypto.jsonsign import sign_json
+from syutil.jsonutil import encode_canonical_json
+
+
+def select_v1_keys(connection):
+    cursor = connection.cursor()
+    cursor.execute("SELECT server_name, key_id, verify_key FROM server_signature_keys")
+    rows = cursor.fetchall()
+    cursor.close()
+    results = {}
+    for server_name, key_id, verify_key in rows:
+        results.setdefault(server_name, {})[key_id] = encode_base64(verify_key)
+    return results
+
+
+def select_v1_certs(connection):
+    cursor = connection.cursor()
+    cursor.execute("SELECT server_name, tls_certificate FROM server_tls_certificates")
+    rows = cursor.fetchall()
+    cursor.close()
+    results = {}
+    for server_name, tls_certificate in rows:
+        results[server_name] = tls_certificate
+    return results
+
+
+def select_v2_json(connection):
+    cursor = connection.cursor()
+    cursor.execute("SELECT server_name, key_id, key_json FROM server_keys_json")
+    rows = cursor.fetchall()
+    cursor.close()
+    results = {}
+    for server_name, key_id, key_json in rows:
+        results.setdefault(server_name, {})[key_id] = json.loads(str(key_json).decode("utf-8"))
+    return results
+
+
+def convert_v1_to_v2(server_name, valid_until, keys, certificate):
+    return {
+        "old_verify_keys": {},
+        "server_name": server_name,
+        "verify_keys": keys,
+        "valid_until_ts": valid_until,
+        "tls_fingerprints": [fingerprint(certificate)],
+    }
+
+
+def fingerprint(certificate):
+    finger = hashlib.sha256(certificate)
+    return {"sha256": encode_base64(finger.digest())}
+
+
+def rows_v2(server, json):
+    valid_until = json["valid_until_ts"]
+    key_json = encode_canonical_json(json)
+    for key_id in json["verify_keys"]:
+        yield (server, key_id, "-", valid_until, valid_until, buffer(key_json))
+
+
+def main():
+    config = yaml.load(open(sys.argv[1]))
+    valid_until = int(time.time() / (3600 * 24)) * 1000 * 3600 * 24
+
+    server_name = config["server_name"]
+    signing_key = read_signing_keys(open(config["signing_key_path"]))[0]
+
+    database = config["database"]
+    assert database["name"] == "psycopg2", "Can only convert for postgresql"
+    args = database["args"]
+    args.pop("cp_max")
+    args.pop("cp_min")
+    connection = psycopg2.connect(**args)
+    keys = select_v1_keys(connection)
+    certificates = select_v1_certs(connection)
+    json = select_v2_json(connection)
+
+    result = {}
+    for server in keys:
+        if not server in json:
+            v2_json = convert_v1_to_v2(
+                server, valid_until, keys[server], certificates[server]
+            )
+            v2_json = sign_json(v2_json, server_name, signing_key)
+            result[server] = v2_json
+
+    yaml.safe_dump(result, sys.stdout, default_flow_style=False)
+
+    rows = list(
+        row for server, json in result.items()
+        for row in rows_v2(server, json)
+    )
+
+    cursor = connection.cursor()
+    cursor.executemany(
+        "INSERT INTO server_keys_json ("
+        " server_name, key_id, from_server,"
+        " ts_added_ms, ts_valid_until_ms, key_json"
+        ") VALUES (%s, %s, %s, %s, %s, %s)",
+        rows
+    )
+    connection.commit()
+
+
+if __name__ == '__main__':
+    main()
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 18b8ff7759..041e2151b0 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -16,4 +16,4 @@
 """ This is a reference implementation of a Matrix home server.
 """
 
-__version__ = "0.9.0-r1"
+__version__ = "0.9.0-r4"
diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py
index 4911f0896b..24f15f3154 100644
--- a/synapse/crypto/keyclient.py
+++ b/synapse/crypto/keyclient.py
@@ -18,7 +18,9 @@ from twisted.web.http import HTTPClient
 from twisted.internet.protocol import Factory
 from twisted.internet import defer, reactor
 from synapse.http.endpoint import matrix_federation_endpoint
-from synapse.util.logcontext import PreserveLoggingContext
+from synapse.util.logcontext import (
+    preserve_context_over_fn, preserve_context_over_deferred
+)
 import simplejson as json
 import logging
 
@@ -40,11 +42,14 @@ def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1):
 
     for i in range(5):
         try:
-            with PreserveLoggingContext():
-                protocol = yield endpoint.connect(factory)
-                server_response, server_certificate = yield protocol.remote_key
-                defer.returnValue((server_response, server_certificate))
-                return
+            protocol = yield preserve_context_over_fn(
+                endpoint.connect, factory
+            )
+            server_response, server_certificate = yield preserve_context_over_deferred(
+                protocol.remote_key
+            )
+            defer.returnValue((server_response, server_certificate))
+            return
         except SynapseKeyClientError as e:
             logger.exception("Error getting key for %r" % (server_name,))
             if e.status.startswith("4"):
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 21a763214b..5217d91aab 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -24,6 +24,8 @@ from synapse.crypto.event_signing import check_event_content_hash
 
 from synapse.api.errors import SynapseError
 
+from synapse.util import unwrapFirstError
+
 import logging
 
 
@@ -94,7 +96,7 @@ class FederationBase(object):
         yield defer.gatherResults(
             [do(pdu) for pdu in pdus],
             consumeErrors=True
-        )
+        ).addErrback(unwrapFirstError)
 
         defer.returnValue(signed_pdus)
 
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 2b46188c91..cd79e23f4b 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -20,7 +20,6 @@ from .federation_base import FederationBase
 from .units import Transaction, Edu
 
 from synapse.util.logutils import log_function
-from synapse.util.logcontext import PreserveLoggingContext
 from synapse.events import FrozenEvent
 import synapse.metrics
 
@@ -123,29 +122,28 @@ class FederationServer(FederationBase):
 
         logger.debug("[%s] Transaction is new", transaction.transaction_id)
 
-        with PreserveLoggingContext():
-            results = []
-
-            for pdu in pdu_list:
-                d = self._handle_new_pdu(transaction.origin, pdu)
-
-                try:
-                    yield d
-                    results.append({})
-                except FederationError as e:
-                    self.send_failure(e, transaction.origin)
-                    results.append({"error": str(e)})
-                except Exception as e:
-                    results.append({"error": str(e)})
-                    logger.exception("Failed to handle PDU")
-
-            if hasattr(transaction, "edus"):
-                for edu in [Edu(**x) for x in transaction.edus]:
-                    self.received_edu(
-                        transaction.origin,
-                        edu.edu_type,
-                        edu.content
-                    )
+        results = []
+
+        for pdu in pdu_list:
+            d = self._handle_new_pdu(transaction.origin, pdu)
+
+            try:
+                yield d
+                results.append({})
+            except FederationError as e:
+                self.send_failure(e, transaction.origin)
+                results.append({"error": str(e)})
+            except Exception as e:
+                results.append({"error": str(e)})
+                logger.exception("Failed to handle PDU")
+
+        if hasattr(transaction, "edus"):
+            for edu in [Edu(**x) for x in transaction.edus]:
+                self.received_edu(
+                    transaction.origin,
+                    edu.edu_type,
+                    edu.content
+                )
 
             for failure in getattr(transaction, "pdu_failures", []):
                 logger.info("Got failure %r", failure)
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 4b3f4eadab..ddc5c21e7d 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -20,6 +20,8 @@ from synapse.crypto.event_signing import add_hashes_and_signatures
 from synapse.api.constants import Membership, EventTypes
 from synapse.types import UserID
 
+from synapse.util.logcontext import PreserveLoggingContext
+
 import logging
 
 
@@ -137,10 +139,11 @@ class BaseHandler(object):
                     "Failed to get destination from event %s", s.event_id
                 )
 
-        # Don't block waiting on waking up all the listeners.
-        notify_d = self.notifier.on_new_room_event(
-            event, extra_users=extra_users
-        )
+        with PreserveLoggingContext():
+            # Don't block waiting on waking up all the listeners.
+            notify_d = self.notifier.on_new_room_event(
+                event, extra_users=extra_users
+            )
 
         def log_failure(f):
             logger.warn(
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index f9f855213b..993d33ba47 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -15,7 +15,6 @@
 
 from twisted.internet import defer
 
-from synapse.util.logcontext import PreserveLoggingContext
 from synapse.util.logutils import log_function
 from synapse.types import UserID
 from synapse.events.utils import serialize_event
@@ -81,10 +80,9 @@ class EventStreamHandler(BaseHandler):
                 # thundering herds on restart.
                 timeout = random.randint(int(timeout*0.9), int(timeout*1.1))
 
-            with PreserveLoggingContext():
-                events, tokens = yield self.notifier.get_events_for(
-                    auth_user, room_ids, pagin_config, timeout
-                )
+            events, tokens = yield self.notifier.get_events_for(
+                auth_user, room_ids, pagin_config, timeout
+            )
 
             time_now = self.clock.time_msec()
 
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 85e2757227..7d9906039e 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -18,9 +18,11 @@
 from ._base import BaseHandler
 
 from synapse.api.errors import (
-    AuthError, FederationError, StoreError,
+    AuthError, FederationError, StoreError, CodeMessageException, SynapseError,
 )
 from synapse.api.constants import EventTypes, Membership, RejectedReason
+from synapse.util import unwrapFirstError
+from synapse.util.logcontext import PreserveLoggingContext
 from synapse.util.logutils import log_function
 from synapse.util.async import run_on_reactor
 from synapse.util.frozenutils import unfreeze
@@ -29,6 +31,8 @@ from synapse.crypto.event_signing import (
 )
 from synapse.types import UserID
 
+from synapse.util.retryutils import NotRetryingDestination
+
 from twisted.internet import defer
 
 import itertools
@@ -197,9 +201,10 @@ class FederationHandler(BaseHandler):
                 target_user = UserID.from_string(target_user_id)
                 extra_users.append(target_user)
 
-            d = self.notifier.on_new_room_event(
-                event, extra_users=extra_users
-            )
+            with PreserveLoggingContext():
+                d = self.notifier.on_new_room_event(
+                    event, extra_users=extra_users
+                )
 
             def log_failure(f):
                 logger.warn(
@@ -218,10 +223,11 @@ class FederationHandler(BaseHandler):
 
     @log_function
     @defer.inlineCallbacks
-    def backfill(self, dest, room_id, limit):
+    def backfill(self, dest, room_id, limit, extremities=[]):
         """ Trigger a backfill request to `dest` for the given `room_id`
         """
-        extremities = yield self.store.get_oldest_events_in_room(room_id)
+        if not extremities:
+            extremities = yield self.store.get_oldest_events_in_room(room_id)
 
         pdus = yield self.replication_layer.backfill(
             dest,
@@ -249,6 +255,138 @@ class FederationHandler(BaseHandler):
         defer.returnValue(events)
 
     @defer.inlineCallbacks
+    def maybe_backfill(self, room_id, current_depth):
+        """Checks the database to see if we should backfill before paginating,
+        and if so do.
+        """
+        extremities = yield self.store.get_oldest_events_with_depth_in_room(
+            room_id
+        )
+
+        if not extremities:
+            logger.debug("Not backfilling as no extremeties found.")
+            return
+
+        # Check if we reached a point where we should start backfilling.
+        sorted_extremeties_tuple = sorted(
+            extremities.items(),
+            key=lambda e: -int(e[1])
+        )
+        max_depth = sorted_extremeties_tuple[0][1]
+
+        if current_depth > max_depth:
+            logger.debug(
+                "Not backfilling as we don't need to. %d < %d",
+                max_depth, current_depth,
+            )
+            return
+
+        # Now we need to decide which hosts to hit first.
+
+        # First we try hosts that are already in the room
+        # TODO: HEURISTIC ALERT.
+
+        curr_state = yield self.state_handler.get_current_state(room_id)
+
+        def get_domains_from_state(state):
+            joined_users = [
+                (state_key, int(event.depth))
+                for (e_type, state_key), event in state.items()
+                if e_type == EventTypes.Member
+                and event.membership == Membership.JOIN
+            ]
+
+            joined_domains = {}
+            for u, d in joined_users:
+                try:
+                    dom = UserID.from_string(u).domain
+                    old_d = joined_domains.get(dom)
+                    if old_d:
+                        joined_domains[dom] = min(d, old_d)
+                    else:
+                        joined_domains[dom] = d
+                except:
+                    pass
+
+            return sorted(joined_domains.items(), key=lambda d: d[1])
+
+        curr_domains = get_domains_from_state(curr_state)
+
+        likely_domains = [
+            domain for domain, depth in curr_domains
+        ]
+
+        @defer.inlineCallbacks
+        def try_backfill(domains):
+            # TODO: Should we try multiple of these at a time?
+            for dom in domains:
+                try:
+                    events = yield self.backfill(
+                        dom, room_id,
+                        limit=100,
+                        extremities=[e for e in extremities.keys()]
+                    )
+                except SynapseError:
+                    logger.info(
+                        "Failed to backfill from %s because %s",
+                        dom, e,
+                    )
+                    continue
+                except CodeMessageException as e:
+                    if 400 <= e.code < 500:
+                        raise
+
+                    logger.info(
+                        "Failed to backfill from %s because %s",
+                        dom, e,
+                    )
+                    continue
+                except NotRetryingDestination as e:
+                    logger.info(e.message)
+                    continue
+                except Exception as e:
+                    logger.warn(
+                        "Failed to backfill from %s because %s",
+                        dom, e,
+                    )
+                    continue
+
+                if events:
+                    defer.returnValue(True)
+            defer.returnValue(False)
+
+        success = yield try_backfill(likely_domains)
+        if success:
+            defer.returnValue(True)
+
+        # Huh, well *those* domains didn't work out. Lets try some domains
+        # from the time.
+
+        tried_domains = set(likely_domains)
+
+        event_ids = list(extremities.keys())
+
+        states = yield defer.gatherResults([
+            self.state_handler.resolve_state_groups([e])
+            for e in event_ids
+        ])
+        states = dict(zip(event_ids, [s[1] for s in states]))
+
+        for e_id, _ in sorted_extremeties_tuple:
+            likely_domains = get_domains_from_state(states[e_id])
+
+            success = yield try_backfill([
+                dom for dom in likely_domains
+                if dom not in tried_domains
+            ])
+            if success:
+                defer.returnValue(True)
+
+            tried_domains.update(likely_domains)
+
+        defer.returnValue(False)
+
+    @defer.inlineCallbacks
     def send_invite(self, target_host, event):
         """ Sends the invite to the remote server for signing.
 
@@ -431,9 +569,10 @@ class FederationHandler(BaseHandler):
                 auth_events=auth_events,
             )
 
-            d = self.notifier.on_new_room_event(
-                new_event, extra_users=[joinee]
-            )
+            with PreserveLoggingContext():
+                d = self.notifier.on_new_room_event(
+                    new_event, extra_users=[joinee]
+                )
 
             def log_failure(f):
                 logger.warn(
@@ -512,9 +651,10 @@ class FederationHandler(BaseHandler):
             target_user = UserID.from_string(target_user_id)
             extra_users.append(target_user)
 
-        d = self.notifier.on_new_room_event(
-            event, extra_users=extra_users
-        )
+        with PreserveLoggingContext():
+            d = self.notifier.on_new_room_event(
+                event, extra_users=extra_users
+            )
 
         def log_failure(f):
             logger.warn(
@@ -594,9 +734,10 @@ class FederationHandler(BaseHandler):
         )
 
         target_user = UserID.from_string(event.state_key)
-        d = self.notifier.on_new_room_event(
-            event, extra_users=[target_user],
-        )
+        with PreserveLoggingContext():
+            d = self.notifier.on_new_room_event(
+                event, extra_users=[target_user],
+            )
 
         def log_failure(f):
             logger.warn(
@@ -921,7 +1062,7 @@ class FederationHandler(BaseHandler):
                     if d in have_events and not have_events[d]
                 ],
                 consumeErrors=True
-            )
+            ).addErrback(unwrapFirstError)
 
             if different_events:
                 local_view = dict(auth_events)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 22e19af17f..867fdbefb0 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -20,8 +20,9 @@ from synapse.api.errors import RoomError, SynapseError
 from synapse.streams.config import PaginationConfig
 from synapse.events.utils import serialize_event
 from synapse.events.validator import EventValidator
+from synapse.util import unwrapFirstError
 from synapse.util.logcontext import PreserveLoggingContext
-from synapse.types import UserID
+from synapse.types import UserID, RoomStreamToken
 
 from ._base import BaseHandler
 
@@ -89,9 +90,19 @@ class MessageHandler(BaseHandler):
 
         if not pagin_config.from_token:
             pagin_config.from_token = (
-                yield self.hs.get_event_sources().get_current_token()
+                yield self.hs.get_event_sources().get_current_token(
+                    direction='b'
+                )
             )
 
+        room_token = RoomStreamToken.parse(pagin_config.from_token.room_key)
+        if room_token.topological is None:
+            raise SynapseError(400, "Invalid token")
+
+        yield self.hs.get_handlers().federation_handler.maybe_backfill(
+            room_id, room_token.topological
+        )
+
         user = UserID.from_string(user_id)
 
         events, next_key = yield data_source.get_pagination_rows(
@@ -303,7 +314,7 @@ class MessageHandler(BaseHandler):
                             event.room_id
                         ),
                     ]
-                )
+                ).addErrback(unwrapFirstError)
 
                 start_token = now_token.copy_and_replace("room_key", token[0])
                 end_token = now_token.copy_and_replace("room_key", token[1])
@@ -328,7 +339,7 @@ class MessageHandler(BaseHandler):
         yield defer.gatherResults(
             [handle_room(e) for e in room_list],
             consumeErrors=True
-        )
+        ).addErrback(unwrapFirstError)
 
         ret = {
             "rooms": rooms_ret,
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 9e15610401..28688d532d 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -18,14 +18,15 @@ from twisted.internet import defer
 from synapse.api.errors import SynapseError, AuthError
 from synapse.api.constants import PresenceState
 
-from synapse.util.logutils import log_function
 from synapse.util.logcontext import PreserveLoggingContext
+from synapse.util.logutils import log_function
 from synapse.types import UserID
 import synapse.metrics
 
 from ._base import BaseHandler
 
 import logging
+from collections import OrderedDict
 
 
 logger = logging.getLogger(__name__)
@@ -143,7 +144,7 @@ class PresenceHandler(BaseHandler):
         self._remote_offline_serials = []
 
         # map any user to a UserPresenceCache
-        self._user_cachemap = {}
+        self._user_cachemap = OrderedDict()  # keep them sorted by serial
         self._user_cachemap_latest_serial = 0
 
         metrics.register_callback(
@@ -165,6 +166,14 @@ class PresenceHandler(BaseHandler):
         else:
             return UserPresenceCache()
 
+    def _bump_serial(self, user=None):
+        self._user_cachemap_latest_serial += 1
+
+        if user:
+            # Move to end
+            cache = self._user_cachemap.pop(user)
+            self._user_cachemap[user] = cache
+
     def registered_user(self, user):
         return self.store.create_presence(user.localpart)
 
@@ -278,15 +287,14 @@ class PresenceHandler(BaseHandler):
         now_online = state["presence"] != PresenceState.OFFLINE
         was_polling = target_user in self._user_cachemap
 
-        with PreserveLoggingContext():
-            if now_online and not was_polling:
-                self.start_polling_presence(target_user, state=state)
-            elif not now_online and was_polling:
-                self.stop_polling_presence(target_user)
+        if now_online and not was_polling:
+            self.start_polling_presence(target_user, state=state)
+        elif not now_online and was_polling:
+            self.stop_polling_presence(target_user)
 
-            # TODO(paul): perform a presence push as part of start/stop poll so
-            #   we don't have to do this all the time
-            self.changed_presencelike_data(target_user, state)
+        # TODO(paul): perform a presence push as part of start/stop poll so
+        #   we don't have to do this all the time
+        self.changed_presencelike_data(target_user, state)
 
     def bump_presence_active_time(self, user, now=None):
         if now is None:
@@ -301,7 +309,7 @@ class PresenceHandler(BaseHandler):
     def changed_presencelike_data(self, user, state):
         statuscache = self._get_or_make_usercache(user)
 
-        self._user_cachemap_latest_serial += 1
+        self._bump_serial(user=user)
         statuscache.update(state, serial=self._user_cachemap_latest_serial)
 
         return self.push_presence(user, statuscache=statuscache)
@@ -323,7 +331,7 @@ class PresenceHandler(BaseHandler):
 
             # No actual update but we need to bump the serial anyway for the
             # event source
-            self._user_cachemap_latest_serial += 1
+            self._bump_serial()
             statuscache.update({}, serial=self._user_cachemap_latest_serial)
 
             self.push_update_to_local_and_remote(
@@ -408,10 +416,10 @@ class PresenceHandler(BaseHandler):
         yield self.store.set_presence_list_accepted(
             observer_user.localpart, observed_user.to_string()
         )
-        with PreserveLoggingContext():
-            self.start_polling_presence(
-                observer_user, target_user=observed_user
-            )
+
+        self.start_polling_presence(
+            observer_user, target_user=observed_user
+        )
 
     @defer.inlineCallbacks
     def deny_presence(self, observed_user, observer_user):
@@ -430,10 +438,9 @@ class PresenceHandler(BaseHandler):
             observer_user.localpart, observed_user.to_string()
         )
 
-        with PreserveLoggingContext():
-            self.stop_polling_presence(
-                observer_user, target_user=observed_user
-            )
+        self.stop_polling_presence(
+            observer_user, target_user=observed_user
+        )
 
     @defer.inlineCallbacks
     def get_presence_list(self, observer_user, accepted=None):
@@ -706,7 +713,7 @@ class PresenceHandler(BaseHandler):
 
             statuscache = self._get_or_make_usercache(user)
 
-            self._user_cachemap_latest_serial += 1
+            self._bump_serial(user=user)
             statuscache.update(state, serial=self._user_cachemap_latest_serial)
 
             if not observers and not room_ids:
@@ -766,8 +773,7 @@ class PresenceHandler(BaseHandler):
                 if not self._remote_sendmap[user]:
                     del self._remote_sendmap[user]
 
-        with PreserveLoggingContext():
-            yield defer.DeferredList(deferreds, consumeErrors=True)
+        yield defer.DeferredList(deferreds, consumeErrors=True)
 
     @defer.inlineCallbacks
     def push_update_to_local_and_remote(self, observed_user, statuscache,
@@ -812,10 +818,11 @@ class PresenceHandler(BaseHandler):
 
     def push_update_to_clients(self, observed_user, users_to_push=[],
                                room_ids=[], statuscache=None):
-        self.notifier.on_new_user_event(
-            users_to_push,
-            room_ids,
-        )
+        with PreserveLoggingContext():
+            self.notifier.on_new_user_event(
+                users_to_push,
+                room_ids,
+            )
 
 
 class PresenceEventSource(object):
@@ -866,10 +873,15 @@ class PresenceEventSource(object):
 
         updates = []
         # TODO(paul): use a DeferredList ? How to limit concurrency.
-        for observed_user in cachemap.keys():
+        for observed_user in reversed(cachemap.keys()):
             cached = cachemap[observed_user]
 
-            if cached.serial <= from_key or cached.serial > max_serial:
+            # Since this is ordered in descending order of serial, we can just
+            # stop once we've seen enough
+            if cached.serial <= from_key:
+                break
+
+            if cached.serial > max_serial:
                 continue
 
             if not (yield self.is_visible(observer_user, observed_user)):
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index ee2732b848..71ff78ab23 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -17,8 +17,8 @@ from twisted.internet import defer
 
 from synapse.api.errors import SynapseError, AuthError, CodeMessageException
 from synapse.api.constants import EventTypes, Membership
-from synapse.util.logcontext import PreserveLoggingContext
 from synapse.types import UserID
+from synapse.util import unwrapFirstError
 
 from ._base import BaseHandler
 
@@ -154,14 +154,13 @@ class ProfileHandler(BaseHandler):
         if not self.hs.is_mine(user):
             defer.returnValue(None)
 
-        with PreserveLoggingContext():
-            (displayname, avatar_url) = yield defer.gatherResults(
-                [
-                    self.store.get_profile_displayname(user.localpart),
-                    self.store.get_profile_avatar_url(user.localpart),
-                ],
-                consumeErrors=True
-            )
+        (displayname, avatar_url) = yield defer.gatherResults(
+            [
+                self.store.get_profile_displayname(user.localpart),
+                self.store.get_profile_avatar_url(user.localpart),
+            ],
+            consumeErrors=True
+        ).addErrback(unwrapFirstError)
 
         state["displayname"] = displayname
         state["avatar_url"] = avatar_url
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index cfa2e38ed2..dac683616a 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -21,7 +21,7 @@ from ._base import BaseHandler
 from synapse.types import UserID, RoomAlias, RoomID
 from synapse.api.constants import EventTypes, Membership, JoinRules
 from synapse.api.errors import StoreError, SynapseError
-from synapse.util import stringutils
+from synapse.util import stringutils, unwrapFirstError
 from synapse.util.async import run_on_reactor
 from synapse.events.utils import serialize_event
 
@@ -537,7 +537,7 @@ class RoomListHandler(BaseHandler):
                 for room in chunk
             ],
             consumeErrors=True,
-        )
+        ).addErrback(unwrapFirstError)
 
         for i, room in enumerate(chunk):
             room["num_joined_members"] = len(results[i])
@@ -577,8 +577,8 @@ class RoomEventSource(object):
 
         defer.returnValue((events, end_key))
 
-    def get_current_key(self):
-        return self.store.get_room_events_max_id()
+    def get_current_key(self, direction='f'):
+        return self.store.get_room_events_max_id(direction)
 
     @defer.inlineCallbacks
     def get_pagination_rows(self, user, config, key):
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index c0b2bd7db0..64fe51aa3e 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -18,6 +18,7 @@ from twisted.internet import defer
 from ._base import BaseHandler
 
 from synapse.api.errors import SynapseError, AuthError
+from synapse.util.logcontext import PreserveLoggingContext
 from synapse.types import UserID
 
 import logging
@@ -216,7 +217,8 @@ class TypingNotificationHandler(BaseHandler):
         self._latest_room_serial += 1
         self._room_serials[room_id] = self._latest_room_serial
 
-        self.notifier.on_new_user_event(rooms=[room_id])
+        with PreserveLoggingContext():
+            self.notifier.on_new_user_event(rooms=[room_id])
 
 
 class TypingNotificationEventSource(object):
diff --git a/synapse/http/client.py b/synapse/http/client.py
index e8a5dedab4..5b3cefb2dc 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 from synapse.api.errors import CodeMessageException
+from synapse.util.logcontext import preserve_context_over_fn
 from syutil.jsonutil import encode_canonical_json
 import synapse.metrics
 
@@ -61,7 +62,10 @@ class SimpleHttpClient(object):
         # A small wrapper around self.agent.request() so we can easily attach
         # counters to it
         outgoing_requests_counter.inc(method)
-        d = self.agent.request(method, *args, **kwargs)
+        d = preserve_context_over_fn(
+            self.agent.request,
+            method, *args, **kwargs
+        )
 
         def _cb(response):
             incoming_responses_counter.inc(method, response.code)
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 7fa295cad5..c99d237c73 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -22,7 +22,7 @@ from twisted.web._newclient import ResponseDone
 
 from synapse.http.endpoint import matrix_federation_endpoint
 from synapse.util.async import sleep
-from synapse.util.logcontext import PreserveLoggingContext
+from synapse.util.logcontext import preserve_context_over_fn
 import synapse.metrics
 
 from syutil.jsonutil import encode_canonical_json
@@ -144,22 +144,22 @@ class MatrixFederationHttpClient(object):
                 producer = body_callback(method, url_bytes, headers_dict)
 
             try:
-                with PreserveLoggingContext():
-                    request_deferred = self.agent.request(
-                        destination,
-                        endpoint,
-                        method,
-                        path_bytes,
-                        param_bytes,
-                        query_bytes,
-                        Headers(headers_dict),
-                        producer
-                    )
+                request_deferred = preserve_context_over_fn(
+                    self.agent.request,
+                    destination,
+                    endpoint,
+                    method,
+                    path_bytes,
+                    param_bytes,
+                    query_bytes,
+                    Headers(headers_dict),
+                    producer
+                )
 
-                    response = yield self.clock.time_bound_deferred(
-                        request_deferred,
-                        time_out=60,
-                    )
+                response = yield self.clock.time_bound_deferred(
+                    request_deferred,
+                    time_out=60,
+                )
 
                 logger.debug("Got response to %s", method)
                 break
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 93ecbd7589..73efbff4f2 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -17,7 +17,7 @@
 from synapse.api.errors import (
     cs_exception, SynapseError, CodeMessageException, UnrecognizedRequestError
 )
-from synapse.util.logcontext import LoggingContext
+from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
 import synapse.metrics
 
 from syutil.jsonutil import (
@@ -85,7 +85,9 @@ def request_handler(request_handler):
                     "Received request: %s %s",
                     request.method, request.path
                 )
-                yield request_handler(self, request)
+                d = request_handler(self, request)
+                with PreserveLoggingContext():
+                    yield d
                 code = request.code
             except CodeMessageException as e:
                 code = e.code
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 78eb28e4b2..7282dfd7f3 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -16,7 +16,6 @@
 from twisted.internet import defer
 
 from synapse.util.logutils import log_function
-from synapse.util.logcontext import PreserveLoggingContext
 from synapse.types import StreamToken
 import synapse.metrics
 
@@ -223,11 +222,10 @@ class Notifier(object):
         def eb(failure):
             logger.exception("Failed to notify listener", failure)
 
-        with PreserveLoggingContext():
-            yield defer.DeferredList(
-                [notify(l).addErrback(eb) for l in listeners],
-                consumeErrors=True,
-            )
+        yield defer.DeferredList(
+            [notify(l).addErrback(eb) for l in listeners],
+            consumeErrors=True,
+        )
 
     @defer.inlineCallbacks
     @log_function
@@ -298,11 +296,10 @@ class Notifier(object):
                     failure.getTracebackObject())
             )
 
-        with PreserveLoggingContext():
-            yield defer.DeferredList(
-                [notify(l).addErrback(eb) for l in listeners],
-                consumeErrors=True,
-            )
+        yield defer.DeferredList(
+            [notify(l).addErrback(eb) for l in listeners],
+            consumeErrors=True,
+        )
 
     @defer.inlineCallbacks
     def wait_for_events(self, user, rooms, filter, timeout, callback):
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index ee5587c721..c9fe5a3555 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -18,7 +18,7 @@ from synapse.api.errors import StoreError
 from synapse.events import FrozenEvent
 from synapse.events.utils import prune_event
 from synapse.util.logutils import log_function
-from synapse.util.logcontext import PreserveLoggingContext, LoggingContext
+from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
 from synapse.util.lrucache import LruCache
 import synapse.metrics
 
@@ -308,6 +308,7 @@ class SQLBaseStore(object):
         self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
         self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
         self._pushers_id_gen = IdGenerator("pushers", "id", self)
+        self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
 
     def start_profiling(self):
         self._previous_loop_ts = self._clock.time_msec()
@@ -419,10 +420,11 @@ class SQLBaseStore(object):
                     self._txn_perf_counters.update(desc, start, end)
                     sql_txn_timer.inc_by(duration, desc)
 
-        with PreserveLoggingContext():
-            result = yield self._db_pool.runWithConnection(
-                inner_func, *args, **kwargs
-            )
+        result = yield preserve_context_over_fn(
+            self._db_pool.runWithConnection,
+            inner_func, *args, **kwargs
+        )
+
         for after_callback, after_args in after_callbacks:
             after_callback(*after_args)
         defer.returnValue(result)
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 74b4e23590..a1982dfbb5 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -79,6 +79,28 @@ class EventFederationStore(SQLBaseStore):
             room_id,
         )
 
+    def get_oldest_events_with_depth_in_room(self, room_id):
+        return self.runInteraction(
+            "get_oldest_events_with_depth_in_room",
+            self.get_oldest_events_with_depth_in_room_txn,
+            room_id,
+        )
+
+    def get_oldest_events_with_depth_in_room_txn(self, txn, room_id):
+        sql = (
+            "SELECT b.event_id, MAX(e.depth) FROM events as e"
+            " INNER JOIN event_edges as g"
+            " ON g.event_id = e.event_id AND g.room_id = e.room_id"
+            " INNER JOIN event_backward_extremities as b"
+            " ON g.prev_event_id = b.event_id AND g.room_id = b.room_id"
+            " WHERE b.room_id = ? AND g.is_state is ?"
+            " GROUP BY b.event_id"
+        )
+
+        txn.execute(sql, (room_id, False,))
+
+        return dict(txn.fetchall())
+
     def _get_oldest_events_in_room_txn(self, txn, room_id):
         return self._simple_select_onecol_txn(
             txn,
@@ -247,11 +269,13 @@ class EventFederationStore(SQLBaseStore):
         do_insert = depth < min_depth if min_depth else True
 
         if do_insert:
-            self._simple_insert_txn(
+            self._simple_upsert_txn(
                 txn,
                 table="room_depth",
-                values={
+                keyvalues={
                     "room_id": room_id,
+                },
+                values={
                     "min_depth": depth,
                 },
             )
@@ -306,31 +330,27 @@ class EventFederationStore(SQLBaseStore):
 
                 txn.execute(query, (event_id, room_id))
 
-            # Insert all the prev_events as a backwards thing, they'll get
-            # deleted in a second if they're incorrect anyway.
-            self._simple_insert_many_txn(
-                txn,
-                table="event_backward_extremities",
-                values=[
-                    {
-                        "event_id": e_id,
-                        "room_id": room_id,
-                    }
-                    for e_id, _ in prev_events
-                ],
+            query = (
+                "INSERT INTO event_backward_extremities (event_id, room_id)"
+                " SELECT ?, ? WHERE NOT EXISTS ("
+                " SELECT 1 FROM event_backward_extremities"
+                " WHERE event_id = ? AND room_id = ?"
+                " )"
+                " AND NOT EXISTS ("
+                " SELECT 1 FROM events WHERE event_id = ? AND room_id = ?"
+                " )"
             )
 
-            # Also delete from the backwards extremities table all ones that
-            # reference events that we have already seen
+            txn.executemany(query, [
+                (e_id, room_id, e_id, room_id, e_id, room_id, )
+                for e_id, _ in prev_events
+            ])
+
             query = (
-                "DELETE FROM event_backward_extremities WHERE EXISTS ("
-                "SELECT 1 FROM events "
-                "WHERE "
-                "event_backward_extremities.event_id = events.event_id "
-                "AND not events.outlier "
-                ")"
+                "DELETE FROM event_backward_extremities"
+                " WHERE event_id = ? AND room_id = ?"
             )
-            txn.execute(query)
+            txn.execute(query, (event_id, room_id))
 
             txn.call_after(
                 self.get_latest_event_ids_in_room.invalidate, room_id
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 38395c66ab..a5a6869079 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -23,6 +23,7 @@ from synapse.crypto.event_signing import compute_event_reference_hash
 
 from syutil.base64util import decode_base64
 from syutil.jsonutil import encode_canonical_json
+from contextlib import contextmanager
 
 import logging
 
@@ -41,17 +42,25 @@ class EventsStore(SQLBaseStore):
             self.min_token -= 1
             stream_ordering = self.min_token
 
+        if stream_ordering is None:
+            stream_ordering_manager = yield self._stream_id_gen.get_next(self)
+        else:
+            @contextmanager
+            def stream_ordering_manager():
+                yield stream_ordering
+
         try:
-            yield self.runInteraction(
-                "persist_event",
-                self._persist_event_txn,
-                event=event,
-                context=context,
-                backfilled=backfilled,
-                stream_ordering=stream_ordering,
-                is_new_state=is_new_state,
-                current_state=current_state,
-            )
+            with stream_ordering_manager as stream_ordering:
+                yield self.runInteraction(
+                    "persist_event",
+                    self._persist_event_txn,
+                    event=event,
+                    context=context,
+                    backfilled=backfilled,
+                    stream_ordering=stream_ordering,
+                    is_new_state=is_new_state,
+                    current_state=current_state,
+                )
         except _RollbackButIsFineException:
             pass
 
@@ -95,15 +104,6 @@ class EventsStore(SQLBaseStore):
         # Remove the any existing cache entries for the event_id
         txn.call_after(self._invalidate_get_event_cache, event.event_id)
 
-        if stream_ordering is None:
-            with self._stream_id_gen.get_next_txn(txn) as stream_ordering:
-                return self._persist_event_txn(
-                    txn, event, context, backfilled,
-                    stream_ordering=stream_ordering,
-                    is_new_state=is_new_state,
-                    current_state=current_state,
-                )
-
         # We purposefully do this first since if we include a `current_state`
         # key, we *want* to update the `current_state_events` table
         if current_state:
@@ -135,19 +135,17 @@ class EventsStore(SQLBaseStore):
         outlier = event.internal_metadata.is_outlier()
 
         if not outlier:
-            self._store_state_groups_txn(txn, event, context)
-
             self._update_min_depth_for_room_txn(
                 txn,
                 event.room_id,
                 event.depth
             )
 
-        have_persisted = self._simple_select_one_onecol_txn(
+        have_persisted = self._simple_select_one_txn(
             txn,
-            table="event_json",
+            table="events",
             keyvalues={"event_id": event.event_id},
-            retcol="event_id",
+            retcols=["event_id", "outlier"],
             allow_none=True,
         )
 
@@ -162,7 +160,9 @@ class EventsStore(SQLBaseStore):
         # if we are persisting an event that we had persisted as an outlier,
         # but is no longer one.
         if have_persisted:
-            if not outlier:
+            if not outlier and have_persisted["outlier"]:
+                self._store_state_groups_txn(txn, event, context)
+
                 sql = (
                     "UPDATE event_json SET internal_metadata = ?"
                     " WHERE event_id = ?"
@@ -182,6 +182,9 @@ class EventsStore(SQLBaseStore):
                 )
             return
 
+        if not outlier:
+            self._store_state_groups_txn(txn, event, context)
+
         self._handle_prev_events(
             txn,
             outlier=outlier,
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index ee7718d5ed..34805e276e 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -19,7 +19,6 @@ from ._base import SQLBaseStore, Table
 from twisted.internet import defer
 
 import logging
-import copy
 import simplejson as json
 
 logger = logging.getLogger(__name__)
@@ -28,46 +27,45 @@ logger = logging.getLogger(__name__)
 class PushRuleStore(SQLBaseStore):
     @defer.inlineCallbacks
     def get_push_rules_for_user(self, user_name):
-        sql = (
-            "SELECT "+",".join(PushRuleTable.fields)+" "
-            "FROM "+PushRuleTable.table_name+" "
-            "WHERE user_name = ? "
-            "ORDER BY priority_class DESC, priority DESC"
+        rows = yield self._simple_select_list(
+            table=PushRuleTable.table_name,
+            keyvalues={
+                "user_name": user_name,
+            },
+            retcols=PushRuleTable.fields,
         )
-        rows = yield self._execute("get_push_rules_for_user", None, sql, user_name)
 
-        dicts = []
-        for r in rows:
-            d = {}
-            for i, f in enumerate(PushRuleTable.fields):
-                d[f] = r[i]
-            dicts.append(d)
+        rows.sort(
+            key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))
+        )
 
-        defer.returnValue(dicts)
+        defer.returnValue(rows)
 
     @defer.inlineCallbacks
     def get_push_rules_enabled_for_user(self, user_name):
         results = yield self._simple_select_list(
-            PushRuleEnableTable.table_name,
-            {'user_name': user_name},
-            PushRuleEnableTable.fields,
+            table=PushRuleEnableTable.table_name,
+            keyvalues={
+                'user_name': user_name
+            },
+            retcols=PushRuleEnableTable.fields,
             desc="get_push_rules_enabled_for_user",
         )
-        defer.returnValue(
-            {r['rule_id']: False if r['enabled'] == 0 else True for r in results}
-        )
+        defer.returnValue({
+            r['rule_id']: False if r['enabled'] == 0 else True for r in results
+        })
 
     @defer.inlineCallbacks
     def add_push_rule(self, before, after, **kwargs):
-        vals = copy.copy(kwargs)
+        vals = kwargs
         if 'conditions' in vals:
             vals['conditions'] = json.dumps(vals['conditions'])
         if 'actions' in vals:
             vals['actions'] = json.dumps(vals['actions'])
+
         # we could check the rest of the keys are valid column names
         # but sqlite will do that anyway so I think it's just pointless.
-        if 'id' in vals:
-            del vals['id']
+        vals.pop("id", None)
 
         if before or after:
             ret = yield self.runInteraction(
@@ -87,39 +85,39 @@ class PushRuleStore(SQLBaseStore):
             defer.returnValue(ret)
 
     def _add_push_rule_relative_txn(self, txn, user_name, **kwargs):
-        after = None
-        relative_to_rule = None
-        if 'after' in kwargs and kwargs['after']:
-            after = kwargs['after']
-            relative_to_rule = after
-        if 'before' in kwargs and kwargs['before']:
-            relative_to_rule = kwargs['before']
-
-        # get the priority of the rule we're inserting after/before
-        sql = (
-            "SELECT priority_class, priority FROM ? "
-            "WHERE user_name = ? and rule_id = ?" % (PushRuleTable.table_name,)
+        after = kwargs.pop("after", None)
+        relative_to_rule = kwargs.pop("before", after)
+
+        res = self._simple_select_one_txn(
+            txn,
+            table=PushRuleTable.table_name,
+            keyvalues={
+                "user_name": user_name,
+                "rule_id": relative_to_rule,
+            },
+            retcols=["priority_class", "priority"],
+            allow_none=True,
         )
-        txn.execute(sql, (user_name, relative_to_rule))
-        res = txn.fetchall()
+
         if not res:
             raise RuleNotFoundException(
                 "before/after rule not found: %s" % (relative_to_rule,)
             )
-        priority_class, base_rule_priority = res[0]
+
+        priority_class = res["priority_class"]
+        base_rule_priority = res["priority"]
 
         if 'priority_class' in kwargs and kwargs['priority_class'] != priority_class:
             raise InconsistentRuleException(
                 "Given priority class does not match class of relative rule"
             )
 
-        new_rule = copy.copy(kwargs)
-        if 'before' in new_rule:
-            del new_rule['before']
-        if 'after' in new_rule:
-            del new_rule['after']
+        new_rule = kwargs
+        new_rule.pop("before", None)
+        new_rule.pop("after", None)
         new_rule['priority_class'] = priority_class
         new_rule['user_name'] = user_name
+        new_rule['id'] = self._push_rule_id_gen.get_next_txn(txn)
 
         # check if the priority before/after is free
         new_rule_priority = base_rule_priority
@@ -153,12 +151,11 @@ class PushRuleStore(SQLBaseStore):
 
             txn.execute(sql, (user_name, priority_class, new_rule_priority))
 
-        # now insert the new rule
-        sql = "INSERT INTO "+PushRuleTable.table_name+" ("
-        sql += ",".join(new_rule.keys())+") VALUES ("
-        sql += ", ".join(["?" for _ in new_rule.keys()])+")"
-
-        txn.execute(sql, new_rule.values())
+        self._simple_insert_txn(
+            txn,
+            table=PushRuleTable.table_name,
+            values=new_rule,
+        )
 
     def _add_push_rule_highest_priority_txn(self, txn, user_name,
                                             priority_class, **kwargs):
@@ -176,18 +173,17 @@ class PushRuleStore(SQLBaseStore):
             new_prio = highest_prio + 1
 
         # and insert the new rule
-        new_rule = copy.copy(kwargs)
-        if 'id' in new_rule:
-            del new_rule['id']
+        new_rule = kwargs
+        new_rule['id'] = self._push_rule_id_gen.get_next_txn(txn)
         new_rule['user_name'] = user_name
         new_rule['priority_class'] = priority_class
         new_rule['priority'] = new_prio
 
-        sql = "INSERT INTO "+PushRuleTable.table_name+" ("
-        sql += ",".join(new_rule.keys())+") VALUES ("
-        sql += ", ".join(["?" for _ in new_rule.keys()])+")"
-
-        txn.execute(sql, new_rule.values())
+        self._simple_insert_txn(
+            txn,
+            table=PushRuleTable.table_name,
+            values=new_rule,
+        )
 
     @defer.inlineCallbacks
     def delete_push_rule(self, user_name, rule_id):
@@ -211,7 +207,7 @@ class PushRuleStore(SQLBaseStore):
         yield self._simple_upsert(
             PushRuleEnableTable.table_name,
             {'user_name': user_name, 'rule_id': rule_id},
-            {'enabled': enabled},
+            {'enabled': 1 if enabled else 0},
             desc="set_push_rule_enabled",
         )
 
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 280d4ad605..8045e17fd7 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -37,11 +37,9 @@ from twisted.internet import defer
 
 from ._base import SQLBaseStore
 from synapse.api.constants import EventTypes
-from synapse.api.errors import SynapseError
+from synapse.types import RoomStreamToken
 from synapse.util.logutils import log_function
 
-from collections import namedtuple
-
 import logging
 
 
@@ -55,76 +53,26 @@ _STREAM_TOKEN = "stream"
 _TOPOLOGICAL_TOKEN = "topological"
 
 
-class _StreamToken(namedtuple("_StreamToken", "topological stream")):
-    """Tokens are positions between events. The token "s1" comes after event 1.
-
-            s0    s1
-            |     |
-        [0] V [1] V [2]
-
-    Tokens can either be a point in the live event stream or a cursor going
-    through historic events.
-
-    When traversing the live event stream events are ordered by when they
-    arrived at the homeserver.
-
-    When traversing historic events the events are ordered by their depth in
-    the event graph "topological_ordering" and then by when they arrived at the
-    homeserver "stream_ordering".
-
-    Live tokens start with an "s" followed by the "stream_ordering" id of the
-    event it comes after. Historic tokens start with a "t" followed by the
-    "topological_ordering" id of the event it comes after, follewed by "-",
-    followed by the "stream_ordering" id of the event it comes after.
-    """
-    __slots__ = []
-
-    @classmethod
-    def parse(cls, string):
-        try:
-            if string[0] == 's':
-                return cls(topological=None, stream=int(string[1:]))
-            if string[0] == 't':
-                parts = string[1:].split('-', 1)
-                return cls(topological=int(parts[0]), stream=int(parts[1]))
-        except:
-            pass
-        raise SynapseError(400, "Invalid token %r" % (string,))
-
-    @classmethod
-    def parse_stream_token(cls, string):
-        try:
-            if string[0] == 's':
-                return cls(topological=None, stream=int(string[1:]))
-        except:
-            pass
-        raise SynapseError(400, "Invalid token %r" % (string,))
-
-    def __str__(self):
-        if self.topological is not None:
-            return "t%d-%d" % (self.topological, self.stream)
-        else:
-            return "s%d" % (self.stream,)
+def lower_bound(token):
+    if token.topological is None:
+        return "(%d < %s)" % (token.stream, "stream_ordering")
+    else:
+        return "(%d < %s OR (%d = %s AND %d < %s))" % (
+            token.topological, "topological_ordering",
+            token.topological, "topological_ordering",
+            token.stream, "stream_ordering",
+        )
 
-    def lower_bound(self):
-        if self.topological is None:
-            return "(%d < %s)" % (self.stream, "stream_ordering")
-        else:
-            return "(%d < %s OR (%d = %s AND %d < %s))" % (
-                self.topological, "topological_ordering",
-                self.topological, "topological_ordering",
-                self.stream, "stream_ordering",
-            )
 
-    def upper_bound(self):
-        if self.topological is None:
-            return "(%d >= %s)" % (self.stream, "stream_ordering")
-        else:
-            return "(%d > %s OR (%d = %s AND %d >= %s))" % (
-                self.topological, "topological_ordering",
-                self.topological, "topological_ordering",
-                self.stream, "stream_ordering",
-            )
+def upper_bound(token):
+    if token.topological is None:
+        return "(%d >= %s)" % (token.stream, "stream_ordering")
+    else:
+        return "(%d > %s OR (%d = %s AND %d >= %s))" % (
+            token.topological, "topological_ordering",
+            token.topological, "topological_ordering",
+            token.stream, "stream_ordering",
+        )
 
 
 class StreamStore(SQLBaseStore):
@@ -139,8 +87,8 @@ class StreamStore(SQLBaseStore):
             limit = MAX_STREAM_SIZE
 
         # From and to keys should be integers from ordering.
-        from_id = _StreamToken.parse_stream_token(from_key)
-        to_id = _StreamToken.parse_stream_token(to_key)
+        from_id = RoomStreamToken.parse_stream_token(from_key)
+        to_id = RoomStreamToken.parse_stream_token(to_key)
 
         if from_key == to_key:
             defer.returnValue(([], to_key))
@@ -234,8 +182,8 @@ class StreamStore(SQLBaseStore):
             limit = MAX_STREAM_SIZE
 
         # From and to keys should be integers from ordering.
-        from_id = _StreamToken.parse_stream_token(from_key)
-        to_id = _StreamToken.parse_stream_token(to_key)
+        from_id = RoomStreamToken.parse_stream_token(from_key)
+        to_id = RoomStreamToken.parse_stream_token(to_key)
 
         if from_key == to_key:
             return defer.succeed(([], to_key))
@@ -288,17 +236,17 @@ class StreamStore(SQLBaseStore):
         args = [False, room_id]
         if direction == 'b':
             order = "DESC"
-            bounds = _StreamToken.parse(from_key).upper_bound()
+            bounds = upper_bound(RoomStreamToken.parse(from_key))
             if to_key:
                 bounds = "%s AND %s" % (
-                    bounds, _StreamToken.parse(to_key).lower_bound()
+                    bounds, lower_bound(RoomStreamToken.parse(to_key))
                 )
         else:
             order = "ASC"
-            bounds = _StreamToken.parse(from_key).lower_bound()
+            bounds = lower_bound(RoomStreamToken.parse(from_key))
             if to_key:
                 bounds = "%s AND %s" % (
-                    bounds, _StreamToken.parse(to_key).upper_bound()
+                    bounds, upper_bound(RoomStreamToken.parse(to_key))
                 )
 
         if int(limit) > 0:
@@ -333,7 +281,7 @@ class StreamStore(SQLBaseStore):
                     # when we are going backwards so we subtract one from the
                     # stream part.
                     toke -= 1
-                next_token = str(_StreamToken(topo, toke))
+                next_token = str(RoomStreamToken(topo, toke))
             else:
                 # TODO (erikj): We should work out what to do here instead.
                 next_token = to_key if to_key else from_key
@@ -354,7 +302,7 @@ class StreamStore(SQLBaseStore):
                                    with_feedback=False, from_token=None):
         # TODO (erikj): Handle compressed feedback
 
-        end_token = _StreamToken.parse_stream_token(end_token)
+        end_token = RoomStreamToken.parse_stream_token(end_token)
 
         if from_token is None:
             sql = (
@@ -365,7 +313,7 @@ class StreamStore(SQLBaseStore):
                 " LIMIT ?"
             )
         else:
-            from_token = _StreamToken.parse_stream_token(from_token)
+            from_token = RoomStreamToken.parse_stream_token(from_token)
             sql = (
                 "SELECT stream_ordering, topological_ordering, event_id"
                 " FROM events"
@@ -395,7 +343,7 @@ class StreamStore(SQLBaseStore):
                 # stream part.
                 topo = rows[0]["topological_ordering"]
                 toke = rows[0]["stream_ordering"] - 1
-                start_token = str(_StreamToken(topo, toke))
+                start_token = str(RoomStreamToken(topo, toke))
 
                 token = (start_token, str(end_token))
             else:
@@ -416,9 +364,25 @@ class StreamStore(SQLBaseStore):
         )
 
     @defer.inlineCallbacks
-    def get_room_events_max_id(self):
+    def get_room_events_max_id(self, direction='f'):
         token = yield self._stream_id_gen.get_max_token(self)
-        defer.returnValue("s%d" % (token,))
+        if direction != 'b':
+            defer.returnValue("s%d" % (token,))
+        else:
+            topo = yield self.runInteraction(
+                "_get_max_topological_txn", self._get_max_topological_txn
+            )
+            defer.returnValue("t%d-%d" % (topo, token))
+
+    def _get_max_topological_txn(self, txn):
+        txn.execute(
+            "SELECT MAX(topological_ordering) FROM events"
+            " WHERE outlier = ?",
+            (False,)
+        )
+
+        rows = txn.fetchall()
+        return rows[0][0] if rows else 0
 
     @defer.inlineCallbacks
     def _get_min_token(self):
@@ -439,5 +403,5 @@ class StreamStore(SQLBaseStore):
             stream = row["stream_ordering"]
             topo = event.depth
             internal = event.internal_metadata
-            internal.before = str(_StreamToken(topo, stream - 1))
-            internal.after = str(_StreamToken(topo, stream))
+            internal.before = str(RoomStreamToken(topo, stream - 1))
+            internal.after = str(RoomStreamToken(topo, stream))
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index e40eb8a8c4..89d1643f10 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -78,14 +78,18 @@ class StreamIdGenerator(object):
         self._current_max = None
         self._unfinished_ids = deque()
 
-    def get_next_txn(self, txn):
+    @defer.inlineCallbacks
+    def get_next(self, store):
         """
         Usage:
-            with stream_id_gen.get_next_txn(txn) as stream_id:
+            with yield stream_id_gen.get_next as stream_id:
                 # ... persist event ...
         """
         if not self._current_max:
-            self._get_or_compute_current_max(txn)
+            yield store.runInteraction(
+                "_compute_current_max",
+                self._get_or_compute_current_max,
+            )
 
         with self._lock:
             self._current_max += 1
@@ -101,7 +105,7 @@ class StreamIdGenerator(object):
                 with self._lock:
                     self._unfinished_ids.remove(next_id)
 
-        return manager()
+        defer.returnValue(manager())
 
     @defer.inlineCallbacks
     def get_max_token(self, store):
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index 5c8e54b78b..dff7970bea 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -31,7 +31,7 @@ class NullSource(object):
     def get_new_events_for_user(self, user, from_key, limit):
         return defer.succeed(([], from_key))
 
-    def get_current_key(self):
+    def get_current_key(self, direction='f'):
         return defer.succeed(0)
 
     def get_pagination_rows(self, user, pagination_config, key):
@@ -52,10 +52,10 @@ class EventSources(object):
         }
 
     @defer.inlineCallbacks
-    def get_current_token(self):
+    def get_current_token(self, direction='f'):
         token = StreamToken(
             room_key=(
-                yield self.sources["room"].get_current_key()
+                yield self.sources["room"].get_current_key(direction)
             ),
             presence_key=(
                 yield self.sources["presence"].get_current_key()
diff --git a/synapse/types.py b/synapse/types.py
index f6a1b0bbcf..0f16867d75 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -121,4 +121,56 @@ class StreamToken(
         return StreamToken(**d)
 
 
+class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
+    """Tokens are positions between events. The token "s1" comes after event 1.
+
+            s0    s1
+            |     |
+        [0] V [1] V [2]
+
+    Tokens can either be a point in the live event stream or a cursor going
+    through historic events.
+
+    When traversing the live event stream events are ordered by when they
+    arrived at the homeserver.
+
+    When traversing historic events the events are ordered by their depth in
+    the event graph "topological_ordering" and then by when they arrived at the
+    homeserver "stream_ordering".
+
+    Live tokens start with an "s" followed by the "stream_ordering" id of the
+    event it comes after. Historic tokens start with a "t" followed by the
+    "topological_ordering" id of the event it comes after, follewed by "-",
+    followed by the "stream_ordering" id of the event it comes after.
+    """
+    __slots__ = []
+
+    @classmethod
+    def parse(cls, string):
+        try:
+            if string[0] == 's':
+                return cls(topological=None, stream=int(string[1:]))
+            if string[0] == 't':
+                parts = string[1:].split('-', 1)
+                return cls(topological=int(parts[0]), stream=int(parts[1]))
+        except:
+            pass
+        raise SynapseError(400, "Invalid token %r" % (string,))
+
+    @classmethod
+    def parse_stream_token(cls, string):
+        try:
+            if string[0] == 's':
+                return cls(topological=None, stream=int(string[1:]))
+        except:
+            pass
+        raise SynapseError(400, "Invalid token %r" % (string,))
+
+    def __str__(self):
+        if self.topological is not None:
+            return "t%d-%d" % (self.topological, self.stream)
+        else:
+            return "s%d" % (self.stream,)
+
+
 ClientInfo = namedtuple("ClientInfo", ("device_id", "token_id"))
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 79109d0b19..c1a16b639a 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.util.logcontext import LoggingContext
+from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
 
 from twisted.internet import defer, reactor, task
 
@@ -23,6 +23,12 @@ import logging
 logger = logging.getLogger(__name__)
 
 
+def unwrapFirstError(failure):
+    # defer.gatherResults and DeferredLists wrap failures.
+    failure.trap(defer.FirstError)
+    return failure.value.subFailure
+
+
 class Clock(object):
     """A small utility that obtains current time-of-day so that time may be
     mocked during unit-tests.
@@ -50,9 +56,12 @@ class Clock(object):
         current_context = LoggingContext.current_context()
 
         def wrapped_callback():
-            LoggingContext.thread_local.current_context = current_context
-            callback()
-        return reactor.callLater(delay, wrapped_callback)
+            with PreserveLoggingContext():
+                LoggingContext.thread_local.current_context = current_context
+                callback()
+
+        with PreserveLoggingContext():
+            return reactor.callLater(delay, wrapped_callback)
 
     def cancel_call_later(self, timer):
         timer.cancel()
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 34acb14a6f..1c2044e5b4 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -16,15 +16,13 @@
 
 from twisted.internet import defer, reactor
 
-from .logcontext import PreserveLoggingContext
+from .logcontext import preserve_context_over_deferred
 
 
-@defer.inlineCallbacks
 def sleep(seconds):
     d = defer.Deferred()
     reactor.callLater(seconds, d.callback, seconds)
-    with PreserveLoggingContext():
-        yield d
+    return preserve_context_over_deferred(d)
 
 
 def run_on_reactor():
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index 9d9c350397..064c4a7a1e 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -13,10 +13,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.util.logcontext import PreserveLoggingContext
-
 from twisted.internet import defer
 
+from synapse.util.logcontext import (
+    PreserveLoggingContext, preserve_context_over_deferred,
+)
+
+from synapse.util import unwrapFirstError
+
 import logging
 
 
@@ -93,7 +97,6 @@ class Signal(object):
         Each observer callable may return a Deferred."""
         self.observers.append(observer)
 
-    @defer.inlineCallbacks
     def fire(self, *args, **kwargs):
         """Invokes every callable in the observer list, passing in the args and
         kwargs. Exceptions thrown by observers are logged but ignored. It is
@@ -101,24 +104,28 @@ class Signal(object):
 
         Returns a Deferred that will complete when all the observers have
         completed."""
+
+        def do(observer):
+            def eb(failure):
+                logger.warning(
+                    "%s signal observer %s failed: %r",
+                    self.name, observer, failure,
+                    exc_info=(
+                        failure.type,
+                        failure.value,
+                        failure.getTracebackObject()))
+                if not self.suppress_failures:
+                    return failure
+            return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb)
+
         with PreserveLoggingContext():
-            deferreds = []
-            for observer in self.observers:
-                d = defer.maybeDeferred(observer, *args, **kwargs)
-
-                def eb(failure):
-                    logger.warning(
-                        "%s signal observer %s failed: %r",
-                        self.name, observer, failure,
-                        exc_info=(
-                            failure.type,
-                            failure.value,
-                            failure.getTracebackObject()))
-                    if not self.suppress_failures:
-                        failure.raiseException()
-                deferreds.append(d.addErrback(eb))
-            results = []
-            for deferred in deferreds:
-                result = yield deferred
-                results.append(result)
-        defer.returnValue(results)
+            deferreds = [
+                do(observer)
+                for observer in self.observers
+            ]
+
+            d = defer.gatherResults(deferreds, consumeErrors=True)
+
+        d.addErrback(unwrapFirstError)
+
+        return preserve_context_over_deferred(d)
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index da7872e95d..a92d518b43 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -12,6 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from twisted.internet import defer
+
 import threading
 import logging
 
@@ -129,3 +131,53 @@ class PreserveLoggingContext(object):
     def __exit__(self, type, value, traceback):
         """Restores the current logging context"""
         LoggingContext.thread_local.current_context = self.current_context
+
+        if self.current_context is not LoggingContext.sentinel:
+            if self.current_context.parent_context is None:
+                logger.warn(
+                    "Restoring dead context: %s",
+                    self.current_context,
+                )
+
+
+def preserve_context_over_fn(fn, *args, **kwargs):
+    """Takes a function and invokes it with the given arguments, but removes
+    and restores the current logging context while doing so.
+
+    If the result is a deferred, call preserve_context_over_deferred before
+    returning it.
+    """
+    with PreserveLoggingContext():
+        res = fn(*args, **kwargs)
+
+    if isinstance(res, defer.Deferred):
+        return preserve_context_over_deferred(res)
+    else:
+        return res
+
+
+def preserve_context_over_deferred(deferred):
+    """Given a deferred wrap it such that any callbacks added later to it will
+    be invoked with the current context.
+    """
+    d = defer.Deferred()
+
+    current_context = LoggingContext.current_context()
+
+    def cb(res):
+        with PreserveLoggingContext():
+            LoggingContext.thread_local.current_context = current_context
+            res = d.callback(res)
+        return res
+
+    def eb(failure):
+        with PreserveLoggingContext():
+            LoggingContext.thread_local.current_context = current_context
+            res = d.errback(failure)
+        return res
+
+    if deferred.called:
+        return deferred
+
+    deferred.addCallbacks(cb, eb)
+    return d