diff options
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/metrics/metric.py | 30 | ||||
-rw-r--r-- | synapse/rest/client/v1/base.py | 6 | ||||
-rw-r--r-- | synapse/rest/client/v1/pusher.py | 2 | ||||
-rw-r--r-- | synapse/server.py | 10 | ||||
-rw-r--r-- | synapse/storage/events.py | 104 |
5 files changed, 101 insertions, 51 deletions
diff --git a/synapse/metrics/metric.py b/synapse/metrics/metric.py index 89bd47c3f7..fbba94e633 100644 --- a/synapse/metrics/metric.py +++ b/synapse/metrics/metric.py @@ -16,6 +16,7 @@ from itertools import chain import logging +import re logger = logging.getLogger(__name__) @@ -56,8 +57,7 @@ class BaseMetric(object): return not len(self.labels) def _render_labelvalue(self, value): - # TODO: escape backslashes, quotes and newlines - return '"%s"' % (value) + return '"%s"' % (_escape_label_value(value),) def _render_key(self, values): if self.is_scalar(): @@ -299,3 +299,29 @@ class MemoryUsageMetric(object): "process_psutil_rss:total %d" % sum_rss, "process_psutil_rss:count %d" % len_rss, ] + + +def _escape_character(m): + """Replaces a single character with its escape sequence. + + Args: + m (re.MatchObject): A match object whose first group is the single + character to replace + + Returns: + str + """ + c = m.group(1) + if c == "\\": + return "\\\\" + elif c == "\"": + return "\\\"" + elif c == "\n": + return "\\n" + return c + + +def _escape_label_value(value): + """Takes a label value and escapes quotes, newlines and backslashes + """ + return re.sub(r"([\n\"\\])", _escape_character, value) diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py index c7aa0bbf59..197335d7aa 100644 --- a/synapse/rest/client/v1/base.py +++ b/synapse/rest/client/v1/base.py @@ -52,6 +52,10 @@ class ClientV1RestServlet(RestServlet): """A base Synapse REST Servlet for the client version 1 API. """ + # This subclass was presumably created to allow the auth for the v1 + # protocol version to be different, however this behaviour was removed. + # it may no longer be necessary + def __init__(self, hs): """ Args: @@ -59,5 +63,5 @@ class ClientV1RestServlet(RestServlet): """ self.hs = hs self.builder_factory = hs.get_event_builder_factory() - self.auth = hs.get_v1auth() + self.auth = hs.get_auth() self.txns = HttpTransactionCache(hs.get_clock()) diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 1819a560cb..0206e664c1 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -150,7 +150,7 @@ class PushersRemoveRestServlet(RestServlet): super(RestServlet, self).__init__() self.hs = hs self.notifier = hs.get_notifier() - self.auth = hs.get_v1auth() + self.auth = hs.get_auth() self.pusher_pool = self.hs.get_pusherpool() @defer.inlineCallbacks diff --git a/synapse/server.py b/synapse/server.py index cd0c1a51be..ebdea6b0c4 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -105,7 +105,6 @@ class HomeServer(object): 'federation_client', 'federation_server', 'handlers', - 'v1auth', 'auth', 'state_handler', 'state_resolution_handler', @@ -225,15 +224,6 @@ class HomeServer(object): def build_simple_http_client(self): return SimpleHttpClient(self) - def build_v1auth(self): - orf = Auth(self) - # Matrix spec makes no reference to what HTTP status code is returned, - # but the V1 API uses 403 where it means 401, and the webclient - # relies on this behaviour, so V1 gets its own copy of the auth - # with backwards compat behaviour. - orf.TOKEN_NOT_FOUND_HTTP_STATUS = 403 - return orf - def build_state_handler(self): return StateHandler(self) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 5fe4a0e56c..05cde96afc 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -22,7 +22,6 @@ import logging import simplejson as json from twisted.internet import defer - from synapse.storage.events_worker import EventsWorkerStore from synapse.util.async import ObservableDeferred from synapse.util.frozenutils import frozendict_json_encoder @@ -425,7 +424,9 @@ class EventsStore(EventsWorkerStore): ) current_state = yield self._get_new_state_after_events( room_id, - ev_ctx_rm, new_latest_event_ids, + ev_ctx_rm, + latest_event_ids, + new_latest_event_ids, ) if current_state is not None: current_state_for_room[room_id] = current_state @@ -513,7 +514,8 @@ class EventsStore(EventsWorkerStore): defer.returnValue(new_latest_event_ids) @defer.inlineCallbacks - def _get_new_state_after_events(self, room_id, events_context, new_latest_event_ids): + def _get_new_state_after_events(self, room_id, events_context, old_latest_event_ids, + new_latest_event_ids): """Calculate the current state dict after adding some new events to a room @@ -524,6 +526,9 @@ class EventsStore(EventsWorkerStore): events_context (list[(EventBase, EventContext)]): events and contexts which are being added to the room + old_latest_event_ids (iterable[str]): + the old forward extremities for the room. + new_latest_event_ids (iterable[str]): the new forward extremities for the room. @@ -534,64 +539,89 @@ class EventsStore(EventsWorkerStore): """ if not new_latest_event_ids: - defer.returnValue({}) + return # map from state_group to ((type, key) -> event_id) state map - state_groups = {} - missing_event_ids = [] - was_updated = False + state_groups_map = {} + for ev, ctx in events_context: + if ctx.state_group is None: + # I don't think this can happen, but let's double-check + raise Exception( + "Context for new extremity event %s has no state " + "group" % (ev.event_id, ), + ) + + if ctx.state_group in state_groups_map: + continue + + state_groups_map[ctx.state_group] = ctx.current_state_ids + + # We need to map the event_ids to their state groups. First, let's + # check if the event is one we're persisting, in which case we can + # pull the state group from its context. + # Otherwise we need to pull the state group from the database. + + # Set of events we need to fetch groups for. (We know none of the old + # extremities are going to be in events_context). + missing_event_ids = set(old_latest_event_ids) + + event_id_to_state_group = {} for event_id in new_latest_event_ids: - # First search in the list of new events we're adding, - # and then use the current state from that + # First search in the list of new events we're adding. for ev, ctx in events_context: if event_id == ev.event_id: - if ctx.current_state_ids is None: - raise Exception("Unknown current state") - - if ctx.state_group is None: - # I don't think this can happen, but let's double-check - raise Exception( - "Context for new extremity event %s has no state " - "group" % (event_id, ), - ) - - # If we've already seen the state group don't bother adding - # it to the state sets again - if ctx.state_group not in state_groups: - state_groups[ctx.state_group] = ctx.current_state_ids - if ctx.delta_ids or hasattr(ev, "state_key"): - was_updated = True + event_id_to_state_group[event_id] = ctx.state_group break else: # If we couldn't find it, then we'll need to pull # the state from the database - was_updated = True - missing_event_ids.append(event_id) - - if not was_updated: - return + missing_event_ids.add(event_id) if missing_event_ids: - # Now pull out the state for any missing events from DB + # Now pull out the state groups for any missing events from DB event_to_groups = yield self._get_state_group_for_events( missing_event_ids, ) + event_id_to_state_group.update(event_to_groups) + + # State groups of old_latest_event_ids + old_state_groups = set( + event_id_to_state_group[evid] for evid in old_latest_event_ids + ) + + # State groups of new_latest_event_ids + new_state_groups = set( + event_id_to_state_group[evid] for evid in new_latest_event_ids + ) - groups = set(event_to_groups.itervalues()) - set(state_groups.iterkeys()) + # If they old and new groups are the same then we don't need to do + # anything. + if old_state_groups == new_state_groups: + return - if groups: - group_to_state = yield self._get_state_for_groups(groups) - state_groups.update(group_to_state) + # Now that we have calculated new_state_groups we need to get + # their state IDs so we can resolve to a single state set. + missing_state = new_state_groups - set(state_groups_map) + if missing_state: + group_to_state = yield self._get_state_for_groups(missing_state) + state_groups_map.update(group_to_state) - if len(state_groups) == 1: + if len(new_state_groups) == 1: # If there is only one state group, then we know what the current # state is. - defer.returnValue(state_groups.values()[0]) + defer.returnValue(state_groups_map[new_state_groups.pop()]) + + # Ok, we need to defer to the state handler to resolve our state sets. def get_events(ev_ids): return self.get_events( ev_ids, get_prev_content=False, check_redacted=False, ) + + state_groups = { + sg: state_groups_map[sg] for sg in new_state_groups + } + events_map = {ev.event_id: ev for ev, _ in events_context} logger.debug("calling resolve_state_groups from preserve_events") res = yield self._state_resolution_handler.resolve_state_groups( |