diff options
Diffstat (limited to 'synapse')
24 files changed, 1227 insertions, 820 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py index f93e45a744..03a215ab1b 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -16,18 +16,14 @@ import logging import pymacaroons -from canonicaljson import encode_canonical_json -from signedjson.key import decode_verify_key_bytes -from signedjson.sign import verify_signed_json, SignatureVerifyException from twisted.internet import defer -from unpaddedbase64 import decode_base64 import synapse.types +from synapse import event_auth from synapse.api.constants import EventTypes, Membership, JoinRules -from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError -from synapse.types import UserID, get_domain_from_id +from synapse.api.errors import AuthError, Codes +from synapse.types import UserID from synapse.util.logcontext import preserve_context_over_fn -from synapse.util.logutils import log_function from synapse.util.metrics import Measure logger = logging.getLogger(__name__) @@ -78,147 +74,7 @@ class Auth(object): True if the auth checks pass. """ with Measure(self.clock, "auth.check"): - self.check_size_limits(event) - - if not hasattr(event, "room_id"): - raise AuthError(500, "Event has no room_id: %s" % event) - - if do_sig_check: - sender_domain = get_domain_from_id(event.sender) - event_id_domain = get_domain_from_id(event.event_id) - - is_invite_via_3pid = ( - event.type == EventTypes.Member - and event.membership == Membership.INVITE - and "third_party_invite" in event.content - ) - - # Check the sender's domain has signed the event - if not event.signatures.get(sender_domain): - # We allow invites via 3pid to have a sender from a different - # HS, as the sender must match the sender of the original - # 3pid invite. This is checked further down with the - # other dedicated membership checks. - if not is_invite_via_3pid: - raise AuthError(403, "Event not signed by sender's server") - - # Check the event_id's domain has signed the event - if not event.signatures.get(event_id_domain): - raise AuthError(403, "Event not signed by sending server") - - if auth_events is None: - # Oh, we don't know what the state of the room was, so we - # are trusting that this is allowed (at least for now) - logger.warn("Trusting event: %s", event.event_id) - return True - - if event.type == EventTypes.Create: - room_id_domain = get_domain_from_id(event.room_id) - if room_id_domain != sender_domain: - raise AuthError( - 403, - "Creation event's room_id domain does not match sender's" - ) - # FIXME - return True - - creation_event = auth_events.get((EventTypes.Create, ""), None) - - if not creation_event: - raise SynapseError( - 403, - "Room %r does not exist" % (event.room_id,) - ) - - creating_domain = get_domain_from_id(event.room_id) - originating_domain = get_domain_from_id(event.sender) - if creating_domain != originating_domain: - if not self.can_federate(event, auth_events): - raise AuthError( - 403, - "This room has been marked as unfederatable." - ) - - # FIXME: Temp hack - if event.type == EventTypes.Aliases: - if not event.is_state(): - raise AuthError( - 403, - "Alias event must be a state event", - ) - if not event.state_key: - raise AuthError( - 403, - "Alias event must have non-empty state_key" - ) - sender_domain = get_domain_from_id(event.sender) - if event.state_key != sender_domain: - raise AuthError( - 403, - "Alias event's state_key does not match sender's domain" - ) - return True - - logger.debug( - "Auth events: %s", - [a.event_id for a in auth_events.values()] - ) - - if event.type == EventTypes.Member: - allowed = self.is_membership_change_allowed( - event, auth_events - ) - if allowed: - logger.debug("Allowing! %s", event) - else: - logger.debug("Denying! %s", event) - return allowed - - self.check_event_sender_in_room(event, auth_events) - - # Special case to allow m.room.third_party_invite events wherever - # a user is allowed to issue invites. Fixes - # https://github.com/vector-im/vector-web/issues/1208 hopefully - if event.type == EventTypes.ThirdPartyInvite: - user_level = self._get_user_power_level(event.user_id, auth_events) - invite_level = self._get_named_level(auth_events, "invite", 0) - - if user_level < invite_level: - raise AuthError( - 403, ( - "You cannot issue a third party invite for %s." % - (event.content.display_name,) - ) - ) - else: - return True - - self._can_send_event(event, auth_events) - - if event.type == EventTypes.PowerLevels: - self._check_power_levels(event, auth_events) - - if event.type == EventTypes.Redaction: - self.check_redaction(event, auth_events) - - logger.debug("Allowing! %s", event) - - def check_size_limits(self, event): - def too_big(field): - raise EventSizeError("%s too large" % (field,)) - - if len(event.user_id) > 255: - too_big("user_id") - if len(event.room_id) > 255: - too_big("room_id") - if event.is_state() and len(event.state_key) > 255: - too_big("state_key") - if len(event.type) > 255: - too_big("type") - if len(event.event_id) > 255: - too_big("event_id") - if len(encode_canonical_json(event.get_pdu_json())) > 65536: - too_big("event") + event_auth.check(event, auth_events, do_sig_check=do_sig_check) @defer.inlineCallbacks def check_joined_room(self, room_id, user_id, current_state=None): @@ -290,7 +146,7 @@ class Auth(object): with Measure(self.clock, "check_host_in_room"): latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - logger.info("calling resolve_state_groups from check_host_in_room") + logger.debug("calling resolve_state_groups from check_host_in_room") entry = yield self.state.resolve_state_groups( room_id, latest_event_ids ) @@ -300,16 +156,6 @@ class Auth(object): ) defer.returnValue(ret) - def check_event_sender_in_room(self, event, auth_events): - key = (EventTypes.Member, event.user_id, ) - member_event = auth_events.get(key) - - return self._check_joined_room( - member_event, - event.user_id, - event.room_id - ) - def _check_joined_room(self, member, user_id, room_id): if not member or member.membership != Membership.JOIN: raise AuthError(403, "User %s not in room %s (%s)" % ( @@ -321,267 +167,8 @@ class Auth(object): return creation_event.content.get("m.federate", True) is True - @log_function - def is_membership_change_allowed(self, event, auth_events): - membership = event.content["membership"] - - # Check if this is the room creator joining: - if len(event.prev_events) == 1 and Membership.JOIN == membership: - # Get room creation event: - key = (EventTypes.Create, "", ) - create = auth_events.get(key) - if create and event.prev_events[0][0] == create.event_id: - if create.content["creator"] == event.state_key: - return True - - target_user_id = event.state_key - - creating_domain = get_domain_from_id(event.room_id) - target_domain = get_domain_from_id(target_user_id) - if creating_domain != target_domain: - if not self.can_federate(event, auth_events): - raise AuthError( - 403, - "This room has been marked as unfederatable." - ) - - # get info about the caller - key = (EventTypes.Member, event.user_id, ) - caller = auth_events.get(key) - - caller_in_room = caller and caller.membership == Membership.JOIN - caller_invited = caller and caller.membership == Membership.INVITE - - # get info about the target - key = (EventTypes.Member, target_user_id, ) - target = auth_events.get(key) - - target_in_room = target and target.membership == Membership.JOIN - target_banned = target and target.membership == Membership.BAN - - key = (EventTypes.JoinRules, "", ) - join_rule_event = auth_events.get(key) - if join_rule_event: - join_rule = join_rule_event.content.get( - "join_rule", JoinRules.INVITE - ) - else: - join_rule = JoinRules.INVITE - - user_level = self._get_user_power_level(event.user_id, auth_events) - target_level = self._get_user_power_level( - target_user_id, auth_events - ) - - # FIXME (erikj): What should we do here as the default? - ban_level = self._get_named_level(auth_events, "ban", 50) - - logger.debug( - "is_membership_change_allowed: %s", - { - "caller_in_room": caller_in_room, - "caller_invited": caller_invited, - "target_banned": target_banned, - "target_in_room": target_in_room, - "membership": membership, - "join_rule": join_rule, - "target_user_id": target_user_id, - "event.user_id": event.user_id, - } - ) - - if Membership.INVITE == membership and "third_party_invite" in event.content: - if not self._verify_third_party_invite(event, auth_events): - raise AuthError(403, "You are not invited to this room.") - if target_banned: - raise AuthError( - 403, "%s is banned from the room" % (target_user_id,) - ) - return True - - if Membership.JOIN != membership: - if (caller_invited - and Membership.LEAVE == membership - and target_user_id == event.user_id): - return True - - if not caller_in_room: # caller isn't joined - raise AuthError( - 403, - "%s not in room %s." % (event.user_id, event.room_id,) - ) - - if Membership.INVITE == membership: - # TODO (erikj): We should probably handle this more intelligently - # PRIVATE join rules. - - # Invites are valid iff caller is in the room and target isn't. - if target_banned: - raise AuthError( - 403, "%s is banned from the room" % (target_user_id,) - ) - elif target_in_room: # the target is already in the room. - raise AuthError(403, "%s is already in the room." % - target_user_id) - else: - invite_level = self._get_named_level(auth_events, "invite", 0) - - if user_level < invite_level: - raise AuthError( - 403, "You cannot invite user %s." % target_user_id - ) - elif Membership.JOIN == membership: - # Joins are valid iff caller == target and they were: - # invited: They are accepting the invitation - # joined: It's a NOOP - if event.user_id != target_user_id: - raise AuthError(403, "Cannot force another user to join.") - elif target_banned: - raise AuthError(403, "You are banned from this room") - elif join_rule == JoinRules.PUBLIC: - pass - elif join_rule == JoinRules.INVITE: - if not caller_in_room and not caller_invited: - raise AuthError(403, "You are not invited to this room.") - else: - # TODO (erikj): may_join list - # TODO (erikj): private rooms - raise AuthError(403, "You are not allowed to join this room") - elif Membership.LEAVE == membership: - # TODO (erikj): Implement kicks. - if target_banned and user_level < ban_level: - raise AuthError( - 403, "You cannot unban user &s." % (target_user_id,) - ) - elif target_user_id != event.user_id: - kick_level = self._get_named_level(auth_events, "kick", 50) - - if user_level < kick_level or user_level <= target_level: - raise AuthError( - 403, "You cannot kick user %s." % target_user_id - ) - elif Membership.BAN == membership: - if user_level < ban_level or user_level <= target_level: - raise AuthError(403, "You don't have permission to ban") - else: - raise AuthError(500, "Unknown membership %s" % membership) - - return True - - def _verify_third_party_invite(self, event, auth_events): - """ - Validates that the invite event is authorized by a previous third-party invite. - - Checks that the public key, and keyserver, match those in the third party invite, - and that the invite event has a signature issued using that public key. - - Args: - event: The m.room.member join event being validated. - auth_events: All relevant previous context events which may be used - for authorization decisions. - - Return: - True if the event fulfills the expectations of a previous third party - invite event. - """ - if "third_party_invite" not in event.content: - return False - if "signed" not in event.content["third_party_invite"]: - return False - signed = event.content["third_party_invite"]["signed"] - for key in {"mxid", "token"}: - if key not in signed: - return False - - token = signed["token"] - - invite_event = auth_events.get( - (EventTypes.ThirdPartyInvite, token,) - ) - if not invite_event: - return False - - if invite_event.sender != event.sender: - return False - - if event.user_id != invite_event.user_id: - return False - - if signed["mxid"] != event.state_key: - return False - if signed["token"] != token: - return False - - for public_key_object in self.get_public_keys(invite_event): - public_key = public_key_object["public_key"] - try: - for server, signature_block in signed["signatures"].items(): - for key_name, encoded_signature in signature_block.items(): - if not key_name.startswith("ed25519:"): - continue - verify_key = decode_verify_key_bytes( - key_name, - decode_base64(public_key) - ) - verify_signed_json(signed, server, verify_key) - - # We got the public key from the invite, so we know that the - # correct server signed the signed bundle. - # The caller is responsible for checking that the signing - # server has not revoked that public key. - return True - except (KeyError, SignatureVerifyException,): - continue - return False - def get_public_keys(self, invite_event): - public_keys = [] - if "public_key" in invite_event.content: - o = { - "public_key": invite_event.content["public_key"], - } - if "key_validity_url" in invite_event.content: - o["key_validity_url"] = invite_event.content["key_validity_url"] - public_keys.append(o) - public_keys.extend(invite_event.content.get("public_keys", [])) - return public_keys - - def _get_power_level_event(self, auth_events): - key = (EventTypes.PowerLevels, "", ) - return auth_events.get(key) - - def _get_user_power_level(self, user_id, auth_events): - power_level_event = self._get_power_level_event(auth_events) - - if power_level_event: - level = power_level_event.content.get("users", {}).get(user_id) - if not level: - level = power_level_event.content.get("users_default", 0) - - if level is None: - return 0 - else: - return int(level) - else: - key = (EventTypes.Create, "", ) - create_event = auth_events.get(key) - if (create_event is not None and - create_event.content["creator"] == user_id): - return 100 - else: - return 0 - - def _get_named_level(self, auth_events, name, default): - power_level_event = self._get_power_level_event(auth_events) - - if not power_level_event: - return default - - level = power_level_event.content.get(name, None) - if level is not None: - return int(level) - else: - return default + return event_auth.get_public_keys(invite_event) @defer.inlineCallbacks def get_user_by_req(self, request, allow_guest=False, rights="access"): @@ -974,56 +561,6 @@ class Auth(object): defer.returnValue(auth_ids) - def _get_send_level(self, etype, state_key, auth_events): - key = (EventTypes.PowerLevels, "", ) - send_level_event = auth_events.get(key) - send_level = None - if send_level_event: - send_level = send_level_event.content.get("events", {}).get( - etype - ) - if send_level is None: - if state_key is not None: - send_level = send_level_event.content.get( - "state_default", 50 - ) - else: - send_level = send_level_event.content.get( - "events_default", 0 - ) - - if send_level: - send_level = int(send_level) - else: - send_level = 0 - - return send_level - - @log_function - def _can_send_event(self, event, auth_events): - send_level = self._get_send_level( - event.type, event.get("state_key", None), auth_events - ) - user_level = self._get_user_power_level(event.user_id, auth_events) - - if user_level < send_level: - raise AuthError( - 403, - "You don't have permission to post that to the room. " + - "user_level (%d) < send_level (%d)" % (user_level, send_level) - ) - - # Check state_key - if hasattr(event, "state_key"): - if event.state_key.startswith("@"): - if event.state_key != event.user_id: - raise AuthError( - 403, - "You are not allowed to set others state" - ) - - return True - def check_redaction(self, event, auth_events): """Check whether the event sender is allowed to redact the target event. @@ -1037,107 +574,7 @@ class Auth(object): AuthError if the event sender is definitely not allowed to redact the target event. """ - user_level = self._get_user_power_level(event.user_id, auth_events) - - redact_level = self._get_named_level(auth_events, "redact", 50) - - if user_level >= redact_level: - return False - - redacter_domain = get_domain_from_id(event.event_id) - redactee_domain = get_domain_from_id(event.redacts) - if redacter_domain == redactee_domain: - return True - - raise AuthError( - 403, - "You don't have permission to redact events" - ) - - def _check_power_levels(self, event, auth_events): - user_list = event.content.get("users", {}) - # Validate users - for k, v in user_list.items(): - try: - UserID.from_string(k) - except: - raise SynapseError(400, "Not a valid user_id: %s" % (k,)) - - try: - int(v) - except: - raise SynapseError(400, "Not a valid power level: %s" % (v,)) - - key = (event.type, event.state_key, ) - current_state = auth_events.get(key) - - if not current_state: - return - - user_level = self._get_user_power_level(event.user_id, auth_events) - - # Check other levels: - levels_to_check = [ - ("users_default", None), - ("events_default", None), - ("state_default", None), - ("ban", None), - ("redact", None), - ("kick", None), - ("invite", None), - ] - - old_list = current_state.content.get("users") - for user in set(old_list.keys() + user_list.keys()): - levels_to_check.append( - (user, "users") - ) - - old_list = current_state.content.get("events") - new_list = event.content.get("events") - for ev_id in set(old_list.keys() + new_list.keys()): - levels_to_check.append( - (ev_id, "events") - ) - - old_state = current_state.content - new_state = event.content - - for level_to_check, dir in levels_to_check: - old_loc = old_state - new_loc = new_state - if dir: - old_loc = old_loc.get(dir, {}) - new_loc = new_loc.get(dir, {}) - - if level_to_check in old_loc: - old_level = int(old_loc[level_to_check]) - else: - old_level = None - - if level_to_check in new_loc: - new_level = int(new_loc[level_to_check]) - else: - new_level = None - - if new_level is not None and old_level is not None: - if new_level == old_level: - continue - - if dir == "users" and level_to_check != event.user_id: - if old_level == user_level: - raise AuthError( - 403, - "You don't have permission to remove ops level equal " - "to your own" - ) - - if old_level > user_level or new_level > user_level: - raise AuthError( - 403, - "You don't have permission to add ops level greater " - "than your own" - ) + return event_auth.check_redaction(event, auth_events) @defer.inlineCallbacks def check_can_change_room_list(self, room_id, user): @@ -1167,10 +604,10 @@ class Auth(object): if power_level_event: auth_events[(EventTypes.PowerLevels, "")] = power_level_event - send_level = self._get_send_level( + send_level = event_auth.get_send_level( EventTypes.Aliases, "", auth_events ) - user_level = self._get_user_power_level(user_id, auth_events) + user_level = event_auth.get_user_power_level(user_id, auth_events) if user_level < send_level: raise AuthError( diff --git a/synapse/event_auth.py b/synapse/event_auth.py new file mode 100644 index 0000000000..4096c606f1 --- /dev/null +++ b/synapse/event_auth.py @@ -0,0 +1,678 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 - 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from canonicaljson import encode_canonical_json +from signedjson.key import decode_verify_key_bytes +from signedjson.sign import verify_signed_json, SignatureVerifyException +from unpaddedbase64 import decode_base64 + +from synapse.api.constants import EventTypes, Membership, JoinRules +from synapse.api.errors import AuthError, SynapseError, EventSizeError +from synapse.types import UserID, get_domain_from_id + +logger = logging.getLogger(__name__) + + +def check(event, auth_events, do_sig_check=True, do_size_check=True): + """ Checks if this event is correctly authed. + + Args: + event: the event being checked. + auth_events (dict: event-key -> event): the existing room state. + + + Returns: + True if the auth checks pass. + """ + if do_size_check: + _check_size_limits(event) + + if not hasattr(event, "room_id"): + raise AuthError(500, "Event has no room_id: %s" % event) + + if do_sig_check: + sender_domain = get_domain_from_id(event.sender) + event_id_domain = get_domain_from_id(event.event_id) + + is_invite_via_3pid = ( + event.type == EventTypes.Member + and event.membership == Membership.INVITE + and "third_party_invite" in event.content + ) + + # Check the sender's domain has signed the event + if not event.signatures.get(sender_domain): + # We allow invites via 3pid to have a sender from a different + # HS, as the sender must match the sender of the original + # 3pid invite. This is checked further down with the + # other dedicated membership checks. + if not is_invite_via_3pid: + raise AuthError(403, "Event not signed by sender's server") + + # Check the event_id's domain has signed the event + if not event.signatures.get(event_id_domain): + raise AuthError(403, "Event not signed by sending server") + + if auth_events is None: + # Oh, we don't know what the state of the room was, so we + # are trusting that this is allowed (at least for now) + logger.warn("Trusting event: %s", event.event_id) + return True + + if event.type == EventTypes.Create: + room_id_domain = get_domain_from_id(event.room_id) + if room_id_domain != sender_domain: + raise AuthError( + 403, + "Creation event's room_id domain does not match sender's" + ) + # FIXME + return True + + creation_event = auth_events.get((EventTypes.Create, ""), None) + + if not creation_event: + raise SynapseError( + 403, + "Room %r does not exist" % (event.room_id,) + ) + + creating_domain = get_domain_from_id(event.room_id) + originating_domain = get_domain_from_id(event.sender) + if creating_domain != originating_domain: + if not _can_federate(event, auth_events): + raise AuthError( + 403, + "This room has been marked as unfederatable." + ) + + # FIXME: Temp hack + if event.type == EventTypes.Aliases: + if not event.is_state(): + raise AuthError( + 403, + "Alias event must be a state event", + ) + if not event.state_key: + raise AuthError( + 403, + "Alias event must have non-empty state_key" + ) + sender_domain = get_domain_from_id(event.sender) + if event.state_key != sender_domain: + raise AuthError( + 403, + "Alias event's state_key does not match sender's domain" + ) + return True + + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "Auth events: %s", + [a.event_id for a in auth_events.values()] + ) + + if event.type == EventTypes.Member: + allowed = _is_membership_change_allowed( + event, auth_events + ) + if allowed: + logger.debug("Allowing! %s", event) + else: + logger.debug("Denying! %s", event) + return allowed + + _check_event_sender_in_room(event, auth_events) + + # Special case to allow m.room.third_party_invite events wherever + # a user is allowed to issue invites. Fixes + # https://github.com/vector-im/vector-web/issues/1208 hopefully + if event.type == EventTypes.ThirdPartyInvite: + user_level = get_user_power_level(event.user_id, auth_events) + invite_level = _get_named_level(auth_events, "invite", 0) + + if user_level < invite_level: + raise AuthError( + 403, ( + "You cannot issue a third party invite for %s." % + (event.content.display_name,) + ) + ) + else: + return True + + _can_send_event(event, auth_events) + + if event.type == EventTypes.PowerLevels: + _check_power_levels(event, auth_events) + + if event.type == EventTypes.Redaction: + check_redaction(event, auth_events) + + logger.debug("Allowing! %s", event) + + +def _check_size_limits(event): + def too_big(field): + raise EventSizeError("%s too large" % (field,)) + + if len(event.user_id) > 255: + too_big("user_id") + if len(event.room_id) > 255: + too_big("room_id") + if event.is_state() and len(event.state_key) > 255: + too_big("state_key") + if len(event.type) > 255: + too_big("type") + if len(event.event_id) > 255: + too_big("event_id") + if len(encode_canonical_json(event.get_pdu_json())) > 65536: + too_big("event") + + +def _can_federate(event, auth_events): + creation_event = auth_events.get((EventTypes.Create, "")) + + return creation_event.content.get("m.federate", True) is True + + +def _is_membership_change_allowed(event, auth_events): + membership = event.content["membership"] + + # Check if this is the room creator joining: + if len(event.prev_events) == 1 and Membership.JOIN == membership: + # Get room creation event: + key = (EventTypes.Create, "", ) + create = auth_events.get(key) + if create and event.prev_events[0][0] == create.event_id: + if create.content["creator"] == event.state_key: + return True + + target_user_id = event.state_key + + creating_domain = get_domain_from_id(event.room_id) + target_domain = get_domain_from_id(target_user_id) + if creating_domain != target_domain: + if not _can_federate(event, auth_events): + raise AuthError( + 403, + "This room has been marked as unfederatable." + ) + + # get info about the caller + key = (EventTypes.Member, event.user_id, ) + caller = auth_events.get(key) + + caller_in_room = caller and caller.membership == Membership.JOIN + caller_invited = caller and caller.membership == Membership.INVITE + + # get info about the target + key = (EventTypes.Member, target_user_id, ) + target = auth_events.get(key) + + target_in_room = target and target.membership == Membership.JOIN + target_banned = target and target.membership == Membership.BAN + + key = (EventTypes.JoinRules, "", ) + join_rule_event = auth_events.get(key) + if join_rule_event: + join_rule = join_rule_event.content.get( + "join_rule", JoinRules.INVITE + ) + else: + join_rule = JoinRules.INVITE + + user_level = get_user_power_level(event.user_id, auth_events) + target_level = get_user_power_level( + target_user_id, auth_events + ) + + # FIXME (erikj): What should we do here as the default? + ban_level = _get_named_level(auth_events, "ban", 50) + + logger.debug( + "_is_membership_change_allowed: %s", + { + "caller_in_room": caller_in_room, + "caller_invited": caller_invited, + "target_banned": target_banned, + "target_in_room": target_in_room, + "membership": membership, + "join_rule": join_rule, + "target_user_id": target_user_id, + "event.user_id": event.user_id, + } + ) + + if Membership.INVITE == membership and "third_party_invite" in event.content: + if not _verify_third_party_invite(event, auth_events): + raise AuthError(403, "You are not invited to this room.") + if target_banned: + raise AuthError( + 403, "%s is banned from the room" % (target_user_id,) + ) + return True + + if Membership.JOIN != membership: + if (caller_invited + and Membership.LEAVE == membership + and target_user_id == event.user_id): + return True + + if not caller_in_room: # caller isn't joined + raise AuthError( + 403, + "%s not in room %s." % (event.user_id, event.room_id,) + ) + + if Membership.INVITE == membership: + # TODO (erikj): We should probably handle this more intelligently + # PRIVATE join rules. + + # Invites are valid iff caller is in the room and target isn't. + if target_banned: + raise AuthError( + 403, "%s is banned from the room" % (target_user_id,) + ) + elif target_in_room: # the target is already in the room. + raise AuthError(403, "%s is already in the room." % + target_user_id) + else: + invite_level = _get_named_level(auth_events, "invite", 0) + + if user_level < invite_level: + raise AuthError( + 403, "You cannot invite user %s." % target_user_id + ) + elif Membership.JOIN == membership: + # Joins are valid iff caller == target and they were: + # invited: They are accepting the invitation + # joined: It's a NOOP + if event.user_id != target_user_id: + raise AuthError(403, "Cannot force another user to join.") + elif target_banned: + raise AuthError(403, "You are banned from this room") + elif join_rule == JoinRules.PUBLIC: + pass + elif join_rule == JoinRules.INVITE: + if not caller_in_room and not caller_invited: + raise AuthError(403, "You are not invited to this room.") + else: + # TODO (erikj): may_join list + # TODO (erikj): private rooms + raise AuthError(403, "You are not allowed to join this room") + elif Membership.LEAVE == membership: + # TODO (erikj): Implement kicks. + if target_banned and user_level < ban_level: + raise AuthError( + 403, "You cannot unban user &s." % (target_user_id,) + ) + elif target_user_id != event.user_id: + kick_level = _get_named_level(auth_events, "kick", 50) + + if user_level < kick_level or user_level <= target_level: + raise AuthError( + 403, "You cannot kick user %s." % target_user_id + ) + elif Membership.BAN == membership: + if user_level < ban_level or user_level <= target_level: + raise AuthError(403, "You don't have permission to ban") + else: + raise AuthError(500, "Unknown membership %s" % membership) + + return True + + +def _check_event_sender_in_room(event, auth_events): + key = (EventTypes.Member, event.user_id, ) + member_event = auth_events.get(key) + + return _check_joined_room( + member_event, + event.user_id, + event.room_id + ) + + +def _check_joined_room(member, user_id, room_id): + if not member or member.membership != Membership.JOIN: + raise AuthError(403, "User %s not in room %s (%s)" % ( + user_id, room_id, repr(member) + )) + + +def get_send_level(etype, state_key, auth_events): + key = (EventTypes.PowerLevels, "", ) + send_level_event = auth_events.get(key) + send_level = None + if send_level_event: + send_level = send_level_event.content.get("events", {}).get( + etype + ) + if send_level is None: + if state_key is not None: + send_level = send_level_event.content.get( + "state_default", 50 + ) + else: + send_level = send_level_event.content.get( + "events_default", 0 + ) + + if send_level: + send_level = int(send_level) + else: + send_level = 0 + + return send_level + + +def _can_send_event(event, auth_events): + send_level = get_send_level( + event.type, event.get("state_key", None), auth_events + ) + user_level = get_user_power_level(event.user_id, auth_events) + + if user_level < send_level: + raise AuthError( + 403, + "You don't have permission to post that to the room. " + + "user_level (%d) < send_level (%d)" % (user_level, send_level) + ) + + # Check state_key + if hasattr(event, "state_key"): + if event.state_key.startswith("@"): + if event.state_key != event.user_id: + raise AuthError( + 403, + "You are not allowed to set others state" + ) + + return True + + +def check_redaction(event, auth_events): + """Check whether the event sender is allowed to redact the target event. + + Returns: + True if the the sender is allowed to redact the target event if the + target event was created by them. + False if the sender is allowed to redact the target event with no + further checks. + + Raises: + AuthError if the event sender is definitely not allowed to redact + the target event. + """ + user_level = get_user_power_level(event.user_id, auth_events) + + redact_level = _get_named_level(auth_events, "redact", 50) + + if user_level >= redact_level: + return False + + redacter_domain = get_domain_from_id(event.event_id) + redactee_domain = get_domain_from_id(event.redacts) + if redacter_domain == redactee_domain: + return True + + raise AuthError( + 403, + "You don't have permission to redact events" + ) + + +def _check_power_levels(event, auth_events): + user_list = event.content.get("users", {}) + # Validate users + for k, v in user_list.items(): + try: + UserID.from_string(k) + except: + raise SynapseError(400, "Not a valid user_id: %s" % (k,)) + + try: + int(v) + except: + raise SynapseError(400, "Not a valid power level: %s" % (v,)) + + key = (event.type, event.state_key, ) + current_state = auth_events.get(key) + + if not current_state: + return + + user_level = get_user_power_level(event.user_id, auth_events) + + # Check other levels: + levels_to_check = [ + ("users_default", None), + ("events_default", None), + ("state_default", None), + ("ban", None), + ("redact", None), + ("kick", None), + ("invite", None), + ] + + old_list = current_state.content.get("users") + for user in set(old_list.keys() + user_list.keys()): + levels_to_check.append( + (user, "users") + ) + + old_list = current_state.content.get("events") + new_list = event.content.get("events") + for ev_id in set(old_list.keys() + new_list.keys()): + levels_to_check.append( + (ev_id, "events") + ) + + old_state = current_state.content + new_state = event.content + + for level_to_check, dir in levels_to_check: + old_loc = old_state + new_loc = new_state + if dir: + old_loc = old_loc.get(dir, {}) + new_loc = new_loc.get(dir, {}) + + if level_to_check in old_loc: + old_level = int(old_loc[level_to_check]) + else: + old_level = None + + if level_to_check in new_loc: + new_level = int(new_loc[level_to_check]) + else: + new_level = None + + if new_level is not None and old_level is not None: + if new_level == old_level: + continue + + if dir == "users" and level_to_check != event.user_id: + if old_level == user_level: + raise AuthError( + 403, + "You don't have permission to remove ops level equal " + "to your own" + ) + + if old_level > user_level or new_level > user_level: + raise AuthError( + 403, + "You don't have permission to add ops level greater " + "than your own" + ) + + +def _get_power_level_event(auth_events): + key = (EventTypes.PowerLevels, "", ) + return auth_events.get(key) + + +def get_user_power_level(user_id, auth_events): + power_level_event = _get_power_level_event(auth_events) + + if power_level_event: + level = power_level_event.content.get("users", {}).get(user_id) + if not level: + level = power_level_event.content.get("users_default", 0) + + if level is None: + return 0 + else: + return int(level) + else: + key = (EventTypes.Create, "", ) + create_event = auth_events.get(key) + if (create_event is not None and + create_event.content["creator"] == user_id): + return 100 + else: + return 0 + + +def _get_named_level(auth_events, name, default): + power_level_event = _get_power_level_event(auth_events) + + if not power_level_event: + return default + + level = power_level_event.content.get(name, None) + if level is not None: + return int(level) + else: + return default + + +def _verify_third_party_invite(event, auth_events): + """ + Validates that the invite event is authorized by a previous third-party invite. + + Checks that the public key, and keyserver, match those in the third party invite, + and that the invite event has a signature issued using that public key. + + Args: + event: The m.room.member join event being validated. + auth_events: All relevant previous context events which may be used + for authorization decisions. + + Return: + True if the event fulfills the expectations of a previous third party + invite event. + """ + if "third_party_invite" not in event.content: + return False + if "signed" not in event.content["third_party_invite"]: + return False + signed = event.content["third_party_invite"]["signed"] + for key in {"mxid", "token"}: + if key not in signed: + return False + + token = signed["token"] + + invite_event = auth_events.get( + (EventTypes.ThirdPartyInvite, token,) + ) + if not invite_event: + return False + + if invite_event.sender != event.sender: + return False + + if event.user_id != invite_event.user_id: + return False + + if signed["mxid"] != event.state_key: + return False + if signed["token"] != token: + return False + + for public_key_object in get_public_keys(invite_event): + public_key = public_key_object["public_key"] + try: + for server, signature_block in signed["signatures"].items(): + for key_name, encoded_signature in signature_block.items(): + if not key_name.startswith("ed25519:"): + continue + verify_key = decode_verify_key_bytes( + key_name, + decode_base64(public_key) + ) + verify_signed_json(signed, server, verify_key) + + # We got the public key from the invite, so we know that the + # correct server signed the signed bundle. + # The caller is responsible for checking that the signing + # server has not revoked that public key. + return True + except (KeyError, SignatureVerifyException,): + continue + return False + + +def get_public_keys(invite_event): + public_keys = [] + if "public_key" in invite_event.content: + o = { + "public_key": invite_event.content["public_key"], + } + if "key_validity_url" in invite_event.content: + o["key_validity_url"] = invite_event.content["key_validity_url"] + public_keys.append(o) + public_keys.extend(invite_event.content.get("public_keys", [])) + return public_keys + + +def auth_types_for_event(event): + """Given an event, return a list of (EventType, StateKey) that may be + needed to auth the event. The returned list may be a superset of what + would actually be required depending on the full state of the room. + + Used to limit the number of events to fetch from the database to + actually auth the event. + """ + if event.type == EventTypes.Create: + return [] + + auth_types = [] + + auth_types.append((EventTypes.PowerLevels, "", )) + auth_types.append((EventTypes.Member, event.user_id, )) + auth_types.append((EventTypes.Create, "", )) + + if event.type == EventTypes.Member: + membership = event.content["membership"] + if membership in [Membership.JOIN, Membership.INVITE]: + auth_types.append((EventTypes.JoinRules, "", )) + + auth_types.append((EventTypes.Member, event.state_key, )) + + if membership == Membership.INVITE: + if "third_party_invite" in event.content: + key = ( + EventTypes.ThirdPartyInvite, + event.content["third_party_invite"]["signed"]["token"] + ) + auth_types.append(key) + + return auth_types diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index da9f3ad436..e673e96cc0 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -79,7 +79,6 @@ class EventBase(object): auth_events = _event_dict_property("auth_events") depth = _event_dict_property("depth") content = _event_dict_property("content") - event_id = _event_dict_property("event_id") hashes = _event_dict_property("hashes") origin = _event_dict_property("origin") origin_server_ts = _event_dict_property("origin_server_ts") @@ -88,8 +87,6 @@ class EventBase(object): redacts = _event_dict_property("redacts") room_id = _event_dict_property("room_id") sender = _event_dict_property("sender") - state_key = _event_dict_property("state_key") - type = _event_dict_property("type") user_id = _event_dict_property("sender") @property @@ -162,6 +159,11 @@ class FrozenEvent(EventBase): else: frozen_dict = event_dict + self.event_id = event_dict["event_id"] + self.type = event_dict["type"] + if "state_key" in event_dict: + self.state_key = event_dict["state_key"] + super(FrozenEvent, self).__init__( frozen_dict, signatures=signatures, diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 7369d70980..365fd96bd2 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import EventBase, FrozenEvent +from . import EventBase, FrozenEvent, _event_dict_property from synapse.types import EventID @@ -34,6 +34,10 @@ class EventBuilder(EventBase): internal_metadata_dict=internal_metadata_dict, ) + event_id = _event_dict_property("event_id") + state_key = _event_dict_property("state_key") + type = _event_dict_property("type") + def build(self): return FrozenEvent.from_event(self) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index b4bcec77ed..c9175bb33d 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -26,7 +26,7 @@ from synapse.util import unwrapFirstError from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.logutils import log_function from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred -from synapse.events import FrozenEvent +from synapse.events import FrozenEvent, builder import synapse.metrics from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination @@ -499,8 +499,10 @@ class FederationClient(FederationBase): if "prev_state" not in pdu_dict: pdu_dict["prev_state"] = [] + ev = builder.EventBuilder(pdu_dict) + defer.returnValue( - (destination, self.event_from_pdu_json(pdu_dict)) + (destination, ev) ) break except CodeMessageException as e: diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index 7db7b806dc..6b3a7abb9e 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -362,7 +362,7 @@ class TransactionQueue(object): if not success: break except NotRetryingDestination: - logger.info( + logger.debug( "TX [%s] not ready for retry yet - " "dropping transaction for now", destination, diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 1021bcc405..d3f5892376 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -591,12 +591,12 @@ class FederationHandler(BaseHandler): event_ids = list(extremities.keys()) - logger.info("calling resolve_state_groups in _maybe_backfill") + logger.debug("calling resolve_state_groups in _maybe_backfill") states = yield preserve_context_over_deferred(defer.gatherResults([ preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e]) for e in event_ids ])) - states = dict(zip(event_ids, [s[1] for s in states])) + states = dict(zip(event_ids, [s.state for s in states])) state_map = yield self.store.get_events( [e_id for ids in states.values() for e_id in ids], @@ -1530,7 +1530,7 @@ class FederationHandler(BaseHandler): (d.type, d.state_key): d for d in different_events if d }) - new_state, prev_state = self.state_handler.resolve_events( + new_state = self.state_handler.resolve_events( [local_view.values(), remote_view.values()], event ) diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index b47bf1f92b..a27476bbad 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -52,7 +52,7 @@ def get_badge_count(store, user_id): def get_context_for_event(store, state_handler, ev, user_id): ctx = {} - room_state_ids = yield state_handler.get_current_state_ids(ev.room_id) + room_state_ids = yield store.get_state_ids_for_event(ev.event_id) # we no longer bother setting room_alias, and make room_name the # human-readable name instead, be that m.room.name, an alias or diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 093bc072f4..0c9cdff3b8 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -118,8 +118,14 @@ class LoginRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def do_password_login(self, login_submission): if 'medium' in login_submission and 'address' in login_submission: + address = login_submission['address'] + if login_submission['medium'] == 'email': + # For emails, transform the address to lowercase. + # We store all email addreses as lowercase in the DB. + # (See add_threepid in synapse/handlers/auth.py) + address = address.lower() user_id = yield self.hs.get_datastore().get_user_id_by_threepid( - login_submission['medium'], login_submission['address'] + login_submission['medium'], address ) if not user_id: raise LoginError(403, "", errcode=Codes.FORBIDDEN) diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index e74e5e0123..398e7f5eb0 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -96,6 +96,11 @@ class PasswordRestServlet(RestServlet): threepid = result[LoginType.EMAIL_IDENTITY] if 'medium' not in threepid or 'address' not in threepid: raise SynapseError(500, "Malformed threepid") + if threepid['medium'] == 'email': + # For emails, transform the address to lowercase. + # We store all email addreses as lowercase in the DB. + # (See add_threepid in synapse/handlers/auth.py) + threepid['address'] = threepid['address'].lower() # if using email, we must know about the email they're authing with! threepid_user_id = yield self.hs.get_datastore().get_user_id_by_threepid( threepid['medium'], threepid['address'] diff --git a/synapse/state.py b/synapse/state.py index b9d5627a82..20aaacf40f 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -16,12 +16,12 @@ from twisted.internet import defer +from synapse import event_auth from synapse.util.logutils import log_function from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import Measure from synapse.api.constants import EventTypes from synapse.api.errors import AuthError -from synapse.api.auth import AuthEventTypes from synapse.events.snapshot import EventContext from synapse.util.async import Linearizer @@ -41,12 +41,14 @@ KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) -SIZE_OF_CACHE = int(1000 * CACHE_SIZE_FACTOR) +SIZE_OF_CACHE = int(100000 * CACHE_SIZE_FACTOR) EVICTION_TIMEOUT_SECONDS = 60 * 60 _NEXT_STATE_ID = 1 +POWER_KEY = (EventTypes.PowerLevels, "") + def _gen_state_id(): global _NEXT_STATE_ID @@ -77,6 +79,9 @@ class _StateCacheEntry(object): else: self.state_id = _gen_state_id() + def __len__(self): + return len(self.state) + class StateHandler(object): """ Responsible for doing state conflict resolution. @@ -99,6 +104,7 @@ class StateHandler(object): clock=self.clock, max_len=SIZE_OF_CACHE, expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000, + iterable=True, reset_expiry_on_get=True, ) @@ -123,7 +129,7 @@ class StateHandler(object): if not latest_event_ids: latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - logger.info("calling resolve_state_groups from get_current_state") + logger.debug("calling resolve_state_groups from get_current_state") ret = yield self.resolve_state_groups(room_id, latest_event_ids) state = ret.state @@ -148,7 +154,7 @@ class StateHandler(object): if not latest_event_ids: latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - logger.info("calling resolve_state_groups from get_current_state_ids") + logger.debug("calling resolve_state_groups from get_current_state_ids") ret = yield self.resolve_state_groups(room_id, latest_event_ids) state = ret.state @@ -162,7 +168,7 @@ class StateHandler(object): def get_current_user_in_room(self, room_id, latest_event_ids=None): if not latest_event_ids: latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - logger.info("calling resolve_state_groups from get_current_user_in_room") + logger.debug("calling resolve_state_groups from get_current_user_in_room") entry = yield self.resolve_state_groups(room_id, latest_event_ids) joined_users = yield self.store.get_joined_users_from_state( room_id, entry.state_id, entry.state @@ -226,7 +232,7 @@ class StateHandler(object): context.prev_state_events = [] defer.returnValue(context) - logger.info("calling resolve_state_groups from compute_event_context") + logger.debug("calling resolve_state_groups from compute_event_context") if event.is_state(): entry = yield self.resolve_state_groups( event.room_id, [e for e, _ in event.prev_events], @@ -327,20 +333,13 @@ class StateHandler(object): if conflicted_state: logger.info("Resolving conflicted state for %r", room_id) - state_map = yield self.store.get_events( - [e_id for st in state_groups_ids.values() for e_id in st.values()], - get_prev_content=False - ) - state_sets = [ - [state_map[e_id] for key, e_id in st.items() if e_id in state_map] - for st in state_groups_ids.values() - ] - new_state, _ = self._resolve_events( - state_sets, event_type, state_key - ) - new_state = { - key: e.event_id for key, e in new_state.items() - } + with Measure(self.clock, "state._resolve_events"): + new_state = yield resolve_events( + state_groups_ids.values(), + state_map_factory=lambda ev_ids: self.store.get_events( + ev_ids, get_prev_content=False, check_redacted=False, + ), + ) else: new_state = { key: e_ids.pop() for key, e_ids in state.items() @@ -388,152 +387,264 @@ class StateHandler(object): logger.info( "Resolving state for %s with %d groups", event.room_id, len(state_sets) ) - if event.is_state(): - return self._resolve_events( - state_sets, event.type, event.state_key - ) - else: - return self._resolve_events(state_sets) + state_set_ids = [{ + (ev.type, ev.state_key): ev.event_id + for ev in st + } for st in state_sets] + + state_map = { + ev.event_id: ev + for st in state_sets + for ev in st + } - def _resolve_events(self, state_sets, event_type=None, state_key=""): - """ - Returns - (dict[(str, str), synapse.events.FrozenEvent], list[str]): a tuple - (new_state, prev_states). new_state is a map from (type, state_key) - to event. prev_states is a list of event_ids. - """ with Measure(self.clock, "state._resolve_events"): - state = {} - for st in state_sets: - for e in st: - state.setdefault( - (e.type, e.state_key), - {} - )[e.event_id] = e - - unconflicted_state = { - k: v.values()[0] for k, v in state.items() - if len(v.values()) == 1 - } + new_state = resolve_events(state_set_ids, state_map) - conflicted_state = { - k: v.values() - for k, v in state.items() - if len(v.values()) > 1 - } + new_state = { + key: state_map[ev_id] for key, ev_id in new_state.items() + } - if event_type: - prev_states_events = conflicted_state.get( - (event_type, state_key), [] - ) - prev_states = [s.event_id for s in prev_states_events] - else: - prev_states = [] + return new_state - auth_events = { - k: e for k, e in unconflicted_state.items() - if k[0] in AuthEventTypes - } - try: - resolved_state = self._resolve_state_events( - conflicted_state, auth_events - ) - except: - logger.exception("Failed to resolve state") - raise +def _ordered_events(events): + def key_func(e): + return -int(e.depth), hashlib.sha1(e.event_id).hexdigest() - new_state = unconflicted_state - new_state.update(resolved_state) + return sorted(events, key=key_func) - return new_state, prev_states - @log_function - def _resolve_state_events(self, conflicted_state, auth_events): - """ This is where we actually decide which of the conflicted state to - use. - - We resolve conflicts in the following order: - 1. power levels - 2. join rules - 3. memberships - 4. other events. - """ - resolved_state = {} - power_key = (EventTypes.PowerLevels, "") - if power_key in conflicted_state: - events = conflicted_state[power_key] - logger.debug("Resolving conflicted power levels %r", events) - resolved_state[power_key] = self._resolve_auth_events( - events, auth_events) - - auth_events.update(resolved_state) - - for key, events in conflicted_state.items(): - if key[0] == EventTypes.JoinRules: - logger.debug("Resolving conflicted join rules %r", events) - resolved_state[key] = self._resolve_auth_events( - events, - auth_events - ) - - auth_events.update(resolved_state) - - for key, events in conflicted_state.items(): - if key[0] == EventTypes.Member: - logger.debug("Resolving conflicted member lists %r", events) - resolved_state[key] = self._resolve_auth_events( - events, - auth_events - ) - - auth_events.update(resolved_state) - - for key, events in conflicted_state.items(): - if key not in resolved_state: - logger.debug("Resolving conflicted state %r:%r", key, events) - resolved_state[key] = self._resolve_normal_events( - events, auth_events - ) - - return resolved_state - - def _resolve_auth_events(self, events, auth_events): - reverse = [i for i in reversed(self._ordered_events(events))] - - auth_events = dict(auth_events) - - prev_event = reverse[0] - for event in reverse[1:]: - auth_events[(prev_event.type, prev_event.state_key)] = prev_event - try: - # FIXME: hs.get_auth() is bad style, but we need to do it to - # get around circular deps. - # The signatures have already been checked at this point - self.hs.get_auth().check(event, auth_events, do_sig_check=False) - prev_event = event - except AuthError: - return prev_event - - return event - - def _resolve_normal_events(self, events, auth_events): - for event in self._ordered_events(events): - try: - # FIXME: hs.get_auth() is bad style, but we need to do it to - # get around circular deps. - # The signatures have already been checked at this point - self.hs.get_auth().check(event, auth_events, do_sig_check=False) - return event - except AuthError: - pass - - # Use the last event (the one with the least depth) if they all fail - # the auth check. - return event - - def _ordered_events(self, events): - def key_func(e): - return -int(e.depth), hashlib.sha1(e.event_id).hexdigest() - - return sorted(events, key=key_func) +def resolve_events(state_sets, state_map_factory): + """ + Args: + state_sets(list): List of dicts of (type, state_key) -> event_id, + which are the different state groups to resolve. + state_map_factory(dict|callable): If callable, then will be called + with a list of event_ids that are needed, and should return with + a Deferred of dict of event_id to event. Otherwise, should be + a dict from event_id to event of all events in state_sets. + + Returns + dict[(str, str), synapse.events.FrozenEvent] is a map from + (type, state_key) to event. + """ + unconflicted_state, conflicted_state = _seperate( + state_sets, + ) + + if callable(state_map_factory): + return _resolve_with_state_fac( + unconflicted_state, conflicted_state, state_map_factory + ) + + state_map = state_map_factory + + auth_events = _create_auth_events_from_maps( + unconflicted_state, conflicted_state, state_map + ) + + return _resolve_with_state( + unconflicted_state, conflicted_state, auth_events, state_map + ) + + +def _seperate(state_sets): + """Takes the state_sets and figures out which keys are conflicted and + which aren't. i.e., which have multiple different event_ids associated + with them in different state sets. + """ + unconflicted_state = dict(state_sets[0]) + conflicted_state = {} + + for state_set in state_sets[1:]: + for key, value in state_set.iteritems(): + # Check if there is an unconflicted entry for the state key. + unconflicted_value = unconflicted_state.get(key) + if unconflicted_value is None: + # There isn't an unconflicted entry so check if there is a + # conflicted entry. + ls = conflicted_state.get(key) + if ls is None: + # There wasn't a conflicted entry so haven't seen this key before. + # Therefore it isn't conflicted yet. + unconflicted_state[key] = value + else: + # This key is already conflicted, add our value to the conflict set. + ls.add(value) + elif unconflicted_value != value: + # If the unconflicted value is not the same as our value then we + # have a new conflict. So move the key from the unconflicted_state + # to the conflicted state. + conflicted_state[key] = {value, unconflicted_value} + unconflicted_state.pop(key, None) + + return unconflicted_state, conflicted_state + + +@defer.inlineCallbacks +def _resolve_with_state_fac(unconflicted_state, conflicted_state, + state_map_factory): + needed_events = set( + event_id + for event_ids in conflicted_state.itervalues() + for event_id in event_ids + ) + + logger.info("Asking for %d conflicted events", len(needed_events)) + + state_map = yield state_map_factory(needed_events) + + auth_events = _create_auth_events_from_maps( + unconflicted_state, conflicted_state, state_map + ) + + new_needed_events = set(auth_events.itervalues()) + new_needed_events -= needed_events + + logger.info("Asking for %d auth events", len(new_needed_events)) + + state_map_new = yield state_map_factory(new_needed_events) + state_map.update(state_map_new) + + defer.returnValue(_resolve_with_state( + unconflicted_state, conflicted_state, auth_events, state_map + )) + + +def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map): + auth_events = {} + for event_ids in conflicted_state.itervalues(): + for event_id in event_ids: + if event_id in state_map: + keys = event_auth.auth_types_for_event(state_map[event_id]) + for key in keys: + if key not in auth_events: + event_id = unconflicted_state.get(key, None) + if event_id: + auth_events[key] = event_id + return auth_events + + +def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_ids, + state_map): + conflicted_state = {} + for key, event_ids in conflicted_state_ds.iteritems(): + events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map] + if len(events) > 1: + conflicted_state[key] = events + elif len(events) == 1: + unconflicted_state_ids[key] = events[0].event_id + + auth_events = { + key: state_map[ev_id] + for key, ev_id in auth_event_ids.items() + if ev_id in state_map + } + + try: + resolved_state = _resolve_state_events( + conflicted_state, auth_events + ) + except: + logger.exception("Failed to resolve state") + raise + + new_state = unconflicted_state_ids + for key, event in resolved_state.iteritems(): + new_state[key] = event.event_id + + return new_state + + +def _resolve_state_events(conflicted_state, auth_events): + """ This is where we actually decide which of the conflicted state to + use. + + We resolve conflicts in the following order: + 1. power levels + 2. join rules + 3. memberships + 4. other events. + """ + resolved_state = {} + if POWER_KEY in conflicted_state: + events = conflicted_state[POWER_KEY] + logger.debug("Resolving conflicted power levels %r", events) + resolved_state[POWER_KEY] = _resolve_auth_events( + events, auth_events) + + auth_events.update(resolved_state) + + for key, events in conflicted_state.items(): + if key[0] == EventTypes.JoinRules: + logger.debug("Resolving conflicted join rules %r", events) + resolved_state[key] = _resolve_auth_events( + events, + auth_events + ) + + auth_events.update(resolved_state) + + for key, events in conflicted_state.items(): + if key[0] == EventTypes.Member: + logger.debug("Resolving conflicted member lists %r", events) + resolved_state[key] = _resolve_auth_events( + events, + auth_events + ) + + auth_events.update(resolved_state) + + for key, events in conflicted_state.items(): + if key not in resolved_state: + logger.debug("Resolving conflicted state %r:%r", key, events) + resolved_state[key] = _resolve_normal_events( + events, auth_events + ) + + return resolved_state + + +def _resolve_auth_events(events, auth_events): + reverse = [i for i in reversed(_ordered_events(events))] + + auth_keys = set( + key + for event in events + for key in event_auth.auth_types_for_event(event) + ) + + new_auth_events = {} + for key in auth_keys: + auth_event = auth_events.get(key, None) + if auth_event: + new_auth_events[key] = auth_event + + auth_events = new_auth_events + + prev_event = reverse[0] + for event in reverse[1:]: + auth_events[(prev_event.type, prev_event.state_key)] = prev_event + try: + # The signatures have already been checked at this point + event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False) + prev_event = event + except AuthError: + return prev_event + + return event + + +def _resolve_normal_events(events, auth_events): + for event in _ordered_events(events): + try: + # The signatures have already been checked at this point + event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False) + return event + except AuthError: + pass + + # Use the last event (the one with the least depth) if they all fail + # the auth check. + return event diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 5620a655eb..963ef999d5 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -169,7 +169,7 @@ class SQLBaseStore(object): max_entries=hs.config.event_cache_size) self._state_group_cache = DictionaryCache( - "*stateGroupCache*", 2000 * CACHE_SIZE_FACTOR + "*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR ) self._event_fetch_lock = threading.Condition() diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py index 2821eb89c9..bde3b5cbbc 100644 --- a/synapse/storage/deviceinbox.py +++ b/synapse/storage/deviceinbox.py @@ -18,13 +18,29 @@ import ujson from twisted.internet import defer -from ._base import SQLBaseStore +from .background_updates import BackgroundUpdateStore logger = logging.getLogger(__name__) -class DeviceInboxStore(SQLBaseStore): +class DeviceInboxStore(BackgroundUpdateStore): + DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" + + def __init__(self, hs): + super(DeviceInboxStore, self).__init__(hs) + + self.register_background_index_update( + "device_inbox_stream_index", + index_name="device_inbox_stream_id_user_id", + table="device_inbox", + columns=["stream_id", "user_id"], + ) + + self.register_background_update_handler( + self.DEVICE_INBOX_STREAM_ID, + self._background_drop_index_device_inbox, + ) @defer.inlineCallbacks def add_messages_to_device_inbox(self, local_messages_by_user_then_device, @@ -368,3 +384,18 @@ class DeviceInboxStore(SQLBaseStore): "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn ) + + @defer.inlineCallbacks + def _background_drop_index_device_inbox(self, progress, batch_size): + def reindex_txn(conn): + txn = conn.cursor() + txn.execute( + "DROP INDEX IF EXISTS device_inbox_stream_id" + ) + txn.close() + + yield self.runWithConnection(reindex_txn) + + yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID) + + defer.returnValue(1) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 04dbdac3f8..ca501932f3 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -1084,10 +1084,10 @@ class EventsStore(SQLBaseStore): self._do_fetch ) - logger.info("Loading %d events", len(events)) + logger.debug("Loading %d events", len(events)) with PreserveLoggingContext(): rows = yield events_d - logger.info("Loaded %d events (%d rows)", len(events), len(rows)) + logger.debug("Loaded %d events (%d rows)", len(events), len(rows)) if not allow_rejected: rows[:] = [r for r in rows if not r["rejects"]] diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index e46ae6502e..b357f22be7 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 39 +SCHEMA_VERSION = 40 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 5d18037c7c..768e0a4451 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -390,7 +390,8 @@ class RoomMemberStore(SQLBaseStore): room_id, state_group, state_ids, ) - @cachedInlineCallbacks(num_args=2, cache_context=True) + @cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True, + max_entries=100000) def _get_joined_users_from_context(self, room_id, state_group, current_state_ids, cache_context, event=None): # We don't use `state_group`, it's there so that we can cache based diff --git a/synapse/storage/schema/delta/40/device_inbox.sql b/synapse/storage/schema/delta/40/device_inbox.sql new file mode 100644 index 0000000000..b9fe1f0480 --- /dev/null +++ b/synapse/storage/schema/delta/40/device_inbox.sql @@ -0,0 +1,21 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- turn the pre-fill startup query into a index-only scan on postgresql. +INSERT into background_updates (update_name, progress_json) + VALUES ('device_inbox_stream_index', '{}'); + +INSERT into background_updates (update_name, progress_json, depends_on) + VALUES ('device_inbox_stream_drop', '{}', 'device_inbox_stream_index'); diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 7f466c40ac..7d34dd03bf 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -284,7 +284,7 @@ class StateStore(SQLBaseStore): return [r[0] for r in results] return self.runInteraction("get_current_state_for_key", f) - @cached(num_args=2, max_entries=1000) + @cached(num_args=2, max_entries=100000, iterable=True) def _get_state_group_from_group(self, group, types): raise NotImplementedError() diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index ebd715c5dc..8a7774a88e 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -40,8 +40,8 @@ def register_cache(name, cache): ) -_string_cache = LruCache(int(5000 * CACHE_SIZE_FACTOR)) -caches_by_name["string_cache"] = _string_cache +_string_cache = LruCache(int(100000 * CACHE_SIZE_FACTOR)) +_stirng_cache_metrics = register_cache("string_cache", _string_cache) KNOWN_KEYS = { @@ -69,7 +69,12 @@ KNOWN_KEYS = { def intern_string(string): """Takes a (potentially) unicode string and interns using custom cache """ - return _string_cache.setdefault(string, string) + new_str = _string_cache.setdefault(string, string) + if new_str is string: + _stirng_cache_metrics.inc_hits() + else: + _stirng_cache_metrics.inc_misses() + return new_str def intern_dict(dictionary): diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 8dba61d49f..675bfd5feb 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -17,7 +17,7 @@ import logging from synapse.util.async import ObservableDeferred from synapse.util import unwrapFirstError from synapse.util.caches.lrucache import LruCache -from synapse.util.caches.treecache import TreeCache +from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry from synapse.util.logcontext import ( PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn ) @@ -42,6 +42,25 @@ _CacheSentinel = object() CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) +class CacheEntry(object): + __slots__ = [ + "deferred", "sequence", "callbacks", "invalidated" + ] + + def __init__(self, deferred, sequence, callbacks): + self.deferred = deferred + self.sequence = sequence + self.callbacks = set(callbacks) + self.invalidated = False + + def invalidate(self): + if not self.invalidated: + self.invalidated = True + for callback in self.callbacks: + callback() + self.callbacks.clear() + + class Cache(object): __slots__ = ( "cache", @@ -51,12 +70,16 @@ class Cache(object): "sequence", "thread", "metrics", + "_pending_deferred_cache", ) - def __init__(self, name, max_entries=1000, keylen=1, tree=False): + def __init__(self, name, max_entries=1000, keylen=1, tree=False, iterable=False): cache_type = TreeCache if tree else dict + self._pending_deferred_cache = cache_type() + self.cache = LruCache( - max_size=max_entries, keylen=keylen, cache_type=cache_type + max_size=max_entries, keylen=keylen, cache_type=cache_type, + size_callback=(lambda d: len(d.result)) if iterable else None, ) self.name = name @@ -76,7 +99,15 @@ class Cache(object): ) def get(self, key, default=_CacheSentinel, callback=None): - val = self.cache.get(key, _CacheSentinel, callback=callback) + callbacks = [callback] if callback else [] + val = self._pending_deferred_cache.get(key, _CacheSentinel) + if val is not _CacheSentinel: + if val.sequence == self.sequence: + val.callbacks.update(callbacks) + self.metrics.inc_hits() + return val.deferred + + val = self.cache.get(key, _CacheSentinel, callbacks=callbacks) if val is not _CacheSentinel: self.metrics.inc_hits() return val @@ -88,15 +119,39 @@ class Cache(object): else: return default - def update(self, sequence, key, value, callback=None): + def set(self, key, value, callback=None): + callbacks = [callback] if callback else [] self.check_thread() - if self.sequence == sequence: - # Only update the cache if the caches sequence number matches the - # number that the cache had before the SELECT was started (SYN-369) - self.prefill(key, value, callback=callback) + entry = CacheEntry( + deferred=value, + sequence=self.sequence, + callbacks=callbacks, + ) + + entry.callbacks.update(callbacks) + + existing_entry = self._pending_deferred_cache.pop(key, None) + if existing_entry: + existing_entry.invalidate() + + self._pending_deferred_cache[key] = entry + + def shuffle(result): + if self.sequence == entry.sequence: + existing_entry = self._pending_deferred_cache.pop(key, None) + if existing_entry is entry: + self.cache.set(key, entry.deferred, entry.callbacks) + else: + entry.invalidate() + else: + entry.invalidate() + return result + + entry.deferred.addCallback(shuffle) def prefill(self, key, value, callback=None): - self.cache.set(key, value, callback=callback) + callbacks = [callback] if callback else [] + self.cache.set(key, value, callbacks=callbacks) def invalidate(self, key): self.check_thread() @@ -108,6 +163,10 @@ class Cache(object): # Increment the sequence number so that any SELECT statements that # raced with the INSERT don't update the cache (SYN-369) self.sequence += 1 + entry = self._pending_deferred_cache.pop(key, None) + if entry: + entry.invalidate() + self.cache.pop(key, None) def invalidate_many(self, key): @@ -119,6 +178,11 @@ class Cache(object): self.sequence += 1 self.cache.del_multi(key) + entry_dict = self._pending_deferred_cache.pop(key, None) + if entry_dict is not None: + for entry in iterate_tree_cache_entry(entry_dict): + entry.invalidate() + def invalidate_all(self): self.check_thread() self.sequence += 1 @@ -155,7 +219,7 @@ class CacheDescriptor(object): """ def __init__(self, orig, max_entries=1000, num_args=1, tree=False, - inlineCallbacks=False, cache_context=False): + inlineCallbacks=False, cache_context=False, iterable=False): max_entries = int(max_entries * CACHE_SIZE_FACTOR) self.orig = orig @@ -169,6 +233,8 @@ class CacheDescriptor(object): self.num_args = num_args self.tree = tree + self.iterable = iterable + all_args = inspect.getargspec(orig) self.arg_names = all_args.args[1:num_args + 1] @@ -203,6 +269,7 @@ class CacheDescriptor(object): max_entries=self.max_entries, keylen=self.num_args, tree=self.tree, + iterable=self.iterable, ) @functools.wraps(self.orig) @@ -243,11 +310,6 @@ class CacheDescriptor(object): return preserve_context_over_deferred(observer) except KeyError: - # Get the sequence number of the cache before reading from the - # database so that we can tell if the cache is invalidated - # while the SELECT is executing (SYN-369) - sequence = cache.sequence - ret = defer.maybeDeferred( preserve_context_over_fn, self.function_to_call, @@ -261,7 +323,7 @@ class CacheDescriptor(object): ret.addErrback(onErr) ret = ObservableDeferred(ret, consumeErrors=True) - cache.update(sequence, cache_key, ret, callback=invalidate_callback) + cache.set(cache_key, ret, callback=invalidate_callback) return preserve_context_over_deferred(ret.observe()) @@ -359,7 +421,6 @@ class CacheListDescriptor(object): missing.append(arg) if missing: - sequence = cache.sequence args_to_call = dict(arg_dict) args_to_call[self.list_name] = missing @@ -382,8 +443,8 @@ class CacheListDescriptor(object): key = list(keyargs) key[self.list_pos] = arg - cache.update( - sequence, tuple(key), observer, + cache.set( + tuple(key), observer, callback=invalidate_callback ) @@ -421,17 +482,20 @@ class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))): self.cache.invalidate(self.key) -def cached(max_entries=1000, num_args=1, tree=False, cache_context=False): +def cached(max_entries=1000, num_args=1, tree=False, cache_context=False, + iterable=False): return lambda orig: CacheDescriptor( orig, max_entries=max_entries, num_args=num_args, tree=tree, cache_context=cache_context, + iterable=iterable, ) -def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False): +def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False, + iterable=False): return lambda orig: CacheDescriptor( orig, max_entries=max_entries, @@ -439,6 +503,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_contex tree=tree, inlineCallbacks=True, cache_context=cache_context, + iterable=iterable, ) diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index b0ca1bb79d..cb6933c61c 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -23,7 +23,9 @@ import logging logger = logging.getLogger(__name__) -DictionaryEntry = namedtuple("DictionaryEntry", ("full", "value")) +class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "value"))): + def __len__(self): + return len(self.value) class DictionaryCache(object): @@ -32,7 +34,7 @@ class DictionaryCache(object): """ def __init__(self, name, max_entries=1000): - self.cache = LruCache(max_size=max_entries) + self.cache = LruCache(max_size=max_entries, size_callback=len) self.name = name self.sequence = 0 diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index 080388958f..2987c38a2d 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -15,6 +15,7 @@ from synapse.util.caches import register_cache +from collections import OrderedDict import logging @@ -23,7 +24,7 @@ logger = logging.getLogger(__name__) class ExpiringCache(object): def __init__(self, cache_name, clock, max_len=0, expiry_ms=0, - reset_expiry_on_get=False): + reset_expiry_on_get=False, iterable=False): """ Args: cache_name (str): Name of this cache, used for logging. @@ -36,6 +37,8 @@ class ExpiringCache(object): evicted based on time. reset_expiry_on_get (bool): If true, will reset the expiry time for an item on access. Defaults to False. + iterable (bool): If true, the size is calculated by summing the + sizes of all entries, rather than the number of entries. """ self._cache_name = cache_name @@ -47,9 +50,13 @@ class ExpiringCache(object): self._reset_expiry_on_get = reset_expiry_on_get - self._cache = {} + self._cache = OrderedDict() - self.metrics = register_cache(cache_name, self._cache) + self.metrics = register_cache(cache_name, self) + + self.iterable = iterable + + self._size_estimate = 0 def start(self): if not self._expiry_ms: @@ -65,15 +72,14 @@ class ExpiringCache(object): now = self._clock.time_msec() self._cache[key] = _CacheEntry(now, value) - # Evict if there are now too many items - if self._max_len and len(self._cache.keys()) > self._max_len: - sorted_entries = sorted( - self._cache.items(), - key=lambda item: item[1].time, - ) + if self.iterable: + self._size_estimate += len(value) - for k, _ in sorted_entries[self._max_len:]: - self._cache.pop(k) + # Evict if there are now too many items + while self._max_len and len(self) > self._max_len: + _key, value = self._cache.popitem(last=False) + if self.iterable: + self._size_estimate -= len(value.value) def __getitem__(self, key): try: @@ -99,7 +105,7 @@ class ExpiringCache(object): # zero expiry time means don't expire. This should never get called # since we have this check in start too. return - begin_length = len(self._cache) + begin_length = len(self) now = self._clock.time_msec() @@ -110,15 +116,20 @@ class ExpiringCache(object): keys_to_delete.add(key) for k in keys_to_delete: - self._cache.pop(k) + value = self._cache.pop(k) + if self.iterable: + self._size_estimate -= len(value.value) logger.debug( "[%s] _prune_cache before: %d, after len: %d", - self._cache_name, begin_length, len(self._cache) + self._cache_name, begin_length, len(self) ) def __len__(self): - return len(self._cache) + if self.iterable: + return self._size_estimate + else: + return len(self._cache) class _CacheEntry(object): diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 9c4c679175..072f9a9d19 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -49,7 +49,7 @@ class LruCache(object): Can also set callbacks on objects when getting/setting which are fired when that key gets invalidated/evicted. """ - def __init__(self, max_size, keylen=1, cache_type=dict): + def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None): cache = cache_type() self.cache = cache # Used for introspection. list_root = _Node(None, None, None, None) @@ -58,6 +58,12 @@ class LruCache(object): lock = threading.Lock() + def evict(): + while cache_len() > max_size: + todelete = list_root.prev_node + delete_node(todelete) + cache.pop(todelete.key, None) + def synchronized(f): @wraps(f) def inner(*args, **kwargs): @@ -66,6 +72,16 @@ class LruCache(object): return inner + cached_cache_len = [0] + if size_callback is not None: + def cache_len(): + return cached_cache_len[0] + else: + def cache_len(): + return len(cache) + + self.len = synchronized(cache_len) + def add_node(key, value, callbacks=set()): prev_node = list_root next_node = prev_node.next_node @@ -74,6 +90,9 @@ class LruCache(object): next_node.prev_node = node cache[key] = node + if size_callback: + cached_cache_len[0] += size_callback(node.value) + def move_node_to_front(node): prev_node = node.prev_node next_node = node.next_node @@ -92,23 +111,25 @@ class LruCache(object): prev_node.next_node = next_node next_node.prev_node = prev_node + if size_callback: + cached_cache_len[0] -= size_callback(node.value) + for cb in node.callbacks: cb() node.callbacks.clear() @synchronized - def cache_get(key, default=None, callback=None): + def cache_get(key, default=None, callbacks=[]): node = cache.get(key, None) if node is not None: move_node_to_front(node) - if callback: - node.callbacks.add(callback) + node.callbacks.update(callbacks) return node.value else: return default @synchronized - def cache_set(key, value, callback=None): + def cache_set(key, value, callbacks=[]): node = cache.get(key, None) if node is not None: if value != node.value: @@ -116,21 +137,18 @@ class LruCache(object): cb() node.callbacks.clear() - if callback: - node.callbacks.add(callback) + if size_callback: + cached_cache_len[0] -= size_callback(node.value) + cached_cache_len[0] += size_callback(value) + + node.callbacks.update(callbacks) move_node_to_front(node) node.value = value else: - if callback: - callbacks = set([callback]) - else: - callbacks = set() - add_node(key, value, callbacks) - if len(cache) > max_size: - todelete = list_root.prev_node - delete_node(todelete) - cache.pop(todelete.key, None) + add_node(key, value, set(callbacks)) + + evict() @synchronized def cache_set_default(key, value): @@ -139,10 +157,7 @@ class LruCache(object): return node.value else: add_node(key, value) - if len(cache) > max_size: - todelete = list_root.prev_node - delete_node(todelete) - cache.pop(todelete.key, None) + evict() return value @synchronized @@ -176,10 +191,6 @@ class LruCache(object): cache.clear() @synchronized - def cache_len(): - return len(cache) - - @synchronized def cache_contains(key): return key in cache @@ -190,7 +201,7 @@ class LruCache(object): self.pop = cache_pop if cache_type is TreeCache: self.del_multi = cache_del_multi - self.len = cache_len + self.len = synchronized(cache_len) self.contains = cache_contains self.clear = cache_clear diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py index c31585aea3..fcc341a6b7 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py @@ -65,12 +65,27 @@ class TreeCache(object): return popped def values(self): - return [e.value for e in self.root.values()] + return list(iterate_tree_cache_entry(self.root)) def __len__(self): return self.size +def iterate_tree_cache_entry(d): + """Helper function to iterate over the leaves of a tree, i.e. a dict of that + can contain dicts. + """ + if isinstance(d, dict): + for value_d in d.itervalues(): + for value in iterate_tree_cache_entry(value_d): + yield value + else: + if isinstance(d, _Entry): + yield d.value + else: + yield d + + class _Entry(object): __slots__ = ["value"] |