summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/filtering.py261
-rw-r--r--synapse/crypto/keyring.py39
-rw-r--r--synapse/federation/federation_client.py40
-rw-r--r--synapse/federation/transaction_queue.py216
-rw-r--r--synapse/federation/transport/client.py7
-rw-r--r--synapse/handlers/directory.py1
-rw-r--r--synapse/handlers/e2e_keys.py32
-rw-r--r--synapse/handlers/profile.py6
-rw-r--r--synapse/http/matrixfederationclient.py294
-rw-r--r--synapse/python_dependencies.py1
-rw-r--r--synapse/rest/client/v2_alpha/account.py2
-rw-r--r--synapse/util/__init__.py10
-rw-r--r--synapse/util/retryutils.py29
-rw-r--r--synapse/visibility.py7
14 files changed, 533 insertions, 412 deletions
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 47f0cf0fa9..83206348e5 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -15,10 +15,172 @@
 from synapse.api.errors import SynapseError
 from synapse.storage.presence import UserPresenceState
 from synapse.types import UserID, RoomID
-
 from twisted.internet import defer
 
 import ujson as json
+import jsonschema
+from jsonschema import FormatChecker
+
+FILTER_SCHEMA = {
+    "additionalProperties": False,
+    "type": "object",
+    "properties": {
+        "limit": {
+            "type": "number"
+        },
+        "senders": {
+            "$ref": "#/definitions/user_id_array"
+        },
+        "not_senders": {
+            "$ref": "#/definitions/user_id_array"
+        },
+        # TODO: We don't limit event type values but we probably should...
+        # check types are valid event types
+        "types": {
+            "type": "array",
+            "items": {
+                "type": "string"
+            }
+        },
+        "not_types": {
+            "type": "array",
+            "items": {
+                "type": "string"
+            }
+        }
+    }
+}
+
+ROOM_FILTER_SCHEMA = {
+    "additionalProperties": False,
+    "type": "object",
+    "properties": {
+        "not_rooms": {
+            "$ref": "#/definitions/room_id_array"
+        },
+        "rooms": {
+            "$ref": "#/definitions/room_id_array"
+        },
+        "ephemeral": {
+            "$ref": "#/definitions/room_event_filter"
+        },
+        "include_leave": {
+            "type": "boolean"
+        },
+        "state": {
+            "$ref": "#/definitions/room_event_filter"
+        },
+        "timeline": {
+            "$ref": "#/definitions/room_event_filter"
+        },
+        "account_data": {
+            "$ref": "#/definitions/room_event_filter"
+        },
+    }
+}
+
+ROOM_EVENT_FILTER_SCHEMA = {
+    "additionalProperties": False,
+    "type": "object",
+    "properties": {
+        "limit": {
+            "type": "number"
+        },
+        "senders": {
+            "$ref": "#/definitions/user_id_array"
+        },
+        "not_senders": {
+            "$ref": "#/definitions/user_id_array"
+        },
+        "types": {
+            "type": "array",
+            "items": {
+                "type": "string"
+            }
+        },
+        "not_types": {
+            "type": "array",
+            "items": {
+                "type": "string"
+            }
+        },
+        "rooms": {
+            "$ref": "#/definitions/room_id_array"
+        },
+        "not_rooms": {
+            "$ref": "#/definitions/room_id_array"
+        },
+        "contains_url": {
+            "type": "boolean"
+        }
+    }
+}
+
+USER_ID_ARRAY_SCHEMA = {
+    "type": "array",
+    "items": {
+        "type": "string",
+        "format": "matrix_user_id"
+    }
+}
+
+ROOM_ID_ARRAY_SCHEMA = {
+    "type": "array",
+    "items": {
+        "type": "string",
+        "format": "matrix_room_id"
+    }
+}
+
+USER_FILTER_SCHEMA = {
+    "$schema": "http://json-schema.org/draft-04/schema#",
+    "description": "schema for a Sync filter",
+    "type": "object",
+    "definitions": {
+        "room_id_array": ROOM_ID_ARRAY_SCHEMA,
+        "user_id_array": USER_ID_ARRAY_SCHEMA,
+        "filter": FILTER_SCHEMA,
+        "room_filter": ROOM_FILTER_SCHEMA,
+        "room_event_filter": ROOM_EVENT_FILTER_SCHEMA
+    },
+    "properties": {
+        "presence": {
+            "$ref": "#/definitions/filter"
+        },
+        "account_data": {
+            "$ref": "#/definitions/filter"
+        },
+        "room": {
+            "$ref": "#/definitions/room_filter"
+        },
+        "event_format": {
+            "type": "string",
+            "enum": ["client", "federation"]
+        },
+        "event_fields": {
+            "type": "array",
+            "items": {
+                "type": "string",
+                # Don't allow '\\' in event field filters. This makes matching
+                # events a lot easier as we can then use a negative lookbehind
+                # assertion to split '\.' If we allowed \\ then it would
+                # incorrectly split '\\.' See synapse.events.utils.serialize_event
+                "pattern": "^((?!\\\).)*$"
+            }
+        }
+    },
+    "additionalProperties": False
+}
+
+
+@FormatChecker.cls_checks('matrix_room_id')
+def matrix_room_id_validator(room_id_str):
+    return RoomID.from_string(room_id_str)
+
+
+@FormatChecker.cls_checks('matrix_user_id')
+def matrix_user_id_validator(user_id_str):
+    return UserID.from_string(user_id_str)
 
 
 class Filtering(object):
@@ -53,98 +215,11 @@ class Filtering(object):
         # NB: Filters are the complete json blobs. "Definitions" are an
         # individual top-level key e.g. public_user_data. Filters are made of
         # many definitions.
