| diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index dfbbc5a1cd..3e891a6193 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -24,7 +24,6 @@ from synapse.api.constants import EventTypes, Membership, JoinRules
 from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
 from synapse.types import RoomID, UserID, EventID
 from synapse.util.logutils import log_function
-from synapse.util import third_party_invites
 from unpaddedbase64 import decode_base64
 
 import logging
@@ -318,6 +317,11 @@ class Auth(object):
             }
         )
 
+        if Membership.INVITE == membership and "third_party_invite" in event.content:
+            if not self._verify_third_party_invite(event, auth_events):
+                raise AuthError(403, "You are not invited to this room.")
+            return True
+
         if Membership.JOIN != membership:
             if (caller_invited
                     and Membership.LEAVE == membership
@@ -361,8 +365,7 @@ class Auth(object):
                 pass
             elif join_rule == JoinRules.INVITE:
                 if not caller_in_room and not caller_invited:
-                    if not self._verify_third_party_invite(event, auth_events):
-                        raise AuthError(403, "You are not invited to this room.")
+                    raise AuthError(403, "You are not invited to this room.")
             else:
                 # TODO (erikj): may_join list
                 # TODO (erikj): private rooms
@@ -390,10 +393,10 @@ class Auth(object):
 
     def _verify_third_party_invite(self, event, auth_events):
         """
-        Validates that the join event is authorized by a previous third-party invite.
+        Validates that the invite event is authorized by a previous third-party invite.
 
-        Checks that the public key, and keyserver, match those in the invite,
-        and that the join event has a signature issued using that public key.
+        Checks that the public key, and keyserver, match those in the third party invite,
+        and that the invite event has a signature issued using that public key.
 
         Args:
             event: The m.room.member join event being validated.
@@ -404,35 +407,28 @@ class Auth(object):
             True if the event fulfills the expectations of a previous third party
             invite event.
         """
-        if not third_party_invites.join_has_third_party_invite(event.content):
+        if "third_party_invite" not in event.content:
+            return False
+        if "signed" not in event.content["third_party_invite"]:
             return False
-        join_third_party_invite = event.content["third_party_invite"]
-        token = join_third_party_invite["token"]
+        signed = event.content["third_party_invite"]["signed"]
+        for key in {"mxid", "token"}:
+            if key not in signed:
+                return False
+
+        token = signed["token"]
+
         invite_event = auth_events.get(
             (EventTypes.ThirdPartyInvite, token,)
         )
         if not invite_event:
-            logger.info("Failing 3pid invite because no invite found for token %s", token)
+            return False
+
+        if event.user_id != invite_event.user_id:
             return False
         try:
-            public_key = join_third_party_invite["public_key"]
-            key_validity_url = join_third_party_invite["key_validity_url"]
-            if invite_event.content["public_key"] != public_key:
-                logger.info(
-                    "Failing 3pid invite because public key invite: %s != join: %s",
-                    invite_event.content["public_key"],
-                    public_key
-                )
-                return False
-            if invite_event.content["key_validity_url"] != key_validity_url:
-                logger.info(
-                    "Failing 3pid invite because key_validity_url invite: %s != join: %s",
-                    invite_event.content["key_validity_url"],
-                    key_validity_url
-                )
-                return False
-            signed = join_third_party_invite["signed"]
-            if signed["mxid"] != event.user_id:
+            public_key = invite_event.content["public_key"]
+            if signed["mxid"] != event.state_key:
                 return False
             if signed["token"] != token:
                 return False
@@ -445,6 +441,11 @@ class Auth(object):
                         decode_base64(public_key)
                     )
                     verify_signed_json(signed, server, verify_key)
+
+                    # We got the public key from the invite, so we know that the
+                    # correct server signed the signed bundle.
+                    # The caller is responsible for checking that the signing
+                    # server has not revoked that public key.
                     return True
             return False
         except (KeyError, SignatureVerifyException,):
@@ -751,17 +752,19 @@ class Auth(object):
             if e_type == Membership.JOIN:
                 if member_event and not is_public:
                     auth_ids.append(member_event.event_id)
-                if third_party_invites.join_has_third_party_invite(event.content):
+            else:
+                if member_event:
+                    auth_ids.append(member_event.event_id)
+
+            if e_type == Membership.INVITE:
+                if "third_party_invite" in event.content:
                     key = (
                         EventTypes.ThirdPartyInvite,
                         event.content["third_party_invite"]["token"]
                     )
-                    invite = current_state.get(key)
-                    if invite:
-                        auth_ids.append(invite.event_id)
-            else:
-                if member_event:
-                    auth_ids.append(member_event.event_id)
+                    third_party_invite = current_state.get(key)
+                    if third_party_invite:
+                        auth_ids.append(third_party_invite.event_id)
         elif member_event:
             if member_event.content["membership"] == Membership.JOIN:
                 auth_ids.append(member_event.event_id)
 |