diff options
Diffstat (limited to 'synapse/api/auth.py')
-rw-r--r-- | synapse/api/auth.py | 73 |
1 files changed, 71 insertions, 2 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py index e3b8c3099a..5c83aafa7d 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -14,15 +14,19 @@ # limitations under the License. """This module contains classes for authenticating the user.""" +from nacl.exceptions import BadSignatureError from twisted.internet import defer from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.errors import AuthError, Codes, SynapseError -from synapse.util.logutils import log_function from synapse.types import RoomID, UserID, EventID +from synapse.util.logutils import log_function +from synapse.util import third_party_invites +from unpaddedbase64 import decode_base64 import logging +import nacl.signing import pymacaroons logger = logging.getLogger(__name__) @@ -31,6 +35,7 @@ logger = logging.getLogger(__name__) AuthEventTypes = ( EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels, EventTypes.JoinRules, EventTypes.RoomHistoryVisibility, + EventTypes.ThirdPartyInvite, ) @@ -341,7 +346,8 @@ class Auth(object): pass elif join_rule == JoinRules.INVITE: if not caller_in_room and not caller_invited: - raise AuthError(403, "You are not invited to this room.") + if not self._verify_third_party_invite(event, auth_events): + raise AuthError(403, "You are not invited to this room.") else: # TODO (erikj): may_join list # TODO (erikj): private rooms @@ -367,6 +373,61 @@ class Auth(object): return True + def _verify_third_party_invite(self, event, auth_events): + """ + Validates that the join event is authorized by a previous third-party invite. + + Checks that the public key, and keyserver, match those in the invite, + and that the join event has a signature issued using that public key. + + Args: + event: The m.room.member join event being validated. + auth_events: All relevant previous context events which may be used + for authorization decisions. + + Return: + True if the event fulfills the expectations of a previous third party + invite event. + """ + if not third_party_invites.join_has_third_party_invite(event.content): + return False + join_third_party_invite = event.content["third_party_invite"] + token = join_third_party_invite["token"] + invite_event = auth_events.get( + (EventTypes.ThirdPartyInvite, token,) + ) + if not invite_event: + logger.info("Failing 3pid invite because no invite found for token %s", token) + return False + try: + public_key = join_third_party_invite["public_key"] + key_validity_url = join_third_party_invite["key_validity_url"] + if invite_event.content["public_key"] != public_key: + logger.info( + "Failing 3pid invite because public key invite: %s != join: %s", + invite_event.content["public_key"], + public_key + ) + return False + if invite_event.content["key_validity_url"] != key_validity_url: + logger.info( + "Failing 3pid invite because key_validity_url invite: %s != join: %s", + invite_event.content["key_validity_url"], + key_validity_url + ) + return False + for _, signature_block in join_third_party_invite["signatures"].items(): + for key_name, encoded_signature in signature_block.items(): + if not key_name.startswith("ed25519:"): + return False + verify_key = nacl.signing.VerifyKey(decode_base64(public_key)) + signature = decode_base64(encoded_signature) + verify_key.verify(token, signature) + return True + return False + except (KeyError, BadSignatureError,): + return False + def _get_power_level_event(self, auth_events): key = (EventTypes.PowerLevels, "", ) return auth_events.get(key) @@ -646,6 +707,14 @@ class Auth(object): if e_type == Membership.JOIN: if member_event and not is_public: auth_ids.append(member_event.event_id) + if third_party_invites.join_has_third_party_invite(event.content): + key = ( + EventTypes.ThirdPartyInvite, + event.content["third_party_invite"]["token"] + ) + invite = current_state.get(key) + if invite: + auth_ids.append(invite.event_id) else: if member_event: auth_ids.append(member_event.event_id) |