-
-        top_level_definitions = [
-            "presence", "account_data"
-        ]
-
-        room_level_definitions = [
-            "state", "timeline", "ephemeral", "account_data"
-        ]
-
-        for key in top_level_definitions:
-            if key in user_filter_json:
-                self._check_definition(user_filter_json[key])
-
-        if "room" in user_filter_json:
-            self._check_definition_room_lists(user_filter_json["room"])
-            for key in room_level_definitions:
-                if key in user_filter_json["room"]:
-                    self._check_definition(user_filter_json["room"][key])
-
-        if "event_fields" in user_filter_json:
-            if type(user_filter_json["event_fields"]) != list:
-                raise SynapseError(400, "event_fields must be a list of strings")
-            for field in user_filter_json["event_fields"]:
-                if not isinstance(field, basestring):
-                    raise SynapseError(400, "Event field must be a string")
-                # Don't allow '\\' in event field filters. This makes matching
-                # events a lot easier as we can then use a negative lookbehind
-                # assertion to split '\.' If we allowed \\ then it would
-                # incorrectly split '\\.' See synapse.events.utils.serialize_event
-                if r'\\' in field:
-                    raise SynapseError(
-                        400, r'The escape character \ cannot itself be escaped'
-                    )
-
-    def _check_definition_room_lists(self, definition):
-        """Check that "rooms" and "not_rooms" are lists of room ids if they
-        are present
-
-        Args:
-            definition(dict): The filter definition
-        Raises:
-            SynapseError: If there was a problem with this definition.
-        """
-        # check rooms are valid room IDs
-        room_id_keys = ["rooms", "not_rooms"]
-        for key in room_id_keys:
-            if key in definition:
-                if type(definition[key]) != list:
-                    raise SynapseError(400, "Expected %s to be a list." % key)
-                for room_id in definition[key]:
-                    RoomID.from_string(room_id)
-
-    def _check_definition(self, definition):
-        """Check if the provided definition is valid.
-
-        This inspects not only the types but also the values to make sure they
-        make sense.
-
-        Args:
-            definition(dict): The filter definition
-        Raises:
-            SynapseError: If there was a problem with this definition.
-        """
-        # NB: Filters are the complete json blobs. "Definitions" are an
-        # individual top-level key e.g. public_user_data. Filters are made of
-        # many definitions.
-        if type(definition) != dict:
-            raise SynapseError(
-                400, "Expected JSON object, not %s" % (definition,)
-            )
-
-        self._check_definition_room_lists(definition)
-
-        # check senders are valid user IDs
-        user_id_keys = ["senders", "not_senders"]
-        for key in user_id_keys:
-            if key in definition:
-                if type(definition[key]) != list:
-                    raise SynapseError(400, "Expected %s to be a list." % key)
-                for user_id in definition[key]:
-                    UserID.from_string(user_id)
-
-        # TODO: We don't limit event type values but we probably should...
-        # check types are valid event types
-        event_keys = ["types", "not_types"]
-        for key in event_keys:
-            if key in definition:
-                if type(definition[key]) != list:
-                    raise SynapseError(400, "Expected %s to be a list." % key)
-                for event_type in definition[key]:
-                    if not isinstance(event_type, basestring):
-                        raise SynapseError(400, "Event type should be a string")
+        try:
+            jsonschema.validate(user_filter_json, USER_FILTER_SCHEMA,
+                                format_checker=FormatChecker())
+        except jsonschema.ValidationError as e:
+            raise SynapseError(400, e.message)
 
 
 class FilterCollection(object):
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 0a21392a62..1bb27edc0f 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -15,7 +15,6 @@
 
 from synapse.crypto.keyclient import fetch_server_key
 from synapse.api.errors import SynapseError, Codes
