diff --git a/CHANGES.rst b/CHANGES.rst
index 317846d2a2..9d40b2ac1e 100644
--- a/CHANGES.rst
+++ b/CHANGES.rst
@@ -1,3 +1,18 @@
+Changes in synapse <unreleased>
+===============================
+
+Potentially breaking change:
+
+* Make Client-Server API return 401 for invalid token (PR #3161).
+
+ This changes the Client-server spec to return a 401 error code instead of 403
+ when the access token is unrecognised. This is the behaviour required by the
+ specification, but some clients may be relying on the old, incorrect
+ behaviour.
+
+ Thanks to @NotAFile for fixing this.
+
+
Changes in synapse v0.28.1 (2018-05-01)
=======================================
diff --git a/synapse/metrics/metric.py b/synapse/metrics/metric.py
index 89bd47c3f7..fbba94e633 100644
--- a/synapse/metrics/metric.py
+++ b/synapse/metrics/metric.py
@@ -16,6 +16,7 @@
from itertools import chain
import logging
+import re
logger = logging.getLogger(__name__)
@@ -56,8 +57,7 @@ class BaseMetric(object):
return not len(self.labels)
def _render_labelvalue(self, value):
- # TODO: escape backslashes, quotes and newlines
- return '"%s"' % (value)
+ return '"%s"' % (_escape_label_value(value),)
def _render_key(self, values):
if self.is_scalar():
@@ -299,3 +299,29 @@ class MemoryUsageMetric(object):
"process_psutil_rss:total %d" % sum_rss,
"process_psutil_rss:count %d" % len_rss,
]
+
+
+def _escape_character(m):
+ """Replaces a single character with its escape sequence.
+
+ Args:
+ m (re.MatchObject): A match object whose first group is the single
+ character to replace
+
+ Returns:
+ str
+ """
+ c = m.group(1)
+ if c == "\\":
+ return "\\\\"
+ elif c == "\"":
+ return "\\\""
+ elif c == "\n":
+ return "\\n"
+ return c
+
+
+def _escape_label_value(value):
+ """Takes a label value and escapes quotes, newlines and backslashes
+ """
+ return re.sub(r"([\n\"\\])", _escape_character, value)
diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py
index c7aa0bbf59..197335d7aa 100644
--- a/synapse/rest/client/v1/base.py
+++ b/synapse/rest/client/v1/base.py
@@ -52,6 +52,10 @@ class ClientV1RestServlet(RestServlet):
"""A base Synapse REST Servlet for the client version 1 API.
"""
+ # This subclass was presumably created to allow the auth for the v1
+ # protocol version to be different, however this behaviour was removed.
+ # it may no longer be necessary
+
def __init__(self, hs):
"""
Args:
@@ -59,5 +63,5 @@ class ClientV1RestServlet(RestServlet):
"""
self.hs = hs
self.builder_factory = hs.get_event_builder_factory()
- self.auth = hs.get_v1auth()
+ self.auth = hs.get_auth()
self.txns = HttpTransactionCache(hs.get_clock())
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 1819a560cb..0206e664c1 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -150,7 +150,7 @@ class PushersRemoveRestServlet(RestServlet):
super(RestServlet, self).__init__()
self.hs = hs
self.notifier = hs.get_notifier()
- self.auth = hs.get_v1auth()
+ self.auth = hs.get_auth()
self.pusher_pool = self.hs.get_pusherpool()
@defer.inlineCallbacks
diff --git a/synapse/server.py b/synapse/server.py
index cd0c1a51be..ebdea6b0c4 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -105,7 +105,6 @@ class HomeServer(object):
'federation_client',
'federation_server',
'handlers',
- 'v1auth',
'auth',
'state_handler',
'state_resolution_handler',
@@ -225,15 +224,6 @@ class HomeServer(object):
def build_simple_http_client(self):
return SimpleHttpClient(self)
- def build_v1auth(self):
- orf = Auth(self)
- # Matrix spec makes no reference to what HTTP status code is returned,
- # but the V1 API uses 403 where it means 401, and the webclient
- # relies on this behaviour, so V1 gets its own copy of the auth
- # with backwards compat behaviour.
- orf.TOKEN_NOT_FOUND_HTTP_STATUS = 403
- return orf
-
def build_state_handler(self):
return StateHandler(self)
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 5fe4a0e56c..05cde96afc 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -22,7 +22,6 @@ import logging
import simplejson as json
from twisted.internet import defer
-
from synapse.storage.events_worker import EventsWorkerStore
from synapse.util.async import ObservableDeferred
from synapse.util.frozenutils import frozendict_json_encoder
@@ -425,7 +424,9 @@ class EventsStore(EventsWorkerStore):
)
current_state = yield self._get_new_state_after_events(
room_id,
- ev_ctx_rm, new_latest_event_ids,
+ ev_ctx_rm,
+ latest_event_ids,
+ new_latest_event_ids,
)
if current_state is not None:
current_state_for_room[room_id] = current_state
@@ -513,7 +514,8 @@ class EventsStore(EventsWorkerStore):
defer.returnValue(new_latest_event_ids)
@defer.inlineCallbacks
- def _get_new_state_after_events(self, room_id, events_context, new_latest_event_ids):
+ def _get_new_state_after_events(self, room_id, events_context, old_latest_event_ids,
+ new_latest_event_ids):
"""Calculate the current state dict after adding some new events to
a room
@@ -524,6 +526,9 @@ class EventsStore(EventsWorkerStore):
events_context (list[(EventBase, EventContext)]):
events and contexts which are being added to the room
+ old_latest_event_ids (iterable[str]):
+ the old forward extremities for the room.
+
new_latest_event_ids (iterable[str]):
the new forward extremities for the room.
@@ -534,64 +539,89 @@ class EventsStore(EventsWorkerStore):
"""
if not new_latest_event_ids:
- defer.returnValue({})
+ return
# map from state_group to ((type, key) -> event_id) state map
- state_groups = {}
- missing_event_ids = []
- was_updated = False
+ state_groups_map = {}
+ for ev, ctx in events_context:
+ if ctx.state_group is None:
+ # I don't think this can happen, but let's double-check
+ raise Exception(
+ "Context for new extremity event %s has no state "
+ "group" % (ev.event_id, ),
+ )
+
+ if ctx.state_group in state_groups_map:
+ continue
+
+ state_groups_map[ctx.state_group] = ctx.current_state_ids
+
+ # We need to map the event_ids to their state groups. First, let's
+ # check if the event is one we're persisting, in which case we can
+ # pull the state group from its context.
+ # Otherwise we need to pull the state group from the database.
+
+ # Set of events we need to fetch groups for. (We know none of the old
+ # extremities are going to be in events_context).
+ missing_event_ids = set(old_latest_event_ids)
+
+ event_id_to_state_group = {}
for event_id in new_latest_event_ids:
- # First search in the list of new events we're adding,
- # and then use the current state from that
+ # First search in the list of new events we're adding.
for ev, ctx in events_context:
if event_id == ev.event_id:
- if ctx.current_state_ids is None:
- raise Exception("Unknown current state")
-
- if ctx.state_group is None:
- # I don't think this can happen, but let's double-check
- raise Exception(
- "Context for new extremity event %s has no state "
- "group" % (event_id, ),
- )
-
- # If we've already seen the state group don't bother adding
- # it to the state sets again
- if ctx.state_group not in state_groups:
- state_groups[ctx.state_group] = ctx.current_state_ids
- if ctx.delta_ids or hasattr(ev, "state_key"):
- was_updated = True
+ event_id_to_state_group[event_id] = ctx.state_group
break
else:
# If we couldn't find it, then we'll need to pull
# the state from the database
- was_updated = True
- missing_event_ids.append(event_id)
-
- if not was_updated:
- return
+ missing_event_ids.add(event_id)
if missing_event_ids:
- # Now pull out the state for any missing events from DB
+ # Now pull out the state groups for any missing events from DB
event_to_groups = yield self._get_state_group_for_events(
missing_event_ids,
)
+ event_id_to_state_group.update(event_to_groups)
+
+ # State groups of old_latest_event_ids
+ old_state_groups = set(
+ event_id_to_state_group[evid] for evid in old_latest_event_ids
+ )
+
+ # State groups of new_latest_event_ids
+ new_state_groups = set(
+ event_id_to_state_group[evid] for evid in new_latest_event_ids
+ )
- groups = set(event_to_groups.itervalues()) - set(state_groups.iterkeys())
+ # If they old and new groups are the same then we don't need to do
+ # anything.
+ if old_state_groups == new_state_groups:
+ return
- if groups:
- group_to_state = yield self._get_state_for_groups(groups)
- state_groups.update(group_to_state)
+ # Now that we have calculated new_state_groups we need to get
+ # their state IDs so we can resolve to a single state set.
+ missing_state = new_state_groups - set(state_groups_map)
+ if missing_state:
+ group_to_state = yield self._get_state_for_groups(missing_state)
+ state_groups_map.update(group_to_state)
- if len(state_groups) == 1:
+ if len(new_state_groups) == 1:
# If there is only one state group, then we know what the current
# state is.
- defer.returnValue(state_groups.values()[0])
+ defer.returnValue(state_groups_map[new_state_groups.pop()])
+
+ # Ok, we need to defer to the state handler to resolve our state sets.
def get_events(ev_ids):
return self.get_events(
ev_ids, get_prev_content=False, check_redacted=False,
)
+
+ state_groups = {
+ sg: state_groups_map[sg] for sg in new_state_groups
+ }
+
events_map = {ev.event_id: ev for ev, _ in events_context}
logger.debug("calling resolve_state_groups from preserve_events")
res = yield self._state_resolution_handler.resolve_state_groups(
diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py
index 39bde6e3f8..069c0be762 100644
--- a/tests/metrics/test_metric.py
+++ b/tests/metrics/test_metric.py
@@ -16,7 +16,8 @@
from tests import unittest
from synapse.metrics.metric import (
- CounterMetric, CallbackMetric, DistributionMetric, CacheMetric
+ CounterMetric, CallbackMetric, DistributionMetric, CacheMetric,
+ _escape_label_value,
)
@@ -171,3 +172,21 @@ class CacheMetricTestCase(unittest.TestCase):
'cache:size{name="cache_name"} 1',
'cache:evicted_size{name="cache_name"} 2',
])
+
+
+class LabelValueEscapeTestCase(unittest.TestCase):
+ def test_simple(self):
+ string = "safjhsdlifhyskljfksdfh"
+ self.assertEqual(string, _escape_label_value(string))
+
+ def test_escape(self):
+ self.assertEqual(
+ "abc\\\"def\\nghi\\\\",
+ _escape_label_value("abc\"def\nghi\\"),
+ )
+
+ def test_sequence_of_escapes(self):
+ self.assertEqual(
+ "abc\\\"def\\nghi\\\\\\n",
+ _escape_label_value("abc\"def\nghi\\\n"),
+ )
diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py
index a8d09600bd..f5a7258e68 100644
--- a/tests/rest/client/v1/test_events.py
+++ b/tests/rest/client/v1/test_events.py
@@ -148,11 +148,16 @@ class EventStreamPermissionsTestCase(RestTestCase):
@defer.inlineCallbacks
def test_stream_basic_permissions(self):
- # invalid token, expect 403
+ # invalid token, expect 401
+ # note: this is in violation of the original v1 spec, which expected
+ # 403. However, since the v1 spec no longer exists and the v1
+ # implementation is now part of the r0 implementation, the newer
+ # behaviour is used instead to be consistent with the r0 spec.
+ # see issue #2602
(code, response) = yield self.mock_resource.trigger_get(
"/events?access_token=%s" % ("invalid" + self.token, )
)
- self.assertEquals(403, code, msg=str(response))
+ self.assertEquals(401, code, msg=str(response))
# valid token, expect content
(code, response) = yield self.mock_resource.trigger_get(
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index deac7f100c..dc94b8bd19 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -52,7 +52,7 @@ class ProfileTestCase(unittest.TestCase):
def _get_user_by_req(request=None, allow_guest=False):
return synapse.types.create_requester(myid)
- hs.get_v1auth().get_user_by_req = _get_user_by_req
+ hs.get_auth().get_user_by_req = _get_user_by_req
profile.register_servlets(hs, self.mock_resource)
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index d763400eaf..61d737725b 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -60,7 +60,7 @@ class RoomPermissionsTestCase(RestTestCase):
"token_id": 1,
"is_guest": False,
}
- hs.get_v1auth().get_user_by_access_token = get_user_by_access_token
+ hs.get_auth().get_user_by_access_token = get_user_by_access_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
@@ -70,7 +70,7 @@ class RoomPermissionsTestCase(RestTestCase):
synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource)
- self.auth = hs.get_v1auth()
+ self.auth = hs.get_auth()
# create some rooms under the name rmcreator_id
self.uncreated_rmid = "!aa:test"
@@ -425,7 +425,7 @@ class RoomsMemberListTestCase(RestTestCase):
"token_id": 1,
"is_guest": False,
}
- hs.get_v1auth().get_user_by_access_token = get_user_by_access_token
+ hs.get_auth().get_user_by_access_token = get_user_by_access_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
@@ -507,7 +507,7 @@ class RoomsCreateTestCase(RestTestCase):
"token_id": 1,
"is_guest": False,
}
- hs.get_v1auth().get_user_by_access_token = get_user_by_access_token
+ hs.get_auth().get_user_by_access_token = get_user_by_access_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
@@ -597,7 +597,7 @@ class RoomTopicTestCase(RestTestCase):
"is_guest": False,
}
- hs.get_v1auth().get_user_by_access_token = get_user_by_access_token
+ hs.get_auth().get_user_by_access_token = get_user_by_access_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
@@ -711,7 +711,7 @@ class RoomMemberStateTestCase(RestTestCase):
"token_id": 1,
"is_guest": False,
}
- hs.get_v1auth().get_user_by_access_token = get_user_by_access_token
+ hs.get_auth().get_user_by_access_token = get_user_by_access_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
@@ -843,7 +843,7 @@ class RoomMessagesTestCase(RestTestCase):
"token_id": 1,
"is_guest": False,
}
- hs.get_v1auth().get_user_by_access_token = get_user_by_access_token
+ hs.get_auth().get_user_by_access_token = get_user_by_access_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
@@ -945,7 +945,7 @@ class RoomInitialSyncTestCase(RestTestCase):
"token_id": 1,
"is_guest": False,
}
- hs.get_v1auth().get_user_by_access_token = get_user_by_access_token
+ hs.get_auth().get_user_by_access_token = get_user_by_access_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
@@ -1017,7 +1017,7 @@ class RoomMessageListTestCase(RestTestCase):
"token_id": 1,
"is_guest": False,
}
- hs.get_v1auth().get_user_by_access_token = get_user_by_access_token
+ hs.get_auth().get_user_by_access_token = get_user_by_access_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 2ec4ecab5b..fe161ee5cb 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -68,7 +68,7 @@ class RoomTypingTestCase(RestTestCase):
"is_guest": False,
}
- hs.get_v1auth().get_user_by_access_token = get_user_by_access_token
+ hs.get_auth().get_user_by_access_token = get_user_by_access_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
|