summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/auth.py45
-rw-r--r--synapse/app/pusher.py16
-rw-r--r--synapse/app/synchrotron.py54
-rw-r--r--synapse/crypto/keyring.py28
-rw-r--r--synapse/events/utils.py2
-rw-r--r--synapse/federation/federation_client.py47
-rw-r--r--synapse/federation/transaction_queue.py381
-rw-r--r--synapse/handlers/__init__.py3
-rw-r--r--synapse/handlers/auth.py25
-rw-r--r--synapse/handlers/federation.py27
-rw-r--r--synapse/handlers/presence.py36
-rw-r--r--synapse/handlers/room_member.py2
-rw-r--r--synapse/http/matrixfederationclient.py4
-rw-r--r--synapse/replication/resource.py21
-rw-r--r--synapse/replication/slave/storage/_base.py30
-rw-r--r--synapse/replication/slave/storage/directory.py2
-rw-r--r--synapse/rest/client/v1/admin.py8
-rw-r--r--synapse/rest/client/v1/base.py1
-rw-r--r--synapse/rest/client/v1/directory.py5
-rw-r--r--synapse/rest/client/v1/events.py11
-rw-r--r--synapse/rest/client/v1/initial_sync.py4
-rw-r--r--synapse/rest/client/v1/login.py162
-rw-r--r--synapse/rest/client/v1/profile.py12
-rw-r--r--synapse/rest/client/v1/register.py2
-rw-r--r--synapse/rest/client/v1/room.py48
-rw-r--r--synapse/rest/client/v2_alpha/sync.py2
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py4
-rw-r--r--synapse/server.py9
-rw-r--r--synapse/server.pyi4
-rw-r--r--synapse/storage/__init__.py8
-rw-r--r--synapse/storage/_base.py68
-rw-r--r--synapse/storage/directory.py37
-rw-r--r--synapse/storage/prepare_database.py2
-rw-r--r--synapse/storage/presence.py32
-rw-r--r--synapse/storage/roommember.py12
-rw-r--r--synapse/storage/schema/delta/34/cache_stream.py46
-rw-r--r--synapse/util/logcontext.py1
-rw-r--r--synapse/util/metrics.py17
38 files changed, 707 insertions, 511 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 59db76debc..0db26fcfd7 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -675,27 +675,18 @@ class Auth(object):
         try:
             macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
 
-            user_prefix = "user_id = "
-            user = None
-            user_id = None
-            guest = False
-            for caveat in macaroon.caveats:
-                if caveat.caveat_id.startswith(user_prefix):
-                    user_id = caveat.caveat_id[len(user_prefix):]
-                    user = UserID.from_string(user_id)
-                elif caveat.caveat_id == "guest = true":
-                    guest = True
+            user_id = self.get_user_id_from_macaroon(macaroon)
+            user = UserID.from_string(user_id)
 
             self.validate_macaroon(
                 macaroon, rights, self.hs.config.expire_access_token,
                 user_id=user_id,
             )
 
-            if user is None:
-                raise AuthError(
-                    self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
-                    errcode=Codes.UNKNOWN_TOKEN
-                )
+            guest = False
+            for caveat in macaroon.caveats:
+                if caveat.caveat_id == "guest = true":
+                    guest = True
 
             if guest:
                 ret = {
@@ -743,6 +734,29 @@ class Auth(object):
                 errcode=Codes.UNKNOWN_TOKEN
             )
 
+    def get_user_id_from_macaroon(self, macaroon):
+        """Retrieve the user_id given by the caveats on the macaroon.
+
+        Does *not* validate the macaroon.
+
+        Args:
+            macaroon (pymacaroons.Macaroon): The macaroon to validate
+
+        Returns:
+            (str) user id
+
+        Raises:
+            AuthError if there is no user_id caveat in the macaroon
+        """
+        user_prefix = "user_id = "
+        for caveat in macaroon.caveats:
+            if caveat.caveat_id.startswith(user_prefix):
+                return caveat.caveat_id[len(user_prefix):]
+        raise AuthError(
+            self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
+            errcode=Codes.UNKNOWN_TOKEN
+        )
+
     def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id):
         """
         validate that a Macaroon is understood by and was signed by this server.
@@ -754,6 +768,7 @@ class Auth(object):
             verify_expiry(bool): Whether to verify whether the macaroon has expired.
                 This should really always be True, but no clients currently implement
                 token refresh, so we can't enforce expiry yet.
+            user_id (str): The user_id required
         """
         v = pymacaroons.Verifier()
         v.satisfy_exact("gen = 1")
diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py
index c8dde0fcb8..8d755a4b33 100644
--- a/synapse/app/pusher.py
+++ b/synapse/app/pusher.py
@@ -80,11 +80,6 @@ class PusherSlaveStore(
         DataStore.get_profile_displayname.__func__
     )
 
-    # XXX: This is a bit broken because we don't persist forgotten rooms
-    # in a way that they can be streamed. This means that we don't have a
-    # way to invalidate the forgotten rooms cache correctly.
-    # For now we expire the cache every 10 minutes.
-    BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000
     who_forgot_in_room = (
         RoomMemberStore.__dict__["who_forgot_in_room"]
     )
@@ -168,7 +163,6 @@ class PusherServer(HomeServer):
         store = self.get_datastore()
         replication_url = self.config.worker_replication_url
         pusher_pool = self.get_pusherpool()
-        clock = self.get_clock()
 
         def stop_pusher(user_id, app_id, pushkey):
             key = "%s:%s" % (app_id, pushkey)
@@ -220,21 +214,11 @@ class PusherServer(HomeServer):
                     min_stream_id, max_stream_id, affected_room_ids
                 )
 
-        def expire_broken_caches():
-            store.who_forgot_in_room.invalidate_all()
-
-        next_expire_broken_caches_ms = 0
         while True:
             try:
                 args = store.stream_positions()
                 args["timeout"] = 30000
                 result = yield http_client.get_json(replication_url, args=args)
-                now_ms = clock.time_msec()
-                if now_ms > next_expire_broken_caches_ms:
-                    expire_broken_caches()
-                    next_expire_broken_caches_ms = (
-                        now_ms + store.BROKEN_CACHE_EXPIRY_MS
-                    )
                 yield store.process_replication(result)
                 poke_pushers(result)
             except:
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index 215ccfd522..e3173533e2 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -26,6 +26,7 @@ from synapse.http.site import SynapseSite
 from synapse.http.server import JsonResource
 from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
 from synapse.rest.client.v2_alpha import sync
+from synapse.rest.client.v1 import events
 from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage.events import SlavedEventStore
 from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
@@ -74,11 +75,6 @@ class SynchrotronSlavedStore(
     BaseSlavedStore,
     ClientIpStore,  # After BaseSlavedStore because the constructor is different
 ):
-    # XXX: This is a bit broken because we don't persist forgotten rooms
-    # in a way that they can be streamed. This means that we don't have a
-    # way to invalidate the forgotten rooms cache correctly.
-    # For now we expire the cache every 10 minutes.
-    BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000
     who_forgot_in_room = (
         RoomMemberStore.__dict__["who_forgot_in_room"]
     )