-from synapse.util.retryutils import get_retry_limiter
 from synapse.util import unwrapFirstError
 from synapse.util.async import ObservableDeferred
 from synapse.util.logcontext import (
@@ -382,30 +381,24 @@ class Keyring(object):
     def get_keys_from_server(self, server_name_and_key_ids):
         @defer.inlineCallbacks
         def get_key(server_name, key_ids):
-            limiter = yield get_retry_limiter(
-                server_name,
-                self.clock,
-                self.store,
-            )
-            with limiter:
-                keys = None
-                try:
-                    keys = yield self.get_server_verify_key_v2_direct(
-                        server_name, key_ids
-                    )
-                except Exception as e:
-                    logger.info(
-                        "Unable to get key %r for %r directly: %s %s",
-                        key_ids, server_name,
-                        type(e).__name__, str(e.message),
-                    )
+            keys = None
+            try:
+                keys = yield self.get_server_verify_key_v2_direct(
+                    server_name, key_ids
+                )
+            except Exception as e:
+                logger.info(
+                    "Unable to get key %r for %r directly: %s %s",
+                    key_ids, server_name,
+                    type(e).__name__, str(e.message),
+                )
 
-                if not keys:
-                    keys = yield self.get_server_verify_key_v1_direct(
-                        server_name, key_ids
-                    )
+            if not keys:
+                keys = yield self.get_server_verify_key_v1_direct(
+                    server_name, key_ids
+                )
 
-                    keys = {server_name: keys}
+                keys = {server_name: keys}
 
             defer.returnValue(keys)
 
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 5dcd4eecce..deee0f4904 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -29,7 +29,7 @@ from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
 from synapse.events import FrozenEvent, builder
 import synapse.metrics
 
-from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
+from synapse.util.retryutils import NotRetryingDestination
 
 import copy
 import itertools
@@ -88,7 +88,7 @@ class FederationClient(FederationBase):
 
     @log_function
     def make_query(self, destination, query_type, args,
-                   retry_on_dns_fail=False):
+                   retry_on_dns_fail=False, ignore_backoff=False):
         """Sends a federation Query to a remote homeserver of the given type
         and arguments.
 
@@ -98,6 +98,8 @@ class FederationClient(FederationBase):
                 handler name used in register_query_handler().
             args (dict): Mapping of strings to strings containing the details
                 of the query request.
+            ignore_backoff (bool): true to ignore the historical backoff data
+                and try the request anyway.
 
         Returns:
             a Deferred which will eventually yield a JSON object from the
@@ -106,7 +108,8 @@ class FederationClient(FederationBase):
         sent_queries_counter.inc(query_type)
 
         return self.transport_layer.make_query(
-            destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail
+            destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail,
+            ignore_backoff=ignore_backoff,
         )
 
     @log_function
@@ -234,31 +237,24 @@ class FederationClient(FederationBase):
                 continue
 
             try:
-                limiter = yield get_retry_limiter(
-                    destination,
-                    self._clock,
-                    self.store,
+                transaction_data = yield self.transport_layer.get_event(
+                    destination, event_id, timeout=timeout,
                 )
 
-                with limiter:
-                    transaction_data = yield self.transport_layer.get_event(
-                        destination, event_id, timeout=timeout,
-                    )
-
-                    logger.debug("transaction_data %r", transaction_data)
+                logger.debug("transaction_data %r", transaction_data)
 
-                    pdu_list = [
-                        self.event_from_pdu_json(p, outlier=outlier)
-                        for p in transaction_data["pdus"]
-                    ]
+                pdu_list = [
+                    self.event_from_pdu_json(p, outlier=outlier)
+                    for p in transaction_data["pdus"]
+                ]
 
-                    if pdu_list and pdu_list[0]:
-                        pdu = pdu_list[0]
+                if pdu_list and pdu_list[0]:
+                    pdu = pdu_list[0]
 
-                        # Check signatures are correct.
-                        signed_pdu = yield self._check_sigs_and_hashes([pdu])[0]
+                    # Check signatures are correct.
+                    signed_pdu = yield self._check_sigs_and_hashes([pdu])[0]
 
-                        break
+                    break
 
                 pdu_attempts[destination] = now
 
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index c802dd67a3..d7ecefcc64 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -12,7 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+import datetime
 
 from twisted.internet import defer
 
@@ -22,9 +22,7 @@ from .units import Transaction, Edu
 from synapse.api.errors import HttpResponseException
 from synapse.util.async import run_on_reactor
 from synapse.util.logcontext import preserve_context_over_fn
-from synapse.util.retryutils import (
-    get_retry_limiter, NotRetryingDestination,
-)
+from synapse.util.retryutils import NotRetryingDestination
 from synapse.util.metrics import measure_func
 from synapse.types import get_domain_from_id
 from synapse.handlers.presence import format_user_presence_state
@@ -312,13 +310,6 @@ class TransactionQueue(object):
             yield run_on_reactor()
 
             while True:
-                limiter = yield get_retry_limiter(
-                    destination,
-                    self.clock,
-                    self.store,
-                    backoff_on_404=True,  # If we get a 404 the other side has gone
-                )
-
                 device_message_edus, device_stream_id, dev_list_id = (
                     yield self._get_new_device_messages(destination)
                 )
@@ -374,7 +365,6 @@ class TransactionQueue(object):
 
                 success = yield self._send_new_transaction(
                     destination, pending_pdus, pending_edus, pending_failures,
-                    limiter=limiter,
                 )
                 if success:
                     # Remove the acknowledged device messages from the database
@@ -392,12 +382,24 @@ class TransactionQueue(object):
                     self.last_device_list_stream_id_by_dest[destination] = dev_list_id
                 else:
                     break
-        except NotRetryingDestination:
+        except NotRetryingDestination as e:
             logger.debug(
-                "TX [%s] not ready for retry yet - "
+                "TX [%s] not ready for retry yet (next retry at %s) - "
                 "dropping transaction for now",
                 destination,
+                datetime.datetime.fromtimestamp(
+                    (e.retry_last_ts + e.retry_interval) / 1000.0
+                ),
+            )
+        except Exception as e:
+            logger.warn(
+                "TX [%s] Failed to send transaction: %s",
+                destination,
+                e,
             )
+            for p in pending_pdus:
+                logger.info("Failed to send event %s to %s", p.event_id,
+                            destination)
         finally:
             # We want to be *very* sure we delete this after we stop processing
             self.pending_transactions.pop(destination, None)
@@ -437,7 +439,7 @@ class TransactionQueue(object):
     @measure_func("_send_new_transaction")
     @defer.inlineCallbacks
     def _send_new_transaction(self, destination, pending_pdus, pending_edus,
-                              pending_failures, limiter):
+                              pending_failures):
 
         # Sort based on the order field
         pending_pdus.sort(key=lambda t: t[1])
@@ -447,132 +449,104 @@ class TransactionQueue(object):
 
         success = True
 
-        try:
-            logger.debug("TX [%s] _attempt_new_transaction", destination)
+        logger.debug("TX [%s] _attempt_new_transaction", destination)
 
-            txn_id = str(self._next_txn_id)
+        txn_id = str(self._next_txn_id)
 
-            logger.debug(
-                "TX [%s] {%s} Attempting new transaction"
-                " (pdus: %d, edus: %d, failures: %d)",
-                destination, txn_id,
-                len(pdus),
-                len(edus),
-                len(failures)
-            )
+        logger.debug(
+            "TX [%s] {%s} Attempting new transaction"
+            " (pdus: %d, edus: %d, failures: %d)",
+            destination, txn_id,
+            len(pdus),
+            len(edus),
+            len(failures)
+        )
 
-            logger.debug("TX [%s] Persisting transaction...", destination)
+        logger.debug("TX [%s] Persisting transaction...", destination)
 
-            transaction = Transaction.create_new(
-                origin_server_ts=int(self.clock.time_msec()),
-                transaction_id=txn_id,
-                origin=self.server_name,
-                destination=destination,
-                pdus=pdus,
-                edus=edus,
-                pdu_failures=failures,
-            )
+        transaction = Transaction.create_new(
+            origin_server_ts=int(self.clock.time_msec()),
+            transaction_id=txn_id,
+            origin=self.server_name,
+            destination=destination,
+            pdus=pdus,
+            edus=edus,
+            pdu_failures=failures,
+        )
 
-            self._next_txn_id += 1
+        self._next_txn_id += 1
 
-            yield self.transaction_actions.prepare_to_send(transaction)
+        yield self.transaction_actions.prepare_to_send(transaction)
 
-            logger.debug("TX [%s] Persisted transaction", destination)
-            logger.info(
-                "TX [%s] {%s} Sending transaction [%s],"
-                " (PDUs: %d, EDUs: %d, failures: %d)",
-                destination, txn_id,
-                transaction.transaction_id,
-                len(pdus),
-                len(edus),
-                len(failures),
-            )
+        logger.debug("TX [%s] Persisted transaction", destination)
+        logger.info(
+            "TX [%s] {%s} Sending transaction [%s],"
+            " (PDUs: %d, EDUs: %d, failures: %d)",
+            destination, txn_id,
+            transaction.transaction_id,
+            len(pdus),
+            len(edus),
+            len(failures),
+        )
 
-            with limiter:
-                # Actually send the transaction
-
-                # FIXME (erikj): This is a bit of a hack to make the Pdu age
-                # keys work
-                def json_data_cb():
-                    data = transaction.get_dict()
-                    now = int(self.clock.time_msec())
-                    if "pdus" in data:
-                        for p in data["pdus"]:
-                            if "age_ts" in p:
-                                unsigned = p.setdefault("unsigned", {})
-                                unsigned["age"] = now - int(p["age_ts"])
-                                del p["age_ts"]
-                    return data
-
-                try:
-                    response = yield self.transport_layer.send_transaction(
-                        transaction, json_data_cb
-                    )
-                    code = 200
-
-                    if response:
-                        for e_id, r in response.get("pdus", {}).items():
-                            if "error" in r:
-                                logger.warn(
-                                    "Transaction returned error for %s: %s",
-                                    e_id, r,
-                                )
-                except HttpResponseException as e:
-                    code = e.code
-                    response = e.response
-
-                    if e.code in (401, 404, 429) or 500 <= e.code:
-                        logger.info(
-                            "TX [%s] {%s} got %d response",
-                            destination, txn_id, code
+        # Actually send the transaction
+
+        # FIXME (erikj): This is a bit of a hack to make the Pdu age
+        # keys work
+        def json_data_cb():
+            data = transaction.get_dict()
+            now = int(self.clock.time_msec())
+            if "pdus" in data:
+                for p in data["pdus"]:
+                    if "age_ts" in p:
+                        unsigned = p.setdefault("unsigned", {})
+                        unsigned["age"] = now - int(p["age_ts"])
+                        del p["age_ts"]
+            return data
+
+        try:
+            response = yield self.transport_layer.send_transaction(
+                transaction, json_data_cb
+            )
+            code = 200
+
+            if response:
+                for e_id, r in response.get("pdus", {}).items():
+                    if "error" in r:
+                        logger.warn(
+                            "Transaction returned error for %s: %s",
+                            e_id, r,
                         )
-                        raise e
+        except HttpResponseException as e:
+            code = e.code
+            response = e.response
 
+            if e.code in (401, 404, 429) or 500 <= e.code:
                 logger.info(
                     "TX [%s] {%s} got %d response",
                     destination, txn_id, code
                 )
+                raise e
 
-                logger.debug("TX [%s] Sent transaction", destination)
-                logger.debug("TX [%s] Marking as delivered...", destination)
-
-            yield self.transaction_actions.delivered(
-                transaction, code, response
-            )
+        logger.info(
+            "TX [%s] {%s} got %d response",
+            destination, txn_id, code
+        )
 
-            logger.debug("TX [%s] Marked as delivered", destination)
+        logger.debug("TX [%s] Sent transaction", destination)
+        logger.debug("TX [%s] Marking as delivered...", destination)
 
-            if code != 200:
-                for p in pdus:
-                    logger.info(
-                        "Failed to send event %s to %s", p.event_id, destination
-                    )
-                success = False
-        except RuntimeError as e:
-            # We capture this here as there as nothing actually listens
-            # for this finishing functions deferred.
-            logger.warn(
-                "TX [%s] Problem in _attempt_transaction: %s",
-                destination,
-                e,
-            )
+        yield self.transaction_actions.delivered(
+            transaction, code, response
+        )
 
-            success = False
+        logger.debug("TX [%s] Marked as delivered", destination)
 
+        if code != 200:
             for p in pdus:
-                logger.info("Failed to send event %s to %s", p.event_id, destination)
-        except Exception as e:
-            # We capture this here as there as nothing actually listens
-            # for this finishing functions deferred.
-            logger.warn(
-                "TX [%s] Problem in _attempt_transaction: %s",
-                destination,
-                e,
-            )
-
+                logger.info(
+                    "Failed to send event %s to %s", p.event_id, destination
+                )
             success = False
 
-            for p in pdus:
-                logger.info("Failed to send event %s to %s", p.event_id, destination)
-
         defer.returnValue(success)
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index f49e8a2cc4..15a03378f5 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -163,6 +163,7 @@ class TransportLayerClient(object):
             data=json_data,
             json_data_callback=json_data_callback,
             long_retries=True,
+            backoff_on_404=True,  # If we get a 404 the other side has gone
         )
 
         logger.debug(
@@ -174,7 +175,8 @@ class TransportLayerClient(object):
 
     @defer.inlineCallbacks
     @log_function
-    def make_query(self, destination, query_type, args, retry_on_dns_fail):
+    def make_query(self, destination, query_type, args, retry_on_dns_fail,
+                   ignore_backoff=False):
         path = PREFIX + "/query/%s" % query_type
 
         content = yield self.client.get_json(
@@ -183,6 +185,7 @@ class TransportLayerClient(object):
             args=args,
             retry_on_dns_fail=retry_on_dns_fail,
             timeout=10000,
+            ignore_backoff=ignore_backoff,
         )
 
         defer.returnValue(content)
@@ -242,6 +245,7 @@ class TransportLayerClient(object):
             destination=destination,
             path=path,
             data=content,
+            ignore_backoff=True,
         )
 
         defer.returnValue(response)
@@ -269,6 +273,7 @@ class TransportLayerClient(object):
             destination=remote_server,
             path=path,
             args=args,
+            ignore_backoff=True,
         )
 
         defer.returnValue(response)
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 1b5317edf5..943554ce98 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -175,6 +175,7 @@ class DirectoryHandler(BaseHandler):
                         "room_alias": room_alias.to_string(),
                     },
                     retry_on_dns_fail=False,
+                    ignore_backoff=True,
                 )
             except CodeMessageException as e:
                 logging.warn("Error retrieving alias")
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index c02d41a74c..c2b38d72a9 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -22,7 +22,7 @@ from twisted.internet import defer
 from synapse.api.errors import SynapseError, CodeMessageException
 from synapse.types import get_domain_from_id
 from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
-from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
+from synapse.util.retryutils import NotRetryingDestination
 
 logger = logging.getLogger(__name__)
 
@@ -121,15 +121,11 @@ class E2eKeysHandler(object):
         def do_remote_query(destination):
             destination_query = remote_queries_not_in_cache[destination]
             try:
-                limiter = yield get_retry_limiter(
-                    destination, self.clock, self.store
+                remote_result = yield self.federation.query_client_keys(
+                    destination,
+                    {"device_keys": destination_query},
+                    timeout=timeout
                 )
-                with limiter:
-                    remote_result = yield self.federation.query_client_keys(
-                        destination,
-                        {"device_keys": destination_query},
-                        timeout=timeout
-                    )
 
                 for user_id, keys in remote_result["device_keys"].items():
                     if user_id in destination_query:
@@ -239,18 +235,14 @@ class E2eKeysHandler(object):
         def claim_client_keys(destination):
             device_keys = remote_queries[destination]
             try:
-                limiter = yield get_retry_limiter(
-                    destination, self.clock, self.store
+                remote_result = yield self.federation.claim_client_keys(
+                    destination,
+                    {"one_time_keys": device_keys},
+                    timeout=timeout
                 )
-                with limiter:
-                    remote_result = yield self.federation.claim_client_keys(
-                        destination,
-                        {"one_time_keys": device_keys},
-                        timeout=timeout
-                    )
-                    for user_id, keys in remote_result["one_time_keys"].items():
-                        if user_id in device_keys:
-                            json_result[user_id] = keys
+                for user_id, keys in remote_result["one_time_keys"].items():
+                    if user_id in device_keys:
+                        json_result[user_id] = keys
             except CodeMessageException as e:
                 failures[destination] = {
                     "status": e.code, "message": e.message
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index abd1fb28cb..9bf638f818 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -52,7 +52,8 @@ class ProfileHandler(BaseHandler):
                     args={
                         "user_id": target_user.to_string(),
                         "field": "displayname",
-                    }
+                    },
+                    ignore_backoff=True,
                 )
             except CodeMessageException as e:
                 if e.code != 404:
@@ -99,7 +100,8 @@ class ProfileHandler(BaseHandler):
                     args={
                         "user_id": target_user.to_string(),
                         "field": "avatar_url",
-                    }
+                    },
+                    ignore_backoff=True,
                 )
             except CodeMessageException as e:
                 if e.code != 404:
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 82586e3dea..f9e32ef03d 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -12,8 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
-
+import synapse.util.retryutils
 from twisted.internet import defer, reactor, protocol
 from twisted.internet.error import DNSLookupError
 from twisted.web.client import readBody, HTTPConnectionPool, Agent
@@ -94,6 +93,7 @@ class MatrixFederationHttpClient(object):
             reactor, MatrixFederationEndpointFactory(hs), pool=pool
         )
         self.clock = hs.get_clock()
+        self._store = hs.get_datastore()
         self.version_string = hs.version_string
         self._next_id = 1
 
@@ -103,129 +103,151 @@ class MatrixFederationHttpClient(object):
         )
 
     @defer.inlineCallbacks
-    def _create_request(self, destination, method, path_bytes,
-                        body_callback, headers_dict={}, param_bytes=b"",
-                        query_bytes=b"", retry_on_dns_fail=True,
-                        timeout=None, long_retries=False):
-        """ Creates and sends a request to the given url
+    def _request(self, destination, method, path,
+                 body_callback, headers_dict={}, param_bytes=b"",
+                 query_bytes=b"", retry_on_dns_fail=True,
+                 timeout=None, long_retries=False,
+                 ignore_backoff=False,
+                 backoff_on_404=False):
+        """ Creates and sends a request to the given server
+        Args:
+            destination (str): The remote server to send the HTTP request to.
+            method (str): HTTP method
+            path (str): The HTTP path
+            ignore_backoff (bool): true to ignore the historical backoff data
+                and try the request anyway.
+            backoff_on_404 (bool): Back off if we get a 404
 
         Returns:
             Deferred: resolves with the http response object on success.
 
             Fails with ``HTTPRequestException``: if we get an HTTP response
-            code >= 300.
+                code >= 300.
+            Fails with ``NotRetryingDestination`` if we are not yet ready
+                to retry this server.
         """
