diff options
-rw-r--r-- | CHANGES.rst | 15 | ||||
-rw-r--r-- | synapse/metrics/metric.py | 30 | ||||
-rw-r--r-- | synapse/rest/client/v1/base.py | 6 | ||||
-rw-r--r-- | synapse/rest/client/v1/pusher.py | 2 | ||||
-rw-r--r-- | synapse/server.py | 10 | ||||
-rw-r--r-- | synapse/storage/events.py | 104 | ||||
-rw-r--r-- | tests/metrics/test_metric.py | 21 | ||||
-rw-r--r-- | tests/rest/client/v1/test_events.py | 9 | ||||
-rw-r--r-- | tests/rest/client/v1/test_profile.py | 2 | ||||
-rw-r--r-- | tests/rest/client/v1/test_rooms.py | 18 | ||||
-rw-r--r-- | tests/rest/client/v1/test_typing.py | 2 |
11 files changed, 154 insertions, 65 deletions
diff --git a/CHANGES.rst b/CHANGES.rst index 317846d2a2..9d40b2ac1e 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,3 +1,18 @@ +Changes in synapse <unreleased> +=============================== + +Potentially breaking change: + +* Make Client-Server API return 401 for invalid token (PR #3161). + + This changes the Client-server spec to return a 401 error code instead of 403 + when the access token is unrecognised. This is the behaviour required by the + specification, but some clients may be relying on the old, incorrect + behaviour. + + Thanks to @NotAFile for fixing this. + + Changes in synapse v0.28.1 (2018-05-01) ======================================= diff --git a/synapse/metrics/metric.py b/synapse/metrics/metric.py index 89bd47c3f7..fbba94e633 100644 --- a/synapse/metrics/metric.py +++ b/synapse/metrics/metric.py @@ -16,6 +16,7 @@ from itertools import chain import logging +import re logger = logging.getLogger(__name__) @@ -56,8 +57,7 @@ class BaseMetric(object): return not len(self.labels) def _render_labelvalue(self, value): - # TODO: escape backslashes, quotes and newlines - return '"%s"' % (value) + return '"%s"' % (_escape_label_value(value),) def _render_key(self, values): if self.is_scalar(): @@ -299,3 +299,29 @@ class MemoryUsageMetric(object): "process_psutil_rss:total %d" % sum_rss, "process_psutil_rss:count %d" % len_rss, ] + + +def _escape_character(m): + """Replaces a single character with its escape sequence. + + Args: + m (re.MatchObject): A match object whose first group is the single + character to replace + + Returns: + str + """ + c = m.group(1) + if c == "\\": + return "\\\\" + elif c == "\"": + return "\\\"" + elif c == "\n": + return "\\n" + return c + + +def _escape_label_value(value): + """Takes a label value and escapes quotes, newlines and backslashes + """ + return re.sub(r"([\n\"\\])", _escape_character, value) diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py index c7aa0bbf59..197335d7aa 100644 --- a/synapse/rest/client/v1/base.py +++ b/synapse/rest/client/v1/base.py @@ -52,6 +52,10 @@ class ClientV1RestServlet(RestServlet): """A base Synapse REST Servlet for the client version 1 API. """ + # This subclass was presumably created to allow the auth for the v1 + # protocol version to be different, however this behaviour was removed. + # it may no longer be necessary + def __init__(self, hs): """ Args: @@ -59,5 +63,5 @@ class ClientV1RestServlet(RestServlet): """ self.hs = hs self.builder_factory = hs.get_event_builder_factory() - self.auth = hs.get_v1auth() + self.auth = hs.get_auth() self.txns = HttpTransactionCache(hs.get_clock()) diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 1819a560cb..0206e664c1 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -150,7 +150,7 @@ class PushersRemoveRestServlet(RestServlet): super(RestServlet, self).__init__() self.hs = hs self.notifier = hs.get_notifier() - self.auth = hs.get_v1auth() + self.auth = hs.get_auth() self.pusher_pool = self.hs.get_pusherpool() @defer.inlineCallbacks diff --git a/synapse/server.py b/synapse/server.py index cd0c1a51be..ebdea6b0c4 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -105,7 +105,6 @@ class HomeServer(object): 'federation_client', 'federation_server', 'handlers', - 'v1auth', 'auth', 'state_handler', 'state_resolution_handler', @@ -225,15 +224,6 @@ class HomeServer(object): def build_simple_http_client(self): return SimpleHttpClient(self) - def build_v1auth(self): - orf = Auth(self) - # Matrix spec makes no reference to what HTTP status code is returned, - # but the V1 API uses 403 where it means 401, and the webclient - # relies on this behaviour, so V1 gets its own copy of the auth - # with backwards compat behaviour. - orf.TOKEN_NOT_FOUND_HTTP_STATUS = 403 - return orf - def build_state_handler(self): return StateHandler(self) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 5fe4a0e56c..05cde96afc 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -22,7 +22,6 @@ import logging import simplejson as json from twisted.internet import defer - from synapse.storage.events_worker import EventsWorkerStore from synapse.util.async import ObservableDeferred from synapse.util.frozenutils import frozendict_json_encoder @@ -425,7 +424,9 @@ class EventsStore(EventsWorkerStore): ) current_state = yield self._get_new_state_after_events( room_id, - ev_ctx_rm, new_latest_event_ids, + ev_ctx_rm, + latest_event_ids, + new_latest_event_ids, ) if current_state is not None: current_state_for_room[room_id] = current_state @@ -513,7 +514,8 @@ class EventsStore(EventsWorkerStore): defer.returnValue(new_latest_event_ids) @defer.inlineCallbacks - def _get_new_state_after_events(self, room_id, events_context, new_latest_event_ids): + def _get_new_state_after_events(self, room_id, events_context, old_latest_event_ids, + new_latest_event_ids): """Calculate the current state dict after adding some new events to a room @@ -524,6 +526,9 @@ class EventsStore(EventsWorkerStore): events_context (list[(EventBase, EventContext)]): events and contexts which are being added to the room + old_latest_event_ids (iterable[str]): + the old forward extremities for the room. + new_latest_event_ids (iterable[str]): the new forward extremities for the room. @@ -534,64 +539,89 @@ class EventsStore(EventsWorkerStore): """ if not new_latest_event_ids: - defer.returnValue({}) + return # map from state_group to ((type, key) -> event_id) state map - state_groups = {} - missing_event_ids = [] - was_updated = False + state_groups_map = {} + for ev, ctx in events_context: + if ctx.state_group is None: + # I don't think this can happen, but let's double-check + raise Exception( + "Context for new extremity event %s has no state " + "group" % (ev.event_id, ), + ) + + if ctx.state_group in state_groups_map: + continue + + state_groups_map[ctx.state_group] = ctx.current_state_ids + + # We need to map the event_ids to their state groups. First, let's + # check if the event is one we're persisting, in which case we can + # pull the state group from its context. + # Otherwise we need to pull the state group from the database. + + # Set of events we need to fetch groups for. (We know none of the old + # extremities are going to be in events_context). + missing_event_ids = set(old_latest_event_ids) + + event_id_to_state_group = {} for event_id in new_latest_event_ids: - # First search in the list of new events we're adding, - # and then use the current state from that + # First search in the list of new events we're adding. for ev, ctx in events_context: if event_id == ev.event_id: - if ctx.current_state_ids is None: - raise Exception("Unknown current state") - - if ctx.state_group is None: - # I don't think this can happen, but let's double-check - raise Exception( - "Context for new extremity event %s has no state " - "group" % (event_id, ), - ) - - # If we've already seen the state group don't bother adding - # it to the state sets again - if ctx.state_group not in state_groups: - state_groups[ctx.state_group] = ctx.current_state_ids - if ctx.delta_ids or hasattr(ev, "state_key"): - was_updated = True + event_id_to_state_group[event_id] = ctx.state_group break else: # If we couldn't find it, then we'll need to pull # the state from the database - was_updated = True - missing_event_ids.append(event_id) - - if not was_updated: - return + missing_event_ids.add(event_id) if missing_event_ids: - # Now pull out the state for any missing events from DB + # Now pull out the state groups for any missing events from DB event_to_groups = yield self._get_state_group_for_events( missing_event_ids, ) + event_id_to_state_group.update(event_to_groups) + + # State groups of old_latest_event_ids + old_state_groups = set( + event_id_to_state_group[evid] for evid in old_latest_event_ids + ) + + # State groups of new_latest_event_ids + new_state_groups = set( + event_id_to_state_group[evid] for evid in new_latest_event_ids + ) - groups = set(event_to_groups.itervalues()) - set(state_groups.iterkeys()) + # If they old and new groups are the same then we don't need to do + # anything. + if old_state_groups == new_state_groups: + return - if groups: - group_to_state = yield self._get_state_for_groups(groups) - state_groups.update(group_to_state) + # Now that we have calculated new_state_groups we need to get + # their state IDs so we can resolve to a single state set. + missing_state = new_state_groups - set(state_groups_map) + if missing_state: + group_to_state = yield self._get_state_for_groups(missing_state) + state_groups_map.update(group_to_state) - if len(state_groups) == 1: + if len(new_state_groups) == 1: # If there is only one state group, then we know what the current # state is. - defer.returnValue(state_groups.values()[0]) + defer.returnValue(state_groups_map[new_state_groups.pop()]) + + # Ok, we need to defer to the state handler to resolve our state sets. def get_events(ev_ids): return self.get_events( ev_ids, get_prev_content=False, check_redacted=False, ) + + state_groups = { + sg: state_groups_map[sg] for sg in new_state_groups + } + events_map = {ev.event_id: ev for ev, _ in events_context} logger.debug("calling resolve_state_groups from preserve_events") res = yield self._state_resolution_handler.resolve_state_groups( diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py index 39bde6e3f8..069c0be762 100644 --- a/tests/metrics/test_metric.py +++ b/tests/metrics/test_metric.py @@ -16,7 +16,8 @@ from tests import unittest from synapse.metrics.metric import ( - CounterMetric, CallbackMetric, DistributionMetric, CacheMetric + CounterMetric, CallbackMetric, DistributionMetric, CacheMetric, + _escape_label_value, ) @@ -171,3 +172,21 @@ class CacheMetricTestCase(unittest.TestCase): 'cache:size{name="cache_name"} 1', 'cache:evicted_size{name="cache_name"} 2', ]) + + +class LabelValueEscapeTestCase(unittest.TestCase): + def test_simple(self): + string = "safjhsdlifhyskljfksdfh" + self.assertEqual(string, _escape_label_value(string)) + + def test_escape(self): + self.assertEqual( + "abc\\\"def\\nghi\\\\", + _escape_label_value("abc\"def\nghi\\"), + ) + + def test_sequence_of_escapes(self): + self.assertEqual( + "abc\\\"def\\nghi\\\\\\n", + _escape_label_value("abc\"def\nghi\\\n"), + ) diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index a8d09600bd..f5a7258e68 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -148,11 +148,16 @@ class EventStreamPermissionsTestCase(RestTestCase): @defer.inlineCallbacks def test_stream_basic_permissions(self): - # invalid token, expect 403 + # invalid token, expect 401 + # note: this is in violation of the original v1 spec, which expected + # 403. However, since the v1 spec no longer exists and the v1 + # implementation is now part of the r0 implementation, the newer + # behaviour is used instead to be consistent with the r0 spec. + # see issue #2602 (code, response) = yield self.mock_resource.trigger_get( "/events?access_token=%s" % ("invalid" + self.token, ) ) - self.assertEquals(403, code, msg=str(response)) + self.assertEquals(401, code, msg=str(response)) # valid token, expect content (code, response) = yield self.mock_resource.trigger_get( diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index deac7f100c..dc94b8bd19 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -52,7 +52,7 @@ class ProfileTestCase(unittest.TestCase): def _get_user_by_req(request=None, allow_guest=False): return synapse.types.create_requester(myid) - hs.get_v1auth().get_user_by_req = _get_user_by_req + hs.get_auth().get_user_by_req = _get_user_by_req profile.register_servlets(hs, self.mock_resource) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index d763400eaf..61d737725b 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -60,7 +60,7 @@ class RoomPermissionsTestCase(RestTestCase): "token_id": 1, "is_guest": False, } - hs.get_v1auth().get_user_by_access_token = get_user_by_access_token + hs.get_auth().get_user_by_access_token = get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -70,7 +70,7 @@ class RoomPermissionsTestCase(RestTestCase): synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource) - self.auth = hs.get_v1auth() + self.auth = hs.get_auth() # create some rooms under the name rmcreator_id self.uncreated_rmid = "!aa:test" @@ -425,7 +425,7 @@ class RoomsMemberListTestCase(RestTestCase): "token_id": 1, "is_guest": False, } - hs.get_v1auth().get_user_by_access_token = get_user_by_access_token + hs.get_auth().get_user_by_access_token = get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -507,7 +507,7 @@ class RoomsCreateTestCase(RestTestCase): "token_id": 1, "is_guest": False, } - hs.get_v1auth().get_user_by_access_token = get_user_by_access_token + hs.get_auth().get_user_by_access_token = get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -597,7 +597,7 @@ class RoomTopicTestCase(RestTestCase): "is_guest": False, } - hs.get_v1auth().get_user_by_access_token = get_user_by_access_token + hs.get_auth().get_user_by_access_token = get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -711,7 +711,7 @@ class RoomMemberStateTestCase(RestTestCase): "token_id": 1, "is_guest": False, } - hs.get_v1auth().get_user_by_access_token = get_user_by_access_token + hs.get_auth().get_user_by_access_token = get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -843,7 +843,7 @@ class RoomMessagesTestCase(RestTestCase): "token_id": 1, "is_guest": False, } - hs.get_v1auth().get_user_by_access_token = get_user_by_access_token + hs.get_auth().get_user_by_access_token = get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -945,7 +945,7 @@ class RoomInitialSyncTestCase(RestTestCase): "token_id": 1, "is_guest": False, } - hs.get_v1auth().get_user_by_access_token = get_user_by_access_token + hs.get_auth().get_user_by_access_token = get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -1017,7 +1017,7 @@ class RoomMessageListTestCase(RestTestCase): "token_id": 1, "is_guest": False, } - hs.get_v1auth().get_user_by_access_token = get_user_by_access_token + hs.get_auth().get_user_by_access_token = get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 2ec4ecab5b..fe161ee5cb 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -68,7 +68,7 @@ class RoomTypingTestCase(RestTestCase): "is_guest": False, } - hs.get_v1auth().get_user_by_access_token = get_user_by_access_token + hs.get_auth().get_user_by_access_token = get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) |