@@ -89,17 +85,23 @@ class SynchrotronSlavedStore(
     get_presence_list_accepted = PresenceStore.__dict__[
         "get_presence_list_accepted"
     ]
+    get_presence_list_observers_accepted = PresenceStore.__dict__[
+        "get_presence_list_observers_accepted"
+    ]
+
 
 UPDATE_SYNCING_USERS_MS = 10 * 1000
 
 
 class SynchrotronPresence(object):
     def __init__(self, hs):
+        self.is_mine_id = hs.is_mine_id
         self.http_client = hs.get_simple_http_client()
         self.store = hs.get_datastore()
         self.user_to_num_current_syncs = {}
         self.syncing_users_url = hs.config.worker_replication_url + "/syncing_users"
         self.clock = hs.get_clock()
+        self.notifier = hs.get_notifier()
 
         active_presence = self.store.take_presence_startup_info()
         self.user_to_current_state = {
@@ -119,11 +121,13 @@ class SynchrotronPresence(object):
 
         reactor.addSystemEventTrigger("before", "shutdown", self._on_shutdown)
 
-    def set_state(self, user, state):
+    def set_state(self, user, state, ignore_status_msg=False):
         # TODO Hows this supposed to work?
         pass
 
     get_states = PresenceHandler.get_states.__func__
+    get_state = PresenceHandler.get_state.__func__
+    _get_interested_parties = PresenceHandler._get_interested_parties.__func__
     current_state_for_users = PresenceHandler.current_state_for_users.__func__
 
     @defer.inlineCallbacks
@@ -194,19 +198,39 @@ class SynchrotronPresence(object):
             self._need_to_send_sync = False
             yield self._send_syncing_users_now()
 
+    @defer.inlineCallbacks
+    def notify_from_replication(self, states, stream_id):
+        parties = yield self._get_interested_parties(
+            states, calculate_remote_hosts=False
+        )
+        room_ids_to_states, users_to_states, _ = parties
+
+        self.notifier.on_new_event(
+            "presence_key", stream_id, rooms=room_ids_to_states.keys(),
+            users=users_to_states.keys()
+        )
+
+    @defer.inlineCallbacks
     def process_replication(self, result):
         stream = result.get("presence", {"rows": []})
+        states = []
         for row in stream["rows"]:
             (
                 position, user_id, state, last_active_ts,
                 last_federation_update_ts, last_user_sync_ts, status_msg,
                 currently_active
             ) = row
-            self.user_to_current_state[user_id] = UserPresenceState(
+            state = UserPresenceState(
                 user_id, state, last_active_ts,
                 last_federation_update_ts, last_user_sync_ts, status_msg,
                 currently_active
             )
+            self.user_to_current_state[user_id] = state
+            states.append(state)
+
+        if states and "position" in stream:
+            stream_id = int(stream["position"])
+            yield self.notify_from_replication(states, stream_id)
 
 
 class SynchrotronTyping(object):
@@ -266,10 +290,12 @@ class SynchrotronServer(HomeServer):
                 elif name == "client":
                     resource = JsonResource(self, canonical_json=False)
                     sync.register_servlets(self, resource)
+                    events.register_servlets(self, resource)
                     resources.update({
                         "/_matrix/client/r0": resource,
                         "/_matrix/client/unstable": resource,
                         "/_matrix/client/v2_alpha": resource,
+                        "/_matrix/client/api/v1": resource,
                     })
 
         root_resource = create_resource_tree(resources, Resource())
@@ -307,15 +333,10 @@ class SynchrotronServer(HomeServer):
         http_client = self.get_simple_http_client()
         store = self.get_datastore()
         replication_url = self.config.worker_replication_url
-        clock = self.get_clock()
         notifier = self.get_notifier()
         presence_handler = self.get_presence_handler()
         typing_handler = self.get_typing_handler()
 
-        def expire_broken_caches():
-            store.who_forgot_in_room.invalidate_all()
-            store.get_presence_list_accepted.invalidate_all()
-
         def notify_from_stream(
             result, stream_name, stream_key, room=None, user=None
         ):
@@ -377,22 +398,15 @@ class SynchrotronServer(HomeServer):
                 result, "typing", "typing_key", room="room_id"
             )
 
-        next_expire_broken_caches_ms = 0
         while True:
             try:
                 args = store.stream_positions()
                 args.update(typing_handler.stream_positions())
                 args["timeout"] = 30000
                 result = yield http_client.get_json(replication_url, args=args)
-                now_ms = clock.time_msec()
-                if now_ms > next_expire_broken_caches_ms:
-                    expire_broken_caches()
-                    next_expire_broken_caches_ms = (
-                        now_ms + store.BROKEN_CACHE_EXPIRY_MS
-                    )
                 yield store.process_replication(result)
                 typing_handler.process_replication(result)
-                presence_handler.process_replication(result)
+                yield presence_handler.process_replication(result)
                 notify(result)
             except:
                 logger.exception("Error replicating from %r", replication_url)
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 5012c10ee8..7cd11cfae7 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -61,6 +61,10 @@ Attributes:
 """
 
 
+class KeyLookupError(ValueError):
+    pass
+
+
 class Keyring(object):
     def __init__(self, hs):
         self.store = hs.get_datastore()
@@ -363,7 +367,7 @@ class Keyring(object):
                     )
                 except Exception as e:
                     logger.info(
-                        "Unable to getting key %r for %r directly: %s %s",
+                        "Unable to get key %r for %r directly: %s %s",
                         key_ids, server_name,
                         type(e).__name__, str(e.message),
                     )
@@ -425,7 +429,7 @@ class Keyring(object):
         for response in responses:
             if (u"signatures" not in response
                     or perspective_name not in response[u"signatures"]):
-                raise ValueError(
+                raise KeyLookupError(
                     "Key response not signed by perspective server"
                     " %r" % (perspective_name,)
                 )
@@ -448,7 +452,7 @@ class Keyring(object):
                     list(response[u"signatures"][perspective_name]),
                     list(perspective_keys)
                 )
-                raise ValueError(
+                raise KeyLookupError(
                     "Response not signed with a known key for perspective"
                     " server %r" % (perspective_name,)
                 )
@@ -491,10 +495,10 @@ class Keyring(object):
 
             if (u"signatures" not in response
                     or server_name not in response[u"signatures"]):
-                raise ValueError("Key response not signed by remote server")
+                raise KeyLookupError("Key response not signed by remote server")
 
             if "tls_fingerprints" not in response:
-                raise ValueError("Key response missing TLS fingerprints")
+                raise KeyLookupError("Key response missing TLS fingerprints")
 
             certificate_bytes = crypto.dump_certificate(
                 crypto.FILETYPE_ASN1, tls_certificate
@@ -508,7 +512,7 @@ class Keyring(object):
                     response_sha256_fingerprints.add(fingerprint[u"sha256"])
 
             if sha256_fingerprint_b64 not in response_sha256_fingerprints:
-                raise ValueError("TLS certificate not allowed by fingerprints")
+                raise KeyLookupError("TLS certificate not allowed by fingerprints")
 
             response_keys = yield self.process_v2_response(
                 from_server=server_name,
@@ -560,14 +564,14 @@ class Keyring(object):
         server_name = response_json["server_name"]
         if only_from_server:
             if server_name != from_server:
-                raise ValueError(
+                raise KeyLookupError(
                     "Expected a response for server %r not %r" % (
                         from_server, server_name
                     )
                 )
         for key_id in response_json["signatures"].get(server_name, {}):
             if key_id not in response_json["verify_keys"]:
-                raise ValueError(
+                raise KeyLookupError(
                     "Key response must include verification keys for all"
                     " signatures"
                 )
@@ -635,15 +639,15 @@ class Keyring(object):
 
         if ("signatures" not in response
                 or server_name not in response["signatures"]):
-            raise ValueError("Key response not signed by remote server")
+            raise KeyLookupError("Key response not signed by remote server")
 
         if "tls_certificate" not in response:
-            raise ValueError("Key response missing TLS certificate")
+            raise KeyLookupError("Key response missing TLS certificate")
 
         tls_certificate_b64 = response["tls_certificate"]
 
         if encode_base64(x509_certificate_bytes) != tls_certificate_b64:
-            raise ValueError("TLS certificate doesn't match")
+            raise KeyLookupError("TLS certificate doesn't match")
 
         # Cache the result in the datastore.
 
@@ -659,7 +663,7 @@ class Keyring(object):
 
         for key_id in response["signatures"][server_name]:
             if key_id not in response["verify_keys"]:
-                raise ValueError(
+                raise KeyLookupError(
                     "Key response must include verification keys for all"
                     " signatures"
                 )
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index aab18d7f71..0e9fd902af 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -88,6 +88,8 @@ def prune_event(event):
 
     if "age_ts" in event.unsigned:
         allowed_fields["unsigned"]["age_ts"] = event.unsigned["age_ts"]
+    if "replaces_state" in event.unsigned:
+        allowed_fields["unsigned"]["replaces_state"] = event.unsigned["replaces_state"]
 
     return type(event)(
         allowed_fields,
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index da95c2ad6d..9ba3151713 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -51,10 +51,34 @@ sent_edus_counter = metrics.register_counter("sent_edus")
 sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"])
 
 
+PDU_RETRY_TIME_MS = 1 * 60 * 1000
+
+
 class FederationClient(FederationBase):
     def __init__(self, hs):
         super(FederationClient, self).__init__(hs)
 
+        self.pdu_destination_tried = {}
+        self._clock.looping_call(
+            self._clear_tried_cache, 60 * 1000,
+        )
+
+    def _clear_tried_cache(self):
+        """Clear pdu_destination_tried cache"""
+        now = self._clock.time_msec()
+
+        old_dict = self.pdu_destination_tried
+        self.pdu_destination_tried = {}
+
+        for event_id, destination_dict in old_dict.items():
+            destination_dict = {
+                dest: time
+                for dest, time in destination_dict.items()
+                if time + PDU_RETRY_TIME_MS > now
+            }
+            if destination_dict:
+                self.pdu_destination_tried[event_id] = destination_dict
+
     def start_get_pdu_cache(self):
         self._get_pdu_cache = ExpiringCache(
             cache_name="get_pdu_cache",
@@ -240,8 +264,15 @@ class FederationClient(FederationBase):
             if ev:
                 defer.returnValue(ev)
 
+        pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})
+
         pdu = None
         for destination in destinations:
+            now = self._clock.time_msec()
+            last_attempt = pdu_attempts.get(destination, 0)
+            if last_attempt + PDU_RETRY_TIME_MS > now:
+                continue
+
             try:
                 limiter = yield get_retry_limiter(
                     destination,
@@ -269,25 +300,19 @@ class FederationClient(FederationBase):
 
                         break
 
-            except SynapseError as e:
-                logger.info(
-                    "Failed to get PDU %s from %s because %s",
-                    event_id, destination, e,
-                )
-                continue
-            except CodeMessageException as e:
-                if 400 <= e.code < 500:
-                    raise
+                pdu_attempts[destination] = now
 
+            except SynapseError as e:
                 logger.info(
                     "Failed to get PDU %s from %s because %s",
                     event_id, destination, e,
                 )
-                continue
             except NotRetryingDestination as e:
                 logger.info(e.message)
                 continue
             except Exception as e:
+                pdu_attempts[destination] = now
+
                 logger.info(
                     "Failed to get PDU %s from %s because %s",
                     event_id, destination, e,
@@ -406,7 +431,7 @@ class FederationClient(FederationBase):
             events and the second is a list of event ids that we failed to fetch.
         """
         if return_local:
-            seen_events = yield self.store.get_events(event_ids)
+            seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
             signed_events = seen_events.values()
         else:
             seen_events = yield self.store.have_events(event_ids)
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 5787f854d4..cb2ef0210c 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -21,11 +21,11 @@ from .units import Transaction
 
 from synapse.api.errors import HttpResponseException
 from synapse.util.async import run_on_reactor
-from synapse.util.logutils import log_function
-from synapse.util.logcontext import PreserveLoggingContext
+from synapse.util.logcontext import preserve_context_over_fn
 from synapse.util.retryutils import (
     get_retry_limiter, NotRetryingDestination,
 )
+from synapse.util.metrics import measure_func
 import synapse.metrics
 
 import logging
@@ -51,7 +51,7 @@ class TransactionQueue(object):
 
         self.transport_layer = transport_layer
 
-        self._clock = hs.get_clock()
+        self.clock = hs.get_clock()
 
         # Is a mapping from destinations -> deferreds. Used to keep track
         # of which destinations have transactions in flight and when they are
@@ -82,7 +82,7 @@ class TransactionQueue(object):
         self.pending_failures_by_dest = {}
 
         # HACK to get unique tx id
-        self._next_txn_id = int(self._clock.time_msec())
+        self._next_txn_id = int(self.clock.time_msec())
 
     def can_send_to(self, destination):
         """Can we send messages to the given server?
@@ -119,266 +119,215 @@ class TransactionQueue(object):
         if not destinations:
             return
 
-        deferreds = []
-
         for destination in destinations:
-            deferred = defer.Deferred()
             self.pending_pdus_by_dest.setdefault(destination, []).append(
-                (pdu, deferred, order)
+                (pdu, order)
             )
 
-            def chain(failure):
-                if not deferred.called:
-                    deferred.errback(failure)
-
-            def log_failure(f):
-                logger.warn("Failed to send pdu to %s: %s", destination, f.value)
-
-            deferred.addErrback(log_failure)
-
-            with PreserveLoggingContext():
-                self._attempt_new_transaction(destination).addErrback(chain)
-
-            deferreds.append(deferred)
+            preserve_context_over_fn(
+                self._attempt_new_transaction, destination
+            )
 
-    # NO inlineCallbacks
     def enqueue_edu(self, edu):
         destination = edu.destination
 
         if not self.can_send_to(destination):
             return
 
-        deferred = defer.Deferred()
-        self.pending_edus_by_dest.setdefault(destination, []).append(
-            (edu, deferred)
-        )
+        self.pending_edus_by_dest.setdefault(destination, []).append(edu)
 
-        def chain(failure):
-            if not deferred.called:
-                deferred.errback(failure)
-
-        def log_failure(f):
-            logger.warn("Failed to send edu to %s: %s", destination, f.value)
-
-        deferred.addErrback(log_failure)
-
-        with PreserveLoggingContext():
-            self._attempt_new_transaction(destination).addErrback(chain)
-
-        return deferred
+        preserve_context_over_fn(
+            self._attempt_new_transaction, destination
+        )
 
-    @defer.inlineCallbacks
     def enqueue_failure(self, failure, destination):
         if destination == self.server_name or destination == "localhost":
             return
 
-        deferred = defer.Deferred()
-
         if not self.can_send_to(destination):
             return
 
         self.pending_failures_by_dest.setdefault(
             destination, []
-        ).append(
-            (failure, deferred)
-        )
-
-        def chain(f):
-            if not deferred.called:
-                deferred.errback(f)
-
-        def log_failure(f):
-            logger.warn("Failed to send failure to %s: %s", destination, f.value)
-
-        deferred.addErrback(log_failure)
-
-        with PreserveLoggingContext():
-            self._attempt_new_transaction(destination).addErrback(chain)
+        ).append(failure)
 
-        yield deferred
+        preserve_context_over_fn(
+            self._attempt_new_transaction, destination
+        )
 
     @defer.inlineCallbacks
-    @log_function
     def _attempt_new_transaction(self, destination):
         yield run_on_reactor()
+        while True:
+            # list of (pending_pdu, deferred, order)
+            if destination in self.pending_transactions:
+                # XXX: pending_transactions can get stuck on by a never-ending
+                # request at which point pending_pdus_by_dest just keeps growing.
+                # we need application-layer timeouts of some flavour of these
+                # requests
+                logger.debug(
+                    "TX [%s] Transaction already in progress",
+                    destination
+                )
+                return
 
-        # list of (pending_pdu, deferred, order)
-        if destination in self.pending_transactions:
-            # XXX: pending_transactions can get stuck on by a never-ending
-            # request at which point pending_pdus_by_dest just keeps growing.
-            # we need application-layer timeouts of some flavour of these
-            # requests
-            logger.debug(
-                "TX [%s] Transaction already in progress",
-                destination
-            )
-            return
-
-        pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
-        pending_edus = self.pending_edus_by_dest.pop(destination, [])
-        pending_failures = self.pending_failures_by_dest.pop(destination, [])
+            pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
+            pending_edus = self.pending_edus_by_dest.pop(destination, [])
+            pending_failures = self.pending_failures_by_dest.pop(destination, [])
 
-        if pending_pdus:
-            logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
-                         destination, len(pending_pdus))
+            if pending_pdus:
+                logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
+                             destination, len(pending_pdus))
 
-        if not pending_pdus and not pending_edus and not pending_failures:
-            logger.debug("TX [%s] Nothing to send", destination)
-            return
+            if not pending_pdus and not pending_edus and not pending_failures:
+                logger.debug("TX [%s] Nothing to send", destination)
+                return
 
-        try:
-            self.pending_transactions[destination] = 1
+            yield self._send_new_transaction(
+                destination, pending_pdus, pending_edus, pending_failures
+            )
 
-            logger.debug("TX [%s] _attempt_new_transaction", destination)
+    @measure_func("_send_new_transaction")
+    @defer.inlineCallbacks
+    def _send_new_transaction(self, destination, pending_pdus, pending_edus,
+                              pending_failures):
 
             # Sort based on the order field
-            pending_pdus.sort(key=lambda t: t[2])
-
+            pending_pdus.sort(key=lambda t: t[1])
             pdus = [x[0] for x in pending_pdus]
-            edus = [x[0] for x in pending_edus]
-            failures = [x[0].get_dict() for x in pending_failures]
-            deferreds = [
-                x[1]
-                for x in pending_pdus + pending_edus + pending_failures
-            ]
-
-            txn_id = str(self._next_txn_id)
-
-            limiter = yield get_retry_limiter(
-                destination,
-                self._clock,
-                self.store,
-            )
+            edus = pending_edus
+            failures = [x.get_dict() for x in pending_failures]
 
-            logger.debug(
-                "TX [%s] {%s} Attempting new transaction"
-                " (pdus: %d, edus: %d, failures: %d)",
-                destination, txn_id,
-                len(pending_pdus),
-                len(pending_edus),
-                len(pending_failures)
-            )
+            try:
+                self.pending_transactions[destination] = 1
 
-            logger.debug("TX [%s] Persisting transaction...", destination)
+                logger.debug("TX [%s] _attempt_new_transaction", destination)
 
-            transaction = Transaction.create_new(
-                origin_server_ts=int(self._clock.time_msec()),
-                transaction_id=txn_id,
-                origin=self.server_name,
-                destination=destination,
-                pdus=pdus,
-                edus=edus,
-                pdu_failures=failures,
-            )
+                txn_id = str(self._next_txn_id)
 
-            self._next_txn_id += 1
+                limiter = yield get_retry_limiter(
+                    destination,
+                    self.clock,
+                    self.store,
+                )
 
-            yield self.transaction_actions.prepare_to_send(transaction)
+                logger.debug(
+                    "TX [%s] {%s} Attempting new transaction"
+                    " (pdus: %d, edus: %d, failures: %d)",
+                    destination, txn_id,
+                    len(pending_pdus),
+                    len(pending_edus),
+                    len(pending_failures)
+                )
 
-            logger.debug("TX [%s] Persisted transaction", destination)
-            logger.info(
-                "TX [%s] {%s} Sending transaction [%s],"
-                " (PDUs: %d, EDUs: %d, failures: %d)",
-                destination, txn_id,
-                transaction.transaction_id,
-                len(pending_pdus),
-                len(pending_edus),
-                len(pending_failures),
-            )
+                logger.debug("TX [%s] Persisting transaction...", destination)
 
-            with limiter:
-                # Actually send the transaction
-
-                # FIXME (erikj): This is a bit of a hack to make the Pdu age
-                # keys work
-                def json_data_cb():
-                    data = transaction.get_dict()
-                    now = int(self._clock.time_msec())
-                    if "pdus" in data:
-                        for p in data["pdus"]:
-                            if "age_ts" in p:
-                                unsigned = p.setdefault("unsigned", {})
-                                unsigned["age"] = now - int(p["age_ts"])
-                                del p["age_ts"]
-                    return data
-
-                try:
-                    response = yield self.transport_layer.send_transaction(
-                        transaction, json_data_cb
-                    )
-                    code = 200
-
-                    if response:
-                        for e_id, r in response.get("pdus", {}).items():
-                            if "error" in r:
-                                logger.warn(
-                                    "Transaction returned error for %s: %s",
-                                    e_id, r,
-                                )
-                except HttpResponseException as e:
-                    code = e.code
-                    response = e.response
+                transaction = Transaction.create_new(
+                    origin_server_ts=int(self.clock.time_msec()),
+                    transaction_id=txn_id,
+                    origin=self.server_name,
+                    destination=destination,
+                    pdus=pdus,
+                    edus=edus,
+                    pdu_failures=failures,
+                )
+
+                self._next_txn_id += 1
+
+                yield self.transaction_actions.prepare_to_send(transaction)
 
+                logger.debug("TX [%s] Persisted transaction", destination)
                 logger.info(
-                    "TX [%s] {%s} got %d response",
-                    destination, txn_id, code
+                    "TX [%s] {%s} Sending transaction [%s],"
+                    " (PDUs: %d, EDUs: %d, failures: %d)",
+                    destination, txn_id,
+                    transaction.transaction_id,
+                    len(pending_pdus),
+                    len(pending_edus),
+                    len(pending_failures),
                 )
 
-                logger.debug("TX [%s] Sent transaction", destination)
-                logger.debug("TX [%s] Marking as delivered...", destination)
+                with limiter:
+                    # Actually send the transaction
+
+                    # FIXME (erikj): This is a bit of a hack to make the Pdu age
+                    # keys work
+                    def json_data_cb():
+                        data = transaction.get_dict()
+                        now = int(self.clock.time_msec())
+                        if "pdus" in data:
+                            for p in data["pdus"]:
+                                if "age_ts" in p:
+                                    unsigned = p.setdefault("unsigned", {})
+                                    unsigned["age"] = now - int(p["age_ts"])
+                                    del p["age_ts"]
+                        return data
+
+                    try:
+                        response = yield self.transport_layer.send_transaction(
+                            transaction, json_data_cb
+                        )
+                        code = 200
+
+                        if response:
+                            for e_id, r in response.get("pdus", {}).items():
+                                if "error" in r:
+                                    logger.warn(
+                                        "Transaction returned error for %s: %s",
+                                        e_id, r,
+                                    )
+                    except HttpResponseException as e:
+                        code = e.code
+                        response = e.response
+
+                    logger.info(
+                        "TX [%s] {%s} got %d response",
+                        destination, txn_id, code
+                    )
 
-            yield self.transaction_actions.delivered(
-                transaction, code, response
-            )
+                    logger.debug("TX [%s] Sent transaction", destination)
+                    logger.debug("TX [%s] Marking as delivered...", destination)
 
-            logger.debug("TX [%s] Marked as delivered", destination)
-
-            logger.debug("TX [%s] Yielding to callbacks...", destination)
-
-            for deferred in deferreds:
-                if code == 200:
-                    deferred.callback(None)
-                else:
-                    deferred.errback(RuntimeError("Got status %d" % code))
-
-                # Ensures we don't continue until all callbacks on that
-                # deferred have fired
-                try:
-                    yield deferred
-                except:
-                    pass
-
-            logger.debug("TX [%s] Yielded to callbacks", destination)
-        except NotRetryingDestination:
-            logger.info(
-                "TX [%s] not ready for retry yet - "
-                "dropping transaction for now",
-                destination,
-            )
-        except RuntimeError as e:
-            # We capture this here as there as nothing actually listens
-            # for this finishing functions deferred.
-            logger.warn(
-                "TX [%s] Problem in _attempt_transaction: %s",
-                destination,
-                e,
-            )
-        except Exception as e:
-            # We capture this here as there as nothing actually listens
-            # for this finishing functions deferred.
-            logger.warn(
-                "TX [%s] Problem in _attempt_transaction: %s",
-                destination,
-                e,
-            )
+                yield self.transaction_actions.delivered(
+                    transaction, code, response
+                )
 
-            for deferred in deferreds:
-                if not deferred.called:
-                    deferred.errback(e)
+                logger.debug("TX [%s] Marked as delivered", destination)
+
+                if code != 200:
+                    for p in pdus:
+                        logger.info(
+                            "Failed to send event %s to %s", p.event_id, destination
+                        )
+            except NotRetryingDestination:
+                logger.info(
+                    "TX [%s] not ready for retry yet - "
+                    "dropping transaction for now",
+                    destination,
+                )
+            except RuntimeError as e:
+                # We capture this here as there as nothing actually listens
+                # for this finishing functions deferred.
+                logger.warn(
+                    "TX [%s] Problem in _attempt_transaction: %s",
+                    destination,
+                    e,
+                )
+
+                for p in pdus:
+                    logger.info("Failed to send event %s to %s", p.event_id, destination)
+            except Exception as e:
+                # We capture this here as there as nothing actually listens
+                # for this finishing functions deferred.
+                logger.warn(
+                    "TX [%s] Problem in _attempt_transaction: %s",
+                    destination,
+                    e,
+                )
 
-        finally:
-            # We want to be *very* sure we delete this after we stop processing
-            self.pending_transactions.pop(destination, None)
+                for p in pdus:
+                    logger.info("Failed to send event %s to %s", p.event_id, destination)
 
-            # Check to see if there is anything else to send.
-            self._attempt_new_transaction(destination)
+            finally:
+                # We want to be *very* sure we delete this after we stop processing
+                self.pending_transactions.pop(destination, None)
diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py
index 1a50a2ec98..63d05f2531 100644
--- a/synapse/handlers/__init__.py
+++ b/synapse/handlers/__init__.py
@@ -19,7 +19,6 @@ from .room import (
 )
 from .room_member import RoomMemberHandler
 from .message import MessageHandler
-from .events import EventStreamHandler, EventHandler
 from .federation import FederationHandler
 from .profile import ProfileHandler
 from .directory import DirectoryHandler
@@ -53,8 +52,6 @@ class Handlers(object):
         self.message_handler = MessageHandler(hs)
         self.room_creation_handler = RoomCreationHandler(hs)
         self.room_member_handler = RoomMemberHandler(hs)
-        self.event_stream_handler = EventStreamHandler(hs)
-        self.event_handler = EventHandler(hs)
         self.federation_handler = FederationHandler(hs)
         self.profile_handler = ProfileHandler(hs)
         self.directory_handler = DirectoryHandler(hs)
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 2e138f328f..a582d6334b 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -70,11 +70,11 @@ class AuthHandler(BaseHandler):
             self.ldap_uri = hs.config.ldap_uri
             self.ldap_start_tls = hs.config.ldap_start_tls
             self.ldap_base = hs.config.ldap_base
-            self.ldap_filter = hs.config.ldap_filter
             self.ldap_attributes = hs.config.ldap_attributes
             if self.ldap_mode == LDAPMode.SEARCH:
                 self.ldap_bind_dn = hs.config.ldap_bind_dn
                 self.ldap_bind_password = hs.config.ldap_bind_password
+                self.ldap_filter = hs.config.ldap_filter
 
         self.hs = hs  # FIXME better possibility to access registrationHandler later?
         self.device_handler = hs.get_device_handler()
@@ -660,7 +660,7 @@ class AuthHandler(BaseHandler):
                 else:
                     logger.warn(
                         "ldap registration failed: unexpected (%d!=1) amount of results",
-                        len(result)
+                        len(conn.response)
                     )
                     defer.returnValue(False)
 
@@ -719,13 +719,14 @@ class AuthHandler(BaseHandler):
         return macaroon.serialize()
 
     def validate_short_term_login_token_and_get_user_id(self, login_token):
+        auth_api = self.hs.get_auth()
         try:
             macaroon = pymacaroons.Macaroon.deserialize(login_token)
-            auth_api = self.hs.get_auth()
-            auth_api.validate_macaroon(macaroon, "login", True)
-            return self.get_user_from_macaroon(macaroon)
-        except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
-            raise AuthError(401, "Invalid token", errcode=Codes.UNKNOWN_TOKEN)
+            user_id = auth_api.get_user_id_from_macaroon(macaroon)
+            auth_api.validate_macaroon(macaroon, "login", True, user_id)
+            return user_id
+        except Exception:
+            raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
 
     def _generate_base_macaroon(self, user_id):
         macaroon = pymacaroons.Macaroon(
@@ -736,16 +737,6 @@ class AuthHandler(BaseHandler):
         macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
         return macaroon
 
-    def get_user_from_macaroon(self, macaroon):
-        user_prefix = "user_id = "
-        for caveat in macaroon.caveats:
-            if caveat.caveat_id.startswith(user_prefix):
-                return caveat.caveat_id[len(user_prefix):]
-        raise AuthError(
-            self.INVALID_TOKEN_HTTP_STATUS, "No user_id found in token",
-            errcode=Codes.UNKNOWN_TOKEN
-        )
-
     @defer.inlineCallbacks
     def set_password(self, user_id, newpassword, requester=None):
         password_hash = self.hash(newpassword)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 618cb53629..ff6bb475b5 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -249,7 +249,7 @@ class FederationHandler(BaseHandler):
                         if ev.type != EventTypes.Member:
                             continue
                         try:
-                            domain = UserID.from_string(ev.state_key).domain
+                            domain = get_domain_from_id(ev.state_key)
                         except:
                             continue
 
@@ -1093,16 +1093,17 @@ class FederationHandler(BaseHandler):
         )
 
         if event:
-            # FIXME: This is a temporary work around where we occasionally
-            # return events slightly differently than when they were
-            # originally signed
-            event.signatures.update(
-                compute_event_signature(
-                    event,
-                    self.hs.hostname,
-                    self.hs.config.signing_key[0]
+            if self.hs.is_mine_id(event.event_id):
+                # FIXME: This is a temporary work around where we occasionally
+                # return events slightly differently than when they were
+                # originally signed
+                event.signatures.update(
+                    compute_event_signature(
+                        event,
+                        self.hs.hostname,
+                        self.hs.config.signing_key[0]
+                    )
                 )
-            )
 
             if do_auth:
                 in_room = yield self.auth.check_host_in_room(
@@ -1112,6 +1113,12 @@ class FederationHandler(BaseHandler):
                 if not in_room:
                     raise AuthError(403, "Host not in room.")
 
+                events = yield self._filter_events_for_server(
+                    origin, event.room_id, [event]
+                )
+
+                event = events[0]
+
             defer.returnValue(event)
         else:
             defer.returnValue(None)
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 6b70fa3817..6a1fe76c88 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -503,7 +503,7 @@ class PresenceHandler(object):
         defer.returnValue(states)
 
     @defer.inlineCallbacks
-    def _get_interested_parties(self, states):
+    def _get_interested_parties(self, states, calculate_remote_hosts=True):
         """Given a list of states return which entities (rooms, users, servers)
         are interested in the given states.
 
@@ -526,14 +526,15 @@ class PresenceHandler(object):
             users_to_states.setdefault(state.user_id, []).append(state)
 
         hosts_to_states = {}
-        for room_id, states in room_ids_to_states.items():
-            local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
-            if not local_states:
-                continue
+        if calculate_remote_hosts:
+            for room_id, states in room_ids_to_states.items():
+                local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
+                if not local_states:
+                    continue
 
-            hosts = yield self.store.get_joined_hosts_for_room(room_id)
-            for host in hosts:
-                hosts_to_states.setdefault(host, []).extend(local_states)
+                hosts = yield self.store.get_joined_hosts_for_room(room_id)
+                for host in hosts:
+                    hosts_to_states.setdefault(host, []).extend(local_states)
 
         for user_id, states in users_to_states.items():
             local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
@@ -565,6 +566,16 @@ class PresenceHandler(object):
 
         self._push_to_remotes(hosts_to_states)
 
+    @defer.inlineCallbacks
+    def notify_for_states(self, state, stream_id):
+        parties = yield self._get_interested_parties([state])
+        room_ids_to_states, users_to_states, hosts_to_states = parties
+
+        self.notifier.on_new_event(
+            "presence_key", stream_id, rooms=room_ids_to_states.keys(),
+            users=[UserID.from_string(u) for u in users_to_states.keys()]
+        )
+
     def _push_to_remotes(self, hosts_to_states):
         """Sends state updates to remote servers.
 
@@ -672,7 +683,7 @@ class PresenceHandler(object):
             ])
 
     @defer.inlineCallbacks
-    def set_state(self, target_user, state):
+    def set_state(self, target_user, state, ignore_status_msg=False):
         """Set the presence state of the user.
         """
         status_msg = state.get("status_msg", None)
@@ -689,10 +700,13 @@ class PresenceHandler(object):
         prev_state = yield self.current_state_for_user(user_id)
 
         new_fields = {
-            "state": presence,
-            "status_msg": status_msg if presence != PresenceState.OFFLINE else None
+            "state": presence
         }
 
+        if not ignore_status_msg:
+            msg = status_msg if presence != PresenceState.OFFLINE else None
+            new_fields["status_msg"] = msg
+
         if presence == PresenceState.ONLINE:
             new_fields["last_active_ts"] = self.clock.time_msec()
 
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 8cec8fc4ed..4709112a0c 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -141,7 +141,7 @@ class RoomMemberHandler(BaseHandler):
             third_party_signed=None,
             ratelimit=True,
     ):
-        key = (target, room_id,)
+        key = (room_id,)
 
         with (yield self.member_linearizer.queue(key)):
             result = yield self._update_membership(
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index c3589534f8..f93093dd85 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -155,9 +155,7 @@ class MatrixFederationHttpClient(object):
                             time_out=timeout / 1000. if timeout else 60,
                         )
 
-                    response = yield preserve_context_over_fn(
-                        send_request,
-                    )
+                    response = yield preserve_context_over_fn(send_request)
 
                     log_result = "%d %s" % (response.code, response.phrase,)
                     break
diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py
index 8c2d487ff4..84993b33b3 100644
--- a/synapse/replication/resource.py
+++ b/synapse/replication/resource.py
@@ -41,6 +41,7 @@ STREAM_NAMES = (
     ("push_rules",),
     ("pushers",),
     ("state",),
+    ("caches",),
 )
 
 
@@ -70,6 +71,7 @@ class ReplicationResource(Resource):
     * "backfill": Old events that have been backfilled from other servers.
     * "push_rules": Per user changes to push rules.
     * "pushers": Per user changes to their pushers.
+    * "caches": Cache invalidations.
 
     The API takes two additional query parameters:
 
@@ -129,6 +131,7 @@ class ReplicationResource(Resource):
         push_rules_token, room_stream_token = self.store.get_push_rules_stream_token()
         pushers_token = self.store.get_pushers_stream_token()
         state_token = self.store.get_state_stream_token()
+        caches_token = self.store.get_cache_stream_token()
 
         defer.returnValue(_ReplicationToken(
             room_stream_token,
@@ -140,6 +143,7 @@ class ReplicationResource(Resource):
             push_rules_token,
             pushers_token,
             state_token,
+            caches_token,
         ))
 
     @request_handler()
@@ -188,6 +192,7 @@ class ReplicationResource(Resource):
         yield self.push_rules(writer, current_token, limit, request_streams)
         yield self.pushers(writer, current_token, limit, request_streams)
         yield self.state(writer, current_token, limit, request_streams)
+        yield self.caches(writer, current_token, limit, request_streams)
         self.streams(writer, current_token, request_streams)
 
         logger.info("Replicated %d rows", writer.total)
@@ -379,6 +384,20 @@ class ReplicationResource(Resource):
                 "position", "type", "state_key", "event_id"
             ))
 
+    @defer.inlineCallbacks
+    def caches(self, writer, current_token, limit, request_streams):
+        current_position = current_token.caches
+
+        caches = request_streams.get("caches")
+
+        if caches is not None:
+            updated_caches = yield self.store.get_all_updated_caches(
+                caches, current_position, limit
+            )
+            writer.write_header_and_rows("caches", updated_caches, (
+                "position", "cache_func", "keys", "invalidation_ts"
+            ))
+
 
 class _Writer(object):
     """Writes the streams as a JSON object as the response to the request"""
@@ -407,7 +426,7 @@ class _Writer(object):
 
 class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
     "events", "presence", "typing", "receipts", "account_data", "backfill",
-    "push_rules", "pushers", "state"
+    "push_rules", "pushers", "state", "caches",
 ))):
     __slots__ = []
 
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index 46e43ce1c7..d839d169ab 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -14,15 +14,43 @@
 # limitations under the License.
 
 from synapse.storage._base import SQLBaseStore
+from synapse.storage.engines import PostgresEngine
 from twisted.internet import defer
 
+from ._slaved_id_tracker import SlavedIdTracker
+
+import logging
+
+logger = logging.getLogger(__name__)
+
 
 class BaseSlavedStore(SQLBaseStore):
     def __init__(self, db_conn, hs):
         super(BaseSlavedStore, self).__init__(hs)
+        if isinstance(self.database_engine, PostgresEngine):
+            self._cache_id_gen = SlavedIdTracker(
+                db_conn, "cache_invalidation_stream", "stream_id",
+            )
+        else:
+            self._cache_id_gen = None
 
     def stream_positions(self):
-        return {}
+        pos = {}
+        if self._cache_id_gen:
+            pos["caches"] = self._cache_id_gen.get_current_token()
+        return pos
 
     def process_replication(self, result):
+        stream = result.get("caches")
+        if stream:
+            for row in stream["rows"]:
+                (
+                    position, cache_func, keys, invalidation_ts,
+                ) = row
+
+                try:
+                    getattr(self, cache_func).invalidate(tuple(keys))
+                except AttributeError:
+                    logger.warn("Got unexpected cache_func: %r", cache_func)
+            self._cache_id_gen.advance(int(stream["position"]))
         return defer.succeed(None)
diff --git a/synapse/replication/slave/storage/directory.py b/synapse/replication/slave/storage/directory.py
index 5fbe3a303a..7301d885f2 100644
--- a/synapse/replication/slave/storage/directory.py
+++ b/synapse/replication/slave/storage/directory.py
@@ -20,4 +20,4 @@ from synapse.storage.directory import DirectoryStore
 class DirectoryStore(BaseSlavedStore):
     get_aliases_for_room = DirectoryStore.__dict__[
         "get_aliases_for_room"
-    ].orig
+    ]
diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py
index b0cb31a448..af21661d7c 100644
--- a/synapse/rest/client/v1/admin.py
+++ b/synapse/rest/client/v1/admin.py
@@ -28,6 +28,10 @@ logger = logging.getLogger(__name__)
 class WhoisRestServlet(ClientV1RestServlet):
     PATTERNS = client_path_patterns("/admin/whois/(?P<user_id>[^/]*)")
 
+    def __init__(self, hs):
+        super(WhoisRestServlet, self).__init__(hs)
+        self.handlers = hs.get_handlers()
+
     @defer.inlineCallbacks
     def on_GET(self, request, user_id):
         target_user = UserID.from_string(user_id)
@@ -82,6 +86,10 @@ class PurgeHistoryRestServlet(ClientV1RestServlet):
         "/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
     )
 
+    def __init__(self, hs):
+        super(PurgeHistoryRestServlet, self).__init__(hs)
+        self.handlers = hs.get_handlers()
+
     @defer.inlineCallbacks
     def on_POST(self, request, room_id, event_id):
         requester = yield self.auth.get_user_by_req(request)
diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py
index 96b49b01f2..c2a8447860 100644
--- a/synapse/rest/client/v1/base.py
+++ b/synapse/rest/client/v1/base.py
@@ -57,7 +57,6 @@ class ClientV1RestServlet(RestServlet):
             hs (synapse.server.HomeServer):
         """
         self.hs = hs
-        self.handlers = hs.get_handlers()
         self.builder_factory = hs.get_event_builder_factory()
         self.auth = hs.get_v1auth()
         self.txns = HttpTransactionStore()
diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py
index 8ac09419dc..09d0831594 100644
--- a/synapse/rest/client/v1/directory.py
+++ b/synapse/rest/client/v1/directory.py
@@ -36,6 +36,10 @@ def register_servlets(hs, http_server):
 class ClientDirectoryServer(ClientV1RestServlet):
     PATTERNS = client_path_patterns("/directory/room/(?P<room_alias>[^/]*)$")
 
+    def __init__(self, hs):
+        super(ClientDirectoryServer, self).__init__(hs)
+        self.handlers = hs.get_handlers()
+
     @defer.inlineCallbacks
     def on_GET(self, request, room_alias):
         room_alias = RoomAlias.from_string(room_alias)
@@ -146,6 +150,7 @@ class ClientDirectoryListServer(ClientV1RestServlet):
     def __init__(self, hs):
         super(ClientDirectoryListServer, self).__init__(hs)
         self.store = hs.get_datastore()
+        self.handlers = hs.get_handlers()
 
     @defer.inlineCallbacks
     def on_GET(self, request, room_id):
diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py
index 498bb9e18a..701b6f549b 100644
--- a/synapse/rest/client/v1/events.py
+++ b/synapse/rest/client/v1/events.py
@@ -32,6 +32,10 @@ class EventStreamRestServlet(ClientV1RestServlet):
 
     DEFAULT_LONGPOLL_TIME_MS = 30000
 
+    def __init__(self, hs):
+        super(EventStreamRestServlet, self).__init__(hs)
+        self.event_stream_handler = hs.get_event_stream_handler()
+
     @defer.inlineCallbacks
     def on_GET(self, request):
         requester = yield self.auth.get_user_by_req(
@@ -46,7 +50,6 @@ class EventStreamRestServlet(ClientV1RestServlet):
         if "room_id" in request.args:
             room_id = request.args["room_id"][0]
 
-        handler = self.handlers.event_stream_handler
         pagin_config = PaginationConfig.from_request(request)
         timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
         if "timeout" in request.args:
@@ -57,7 +60,7 @@ class EventStreamRestServlet(ClientV1RestServlet):
 
         as_client_event = "raw" not in request.args
 
-        chunk = yield handler.get_stream(
+        chunk = yield self.event_stream_handler.get_stream(
             requester.user.to_string(),
             pagin_config,
             timeout=timeout,
@@ -80,12 +83,12 @@ class EventRestServlet(ClientV1RestServlet):
     def __init__(self, hs):
         super(EventRestServlet, self).__init__(hs)
         self.clock = hs.get_clock()
+        self.event_handler = hs.get_event_handler()
 
     @defer.inlineCallbacks
     def on_GET(self, request, event_id):
         requester = yield self.auth.get_user_by_req(request)
-        handler = self.handlers.event_handler
-        event = yield handler.get_event(requester.user, event_id)
+        event = yield self.event_handler.get_event(requester.user, event_id)
 
         time_now = self.clock.time_msec()
         if event:
diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py
index 36c3520567..113a49e539 100644
--- a/synapse/rest/client/v1/initial_sync.py
+++ b/synapse/rest/client/v1/initial_sync.py
@@ -23,6 +23,10 @@ from .base import ClientV1RestServlet, client_path_patterns
 class InitialSyncRestServlet(ClientV1RestServlet):
     PATTERNS = client_path_patterns("/initialSync$")
 
+    def __init__(self, hs):
+        super(InitialSyncRestServlet, self).__init__(hs)
+        self.handlers = hs.get_handlers()
+
     @defer.inlineCallbacks
     def on_GET(self, request):
         requester = yield self.auth.get_user_by_req(request)
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 92fcae674a..6c0eec8fb3 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -54,12 +54,9 @@ class LoginRestServlet(ClientV1RestServlet):
         self.jwt_secret = hs.config.jwt_secret
         self.jwt_algorithm = hs.config.jwt_algorithm
         self.cas_enabled = hs.config.cas_enabled
-        self.cas_server_url = hs.config.cas_server_url
-        self.cas_required_attributes = hs.config.cas_required_attributes
-        self.servername = hs.config.server_name
-        self.http_client = hs.get_simple_http_client()
         self.auth_handler = self.hs.get_auth_handler()
         self.device_handler = self.hs.get_device_handler()
+        self.handlers = hs.get_handlers()
 
     def on_GET(self, request):
         flows = []
@@ -110,17 +107,6 @@ class LoginRestServlet(ClientV1RestServlet):
                                        LoginRestServlet.JWT_TYPE):
                 result = yield self.do_jwt_login(login_submission)
                 defer.returnValue(result)
-            # TODO Delete this after all CAS clients switch to token login instead
-            elif self.cas_enabled and (login_submission["type"] ==
-                                       LoginRestServlet.CAS_TYPE):
-                uri = "%s/proxyValidate" % (self.cas_server_url,)
-                args = {
-                    "ticket": login_submission["ticket"],
-                    "service": login_submission["service"]
-                }
-                body = yield self.http_client.get_raw(uri, args)
-                result = yield self.do_cas_login(body)
-                defer.returnValue(result)
             elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
                 result = yield self.do_token_login(login_submission)
                 defer.returnValue(result)
@@ -191,51 +177,6 @@ class LoginRestServlet(ClientV1RestServlet):
 
         defer.returnValue((200, result))
 
-    # TODO Delete this after all CAS clients switch to token login instead
-    @defer.inlineCallbacks
-    def do_cas_login(self, cas_response_body):
-        user, attributes = self.parse_cas_response(cas_response_body)
-
-        for required_attribute, required_value in self.cas_required_attributes.items():
-            # If required attribute was not in CAS Response - Forbidden
-            if required_attribute not in attributes:
-                raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
-
-            # Also need to check value
-            if required_value is not None:
-                actual_value = attributes[required_attribute]
-                # If required attribute value does not match expected - Forbidden
-                if required_value != actual_value:
-                    raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
-
-        user_id = UserID.create(user, self.hs.hostname).to_string()
-        auth_handler = self.auth_handler
-        registered_user_id = yield auth_handler.check_user_exists(user_id)
-        if registered_user_id:
-            access_token, refresh_token = (
-                yield auth_handler.get_login_tuple_for_user_id(
-                    registered_user_id
-                )
-            )
-            result = {
-                "user_id": registered_user_id,  # may have changed
-                "access_token": access_token,
-                "refresh_token": refresh_token,
-                "home_server": self.hs.hostname,
-            }
-
-        else:
-            user_id, access_token = (
-                yield self.handlers.registration_handler.register(localpart=user)
-            )
-            result = {
-                "user_id": user_id,  # may have changed
-                "access_token": access_token,
-                "home_server": self.hs.hostname,
-            }
-
-        defer.returnValue((200, result))
-
     @defer.inlineCallbacks
     def do_jwt_login(self, login_submission):
         token = login_submission.get("token", None)
@@ -293,33 +234,6 @@ class LoginRestServlet(ClientV1RestServlet):
 
         defer.returnValue((200, result))
 
-    # TODO Delete this after all CAS clients switch to token login instead
-    def parse_cas_response(self, cas_response_body):
-        root = ET.fromstring(cas_response_body)
-        if not root.tag.endswith("serviceResponse"):
-            raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
-        if not root[0].tag.endswith("authenticationSuccess"):
-            raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED)
-        for child in root[0]:
-            if child.tag.endswith("user"):
-                user = child.text
-            if child.tag.endswith("attributes"):
-                attributes = {}
-                for attribute in child:
-                    # ElementTree library expands the namespace in attribute tags
-                    # to the full URL of the namespace.
-                    # See (https://docs.python.org/2/library/xml.etree.elementtree.html)
-                    # We don't care about namespace here and it will always be encased in
-                    # curly braces, so we remove them.
-                    if "}" in attribute.tag:
-                        attributes[attribute.tag.split("}")[1]] = attribute.text
-                    else:
-                        attributes[attribute.tag] = attribute.text
-        if user is None or attributes is None:
-            raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
-
-        return (user, attributes)
-
     def _register_device(self, user_id, login_submission):
         """Register a device for a user.
 
@@ -347,6 +261,7 @@ class SAML2RestServlet(ClientV1RestServlet):
     def __init__(self, hs):
         super(SAML2RestServlet, self).__init__(hs)
         self.sp_config = hs.config.saml2_config_path
+        self.handlers = hs.get_handlers()
 
     @defer.inlineCallbacks
     def on_POST(self, request):
@@ -384,18 +299,6 @@ class SAML2RestServlet(ClientV1RestServlet):
         defer.returnValue((200, {"status": "not_authenticated"}))
 
 
-# TODO Delete this after all CAS clients switch to token login instead
-class CasRestServlet(ClientV1RestServlet):
-    PATTERNS = client_path_patterns("/login/cas", releases=())
-
-    def __init__(self, hs):
-        super(CasRestServlet, self).__init__(hs)
-        self.cas_server_url = hs.config.cas_server_url
-
-    def on_GET(self, request):
-        return (200, {"serverUrl": self.cas_server_url})
-
-
 class CasRedirectServlet(ClientV1RestServlet):
     PATTERNS = client_path_patterns("/login/cas/redirect", releases=())
 
@@ -427,6 +330,8 @@ class CasTicketServlet(ClientV1RestServlet):
         self.cas_server_url = hs.config.cas_server_url
         self.cas_service_url = hs.config.cas_service_url
         self.cas_required_attributes = hs.config.cas_required_attributes
+        self.auth_handler = hs.get_auth_handler()
+        self.handlers = hs.get_handlers()
 
     @defer.inlineCallbacks
     def on_GET(self, request):
@@ -479,30 +384,39 @@ class CasTicketServlet(ClientV1RestServlet):
         return urlparse.urlunparse(url_parts)
 
     def parse_cas_response(self, cas_response_body):
-        root = ET.fromstring(cas_response_body)
-        if not root.tag.endswith("serviceResponse"):
-            raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
-        if not root[0].tag.endswith("authenticationSuccess"):
-            raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED)
-        for child in root[0]:
-            if child.tag.endswith("user"):
-                user = child.text
-            if child.tag.endswith("attributes"):
-                attributes = {}
-                for attribute in child:
-                    # ElementTree library expands the namespace in attribute tags
-                    # to the full URL of the namespace.
-                    # See (https://docs.python.org/2/library/xml.etree.elementtree.html)
-                    # We don't care about namespace here and it will always be encased in
-                    # curly braces, so we remove them.
-                    if "}" in attribute.tag:
-                        attributes[attribute.tag.split("}")[1]] = attribute.text
-                    else:
-                        attributes[attribute.tag] = attribute.text
-        if user is None or attributes is None:
-            raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
-
-        return (user, attributes)
+        user = None
+        attributes = None
+        try:
+            root = ET.fromstring(cas_response_body)
+            if not root.tag.endswith("serviceResponse"):
+                raise Exception("root of CAS response is not serviceResponse")
+            success = (root[0].tag.endswith("authenticationSuccess"))
+            for child in root[0]:
+                if child.tag.endswith("user"):
+                    user = child.text
+                if child.tag.endswith("attributes"):
+                    attributes = {}
+                    for attribute in child:
+                        # ElementTree library expands the namespace in
+                        # attribute tags to the full URL of the namespace.
+                        # We don't care about namespace here and it will always
+                        # be encased in curly braces, so we remove them.
+                        tag = attribute.tag
+                        if "}" in tag:
+                            tag = tag.split("}")[1]
+                        attributes[tag] = attribute.text
+            if user is None:
+                raise Exception("CAS response does not contain user")
+            if attributes is None:
+                raise Exception("CAS response does not contain attributes")
+        except Exception:
+            logger.error("Error parsing CAS response", exc_info=1)
+            raise LoginError(401, "Invalid CAS response",
+                             errcode=Codes.UNAUTHORIZED)
+        if not success:
+            raise LoginError(401, "Unsuccessful CAS response",
+                             errcode=Codes.UNAUTHORIZED)
+        return user, attributes
 
 
 def register_servlets(hs, http_server):
@@ -512,5 +426,3 @@ def register_servlets(hs, http_server):
     if hs.config.cas_enabled:
         CasRedirectServlet(hs).register(http_server)
         CasTicketServlet(hs).register(http_server)
-        CasRestServlet(hs).register(http_server)
-    # TODO PasswordResetRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index 65c4e2ebef..355e82474b 100644
--- a/synapse/rest/client/v1/profile.py
+++ b/synapse/rest/client/v1/profile.py
@@ -24,6 +24,10 @@ from synapse.http.servlet import parse_json_object_from_request
 class ProfileDisplaynameRestServlet(ClientV1RestServlet):
     PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/displayname")
 
+    def __init__(self, hs):
+        super(ProfileDisplaynameRestServlet, self).__init__(hs)
+        self.handlers = hs.get_handlers()
+
     @defer.inlineCallbacks
     def on_GET(self, request, user_id):
         user = UserID.from_string(user_id)
@@ -62,6 +66,10 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
 class ProfileAvatarURLRestServlet(ClientV1RestServlet):
     PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/avatar_url")
 
+    def __init__(self, hs):
+        super(ProfileAvatarURLRestServlet, self).__init__(hs)
+        self.handlers = hs.get_handlers()
+
     @defer.inlineCallbacks
     def on_GET(self, request, user_id):
         user = UserID.from_string(user_id)
@@ -99,6 +107,10 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
 class ProfileRestServlet(ClientV1RestServlet):
     PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)")
 
+    def __init__(self, hs):
+        super(ProfileRestServlet, self).__init__(hs)
+        self.handlers = hs.get_handlers()
+
     @defer.inlineCallbacks
     def on_GET(self, request, user_id):
         user = UserID.from_string(user_id)
diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py
index 2383b9df86..71d58c8e8d 100644
--- a/synapse/rest/client/v1/register.py
+++ b/synapse/rest/client/v1/register.py
@@ -65,6 +65,7 @@ class RegisterRestServlet(ClientV1RestServlet):
         self.sessions = {}
         self.enable_registration = hs.config.enable_registration
         self.auth_handler = hs.get_auth_handler()
+        self.handlers = hs.get_handlers()
 
     def on_GET(self, request):
         if self.hs.config.enable_registration_captcha:
@@ -383,6 +384,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
         super(CreateUserRestServlet, self).__init__(hs)
         self.store = hs.get_datastore()
         self.direct_user_creation_max_duration = hs.config.user_creation_max_duration
+        self.handlers = hs.get_handlers()
 
     @defer.inlineCallbacks
     def on_POST(self, request):
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 866a1e9120..89c3895118 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -35,6 +35,10 @@ logger = logging.getLogger(__name__)
 class RoomCreateRestServlet(ClientV1RestServlet):
     # No PATTERN; we have custom dispatch rules here
 
+    def __init__(self, hs):
+        super(RoomCreateRestServlet, self).__init__(hs)
+        self.handlers = hs.get_handlers()
+
     def register(self, http_server):
         PATTERNS = "/createRoom"
         register_txn_path(self, PATTERNS, http_server)
@@ -82,6 +86,10 @@ class RoomCreateRestServlet(ClientV1RestServlet):
 
 # TODO: Needs unit testing for generic events
 class RoomStateEventRestServlet(ClientV1RestServlet):
+    def __init__(self, hs):
+        super(RoomStateEventRestServlet, self).__init__(hs)
+        self.handlers = hs.get_handlers()
+
     def register(self, http_server):
         # /room/$roomid/state/$eventtype
         no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$"
@@ -166,6 +174,10 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
 # TODO: Needs unit testing for generic events + feedback
 class RoomSendEventRestServlet(ClientV1RestServlet):
 
+    def __init__(self, hs):
+        super(RoomSendEventRestServlet, self).__init__(hs)
+        self.handlers = hs.get_handlers()
+
     def register(self, http_server):
         # /rooms/$roomid/send/$event_type[/$txn_id]
         PATTERNS = ("/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)")
@@ -210,6 +222,9 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
 
 # TODO: Needs unit testing for room ID + alias joins
 class JoinRoomAliasServlet(ClientV1RestServlet):
+    def __init__(self, hs):
+        super(JoinRoomAliasServlet, self).__init__(hs)
+        self.handlers = hs.get_handlers()
 
     def register(self, http_server):
         # /join/$room_identifier[/$txn_id]
@@ -296,6 +311,10 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
 class RoomMemberListRestServlet(ClientV1RestServlet):
     PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/members$")
 
+    def __init__(self, hs):
+        super(RoomMemberListRestServlet, self).__init__(hs)
+        self.handlers = hs.get_handlers()
+
     @defer.inlineCallbacks
     def on_GET(self, request, room_id):
         # TODO support Pagination stream API (limit/tokens)
@@ -322,6 +341,10 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
 class RoomMessageListRestServlet(ClientV1RestServlet):
     PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/messages$")
 
+    def __init__(self, hs):
+        super(RoomMessageListRestServlet, self).__init__(hs)
+        self.handlers = hs.get_handlers()
+
     @defer.inlineCallbacks
     def on_GET(self, request, room_id):
         requester = yield self.auth.get_user_by_req(request, allow_guest=True)
@@ -351,6 +374,10 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
 class RoomStateRestServlet(ClientV1RestServlet):
     PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/state$")
 
+    def __init__(self, hs):
+        super(RoomStateRestServlet, self).__init__(hs)
+        self.handlers = hs.get_handlers()
+
     @defer.inlineCallbacks
     def on_GET(self, request, room_id):
         requester = yield self.auth.get_user_by_req(request, allow_guest=True)
@@ -368,6 +395,10 @@ class RoomStateRestServlet(ClientV1RestServlet):
 class RoomInitialSyncRestServlet(ClientV1RestServlet):
     PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$")
 
+    def __init__(self, hs):
+        super(RoomInitialSyncRestServlet, self).__init__(hs)
+        self.handlers = hs.get_handlers()
+
     @defer.inlineCallbacks
     def on_GET(self, request, room_id):
         requester = yield self.auth.get_user_by_req(request, allow_guest=True)
@@ -388,6 +419,7 @@ class RoomEventContext(ClientV1RestServlet):
     def __init__(self, hs):
         super(RoomEventContext, self).__init__(hs)
         self.clock = hs.get_clock()
+        self.handlers = hs.get_handlers()
 
     @defer.inlineCallbacks
     def on_GET(self, request, room_id, event_id):
@@ -424,6 +456,10 @@ class RoomEventContext(ClientV1RestServlet):
 
 
 class RoomForgetRestServlet(ClientV1RestServlet):
+    def __init__(self, hs):
+        super(RoomForgetRestServlet, self).__init__(hs)
+        self.handlers = hs.get_handlers()
+
     def register(self, http_server):
         PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget")
         register_txn_path(self, PATTERNS, http_server)
@@ -462,6 +498,10 @@ class RoomForgetRestServlet(ClientV1RestServlet):
 # TODO: Needs unit testing
 class RoomMembershipRestServlet(ClientV1RestServlet):
 
+    def __init__(self, hs):
+        super(RoomMembershipRestServlet, self).__init__(hs)
+        self.handlers = hs.get_handlers()
+
     def register(self, http_server):
         # /rooms/$roomid/[invite|join|leave]
         PATTERNS = ("/rooms/(?P<room_id>[^/]*)/"
@@ -542,6 +582,10 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
 
 
 class RoomRedactEventRestServlet(ClientV1RestServlet):
+    def __init__(self, hs):
+        super(RoomRedactEventRestServlet, self).__init__(hs)
+        self.handlers = hs.get_handlers()
+
     def register(self, http_server):
         PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
         register_txn_path(self, PATTERNS, http_server)
@@ -624,6 +668,10 @@ class SearchRestServlet(ClientV1RestServlet):
         "/search$"
     )
 
+    def __init__(self, hs):
+        super(SearchRestServlet, self).__init__(hs)
+        self.handlers = hs.get_handlers()
+
     @defer.inlineCallbacks
     def on_POST(self, request):
         requester = yield self.auth.get_user_by_req(request)
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 43d8e0bf39..b11acdbea7 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -146,7 +146,7 @@ class SyncRestServlet(RestServlet):
         affect_presence = set_presence != PresenceState.OFFLINE
 
         if affect_presence:
-            yield self.presence_handler.set_state(user, {"presence": set_presence})
+            yield self.presence_handler.set_state(user, {"presence": set_presence}, True)
 
         context = yield self.presence_handler.user_syncing(
             user.to_string(), affect_presence=affect_presence,
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 7209d5a37d..9fe2013657 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -15,6 +15,7 @@
 from synapse.http.server import request_handler, respond_with_json_bytes
 from synapse.http.servlet import parse_integer, parse_json_object_from_request
 from synapse.api.errors import SynapseError, Codes
+from synapse.crypto.keyring import KeyLookupError
 
 from twisted.web.resource import Resource
 from twisted.web.server import NOT_DONE_YET
@@ -210,9 +211,10 @@ class RemoteKey(Resource):
                     yield self.keyring.get_server_verify_key_v2_direct(
                         server_name, key_ids
                     )
+                except KeyLookupError as e:
+                    logger.info("Failed to fetch key: %s", e)
                 except:
                     logger.exception("Failed to get key for %r", server_name)
-                    pass
             yield self.query_keys(
                 request, query, query_remote_on_cache_miss=False
             )
diff --git a/synapse/server.py b/synapse/server.py
index 6bb4988309..af3246504b 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -41,6 +41,7 @@ from synapse.handlers.presence import PresenceHandler
 from synapse.handlers.room import RoomListHandler
 from synapse.handlers.sync import SyncHandler
 from synapse.handlers.typing import TypingHandler
+from synapse.handlers.events import EventHandler, EventStreamHandler
 from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
 from synapse.http.matrixfederationclient import MatrixFederationHttpClient
 from synapse.notifier import Notifier
@@ -94,6 +95,8 @@ class HomeServer(object):
         'auth_handler',
         'device_handler',
         'e2e_keys_handler',
+        'event_handler',
+        'event_stream_handler',
         'application_service_api',
         'application_service_scheduler',
         'application_service_handler',
@@ -214,6 +217,12 @@ class HomeServer(object):
     def build_application_service_handler(self):
         return ApplicationServicesHandler(self)
 
+    def build_event_handler(self):
+        return EventHandler(self)
+
+    def build_event_stream_handler(self):
+        return EventStreamHandler(self)
+
     def build_event_sources(self):
         return EventSources(self)
 
diff --git a/synapse/server.pyi b/synapse/server.pyi
index c0aa868c4f..9570df5537 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -1,3 +1,4 @@
+import synapse.api.auth
 import synapse.handlers
 import synapse.handlers.auth
 import synapse.handlers.device
@@ -6,6 +7,9 @@ import synapse.storage
 import synapse.state
 
 class HomeServer(object):
+    def get_auth(self) -> synapse.api.auth.Auth:
+        pass
+
     def get_auth_handler(self) -> synapse.handlers.auth.AuthHandler:
         pass
 
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 73fb334dd6..7efc5bfeef 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -50,6 +50,7 @@ from .openid import OpenIdStore
 from .client_ips import ClientIpStore
 
 from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator
+from .engines import PostgresEngine
 
 from synapse.api.constants import PresenceState
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -123,6 +124,13 @@ class DataStore(RoomMemberStore, RoomStore,
             extra_tables=[("deleted_pushers", "stream_id")],
         )
 
+        if isinstance(self.database_engine, PostgresEngine):
+            self._cache_id_gen = StreamIdGenerator(
+                db_conn, "cache_invalidation_stream", "stream_id",
+            )
+        else:
+            self._cache_id_gen = None
+
         events_max = self._stream_id_gen.get_current_token()
         event_cache_prefill, min_event_val = self._get_cache_dict(
             db_conn, "events",
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 0117fdc639..b0923a9cad 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -19,6 +19,7 @@ from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
 from synapse.util.caches.dictionary_cache import DictionaryCache
 from synapse.util.caches.descriptors import Cache
 from synapse.util.caches import intern_dict
+from synapse.storage.engines import PostgresEngine
 import synapse.metrics
 
 
@@ -305,13 +306,14 @@ class SQLBaseStore(object):
                     func, *args, **kwargs
                 )
 
-        with PreserveLoggingContext():
-            result = yield self._db_pool.runWithConnection(
-                inner_func, *args, **kwargs
-            )
-
-        for after_callback, after_args in after_callbacks:
-            after_callback(*after_args)
+        try:
+            with PreserveLoggingContext():
+                result = yield self._db_pool.runWithConnection(
+                    inner_func, *args, **kwargs
+                )
+        finally:
+            for after_callback, after_args in after_callbacks:
+                after_callback(*after_args)
         defer.returnValue(result)
 
     @defer.inlineCallbacks
@@ -860,6 +862,58 @@ class SQLBaseStore(object):
 
         return cache, min_val
 
+    def _invalidate_cache_and_stream(self, txn, cache_func, keys):
+        """Invalidates the cache and adds it to the cache stream so slaves
+        will know to invalidate their caches.
+
+        This should only be used to invalidate caches where slaves won't
+        otherwise know from other replication streams that the cache should
+        be invalidated.
+        """
+        txn.call_after(cache_func.invalidate, keys)
+
+        if isinstance(self.database_engine, PostgresEngine):
+            # get_next() returns a context manager which is designed to wrap
+            # the transaction. However, we want to only get an ID when we want
+            # to use it, here, so we need to call __enter__ manually, and have
+            # __exit__ called after the transaction finishes.
+            ctx = self._cache_id_gen.get_next()
+            stream_id = ctx.__enter__()
+            txn.call_after(ctx.__exit__, None, None, None)
+
+            self._simple_insert_txn(
+                txn,
+                table="cache_invalidation_stream",
+                values={
+                    "stream_id": stream_id,
+                    "cache_func": cache_func.__name__,
+                    "keys": list(keys),
+                    "invalidation_ts": self.clock.time_msec(),
+                }
+            )
+
+    def get_all_updated_caches(self, last_id, current_id, limit):
+        def get_all_updated_caches_txn(txn):
+            # We purposefully don't bound by the current token, as we want to
+            # send across cache invalidations as quickly as possible. Cache
+            # invalidations are idempotent, so duplicates are fine.
+            sql = (
+                "SELECT stream_id, cache_func, keys, invalidation_ts"
+                " FROM cache_invalidation_stream"
+                " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
+            )
+            txn.execute(sql, (last_id, limit,))
+            return txn.fetchall()
+        return self.runInteraction(
+            "get_all_updated_caches", get_all_updated_caches_txn
+        )
+
+    def get_cache_stream_token(self):
+        if self._cache_id_gen:
+            return self._cache_id_gen.get_current_token()
+        else:
+            return 0
+
 
 class _RollbackButIsFineException(Exception):
     """ This exception is used to rollback a transaction without implying
diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py
index ef231a04dc..9caaf81f2c 100644
--- a/synapse/storage/directory.py
+++ b/synapse/storage/directory.py
@@ -82,32 +82,39 @@ class DirectoryStore(SQLBaseStore):
         Returns:
             Deferred
         """
-        try:
-            yield self._simple_insert(
+        def alias_txn(txn):
+            self._simple_insert_txn(
+                txn,
                 "room_aliases",
                 {
                     "room_alias": room_alias.to_string(),
                     "room_id": room_id,
                     "creator": creator,
                 },
-                desc="create_room_alias_association",
-            )
-        except self.database_engine.module.IntegrityError:
-            raise SynapseError(
-                409, "Room alias %s already exists" % room_alias.to_string()
             )
 
-        for server in servers:
-            # TODO(erikj): Fix this to bulk insert
-            yield self._simple_insert(
-                "room_alias_servers",
-                {
+            self._simple_insert_many_txn(
+                txn,
+                table="room_alias_servers",
+                values=[{
                     "room_alias": room_alias.to_string(),
                     "server": server,
-                },
-                desc="create_room_alias_association",
+                } for server in servers],
             )
-        self.get_aliases_for_room.invalidate((room_id,))
+
+            self._invalidate_cache_and_stream(
+                txn, self.get_aliases_for_room, (room_id,)
+            )
+
+        try:
+            ret = yield self.runInteraction(
+                "create_room_alias_association", alias_txn
+            )
+        except self.database_engine.module.IntegrityError:
+            raise SynapseError(
+                409, "Room alias %s already exists" % room_alias.to_string()
+            )
+        defer.returnValue(ret)
 
     def get_room_alias_creator(self, room_alias):
         return self._simple_select_one_onecol(
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 8801669a6b..b94ce7bea1 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
 
 # Remember to update this number every time a change is made to database
 # schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 33
+SCHEMA_VERSION = 34
 
 dir_path = os.path.abspath(os.path.dirname(__file__))
 
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index d03f7c541e..21d0696640 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -189,18 +189,30 @@ class PresenceStore(SQLBaseStore):
             desc="add_presence_list_pending",
         )
 
-    @defer.inlineCallbacks
     def set_presence_list_accepted(self, observer_localpart, observed_userid):
-        result = yield self._simple_update_one(
-            table="presence_list",
-            keyvalues={"user_id": observer_localpart,
-                       "observed_user_id": observed_userid},
-            updatevalues={"accepted": True},
-            desc="set_presence_list_accepted",
+        def update_presence_list_txn(txn):
+            result = self._simple_update_one_txn(
+                txn,
+                table="presence_list",
+                keyvalues={
+                    "user_id": observer_localpart,
+                    "observed_user_id": observed_userid
+                },
+                updatevalues={"accepted": True},
+            )
+
+            self._invalidate_cache_and_stream(
+                txn, self.get_presence_list_accepted, (observer_localpart,)
+            )
+            self._invalidate_cache_and_stream(
+                txn, self.get_presence_list_observers_accepted, (observed_userid,)
+            )
+
+            return result
+
+        return self.runInteraction(
+            "set_presence_list_accepted", update_presence_list_txn,
         )
-        self.get_presence_list_accepted.invalidate((observer_localpart,))
-        self.get_presence_list_observers_accepted.invalidate((observed_userid,))
-        defer.returnValue(result)
 
     def get_presence_list(self, observer_localpart, accepted=None):
         if accepted:
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 8bd693be72..a422ddf633 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -277,7 +277,6 @@ class RoomMemberStore(SQLBaseStore):
             user_id, membership_list=[Membership.JOIN],
         )
 
-    @defer.inlineCallbacks
     def forget(self, user_id, room_id):
         """Indicate that user_id wishes to discard history for room_id."""
         def f(txn):
@@ -292,10 +291,13 @@ class RoomMemberStore(SQLBaseStore):
                 "  room_id = ?"
             )
             txn.execute(sql, (user_id, room_id))
-        yield self.runInteraction("forget_membership", f)
-        self.was_forgotten_at.invalidate_all()
-        self.who_forgot_in_room.invalidate_all()
-        self.did_forget.invalidate((user_id, room_id))
+
+            txn.call_after(self.was_forgotten_at.invalidate_all)
+            txn.call_after(self.did_forget.invalidate, (user_id, room_id))
+            self._invalidate_cache_and_stream(
+                txn, self.who_forgot_in_room, (room_id,)
+            )
+        return self.runInteraction("forget_membership", f)
 
     @cachedInlineCallbacks(num_args=2)
     def did_forget(self, user_id, room_id):
diff --git a/synapse/storage/schema/delta/34/cache_stream.py b/synapse/storage/schema/delta/34/cache_stream.py
new file mode 100644
index 0000000000..3b63a1562d
--- /dev/null
+++ b/synapse/storage/schema/delta/34/cache_stream.py
@@ -0,0 +1,46 @@
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage.prepare_database import get_statements
+from synapse.storage.engines import PostgresEngine
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+# This stream is used to notify replication slaves that some caches have
+# been invalidated that they cannot infer from the other streams.
+CREATE_TABLE = """
+CREATE TABLE cache_invalidation_stream (
+    stream_id       BIGINT,
+    cache_func      TEXT,
+    keys            TEXT[],
+    invalidation_ts BIGINT
+);
+
+CREATE INDEX cache_invalidation_stream_id ON cache_invalidation_stream(stream_id);
+"""
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+    if not isinstance(database_engine, PostgresEngine):
+        return
+
+    for statement in get_statements(CREATE_TABLE.splitlines()):
+        cur.execute(statement)
+
+
+def run_upgrade(cur, database_engine, *args, **kwargs):
+    pass
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index 5316259d15..7a87045f87 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -317,7 +317,6 @@ def preserve_fn(f):
     def g(*args, **kwargs):
         with PreserveLoggingContext(current):
             return f(*args, **kwargs)
-
     return g
 
 
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 0b944d3e63..76f301f549 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -13,10 +13,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from twisted.internet import defer
 
 from synapse.util.logcontext import LoggingContext
 import synapse.metrics
 
+from functools import wraps
 import logging
 
 
@@ -47,6 +49,18 @@ block_db_txn_duration = metrics.register_distribution(
 )
 
 
+def measure_func(name):
+    def wrapper(func):
+        @wraps(func)
+        @defer.inlineCallbacks
+        def measured_func(self, *args, **kwargs):
+            with Measure(self.clock, name):
+                r = yield func(self, *args, **kwargs)
+            defer.returnValue(r)
+        return measured_func
+    return wrapper
+
+
 class Measure(object):
     __slots__ = [
         "clock", "name", "start_context", "start", "new_context", "ru_utime",
@@ -64,7 +78,6 @@ class Measure(object):
         self.start = self.clock.time_msec()
         self.start_context = LoggingContext.current_context()
         if not self.start_context:
-            logger.warn("Entered Measure without log context: %s", self.name)
             self.start_context = LoggingContext("Measure")
             self.start_context.__enter__()
             self.created_context = True
@@ -85,7 +98,7 @@ class Measure(object):
         if context != self.start_context:
             logger.warn(
                 "Context has unexpectedly changed from '%s' to '%s'. (%r)",
-                context, self.start_context, self.name
+                self.start_context, context, self.name
             )
             return