-        headers_dict[b"User-Agent"] = [self.version_string]
-        headers_dict[b"Host"] = [destination]
-
-        url_bytes = self._create_url(
-            destination, path_bytes, param_bytes, query_bytes
+        limiter = yield synapse.util.retryutils.get_retry_limiter(
+            destination,
+            self.clock,
+            self._store,
+            backoff_on_404=backoff_on_404,
+            ignore_backoff=ignore_backoff,
         )
 
-        txn_id = "%s-O-%s" % (method, self._next_id)
-        self._next_id = (self._next_id + 1) % (sys.maxint - 1)
+        destination = destination.encode("ascii")
+        path_bytes = path.encode("ascii")
+        with limiter:
+            headers_dict[b"User-Agent"] = [self.version_string]
+            headers_dict[b"Host"] = [destination]
 
-        outbound_logger.info(
-            "{%s} [%s] Sending request: %s %s",
-            txn_id, destination, method, url_bytes
-        )
+            url_bytes = self._create_url(
+                destination, path_bytes, param_bytes, query_bytes
+            )
 
-        # XXX: Would be much nicer to retry only at the transaction-layer
-        # (once we have reliable transactions in place)
-        if long_retries:
-            retries_left = MAX_LONG_RETRIES
-        else:
-            retries_left = MAX_SHORT_RETRIES
+            txn_id = "%s-O-%s" % (method, self._next_id)
+            self._next_id = (self._next_id + 1) % (sys.maxint - 1)
 
-        http_url_bytes = urlparse.urlunparse(
-            ("", "", path_bytes, param_bytes, query_bytes, "")
-        )
+            outbound_logger.info(
+                "{%s} [%s] Sending request: %s %s",
+                txn_id, destination, method, url_bytes
+            )
 
-        log_result = None
-        try:
-            while True:
-                producer = None
-                if body_callback:
-                    producer = body_callback(method, http_url_bytes, headers_dict)
-
-                try:
-                    def send_request():
-                        request_deferred = preserve_context_over_fn(
-                            self.agent.request,
-                            method,
-                            url_bytes,
-                            Headers(headers_dict),
-                            producer
-                        )
+            # XXX: Would be much nicer to retry only at the transaction-layer
+            # (once we have reliable transactions in place)
+            if long_retries:
+                retries_left = MAX_LONG_RETRIES
+            else:
+                retries_left = MAX_SHORT_RETRIES
 
-                        return self.clock.time_bound_deferred(
-                            request_deferred,
-                            time_out=timeout / 1000. if timeout else 60,
-                        )
+            http_url_bytes = urlparse.urlunparse(
+                ("", "", path_bytes, param_bytes, query_bytes, "")
+            )
 
-                    response = yield preserve_context_over_fn(send_request)
+            log_result = None
+            try:
+                while True:
+                    producer = None
+                    if body_callback:
+                        producer = body_callback(method, http_url_bytes, headers_dict)
+
+                    try:
+                        def send_request():
+                            request_deferred = preserve_context_over_fn(
+                                self.agent.request,
+                                method,
+                                url_bytes,
+                                Headers(headers_dict),
+                                producer
+                            )
+
+                            return self.clock.time_bound_deferred(
+                                request_deferred,
+                                time_out=timeout / 1000. if timeout else 60,
+                            )
+
+                        response = yield preserve_context_over_fn(send_request)
+
+                        log_result = "%d %s" % (response.code, response.phrase,)
+                        break
+                    except Exception as e:
+                        if not retry_on_dns_fail and isinstance(e, DNSLookupError):
+                            logger.warn(
+                                "DNS Lookup failed to %s with %s",
+                                destination,
+                                e
+                            )
+                            log_result = "DNS Lookup failed to %s with %s" % (
+                                destination, e
+                            )
+                            raise
 
-                    log_result = "%d %s" % (response.code, response.phrase,)
-                    break
-                except Exception as e:
-                    if not retry_on_dns_fail and isinstance(e, DNSLookupError):
                         logger.warn(
-                            "DNS Lookup failed to %s with %s",
+                            "{%s} Sending request failed to %s: %s %s: %s - %s",
+                            txn_id,
                             destination,
-                            e
+                            method,
+                            url_bytes,
+                            type(e).__name__,
+                            _flatten_response_never_received(e),
                         )
-                        log_result = "DNS Lookup failed to %s with %s" % (
-                            destination, e
+
+                        log_result = "%s - %s" % (
+                            type(e).__name__, _flatten_response_never_received(e),
                         )
-                        raise
-
-                    logger.warn(
-                        "{%s} Sending request failed to %s: %s %s: %s - %s",
-                        txn_id,
-                        destination,
-                        method,
-                        url_bytes,
-                        type(e).__name__,
-                        _flatten_response_never_received(e),
-                    )
-
-                    log_result = "%s - %s" % (
-                        type(e).__name__, _flatten_response_never_received(e),
-                    )
-
-                    if retries_left and not timeout:
-                        if long_retries:
-                            delay = 4 ** (MAX_LONG_RETRIES + 1 - retries_left)
-                            delay = min(delay, 60)
-                            delay *= random.uniform(0.8, 1.4)
+
+                        if retries_left and not timeout:
+                            if long_retries:
+                                delay = 4 ** (MAX_LONG_RETRIES + 1 - retries_left)
+                                delay = min(delay, 60)
+                                delay *= random.uniform(0.8, 1.4)
+                            else:
+                                delay = 0.5 * 2 ** (MAX_SHORT_RETRIES - retries_left)
+                                delay = min(delay, 2)
+                                delay *= random.uniform(0.8, 1.4)
+
+                            yield sleep(delay)
+                            retries_left -= 1
                         else:
-                            delay = 0.5 * 2 ** (MAX_SHORT_RETRIES - retries_left)
-                            delay = min(delay, 2)
-                            delay *= random.uniform(0.8, 1.4)
-
-                        yield sleep(delay)
-                        retries_left -= 1
-                    else:
-                        raise
-        finally:
-            outbound_logger.info(
-                "{%s} [%s] Result: %s",
-                txn_id,
-                destination,
-                log_result,
-            )
+                            raise
+            finally:
+                outbound_logger.info(
+                    "{%s} [%s] Result: %s",
+                    txn_id,
+                    destination,
+                    log_result,
+                )
 
-        if 200 <= response.code < 300:
-            pass
-        else:
-            # :'(
-            # Update transactions table?
-            body = yield preserve_context_over_fn(readBody, response)
-            raise HttpResponseException(
-                response.code, response.phrase, body
-            )
+            if 200 <= response.code < 300:
+                pass
+            else:
+                # :'(
+                # Update transactions table?
+                body = yield preserve_context_over_fn(readBody, response)
+                raise HttpResponseException(
+                    response.code, response.phrase, body
+                )
 
-        defer.returnValue(response)
+            defer.returnValue(response)
 
     def sign_request(self, destination, method, url_bytes, headers_dict,
                      content=None):
@@ -254,7 +276,9 @@ class MatrixFederationHttpClient(object):
 
     @defer.inlineCallbacks
     def put_json(self, destination, path, data={}, json_data_callback=None,
-                 long_retries=False, timeout=None):
+                 long_retries=False, timeout=None,
+                 ignore_backoff=False,
+                 backoff_on_404=False):
         """ Sends the specifed json data using PUT
 
         Args:
@@ -269,11 +293,19 @@ class MatrixFederationHttpClient(object):
                 retry for a short or long time.
             timeout(int): How long to try (in ms) the destination for before
                 giving up. None indicates no timeout.
+            ignore_backoff (bool): true to ignore the historical backoff data
+                and try the request anyway.
+            backoff_on_404 (bool): True if we should count a 404 response as
+                a failure of the server (and should therefore back off future
+                requests)
 
         Returns:
             Deferred: Succeeds when we get a 2xx HTTP response. The result
             will be the decoded JSON body. On a 4xx or 5xx error response a
             CodeMessageException is raised.
+
+            Fails with ``NotRetryingDestination`` if we are not yet ready
+            to retry this server.
         """
 
         if not json_data_callback:
@@ -288,14 +320,16 @@ class MatrixFederationHttpClient(object):
             producer = _JsonProducer(json_data)
             return producer
 
-        response = yield self._create_request(
-            destination.encode("ascii"),
+        response = yield self._request(
+            destination,
             "PUT",
-            path.encode("ascii"),
+            path,
             body_callback=body_callback,
             headers_dict={"Content-Type": ["application/json"]},
             long_retries=long_retries,
             timeout=timeout,
+            ignore_backoff=ignore_backoff,
+            backoff_on_404=backoff_on_404,
         )
 
         if 200 <= response.code < 300:
@@ -307,7 +341,7 @@ class MatrixFederationHttpClient(object):
 
     @defer.inlineCallbacks
     def post_json(self, destination, path, data={}, long_retries=False,
-                  timeout=None):
+                  timeout=None, ignore_backoff=False):
         """ Sends the specifed json data using POST
 
         Args:
@@ -320,11 +354,15 @@ class MatrixFederationHttpClient(object):
                 retry for a short or long time.
             timeout(int): How long to try (in ms) the destination for before
                 giving up. None indicates no timeout.
-
+            ignore_backoff (bool): true to ignore the historical backoff data and
+                try the request anyway.
         Returns:
             Deferred: Succeeds when we get a 2xx HTTP response. The result
             will be the decoded JSON body. On a 4xx or 5xx error response a
             CodeMessageException is raised.
+
+            Fails with ``NotRetryingDestination`` if we are not yet ready
+            to retry this server.
         """
 
         def body_callback(method, url_bytes, headers_dict):
@@ -333,14 +371,15 @@ class MatrixFederationHttpClient(object):
             )
             return _JsonProducer(data)
 
