summary refs log tree commit diff
path: root/synapse/api
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/api')
-rw-r--r--synapse/api/auth.py168
-rw-r--r--synapse/api/constants.py1
-rw-r--r--synapse/api/errors.py1
-rw-r--r--synapse/api/filtering.py27
4 files changed, 103 insertions, 94 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 88445fe999..3e891a6193 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -24,7 +24,6 @@ from synapse.api.constants import EventTypes, Membership, JoinRules
 from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
 from synapse.types import RoomID, UserID, EventID
 from synapse.util.logutils import log_function
-from synapse.util import third_party_invites
 from unpaddedbase64 import decode_base64
 
 import logging
@@ -49,6 +48,7 @@ class Auth(object):
         self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
         self._KNOWN_CAVEAT_PREFIXES = set([
             "gen = ",
+            "guest = ",
             "type = ",
             "time < ",
             "user_id = ",
@@ -183,15 +183,11 @@ class Auth(object):
         defer.returnValue(member)
 
     @defer.inlineCallbacks
-    def check_user_was_in_room(self, room_id, user_id, current_state=None):
+    def check_user_was_in_room(self, room_id, user_id):
         """Check if the user was in the room at some point.
         Args:
             room_id(str): The room to check.
             user_id(str): The user to check.
-            current_state(dict): Optional map of the current state of the room.
-                If provided then that map is used to check whether they are a
-                member of the room. Otherwise the current membership is
-                loaded from the database.
         Raises:
             AuthError if the user was never in the room.
         Returns:
@@ -199,17 +195,11 @@ class Auth(object):
             room. This will be the join event if they are currently joined to
             the room. This will be the leave event if they have left the room.
         """
-        if current_state:
-            member = current_state.get(
-                (EventTypes.Member, user_id),
-                None
-            )
-        else:
-            member = yield self.state.get_current_state(
-                room_id=room_id,
-                event_type=EventTypes.Member,
-                state_key=user_id
-            )
+        member = yield self.state.get_current_state(
+            room_id=room_id,
+            event_type=EventTypes.Member,
+            state_key=user_id
+        )
         membership = member.membership if member else None
 
         if membership not in (Membership.JOIN, Membership.LEAVE):
@@ -327,6 +317,11 @@ class Auth(object):
             }
         )
 
+        if Membership.INVITE == membership and "third_party_invite" in event.content:
+            if not self._verify_third_party_invite(event, auth_events):
+                raise AuthError(403, "You are not invited to this room.")
+            return True
+
         if Membership.JOIN != membership:
             if (caller_invited
                     and Membership.LEAVE == membership
@@ -370,8 +365,7 @@ class Auth(object):
                 pass
             elif join_rule == JoinRules.INVITE:
                 if not caller_in_room and not caller_invited:
-                    if not self._verify_third_party_invite(event, auth_events):
-                        raise AuthError(403, "You are not invited to this room.")
+                    raise AuthError(403, "You are not invited to this room.")
             else:
                 # TODO (erikj): may_join list
                 # TODO (erikj): private rooms
@@ -399,10 +393,10 @@ class Auth(object):
 
     def _verify_third_party_invite(self, event, auth_events):
         """
-        Validates that the join event is authorized by a previous third-party invite.
+        Validates that the invite event is authorized by a previous third-party invite.
 
-        Checks that the public key, and keyserver, match those in the invite,
-        and that the join event has a signature issued using that public key.
+        Checks that the public key, and keyserver, match those in the third party invite,
+        and that the invite event has a signature issued using that public key.
 
         Args:
             event: The m.room.member join event being validated.
@@ -413,35 +407,28 @@ class Auth(object):
             True if the event fulfills the expectations of a previous third party
             invite event.
         """
-        if not third_party_invites.join_has_third_party_invite(event.content):
+        if "third_party_invite" not in event.content:
             return False
-        join_third_party_invite = event.content["third_party_invite"]
-        token = join_third_party_invite["token"]
+        if "signed" not in event.content["third_party_invite"]:
+            return False
+        signed = event.content["third_party_invite"]["signed"]
+        for key in {"mxid", "token"}:
+            if key not in signed:
+                return False
+
+        token = signed["token"]
+
         invite_event = auth_events.get(
             (EventTypes.ThirdPartyInvite, token,)
         )
         if not invite_event:
-            logger.info("Failing 3pid invite because no invite found for token %s", token)
+            return False
+
+        if event.user_id != invite_event.user_id:
             return False
         try:
-            public_key = join_third_party_invite["public_key"]
-            key_validity_url = join_third_party_invite["key_validity_url"]
-            if invite_event.content["public_key"] != public_key:
-                logger.info(
-                    "Failing 3pid invite because public key invite: %s != join: %s",
-                    invite_event.content["public_key"],
-                    public_key
-                )
-                return False
-            if invite_event.content["key_validity_url"] != key_validity_url:
-                logger.info(
-                    "Failing 3pid invite because key_validity_url invite: %s != join: %s",
-                    invite_event.content["key_validity_url"],
-                    key_validity_url
-                )
-                return False
-            signed = join_third_party_invite["signed"]
-            if signed["mxid"] != event.user_id:
+            public_key = invite_event.content["public_key"]
+            if signed["mxid"] != event.state_key:
                 return False
             if signed["token"] != token:
                 return False
@@ -454,6 +441,11 @@ class Auth(object):
                         decode_base64(public_key)
                     )
                     verify_signed_json(signed, server, verify_key)
