From a2a6c1c22f270047fe23a96011b3675366ed6d96 Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Mon, 21 Nov 2016 13:13:55 +0000 Subject: Fail with a coherent error message if `/sync?filter=` is invalid --- synapse/api/errors.py | 1 + 1 file changed, 1 insertion(+) (limited to 'synapse/api') diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 0041646858..921c457738 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -39,6 +39,7 @@ class Codes(object): CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED" CAPTCHA_INVALID = "M_CAPTCHA_INVALID" MISSING_PARAM = "M_MISSING_PARAM" + INVALID_PARAM = "M_INVALID_PARAM" TOO_LARGE = "M_TOO_LARGE" EXCLUSIVE = "M_EXCLUSIVE" THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED" -- cgit 1.4.1 From e90fcd9edd66b78e520b407e6ffa3a66be2c9178 Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Mon, 21 Nov 2016 15:18:18 +0000 Subject: Add filter_event_fields and filter_field to FilterCollection --- synapse/api/filtering.py | 69 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) (limited to 'synapse/api') diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 3b3ef70750..27f8b99e3d 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -18,6 +18,7 @@ from synapse.types import UserID, RoomID from twisted.internet import defer import ujson as json +import re class Filtering(object): @@ -71,6 +72,21 @@ class Filtering(object): if key in user_filter_json["room"]: self._check_definition(user_filter_json["room"][key]) + if "event_fields" in user_filter_json: + if type(user_filter_json["event_fields"]) != list: + raise SynapseError(400, "event_fields must be a list of strings") + for field in user_filter_json["event_fields"]: + if not isinstance(field, basestring): + raise SynapseError(400, "Event field must be a string") + # Don't allow '\\' in event field filters. This makes matching + # events a lot easier as we can then use a negative lookbehind + # assertion to split '\.' If we allowed \\ then it would + # incorrectly split '\\.' + if r'\\' in field: + raise SynapseError( + 400, r'The escape character \ cannot itself be escaped' + ) + def _check_definition_room_lists(self, definition): """Check that "rooms" and "not_rooms" are lists of room ids if they are present @@ -152,6 +168,11 @@ class FilterCollection(object): self.include_leave = filter_json.get("room", {}).get( "include_leave", False ) + self._event_fields = filter_json.get("event_fields", []) + # Negative lookbehind assertion for '\' + # (?" % (json.dumps(self._filter_json),) @@ -186,6 +207,54 @@ class FilterCollection(object): def filter_room_account_data(self, events): return self._room_account_data.filter(self._room_filter.filter(events)) + def filter_event_fields(self, event): + """Remove fields from an event in accordance with the 'event_fields' of a filter. + + If there are no event fields specified then all fields are included. + The entries may include '.' charaters to indicate sub-fields. + So ['content.body'] will include the 'body' field of the 'content' object. + A literal '.' character in a field name may be escaped using a '\'. + + Args: + event(dict): The raw event to filter + Returns: + dict: The same event with some fields missing, if required. + """ + for field in self._event_fields: + self.filter_field(event, field) + return event + + def filter_field(self, dictionary, field): + """Filter the given field from the given dictionary. + + Args: + dictionary(dict): The dictionary to remove the field from. + field(str): The key to remove. + Returns: + dict: The same dictionary with the field removed. + """ + # "content.body.thing\.with\.dots" => ["content", "body", "thing\.with\.dots"] + sub_fields = self._split_field_regex.split(field) + # remove escaping so we can use the right key names when deleting + sub_fields = [f.replace(r'\.', r'.') for f in sub_fields] + + # common case e.g. 'origin_server_ts' + if len(sub_fields) == 1: + dictionary.pop(sub_fields[0], None) + # nested field e.g. 'content.body' + elif len(sub_fields) > 1: + # Pop the last field as that's the key to delete and we need the + # parent dict in order to remove the key. Drill down to the right dict. + key_to_delete = sub_fields.pop(-1) + sub_dict = dictionary + for sub_field in sub_fields: + if sub_field in sub_dict and type(sub_dict[sub_field]) == dict: + sub_dict = sub_dict[sub_field] + else: + return dictionary + sub_dict.pop(key_to_delete, None) + return dictionary + class Filter(object): def __init__(self, filter_json): -- cgit 1.4.1 From f97511a1f3197c6011b5ef7a363885dde9939d6b Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Mon, 21 Nov 2016 17:42:16 +0000 Subject: Move event_fields filtering to serialize_event Also make it an inclusive not exclusive filter, as the spec demands. --- synapse/api/filtering.py | 56 +------------------------ synapse/events/utils.py | 101 +++++++++++++++++++++++++++++++++++++++++++-- tests/events/test_utils.py | 21 ++++++++++ 3 files changed, 119 insertions(+), 59 deletions(-) (limited to 'synapse/api') diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 27f8b99e3d..4fd0e2d9fa 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -18,7 +18,6 @@ from synapse.types import UserID, RoomID from twisted.internet import defer import ujson as json -import re class Filtering(object): @@ -81,7 +80,7 @@ class Filtering(object): # Don't allow '\\' in event field filters. This makes matching # events a lot easier as we can then use a negative lookbehind # assertion to split '\.' If we allowed \\ then it would - # incorrectly split '\\.' + # incorrectly split '\\.' See synapse.events.utils.serialize_event if r'\\' in field: raise SynapseError( 400, r'The escape character \ cannot itself be escaped' @@ -168,11 +167,6 @@ class FilterCollection(object): self.include_leave = filter_json.get("room", {}).get( "include_leave", False ) - self._event_fields = filter_json.get("event_fields", []) - # Negative lookbehind assertion for '\' - # (?" % (json.dumps(self._filter_json),) @@ -207,54 +201,6 @@ class FilterCollection(object): def filter_room_account_data(self, events): return self._room_account_data.filter(self._room_filter.filter(events)) - def filter_event_fields(self, event): - """Remove fields from an event in accordance with the 'event_fields' of a filter. - - If there are no event fields specified then all fields are included. - The entries may include '.' charaters to indicate sub-fields. - So ['content.body'] will include the 'body' field of the 'content' object. - A literal '.' character in a field name may be escaped using a '\'. - - Args: - event(dict): The raw event to filter - Returns: - dict: The same event with some fields missing, if required. - """ - for field in self._event_fields: - self.filter_field(event, field) - return event - - def filter_field(self, dictionary, field): - """Filter the given field from the given dictionary. - - Args: - dictionary(dict): The dictionary to remove the field from. - field(str): The key to remove. - Returns: - dict: The same dictionary with the field removed. - """ - # "content.body.thing\.with\.dots" => ["content", "body", "thing\.with\.dots"] - sub_fields = self._split_field_regex.split(field) - # remove escaping so we can use the right key names when deleting - sub_fields = [f.replace(r'\.', r'.') for f in sub_fields] - - # common case e.g. 'origin_server_ts' - if len(sub_fields) == 1: - dictionary.pop(sub_fields[0], None) - # nested field e.g. 'content.body' - elif len(sub_fields) > 1: - # Pop the last field as that's the key to delete and we need the - # parent dict in order to remove the key. Drill down to the right dict. - key_to_delete = sub_fields.pop(-1) - sub_dict = dictionary - for sub_field in sub_fields: - if sub_field in sub_dict and type(sub_dict[sub_field]) == dict: - sub_dict = sub_dict[sub_field] - else: - return dictionary - sub_dict.pop(key_to_delete, None) - return dictionary - class Filter(object): def __init__(self, filter_json): diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 0e9fd902af..4febd98f43 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -16,6 +16,15 @@ from synapse.api.constants import EventTypes from . import EventBase +import re + +# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\' +# (?): List of keys to drill down to in 'src'. + """ + if len(field) == 0: # this should be impossible + return + if len(field) == 1: # common case e.g. 'origin_server_ts' + if field[0] in src: + dst[field[0]] = src[field[0]] + return + + # Else is a nested field e.g. 'content.body' + # Pop the last field as that's the key to move across and we need the + # parent dict in order to access the data. Drill down to the right dict. + key_to_move = field.pop(-1) + sub_dict = src + for sub_field in field: # e.g. sub_field => "content" + if sub_field in sub_dict and type(sub_dict[sub_field]) == dict: + sub_dict = sub_dict[sub_field] + else: + return + + if key_to_move not in sub_dict: + return + + # Insert the key into the output dictionary, creating nested objects + # as required. We couldn't do this any earlier or else we'd need to delete + # the empty objects if the key didn't exist. + sub_out_dict = dst + for sub_field in field: + if sub_field not in sub_out_dict: + sub_out_dict[sub_field] = {} + sub_out_dict = sub_out_dict[sub_field] + sub_out_dict[key_to_move] = sub_dict[key_to_move] + + +def only_fields(dictionary, fields): + """Return a new dict with only the fields in 'dictionary' which are present + in 'fields'. + + If there are no event fields specified then all fields are included. + The entries may include '.' charaters to indicate sub-fields. + So ['content.body'] will include the 'body' field of the 'content' object. + A literal '.' character in a field name may be escaped using a '\'. + + Args: + dictionary(dict): The dictionary to read from. + fields(list): A list of fields to copy over. Only shallow refs are + taken. + Returns: + dict: A new dictionary with only the given fields. If fields was empty, + the same dictionary is returned. + """ + if len(fields) == 0: + return dictionary + + # for each field, convert it: + # ["content.body.thing\.with\.dots"] => [["content", "body", "thing\.with\.dots"]] + split_fields = [SPLIT_FIELD_REGEX.split(f) for f in fields] + + # for each element of the output array of arrays: + # remove escaping so we can use the right key names. This purposefully avoids + # using list comprehensions to avoid needless allocations as this may be called + # on a lot of events. + for field_array in split_fields: + for i, field in enumerate(field_array): + field_array[i] = field.replace(r'\.', r'.') + + output = {} + for field_array in split_fields: + _copy_field(dictionary, output, field_array) + return output + + def format_event_raw(d): return d @@ -137,7 +227,7 @@ def format_event_for_client_v2_without_room_id(d): def serialize_event(e, time_now_ms, as_client_event=True, event_format=format_event_for_client_v1, - token_id=None): + token_id=None, event_fields=None): # FIXME(erikj): To handle the case of presence events and the like if not isinstance(e, EventBase): return e @@ -164,6 +254,9 @@ def serialize_event(e, time_now_ms, as_client_event=True, d["unsigned"]["transaction_id"] = txn_id if as_client_event: - return event_format(d) - else: - return d + d = event_format(d) + + if isinstance(event_fields, list): + d = only_fields(d, event_fields) + + return d diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index fb0953c4ec..b9f55d174d 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -114,3 +114,24 @@ class PruneEventTestCase(unittest.TestCase): 'unsigned': {}, } ) + + +class SerializeEventTestCase(unittest.TestCase): + + def test_event_fields_works_with_keys(self): + pass + + def test_event_fields_works_with_nested_keys(self): + pass + + def test_event_fields_works_with_dot_keys(self): + pass + + def test_event_fields_works_with_nested_dot_keys(self): + pass + + def test_event_fields_nops_with_unknown_keys(self): + pass + + def test_event_fields_nops_with_non_dict_keys(self): + pass -- cgit 1.4.1 From cea4e4e7b2534f85abbb90a0cc07125db0aa1727 Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Tue, 22 Nov 2016 10:14:05 +0000 Subject: Glue only_event_fields into the sync rest servlet --- synapse/api/filtering.py | 1 + synapse/events/utils.py | 2 +- synapse/rest/client/v2_alpha/sync.py | 23 +++++++++++++---------- 3 files changed, 15 insertions(+), 11 deletions(-) (limited to 'synapse/api') diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 4fd0e2d9fa..6f16e4d406 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -167,6 +167,7 @@ class FilterCollection(object): self.include_leave = filter_json.get("room", {}).get( "include_leave", False ) + self.event_fields = filter_json.get("event_fields", []) def __repr__(self): return "" % (json.dumps(self._filter_json),) diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 9a700d39bb..ce4fe55204 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -258,7 +258,7 @@ def serialize_event(e, time_now_ms, as_client_event=True, if as_client_event: d = event_format(d) - if (isinstance(only_event_fields, list) and + if (only_event_fields and isinstance(only_event_fields, list) and all(isinstance(f, basestring) for f in only_event_fields)): d = only_fields(d, only_event_fields) diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 6fc63715aa..7199ec883a 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -162,7 +162,7 @@ class SyncRestServlet(RestServlet): time_now = self.clock.time_msec() joined = self.encode_joined( - sync_result.joined, time_now, requester.access_token_id + sync_result.joined, time_now, requester.access_token_id, filter.event_fields ) invited = self.encode_invited( @@ -170,7 +170,7 @@ class SyncRestServlet(RestServlet): ) archived = self.encode_archived( - sync_result.archived, time_now, requester.access_token_id + sync_result.archived, time_now, requester.access_token_id, filter.event_fields ) response_content = { @@ -197,7 +197,7 @@ class SyncRestServlet(RestServlet): formatted.append(event) return {"events": formatted} - def encode_joined(self, rooms, time_now, token_id): + def encode_joined(self, rooms, time_now, token_id, event_fields): """ Encode the joined rooms in a sync result @@ -208,7 +208,8 @@ class SyncRestServlet(RestServlet): calculations token_id(int): ID of the user's auth token - used for namespacing of transaction IDs - + event_fields(list): List of event fields to include. If empty, + all fields will be returned. Returns: dict[str, dict[str, object]]: the joined rooms list, in our response format @@ -216,7 +217,7 @@ class SyncRestServlet(RestServlet): joined = {} for room in rooms: joined[room.room_id] = self.encode_room( - room, time_now, token_id + room, time_now, token_id, only_fields=event_fields ) return joined @@ -253,7 +254,7 @@ class SyncRestServlet(RestServlet): return invited - def encode_archived(self, rooms, time_now, token_id): + def encode_archived(self, rooms, time_now, token_id, event_fields): """ Encode the archived rooms in a sync result @@ -264,7 +265,8 @@ class SyncRestServlet(RestServlet): calculations token_id(int): ID of the user's auth token - used for namespacing of transaction IDs - + event_fields(list): List of event fields to include. If empty, + all fields will be returned. Returns: dict[str, dict[str, object]]: The invited rooms list, in our response format @@ -272,13 +274,13 @@ class SyncRestServlet(RestServlet): joined = {} for room in rooms: joined[room.room_id] = self.encode_room( - room, time_now, token_id, joined=False + room, time_now, token_id, joined=False, only_fields=event_fields ) return joined @staticmethod - def encode_room(room, time_now, token_id, joined=True): + def encode_room(room, time_now, token_id, joined=True, only_fields=None): """ Args: room (JoinedSyncResult|ArchivedSyncResult): sync result for a @@ -289,7 +291,7 @@ class SyncRestServlet(RestServlet): of transaction IDs joined (bool): True if the user is joined to this room - will mean we handle ephemeral events - + only_fields(list): Optional. The list of event fields to include. Returns: dict[str, object]: the room, encoded in our response format """ @@ -298,6 +300,7 @@ class SyncRestServlet(RestServlet): return serialize_event( event, time_now, token_id=token_id, event_format=format_event_for_client_v2_without_room_id, + only_event_fields=only_fields, ) state_dict = room.state -- cgit 1.4.1 From 83bcdcee616806ad4b39582b1015a37679b82b9a Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Tue, 22 Nov 2016 16:38:35 +0000 Subject: Return early on /sync code paths if a '*' filter is used This is currently very conservative in that it only does this if there is no `since` token. This limits the risk to clients likely to be doing one-off syncs (like bridges), but does mean that normal human clients won't benefit from the time savings here. If the savings are large enough, I would consider generalising this to just check the filter. --- synapse/api/filtering.py | 29 +++++++++++++++++++++++++++++ synapse/handlers/sync.py | 31 ++++++++++++++++++++++--------- 2 files changed, 51 insertions(+), 9 deletions(-) (limited to 'synapse/api') diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 6f16e4d406..fb291d7fb9 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -202,6 +202,26 @@ class FilterCollection(object): def filter_room_account_data(self, events): return self._room_account_data.filter(self._room_filter.filter(events)) + def blocks_all_presence(self): + return ( + self._presence_filter.filters_all_types() or + self._presence_filter.filters_all_senders() + ) + + def blocks_all_room_ephemeral(self): + return ( + self._room_ephemeral_filter.filters_all_types() or + self._room_ephemeral_filter.filters_all_senders() or + self._room_ephemeral_filter.filters_all_rooms() + ) + + def blocks_all_room_timeline(self): + return ( + self._room_timeline_filter.filters_all_types() or + self._room_timeline_filter.filters_all_senders() or + self._room_timeline_filter.filters_all_rooms() + ) + class Filter(object): def __init__(self, filter_json): @@ -218,6 +238,15 @@ class Filter(object): self.contains_url = self.filter_json.get("contains_url", None) + def filters_all_types(self): + return "*" in self.not_types + + def filters_all_senders(self): + return "*" in self.not_senders + + def filters_all_rooms(self): + return "*" in self.not_rooms + def check(self, event): """Checks whether the filter matches the given event. diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 1f910ff814..a86996689c 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -277,6 +277,7 @@ class SyncHandler(object): """ with Measure(self.clock, "load_filtered_recents"): timeline_limit = sync_config.filter_collection.timeline_limit() + block_all_timeline = sync_config.filter_collection.blocks_all_room_timeline() if recents is None or newly_joined_room or timeline_limit < len(recents): limited = True @@ -293,7 +294,7 @@ class SyncHandler(object): else: recents = [] - if not limited: + if not limited or block_all_timeline: defer.returnValue(TimelineBatch( events=recents, prev_batch=now_token, @@ -531,9 +532,14 @@ class SyncHandler(object): ) newly_joined_rooms, newly_joined_users = res - yield self._generate_sync_entry_for_presence( - sync_result_builder, newly_joined_rooms, newly_joined_users + block_all_presence_data = ( + since_token is None and + sync_config.filter_collection.blocks_all_presence() ) + if not block_all_presence_data: + yield self._generate_sync_entry_for_presence( + sync_result_builder, newly_joined_rooms, newly_joined_users + ) yield self._generate_sync_entry_for_to_device(sync_result_builder) @@ -709,13 +715,20 @@ class SyncHandler(object): `(newly_joined_rooms, newly_joined_users)` """ user_id = sync_result_builder.sync_config.user.to_string() - - now_token, ephemeral_by_room = yield self.ephemeral_by_room( - sync_result_builder.sync_config, - now_token=sync_result_builder.now_token, - since_token=sync_result_builder.since_token, + block_all_room_ephemeral = ( + sync_result_builder.since_token is None and + sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral() ) - sync_result_builder.now_token = now_token + + if block_all_room_ephemeral: + ephemeral_by_room = {} + else: + now_token, ephemeral_by_room = yield self.ephemeral_by_room( + sync_result_builder.sync_config, + now_token=sync_result_builder.now_token, + since_token=sync_result_builder.since_token, + ) + sync_result_builder.now_token = now_token ignored_account_data = yield self.store.get_global_account_data_by_type_for_user( "m.ignored_user_list", user_id=user_id, -- cgit 1.4.1 From e1d7c96814a106cac2c2652d58825537e606a07a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 24 Nov 2016 12:38:17 +0000 Subject: Remove redundant list of known caveat prefixes Also add some comments. --- synapse/api/auth.py | 34 ++++++++-------------------------- 1 file changed, 8 insertions(+), 26 deletions(-) (limited to 'synapse/api') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 69b3392735..1ab27da941 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -51,17 +51,6 @@ class Auth(object): self.store = hs.get_datastore() self.state = hs.get_state_handler() self.TOKEN_NOT_FOUND_HTTP_STATUS = 401 - # Docs for these currently lives at - # github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst - # In addition, we have type == delete_pusher which grants access only to - # delete pushers. - self._KNOWN_CAVEAT_PREFIXES = set([ - "gen = ", - "guest = ", - "type = ", - "time < ", - "user_id = ", - ]) @defer.inlineCallbacks def check_from_context(self, event, context, do_sig_check=True): @@ -801,11 +790,17 @@ class Auth(object): type_string(str): The kind of token required (e.g. "access", "refresh", "delete_pusher") verify_expiry(bool): Whether to verify whether the macaroon has expired. - This should really always be True, but no clients currently implement - token refresh, so we can't enforce expiry yet. + This should really always be True, but there exist access tokens + in the wild which expire when they should not, so we can't + enforce expiry yet. user_id (str): The user_id required """ v = pymacaroons.Verifier() + + # the verifier runs a test for every caveat on the macaroon, to check + # that it is met for the current request. Each caveat must match at + # least one of the predicates specified by satisfy_exact or + # specify_general. v.satisfy_exact("gen = 1") v.satisfy_exact("type = " + type_string) v.satisfy_exact("user_id = %s" % user_id) @@ -817,10 +812,6 @@ class Auth(object): v.verify(macaroon, self.hs.config.macaroon_secret_key) - v = pymacaroons.Verifier() - v.satisfy_general(self._verify_recognizes_caveats) - v.verify(macaroon, self.hs.config.macaroon_secret_key) - def _verify_expiry(self, caveat): prefix = "time < " if not caveat.startswith(prefix): @@ -829,15 +820,6 @@ class Auth(object): now = self.hs.get_clock().time_msec() return now < expiry - def _verify_recognizes_caveats(self, caveat): - first_space = caveat.find(" ") - if first_space < 0: - return False - second_space = caveat.find(" ", first_space + 1) - if second_space < 0: - return False - return caveat[:second_space + 1] in self._KNOWN_CAVEAT_PREFIXES - @defer.inlineCallbacks def _look_up_user_by_access_token(self, token): ret = yield self.store.get_user_by_access_token(token) -- cgit 1.4.1 From 7f02e4d0085b7df0429a298162f7c22b6e19c095 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 25 Nov 2016 15:25:30 +0000 Subject: Give guest users a device_id We need to create devices for guests so that they can use e2e, but we don't have anywhere to store it, so just use a fixed one. --- synapse/api/auth.py | 6 +++++- synapse/rest/client/v2_alpha/register.py | 19 +++++++++++++++---- 2 files changed, 20 insertions(+), 5 deletions(-) (limited to 'synapse/api') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 69b3392735..4321ec26f1 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -39,6 +39,9 @@ AuthEventTypes = ( EventTypes.ThirdPartyInvite, ) +# guests always get this device id. +GUEST_DEVICE_ID = "guest_device" + class Auth(object): """ @@ -728,7 +731,8 @@ class Auth(object): "user": user, "is_guest": True, "token_id": None, - "device_id": None, + # all guests get the same device id + "device_id": GUEST_DEVICE_ID, } elif rights == "delete_pusher": # We don't store these tokens in the database diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 6cfb20866b..7fff2d4bf6 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -15,6 +15,7 @@ from twisted.internet import defer +import synapse from synapse.api.auth import get_access_token_from_request, has_access_token from synapse.api.constants import LoginType from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError @@ -100,12 +101,14 @@ class RegisterRestServlet(RestServlet): def on_POST(self, request): yield run_on_reactor() + body = parse_json_object_from_request(request) + kind = "user" if "kind" in request.args: kind = request.args["kind"][0] if kind == "guest": - ret = yield self._do_guest_registration() + ret = yield self._do_guest_registration(body) defer.returnValue(ret) return elif kind != "user": @@ -113,8 +116,6 @@ class RegisterRestServlet(RestServlet): "Do not understand membership kind: %s" % (kind,) ) - body = parse_json_object_from_request(request) - # we do basic sanity checks here because the auth layer will store these # in sessions. Pull out the username/password provided to us. desired_password = None @@ -421,13 +422,22 @@ class RegisterRestServlet(RestServlet): ) @defer.inlineCallbacks - def _do_guest_registration(self): + def _do_guest_registration(self, params): if not self.hs.config.allow_guest_access: defer.returnValue((403, "Guest access is disabled")) user_id, _ = yield self.registration_handler.register( generate_token=False, make_guest=True ) + + # we don't allow guests to specify their own device_id, because + # we have nowhere to store it. + device_id = synapse.api.auth.GUEST_DEVICE_ID + initial_display_name = params.get("initial_device_display_name") + self.device_handler.check_device_registered( + user_id, device_id, initial_display_name + ) + access_token = self.auth_handler.generate_access_token( user_id, ["guest = true"] ) @@ -435,6 +445,7 @@ class RegisterRestServlet(RestServlet): # so long as we don't return a refresh_token here. defer.returnValue((200, { "user_id": user_id, + "device_id": device_id, "access_token": access_token, "home_server": self.hs.hostname, })) -- cgit 1.4.1 From 1c4f05db41eab20f8be4ac2dac0f7e86b0b7e1fd Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 28 Nov 2016 09:55:21 +0000 Subject: Stop putting a time caveat on access tokens The 'time' caveat on the access tokens was something of a lie, since we weren't enforcing it; more pertinently its presence stops us ever adding useful time caveats. Let's move in the right direction by not lying in our caveats. --- synapse/api/auth.py | 4 ++++ synapse/config/registration.py | 6 ------ synapse/handlers/auth.py | 11 ++++++----- synapse/handlers/register.py | 5 ++--- synapse/rest/client/v1/register.py | 12 ------------ tests/handlers/test_auth.py | 6 +++--- tests/handlers/test_register.py | 6 ++---- 7 files changed, 17 insertions(+), 33 deletions(-) (limited to 'synapse/api') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 1ab27da941..77ff55cddf 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -810,6 +810,10 @@ class Auth(object): else: v.satisfy_general(lambda c: c.startswith("time < ")) + # access_tokens and refresh_tokens include a nonce for uniqueness: any + # value is acceptable + v.satisfy_general(lambda c: c.startswith("nonce = ")) + v.verify(macaroon, self.hs.config.macaroon_secret_key) def _verify_expiry(self, caveat): diff --git a/synapse/config/registration.py b/synapse/config/registration.py index cc3f879857..87e500c97a 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -32,7 +32,6 @@ class RegistrationConfig(Config): ) self.registration_shared_secret = config.get("registration_shared_secret") - self.user_creation_max_duration = int(config["user_creation_max_duration"]) self.bcrypt_rounds = config.get("bcrypt_rounds", 12) self.trusted_third_party_id_servers = config["trusted_third_party_id_servers"] @@ -55,11 +54,6 @@ class RegistrationConfig(Config): # secret, even if registration is otherwise disabled. registration_shared_secret: "%(registration_shared_secret)s" - # Sets the expiry for the short term user creation in - # milliseconds. For instance the bellow duration is two weeks - # in milliseconds. - user_creation_max_duration: 1209600000 - # Set the number of bcrypt rounds used to generate password hash. # Larger numbers increase the work factor needed to generate the hash. # The default number of rounds is 12. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index a2866af431..20aaec36a4 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -538,14 +538,15 @@ class AuthHandler(BaseHandler): device_id) defer.returnValue(refresh_token) - def generate_access_token(self, user_id, extra_caveats=None, - duration_in_ms=(60 * 60 * 1000)): + def generate_access_token(self, user_id, extra_caveats=None): extra_caveats = extra_caveats or [] macaroon = self._generate_base_macaroon(user_id) macaroon.add_first_party_caveat("type = access") - now = self.hs.get_clock().time_msec() - expiry = now + duration_in_ms - macaroon.add_first_party_caveat("time < %d" % (expiry,)) + # Include a nonce, to make sure that each login gets a different + # access token. + macaroon.add_first_party_caveat("nonce = %s" % ( + stringutils.random_string_with_symbols(16), + )) for caveat in extra_caveats: macaroon.add_first_party_caveat(caveat) return macaroon.serialize() diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 7e119f13b1..886fec8701 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -369,7 +369,7 @@ class RegistrationHandler(BaseHandler): defer.returnValue(data) @defer.inlineCallbacks - def get_or_create_user(self, requester, localpart, displayname, duration_in_ms, + def get_or_create_user(self, requester, localpart, displayname, password_hash=None): """Creates a new user if the user does not exist, else revokes all previous access tokens and generates a new one. @@ -399,8 +399,7 @@ class RegistrationHandler(BaseHandler): user = UserID(localpart, self.hs.hostname) user_id = user.to_string() - token = self.auth_handler().generate_access_token( - user_id, None, duration_in_ms) + token = self.auth_handler().generate_access_token(user_id) if need_register: yield self.store.register( diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index b5a76fefac..ecf7e311a9 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -384,7 +384,6 @@ class CreateUserRestServlet(ClientV1RestServlet): def __init__(self, hs): super(CreateUserRestServlet, self).__init__(hs) self.store = hs.get_datastore() - self.direct_user_creation_max_duration = hs.config.user_creation_max_duration self.handlers = hs.get_handlers() @defer.inlineCallbacks @@ -418,18 +417,8 @@ class CreateUserRestServlet(ClientV1RestServlet): if "displayname" not in user_json: raise SynapseError(400, "Expected 'displayname' key.") - if "duration_seconds" not in user_json: - raise SynapseError(400, "Expected 'duration_seconds' key.") - localpart = user_json["localpart"].encode("utf-8") displayname = user_json["displayname"].encode("utf-8") - duration_seconds = 0 - try: - duration_seconds = int(user_json["duration_seconds"]) - except ValueError: - raise SynapseError(400, "Failed to parse 'duration_seconds'") - if duration_seconds > self.direct_user_creation_max_duration: - duration_seconds = self.direct_user_creation_max_duration password_hash = user_json["password_hash"].encode("utf-8") \ if user_json.get("password_hash") else None @@ -438,7 +427,6 @@ class CreateUserRestServlet(ClientV1RestServlet): requester=requester, localpart=localpart, displayname=displayname, - duration_in_ms=(duration_seconds * 1000), password_hash=password_hash ) diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 4a8cd19acf..9d013e5ca7 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -61,14 +61,14 @@ class AuthTestCase(unittest.TestCase): def verify_type(caveat): return caveat == "type = access" - def verify_expiry(caveat): - return caveat == "time < 8600000" + def verify_nonce(caveat): + return caveat.startswith("nonce =") v = pymacaroons.Verifier() v.satisfy_general(verify_gen) v.satisfy_general(verify_user) v.satisfy_general(verify_type) - v.satisfy_general(verify_expiry) + v.satisfy_general(verify_nonce) v.verify(macaroon, self.hs.config.macaroon_secret_key) def test_short_term_login_token_gives_user_id(self): diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 9c9d144690..a4380c48b4 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -53,13 +53,12 @@ class RegistrationTestCase(unittest.TestCase): @defer.inlineCallbacks def test_user_is_created_and_logged_in_if_doesnt_exist(self): - duration_ms = 200 local_part = "someone" display_name = "someone" user_id = "@someone:test" requester = create_requester("@as:test") result_user_id, result_token = yield self.handler.get_or_create_user( - requester, local_part, display_name, duration_ms) + requester, local_part, display_name) self.assertEquals(result_user_id, user_id) self.assertEquals(result_token, 'secret') @@ -71,12 +70,11 @@ class RegistrationTestCase(unittest.TestCase): user_id=frank.to_string(), token="jkv;g498752-43gj['eamb!-5", password_hash=None) - duration_ms = 200 local_part = "frank" display_name = "Frank" user_id = "@frank:test" requester = create_requester("@as:test") result_user_id, result_token = yield self.handler.get_or_create_user( - requester, local_part, display_name, duration_ms) + requester, local_part, display_name) self.assertEquals(result_user_id, user_id) self.assertEquals(result_token, 'secret') -- cgit 1.4.1 From 4febfe47f03a97578e186fa6cae28c29ad8327cb Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 30 Nov 2016 07:36:32 +0000 Subject: Comments Update comments in verify_macaroon --- synapse/api/auth.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) (limited to 'synapse/api') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 77ff55cddf..b8c2917f21 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -790,9 +790,6 @@ class Auth(object): type_string(str): The kind of token required (e.g. "access", "refresh", "delete_pusher") verify_expiry(bool): Whether to verify whether the macaroon has expired. - This should really always be True, but there exist access tokens - in the wild which expire when they should not, so we can't - enforce expiry yet. user_id (str): The user_id required """ v = pymacaroons.Verifier() @@ -805,6 +802,15 @@ class Auth(object): v.satisfy_exact("type = " + type_string) v.satisfy_exact("user_id = %s" % user_id) v.satisfy_exact("guest = true") + + # verify_expiry should really always be True, but there exist access + # tokens in the wild which expire when they should not, so we can't + # enforce expiry yet (so we have to allow any caveat starting with + # 'time < ' in access tokens). + # + # On the other hand, short-term login tokens (as used by CAS login, for + # example) have an expiry time which we do want to enforce. + if verify_expiry: v.satisfy_general(self._verify_expiry) else: -- cgit 1.4.1 From aa09d6b8f0a8f3f006f08b8816b3f2a0fe7eb167 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 30 Nov 2016 17:40:18 +0000 Subject: Rip out more refresh_token code We might as well treat all refresh_tokens as invalid. Just return a 403 from /tokenrefresh, so that we don't have a load of dead, untestable code hanging around. Still TODO: removing the table from the schema. --- synapse/api/auth.py | 5 +-- synapse/handlers/auth.py | 10 ----- synapse/rest/client/v2_alpha/register.py | 2 - synapse/rest/client/v2_alpha/tokenrefresh.py | 26 ++--------- synapse/storage/__init__.py | 1 - synapse/storage/registration.py | 66 ---------------------------- tests/storage/test_registration.py | 55 ----------------------- 7 files changed, 5 insertions(+), 160 deletions(-) (limited to 'synapse/api') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index b17025c7ce..ddab210718 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -791,7 +791,7 @@ class Auth(object): Args: macaroon(pymacaroons.Macaroon): The macaroon to validate - type_string(str): The kind of token required (e.g. "access", "refresh", + type_string(str): The kind of token required (e.g. "access", "delete_pusher") verify_expiry(bool): Whether to verify whether the macaroon has expired. user_id (str): The user_id required @@ -820,8 +820,7 @@ class Auth(object): else: v.satisfy_general(lambda c: c.startswith("time < ")) - # access_tokens and refresh_tokens include a nonce for uniqueness: any - # value is acceptable + # access_tokens include a nonce for uniqueness: any value is acceptable v.satisfy_general(lambda c: c.startswith("nonce = ")) v.verify(macaroon, self.hs.config.macaroon_secret_key) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 91e7e725b9..9d8e6f19bc 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -539,16 +539,6 @@ class AuthHandler(BaseHandler): macaroon.add_first_party_caveat(caveat) return macaroon.serialize() - def generate_refresh_token(self, user_id): - m = self._generate_base_macaroon(user_id) - m.add_first_party_caveat("type = refresh") - # Important to add a nonce, because otherwise every refresh token for a - # user will be the same. - m.add_first_party_caveat("nonce = %s" % ( - stringutils.random_string_with_symbols(16), - )) - return m.serialize() - def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)): macaroon = self._generate_base_macaroon(user_id) macaroon.add_first_party_caveat("type = login") diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index bc2ec95ddd..d5e6ec8b92 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -440,8 +440,6 @@ class RegisterRestServlet(RestServlet): access_token = self.auth_handler.generate_access_token( user_id, ["guest = true"] ) - # XXX the "guest" caveat is not copied by /tokenrefresh. That's ok - # so long as we don't return a refresh_token here. defer.returnValue((200, { "user_id": user_id, "device_id": device_id, diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py index 0d312c91d4..6e76b9e9c2 100644 --- a/synapse/rest/client/v2_alpha/tokenrefresh.py +++ b/synapse/rest/client/v2_alpha/tokenrefresh.py @@ -15,8 +15,8 @@ from twisted.internet import defer -from synapse.api.errors import AuthError, StoreError, SynapseError -from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.api.errors import AuthError +from synapse.http.servlet import RestServlet from ._base import client_v2_patterns @@ -30,30 +30,10 @@ class TokenRefreshRestServlet(RestServlet): def __init__(self, hs): super(TokenRefreshRestServlet, self).__init__() - self.hs = hs - self.store = hs.get_datastore() @defer.inlineCallbacks def on_POST(self, request): - body = parse_json_object_from_request(request) - try: - old_refresh_token = body["refresh_token"] - auth_handler = self.hs.get_auth_handler() - refresh_result = yield self.store.exchange_refresh_token( - old_refresh_token, auth_handler.generate_refresh_token - ) - (user_id, new_refresh_token, device_id) = refresh_result - new_access_token = yield auth_handler.issue_access_token( - user_id, device_id - ) - defer.returnValue((200, { - "access_token": new_access_token, - "refresh_token": new_refresh_token, - })) - except KeyError: - raise SynapseError(400, "Missing required key 'refresh_token'.") - except StoreError: - raise AuthError(403, "Did not recognize refresh token") + raise AuthError(403, "tokenrefresh is no longer supported.") def register_servlets(hs, http_server): diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 9996f195a0..db146ed348 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -120,7 +120,6 @@ class DataStore(RoomMemberStore, RoomStore, self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id") self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") - self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id") self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index e404fa72de..983a8ec52b 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -68,31 +68,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): desc="add_access_token_to_user", ) - @defer.inlineCallbacks - def add_refresh_token_to_user(self, user_id, token, device_id=None): - """Adds a refresh token for the given user. - - Args: - user_id (str): The user ID. - token (str): The new refresh token to add. - device_id (str): ID of the device to associate with the access - token - Raises: - StoreError if there was a problem adding this. - """ - next_id = self._refresh_tokens_id_gen.get_next() - - yield self._simple_insert( - "refresh_tokens", - { - "id": next_id, - "user_id": user_id, - "token": token, - "device_id": device_id, - }, - desc="add_refresh_token_to_user", - ) - def register(self, user_id, token=None, password_hash=None, was_guest=False, make_guest=False, appservice_id=None, create_profile_with_localpart=None, admin=False): @@ -353,47 +328,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): token ) - def exchange_refresh_token(self, refresh_token, token_generator): - """Exchange a refresh token for a new one. - - Doing so invalidates the old refresh token - refresh tokens are single - use. - - Args: - refresh_token (str): The refresh token of a user. - token_generator (fn: str -> str): Function which, when given a - user ID, returns a unique refresh token for that user. This - function must never return the same value twice. - Returns: - tuple of (user_id, new_refresh_token, device_id) - Raises: - StoreError if no user was found with that refresh token. - """ - return self.runInteraction( - "exchange_refresh_token", - self._exchange_refresh_token, - refresh_token, - token_generator - ) - - def _exchange_refresh_token(self, txn, old_token, token_generator): - sql = "SELECT user_id, device_id FROM refresh_tokens WHERE token = ?" - txn.execute(sql, (old_token,)) - rows = self.cursor_to_dict(txn) - if not rows: - raise StoreError(403, "Did not recognize refresh token") - user_id = rows[0]["user_id"] - device_id = rows[0]["device_id"] - - # TODO(danielwh): Maybe perform a validation on the macaroon that - # macaroon.user_id == user_id. - - new_token = token_generator(user_id) - sql = "UPDATE refresh_tokens SET token = ? WHERE token = ?" - txn.execute(sql, (new_token, old_token,)) - - return user_id, new_token, device_id - @defer.inlineCallbacks def is_server_admin(self, user): res = yield self._simple_select_one_onecol( diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index f7d74dea8e..db0faa7fcb 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -80,64 +80,12 @@ class RegistrationStoreTestCase(unittest.TestCase): self.assertTrue("token_id" in result) - @defer.inlineCallbacks - def test_exchange_refresh_token_valid(self): - uid = stringutils.random_string(32) - device_id = stringutils.random_string(16) - generator = TokenGenerator() - last_token = generator.generate(uid) - - self.db_pool.runQuery( - "INSERT INTO refresh_tokens(user_id, token, device_id) " - "VALUES(?,?,?)", - (uid, last_token, device_id)) - - (found_user_id, refresh_token, device_id) = \ - yield self.store.exchange_refresh_token(last_token, - generator.generate) - self.assertEqual(uid, found_user_id) - - rows = yield self.db_pool.runQuery( - "SELECT token, device_id FROM refresh_tokens WHERE user_id = ?", - (uid, )) - self.assertEqual([(refresh_token, device_id)], rows) - # We issued token 1, then exchanged it for token 2 - expected_refresh_token = u"%s-%d" % (uid, 2,) - self.assertEqual(expected_refresh_token, refresh_token) - - @defer.inlineCallbacks - def test_exchange_refresh_token_none(self): - uid = stringutils.random_string(32) - generator = TokenGenerator() - last_token = generator.generate(uid) - - with self.assertRaises(StoreError): - yield self.store.exchange_refresh_token(last_token, generator.generate) - - @defer.inlineCallbacks - def test_exchange_refresh_token_invalid(self): - uid = stringutils.random_string(32) - generator = TokenGenerator() - last_token = generator.generate(uid) - wrong_token = "%s-wrong" % (last_token,) - - self.db_pool.runQuery( - "INSERT INTO refresh_tokens(user_id, token) VALUES(?,?)", - (uid, wrong_token,)) - - with self.assertRaises(StoreError): - yield self.store.exchange_refresh_token(last_token, generator.generate) - @defer.inlineCallbacks def test_user_delete_access_tokens(self): # add some tokens - generator = TokenGenerator() - refresh_token = generator.generate(self.user_id) yield self.store.register(self.user_id, self.tokens[0], self.pwhash) yield self.store.add_access_token_to_user(self.user_id, self.tokens[1], self.device_id) - yield self.store.add_refresh_token_to_user(self.user_id, refresh_token, - self.device_id) # now delete some yield self.store.user_delete_access_tokens( @@ -146,9 +94,6 @@ class RegistrationStoreTestCase(unittest.TestCase): # check they were deleted user = yield self.store.get_user_by_access_token(self.tokens[1]) self.assertIsNone(user, "access token was not deleted by device_id") - with self.assertRaises(StoreError): - yield self.store.exchange_refresh_token(refresh_token, - generator.generate) # check the one not associated with the device was not deleted user = yield self.store.get_user_by_access_token(self.tokens[0]) -- cgit 1.4.1 From 1529c196758ec4106f4b3a0f89f5b41bc5205c7b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 6 Dec 2016 15:31:37 +0000 Subject: Prevent user tokens being used as guest tokens (#1675) Make sure that a user cannot pretend to be a guest by adding 'guest = True' caveats. --- synapse/api/auth.py | 51 +++++++++++++++++------- synapse/handlers/register.py | 2 +- tests/api/test_auth.py | 93 +++++++++++++++++++++++++++++++++++++------- 3 files changed, 115 insertions(+), 31 deletions(-) (limited to 'synapse/api') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index ddab210718..a99986714d 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -677,31 +677,28 @@ class Auth(object): @defer.inlineCallbacks def get_user_by_access_token(self, token, rights="access"): - """ Get a registered user's ID. + """ Validate access token and get user_id from it Args: token (str): The access token to get the user by. + rights (str): The operation being performed; the access token must + allow this. Returns: dict : dict that includes the user and the ID of their access token. Raises: AuthError if no user by that token exists or the token is invalid. """ try: - ret = yield self.get_user_from_macaroon(token, rights) - except AuthError: - # TODO(daniel): Remove this fallback when all existing access tokens - # have been re-issued as macaroons. - if self.hs.config.expire_access_token: - raise - ret = yield self._look_up_user_by_access_token(token) - - defer.returnValue(ret) + macaroon = pymacaroons.Macaroon.deserialize(token) + except Exception: # deserialize can throw more-or-less anything + # doesn't look like a macaroon: treat it as an opaque token which + # must be in the database. + # TODO: it would be nice to get rid of this, but apparently some + # people use access tokens which aren't macaroons + r = yield self._look_up_user_by_access_token(token) + defer.returnValue(r) - @defer.inlineCallbacks - def get_user_from_macaroon(self, macaroon_str, rights="access"): try: - macaroon = pymacaroons.Macaroon.deserialize(macaroon_str) - user_id = self.get_user_id_from_macaroon(macaroon) user = UserID.from_string(user_id) @@ -716,6 +713,30 @@ class Auth(object): guest = True if guest: + # Guest access tokens are not stored in the database (there can + # only be one access token per guest, anyway). + # + # In order to prevent guest access tokens being used as regular + # user access tokens (and hence getting around the invalidation + # process), we look up the user id and check that it is indeed + # a guest user. + # + # It would of course be much easier to store guest access + # tokens in the database as well, but that would break existing + # guest tokens. + stored_user = yield self.store.get_user_by_id(user_id) + if not stored_user: + raise AuthError( + self.TOKEN_NOT_FOUND_HTTP_STATUS, + "Unknown user_id %s" % user_id, + errcode=Codes.UNKNOWN_TOKEN + ) + if not stored_user["is_guest"]: + raise AuthError( + self.TOKEN_NOT_FOUND_HTTP_STATUS, + "Guest access token used for regular user", + errcode=Codes.UNKNOWN_TOKEN + ) ret = { "user": user, "is_guest": True, @@ -743,7 +764,7 @@ class Auth(object): # macaroon. They probably should be. # TODO: build the dictionary from the macaroon once the # above are fixed - ret = yield self._look_up_user_by_access_token(macaroon_str) + ret = yield self._look_up_user_by_access_token(token) if ret["user"] != user: logger.error( "Macaroon user (%s) != DB user (%s)", diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 886fec8701..286f0cef0a 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -81,7 +81,7 @@ class RegistrationHandler(BaseHandler): "User ID already taken.", errcode=Codes.USER_IN_USE, ) - user_data = yield self.auth.get_user_from_macaroon(guest_access_token) + user_data = yield self.auth.get_user_by_access_token(guest_access_token) if not user_data["is_guest"] or user_data["user"].localpart != localpart: raise AuthError( 403, diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 2cf262bb46..4575dd9834 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -12,17 +12,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from tests import unittest -from twisted.internet import defer +import pymacaroons from mock import Mock +from twisted.internet import defer +import synapse.handlers.auth from synapse.api.auth import Auth from synapse.api.errors import AuthError from synapse.types import UserID +from tests import unittest from tests.utils import setup_test_homeserver, mock_getRawHeaders -import pymacaroons + +class TestHandlers(object): + def __init__(self, hs): + self.auth_handler = synapse.handlers.auth.AuthHandler(hs) class AuthTestCase(unittest.TestCase): @@ -34,14 +39,17 @@ class AuthTestCase(unittest.TestCase): self.hs = yield setup_test_homeserver(handlers=None) self.hs.get_datastore = Mock(return_value=self.store) + self.hs.handlers = TestHandlers(self.hs) self.auth = Auth(self.hs) self.test_user = "@foo:bar" self.test_token = "_test_token_" + # this is overridden for the appservice tests + self.store.get_app_service_by_token = Mock(return_value=None) + @defer.inlineCallbacks def test_get_user_by_req_user_valid_token(self): - self.store.get_app_service_by_token = Mock(return_value=None) user_info = { "name": self.test_user, "token_id": "ditto", @@ -56,7 +64,6 @@ class AuthTestCase(unittest.TestCase): self.assertEquals(requester.user.to_string(), self.test_user) def test_get_user_by_req_user_bad_token(self): - self.store.get_app_service_by_token = Mock(return_value=None) self.store.get_user_by_access_token = Mock(return_value=None) request = Mock(args={}) @@ -66,7 +73,6 @@ class AuthTestCase(unittest.TestCase): self.failureResultOf(d, AuthError) def test_get_user_by_req_user_missing_token(self): - self.store.get_app_service_by_token = Mock(return_value=None) user_info = { "name": self.test_user, "token_id": "ditto", @@ -158,7 +164,7 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) - user_info = yield self.auth.get_user_from_macaroon(macaroon.serialize()) + user_info = yield self.auth.get_user_by_access_token(macaroon.serialize()) user = user_info["user"] self.assertEqual(UserID.from_string(user_id), user) @@ -168,6 +174,10 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_guest_user_from_macaroon(self): + self.store.get_user_by_id = Mock(return_value={ + "is_guest": True, + }) + user_id = "@baldrick:matrix.org" macaroon = pymacaroons.Macaroon( location=self.hs.config.server_name, @@ -179,11 +189,12 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("guest = true") serialized = macaroon.serialize() - user_info = yield self.auth.get_user_from_macaroon(serialized) + user_info = yield self.auth.get_user_by_access_token(serialized) user = user_info["user"] is_guest = user_info["is_guest"] self.assertEqual(UserID.from_string(user_id), user) self.assertTrue(is_guest) + self.store.get_user_by_id.assert_called_with(user_id) @defer.inlineCallbacks def test_get_user_from_macaroon_user_db_mismatch(self): @@ -200,7 +211,7 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("user_id = %s" % (user,)) with self.assertRaises(AuthError) as cm: - yield self.auth.get_user_from_macaroon(macaroon.serialize()) + yield self.auth.get_user_by_access_token(macaroon.serialize()) self.assertEqual(401, cm.exception.code) self.assertIn("User mismatch", cm.exception.msg) @@ -220,7 +231,7 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("type = access") with self.assertRaises(AuthError) as cm: - yield self.auth.get_user_from_macaroon(macaroon.serialize()) + yield self.auth.get_user_by_access_token(macaroon.serialize()) self.assertEqual(401, cm.exception.code) self.assertIn("No user caveat", cm.exception.msg) @@ -242,7 +253,7 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("user_id = %s" % (user,)) with self.assertRaises(AuthError) as cm: - yield self.auth.get_user_from_macaroon(macaroon.serialize()) + yield self.auth.get_user_by_access_token(macaroon.serialize()) self.assertEqual(401, cm.exception.code) self.assertIn("Invalid macaroon", cm.exception.msg) @@ -265,7 +276,7 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("cunning > fox") with self.assertRaises(AuthError) as cm: - yield self.auth.get_user_from_macaroon(macaroon.serialize()) + yield self.auth.get_user_by_access_token(macaroon.serialize()) self.assertEqual(401, cm.exception.code) self.assertIn("Invalid macaroon", cm.exception.msg) @@ -293,12 +304,12 @@ class AuthTestCase(unittest.TestCase): self.hs.clock.now = 5000 # seconds self.hs.config.expire_access_token = True - # yield self.auth.get_user_from_macaroon(macaroon.serialize()) + # yield self.auth.get_user_by_access_token(macaroon.serialize()) # TODO(daniel): Turn on the check that we validate expiration, when we # validate expiration (and remove the above line, which will start # throwing). with self.assertRaises(AuthError) as cm: - yield self.auth.get_user_from_macaroon(macaroon.serialize()) + yield self.auth.get_user_by_access_token(macaroon.serialize()) self.assertEqual(401, cm.exception.code) self.assertIn("Invalid macaroon", cm.exception.msg) @@ -327,6 +338,58 @@ class AuthTestCase(unittest.TestCase): self.hs.clock.now = 5000 # seconds self.hs.config.expire_access_token = True - user_info = yield self.auth.get_user_from_macaroon(macaroon.serialize()) + user_info = yield self.auth.get_user_by_access_token(macaroon.serialize()) user = user_info["user"] self.assertEqual(UserID.from_string(user_id), user) + + @defer.inlineCallbacks + def test_cannot_use_regular_token_as_guest(self): + USER_ID = "@percy:matrix.org" + self.store.add_access_token_to_user = Mock() + + token = yield self.hs.handlers.auth_handler.issue_access_token( + USER_ID, "DEVICE" + ) + self.store.add_access_token_to_user.assert_called_with( + USER_ID, token, "DEVICE" + ) + + def get_user(tok): + if token != tok: + return None + return { + "name": USER_ID, + "is_guest": False, + "token_id": 1234, + "device_id": "DEVICE", + } + self.store.get_user_by_access_token = get_user + self.store.get_user_by_id = Mock(return_value={ + "is_guest": False, + }) + + # check the token works + request = Mock(args={}) + request.args["access_token"] = [token] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + self.assertEqual(UserID.from_string(USER_ID), requester.user) + self.assertFalse(requester.is_guest) + + # add an is_guest caveat + mac = pymacaroons.Macaroon.deserialize(token) + mac.add_first_party_caveat("guest = true") + guest_tok = mac.serialize() + + # the token should *not* work now + request = Mock(args={}) + request.args["access_token"] = [guest_tok] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + + with self.assertRaises(AuthError) as cm: + yield self.auth.get_user_by_req(request, allow_guest=True) + + self.assertEqual(401, cm.exception.code) + self.assertEqual("Guest access token used for regular user", cm.exception.msg) + + self.store.get_user_by_id.assert_called_with(USER_ID) -- cgit 1.4.1