-        response = yield self._create_request(
-            destination.encode("ascii"),
+        response = yield self._request(
+            destination,
             "POST",
-            path.encode("ascii"),
+            path,
             body_callback=body_callback,
             headers_dict={"Content-Type": ["application/json"]},
             long_retries=long_retries,
             timeout=timeout,
+            ignore_backoff=ignore_backoff,
         )
 
         if 200 <= response.code < 300:
@@ -353,7 +392,7 @@ class MatrixFederationHttpClient(object):
 
     @defer.inlineCallbacks
     def get_json(self, destination, path, args={}, retry_on_dns_fail=True,
-                 timeout=None):
+                 timeout=None, ignore_backoff=False):
         """ GETs some json from the given host homeserver and path
 
         Args:
@@ -365,11 +404,16 @@ class MatrixFederationHttpClient(object):
             timeout (int): How long to try (in ms) the destination for before
                 giving up. None indicates no timeout and that the request will
                 be retried.
+            ignore_backoff (bool): true to ignore the historical backoff data
+                and try the request anyway.
         Returns:
             Deferred: Succeeds when we get *any* HTTP response.
 
             The result of the deferred is a tuple of `(code, response)`,
             where `response` is a dict representing the decoded JSON body.
+
+            Fails with ``NotRetryingDestination`` if we are not yet ready
+            to retry this server.
         """
         logger.debug("get_json args: %s", args)
 