+
+                    # We got the public key from the invite, so we know that the
+                    # correct server signed the signed bundle.
+                    # The caller is responsible for checking that the signing
+                    # server has not revoked that public key.
                     return True
             return False
         except (KeyError, SignatureVerifyException,):
@@ -497,7 +489,7 @@ class Auth(object):
             return default
 
     @defer.inlineCallbacks
-    def get_user_by_req(self, request):
+    def get_user_by_req(self, request, allow_guest=False):
         """ Get a registered user's ID.
 
         Args:
@@ -535,7 +527,7 @@ class Auth(object):
 
                 request.authenticated_entity = user_id
 
-                defer.returnValue((UserID.from_string(user_id), ""))
+                defer.returnValue((UserID.from_string(user_id), "", False))
                 return
             except KeyError:
                 pass  # normal users won't have the user_id query parameter set.
@@ -543,6 +535,7 @@ class Auth(object):
             user_info = yield self._get_user_by_access_token(access_token)
             user = user_info["user"]
             token_id = user_info["token_id"]
+            is_guest = user_info["is_guest"]
 
             ip_addr = self.hs.get_ip_from_request(request)
             user_agent = request.requestHeaders.getRawHeaders(
@@ -557,9 +550,14 @@ class Auth(object):
                     user_agent=user_agent
                 )
 
+            if is_guest and not allow_guest:
+                raise AuthError(
+                    403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
+                )
+
             request.authenticated_entity = user.to_string()
 
-            defer.returnValue((user, token_id,))
+            defer.returnValue((user, token_id, is_guest,))
         except KeyError:
             raise AuthError(
                 self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
@@ -592,31 +590,45 @@ class Auth(object):
             self._validate_macaroon(macaroon)
 
             user_prefix = "user_id = "
+            user = None
+            guest = False
             for caveat in macaroon.caveats:
                 if caveat.caveat_id.startswith(user_prefix):
                     user = UserID.from_string(caveat.caveat_id[len(user_prefix):])
-                    # This codepath exists so that we can actually return a
-                    # token ID, because we use token IDs in place of device
-                    # identifiers throughout the codebase.
-                    # TODO(daniel): Remove this fallback when device IDs are
-                    # properly implemented.
-                    ret = yield self._look_up_user_by_access_token(macaroon_str)
-                    if ret["user"] != user:
-                        logger.error(
-                            "Macaroon user (%s) != DB user (%s)",
-                            user,
-                            ret["user"]
-                        )
-                        raise AuthError(
-                            self.TOKEN_NOT_FOUND_HTTP_STATUS,
-                            "User mismatch in macaroon",
-                            errcode=Codes.UNKNOWN_TOKEN
-                        )
-                    defer.returnValue(ret)
-            raise AuthError(
-                self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
-                errcode=Codes.UNKNOWN_TOKEN
-            )
+                elif caveat.caveat_id == "guest = true":
+                    guest = True
+
+            if user is None:
+                raise AuthError(
+                    self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
+                    errcode=Codes.UNKNOWN_TOKEN
+                )
+
+            if guest:
+                ret = {
+                    "user": user,
+                    "is_guest": True,
+                    "token_id": None,
+                }
+            else:
+                # This codepath exists so that we can actually return a
+                # token ID, because we use token IDs in place of device
+                # identifiers throughout the codebase.
+                # TODO(daniel): Remove this fallback when device IDs are
+                # properly implemented.
+                ret = yield self._look_up_user_by_access_token(macaroon_str)
+                if ret["user"] != user:
+                    logger.error(
+                        "Macaroon user (%s) != DB user (%s)",
+                        user,
+                        ret["user"]
+                    )
+                    raise AuthError(
+                        self.TOKEN_NOT_FOUND_HTTP_STATUS,
+                        "User mismatch in macaroon",
+                        errcode=Codes.UNKNOWN_TOKEN
+                    )
+            defer.returnValue(ret)
         except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
             raise AuthError(
                 self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.",
@@ -629,6 +641,7 @@ class Auth(object):
         v.satisfy_exact("type = access")
         v.satisfy_general(lambda c: c.startswith("user_id = "))
         v.satisfy_general(self._verify_expiry)
+        v.satisfy_exact("guest = true")
         v.verify(macaroon, self.hs.config.macaroon_secret_key)
 
         v = pymacaroons.Verifier()
@@ -666,6 +679,7 @@ class Auth(object):
         user_info = {
             "user": UserID.from_string(ret.get("name")),
             "token_id": ret.get("token_id", None),
+            "is_guest": False,
         }
         defer.returnValue(user_info)
 
@@ -738,17 +752,19 @@ class Auth(object):
             if e_type == Membership.JOIN:
                 if member_event and not is_public:
                     auth_ids.append(member_event.event_id)
-                if third_party_invites.join_has_third_party_invite(event.content):
+            else:
+                if member_event:
+                    auth_ids.append(member_event.event_id)
+
+            if e_type == Membership.INVITE:
+                if "third_party_invite" in event.content:
                     key = (
                         EventTypes.ThirdPartyInvite,
                         event.content["third_party_invite"]["token"]
                     )
-                    invite = current_state.get(key)
-                    if invite:
-                        auth_ids.append(invite.event_id)
-            else:
-                if member_event:
-                    auth_ids.append(member_event.event_id)
+                    third_party_invite = current_state.get(key)
+                    if third_party_invite:
+                        auth_ids.append(third_party_invite.event_id)
         elif member_event:
             if member_event.content["membership"] == Membership.JOIN:
                 auth_ids.append(member_event.event_id)
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 41125e8719..c2450b771a 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -68,6 +68,7 @@ class EventTypes(object):
     RoomHistoryVisibility = "m.room.history_visibility"
     CanonicalAlias = "m.room.canonical_alias"
     RoomAvatar = "m.room.avatar"
+    GuestAccess = "m.room.guest_access"
 
     # These are used for validation
     Message = "m.room.message"
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index b3fea27d0e..d4037b3d55 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -33,6 +33,7 @@ class Codes(object):
     NOT_FOUND = "M_NOT_FOUND"
     MISSING_TOKEN = "M_MISSING_TOKEN"
     UNKNOWN_TOKEN = "M_UNKNOWN_TOKEN"
+    GUEST_ACCESS_FORBIDDEN = "M_GUEST_ACCESS_FORBIDDEN"
     LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED"
     CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED"
     CAPTCHA_INVALID = "M_CAPTCHA_INVALID"
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index eb15d8c54a..aaa2433cae 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -50,11 +50,11 @@ class Filtering(object):
         # many definitions.
 
         top_level_definitions = [
-            "public_user_data", "private_user_data", "server_data"
+            "presence"
         ]
 
         room_level_definitions = [
-            "state", "timeline", "ephemeral"
+            "state", "timeline", "ephemeral", "private_user_data"
         ]
 
         for key in top_level_definitions:
@@ -114,22 +114,6 @@ class Filtering(object):
                     if not isinstance(event_type, basestring):
                         raise SynapseError(400, "Event type should be a string")
 
-        if "format" in definition:
-            event_format = definition["format"]
-            if event_format not in ["federation", "events"]:
-                raise SynapseError(400, "Invalid format: %s" % (event_format,))
-
-        if "select" in definition:
-            event_select_list = definition["select"]
-            for select_key in event_select_list:
-                if select_key not in ["event_id", "origin_server_ts",
-                                      "thread_id", "content", "content.body"]:
-                    raise SynapseError(400, "Bad select: %s" % (select_key,))
-
-        if ("bundle_updates" in definition and
-                type(definition["bundle_updates"]) != bool):
-            raise SynapseError(400, "Bad bundle_updates: expected bool.")
-
 
 class FilterCollection(object):
     def __init__(self, filter_json):
@@ -147,6 +131,10 @@ class FilterCollection(object):
             self.filter_json.get("room", {}).get("ephemeral", {})
         )
 
+        self.room_private_user_data = Filter(
+            self.filter_json.get("room", {}).get("private_user_data", {})
+        )
+
         self.presence_filter = Filter(
             self.filter_json.get("presence", {})
         )
@@ -172,6 +160,9 @@ class FilterCollection(object):
     def filter_room_ephemeral(self, events):
         return self.room_ephemeral_filter.filter(events)
 
+    def filter_room_private_user_data(self, events):
+        return self.room_private_user_data.filter(events)
+
 
 class Filter(object):
     def __init__(self, filter_json):