@@ -386,14 +430,15 @@ class MatrixFederationHttpClient(object):
             self.sign_request(destination, method, url_bytes, headers_dict)
             return None
 
-        response = yield self._create_request(
-            destination.encode("ascii"),
+        response = yield self._request(
+            destination,
             "GET",
-            path.encode("ascii"),
+            path,
             query_bytes=query_bytes,
             body_callback=body_callback,
             retry_on_dns_fail=retry_on_dns_fail,
             timeout=timeout,
+            ignore_backoff=ignore_backoff,
         )
 
         if 200 <= response.code < 300:
@@ -406,19 +451,25 @@ class MatrixFederationHttpClient(object):
 
     @defer.inlineCallbacks
     def get_file(self, destination, path, output_stream, args={},
-                 retry_on_dns_fail=True, max_size=None):
+                 retry_on_dns_fail=True, max_size=None,
+                 ignore_backoff=False):
         """GETs a file from a given homeserver
         Args:
             destination (str): The remote server to send the HTTP request to.
             path (str): The HTTP path to GET.
             output_stream (file): File to write the response body to.
             args (dict): Optional dictionary used to create the query string.
+            ignore_backoff (bool): true to ignore the historical backoff data
+                and try the request anyway.
         Returns:
             Deferred: resolves with an (int,dict) tuple of the file length and
             a dict of the response headers.
 
             Fails with ``HTTPRequestException`` if we get an HTTP response code
             >= 300
+
+            Fails with ``NotRetryingDestination`` if we are not yet ready
+            to retry this server.
         """
 
         encoded_args = {}
@@ -434,13 +485,14 @@ class MatrixFederationHttpClient(object):
             self.sign_request(destination, method, url_bytes, headers_dict)
             return None
 
-        response = yield self._create_request(
-            destination.encode("ascii"),
+        response = yield self._request(
+            destination,
             "GET",
-            path.encode("ascii"),
+            path,
             query_bytes=query_bytes,
             body_callback=body_callback,
-            retry_on_dns_fail=retry_on_dns_fail
+            retry_on_dns_fail=retry_on_dns_fail,
+            ignore_backoff=ignore_backoff,
         )
 
         headers = dict(response.headers.getAllRawHeaders())
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index c4777b2a2b..ed7f1c89ad 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -19,6 +19,7 @@ from distutils.version import LooseVersion
 logger = logging.getLogger(__name__)
 
 REQUIREMENTS = {
+    "jsonschema>=2.5.1": ["jsonschema>=2.5.1"],
     "frozendict>=0.4": ["frozendict"],
     "unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"],
     "canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"],
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index aac76edf1c..4990b22b9f 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -268,7 +268,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
         if existingUid is not None:
             raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
 
-        ret = yield self.identity_handler.requestEmailToken(**body)
+        ret = yield self.identity_handler.requestMsisdnToken(**body)
         defer.returnValue((200, ret))
 
 
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 30fc480108..98a5a26ac5 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
 
 class DeferredTimedOutError(SynapseError):
     def __init__(self):
-        super(SynapseError).__init__(504, "Timed out")
+        super(SynapseError, self).__init__(504, "Timed out")
 
 
 def unwrapFirstError(failure):
@@ -93,8 +93,10 @@ class Clock(object):
         ret_deferred = defer.Deferred()
 
         def timed_out_fn():
+            e = DeferredTimedOutError()
+
             try:
-                ret_deferred.errback(DeferredTimedOutError())
+                ret_deferred.errback(e)
             except:
                 pass
 
@@ -114,7 +116,7 @@ class Clock(object):
 
         ret_deferred.addBoth(cancel)
 
-        def sucess(res):
+        def success(res):
             try:
                 ret_deferred.callback(res)
             except:
@@ -128,7 +130,7 @@ class Clock(object):
             except:
                 pass
 
-        given_deferred.addCallbacks(callback=sucess, errback=err)
+        given_deferred.addCallbacks(callback=success, errback=err)
 
         timer = self.call_later(time_out, timed_out_fn)
 
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index b68e8c4e9f..4fa9d1a03c 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -35,7 +35,8 @@ class NotRetryingDestination(Exception):
 
 
 @defer.inlineCallbacks
-def get_retry_limiter(destination, clock, store, **kwargs):
+def get_retry_limiter(destination, clock, store, ignore_backoff=False,
+                      **kwargs):
     """For a given destination check if we have previously failed to
     send a request there and are waiting before retrying the destination.
     If we are not ready to retry the destination, this will raise a
@@ -43,6 +44,14 @@ def get_retry_limiter(destination, clock, store, **kwargs):
     that will mark the destination as down if an exception is thrown (excluding
     CodeMessageException with code < 500)
 
+    Args:
+        destination (str): name of homeserver
+        clock (synapse.util.clock): timing source
+        store (synapse.storage.transactions.TransactionStore): datastore
+        ignore_backoff (bool): true to ignore the historical backoff data and
+            try the request anyway. We will still update the next
+            retry_interval on success/failure.
+
     Example usage:
 
         try:
@@ -66,7 +75,7 @@ def get_retry_limiter(destination, clock, store, **kwargs):
 
         now = int(clock.time_msec())
 
-        if retry_last_ts + retry_interval > now:
+        if not ignore_backoff and retry_last_ts + retry_interval > now:
             raise NotRetryingDestination(
                 retry_last_ts=retry_last_ts,
                 retry_interval=retry_interval,
@@ -124,7 +133,13 @@ class RetryDestinationLimiter(object):
 
     def __exit__(self, exc_type, exc_val, exc_tb):
         valid_err_code = False
-        if exc_type is not None and issubclass(exc_type, CodeMessageException):
+        if exc_type is None:
+            valid_err_code = True
+        elif not issubclass(exc_type, Exception):
+            # avoid treating exceptions which don't derive from Exception as
+            # failures; this is mostly so as not to catch defer._DefGen.
+            valid_err_code = True
+        elif issubclass(exc_type, CodeMessageException):
             # Some error codes are perfectly fine for some APIs, whereas other
             # APIs may expect to never received e.g. a 404. It's important to
             # handle 404 as some remote servers will return a 404 when the HS
@@ -142,11 +157,13 @@ class RetryDestinationLimiter(object):
             else:
                 valid_err_code = False
 
-        if exc_type is None or valid_err_code:
+        if valid_err_code:
             # We connected successfully.
             if not self.retry_interval:
                 return
 
+            logger.debug("Connection to %s was successful; clearing backoff",
+                         self.destination)
             retry_last_ts = 0
             self.retry_interval = 0
         else:
@@ -160,6 +177,10 @@ class RetryDestinationLimiter(object):
             else:
                 self.retry_interval = self.min_retry_interval
 
+            logger.debug(
+                "Connection to %s was unsuccessful (%s(%s)); backoff now %i",
+                self.destination, exc_type, exc_val, self.retry_interval
+            )
             retry_last_ts = int(self.clock.time_msec())
 
         @defer.inlineCallbacks
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 199b16d827..31659156ae 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -134,6 +134,13 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
             if prev_membership not in MEMBERSHIP_PRIORITY:
                 prev_membership = "leave"
 
+            # Always allow the user to see their own leave events, otherwise
+            # they won't see the room disappear if they reject the invite
+            if membership == "leave" and (
+                prev_membership == "join" or prev_membership == "invite"
+            ):
+                return True
+
             new_priority = MEMBERSHIP_PRIORITY.index(membership)
             old_priority = MEMBERSHIP_PRIORITY.index(prev_membership)
             if old_priority < new_priority: