diff options
Diffstat (limited to 'synapse')
83 files changed, 3314 insertions, 1642 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py index 498ded38c0..d3f445be9c 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -16,4 +16,4 @@ """ This is a reference implementation of a Matrix home server. """ -__version__ = "0.18.7" +__version__ = "0.19.0" diff --git a/synapse/api/auth.py b/synapse/api/auth.py index f93e45a744..03a215ab1b 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -16,18 +16,14 @@ import logging import pymacaroons -from canonicaljson import encode_canonical_json -from signedjson.key import decode_verify_key_bytes -from signedjson.sign import verify_signed_json, SignatureVerifyException from twisted.internet import defer -from unpaddedbase64 import decode_base64 import synapse.types +from synapse import event_auth from synapse.api.constants import EventTypes, Membership, JoinRules -from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError -from synapse.types import UserID, get_domain_from_id +from synapse.api.errors import AuthError, Codes +from synapse.types import UserID from synapse.util.logcontext import preserve_context_over_fn -from synapse.util.logutils import log_function from synapse.util.metrics import Measure logger = logging.getLogger(__name__) @@ -78,147 +74,7 @@ class Auth(object): True if the auth checks pass. """ with Measure(self.clock, "auth.check"): - self.check_size_limits(event) - - if not hasattr(event, "room_id"): - raise AuthError(500, "Event has no room_id: %s" % event) - - if do_sig_check: - sender_domain = get_domain_from_id(event.sender) - event_id_domain = get_domain_from_id(event.event_id) - - is_invite_via_3pid = ( - event.type == EventTypes.Member - and event.membership == Membership.INVITE - and "third_party_invite" in event.content - ) - - # Check the sender's domain has signed the event - if not event.signatures.get(sender_domain): - # We allow invites via 3pid to have a sender from a different - # HS, as the sender must match the sender of the original - # 3pid invite. This is checked further down with the - # other dedicated membership checks. - if not is_invite_via_3pid: - raise AuthError(403, "Event not signed by sender's server") - - # Check the event_id's domain has signed the event - if not event.signatures.get(event_id_domain): - raise AuthError(403, "Event not signed by sending server") - - if auth_events is None: - # Oh, we don't know what the state of the room was, so we - # are trusting that this is allowed (at least for now) - logger.warn("Trusting event: %s", event.event_id) - return True - - if event.type == EventTypes.Create: - room_id_domain = get_domain_from_id(event.room_id) - if room_id_domain != sender_domain: - raise AuthError( - 403, - "Creation event's room_id domain does not match sender's" - ) - # FIXME - return True - - creation_event = auth_events.get((EventTypes.Create, ""), None) - - if not creation_event: - raise SynapseError( - 403, - "Room %r does not exist" % (event.room_id,) - ) - - creating_domain = get_domain_from_id(event.room_id) - originating_domain = get_domain_from_id(event.sender) - if creating_domain != originating_domain: - if not self.can_federate(event, auth_events): - raise AuthError( - 403, - "This room has been marked as unfederatable." - ) - - # FIXME: Temp hack - if event.type == EventTypes.Aliases: - if not event.is_state(): - raise AuthError( - 403, - "Alias event must be a state event", - ) - if not event.state_key: - raise AuthError( - 403, - "Alias event must have non-empty state_key" - ) - sender_domain = get_domain_from_id(event.sender) - if event.state_key != sender_domain: - raise AuthError( - 403, - "Alias event's state_key does not match sender's domain" - ) - return True - - logger.debug( - "Auth events: %s", - [a.event_id for a in auth_events.values()] - ) - - if event.type == EventTypes.Member: - allowed = self.is_membership_change_allowed( - event, auth_events - ) - if allowed: - logger.debug("Allowing! %s", event) - else: - logger.debug("Denying! %s", event) - return allowed - - self.check_event_sender_in_room(event, auth_events) - - # Special case to allow m.room.third_party_invite events wherever - # a user is allowed to issue invites. Fixes - # https://github.com/vector-im/vector-web/issues/1208 hopefully - if event.type == EventTypes.ThirdPartyInvite: - user_level = self._get_user_power_level(event.user_id, auth_events) - invite_level = self._get_named_level(auth_events, "invite", 0) - - if user_level < invite_level: - raise AuthError( - 403, ( - "You cannot issue a third party invite for %s." % - (event.content.display_name,) - ) - ) - else: - return True - - self._can_send_event(event, auth_events) - - if event.type == EventTypes.PowerLevels: - self._check_power_levels(event, auth_events) - - if event.type == EventTypes.Redaction: - self.check_redaction(event, auth_events) - - logger.debug("Allowing! %s", event) - - def check_size_limits(self, event): - def too_big(field): - raise EventSizeError("%s too large" % (field,)) - - if len(event.user_id) > 255: - too_big("user_id") - if len(event.room_id) > 255: - too_big("room_id") - if event.is_state() and len(event.state_key) > 255: - too_big("state_key") - if len(event.type) > 255: - too_big("type") - if len(event.event_id) > 255: - too_big("event_id") - if len(encode_canonical_json(event.get_pdu_json())) > 65536: - too_big("event") + event_auth.check(event, auth_events, do_sig_check=do_sig_check) @defer.inlineCallbacks def check_joined_room(self, room_id, user_id, current_state=None): @@ -290,7 +146,7 @@ class Auth(object): with Measure(self.clock, "check_host_in_room"): latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - logger.info("calling resolve_state_groups from check_host_in_room") + logger.debug("calling resolve_state_groups from check_host_in_room") entry = yield self.state.resolve_state_groups( room_id, latest_event_ids ) @@ -300,16 +156,6 @@ class Auth(object): ) defer.returnValue(ret) - def check_event_sender_in_room(self, event, auth_events): - key = (EventTypes.Member, event.user_id, ) - member_event = auth_events.get(key) - - return self._check_joined_room( - member_event, - event.user_id, - event.room_id - ) - def _check_joined_room(self, member, user_id, room_id): if not member or member.membership != Membership.JOIN: raise AuthError(403, "User %s not in room %s (%s)" % ( @@ -321,267 +167,8 @@ class Auth(object): return creation_event.content.get("m.federate", True) is True - @log_function - def is_membership_change_allowed(self, event, auth_events): - membership = event.content["membership"] - - # Check if this is the room creator joining: - if len(event.prev_events) == 1 and Membership.JOIN == membership: - # Get room creation event: - key = (EventTypes.Create, "", ) - create = auth_events.get(key) - if create and event.prev_events[0][0] == create.event_id: - if create.content["creator"] == event.state_key: - return True - - target_user_id = event.state_key - - creating_domain = get_domain_from_id(event.room_id) - target_domain = get_domain_from_id(target_user_id) - if creating_domain != target_domain: - if not self.can_federate(event, auth_events): - raise AuthError( - 403, - "This room has been marked as unfederatable." - ) - - # get info about the caller - key = (EventTypes.Member, event.user_id, ) - caller = auth_events.get(key) - - caller_in_room = caller and caller.membership == Membership.JOIN - caller_invited = caller and caller.membership == Membership.INVITE - - # get info about the target - key = (EventTypes.Member, target_user_id, ) - target = auth_events.get(key) - - target_in_room = target and target.membership == Membership.JOIN - target_banned = target and target.membership == Membership.BAN - - key = (EventTypes.JoinRules, "", ) - join_rule_event = auth_events.get(key) - if join_rule_event: - join_rule = join_rule_event.content.get( - "join_rule", JoinRules.INVITE - ) - else: - join_rule = JoinRules.INVITE - - user_level = self._get_user_power_level(event.user_id, auth_events) - target_level = self._get_user_power_level( - target_user_id, auth_events - ) - - # FIXME (erikj): What should we do here as the default? - ban_level = self._get_named_level(auth_events, "ban", 50) - - logger.debug( - "is_membership_change_allowed: %s", - { - "caller_in_room": caller_in_room, - "caller_invited": caller_invited, - "target_banned": target_banned, - "target_in_room": target_in_room, - "membership": membership, - "join_rule": join_rule, - "target_user_id": target_user_id, - "event.user_id": event.user_id, - } - ) - - if Membership.INVITE == membership and "third_party_invite" in event.content: - if not self._verify_third_party_invite(event, auth_events): - raise AuthError(403, "You are not invited to this room.") - if target_banned: - raise AuthError( - 403, "%s is banned from the room" % (target_user_id,) - ) - return True - - if Membership.JOIN != membership: - if (caller_invited - and Membership.LEAVE == membership - and target_user_id == event.user_id): - return True - - if not caller_in_room: # caller isn't joined - raise AuthError( - 403, - "%s not in room %s." % (event.user_id, event.room_id,) - ) - - if Membership.INVITE == membership: - # TODO (erikj): We should probably handle this more intelligently - # PRIVATE join rules. - - # Invites are valid iff caller is in the room and target isn't. - if target_banned: - raise AuthError( - 403, "%s is banned from the room" % (target_user_id,) - ) - elif target_in_room: # the target is already in the room. - raise AuthError(403, "%s is already in the room." % - target_user_id) - else: - invite_level = self._get_named_level(auth_events, "invite", 0) - - if user_level < invite_level: - raise AuthError( - 403, "You cannot invite user %s." % target_user_id - ) - elif Membership.JOIN == membership: - # Joins are valid iff caller == target and they were: - # invited: They are accepting the invitation - # joined: It's a NOOP - if event.user_id != target_user_id: - raise AuthError(403, "Cannot force another user to join.") - elif target_banned: - raise AuthError(403, "You are banned from this room") - elif join_rule == JoinRules.PUBLIC: - pass - elif join_rule == JoinRules.INVITE: - if not caller_in_room and not caller_invited: - raise AuthError(403, "You are not invited to this room.") - else: - # TODO (erikj): may_join list - # TODO (erikj): private rooms - raise AuthError(403, "You are not allowed to join this room") - elif Membership.LEAVE == membership: - # TODO (erikj): Implement kicks. - if target_banned and user_level < ban_level: - raise AuthError( - 403, "You cannot unban user &s." % (target_user_id,) - ) - elif target_user_id != event.user_id: - kick_level = self._get_named_level(auth_events, "kick", 50) - - if user_level < kick_level or user_level <= target_level: - raise AuthError( - 403, "You cannot kick user %s." % target_user_id - ) - elif Membership.BAN == membership: - if user_level < ban_level or user_level <= target_level: - raise AuthError(403, "You don't have permission to ban") - else: - raise AuthError(500, "Unknown membership %s" % membership) - - return True - - def _verify_third_party_invite(self, event, auth_events): - """ - Validates that the invite event is authorized by a previous third-party invite. - - Checks that the public key, and keyserver, match those in the third party invite, - and that the invite event has a signature issued using that public key. - - Args: - event: The m.room.member join event being validated. - auth_events: All relevant previous context events which may be used - for authorization decisions. - - Return: - True if the event fulfills the expectations of a previous third party - invite event. - """ - if "third_party_invite" not in event.content: - return False - if "signed" not in event.content["third_party_invite"]: - return False - signed = event.content["third_party_invite"]["signed"] - for key in {"mxid", "token"}: - if key not in signed: - return False - - token = signed["token"] - - invite_event = auth_events.get( - (EventTypes.ThirdPartyInvite, token,) - ) - if not invite_event: - return False - - if invite_event.sender != event.sender: - return False - - if event.user_id != invite_event.user_id: - return False - - if signed["mxid"] != event.state_key: - return False - if signed["token"] != token: - return False - - for public_key_object in self.get_public_keys(invite_event): - public_key = public_key_object["public_key"] - try: - for server, signature_block in signed["signatures"].items(): - for key_name, encoded_signature in signature_block.items(): - if not key_name.startswith("ed25519:"): - continue - verify_key = decode_verify_key_bytes( - key_name, - decode_base64(public_key) - ) - verify_signed_json(signed, server, verify_key) - - # We got the public key from the invite, so we know that the - # correct server signed the signed bundle. - # The caller is responsible for checking that the signing - # server has not revoked that public key. - return True - except (KeyError, SignatureVerifyException,): - continue - return False - def get_public_keys(self, invite_event): - public_keys = [] - if "public_key" in invite_event.content: - o = { - "public_key": invite_event.content["public_key"], - } - if "key_validity_url" in invite_event.content: - o["key_validity_url"] = invite_event.content["key_validity_url"] - public_keys.append(o) - public_keys.extend(invite_event.content.get("public_keys", [])) - return public_keys - - def _get_power_level_event(self, auth_events): - key = (EventTypes.PowerLevels, "", ) - return auth_events.get(key) - - def _get_user_power_level(self, user_id, auth_events): - power_level_event = self._get_power_level_event(auth_events) - - if power_level_event: - level = power_level_event.content.get("users", {}).get(user_id) - if not level: - level = power_level_event.content.get("users_default", 0) - - if level is None: - return 0 - else: - return int(level) - else: - key = (EventTypes.Create, "", ) - create_event = auth_events.get(key) - if (create_event is not None and - create_event.content["creator"] == user_id): - return 100 - else: - return 0 - - def _get_named_level(self, auth_events, name, default): - power_level_event = self._get_power_level_event(auth_events) - - if not power_level_event: - return default - - level = power_level_event.content.get(name, None) - if level is not None: - return int(level) - else: - return default + return event_auth.get_public_keys(invite_event) @defer.inlineCallbacks def get_user_by_req(self, request, allow_guest=False, rights="access"): @@ -974,56 +561,6 @@ class Auth(object): defer.returnValue(auth_ids) - def _get_send_level(self, etype, state_key, auth_events): - key = (EventTypes.PowerLevels, "", ) - send_level_event = auth_events.get(key) - send_level = None - if send_level_event: - send_level = send_level_event.content.get("events", {}).get( - etype - ) - if send_level is None: - if state_key is not None: - send_level = send_level_event.content.get( - "state_default", 50 - ) - else: - send_level = send_level_event.content.get( - "events_default", 0 - ) - - if send_level: - send_level = int(send_level) - else: - send_level = 0 - - return send_level - - @log_function - def _can_send_event(self, event, auth_events): - send_level = self._get_send_level( - event.type, event.get("state_key", None), auth_events - ) - user_level = self._get_user_power_level(event.user_id, auth_events) - - if user_level < send_level: - raise AuthError( - 403, - "You don't have permission to post that to the room. " + - "user_level (%d) < send_level (%d)" % (user_level, send_level) - ) - - # Check state_key - if hasattr(event, "state_key"): - if event.state_key.startswith("@"): - if event.state_key != event.user_id: - raise AuthError( - 403, - "You are not allowed to set others state" - ) - - return True - def check_redaction(self, event, auth_events): """Check whether the event sender is allowed to redact the target event. @@ -1037,107 +574,7 @@ class Auth(object): AuthError if the event sender is definitely not allowed to redact the target event. """ - user_level = self._get_user_power_level(event.user_id, auth_events) - - redact_level = self._get_named_level(auth_events, "redact", 50) - - if user_level >= redact_level: - return False - - redacter_domain = get_domain_from_id(event.event_id) - redactee_domain = get_domain_from_id(event.redacts) - if redacter_domain == redactee_domain: - return True - - raise AuthError( - 403, - "You don't have permission to redact events" - ) - - def _check_power_levels(self, event, auth_events): - user_list = event.content.get("users", {}) - # Validate users - for k, v in user_list.items(): - try: - UserID.from_string(k) - except: - raise SynapseError(400, "Not a valid user_id: %s" % (k,)) - - try: - int(v) - except: - raise SynapseError(400, "Not a valid power level: %s" % (v,)) - - key = (event.type, event.state_key, ) - current_state = auth_events.get(key) - - if not current_state: - return - - user_level = self._get_user_power_level(event.user_id, auth_events) - - # Check other levels: - levels_to_check = [ - ("users_default", None), - ("events_default", None), - ("state_default", None), - ("ban", None), - ("redact", None), - ("kick", None), - ("invite", None), - ] - - old_list = current_state.content.get("users") - for user in set(old_list.keys() + user_list.keys()): - levels_to_check.append( - (user, "users") - ) - - old_list = current_state.content.get("events") - new_list = event.content.get("events") - for ev_id in set(old_list.keys() + new_list.keys()): - levels_to_check.append( - (ev_id, "events") - ) - - old_state = current_state.content - new_state = event.content - - for level_to_check, dir in levels_to_check: - old_loc = old_state - new_loc = new_state - if dir: - old_loc = old_loc.get(dir, {}) - new_loc = new_loc.get(dir, {}) - - if level_to_check in old_loc: - old_level = int(old_loc[level_to_check]) - else: - old_level = None - - if level_to_check in new_loc: - new_level = int(new_loc[level_to_check]) - else: - new_level = None - - if new_level is not None and old_level is not None: - if new_level == old_level: - continue - - if dir == "users" and level_to_check != event.user_id: - if old_level == user_level: - raise AuthError( - 403, - "You don't have permission to remove ops level equal " - "to your own" - ) - - if old_level > user_level or new_level > user_level: - raise AuthError( - 403, - "You don't have permission to add ops level greater " - "than your own" - ) + return event_auth.check_redaction(event, auth_events) @defer.inlineCallbacks def check_can_change_room_list(self, room_id, user): @@ -1167,10 +604,10 @@ class Auth(object): if power_level_event: auth_events[(EventTypes.PowerLevels, "")] = power_level_event - send_level = self._get_send_level( + send_level = event_auth.get_send_level( EventTypes.Aliases, "", auth_events ) - user_level = self._get_user_power_level(user_id, auth_events) + user_level = event_auth.get_user_power_level(user_id, auth_events) if user_level < send_level: raise AuthError( diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py index dd9ee406a1..1900930053 100644 --- a/synapse/app/appservice.py +++ b/synapse/app/appservice.py @@ -76,7 +76,7 @@ class AppserviceServer(HomeServer): def _listen_http(self, listener_config): port = listener_config["port"] - bind_address = listener_config.get("bind_address", "") + bind_addresses = listener_config["bind_addresses"] site_tag = listener_config.get("tag", port) resources = {} for res in listener_config["resources"]: @@ -85,16 +85,19 @@ class AppserviceServer(HomeServer): resources[METRICS_PREFIX] = MetricsResource(self) root_resource = create_resource_tree(resources, Resource()) - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=bind_address - ) + + for address in bind_addresses: + reactor.listenTCP( + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + ), + interface=address + ) + logger.info("Synapse appservice now listening on port %d", port) def start_listening(self, listeners): @@ -102,15 +105,18 @@ class AppserviceServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=listener.get("bind_address", '127.0.0.1') - ) + bind_addresses = listener["bind_addresses"] + + for address in bind_addresses: + reactor.listenTCP( + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, + ), + interface=address + ) else: logger.warn("Unrecognized listener type: %s", listener["type"]) diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py index 0086a2977e..4d081eccd1 100644 --- a/synapse/app/client_reader.py +++ b/synapse/app/client_reader.py @@ -90,7 +90,7 @@ class ClientReaderServer(HomeServer): def _listen_http(self, listener_config): port = listener_config["port"] - bind_address = listener_config.get("bind_address", "") + bind_addresses = listener_config["bind_addresses"] site_tag = listener_config.get("tag", port) resources = {} for res in listener_config["resources"]: @@ -108,16 +108,19 @@ class ClientReaderServer(HomeServer): }) root_resource = create_resource_tree(resources, Resource()) - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=bind_address - ) + + for address in bind_addresses: + reactor.listenTCP( + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + ), + interface=address + ) + logger.info("Synapse client reader now listening on port %d", port) def start_listening(self, listeners): @@ -125,15 +128,18 @@ class ClientReaderServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=listener.get("bind_address", '127.0.0.1') - ) + bind_addresses = listener["bind_addresses"] + + for address in bind_addresses: + reactor.listenTCP( + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, + ), + interface=address + ) else: logger.warn("Unrecognized listener type: %s", listener["type"]) diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py index b5f59a9931..90a4816753 100644 --- a/synapse/app/federation_reader.py +++ b/synapse/app/federation_reader.py @@ -86,7 +86,7 @@ class FederationReaderServer(HomeServer): def _listen_http(self, listener_config): port = listener_config["port"] - bind_address = listener_config.get("bind_address", "") + bind_addresses = listener_config["bind_addresses"] site_tag = listener_config.get("tag", port) resources = {} for res in listener_config["resources"]: @@ -99,16 +99,19 @@ class FederationReaderServer(HomeServer): }) root_resource = create_resource_tree(resources, Resource()) - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=bind_address - ) + + for address in bind_addresses: + reactor.listenTCP( + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + ), + interface=address + ) + logger.info("Synapse federation reader now listening on port %d", port) def start_listening(self, listeners): @@ -116,15 +119,18 @@ class FederationReaderServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=listener.get("bind_address", '127.0.0.1') - ) + bind_addresses = listener["bind_addresses"] + + for address in bind_addresses: + reactor.listenTCP( + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, + ), + interface=address + ) else: logger.warn("Unrecognized listener type: %s", listener["type"]) diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index 80ea4c8062..411e47d98d 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -30,6 +30,7 @@ from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.transactions import TransactionStore +from synapse.replication.slave.storage.devices import SlavedDeviceStore from synapse.storage.engines import create_engine from synapse.storage.presence import UserPresenceState from synapse.util.async import sleep @@ -56,7 +57,7 @@ logger = logging.getLogger("synapse.app.appservice") class FederationSenderSlaveStore( SlavedDeviceInboxStore, TransactionStore, SlavedReceiptsStore, SlavedEventStore, - SlavedRegistrationStore, + SlavedRegistrationStore, SlavedDeviceStore, ): pass @@ -82,7 +83,7 @@ class FederationSenderServer(HomeServer): def _listen_http(self, listener_config): port = listener_config["port"] - bind_address = listener_config.get("bind_address", "") + bind_addresses = listener_config["bind_addresses"] site_tag = listener_config.get("tag", port) resources = {} for res in listener_config["resources"]: @@ -91,16 +92,19 @@ class FederationSenderServer(HomeServer): resources[METRICS_PREFIX] = MetricsResource(self) root_resource = create_resource_tree(resources, Resource()) - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=bind_address - ) + + for address in bind_addresses: + reactor.listenTCP( + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + ), + interface=address + ) + logger.info("Synapse federation_sender now listening on port %d", port) def start_listening(self, listeners): @@ -108,15 +112,18 @@ class FederationSenderServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=listener.get("bind_address", '127.0.0.1') - ) + bind_addresses = listener["bind_addresses"] + + for address in bind_addresses: + reactor.listenTCP( + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, + ), + interface=address + ) else: logger.warn("Unrecognized listener type: %s", listener["type"]) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 54f35900f8..e0b87468fe 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -107,7 +107,7 @@ def build_resource_for_web_client(hs): class SynapseHomeServer(HomeServer): def _listener_http(self, config, listener_config): port = listener_config["port"] - bind_address = listener_config.get("bind_address", "") + bind_addresses = listener_config["bind_addresses"] tls = listener_config.get("tls", False) site_tag = listener_config.get("tag", port) @@ -173,29 +173,32 @@ class SynapseHomeServer(HomeServer): root_resource = Resource() root_resource = create_resource_tree(resources, root_resource) + if tls: - reactor.listenSSL( - port, - SynapseSite( - "synapse.access.https.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - self.tls_server_context_factory, - interface=bind_address - ) + for address in bind_addresses: + reactor.listenSSL( + port, + SynapseSite( + "synapse.access.https.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + ), + self.tls_server_context_factory, + interface=address + ) else: - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=bind_address - ) + for address in bind_addresses: + reactor.listenTCP( + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + ), + interface=address + ) logger.info("Synapse now listening on port %d", port) def start_listening(self): @@ -205,15 +208,18 @@ class SynapseHomeServer(HomeServer): if listener["type"] == "http": self._listener_http(config, listener) elif listener["type"] == "manhole": - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=listener.get("bind_address", '127.0.0.1') - ) + bind_addresses = listener["bind_addresses"] + + for address in bind_addresses: + reactor.listenTCP( + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, + ), + interface=address + ) else: logger.warn("Unrecognized listener type: %s", listener["type"]) diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index 44c19a1bef..ef17b158a5 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -87,7 +87,7 @@ class MediaRepositoryServer(HomeServer): def _listen_http(self, listener_config): port = listener_config["port"] - bind_address = listener_config.get("bind_address", "") + bind_addresses = listener_config["bind_addresses"] site_tag = listener_config.get("tag", port) resources = {} for res in listener_config["resources"]: @@ -105,16 +105,19 @@ class MediaRepositoryServer(HomeServer): }) root_resource = create_resource_tree(resources, Resource()) - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=bind_address - ) + + for address in bind_addresses: + reactor.listenTCP( + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + ), + interface=address + ) + logger.info("Synapse media repository now listening on port %d", port) def start_listening(self, listeners): @@ -122,15 +125,18 @@ class MediaRepositoryServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=listener.get("bind_address", '127.0.0.1') - ) + bind_addresses = listener["bind_addresses"] + + for address in bind_addresses: + reactor.listenTCP( + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, + ), + interface=address + ) else: logger.warn("Unrecognized listener type: %s", listener["type"]) diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index a0e765c54f..073f2c2489 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -121,7 +121,7 @@ class PusherServer(HomeServer): def _listen_http(self, listener_config): port = listener_config["port"] - bind_address = listener_config.get("bind_address", "") + bind_addresses = listener_config["bind_addresses"] site_tag = listener_config.get("tag", port) resources = {} for res in listener_config["resources"]: @@ -130,16 +130,19 @@ class PusherServer(HomeServer): resources[METRICS_PREFIX] = MetricsResource(self) root_resource = create_resource_tree(resources, Resource()) - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=bind_address - ) + + for address in bind_addresses: + reactor.listenTCP( + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + ), + interface=address + ) + logger.info("Synapse pusher now listening on port %d", port) def start_listening(self, listeners): @@ -147,15 +150,18 @@ class PusherServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=listener.get("bind_address", '127.0.0.1') - ) + bind_addresses = listener["bind_addresses"] + + for address in bind_addresses: + reactor.listenTCP( + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, + ), + interface=address + ) else: logger.warn("Unrecognized listener type: %s", listener["type"]) diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index bf1b995dc2..b3fb408cfd 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -39,6 +39,7 @@ from synapse.replication.slave.storage.filtering import SlavedFilteringStore from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore from synapse.replication.slave.storage.presence import SlavedPresenceStore from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore +from synapse.replication.slave.storage.devices import SlavedDeviceStore from synapse.replication.slave.storage.room import RoomStore from synapse.server import HomeServer from synapse.storage.client_ips import ClientIpStore @@ -77,6 +78,7 @@ class SynchrotronSlavedStore( SlavedFilteringStore, SlavedPresenceStore, SlavedDeviceInboxStore, + SlavedDeviceStore, RoomStore, BaseSlavedStore, ClientIpStore, # After BaseSlavedStore because the constructor is different @@ -289,7 +291,7 @@ class SynchrotronServer(HomeServer): def _listen_http(self, listener_config): port = listener_config["port"] - bind_address = listener_config.get("bind_address", "") + bind_addresses = listener_config["bind_addresses"] site_tag = listener_config.get("tag", port) resources = {} for res in listener_config["resources"]: @@ -310,16 +312,19 @@ class SynchrotronServer(HomeServer): }) root_resource = create_resource_tree(resources, Resource()) - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=bind_address - ) + + for address in bind_addresses: + reactor.listenTCP( + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + ), + interface=address + ) + logger.info("Synapse synchrotron now listening on port %d", port) def start_listening(self, listeners): @@ -327,15 +332,18 @@ class SynchrotronServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=listener.get("bind_address", '127.0.0.1') - ) + bind_addresses = listener["bind_addresses"] + + for address in bind_addresses: + reactor.listenTCP( + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, + ), + interface=address + ) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -374,6 +382,27 @@ class SynchrotronServer(HomeServer): stream_key, position, users=users, rooms=rooms ) + @defer.inlineCallbacks + def notify_device_list_update(result): + stream = result.get("device_lists") + if not stream: + return + + position_index = stream["field_names"].index("position") + user_index = stream["field_names"].index("user_id") + + for row in stream["rows"]: + position = row[position_index] + user_id = row[user_index] + + rooms = yield store.get_rooms_for_user(user_id) + room_ids = [r.room_id for r in rooms] + + notifier.on_new_event( + "device_list_key", position, rooms=room_ids, + ) + + @defer.inlineCallbacks def notify(result): stream = result.get("events") if stream: @@ -411,6 +440,7 @@ class SynchrotronServer(HomeServer): notify_from_stream( result, "to_device", "to_device_key", user="user_id" ) + yield notify_device_list_update(result) while True: try: @@ -421,7 +451,7 @@ class SynchrotronServer(HomeServer): yield store.process_replication(result) typing_handler.process_replication(result) yield presence_handler.process_replication(result) - notify(result) + yield notify(result) except: logger.exception("Error replicating from %r", replication_url) yield sleep(5) diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index a187161272..0030b5db1e 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -68,6 +68,9 @@ class EmailConfig(Config): self.email_notif_for_new_users = email_config.get( "notif_for_new_users", True ) + self.email_riot_base_url = email_config.get( + "riot_base_url", None + ) if "app_name" in email_config: self.email_app_name = email_config["app_name"] else: @@ -85,6 +88,9 @@ class EmailConfig(Config): def default_config(self, config_dir_path, server_name, **kwargs): return """ # Enable sending emails for notification events + # Defining a custom URL for Riot is only needed if email notifications + # should contain links to a self-hosted installation of Riot; when set + # the "app_name" setting is ignored. #email: # enable_notifs: false # smtp_host: "localhost" @@ -95,4 +101,5 @@ class EmailConfig(Config): # notif_template_html: notif_mail.html # notif_template_text: notif_mail.txt # notif_for_new_users: True + # riot_base_url: "http://localhost/riot" """ diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 63e69a7e0c..77ded0ad25 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -22,7 +22,6 @@ import yaml from string import Template import os import signal -from synapse.util.debug import debug_deferreds DEFAULT_LOG_CONFIG = Template(""" @@ -71,8 +70,6 @@ class LoggingConfig(Config): self.verbosity = config.get("verbose", 0) self.log_config = self.abspath(config.get("log_config")) self.log_file = self.abspath(config.get("log_file")) - if config.get("full_twisted_stacktraces"): - debug_deferreds() def default_config(self, config_dir_path, server_name, **kwargs): log_file = self.abspath("homeserver.log") @@ -88,11 +85,6 @@ class LoggingConfig(Config): # A yaml python logging config file log_config: "%(log_config)s" - - # Stop twisted from discarding the stack traces of exceptions in - # deferreds by waiting a reactor tick before running a deferred's - # callbacks. - # full_twisted_stacktraces: true """ % locals() def read_arguments(self, args): diff --git a/synapse/config/server.py b/synapse/config/server.py index 634d8e6fe5..1f9999d57a 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -42,6 +42,15 @@ class ServerConfig(Config): self.listeners = config.get("listeners", []) + for listener in self.listeners: + bind_address = listener.pop("bind_address", None) + bind_addresses = listener.setdefault("bind_addresses", []) + + if bind_address: + bind_addresses.append(bind_address) + elif not bind_addresses: + bind_addresses.append('') + self.gc_thresholds = read_gc_thresholds(config.get("gc_thresholds", None)) bind_port = config.get("bind_port") @@ -54,7 +63,7 @@ class ServerConfig(Config): self.listeners.append({ "port": bind_port, - "bind_address": bind_host, + "bind_addresses": [bind_host], "tls": True, "type": "http", "resources": [ @@ -73,7 +82,7 @@ class ServerConfig(Config): if unsecure_port: self.listeners.append({ "port": unsecure_port, - "bind_address": bind_host, + "bind_addresses": [bind_host], "tls": False, "type": "http", "resources": [ @@ -92,7 +101,7 @@ class ServerConfig(Config): if manhole: self.listeners.append({ "port": manhole, - "bind_address": "127.0.0.1", + "bind_addresses": ["127.0.0.1"], "type": "manhole", }) @@ -100,7 +109,7 @@ class ServerConfig(Config): if metrics_port: self.listeners.append({ "port": metrics_port, - "bind_address": config.get("metrics_bind_host", "127.0.0.1"), + "bind_addresses": [config.get("metrics_bind_host", "127.0.0.1")], "tls": False, "type": "http", "resources": [ @@ -155,9 +164,14 @@ class ServerConfig(Config): # The port to listen for HTTPS requests on. port: %(bind_port)s - # Local interface to listen on. - # The empty string will cause synapse to listen on all interfaces. - bind_address: '' + # Local addresses to listen on. + # This will listen on all IPv4 addresses by default. + bind_addresses: + - '0.0.0.0' + # Uncomment to listen on all IPv6 interfaces + # N.B: On at least Linux this will also listen on all IPv4 + # addresses, so you will need to comment out the line above. + # - '::' # This is a 'http' listener, allows us to specify 'resources'. type: http @@ -188,7 +202,7 @@ class ServerConfig(Config): # For when matrix traffic passes through loadbalancer that unwraps TLS. - port: %(unsecure_port)s tls: false - bind_address: '' + bind_addresses: ['0.0.0.0'] type: http x_forwarded: false diff --git a/synapse/config/voip.py b/synapse/config/voip.py index 169980f60d..eeb693027b 100644 --- a/synapse/config/voip.py +++ b/synapse/config/voip.py @@ -19,7 +19,9 @@ class VoipConfig(Config): def read_config(self, config): self.turn_uris = config.get("turn_uris", []) - self.turn_shared_secret = config["turn_shared_secret"] + self.turn_shared_secret = config.get("turn_shared_secret") + self.turn_username = config.get("turn_username") + self.turn_password = config.get("turn_password") self.turn_user_lifetime = self.parse_duration(config["turn_user_lifetime"]) def default_config(self, **kwargs): @@ -32,6 +34,11 @@ class VoipConfig(Config): # The shared secret used to compute passwords for the TURN server turn_shared_secret: "YOUR_SHARED_SECRET" + # The Username and password if the TURN server needs them and + # does not use a token + #turn_username: "TURNSERVER_USERNAME" + #turn_password: "TURNSERVER_PASSWORD" + # How long generated TURN credentials last turn_user_lifetime: "1h" """ diff --git a/synapse/config/workers.py b/synapse/config/workers.py index 904789d155..b165c67ee7 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -29,3 +29,13 @@ class WorkerConfig(Config): self.worker_log_file = config.get("worker_log_file") self.worker_log_config = config.get("worker_log_config") self.worker_replication_url = config.get("worker_replication_url") + + if self.worker_listeners: + for listener in self.worker_listeners: + bind_address = listener.pop("bind_address", None) + bind_addresses = listener.setdefault("bind_addresses", []) + + if bind_address: + bind_addresses.append(bind_address) + elif not bind_addresses: + bind_addresses.append('') diff --git a/synapse/event_auth.py b/synapse/event_auth.py new file mode 100644 index 0000000000..4096c606f1 --- /dev/null +++ b/synapse/event_auth.py @@ -0,0 +1,678 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 - 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from canonicaljson import encode_canonical_json +from signedjson.key import decode_verify_key_bytes +from signedjson.sign import verify_signed_json, SignatureVerifyException +from unpaddedbase64 import decode_base64 + +from synapse.api.constants import EventTypes, Membership, JoinRules +from synapse.api.errors import AuthError, SynapseError, EventSizeError +from synapse.types import UserID, get_domain_from_id + +logger = logging.getLogger(__name__) + + +def check(event, auth_events, do_sig_check=True, do_size_check=True): + """ Checks if this event is correctly authed. + + Args: + event: the event being checked. + auth_events (dict: event-key -> event): the existing room state. + + + Returns: + True if the auth checks pass. + """ + if do_size_check: + _check_size_limits(event) + + if not hasattr(event, "room_id"): + raise AuthError(500, "Event has no room_id: %s" % event) + + if do_sig_check: + sender_domain = get_domain_from_id(event.sender) + event_id_domain = get_domain_from_id(event.event_id) + + is_invite_via_3pid = ( + event.type == EventTypes.Member + and event.membership == Membership.INVITE + and "third_party_invite" in event.content + ) + + # Check the sender's domain has signed the event + if not event.signatures.get(sender_domain): + # We allow invites via 3pid to have a sender from a different + # HS, as the sender must match the sender of the original + # 3pid invite. This is checked further down with the + # other dedicated membership checks. + if not is_invite_via_3pid: + raise AuthError(403, "Event not signed by sender's server") + + # Check the event_id's domain has signed the event + if not event.signatures.get(event_id_domain): + raise AuthError(403, "Event not signed by sending server") + + if auth_events is None: + # Oh, we don't know what the state of the room was, so we + # are trusting that this is allowed (at least for now) + logger.warn("Trusting event: %s", event.event_id) + return True + + if event.type == EventTypes.Create: + room_id_domain = get_domain_from_id(event.room_id) + if room_id_domain != sender_domain: + raise AuthError( + 403, + "Creation event's room_id domain does not match sender's" + ) + # FIXME + return True + + creation_event = auth_events.get((EventTypes.Create, ""), None) + + if not creation_event: + raise SynapseError( + 403, + "Room %r does not exist" % (event.room_id,) + ) + + creating_domain = get_domain_from_id(event.room_id) + originating_domain = get_domain_from_id(event.sender) + if creating_domain != originating_domain: + if not _can_federate(event, auth_events): + raise AuthError( + 403, + "This room has been marked as unfederatable." + ) + + # FIXME: Temp hack + if event.type == EventTypes.Aliases: + if not event.is_state(): + raise AuthError( + 403, + "Alias event must be a state event", + ) + if not event.state_key: + raise AuthError( + 403, + "Alias event must have non-empty state_key" + ) + sender_domain = get_domain_from_id(event.sender) + if event.state_key != sender_domain: + raise AuthError( + 403, + "Alias event's state_key does not match sender's domain" + ) + return True + + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "Auth events: %s", + [a.event_id for a in auth_events.values()] + ) + + if event.type == EventTypes.Member: + allowed = _is_membership_change_allowed( + event, auth_events + ) + if allowed: + logger.debug("Allowing! %s", event) + else: + logger.debug("Denying! %s", event) + return allowed + + _check_event_sender_in_room(event, auth_events) + + # Special case to allow m.room.third_party_invite events wherever + # a user is allowed to issue invites. Fixes + # https://github.com/vector-im/vector-web/issues/1208 hopefully + if event.type == EventTypes.ThirdPartyInvite: + user_level = get_user_power_level(event.user_id, auth_events) + invite_level = _get_named_level(auth_events, "invite", 0) + + if user_level < invite_level: + raise AuthError( + 403, ( + "You cannot issue a third party invite for %s." % + (event.content.display_name,) + ) + ) + else: + return True + + _can_send_event(event, auth_events) + + if event.type == EventTypes.PowerLevels: + _check_power_levels(event, auth_events) + + if event.type == EventTypes.Redaction: + check_redaction(event, auth_events) + + logger.debug("Allowing! %s", event) + + +def _check_size_limits(event): + def too_big(field): + raise EventSizeError("%s too large" % (field,)) + + if len(event.user_id) > 255: + too_big("user_id") + if len(event.room_id) > 255: + too_big("room_id") + if event.is_state() and len(event.state_key) > 255: + too_big("state_key") + if len(event.type) > 255: + too_big("type") + if len(event.event_id) > 255: + too_big("event_id") + if len(encode_canonical_json(event.get_pdu_json())) > 65536: + too_big("event") + + +def _can_federate(event, auth_events): + creation_event = auth_events.get((EventTypes.Create, "")) + + return creation_event.content.get("m.federate", True) is True + + +def _is_membership_change_allowed(event, auth_events): + membership = event.content["membership"] + + # Check if this is the room creator joining: + if len(event.prev_events) == 1 and Membership.JOIN == membership: + # Get room creation event: + key = (EventTypes.Create, "", ) + create = auth_events.get(key) + if create and event.prev_events[0][0] == create.event_id: + if create.content["creator"] == event.state_key: + return True + + target_user_id = event.state_key + + creating_domain = get_domain_from_id(event.room_id) + target_domain = get_domain_from_id(target_user_id) + if creating_domain != target_domain: + if not _can_federate(event, auth_events): + raise AuthError( + 403, + "This room has been marked as unfederatable." + ) + + # get info about the caller + key = (EventTypes.Member, event.user_id, ) + caller = auth_events.get(key) + + caller_in_room = caller and caller.membership == Membership.JOIN + caller_invited = caller and caller.membership == Membership.INVITE + + # get info about the target + key = (EventTypes.Member, target_user_id, ) + target = auth_events.get(key) + + target_in_room = target and target.membership == Membership.JOIN + target_banned = target and target.membership == Membership.BAN + + key = (EventTypes.JoinRules, "", ) + join_rule_event = auth_events.get(key) + if join_rule_event: + join_rule = join_rule_event.content.get( + "join_rule", JoinRules.INVITE + ) + else: + join_rule = JoinRules.INVITE + + user_level = get_user_power_level(event.user_id, auth_events) + target_level = get_user_power_level( + target_user_id, auth_events + ) + + # FIXME (erikj): What should we do here as the default? + ban_level = _get_named_level(auth_events, "ban", 50) + + logger.debug( + "_is_membership_change_allowed: %s", + { + "caller_in_room": caller_in_room, + "caller_invited": caller_invited, + "target_banned": target_banned, + "target_in_room": target_in_room, + "membership": membership, + "join_rule": join_rule, + "target_user_id": target_user_id, + "event.user_id": event.user_id, + } + ) + + if Membership.INVITE == membership and "third_party_invite" in event.content: + if not _verify_third_party_invite(event, auth_events): + raise AuthError(403, "You are not invited to this room.") + if target_banned: + raise AuthError( + 403, "%s is banned from the room" % (target_user_id,) + ) + return True + + if Membership.JOIN != membership: + if (caller_invited + and Membership.LEAVE == membership + and target_user_id == event.user_id): + return True + + if not caller_in_room: # caller isn't joined + raise AuthError( + 403, + "%s not in room %s." % (event.user_id, event.room_id,) + ) + + if Membership.INVITE == membership: + # TODO (erikj): We should probably handle this more intelligently + # PRIVATE join rules. + + # Invites are valid iff caller is in the room and target isn't. + if target_banned: + raise AuthError( + 403, "%s is banned from the room" % (target_user_id,) + ) + elif target_in_room: # the target is already in the room. + raise AuthError(403, "%s is already in the room." % + target_user_id) + else: + invite_level = _get_named_level(auth_events, "invite", 0) + + if user_level < invite_level: + raise AuthError( + 403, "You cannot invite user %s." % target_user_id + ) + elif Membership.JOIN == membership: + # Joins are valid iff caller == target and they were: + # invited: They are accepting the invitation + # joined: It's a NOOP + if event.user_id != target_user_id: + raise AuthError(403, "Cannot force another user to join.") + elif target_banned: + raise AuthError(403, "You are banned from this room") + elif join_rule == JoinRules.PUBLIC: + pass + elif join_rule == JoinRules.INVITE: + if not caller_in_room and not caller_invited: + raise AuthError(403, "You are not invited to this room.") + else: + # TODO (erikj): may_join list + # TODO (erikj): private rooms + raise AuthError(403, "You are not allowed to join this room") + elif Membership.LEAVE == membership: + # TODO (erikj): Implement kicks. + if target_banned and user_level < ban_level: + raise AuthError( + 403, "You cannot unban user &s." % (target_user_id,) + ) + elif target_user_id != event.user_id: + kick_level = _get_named_level(auth_events, "kick", 50) + + if user_level < kick_level or user_level <= target_level: + raise AuthError( + 403, "You cannot kick user %s." % target_user_id + ) + elif Membership.BAN == membership: + if user_level < ban_level or user_level <= target_level: + raise AuthError(403, "You don't have permission to ban") + else: + raise AuthError(500, "Unknown membership %s" % membership) + + return True + + +def _check_event_sender_in_room(event, auth_events): + key = (EventTypes.Member, event.user_id, ) + member_event = auth_events.get(key) + + return _check_joined_room( + member_event, + event.user_id, + event.room_id + ) + + +def _check_joined_room(member, user_id, room_id): + if not member or member.membership != Membership.JOIN: + raise AuthError(403, "User %s not in room %s (%s)" % ( + user_id, room_id, repr(member) + )) + + +def get_send_level(etype, state_key, auth_events): + key = (EventTypes.PowerLevels, "", ) + send_level_event = auth_events.get(key) + send_level = None + if send_level_event: + send_level = send_level_event.content.get("events", {}).get( + etype + ) + if send_level is None: + if state_key is not None: + send_level = send_level_event.content.get( + "state_default", 50 + ) + else: + send_level = send_level_event.content.get( + "events_default", 0 + ) + + if send_level: + send_level = int(send_level) + else: + send_level = 0 + + return send_level + + +def _can_send_event(event, auth_events): + send_level = get_send_level( + event.type, event.get("state_key", None), auth_events + ) + user_level = get_user_power_level(event.user_id, auth_events) + + if user_level < send_level: + raise AuthError( + 403, + "You don't have permission to post that to the room. " + + "user_level (%d) < send_level (%d)" % (user_level, send_level) + ) + + # Check state_key + if hasattr(event, "state_key"): + if event.state_key.startswith("@"): + if event.state_key != event.user_id: + raise AuthError( + 403, + "You are not allowed to set others state" + ) + + return True + + +def check_redaction(event, auth_events): + """Check whether the event sender is allowed to redact the target event. + + Returns: + True if the the sender is allowed to redact the target event if the + target event was created by them. + False if the sender is allowed to redact the target event with no + further checks. + + Raises: + AuthError if the event sender is definitely not allowed to redact + the target event. + """ + user_level = get_user_power_level(event.user_id, auth_events) + + redact_level = _get_named_level(auth_events, "redact", 50) + + if user_level >= redact_level: + return False + + redacter_domain = get_domain_from_id(event.event_id) + redactee_domain = get_domain_from_id(event.redacts) + if redacter_domain == redactee_domain: + return True + + raise AuthError( + 403, + "You don't have permission to redact events" + ) + + +def _check_power_levels(event, auth_events): + user_list = event.content.get("users", {}) + # Validate users + for k, v in user_list.items(): + try: + UserID.from_string(k) + except: + raise SynapseError(400, "Not a valid user_id: %s" % (k,)) + + try: + int(v) + except: + raise SynapseError(400, "Not a valid power level: %s" % (v,)) + + key = (event.type, event.state_key, ) + current_state = auth_events.get(key) + + if not current_state: + return + + user_level = get_user_power_level(event.user_id, auth_events) + + # Check other levels: + levels_to_check = [ + ("users_default", None), + ("events_default", None), + ("state_default", None), + ("ban", None), + ("redact", None), + ("kick", None), + ("invite", None), + ] + + old_list = current_state.content.get("users") + for user in set(old_list.keys() + user_list.keys()): + levels_to_check.append( + (user, "users") + ) + + old_list = current_state.content.get("events") + new_list = event.content.get("events") + for ev_id in set(old_list.keys() + new_list.keys()): + levels_to_check.append( + (ev_id, "events") + ) + + old_state = current_state.content + new_state = event.content + + for level_to_check, dir in levels_to_check: + old_loc = old_state + new_loc = new_state + if dir: + old_loc = old_loc.get(dir, {}) + new_loc = new_loc.get(dir, {}) + + if level_to_check in old_loc: + old_level = int(old_loc[level_to_check]) + else: + old_level = None + + if level_to_check in new_loc: + new_level = int(new_loc[level_to_check]) + else: + new_level = None + + if new_level is not None and old_level is not None: + if new_level == old_level: + continue + + if dir == "users" and level_to_check != event.user_id: + if old_level == user_level: + raise AuthError( + 403, + "You don't have permission to remove ops level equal " + "to your own" + ) + + if old_level > user_level or new_level > user_level: + raise AuthError( + 403, + "You don't have permission to add ops level greater " + "than your own" + ) + + +def _get_power_level_event(auth_events): + key = (EventTypes.PowerLevels, "", ) + return auth_events.get(key) + + +def get_user_power_level(user_id, auth_events): + power_level_event = _get_power_level_event(auth_events) + + if power_level_event: + level = power_level_event.content.get("users", {}).get(user_id) + if not level: + level = power_level_event.content.get("users_default", 0) + + if level is None: + return 0 + else: + return int(level) + else: + key = (EventTypes.Create, "", ) + create_event = auth_events.get(key) + if (create_event is not None and + create_event.content["creator"] == user_id): + return 100 + else: + return 0 + + +def _get_named_level(auth_events, name, default): + power_level_event = _get_power_level_event(auth_events) + + if not power_level_event: + return default + + level = power_level_event.content.get(name, None) + if level is not None: + return int(level) + else: + return default + + +def _verify_third_party_invite(event, auth_events): + """ + Validates that the invite event is authorized by a previous third-party invite. + + Checks that the public key, and keyserver, match those in the third party invite, + and that the invite event has a signature issued using that public key. + + Args: + event: The m.room.member join event being validated. + auth_events: All relevant previous context events which may be used + for authorization decisions. + + Return: + True if the event fulfills the expectations of a previous third party + invite event. + """ + if "third_party_invite" not in event.content: + return False + if "signed" not in event.content["third_party_invite"]: + return False + signed = event.content["third_party_invite"]["signed"] + for key in {"mxid", "token"}: + if key not in signed: + return False + + token = signed["token"] + + invite_event = auth_events.get( + (EventTypes.ThirdPartyInvite, token,) + ) + if not invite_event: + return False + + if invite_event.sender != event.sender: + return False + + if event.user_id != invite_event.user_id: + return False + + if signed["mxid"] != event.state_key: + return False + if signed["token"] != token: + return False + + for public_key_object in get_public_keys(invite_event): + public_key = public_key_object["public_key"] + try: + for server, signature_block in signed["signatures"].items(): + for key_name, encoded_signature in signature_block.items(): + if not key_name.startswith("ed25519:"): + continue + verify_key = decode_verify_key_bytes( + key_name, + decode_base64(public_key) + ) + verify_signed_json(signed, server, verify_key) + + # We got the public key from the invite, so we know that the + # correct server signed the signed bundle. + # The caller is responsible for checking that the signing + # server has not revoked that public key. + return True + except (KeyError, SignatureVerifyException,): + continue + return False + + +def get_public_keys(invite_event): + public_keys = [] + if "public_key" in invite_event.content: + o = { + "public_key": invite_event.content["public_key"], + } + if "key_validity_url" in invite_event.content: + o["key_validity_url"] = invite_event.content["key_validity_url"] + public_keys.append(o) + public_keys.extend(invite_event.content.get("public_keys", [])) + return public_keys + + +def auth_types_for_event(event): + """Given an event, return a list of (EventType, StateKey) that may be + needed to auth the event. The returned list may be a superset of what + would actually be required depending on the full state of the room. + + Used to limit the number of events to fetch from the database to + actually auth the event. + """ + if event.type == EventTypes.Create: + return [] + + auth_types = [] + + auth_types.append((EventTypes.PowerLevels, "", )) + auth_types.append((EventTypes.Member, event.user_id, )) + auth_types.append((EventTypes.Create, "", )) + + if event.type == EventTypes.Member: + membership = event.content["membership"] + if membership in [Membership.JOIN, Membership.INVITE]: + auth_types.append((EventTypes.JoinRules, "", )) + + auth_types.append((EventTypes.Member, event.state_key, )) + + if membership == Membership.INVITE: + if "third_party_invite" in event.content: + key = ( + EventTypes.ThirdPartyInvite, + event.content["third_party_invite"]["signed"]["token"] + ) + auth_types.append(key) + + return auth_types diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index da9f3ad436..e673e96cc0 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -79,7 +79,6 @@ class EventBase(object): auth_events = _event_dict_property("auth_events") depth = _event_dict_property("depth") content = _event_dict_property("content") - event_id = _event_dict_property("event_id") hashes = _event_dict_property("hashes") origin = _event_dict_property("origin") origin_server_ts = _event_dict_property("origin_server_ts") @@ -88,8 +87,6 @@ class EventBase(object): redacts = _event_dict_property("redacts") room_id = _event_dict_property("room_id") sender = _event_dict_property("sender") - state_key = _event_dict_property("state_key") - type = _event_dict_property("type") user_id = _event_dict_property("sender") @property @@ -162,6 +159,11 @@ class FrozenEvent(EventBase): else: frozen_dict = event_dict + self.event_id = event_dict["event_id"] + self.type = event_dict["type"] + if "state_key" in event_dict: + self.state_key = event_dict["state_key"] + super(FrozenEvent, self).__init__( frozen_dict, signatures=signatures, diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 7369d70980..365fd96bd2 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import EventBase, FrozenEvent +from . import EventBase, FrozenEvent, _event_dict_property from synapse.types import EventID @@ -34,6 +34,10 @@ class EventBuilder(EventBase): internal_metadata_dict=internal_metadata_dict, ) + event_id = _event_dict_property("event_id") + state_key = _event_dict_property("state_key") + type = _event_dict_property("type") + def build(self): return FrozenEvent.from_event(self) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index b4bcec77ed..b5bcfd705a 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -26,7 +26,7 @@ from synapse.util import unwrapFirstError from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.logutils import log_function from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred -from synapse.events import FrozenEvent +from synapse.events import FrozenEvent, builder import synapse.metrics from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination @@ -127,6 +127,16 @@ class FederationClient(FederationBase): ) @log_function + def query_user_devices(self, destination, user_id, timeout=30000): + """Query the device keys for a list of user ids hosted on a remote + server. + """ + sent_queries_counter.inc("user_devices") + return self.transport_layer.query_user_devices( + destination, user_id, timeout + ) + + @log_function def claim_client_keys(self, destination, content, timeout): """Claims one-time keys for a device hosted on a remote server. @@ -499,8 +509,10 @@ class FederationClient(FederationBase): if "prev_state" not in pdu_dict: pdu_dict["prev_state"] = [] + ev = builder.EventBuilder(pdu_dict) + defer.returnValue( - (destination, self.event_from_pdu_json(pdu_dict)) + (destination, ev) ) break except CodeMessageException as e: diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 1fee4e83a6..e922b7ff4a 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -52,8 +52,8 @@ class FederationServer(FederationBase): self.auth = hs.get_auth() - self._room_pdu_linearizer = Linearizer() - self._server_linearizer = Linearizer() + self._room_pdu_linearizer = Linearizer("fed_room_pdu") + self._server_linearizer = Linearizer("fed_server") # We cache responses to state queries, as they take a while and often # come in waves. @@ -416,6 +416,9 @@ class FederationServer(FederationBase): def on_query_client_keys(self, origin, content): return self.on_query_request("client_keys", content) + def on_query_user_devices(self, origin, user_id): + return self.on_query_request("user_devices", user_id) + @defer.inlineCallbacks @log_function def on_claim_client_keys(self, origin, content): diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index 7db7b806dc..bb3d9258a6 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -100,6 +100,7 @@ class TransactionQueue(object): self.pending_failures_by_dest = {} self.last_device_stream_id_by_dest = {} + self.last_device_list_stream_id_by_dest = {} # HACK to get unique tx id self._next_txn_id = int(self.clock.time_msec()) @@ -305,64 +306,77 @@ class TransactionQueue(object): yield run_on_reactor() while True: - pending_pdus = self.pending_pdus_by_dest.pop(destination, []) - pending_edus = self.pending_edus_by_dest.pop(destination, []) - pending_presence = self.pending_presence_by_dest.pop(destination, {}) - pending_failures = self.pending_failures_by_dest.pop(destination, []) + pending_pdus = self.pending_pdus_by_dest.pop(destination, []) + pending_edus = self.pending_edus_by_dest.pop(destination, []) + pending_presence = self.pending_presence_by_dest.pop(destination, {}) + pending_failures = self.pending_failures_by_dest.pop(destination, []) - pending_edus.extend( - self.pending_edus_keyed_by_dest.pop(destination, {}).values() - ) + pending_edus.extend( + self.pending_edus_keyed_by_dest.pop(destination, {}).values() + ) - limiter = yield get_retry_limiter( - destination, - self.clock, - self.store, - ) + limiter = yield get_retry_limiter( + destination, + self.clock, + self.store, + backoff_on_404=True, # If we get a 404 the other side has gone + ) - device_message_edus, device_stream_id = ( - yield self._get_new_device_messages(destination) - ) + device_message_edus, device_stream_id, dev_list_id = ( + yield self._get_new_device_messages(destination) + ) - pending_edus.extend(device_message_edus) - if pending_presence: - pending_edus.append( - Edu( - origin=self.server_name, - destination=destination, - edu_type="m.presence", - content={ - "push": [ - format_user_presence_state( - presence, self.clock.time_msec() - ) - for presence in pending_presence.values() - ] - }, - ) + pending_edus.extend(device_message_edus) + if pending_presence: + pending_edus.append( + Edu( + origin=self.server_name, + destination=destination, + edu_type="m.presence", + content={ + "push": [ + format_user_presence_state( + presence, self.clock.time_msec() + ) + for presence in pending_presence.values() + ] + }, ) + ) + + if pending_pdus: + logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d", + destination, len(pending_pdus)) - if pending_pdus: - logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d", - destination, len(pending_pdus)) + if not pending_pdus and not pending_edus and not pending_failures: + logger.debug("TX [%s] Nothing to send", destination) + self.last_device_stream_id_by_dest[destination] = ( + device_stream_id + ) + return - if not pending_pdus and not pending_edus and not pending_failures: - logger.debug("TX [%s] Nothing to send", destination) - self.last_device_stream_id_by_dest[destination] = ( - device_stream_id + success = yield self._send_new_transaction( + destination, pending_pdus, pending_edus, pending_failures, + limiter=limiter, + ) + if success: + # Remove the acknowledged device messages from the database + # Only bother if we actually sent some device messages + if device_message_edus: + yield self.store.delete_device_msgs_for_remote( + destination, device_stream_id + ) + logger.info("Marking as sent %r %r", destination, dev_list_id) + yield self.store.mark_as_sent_devices_by_remote( + destination, dev_list_id ) - return - success = yield self._send_new_transaction( - destination, pending_pdus, pending_edus, pending_failures, - device_stream_id, - should_delete_from_device_stream=bool(device_message_edus), - limiter=limiter, - ) - if not success: - break + self.last_device_stream_id_by_dest[destination] = device_stream_id + self.last_device_list_stream_id_by_dest[destination] = dev_list_id + else: + break except NotRetryingDestination: - logger.info( + logger.debug( "TX [%s] not ready for retry yet - " "dropping transaction for now", destination, @@ -387,13 +401,26 @@ class TransactionQueue(object): ) for content in contents ] - defer.returnValue((edus, stream_id)) + + last_device_list = self.last_device_list_stream_id_by_dest.get(destination, 0) + now_stream_id, results = yield self.store.get_devices_by_remote( + destination, last_device_list + ) + edus.extend( + Edu( + origin=self.server_name, + destination=destination, + edu_type="m.device_list_update", + content=content, + ) + for content in results + ) + defer.returnValue((edus, stream_id, now_stream_id)) @measure_func("_send_new_transaction") @defer.inlineCallbacks def _send_new_transaction(self, destination, pending_pdus, pending_edus, - pending_failures, device_stream_id, - should_delete_from_device_stream, limiter): + pending_failures, limiter): # Sort based on the order field pending_pdus.sort(key=lambda t: t[1]) @@ -477,7 +504,7 @@ class TransactionQueue(object): code = e.code response = e.response - if e.code == 429 or 500 <= e.code: + if e.code in (401, 404, 429) or 500 <= e.code: logger.info( "TX [%s] {%s} got %d response", destination, txn_id, code @@ -504,13 +531,6 @@ class TransactionQueue(object): "Failed to send event %s to %s", p.event_id, destination ) success = False - else: - # Remove the acknowledged device messages from the database - if should_delete_from_device_stream: - yield self.store.delete_device_msgs_for_remote( - destination, device_stream_id - ) - self.last_device_stream_id_by_dest[destination] = device_stream_id except RuntimeError as e: # We capture this here as there as nothing actually listens # for this finishing functions deferred. diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 915af34409..f49e8a2cc4 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -348,6 +348,32 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function + def query_user_devices(self, destination, user_id, timeout): + """Query the devices for a user id hosted on a remote server. + + Response: + { + "stream_id": "...", + "devices": [ { ... } ] + } + + Args: + destination(str): The server to query. + query_content(dict): The user ids to query. + Returns: + A dict containg the device keys. + """ + path = PREFIX + "/user/devices/" + user_id + + content = yield self.client.get_json( + destination=destination, + path=path, + timeout=timeout, + ) + defer.returnValue(content) + + @defer.inlineCallbacks + @log_function def claim_client_keys(self, destination, query_content, timeout): """Claim one-time keys for a list of devices hosted on a remote server. diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 159dbd1747..c840da834c 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -409,6 +409,13 @@ class FederationClientKeysQueryServlet(BaseFederationServlet): return self.handler.on_query_client_keys(origin, content) +class FederationUserDevicesQueryServlet(BaseFederationServlet): + PATH = "/user/devices/(?P<user_id>[^/]*)" + + def on_GET(self, origin, content, query, user_id): + return self.handler.on_query_user_devices(origin, user_id) + + class FederationClientKeysClaimServlet(BaseFederationServlet): PATH = "/user/keys/claim" @@ -613,6 +620,7 @@ SERVLET_CLASSES = ( FederationGetMissingEventsServlet, FederationEventAuthServlet, FederationClientKeysQueryServlet, + FederationUserDevicesQueryServlet, FederationClientKeysClaimServlet, FederationThirdPartyInviteExchangeServlet, On3pidBindServlet, diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 90f96209f8..e83adc8339 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -88,9 +88,13 @@ class BaseHandler(object): current_state = yield self.store.get_events( context.current_state_ids.values() ) - current_state = current_state.values() else: - current_state = yield self.store.get_current_state(event.room_id) + current_state = yield self.state_handler.get_current_state( + event.room_id + ) + + current_state = current_state.values() + logger.info("maybe_kick_guest_users %r", current_state) yield self.kick_guest_users(current_state) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 3b146f09d6..fffba34383 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -65,6 +65,7 @@ class AuthHandler(BaseHandler): self.hs = hs # FIXME better possibility to access registrationHandler later? self.device_handler = hs.get_device_handler() + self.macaroon_gen = hs.get_macaroon_generator() @defer.inlineCallbacks def check_auth(self, flows, clientdict, clientip): @@ -529,37 +530,11 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def issue_access_token(self, user_id, device_id=None): - access_token = self.generate_access_token(user_id) + access_token = self.macaroon_gen.generate_access_token(user_id) yield self.store.add_access_token_to_user(user_id, access_token, device_id) defer.returnValue(access_token) - def generate_access_token(self, user_id, extra_caveats=None): - extra_caveats = extra_caveats or [] - macaroon = self._generate_base_macaroon(user_id) - macaroon.add_first_party_caveat("type = access") - # Include a nonce, to make sure that each login gets a different - # access token. - macaroon.add_first_party_caveat("nonce = %s" % ( - stringutils.random_string_with_symbols(16), - )) - for caveat in extra_caveats: - macaroon.add_first_party_caveat(caveat) - return macaroon.serialize() - - def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)): - macaroon = self._generate_base_macaroon(user_id) - macaroon.add_first_party_caveat("type = login") - now = self.hs.get_clock().time_msec() - expiry = now + duration_in_ms - macaroon.add_first_party_caveat("time < %d" % (expiry,)) - return macaroon.serialize() - - def generate_delete_pusher_token(self, user_id): - macaroon = self._generate_base_macaroon(user_id) - macaroon.add_first_party_caveat("type = delete_pusher") - return macaroon.serialize() - def validate_short_term_login_token_and_get_user_id(self, login_token): auth_api = self.hs.get_auth() try: @@ -570,15 +545,6 @@ class AuthHandler(BaseHandler): except Exception: raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) - def _generate_base_macaroon(self, user_id): - macaroon = pymacaroons.Macaroon( - location=self.hs.config.server_name, - identifier="key", - key=self.hs.config.macaroon_secret_key) - macaroon.add_first_party_caveat("gen = 1") - macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) - return macaroon - @defer.inlineCallbacks def set_password(self, user_id, newpassword, requester=None): password_hash = self.hash(newpassword) @@ -607,7 +573,7 @@ class AuthHandler(BaseHandler): # types (mediums) of threepid. For now, we still use the existing # infrastructure, but this is the start of synapse gaining knowledge # of specific types of threepid (and fixes the fact that checking - # for the presenc eof an email address during password reset was + # for the presence of an email address during password reset was # case sensitive). if medium == 'email': address = address.lower() @@ -617,6 +583,17 @@ class AuthHandler(BaseHandler): self.hs.get_clock().time_msec() ) + @defer.inlineCallbacks + def delete_threepid(self, user_id, medium, address): + # 'Canonicalise' email addresses as per above + if medium == 'email': + address = address.lower() + + ret = yield self.store.user_delete_threepid( + user_id, medium, address, + ) + defer.returnValue(ret) + def _save_session(self, session): # TODO: Persistent storage logger.debug("Saving session %s", session) @@ -656,12 +633,54 @@ class AuthHandler(BaseHandler): Whether self.hash(password) == stored_hash (bool). """ if stored_hash: - return bcrypt.hashpw(password + self.hs.config.password_pepper, - stored_hash.encode('utf-8')) == stored_hash + return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper, + stored_hash.encode('utf8')) == stored_hash else: return False +class MacaroonGeneartor(object): + def __init__(self, hs): + self.clock = hs.get_clock() + self.server_name = hs.config.server_name + self.macaroon_secret_key = hs.config.macaroon_secret_key + + def generate_access_token(self, user_id, extra_caveats=None): + extra_caveats = extra_caveats or [] + macaroon = self._generate_base_macaroon(user_id) + macaroon.add_first_party_caveat("type = access") + # Include a nonce, to make sure that each login gets a different + # access token. + macaroon.add_first_party_caveat("nonce = %s" % ( + stringutils.random_string_with_symbols(16), + )) + for caveat in extra_caveats: + macaroon.add_first_party_caveat(caveat) + return macaroon.serialize() + + def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)): + macaroon = self._generate_base_macaroon(user_id) + macaroon.add_first_party_caveat("type = login") + now = self.clock.time_msec() + expiry = now + duration_in_ms + macaroon.add_first_party_caveat("time < %d" % (expiry,)) + return macaroon.serialize() + + def generate_delete_pusher_token(self, user_id): + macaroon = self._generate_base_macaroon(user_id) + macaroon.add_first_party_caveat("type = delete_pusher") + return macaroon.serialize() + + def _generate_base_macaroon(self, user_id): + macaroon = pymacaroons.Macaroon( + location=self.server_name, + identifier="key", + key=self.macaroon_secret_key) + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) + return macaroon + + class _AccountHandler(object): """A proxy object that gets passed to password auth providers so they can register new users etc if necessary. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index aa68755936..8cb47ac417 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -14,7 +14,11 @@ # limitations under the License. from synapse.api import errors +from synapse.api.constants import EventTypes from synapse.util import stringutils +from synapse.util.async import Linearizer +from synapse.util.metrics import measure_func +from synapse.types import get_domain_from_id, RoomStreamToken from twisted.internet import defer from ._base import BaseHandler @@ -27,6 +31,21 @@ class DeviceHandler(BaseHandler): def __init__(self, hs): super(DeviceHandler, self).__init__(hs) + self.hs = hs + self.state = hs.get_state_handler() + self.federation_sender = hs.get_federation_sender() + self.federation = hs.get_replication_layer() + self._remote_edue_linearizer = Linearizer(name="remote_device_list") + + self.federation.register_edu_handler( + "m.device_list_update", self._incoming_device_list_update, + ) + self.federation.register_query_handler( + "user_devices", self.on_federation_query_user_devices, + ) + + hs.get_distributor().observe("user_left_room", self.user_left_room) + @defer.inlineCallbacks def check_device_registered(self, user_id, device_id, initial_device_display_name=None): @@ -45,29 +64,29 @@ class DeviceHandler(BaseHandler): str: device id (generated if none was supplied) """ if device_id is not None: - yield self.store.store_device( + new_device = yield self.store.store_device( user_id=user_id, device_id=device_id, initial_device_display_name=initial_device_display_name, - ignore_if_known=True, ) + if new_device: + yield self.notify_device_update(user_id, [device_id]) defer.returnValue(device_id) # if the device id is not specified, we'll autogen one, but loop a few # times in case of a clash. attempts = 0 while attempts < 5: - try: - device_id = stringutils.random_string(10).upper() - yield self.store.store_device( - user_id=user_id, - device_id=device_id, - initial_device_display_name=initial_device_display_name, - ignore_if_known=False, - ) + device_id = stringutils.random_string(10).upper() + new_device = yield self.store.store_device( + user_id=user_id, + device_id=device_id, + initial_device_display_name=initial_device_display_name, + ) + if new_device: + yield self.notify_device_update(user_id, [device_id]) defer.returnValue(device_id) - except errors.StoreError: - attempts += 1 + attempts += 1 raise errors.StoreError(500, "Couldn't generate a device ID.") @@ -147,6 +166,8 @@ class DeviceHandler(BaseHandler): user_id=user_id, device_id=device_id ) + yield self.notify_device_update(user_id, [device_id]) + @defer.inlineCallbacks def update_device(self, user_id, device_id, content): """ Update the given device @@ -166,12 +187,168 @@ class DeviceHandler(BaseHandler): device_id, new_display_name=content.get("display_name") ) + yield self.notify_device_update(user_id, [device_id]) except errors.StoreError, e: if e.code == 404: raise errors.NotFoundError() else: raise + @measure_func("notify_device_update") + @defer.inlineCallbacks + def notify_device_update(self, user_id, device_ids): + """Notify that a user's device(s) has changed. Pokes the notifier, and + remote servers if the user is local. + """ + users_who_share_room = yield self.store.get_users_who_share_room_with_user( + user_id + ) + + hosts = set() + if self.hs.is_mine_id(user_id): + hosts.update(get_domain_from_id(u) for u in users_who_share_room) + hosts.discard(self.server_name) + + position = yield self.store.add_device_change_to_streams( + user_id, device_ids, list(hosts) + ) + + rooms = yield self.store.get_rooms_for_user(user_id) + room_ids = [r.room_id for r in rooms] + + yield self.notifier.on_new_event( + "device_list_key", position, rooms=room_ids, + ) + + if hosts: + logger.info("Sending device list update notif to: %r", hosts) + for host in hosts: + self.federation_sender.send_device_messages(host) + + @measure_func("device.get_user_ids_changed") + @defer.inlineCallbacks + def get_user_ids_changed(self, user_id, from_token): + """Get list of users that have had the devices updated, or have newly + joined a room, that `user_id` may be interested in. + + Args: + user_id (str) + from_token (StreamToken) + """ + rooms = yield self.store.get_rooms_for_user(user_id) + room_ids = set(r.room_id for r in rooms) + + # First we check if any devices have changed + changed = yield self.store.get_user_whose_devices_changed( + from_token.device_list_key + ) + + # Then work out if any users have since joined + rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key) + + possibly_changed = set(changed) + for room_id in rooms_changed: + # Fetch the current state at the time. + stream_ordering = RoomStreamToken.parse_stream_token(from_token.room_key) + + try: + event_ids = yield self.store.get_forward_extremeties_for_room( + room_id, stream_ordering=stream_ordering + ) + prev_state_ids = yield self.store.get_state_ids_for_events(event_ids) + except: + prev_state_ids = {} + + current_state_ids = yield self.state.get_current_state_ids(room_id) + + # If there has been any change in membership, include them in the + # possibly changed list. We'll check if they are joined below, + # and we're not toooo worried about spuriously adding users. + for key, event_id in current_state_ids.iteritems(): + etype, state_key = key + if etype == EventTypes.Member: + prev_event_id = prev_state_ids.get(key, None) + if not prev_event_id or prev_event_id != event_id: + possibly_changed.add(state_key) + + users_who_share_room = yield self.store.get_users_who_share_room_with_user( + user_id + ) + + # Take the intersection of the users whose devices may have changed + # and those that actually still share a room with the user + defer.returnValue(users_who_share_room & possibly_changed) + + @measure_func("_incoming_device_list_update") + @defer.inlineCallbacks + def _incoming_device_list_update(self, origin, edu_content): + user_id = edu_content["user_id"] + device_id = edu_content["device_id"] + stream_id = edu_content["stream_id"] + prev_ids = edu_content.get("prev_id", []) + + if get_domain_from_id(user_id) != origin: + # TODO: Raise? + logger.warning("Got device list update edu for %r from %r", user_id, origin) + return + + rooms = yield self.store.get_rooms_for_user(user_id) + if not rooms: + # We don't share any rooms with this user. Ignore update, as we + # probably won't get any further updates. + return + + with (yield self._remote_edue_linearizer.queue(user_id)): + # If the prev id matches whats in our cache table, then we don't need + # to resync the users device list, otherwise we do. + resync = True + if len(prev_ids) == 1: + extremity = yield self.store.get_device_list_last_stream_id_for_remote( + user_id + ) + logger.info("Extrem: %r, prev_ids: %r", extremity, prev_ids) + if str(extremity) == str(prev_ids[0]): + resync = False + + if resync: + # Fetch all devices for the user. + result = yield self.federation.query_user_devices(origin, user_id) + stream_id = result["stream_id"] + devices = result["devices"] + yield self.store.update_remote_device_list_cache( + user_id, devices, stream_id, + ) + device_ids = [device["device_id"] for device in devices] + yield self.notify_device_update(user_id, device_ids) + else: + # Simply update the single device, since we know that is the only + # change (becuase of the single prev_id matching the current cache) + content = dict(edu_content) + for key in ("user_id", "device_id", "stream_id", "prev_ids"): + content.pop(key, None) + yield self.store.update_remote_device_list_cache_entry( + user_id, device_id, content, stream_id, + ) + yield self.notify_device_update(user_id, [device_id]) + + @defer.inlineCallbacks + def on_federation_query_user_devices(self, user_id): + stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id) + defer.returnValue({ + "user_id": user_id, + "stream_id": stream_id, + "devices": devices, + }) + + @defer.inlineCallbacks + def user_left_room(self, user, room_id): + user_id = user.to_string() + rooms = yield self.store.get_rooms_for_user(user_id) + if not rooms: + # We no longer share rooms with this user, so we'll no longer + # receive device updates. Mark this in DB. + yield self.store.mark_remote_user_device_list_as_unsubscribed(user_id) + def _update_device_from_client_ips(device, client_ips): ip = client_ips.get((device["user_id"], device["device_id"]), {}) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index b63a660c06..e40495d1ab 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -73,10 +73,9 @@ class E2eKeysHandler(object): if self.is_mine_id(user_id): local_query[user_id] = device_ids else: - domain = get_domain_from_id(user_id) - remote_queries.setdefault(domain, {})[user_id] = device_ids + remote_queries[user_id] = device_ids - # do the queries + # Firt get local devices. failures = {} results = {} if local_query: @@ -85,9 +84,42 @@ class E2eKeysHandler(object): if user_id in local_query: results[user_id] = keys + # Now attempt to get any remote devices from our local cache. + remote_queries_not_in_cache = {} + if remote_queries: + query_list = [] + for user_id, device_ids in remote_queries.iteritems(): + if device_ids: + query_list.extend((user_id, device_id) for device_id in device_ids) + else: + query_list.append((user_id, None)) + + user_ids_not_in_cache, remote_results = ( + yield self.store.get_user_devices_from_cache( + query_list + ) + ) + for user_id, devices in remote_results.iteritems(): + user_devices = results.setdefault(user_id, {}) + for device_id, device in devices.iteritems(): + keys = device.get("keys", None) + device_display_name = device.get("device_display_name", None) + if keys: + result = dict(keys) + unsigned = result.setdefault("unsigned", {}) + if device_display_name: + unsigned["device_display_name"] = device_display_name + user_devices[device_id] = result + + for user_id in user_ids_not_in_cache: + domain = get_domain_from_id(user_id) + r = remote_queries_not_in_cache.setdefault(domain, {}) + r[user_id] = remote_queries[user_id] + + # Now fetch any devices that we don't have in our cache @defer.inlineCallbacks def do_remote_query(destination): - destination_query = remote_queries[destination] + destination_query = remote_queries_not_in_cache[destination] try: limiter = yield get_retry_limiter( destination, self.clock, self.store @@ -119,7 +151,7 @@ class E2eKeysHandler(object): yield preserve_context_over_deferred(defer.gatherResults([ preserve_fn(do_remote_query)(destination) - for destination in remote_queries + for destination in remote_queries_not_in_cache ])) defer.returnValue({ @@ -162,7 +194,7 @@ class E2eKeysHandler(object): # "unsigned" section for user_id, device_keys in results.items(): for device_id, device_info in device_keys.items(): - r = json.loads(device_info["key_json"]) + r = dict(device_info["keys"]) r["unsigned"] = {} display_name = device_info["device_display_name"] if display_name is not None: @@ -255,10 +287,12 @@ class E2eKeysHandler(object): device_id, user_id, time_now ) # TODO: Sign the JSON with the server key - yield self.store.set_e2e_device_keys( - user_id, device_id, time_now, - encode_canonical_json(device_keys) + changed = yield self.store.set_e2e_device_keys( + user_id, device_id, time_now, device_keys, ) + if changed: + # Only notify about device updates *if* the keys actually changed + yield self.device_handler.notify_device_update(user_id, [device_id]) one_time_keys = keys.get("one_time_keys", None) if one_time_keys: diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 1021bcc405..996bfd0e23 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -591,12 +591,12 @@ class FederationHandler(BaseHandler): event_ids = list(extremities.keys()) - logger.info("calling resolve_state_groups in _maybe_backfill") + logger.debug("calling resolve_state_groups in _maybe_backfill") states = yield preserve_context_over_deferred(defer.gatherResults([ preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e]) for e in event_ids ])) - states = dict(zip(event_ids, [s[1] for s in states])) + states = dict(zip(event_ids, [s.state for s in states])) state_map = yield self.store.get_events( [e_id for ids in states.values() for e_id in ids], @@ -1319,7 +1319,6 @@ class FederationHandler(BaseHandler): event_stream_id, max_stream_id = yield self.store.persist_event( event, new_event_context, - current_state=state, ) defer.returnValue((event_stream_id, max_stream_id)) @@ -1530,7 +1529,7 @@ class FederationHandler(BaseHandler): (d.type, d.state_key): d for d in different_events if d }) - new_state, prev_state = self.state_handler.resolve_events( + new_state = self.state_handler.resolve_events( [local_view.values(), remote_view.values()], event ) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 7a57a69bd3..7a498af5a2 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -208,8 +208,10 @@ class MessageHandler(BaseHandler): content = builder.content try: - content["displayname"] = yield profile.get_displayname(target) - content["avatar_url"] = yield profile.get_avatar_url(target) + if "displayname" not in content: + content["displayname"] = yield profile.get_displayname(target) + if "avatar_url" not in content: + content["avatar_url"] = yield profile.get_avatar_url(target) except Exception as e: logger.info( "Failed to get profile information for %r: %s", @@ -279,7 +281,9 @@ class MessageHandler(BaseHandler): if event.type == EventTypes.Message: presence = self.hs.get_presence_handler() - yield presence.bump_presence_active_time(user) + # We don't want to block sending messages on any presence code. This + # matters as sometimes presence code can take a while. + preserve_fn(presence.bump_presence_active_time)(user) @defer.inlineCallbacks def deduplicate_state_event(self, event, context): diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 1b89dc6274..fdfce2a88c 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -574,7 +574,7 @@ class PresenceHandler(object): if not local_states: continue - users = yield self.state.get_current_user_in_room(room_id) + users = yield self.store.get_users_in_room(room_id) hosts = set(get_domain_from_id(u) for u in users) for host in hosts: @@ -766,7 +766,7 @@ class PresenceHandler(object): # don't need to send to local clients here, as that is done as part # of the event stream/sync. # TODO: Only send to servers not already in the room. - user_ids = yield self.state.get_current_user_in_room(room_id) + user_ids = yield self.store.get_users_in_room(room_id) if self.is_mine(user): state = yield self.current_state_for_user(user.to_string()) @@ -1011,7 +1011,7 @@ class PresenceEventSource(object): @defer.inlineCallbacks @log_function def get_new_events(self, user, from_key, room_ids=None, include_offline=True, - **kwargs): + explicit_room_id=None, **kwargs): # The process for getting presence events are: # 1. Get the rooms the user is in. # 2. Get the list of user in the rooms. @@ -1028,22 +1028,24 @@ class PresenceEventSource(object): user_id = user.to_string() if from_key is not None: from_key = int(from_key) - room_ids = room_ids or [] presence = self.get_presence_handler() stream_change_cache = self.store.presence_stream_cache - if not room_ids: - rooms = yield self.store.get_rooms_for_user(user_id) - room_ids = set(e.room_id for e in rooms) - else: - room_ids = set(room_ids) - max_token = self.store.get_current_presence_token() plist = yield self.store.get_presence_list_accepted(user.localpart) - friends = set(row["observed_user_id"] for row in plist) - friends.add(user_id) # So that we receive our own presence + users_interested_in = set(row["observed_user_id"] for row in plist) + users_interested_in.add(user_id) # So that we receive our own presence + + users_who_share_room = yield self.store.get_users_who_share_room_with_user( + user_id + ) + users_interested_in.update(users_who_share_room) + + if explicit_room_id: + user_ids = yield self.store.get_users_in_room(explicit_room_id) + users_interested_in.update(user_ids) user_ids_changed = set() changed = None @@ -1055,35 +1057,19 @@ class PresenceEventSource(object): # work out if we share a room or they're in our presence list get_updates_counter.inc("stream") for other_user_id in changed: - if other_user_id in friends: + if other_user_id in users_interested_in: user_ids_changed.add(other_user_id) - continue - other_rooms = yield self.store.get_rooms_for_user(other_user_id) - if room_ids.intersection(e.room_id for e in other_rooms): - user_ids_changed.add(other_user_id) - continue else: # Too many possible updates. Find all users we can see and check # if any of them have changed. get_updates_counter.inc("full") - user_ids_to_check = set() - for room_id in room_ids: - users = yield self.state.get_current_user_in_room(room_id) - user_ids_to_check.update(users) - - user_ids_to_check.update(friends) - - # Always include yourself. Only really matters for when the user is - # not in any rooms, but still. - user_ids_to_check.add(user_id) - if from_key: user_ids_changed = stream_change_cache.get_entities_changed( - user_ids_to_check, from_key, + users_interested_in, from_key, ) else: - user_ids_changed = user_ids_to_check + user_ids_changed = users_interested_in updates = yield presence.current_state_for_users(user_ids_changed) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 286f0cef0a..03c6a85fc6 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -40,6 +40,8 @@ class RegistrationHandler(BaseHandler): self._next_generated_user_id = None + self.macaroon_gen = hs.get_macaroon_generator() + @defer.inlineCallbacks def check_username(self, localpart, guest_access_token=None, assigned_user_id=None): @@ -143,7 +145,7 @@ class RegistrationHandler(BaseHandler): token = None if generate_token: - token = self.auth_handler().generate_access_token(user_id) + token = self.macaroon_gen.generate_access_token(user_id) yield self.store.register( user_id=user_id, token=token, @@ -167,7 +169,7 @@ class RegistrationHandler(BaseHandler): user_id = user.to_string() yield self.check_user_id_not_appservice_exclusive(user_id) if generate_token: - token = self.auth_handler().generate_access_token(user_id) + token = self.macaroon_gen.generate_access_token(user_id) try: yield self.store.register( user_id=user_id, @@ -254,7 +256,7 @@ class RegistrationHandler(BaseHandler): user_id = user.to_string() yield self.check_user_id_not_appservice_exclusive(user_id) - token = self.auth_handler().generate_access_token(user_id) + token = self.macaroon_gen.generate_access_token(user_id) try: yield self.store.register( user_id=user_id, @@ -399,7 +401,7 @@ class RegistrationHandler(BaseHandler): user = UserID(localpart, self.hs.hostname) user_id = user.to_string() - token = self.auth_handler().generate_access_token(user_id) + token = self.macaroon_gen.generate_access_token(user_id) if need_register: yield self.store.register( diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 5f18007e90..7e7671c9a2 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -437,6 +437,7 @@ class RoomEventSource(object): limit, room_ids, is_guest, + explicit_room_id=None, ): # We just ignore the key for now. diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 667223df0c..19eebbd43f 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -62,17 +62,18 @@ class RoomListHandler(BaseHandler): appservice and network id to use an appservice specific one. Setting to None returns all public rooms across all lists. """ - if search_filter or (network_tuple and network_tuple.appservice_id is not None): + if search_filter: # We explicitly don't bother caching searches or requests for # appservice specific lists. return self._get_public_room_list( limit, since_token, search_filter, network_tuple=network_tuple, ) - result = self.response_cache.get((limit, since_token)) + key = (limit, since_token, network_tuple) + result = self.response_cache.get(key) if not result: result = self.response_cache.set( - (limit, since_token), + key, self._get_public_room_list( limit, since_token, network_tuple=network_tuple ) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 2f8782e522..b2806555cf 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -45,7 +45,7 @@ class RoomMemberHandler(BaseHandler): def __init__(self, hs): super(RoomMemberHandler, self).__init__(hs) - self.member_linearizer = Linearizer() + self.member_linearizer = Linearizer(name="member") self.clock = hs.get_clock() @@ -89,7 +89,7 @@ class RoomMemberHandler(BaseHandler): duplicate = yield msg_handler.deduplicate_state_event(event, context) if duplicate is not None: # Discard the new event since this membership change is a no-op. - return + defer.returnValue(duplicate) yield msg_handler.handle_new_client_event( requester, @@ -120,6 +120,8 @@ class RoomMemberHandler(BaseHandler): if prev_member_event.membership == Membership.JOIN: user_left_room(self.distributor, target, room_id) + defer.returnValue(event) + @defer.inlineCallbacks def remote_join(self, remote_room_hosts, room_id, user, content): if len(remote_room_hosts) == 0: @@ -187,6 +189,7 @@ class RoomMemberHandler(BaseHandler): ratelimit=True, content=None, ): + content_specified = bool(content) if content is None: content = {} @@ -229,6 +232,13 @@ class RoomMemberHandler(BaseHandler): errcode=Codes.BAD_STATE ) + if old_state: + same_content = content == old_state.content + same_membership = old_membership == effective_membership_state + same_sender = requester.user.to_string() == old_state.sender + if same_sender and same_membership and same_content: + defer.returnValue(old_state) + is_host_in_room = yield self._is_host_in_room(current_state_ids) if effective_membership_state == Membership.JOIN: @@ -247,8 +257,9 @@ class RoomMemberHandler(BaseHandler): content["membership"] = Membership.JOIN profile = self.hs.get_handlers().profile_handler - content["displayname"] = yield profile.get_displayname(target) - content["avatar_url"] = yield profile.get_avatar_url(target) + if not content_specified: + content["displayname"] = yield profile.get_displayname(target) + content["avatar_url"] = yield profile.get_avatar_url(target) if requester.is_guest: content["kind"] = "guest" @@ -290,7 +301,7 @@ class RoomMemberHandler(BaseHandler): defer.returnValue({}) - yield self._local_membership_update( + res = yield self._local_membership_update( requester=requester, target=target, room_id=room_id, @@ -300,6 +311,7 @@ class RoomMemberHandler(BaseHandler): prev_event_ids=latest_event_ids, content=content, ) + defer.returnValue(res) @defer.inlineCallbacks def send_membership_event( diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index c880f61685..d7dcd1ce5b 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -16,7 +16,7 @@ from synapse.api.constants import Membership, EventTypes from synapse.util.async import concurrently_execute from synapse.util.logcontext import LoggingContext -from synapse.util.metrics import Measure +from synapse.util.metrics import Measure, measure_func from synapse.util.caches.response_cache import ResponseCache from synapse.push.clientformat import format_push_rules_for_user from synapse.visibility import filter_events_for_client @@ -115,6 +115,7 @@ class SyncResult(collections.namedtuple("SyncResult", [ "invited", # InvitedSyncResult for each invited room. "archived", # ArchivedSyncResult for each archived room. "to_device", # List of direct messages for the device. + "device_lists", # List of user_ids whose devices have chanegd ])): __slots__ = [] @@ -129,7 +130,8 @@ class SyncResult(collections.namedtuple("SyncResult", [ self.invited or self.archived or self.account_data or - self.to_device + self.to_device or + self.device_lists ) @@ -544,6 +546,10 @@ class SyncHandler(object): yield self._generate_sync_entry_for_to_device(sync_result_builder) + device_lists = yield self._generate_sync_entry_for_device_list( + sync_result_builder + ) + defer.returnValue(SyncResult( presence=sync_result_builder.presence, account_data=sync_result_builder.account_data, @@ -551,9 +557,33 @@ class SyncHandler(object): invited=sync_result_builder.invited, archived=sync_result_builder.archived, to_device=sync_result_builder.to_device, + device_lists=device_lists, next_batch=sync_result_builder.now_token, )) + @measure_func("_generate_sync_entry_for_device_list") + @defer.inlineCallbacks + def _generate_sync_entry_for_device_list(self, sync_result_builder): + user_id = sync_result_builder.sync_config.user.to_string() + since_token = sync_result_builder.since_token + + if since_token and since_token.device_list_key: + rooms = yield self.store.get_rooms_for_user(user_id) + room_ids = set(r.room_id for r in rooms) + + user_ids_changed = set() + changed = yield self.store.get_user_whose_devices_changed( + since_token.device_list_key + ) + for other_user_id in changed: + other_rooms = yield self.store.get_rooms_for_user(other_user_id) + if room_ids.intersection(e.room_id for e in other_rooms): + user_ids_changed.add(other_user_id) + + defer.returnValue(user_ids_changed) + else: + defer.returnValue([]) + @defer.inlineCallbacks def _generate_sync_entry_for_to_device(self, sync_result_builder): """Generates the portion of the sync response. Populates diff --git a/synapse/http/client.py b/synapse/http/client.py index 3ec9bc7faf..ca2f770f5d 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -25,7 +25,7 @@ from synapse.http.endpoint import SpiderEndpoint from canonicaljson import encode_canonical_json from twisted.internet import defer, reactor, ssl, protocol, task -from twisted.internet.endpoints import SSL4ClientEndpoint, TCP4ClientEndpoint +from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.web.client import ( BrowserLikeRedirectAgent, ContentDecoderAgent, GzipDecoder, Agent, readBody, PartialDownloadError, @@ -386,26 +386,23 @@ class SpiderEndpointFactory(object): def endpointForURI(self, uri): logger.info("Getting endpoint for %s", uri.toBytes()) + if uri.scheme == "http": - return SpiderEndpoint( - reactor, uri.host, uri.port, self.blacklist, self.whitelist, - endpoint=TCP4ClientEndpoint, - endpoint_kw_args={ - 'timeout': 15 - }, - ) + endpoint_factory = HostnameEndpoint elif uri.scheme == "https": - tlsPolicy = self.policyForHTTPS.creatorForNetloc(uri.host, uri.port) - return SpiderEndpoint( - reactor, uri.host, uri.port, self.blacklist, self.whitelist, - endpoint=SSL4ClientEndpoint, - endpoint_kw_args={ - 'sslContextFactory': tlsPolicy, - 'timeout': 15 - }, - ) + tlsCreator = self.policyForHTTPS.creatorForNetloc(uri.host, uri.port) + + def endpoint_factory(reactor, host, port, **kw): + return wrapClientTLS( + tlsCreator, + HostnameEndpoint(reactor, host, port, **kw)) else: logger.warn("Can't get endpoint for unrecognised scheme %s", uri.scheme) + return None + return SpiderEndpoint( + reactor, uri.host, uri.port, self.blacklist, self.whitelist, + endpoint=endpoint_factory, endpoint_kw_args=dict(timeout=15), + ) class SpiderHttpClient(SimpleHttpClient): diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py index 8c64339a7c..d8923c9abb 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet.endpoints import SSL4ClientEndpoint, TCP4ClientEndpoint +from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.internet import defer, reactor from twisted.internet.error import ConnectError from twisted.names import client, dns @@ -58,11 +58,13 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None, endpoint_kw_args.update(timeout=timeout) if ssl_context_factory is None: - transport_endpoint = TCP4ClientEndpoint + transport_endpoint = HostnameEndpoint default_port = 8008 else: - transport_endpoint = SSL4ClientEndpoint - endpoint_kw_args.update(sslContextFactory=ssl_context_factory) + def transport_endpoint(reactor, host, port, timeout): + return wrapClientTLS( + ssl_context_factory, + HostnameEndpoint(reactor, host, port, timeout=timeout)) default_port = 8448 if port is None: @@ -142,7 +144,7 @@ class SpiderEndpoint(object): Implements twisted.internet.interfaces.IStreamClientEndpoint. """ def __init__(self, reactor, host, port, blacklist, whitelist, - endpoint=TCP4ClientEndpoint, endpoint_kw_args={}): + endpoint=HostnameEndpoint, endpoint_kw_args={}): self.reactor = reactor self.host = host self.port = port @@ -180,7 +182,7 @@ class SRVClientEndpoint(object): """ def __init__(self, reactor, service, domain, protocol="tcp", - default_port=None, endpoint=TCP4ClientEndpoint, + default_port=None, endpoint=HostnameEndpoint, endpoint_kw_args={}): self.reactor = reactor self.service_name = "_%s._%s.%s" % (service, protocol, domain) diff --git a/synapse/notifier.py b/synapse/notifier.py index acbd4bb5ae..8051a7a842 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -378,6 +378,7 @@ class Notifier(object): limit=limit, is_guest=is_peeking, room_ids=room_ids, + explicit_room_id=explicit_room_id, ) if name == "room": diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 53551632b6..62d794f22b 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -81,7 +81,7 @@ class Mailer(object): def __init__(self, hs, app_name): self.hs = hs self.store = self.hs.get_datastore() - self.auth_handler = self.hs.get_auth_handler() + self.macaroon_gen = self.hs.get_macaroon_generator() self.state_handler = self.hs.get_state_handler() loader = jinja2.FileSystemLoader(self.hs.config.email_template_dir) self.app_name = app_name @@ -439,15 +439,23 @@ class Mailer(object): }) def make_room_link(self, room_id): - # need /beta for Universal Links to work on iOS - if self.app_name == "Vector": - return "https://vector.im/beta/#/room/%s" % (room_id,) + if self.hs.config.email_riot_base_url: + base_url = self.hs.config.email_riot_base_url + elif self.app_name == "Vector": + # need /beta for Universal Links to work on iOS + base_url = "https://vector.im/beta/#/room" else: - return "https://matrix.to/#/%s" % (room_id,) + base_url = "https://matrix.to/#" + return "%s/%s" % (base_url, room_id) def make_notif_link(self, notif): - # need /beta for Universal Links to work on iOS - if self.app_name == "Vector": + if self.hs.config.email_riot_base_url: + return "%s/#/room/%s/%s" % ( + self.hs.config.email_riot_base_url, + notif['room_id'], notif['event_id'] + ) + elif self.app_name == "Vector": + # need /beta for Universal Links to work on iOS return "https://vector.im/beta/#/room/%s/%s" % ( notif['room_id'], notif['event_id'] ) @@ -458,7 +466,7 @@ class Mailer(object): def make_unsubscribe_link(self, user_id, app_id, email_address): params = { - "access_token": self.auth_handler.generate_delete_pusher_token(user_id), + "access_token": self.macaroon_gen.generate_delete_pusher_token(user_id), "app_id": app_id, "pushkey": email_address, } diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index b47bf1f92b..a27476bbad 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -52,7 +52,7 @@ def get_badge_count(store, user_id): def get_context_for_event(store, state_handler, ev, user_id): ctx = {} - room_state_ids = yield state_handler.get_current_state_ids(ev.room_id) + room_state_ids = yield store.get_state_ids_for_event(ev.event_id) # we no longer bother setting room_alias, and make room_name the # human-readable name instead, be that m.room.name, an alias or diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 3742a25b37..7817b0cd91 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -24,7 +24,7 @@ REQUIREMENTS = { "signedjson>=1.0.0": ["signedjson>=1.0.0"], "pynacl==0.3.0": ["nacl==0.3.0", "nacl.bindings"], "service_identity>=1.0.0": ["service_identity>=1.0.0"], - "Twisted>=15.1.0": ["twisted>=15.1.0"], + "Twisted>=16.0.0": ["twisted>=16.0.0"], "pyopenssl>=0.14": ["OpenSSL>=0.14"], "pyyaml": ["yaml"], "pyasn1": ["pyasn1"], diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py index 4616e9b34a..d8eb14592b 100644 --- a/synapse/replication/resource.py +++ b/synapse/replication/resource.py @@ -46,6 +46,7 @@ STREAM_NAMES = ( ("to_device",), ("public_rooms",), ("federation",), + ("device_lists",), ) @@ -140,6 +141,7 @@ class ReplicationResource(Resource): caches_token = self.store.get_cache_stream_token() public_rooms_token = self.store.get_current_public_room_stream_id() federation_token = self.federation_sender.get_current_token() + device_list_token = self.store.get_device_stream_token() defer.returnValue(_ReplicationToken( room_stream_token, @@ -155,6 +157,7 @@ class ReplicationResource(Resource): int(stream_token.to_device_key), int(public_rooms_token), int(federation_token), + int(device_list_token), )) @request_handler() @@ -214,6 +217,7 @@ class ReplicationResource(Resource): yield self.caches(writer, current_token, limit, request_streams) yield self.to_device(writer, current_token, limit, request_streams) yield self.public_rooms(writer, current_token, limit, request_streams) + yield self.device_lists(writer, current_token, limit, request_streams) self.federation(writer, current_token, limit, request_streams, federation_ack) self.streams(writer, current_token, request_streams) @@ -295,9 +299,6 @@ class ReplicationResource(Resource): "backward_ex_outliers", res.backward_ex_outliers, ("position", "event_id", "state_group"), ) - writer.write_header_and_rows( - "state_resets", res.state_resets, ("position",), - ) @defer.inlineCallbacks def presence(self, writer, current_token, request_streams): @@ -495,6 +496,20 @@ class ReplicationResource(Resource): "position", "type", "content", ), position=upto_token) + @defer.inlineCallbacks + def device_lists(self, writer, current_token, limit, request_streams): + current_position = current_token.device_lists + + device_lists = request_streams.get("device_lists") + + if device_lists is not None and device_lists != current_position: + changes = yield self.store.get_all_device_list_changes_for_remotes( + device_lists, + ) + writer.write_header_and_rows("device_lists", changes, ( + "position", "user_id", "destination", + ), position=current_position) + class _Writer(object): """Writes the streams as a JSON object as the response to the request""" @@ -527,7 +542,7 @@ class _Writer(object): class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( "events", "presence", "typing", "receipts", "account_data", "backfill", "push_rules", "pushers", "state", "caches", "to_device", "public_rooms", - "federation", + "federation", "device_lists", ))): __slots__ = [] diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py new file mode 100644 index 0000000000..ca46aa17b6 --- /dev/null +++ b/synapse/replication/slave/storage/devices.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import BaseSlavedStore +from ._slaved_id_tracker import SlavedIdTracker +from synapse.storage import DataStore +from synapse.util.caches.stream_change_cache import StreamChangeCache + + +class SlavedDeviceStore(BaseSlavedStore): + def __init__(self, db_conn, hs): + super(SlavedDeviceStore, self).__init__(db_conn, hs) + + self.hs = hs + + self._device_list_id_gen = SlavedIdTracker( + db_conn, "device_lists_stream", "stream_id", + ) + device_list_max = self._device_list_id_gen.get_current_token() + self._device_list_stream_cache = StreamChangeCache( + "DeviceListStreamChangeCache", device_list_max, + ) + self._device_list_federation_stream_cache = StreamChangeCache( + "DeviceListFederationStreamChangeCache", device_list_max, + ) + + get_device_stream_token = DataStore.get_device_stream_token.__func__ + get_user_whose_devices_changed = DataStore.get_user_whose_devices_changed.__func__ + get_devices_by_remote = DataStore.get_devices_by_remote.__func__ + _get_devices_by_remote_txn = DataStore._get_devices_by_remote_txn.__func__ + _get_e2e_device_keys_txn = DataStore._get_e2e_device_keys_txn.__func__ + mark_as_sent_devices_by_remote = DataStore.mark_as_sent_devices_by_remote.__func__ + _mark_as_sent_devices_by_remote_txn = ( + DataStore._mark_as_sent_devices_by_remote_txn.__func__ + ) + + def stream_positions(self): + result = super(SlavedDeviceStore, self).stream_positions() + result["device_lists"] = self._device_list_id_gen.get_current_token() + return result + + def process_replication(self, result): + stream = result.get("device_lists") + if stream: + self._device_list_id_gen.advance(int(stream["position"])) + for row in stream["rows"]: + stream_id = row[0] + user_id = row[1] + destination = row[2] + + self._device_list_stream_cache.entity_has_changed( + user_id, stream_id + ) + + if destination: + self._device_list_federation_stream_cache.entity_has_changed( + destination, stream_id + ) + + return super(SlavedDeviceStore, self).process_replication(result) diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 64f18bbb3e..d72ff6055c 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -73,12 +73,12 @@ class SlavedEventStore(BaseSlavedStore): # to reach inside the __dict__ to extract them. get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"] get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"] + get_users_who_share_room_with_user = ( + RoomMemberStore.__dict__["get_users_who_share_room_with_user"] + ) get_latest_event_ids_in_room = EventFederationStore.__dict__[ "get_latest_event_ids_in_room" ] - _get_current_state_for_key = StateStore.__dict__[ - "_get_current_state_for_key" - ] get_invited_rooms_for_user = RoomMemberStore.__dict__[ "get_invited_rooms_for_user" ] @@ -115,8 +115,6 @@ class SlavedEventStore(BaseSlavedStore): ) get_event = DataStore.get_event.__func__ get_events = DataStore.get_events.__func__ - get_current_state = DataStore.get_current_state.__func__ - get_current_state_for_key = DataStore.get_current_state_for_key.__func__ get_rooms_for_user_where_membership_is = ( DataStore.get_rooms_for_user_where_membership_is.__func__ ) @@ -197,10 +195,6 @@ class SlavedEventStore(BaseSlavedStore): return result def process_replication(self, result): - state_resets = set( - r[0] for r in result.get("state_resets", {"rows": []})["rows"] - ) - stream = result.get("events") if stream: self._stream_id_gen.advance(int(stream["position"])) @@ -210,7 +204,7 @@ class SlavedEventStore(BaseSlavedStore): for row in stream["rows"]: self._process_replication_row( - row, backfilled=False, state_resets=state_resets + row, backfilled=False, ) stream = result.get("backfill") @@ -218,7 +212,7 @@ class SlavedEventStore(BaseSlavedStore): self._backfill_id_gen.advance(-int(stream["position"])) for row in stream["rows"]: self._process_replication_row( - row, backfilled=True, state_resets=state_resets + row, backfilled=True, ) stream = result.get("forward_ex_outliers") @@ -237,21 +231,15 @@ class SlavedEventStore(BaseSlavedStore): return super(SlavedEventStore, self).process_replication(result) - def _process_replication_row(self, row, backfilled, state_resets): - position = row[0] + def _process_replication_row(self, row, backfilled): internal = json.loads(row[1]) event_json = json.loads(row[2]) event = FrozenEvent(event_json, internal_metadata_dict=internal) self.invalidate_caches_for_event( - event, backfilled, reset_state=position in state_resets + event, backfilled, ) - def invalidate_caches_for_event(self, event, backfilled, reset_state): - if reset_state: - self._get_current_state_for_key.invalidate_all() - self.get_rooms_for_user.invalidate_all() - self.get_users_in_room.invalidate((event.room_id,)) - + def invalidate_caches_for_event(self, event, backfilled): self._invalidate_get_event_cache(event.event_id) self.get_latest_event_ids_in_room.invalidate((event.room_id,)) @@ -273,8 +261,6 @@ class SlavedEventStore(BaseSlavedStore): self._invalidate_get_event_cache(event.redacts) if event.type == EventTypes.Member: - self.get_rooms_for_user.invalidate((event.state_key,)) - self.get_users_in_room.invalidate((event.room_id,)) self._membership_stream_cache.entity_has_changed( event.state_key, event.internal_metadata.stream_ordering ) @@ -289,7 +275,3 @@ class SlavedEventStore(BaseSlavedStore): if (not event.internal_metadata.is_invite_from_remote() and event.internal_metadata.is_outlier()): return - - self._get_current_state_for_key.invalidate(( - event.room_id, event.type, event.state_key - )) diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py index 351170edbc..efa77b8c51 100644 --- a/synapse/rest/client/transactions.py +++ b/synapse/rest/client/transactions.py @@ -86,7 +86,11 @@ class HttpTransactionCache(object): pass # execute the function instead. deferred = fn(*args, **kwargs) - observable = ObservableDeferred(deferred) + + # We don't add an errback to the raw deferred, so we ask ObservableDeferred + # to swallow the error. This is fine as the error will still be reported + # to the observers. + observable = ObservableDeferred(deferred, consumeErrors=True) self.transactions[txn_key] = (observable, self.clock.time_msec()) return observable.observe() diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 093bc072f4..72057f1b0c 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -118,8 +118,14 @@ class LoginRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def do_password_login(self, login_submission): if 'medium' in login_submission and 'address' in login_submission: + address = login_submission['address'] + if login_submission['medium'] == 'email': + # For emails, transform the address to lowercase. + # We store all email addreses as lowercase in the DB. + # (See add_threepid in synapse/handlers/auth.py) + address = address.lower() user_id = yield self.hs.get_datastore().get_user_id_by_threepid( - login_submission['medium'], login_submission['address'] + login_submission['medium'], address ) if not user_id: raise LoginError(403, "", errcode=Codes.FORBIDDEN) @@ -324,6 +330,7 @@ class CasTicketServlet(ClientV1RestServlet): self.cas_required_attributes = hs.config.cas_required_attributes self.auth_handler = hs.get_auth_handler() self.handlers = hs.get_handlers() + self.macaroon_gen = hs.get_macaroon_generator() @defer.inlineCallbacks def on_GET(self, request): @@ -362,7 +369,9 @@ class CasTicketServlet(ClientV1RestServlet): yield self.handlers.registration_handler.register(localpart=user) ) - login_token = auth_handler.generate_short_term_login_token(registered_user_id) + login_token = self.macaroon_gen.generate_short_term_login_token( + registered_user_id + ) redirect_url = self.add_login_token_to_redirect_url(client_redirect_url, login_token) request.redirect(redirect_url) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index eead435bfd..2ebf5e59a0 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -152,23 +152,29 @@ class RoomStateEventRestServlet(ClientV1RestServlet): if state_key is not None: event_dict["state_key"] = state_key - msg_handler = self.handlers.message_handler - event, context = yield msg_handler.create_event( - event_dict, - token_id=requester.access_token_id, - txn_id=txn_id, - ) - if event_type == EventTypes.Member: - yield self.handlers.room_member_handler.send_membership_event( + membership = content.get("membership", None) + event = yield self.handlers.room_member_handler.update_membership( requester, - event, - context, + target=UserID.from_string(state_key), + room_id=room_id, + action=membership, + content=content, ) else: + msg_handler = self.handlers.message_handler + event, context = yield msg_handler.create_event( + event_dict, + token_id=requester.access_token_id, + txn_id=txn_id, + ) + yield msg_handler.send_nonmember_event(requester, event, context) - defer.returnValue((200, {"event_id": event.event_id})) + ret = {} + if event: + ret = {"event_id": event.event_id} + defer.returnValue((200, ret)) # TODO: Needs unit testing for generic events + feedback diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index c40442f958..03141c623c 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -32,18 +32,26 @@ class VoipRestServlet(ClientV1RestServlet): turnUris = self.hs.config.turn_uris turnSecret = self.hs.config.turn_shared_secret + turnUsername = self.hs.config.turn_username + turnPassword = self.hs.config.turn_password userLifetime = self.hs.config.turn_user_lifetime - if not turnUris or not turnSecret or not userLifetime: - defer.returnValue((200, {})) - expiry = (self.hs.get_clock().time_msec() + userLifetime) / 1000 - username = "%d:%s" % (expiry, requester.user.to_string()) + if turnUris and turnSecret and userLifetime: + expiry = (self.hs.get_clock().time_msec() + userLifetime) / 1000 + username = "%d:%s" % (expiry, requester.user.to_string()) + + mac = hmac.new(turnSecret, msg=username, digestmod=hashlib.sha1) + # We need to use standard padded base64 encoding here + # encode_base64 because we need to add the standard padding to get the + # same result as the TURN server. + password = base64.b64encode(mac.digest()) - mac = hmac.new(turnSecret, msg=username, digestmod=hashlib.sha1) - # We need to use standard padded base64 encoding here - # encode_base64 because we need to add the standard padding to get the - # same result as the TURN server. - password = base64.b64encode(mac.digest()) + elif turnUris and turnUsername and turnPassword and userLifetime: + username = turnUsername + password = turnPassword + + else: + defer.returnValue((200, {})) defer.returnValue((200, { 'username': username, diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index eb49ad62e9..398e7f5eb0 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -96,6 +96,11 @@ class PasswordRestServlet(RestServlet): threepid = result[LoginType.EMAIL_IDENTITY] if 'medium' not in threepid or 'address' not in threepid: raise SynapseError(500, "Malformed threepid") + if threepid['medium'] == 'email': + # For emails, transform the address to lowercase. + # We store all email addreses as lowercase in the DB. + # (See add_threepid in synapse/handlers/auth.py) + threepid['address'] = threepid['address'].lower() # if using email, we must know about the email they're authing with! threepid_user_id = yield self.hs.get_datastore().get_user_id_by_threepid( threepid['medium'], threepid['address'] @@ -241,7 +246,7 @@ class ThreepidRestServlet(RestServlet): for reqd in ['medium', 'address', 'validated_at']: if reqd not in threepid: - logger.warn("Couldn't add 3pid: invalid response from ID sevrer") + logger.warn("Couldn't add 3pid: invalid response from ID server") raise SynapseError(500, "Invalid response from ID Server") yield self.auth_handler.add_threepid( @@ -263,9 +268,43 @@ class ThreepidRestServlet(RestServlet): defer.returnValue((200, {})) +class ThreepidDeleteRestServlet(RestServlet): + PATTERNS = client_v2_patterns("/account/3pid/delete$", releases=()) + + def __init__(self, hs): + super(ThreepidDeleteRestServlet, self).__init__() + self.auth = hs.get_auth() + self.auth_handler = hs.get_auth_handler() + + @defer.inlineCallbacks + def on_POST(self, request): + yield run_on_reactor() + + body = parse_json_object_from_request(request) + + required = ['medium', 'address'] + absent = [] + for k in required: + if k not in body: + absent.append(k) + + if absent: + raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) + + requester = yield self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + + yield self.auth_handler.delete_threepid( + user_id, body['medium'], body['address'] + ) + + defer.returnValue((200, {})) + + def register_servlets(hs, http_server): PasswordRequestTokenRestServlet(hs).register(http_server) PasswordRestServlet(hs).register(http_server) DeactivateAccountRestServlet(hs).register(http_server) ThreepidRequestTokenRestServlet(hs).register(http_server) ThreepidRestServlet(hs).register(http_server) + ThreepidDeleteRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 46789775b9..6a3cfe84f8 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -21,6 +21,8 @@ from synapse.api.errors import SynapseError from synapse.http.servlet import ( RestServlet, parse_json_object_from_request, parse_integer ) +from synapse.http.servlet import parse_string +from synapse.types import StreamToken from ._base import client_v2_patterns logger = logging.getLogger(__name__) @@ -149,6 +151,52 @@ class KeyQueryServlet(RestServlet): defer.returnValue((200, result)) +class KeyChangesServlet(RestServlet): + """Returns the list of changes of keys between two stream tokens (may return + spurious extra results, since we currently ignore the `to` param). + + GET /keys/changes?from=...&to=... + + 200 OK + { "changed": ["@foo:example.com"] } + """ + PATTERNS = client_v2_patterns( + "/keys/changes$", + releases=() + ) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): + """ + super(KeyChangesServlet, self).__init__() + self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() + + @defer.inlineCallbacks + def on_GET(self, request): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + + from_token_string = parse_string(request, "from") + + # We want to enforce they do pass us one, but we ignore it and return + # changes after the "to" as well as before. + parse_string(request, "to") + + from_token = StreamToken.from_string(from_token_string) + + user_id = requester.user.to_string() + + changed = yield self.device_handler.get_user_ids_changed( + user_id, from_token, + ) + + defer.returnValue((200, { + "changed": list(changed), + })) + + class OneTimeKeyServlet(RestServlet): """ POST /keys/claim HTTP/1.1 @@ -192,4 +240,5 @@ class OneTimeKeyServlet(RestServlet): def register_servlets(hs, http_server): KeyUploadServlet(hs).register(http_server) KeyQueryServlet(hs).register(http_server) + KeyChangesServlet(hs).register(http_server) OneTimeKeyServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 3e7a285e10..ccca5a12d5 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -96,6 +96,7 @@ class RegisterRestServlet(RestServlet): self.registration_handler = hs.get_handlers().registration_handler self.identity_handler = hs.get_handlers().identity_handler self.device_handler = hs.get_device_handler() + self.macaroon_gen = hs.get_macaroon_generator() @defer.inlineCallbacks def on_POST(self, request): @@ -436,7 +437,7 @@ class RegisterRestServlet(RestServlet): user_id, device_id, initial_display_name ) - access_token = self.auth_handler.generate_access_token( + access_token = self.macaroon_gen.generate_access_token( user_id, ["guest = true"] ) defer.returnValue((200, { diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 7199ec883a..b3d8001638 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -170,12 +170,16 @@ class SyncRestServlet(RestServlet): ) archived = self.encode_archived( - sync_result.archived, time_now, requester.access_token_id, filter.event_fields + sync_result.archived, time_now, requester.access_token_id, + filter.event_fields, ) response_content = { "account_data": {"events": sync_result.account_data}, "to_device": {"events": sync_result.to_device}, + "device_lists": { + "changed": list(sync_result.device_lists), + }, "presence": self.encode_presence( sync_result.presence, time_now ), diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 692e078419..3cbeca503c 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -61,7 +61,7 @@ class MediaRepository(object): self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.thumbnail_requirements = hs.config.thumbnail_requirements - self.remote_media_linearizer = Linearizer() + self.remote_media_linearizer = Linearizer(name="media_remote") self.recently_accessed_remotes = set() @@ -98,6 +98,8 @@ class MediaRepository(object): with open(fname, "wb") as f: f.write(content) + logger.info("Stored local media in file %r", fname) + yield self.store.store_local_media( media_id=media_id, media_type=media_type, @@ -190,6 +192,8 @@ class MediaRepository(object): else: upload_name = None + logger.info("Stored remote media in file %r", fname) + yield self.store.store_cached_remote_media( origin=server_name, media_id=media_id, diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index 0bb3676844..3868d4f65f 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -16,6 +16,10 @@ import PIL.Image as Image from io import BytesIO +import logging + +logger = logging.getLogger(__name__) + class Thumbnailer(object): @@ -86,4 +90,5 @@ class Thumbnailer(object): output_bytes = output_bytes_io.getvalue() with open(output_path, "wb") as output_file: output_file.write(output_bytes) + logger.info("Stored thumbnail in file %r", output_path) return len(output_bytes) diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index b716d1d892..4ab33f73bf 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -97,6 +97,8 @@ class UploadResource(Resource): content_length, requester.user ) + logger.info("Uploaded content with URI %r", content_uri) + respond_with_json( request, 200, {"content_uri": content_uri}, send_cors=True ) diff --git a/synapse/server.py b/synapse/server.py index 0bfb411269..c577032041 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -37,7 +37,7 @@ from synapse.federation.transport.client import TransportLayerClient from synapse.federation.transaction_queue import TransactionQueue from synapse.handlers import Handlers from synapse.handlers.appservice import ApplicationServicesHandler -from synapse.handlers.auth import AuthHandler +from synapse.handlers.auth import AuthHandler, MacaroonGeneartor from synapse.handlers.devicemessage import DeviceMessageHandler from synapse.handlers.device import DeviceHandler from synapse.handlers.e2e_keys import E2eKeysHandler @@ -131,6 +131,7 @@ class HomeServer(object): 'federation_transport_client', 'federation_sender', 'receipts_handler', + 'macaroon_generator', ] def __init__(self, hostname, **kwargs): @@ -213,6 +214,9 @@ class HomeServer(object): def build_auth_handler(self): return AuthHandler(self) + def build_macaroon_generator(self): + return MacaroonGeneartor(self) + def build_device_handler(self): return DeviceHandler(self) diff --git a/synapse/state.py b/synapse/state.py index 8003099c88..383d32b163 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -16,12 +16,12 @@ from twisted.internet import defer +from synapse import event_auth from synapse.util.logutils import log_function from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import Measure from synapse.api.constants import EventTypes from synapse.api.errors import AuthError -from synapse.api.auth import AuthEventTypes from synapse.events.snapshot import EventContext from synapse.util.async import Linearizer @@ -41,12 +41,14 @@ KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) -SIZE_OF_CACHE = int(1000 * CACHE_SIZE_FACTOR) +SIZE_OF_CACHE = int(100000 * CACHE_SIZE_FACTOR) EVICTION_TIMEOUT_SECONDS = 60 * 60 _NEXT_STATE_ID = 1 +POWER_KEY = (EventTypes.PowerLevels, "") + def _gen_state_id(): global _NEXT_STATE_ID @@ -77,6 +79,9 @@ class _StateCacheEntry(object): else: self.state_id = _gen_state_id() + def __len__(self): + return len(self.state) + class StateHandler(object): """ Responsible for doing state conflict resolution. @@ -89,7 +94,7 @@ class StateHandler(object): # dict of set of event_ids -> _StateCacheEntry. self._state_cache = None - self.resolve_linearizer = Linearizer() + self.resolve_linearizer = Linearizer(name="state_resolve_lock") def start_caching(self): logger.debug("start_caching") @@ -99,6 +104,7 @@ class StateHandler(object): clock=self.clock, max_len=SIZE_OF_CACHE, expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000, + iterable=True, reset_expiry_on_get=True, ) @@ -123,7 +129,7 @@ class StateHandler(object): if not latest_event_ids: latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - logger.info("calling resolve_state_groups from get_current_state") + logger.debug("calling resolve_state_groups from get_current_state") ret = yield self.resolve_state_groups(room_id, latest_event_ids) state = ret.state @@ -148,7 +154,7 @@ class StateHandler(object): if not latest_event_ids: latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - logger.info("calling resolve_state_groups from get_current_state_ids") + logger.debug("calling resolve_state_groups from get_current_state_ids") ret = yield self.resolve_state_groups(room_id, latest_event_ids) state = ret.state @@ -162,7 +168,7 @@ class StateHandler(object): def get_current_user_in_room(self, room_id, latest_event_ids=None): if not latest_event_ids: latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - logger.info("calling resolve_state_groups from get_current_user_in_room") + logger.debug("calling resolve_state_groups from get_current_user_in_room") entry = yield self.resolve_state_groups(room_id, latest_event_ids) joined_users = yield self.store.get_joined_users_from_state( room_id, entry.state_id, entry.state @@ -226,7 +232,7 @@ class StateHandler(object): context.prev_state_events = [] defer.returnValue(context) - logger.info("calling resolve_state_groups from compute_event_context") + logger.debug("calling resolve_state_groups from compute_event_context") if event.is_state(): entry = yield self.resolve_state_groups( event.room_id, [e for e, _ in event.prev_events], @@ -327,20 +333,13 @@ class StateHandler(object): if conflicted_state: logger.info("Resolving conflicted state for %r", room_id) - state_map = yield self.store.get_events( - [e_id for st in state_groups_ids.values() for e_id in st.values()], - get_prev_content=False - ) - state_sets = [ - [state_map[e_id] for key, e_id in st.items() if e_id in state_map] - for st in state_groups_ids.values() - ] - new_state, _ = self._resolve_events( - state_sets, event_type, state_key - ) - new_state = { - key: e.event_id for key, e in new_state.items() - } + with Measure(self.clock, "state._resolve_events"): + new_state = yield resolve_events( + state_groups_ids.values(), + state_map_factory=lambda ev_ids: self.store.get_events( + ev_ids, get_prev_content=False, check_redacted=False, + ), + ) else: new_state = { key: e_ids.pop() for key, e_ids in state.items() @@ -388,152 +387,267 @@ class StateHandler(object): logger.info( "Resolving state for %s with %d groups", event.room_id, len(state_sets) ) - if event.is_state(): - return self._resolve_events( - state_sets, event.type, event.state_key - ) - else: - return self._resolve_events(state_sets) + state_set_ids = [{ + (ev.type, ev.state_key): ev.event_id + for ev in st + } for st in state_sets] + + state_map = { + ev.event_id: ev + for st in state_sets + for ev in st + } - def _resolve_events(self, state_sets, event_type=None, state_key=""): - """ - Returns - (dict[(str, str), synapse.events.FrozenEvent], list[str]): a tuple - (new_state, prev_states). new_state is a map from (type, state_key) - to event. prev_states is a list of event_ids. - """ with Measure(self.clock, "state._resolve_events"): - state = {} - for st in state_sets: - for e in st: - state.setdefault( - (e.type, e.state_key), - {} - )[e.event_id] = e - - unconflicted_state = { - k: v.values()[0] for k, v in state.items() - if len(v.values()) == 1 - } + new_state = resolve_events(state_set_ids, state_map) - conflicted_state = { - k: v.values() - for k, v in state.items() - if len(v.values()) > 1 - } + new_state = { + key: state_map[ev_id] for key, ev_id in new_state.items() + } - if event_type: - prev_states_events = conflicted_state.get( - (event_type, state_key), [] - ) - prev_states = [s.event_id for s in prev_states_events] - else: - prev_states = [] + return new_state - auth_events = { - k: e for k, e in unconflicted_state.items() - if k[0] in AuthEventTypes - } - try: - resolved_state = self._resolve_state_events( - conflicted_state, auth_events - ) - except: - logger.exception("Failed to resolve state") - raise +def _ordered_events(events): + def key_func(e): + return -int(e.depth), hashlib.sha1(e.event_id).hexdigest() - new_state = unconflicted_state - new_state.update(resolved_state) + return sorted(events, key=key_func) - return new_state, prev_states - @log_function - def _resolve_state_events(self, conflicted_state, auth_events): - """ This is where we actually decide which of the conflicted state to - use. - - We resolve conflicts in the following order: - 1. power levels - 2. join rules - 3. memberships - 4. other events. - """ - resolved_state = {} - power_key = (EventTypes.PowerLevels, "") - if power_key in conflicted_state: - events = conflicted_state[power_key] - logger.debug("Resolving conflicted power levels %r", events) - resolved_state[power_key] = self._resolve_auth_events( - events, auth_events) - - auth_events.update(resolved_state) - - for key, events in conflicted_state.items(): - if key[0] == EventTypes.JoinRules: - logger.debug("Resolving conflicted join rules %r", events) - resolved_state[key] = self._resolve_auth_events( - events, - auth_events - ) - - auth_events.update(resolved_state) - - for key, events in conflicted_state.items(): - if key[0] == EventTypes.Member: - logger.debug("Resolving conflicted member lists %r", events) - resolved_state[key] = self._resolve_auth_events( - events, - auth_events - ) - - auth_events.update(resolved_state) - - for key, events in conflicted_state.items(): - if key not in resolved_state: - logger.debug("Resolving conflicted state %r:%r", key, events) - resolved_state[key] = self._resolve_normal_events( - events, auth_events - ) - - return resolved_state - - def _resolve_auth_events(self, events, auth_events): - reverse = [i for i in reversed(self._ordered_events(events))] - - auth_events = dict(auth_events) - - prev_event = reverse[0] - for event in reverse[1:]: - auth_events[(prev_event.type, prev_event.state_key)] = prev_event - try: - # FIXME: hs.get_auth() is bad style, but we need to do it to - # get around circular deps. - # The signatures have already been checked at this point - self.hs.get_auth().check(event, auth_events, do_sig_check=False) - prev_event = event - except AuthError: - return prev_event - - return event - - def _resolve_normal_events(self, events, auth_events): - for event in self._ordered_events(events): - try: - # FIXME: hs.get_auth() is bad style, but we need to do it to - # get around circular deps. - # The signatures have already been checked at this point - self.hs.get_auth().check(event, auth_events, do_sig_check=False) - return event - except AuthError: - pass - - # Use the last event (the one with the least depth) if they all fail - # the auth check. - return event - - def _ordered_events(self, events): - def key_func(e): - return -int(e.depth), hashlib.sha1(e.event_id).hexdigest() - - return sorted(events, key=key_func) +def resolve_events(state_sets, state_map_factory): + """ + Args: + state_sets(list): List of dicts of (type, state_key) -> event_id, + which are the different state groups to resolve. + state_map_factory(dict|callable): If callable, then will be called + with a list of event_ids that are needed, and should return with + a Deferred of dict of event_id to event. Otherwise, should be + a dict from event_id to event of all events in state_sets. + + Returns + dict[(str, str), synapse.events.FrozenEvent] is a map from + (type, state_key) to event. + """ + if len(state_sets) == 1: + return state_sets[0] + + unconflicted_state, conflicted_state = _seperate( + state_sets, + ) + + if callable(state_map_factory): + return _resolve_with_state_fac( + unconflicted_state, conflicted_state, state_map_factory + ) + + state_map = state_map_factory + + auth_events = _create_auth_events_from_maps( + unconflicted_state, conflicted_state, state_map + ) + + return _resolve_with_state( + unconflicted_state, conflicted_state, auth_events, state_map + ) + + +def _seperate(state_sets): + """Takes the state_sets and figures out which keys are conflicted and + which aren't. i.e., which have multiple different event_ids associated + with them in different state sets. + """ + unconflicted_state = dict(state_sets[0]) + conflicted_state = {} + + for state_set in state_sets[1:]: + for key, value in state_set.iteritems(): + # Check if there is an unconflicted entry for the state key. + unconflicted_value = unconflicted_state.get(key) + if unconflicted_value is None: + # There isn't an unconflicted entry so check if there is a + # conflicted entry. + ls = conflicted_state.get(key) + if ls is None: + # There wasn't a conflicted entry so haven't seen this key before. + # Therefore it isn't conflicted yet. + unconflicted_state[key] = value + else: + # This key is already conflicted, add our value to the conflict set. + ls.add(value) + elif unconflicted_value != value: + # If the unconflicted value is not the same as our value then we + # have a new conflict. So move the key from the unconflicted_state + # to the conflicted state. + conflicted_state[key] = {value, unconflicted_value} + unconflicted_state.pop(key, None) + + return unconflicted_state, conflicted_state + + +@defer.inlineCallbacks +def _resolve_with_state_fac(unconflicted_state, conflicted_state, + state_map_factory): + needed_events = set( + event_id + for event_ids in conflicted_state.itervalues() + for event_id in event_ids + ) + + logger.info("Asking for %d conflicted events", len(needed_events)) + + state_map = yield state_map_factory(needed_events) + + auth_events = _create_auth_events_from_maps( + unconflicted_state, conflicted_state, state_map + ) + + new_needed_events = set(auth_events.itervalues()) + new_needed_events -= needed_events + + logger.info("Asking for %d auth events", len(new_needed_events)) + + state_map_new = yield state_map_factory(new_needed_events) + state_map.update(state_map_new) + + defer.returnValue(_resolve_with_state( + unconflicted_state, conflicted_state, auth_events, state_map + )) + + +def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map): + auth_events = {} + for event_ids in conflicted_state.itervalues(): + for event_id in event_ids: + if event_id in state_map: + keys = event_auth.auth_types_for_event(state_map[event_id]) + for key in keys: + if key not in auth_events: + event_id = unconflicted_state.get(key, None) + if event_id: + auth_events[key] = event_id + return auth_events + + +def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_ids, + state_map): + conflicted_state = {} + for key, event_ids in conflicted_state_ds.iteritems(): + events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map] + if len(events) > 1: + conflicted_state[key] = events + elif len(events) == 1: + unconflicted_state_ids[key] = events[0].event_id + + auth_events = { + key: state_map[ev_id] + for key, ev_id in auth_event_ids.items() + if ev_id in state_map + } + + try: + resolved_state = _resolve_state_events( + conflicted_state, auth_events + ) + except: + logger.exception("Failed to resolve state") + raise + + new_state = unconflicted_state_ids + for key, event in resolved_state.iteritems(): + new_state[key] = event.event_id + + return new_state + + +def _resolve_state_events(conflicted_state, auth_events): + """ This is where we actually decide which of the conflicted state to + use. + + We resolve conflicts in the following order: + 1. power levels + 2. join rules + 3. memberships + 4. other events. + """ + resolved_state = {} + if POWER_KEY in conflicted_state: + events = conflicted_state[POWER_KEY] + logger.debug("Resolving conflicted power levels %r", events) + resolved_state[POWER_KEY] = _resolve_auth_events( + events, auth_events) + + auth_events.update(resolved_state) + + for key, events in conflicted_state.items(): + if key[0] == EventTypes.JoinRules: + logger.debug("Resolving conflicted join rules %r", events) + resolved_state[key] = _resolve_auth_events( + events, + auth_events + ) + + auth_events.update(resolved_state) + + for key, events in conflicted_state.items(): + if key[0] == EventTypes.Member: + logger.debug("Resolving conflicted member lists %r", events) + resolved_state[key] = _resolve_auth_events( + events, + auth_events + ) + + auth_events.update(resolved_state) + + for key, events in conflicted_state.items(): + if key not in resolved_state: + logger.debug("Resolving conflicted state %r:%r", key, events) + resolved_state[key] = _resolve_normal_events( + events, auth_events + ) + + return resolved_state + + +def _resolve_auth_events(events, auth_events): + reverse = [i for i in reversed(_ordered_events(events))] + + auth_keys = set( + key + for event in events + for key in event_auth.auth_types_for_event(event) + ) + + new_auth_events = {} + for key in auth_keys: + auth_event = auth_events.get(key, None) + if auth_event: + new_auth_events[key] = auth_event + + auth_events = new_auth_events + + prev_event = reverse[0] + for event in reverse[1:]: + auth_events[(prev_event.type, prev_event.state_key)] = prev_event + try: + # The signatures have already been checked at this point + event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False) + prev_event = event + except AuthError: + return prev_event + + return event + + +def _resolve_normal_events(events, auth_events): + for event in _ordered_events(events): + try: + # The signatures have already been checked at this point + event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False) + return event + except AuthError: + pass + + # Use the last event (the one with the least depth) if they all fail + # the auth check. + return event diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index fe936b3e62..b9968debe5 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -116,6 +116,9 @@ class DataStore(RoomMemberStore, RoomStore, self._public_room_id_gen = StreamIdGenerator( db_conn, "public_room_list_stream", "stream_id" ) + self._device_list_id_gen = StreamIdGenerator( + db_conn, "device_lists_stream", "stream_id", + ) self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id") @@ -189,7 +192,8 @@ class DataStore(RoomMemberStore, RoomStore, db_conn, "device_inbox", entity_column="user_id", stream_column="stream_id", - max_value=max_device_inbox_id + max_value=max_device_inbox_id, + limit=1000, ) self._device_inbox_stream_cache = StreamChangeCache( "DeviceInboxStreamChangeCache", min_device_inbox_id, @@ -202,12 +206,21 @@ class DataStore(RoomMemberStore, RoomStore, entity_column="destination", stream_column="stream_id", max_value=max_device_inbox_id, + limit=1000, ) self._device_federation_outbox_stream_cache = StreamChangeCache( "DeviceFederationOutboxStreamChangeCache", min_device_outbox_id, prefilled_cache=device_outbox_prefill, ) + device_list_max = self._device_list_id_gen.get_current_token() + self._device_list_stream_cache = StreamChangeCache( + "DeviceListStreamChangeCache", device_list_max, + ) + self._device_list_federation_stream_cache = StreamChangeCache( + "DeviceListFederationStreamChangeCache", device_list_max, + ) + cur = LoggingTransaction( db_conn.cursor(), name="_find_stream_orderings_for_times_txn", diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index b62c459d8b..05374682fd 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -169,7 +169,7 @@ class SQLBaseStore(object): max_entries=hs.config.event_cache_size) self._state_group_cache = DictionaryCache( - "*stateGroupCache*", 2000 * CACHE_SIZE_FACTOR + "*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR ) self._event_fetch_lock = threading.Condition() @@ -387,6 +387,10 @@ class SQLBaseStore(object): Args: table : string giving the table name values : dict of new column names and values for them + + Returns: + bool: Whether the row was inserted or not. Only useful when + `or_ignore` is True """ try: yield self.runInteraction( @@ -398,6 +402,8 @@ class SQLBaseStore(object): # a cursor after we receive an error from the db. if not or_ignore: raise + defer.returnValue(False) + defer.returnValue(True) @staticmethod def _simple_insert_txn(txn, table, values): @@ -838,18 +844,19 @@ class SQLBaseStore(object): return txn.execute(sql, keyvalues.values()) def _get_cache_dict(self, db_conn, table, entity_column, stream_column, - max_value): + max_value, limit=100000): # Fetch a mapping of room_id -> max stream position for "recent" rooms. # It doesn't really matter how many we get, the StreamChangeCache will # do the right thing to ensure it respects the max size of cache. sql = ( "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s" - " WHERE %(stream)s > ? - 100000" + " WHERE %(stream)s > ? - %(limit)s" " GROUP BY %(entity)s" ) % { "table": table, "entity": entity_column, "stream": stream_column, + "limit": limit, } sql = self.database_engine.convert_param_style(sql) diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py index 2821eb89c9..bde3b5cbbc 100644 --- a/synapse/storage/deviceinbox.py +++ b/synapse/storage/deviceinbox.py @@ -18,13 +18,29 @@ import ujson from twisted.internet import defer -from ._base import SQLBaseStore +from .background_updates import BackgroundUpdateStore logger = logging.getLogger(__name__) -class DeviceInboxStore(SQLBaseStore): +class DeviceInboxStore(BackgroundUpdateStore): + DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" + + def __init__(self, hs): + super(DeviceInboxStore, self).__init__(hs) + + self.register_background_index_update( + "device_inbox_stream_index", + index_name="device_inbox_stream_id_user_id", + table="device_inbox", + columns=["stream_id", "user_id"], + ) + + self.register_background_update_handler( + self.DEVICE_INBOX_STREAM_ID, + self._background_drop_index_device_inbox, + ) @defer.inlineCallbacks def add_messages_to_device_inbox(self, local_messages_by_user_then_device, @@ -368,3 +384,18 @@ class DeviceInboxStore(SQLBaseStore): "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn ) + + @defer.inlineCallbacks + def _background_drop_index_device_inbox(self, progress, batch_size): + def reindex_txn(conn): + txn = conn.cursor() + txn.execute( + "DROP INDEX IF EXISTS device_inbox_stream_id" + ) + txn.close() + + yield self.runWithConnection(reindex_txn) + + yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID) + + defer.returnValue(1) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 17920d4480..8e17800364 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import ujson as json from twisted.internet import defer @@ -23,27 +24,29 @@ logger = logging.getLogger(__name__) class DeviceStore(SQLBaseStore): + def __init__(self, hs): + super(DeviceStore, self).__init__(hs) + + self._clock.looping_call( + self._prune_old_outbound_device_pokes, 60 * 60 * 1000 + ) + @defer.inlineCallbacks def store_device(self, user_id, device_id, - initial_device_display_name, - ignore_if_known=True): + initial_device_display_name): """Ensure the given device is known; add it to the store if not Args: user_id (str): id of user associated with the device device_id (str): id of device initial_device_display_name (str): initial displayname of the - device - ignore_if_known (bool): ignore integrity errors which mean the - device is already known + device. Ignored if device exists. Returns: - defer.Deferred - Raises: - StoreError: if ignore_if_known is False and the device was already - known + defer.Deferred: boolean whether the device was inserted or an + existing device existed with that ID. """ try: - yield self._simple_insert( + inserted = yield self._simple_insert( "devices", values={ "user_id": user_id, @@ -51,8 +54,9 @@ class DeviceStore(SQLBaseStore): "display_name": initial_device_display_name }, desc="store_device", - or_ignore=ignore_if_known, + or_ignore=True, ) + defer.returnValue(inserted) except Exception as e: logger.error("store_device with device_id=%s(%r) user_id=%s(%r)" " display_name=%s(%r) failed: %s", @@ -139,3 +143,451 @@ class DeviceStore(SQLBaseStore): ) defer.returnValue({d["device_id"]: d for d in devices}) + + def get_device_list_last_stream_id_for_remote(self, user_id): + """Get the last stream_id we got for a user. May be None if we haven't + got any information for them. + """ + return self._simple_select_one_onecol( + table="device_lists_remote_extremeties", + keyvalues={"user_id": user_id}, + retcol="stream_id", + desc="get_device_list_remote_extremity", + allow_none=True, + ) + + def mark_remote_user_device_list_as_unsubscribed(self, user_id): + """Mark that we no longer track device lists for remote user. + """ + return self._simple_delete( + table="device_lists_remote_extremeties", + keyvalues={ + "user_id": user_id, + }, + desc="mark_remote_user_device_list_as_unsubscribed", + ) + + def update_remote_device_list_cache_entry(self, user_id, device_id, content, + stream_id): + """Updates a single user's device in the cache. + """ + return self.runInteraction( + "update_remote_device_list_cache_entry", + self._update_remote_device_list_cache_entry_txn, + user_id, device_id, content, stream_id, + ) + + def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id, + content, stream_id): + self._simple_upsert_txn( + txn, + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + values={ + "content": json.dumps(content), + } + ) + + self._simple_upsert_txn( + txn, + table="device_lists_remote_extremeties", + keyvalues={ + "user_id": user_id, + }, + values={ + "stream_id": stream_id, + } + ) + + def update_remote_device_list_cache(self, user_id, devices, stream_id): + """Replace the cache of the remote user's devices. + """ + return self.runInteraction( + "update_remote_device_list_cache", + self._update_remote_device_list_cache_txn, + user_id, devices, stream_id, + ) + + def _update_remote_device_list_cache_txn(self, txn, user_id, devices, + stream_id): + self._simple_delete_txn( + txn, + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + }, + ) + + self._simple_insert_many_txn( + txn, + table="device_lists_remote_cache", + values=[ + { + "user_id": user_id, + "device_id": content["device_id"], + "content": json.dumps(content), + } + for content in devices + ] + ) + + self._simple_upsert_txn( + txn, + table="device_lists_remote_extremeties", + keyvalues={ + "user_id": user_id, + }, + values={ + "stream_id": stream_id, + } + ) + + def get_devices_by_remote(self, destination, from_stream_id): + """Get stream of updates to send to remote servers + + Returns: + (now_stream_id, [ { updates }, .. ]) + """ + now_stream_id = self._device_list_id_gen.get_current_token() + + has_changed = self._device_list_federation_stream_cache.has_entity_changed( + destination, int(from_stream_id) + ) + if not has_changed: + return (now_stream_id, []) + + return self.runInteraction( + "get_devices_by_remote", self._get_devices_by_remote_txn, + destination, from_stream_id, now_stream_id, + ) + + def _get_devices_by_remote_txn(self, txn, destination, from_stream_id, + now_stream_id): + sql = """ + SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes + WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ? + GROUP BY user_id, device_id + """ + txn.execute( + sql, (destination, from_stream_id, now_stream_id, False) + ) + rows = txn.fetchall() + + if not rows: + return (now_stream_id, []) + + # maps (user_id, device_id) -> stream_id + query_map = {(r[0], r[1]): r[2] for r in rows} + devices = self._get_e2e_device_keys_txn( + txn, query_map.keys(), include_all_devices=True + ) + + prev_sent_id_sql = """ + SELECT coalesce(max(stream_id), 0) as stream_id + FROM device_lists_outbound_pokes + WHERE destination = ? AND user_id = ? AND stream_id <= ? + """ + + results = [] + for user_id, user_devices in devices.iteritems(): + # The prev_id for the first row is always the last row before + # `from_stream_id` + txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id)) + rows = txn.fetchall() + prev_id = rows[0][0] + for device_id, device in user_devices.iteritems(): + stream_id = query_map[(user_id, device_id)] + result = { + "user_id": user_id, + "device_id": device_id, + "prev_id": [prev_id] if prev_id else [], + "stream_id": stream_id, + } + + prev_id = stream_id + + key_json = device.get("key_json", None) + if key_json: + result["keys"] = json.loads(key_json) + device_display_name = device.get("device_display_name", None) + if device_display_name: + result["device_display_name"] = device_display_name + + results.append(result) + + return (now_stream_id, results) + + def get_user_devices_from_cache(self, query_list): + """Get the devices (and keys if any) for remote users from the cache. + + Args: + query_list(list): List of (user_id, device_ids), if device_ids is + falsey then return all device ids for that user. + + Returns: + (user_ids_not_in_cache, results_map), where user_ids_not_in_cache is + a set of user_ids and results_map is a mapping of + user_id -> device_id -> device_info + """ + return self.runInteraction( + "get_user_devices_from_cache", self._get_user_devices_from_cache_txn, + query_list, + ) + + def _get_user_devices_from_cache_txn(self, txn, query_list): + user_ids = {user_id for user_id, _ in query_list} + + user_ids_in_cache = set() + for user_id in user_ids: + stream_ids = self._simple_select_onecol_txn( + txn, + table="device_lists_remote_extremeties", + keyvalues={ + "user_id": user_id, + }, + retcol="stream_id", + ) + if stream_ids: + user_ids_in_cache.add(user_id) + + user_ids_not_in_cache = user_ids - user_ids_in_cache + + results = {} + for user_id, device_id in query_list: + if user_id not in user_ids_in_cache: + continue + + if device_id: + content = self._simple_select_one_onecol_txn( + txn, + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + retcol="content", + ) + results.setdefault(user_id, {})[device_id] = json.loads(content) + else: + devices = self._simple_select_list_txn( + txn, + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + }, + retcols=("device_id", "content"), + ) + results[user_id] = { + device["device_id"]: json.loads(device["content"]) + for device in devices + } + user_ids_in_cache.discard(user_id) + + return user_ids_not_in_cache, results + + def get_devices_with_keys_by_user(self, user_id): + """Get all devices (with any device keys) for a user + + Returns: + (stream_id, devices) + """ + return self.runInteraction( + "get_devices_with_keys_by_user", + self._get_devices_with_keys_by_user_txn, user_id, + ) + + def _get_devices_with_keys_by_user_txn(self, txn, user_id): + now_stream_id = self._device_list_id_gen.get_current_token() + + devices = self._get_e2e_device_keys_txn( + txn, [(user_id, None)], include_all_devices=True + ) + + if devices: + user_devices = devices[user_id] + results = [] + for device_id, device in user_devices.iteritems(): + result = { + "device_id": device_id, + } + + key_json = device.get("key_json", None) + if key_json: + result["keys"] = json.loads(key_json) + device_display_name = device.get("device_display_name", None) + if device_display_name: + result["device_display_name"] = device_display_name + + results.append(result) + + return now_stream_id, results + + return now_stream_id, [] + + def mark_as_sent_devices_by_remote(self, destination, stream_id): + """Mark that updates have successfully been sent to the destination. + """ + return self.runInteraction( + "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn, + destination, stream_id, + ) + + def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id): + # First we DELETE all rows such that only the latest row for each + # (destination, user_id is left. We do this by selecting first and + # deleting. + sql = """ + SELECT user_id, coalesce(max(stream_id), 0) FROM device_lists_outbound_pokes + WHERE destination = ? AND stream_id <= ? + GROUP BY user_id + HAVING count(*) > 1 + """ + txn.execute(sql, (destination, stream_id,)) + rows = txn.fetchall() + + sql = """ + DELETE FROM device_lists_outbound_pokes + WHERE destination = ? AND user_id = ? AND stream_id < ? + """ + txn.executemany( + sql, ((destination, row[0], row[1],) for row in rows) + ) + + # Mark everything that is left as sent + sql = """ + UPDATE device_lists_outbound_pokes SET sent = ? + WHERE destination = ? AND stream_id <= ? + """ + txn.execute(sql, (True, destination, stream_id,)) + + @defer.inlineCallbacks + def get_user_whose_devices_changed(self, from_key): + """Get set of users whose devices have changed since `from_key`. + """ + from_key = int(from_key) + changed = self._device_list_stream_cache.get_all_entities_changed(from_key) + if changed is not None: + defer.returnValue(set(changed)) + + sql = """ + SELECT user_id FROM device_lists_stream WHERE stream_id > ? + """ + rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key) + defer.returnValue(set(row[0] for row in rows)) + + def get_all_device_list_changes_for_remotes(self, from_key): + """Return a list of `(stream_id, user_id, destination)` which is the + combined list of changes to devices, and which destinations need to be + poked. `destination` may be None if no destinations need to be poked. + """ + sql = """ + SELECT stream_id, user_id, destination FROM device_lists_stream + LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id) + WHERE stream_id > ? + """ + return self._execute( + "get_users_and_hosts_device_list", None, + sql, from_key, + ) + + @defer.inlineCallbacks + def add_device_change_to_streams(self, user_id, device_ids, hosts): + """Persist that a user's devices have been updated, and which hosts + (if any) should be poked. + """ + with self._device_list_id_gen.get_next() as stream_id: + yield self.runInteraction( + "add_device_change_to_streams", self._add_device_change_txn, + user_id, device_ids, hosts, stream_id, + ) + defer.returnValue(stream_id) + + def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id): + now = self._clock.time_msec() + + txn.call_after( + self._device_list_stream_cache.entity_has_changed, + user_id, stream_id, + ) + for host in hosts: + txn.call_after( + self._device_list_federation_stream_cache.entity_has_changed, + host, stream_id, + ) + + self._simple_insert_many_txn( + txn, + table="device_lists_stream", + values=[ + { + "stream_id": stream_id, + "user_id": user_id, + "device_id": device_id, + } + for device_id in device_ids + ] + ) + + self._simple_insert_many_txn( + txn, + table="device_lists_outbound_pokes", + values=[ + { + "destination": destination, + "stream_id": stream_id, + "user_id": user_id, + "device_id": device_id, + "sent": False, + "ts": now, + } + for destination in hosts + for device_id in device_ids + ] + ) + + def get_device_stream_token(self): + return self._device_list_id_gen.get_current_token() + + def _prune_old_outbound_device_pokes(self): + """Delete old entries out of the device_lists_outbound_pokes to ensure + that we don't fill up due to dead servers. We keep one entry per + (destination, user_id) tuple to ensure that the prev_ids remain correct + if the server does come back. + """ + yesterday = self._clock.time_msec() - 24 * 60 * 60 * 1000 + + def _prune_txn(txn): + select_sql = """ + SELECT destination, user_id, max(stream_id) as stream_id + FROM device_lists_outbound_pokes + GROUP BY destination, user_id + HAVING min(ts) < ? AND count(*) > 1 + """ + + txn.execute(select_sql, (yesterday,)) + rows = txn.fetchall() + + if not rows: + return + + delete_sql = """ + DELETE FROM device_lists_outbound_pokes + WHERE ts < ? AND destination = ? AND user_id = ? AND stream_id < ? + """ + + txn.executemany( + delete_sql, + ( + (yesterday, row[0], row[1], row[2]) + for row in rows + ) + ) + + logger.info("Pruned %d device list outbound pokes", txn.rowcount) + + return self.runInteraction( + "_prune_old_outbound_device_pokes", _prune_txn + ) diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 385d607056..2040e022fa 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -12,74 +12,111 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import collections +from twisted.internet import defer -import twisted.internet.defer +from canonicaljson import encode_canonical_json +import ujson as json from ._base import SQLBaseStore class EndToEndKeyStore(SQLBaseStore): - def set_e2e_device_keys(self, user_id, device_id, time_now, json_bytes): - return self._simple_upsert( - table="e2e_device_keys_json", - keyvalues={ - "user_id": user_id, - "device_id": device_id, - }, - values={ - "ts_added_ms": time_now, - "key_json": json_bytes, - } + def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys): + """Stores device keys for a device. Returns whether there was a change + or the keys were already in the database. + """ + def _set_e2e_device_keys_txn(txn): + old_key_json = self._simple_select_one_onecol_txn( + txn, + table="e2e_device_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + retcol="key_json", + allow_none=True, + ) + + new_key_json = encode_canonical_json(device_keys) + if old_key_json == new_key_json: + return False + + self._simple_upsert_txn( + txn, + table="e2e_device_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + values={ + "ts_added_ms": time_now, + "key_json": new_key_json, + } + ) + + return True + + return self.runInteraction( + "set_e2e_device_keys", _set_e2e_device_keys_txn ) - def get_e2e_device_keys(self, query_list): + @defer.inlineCallbacks + def get_e2e_device_keys(self, query_list, include_all_devices=False): """Fetch a list of device keys. Args: query_list(list): List of pairs of user_ids and device_ids. + include_all_devices (bool): whether to include entries for devices + that don't have device keys Returns: Dict mapping from user-id to dict mapping from device_id to dict containing "key_json", "device_display_name". """ if not query_list: - return {} + defer.returnValue({}) - return self.runInteraction( - "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list + results = yield self.runInteraction( + "get_e2e_device_keys", self._get_e2e_device_keys_txn, + query_list, include_all_devices, ) - def _get_e2e_device_keys_txn(self, txn, query_list): + for user_id, device_keys in results.iteritems(): + for device_id, device_info in device_keys.iteritems(): + device_info["keys"] = json.loads(device_info.pop("key_json")) + + defer.returnValue(results) + + def _get_e2e_device_keys_txn(self, txn, query_list, include_all_devices): query_clauses = [] query_params = [] for (user_id, device_id) in query_list: - query_clause = "k.user_id = ?" + query_clause = "user_id = ?" query_params.append(user_id) if device_id: - query_clause += " AND k.device_id = ?" + query_clause += " AND device_id = ?" query_params.append(device_id) query_clauses.append(query_clause) sql = ( - "SELECT k.user_id, k.device_id, " + "SELECT user_id, device_id, " " d.display_name AS device_display_name, " " k.key_json" - " FROM e2e_device_keys_json k" - " LEFT JOIN devices d ON d.user_id = k.user_id" - " AND d.device_id = k.device_id" + " FROM devices d" + " %s JOIN e2e_device_keys_json k USING (user_id, device_id)" " WHERE %s" ) % ( + "LEFT" if include_all_devices else "INNER", " OR ".join("(" + q + ")" for q in query_clauses) ) txn.execute(sql, query_params) rows = self.cursor_to_dict(txn) - result = collections.defaultdict(dict) + result = {} for row in rows: - result[row["user_id"]][row["device_id"]] = row + result.setdefault(row["user_id"], {})[row["device_id"]] = row return result @@ -152,7 +189,7 @@ class EndToEndKeyStore(SQLBaseStore): "claim_e2e_one_time_keys", _claim_e2e_one_time_keys ) - @twisted.internet.defer.inlineCallbacks + @defer.inlineCallbacks def delete_e2e_keys_by_device(self, user_id, device_id): yield self._simple_delete( table="e2e_device_keys_json", diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 53feaa1960..ee88c61954 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -129,7 +129,7 @@ class EventFederationStore(SQLBaseStore): room_id, ) - @cached() + @cached(max_entries=5000, iterable=True) def get_latest_event_ids_in_room(self, room_id): return self._simple_select_onecol( table="event_forward_extremities", @@ -235,80 +235,21 @@ class EventFederationStore(SQLBaseStore): ], ) - self._update_extremeties(txn, events) + self._update_backward_extremeties(txn, events) - def _update_extremeties(self, txn, events): - """Updates the event_*_extremities tables based on the new/updated + def _update_backward_extremeties(self, txn, events): + """Updates the event_backward_extremities tables based on the new/updated events being persisted. This is called for new events *and* for events that were outliers, but - are are now being persisted as non-outliers. + are now being persisted as non-outliers. + + Forward extremities are handled when we first start persisting the events. """ events_by_room = {} for ev in events: events_by_room.setdefault(ev.room_id, []).append(ev) - for room_id, room_events in events_by_room.items(): - prevs = [ - e_id for ev in room_events for e_id, _ in ev.prev_events - if not ev.internal_metadata.is_outlier() - ] - if prevs: - txn.execute( - "DELETE FROM event_forward_extremities" - " WHERE room_id = ?" - " AND event_id in (%s)" % ( - ",".join(["?"] * len(prevs)), - ), - [room_id] + prevs, - ) - - query = ( - "INSERT INTO event_forward_extremities (event_id, room_id)" - " SELECT ?, ? WHERE NOT EXISTS (" - " SELECT 1 FROM event_edges WHERE prev_event_id = ?" - " )" - ) - - txn.executemany( - query, - [ - (ev.event_id, ev.room_id, ev.event_id) for ev in events - if not ev.internal_metadata.is_outlier() - ] - ) - - # We now insert into stream_ordering_to_exterm a mapping from room_id, - # new stream_ordering to new forward extremeties in the room. - # This allows us to later efficiently look up the forward extremeties - # for a room before a given stream_ordering - max_stream_ord = max( - ev.internal_metadata.stream_ordering for ev in events - ) - new_extrem = {} - for room_id in events_by_room: - event_ids = self._simple_select_onecol_txn( - txn, - table="event_forward_extremities", - keyvalues={"room_id": room_id}, - retcol="event_id", - ) - new_extrem[room_id] = event_ids - - self._simple_insert_many_txn( - txn, - table="stream_ordering_to_exterm", - values=[ - { - "room_id": room_id, - "event_id": event_id, - "stream_ordering": max_stream_ord, - } - for room_id, extrem_evs in new_extrem.items() - for event_id in extrem_evs - ] - ) - query = ( "INSERT INTO event_backward_extremities (event_id, room_id)" " SELECT ?, ? WHERE NOT EXISTS (" @@ -339,11 +280,6 @@ class EventFederationStore(SQLBaseStore): ] ) - for room_id in events_by_room: - txn.call_after( - self.get_latest_event_ids_in_room.invalidate, (room_id,) - ) - def get_forward_extremeties_for_room(self, room_id, stream_ordering): # We want to make the cache more effective, so we clamp to the last # change before the given ordering. diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 04dbdac3f8..8659f605a5 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, _RollbackButIsFineException +from ._base import SQLBaseStore from twisted.internet import defer, reactor @@ -27,6 +27,8 @@ from synapse.util.logutils import log_function from synapse.util.metrics import Measure from synapse.api.constants import EventTypes from synapse.api.errors import SynapseError +from synapse.state import resolve_events +from synapse.util.caches.descriptors import cached from canonicaljson import encode_canonical_json from collections import deque, namedtuple, OrderedDict @@ -71,22 +73,19 @@ class _EventPeristenceQueue(object): """ _EventPersistQueueItem = namedtuple("_EventPersistQueueItem", ( - "events_and_contexts", "current_state", "backfilled", "deferred", + "events_and_contexts", "backfilled", "deferred", )) def __init__(self): self._event_persist_queues = {} self._currently_persisting_rooms = set() - def add_to_queue(self, room_id, events_and_contexts, backfilled, current_state): + def add_to_queue(self, room_id, events_and_contexts, backfilled): """Add events to the queue, with the given persist_event options. """ queue = self._event_persist_queues.setdefault(room_id, deque()) if queue: end_item = queue[-1] - if end_item.current_state or current_state: - # We perist events with current_state set to True one at a time - pass if end_item.backfilled == backfilled: end_item.events_and_contexts.extend(events_and_contexts) return end_item.deferred.observe() @@ -96,7 +95,6 @@ class _EventPeristenceQueue(object): queue.append(self._EventPersistQueueItem( events_and_contexts=events_and_contexts, backfilled=backfilled, - current_state=current_state, deferred=deferred, )) @@ -216,7 +214,6 @@ class EventsStore(SQLBaseStore): d = preserve_fn(self._event_persist_queue.add_to_queue)( room_id, evs_ctxs, backfilled=backfilled, - current_state=None, ) deferreds.append(d) @@ -229,11 +226,10 @@ class EventsStore(SQLBaseStore): @defer.inlineCallbacks @log_function - def persist_event(self, event, context, current_state=None, backfilled=False): + def persist_event(self, event, context, backfilled=False): deferred = self._event_persist_queue.add_to_queue( event.room_id, [(event, context)], backfilled=backfilled, - current_state=current_state, ) self._maybe_start_persisting(event.room_id) @@ -246,21 +242,10 @@ class EventsStore(SQLBaseStore): def _maybe_start_persisting(self, room_id): @defer.inlineCallbacks def persisting_queue(item): - if item.current_state: - for event, context in item.events_and_contexts: - # There should only ever be one item in - # events_and_contexts when current_state is - # not None - yield self._persist_event( - event, context, - current_state=item.current_state, - backfilled=item.backfilled, - ) - else: - yield self._persist_events( - item.events_and_contexts, - backfilled=item.backfilled, - ) + yield self._persist_events( + item.events_and_contexts, + backfilled=item.backfilled, + ) self._event_persist_queue.handle_queue(room_id, persisting_queue) @@ -294,35 +279,183 @@ class EventsStore(SQLBaseStore): for chunk in chunks: # We can't easily parallelize these since different chunks # might contain the same event. :( + + # NB: Assumes that we are only persisting events for one room + # at a time. + new_forward_extremeties = {} + current_state_for_room = {} + if not backfilled: + with Measure(self._clock, "_calculate_state_and_extrem"): + # Work out the new "current state" for each room. + # We do this by working out what the new extremities are and then + # calculating the state from that. + events_by_room = {} + for event, context in chunk: + events_by_room.setdefault(event.room_id, []).append( + (event, context) + ) + + for room_id, ev_ctx_rm in events_by_room.items(): + # Work out new extremities by recursively adding and removing + # the new events. + latest_event_ids = yield self.get_latest_event_ids_in_room( + room_id + ) + new_latest_event_ids = yield self._calculate_new_extremeties( + room_id, [ev for ev, _ in ev_ctx_rm] + ) + + if new_latest_event_ids == set(latest_event_ids): + # No change in extremities, so no change in state + continue + + new_forward_extremeties[room_id] = new_latest_event_ids + + state = yield self._calculate_state_delta( + room_id, ev_ctx_rm, new_latest_event_ids + ) + if state: + current_state_for_room[room_id] = state + yield self.runInteraction( "persist_events", self._persist_events_txn, events_and_contexts=chunk, backfilled=backfilled, delete_existing=delete_existing, + current_state_for_room=current_state_for_room, + new_forward_extremeties=new_forward_extremeties, ) persist_event_counter.inc_by(len(chunk)) - @_retry_on_integrity_error @defer.inlineCallbacks - @log_function - def _persist_event(self, event, context, current_state=None, backfilled=False, - delete_existing=False): - try: - with self._stream_id_gen.get_next() as stream_ordering: - event.internal_metadata.stream_ordering = stream_ordering - yield self.runInteraction( - "persist_event", - self._persist_event_txn, - event=event, - context=context, - current_state=current_state, - backfilled=backfilled, - delete_existing=delete_existing, - ) - persist_event_counter.inc() - except _RollbackButIsFineException: - pass + def _calculate_new_extremeties(self, room_id, events): + """Calculates the new forward extremeties for a room given events to + persist. + + Assumes that we are only persisting events for one room at a time. + """ + latest_event_ids = yield self.get_latest_event_ids_in_room( + room_id + ) + new_latest_event_ids = set(latest_event_ids) + # First, add all the new events to the list + new_latest_event_ids.update( + event.event_id for event in events + if not event.internal_metadata.is_outlier() + ) + # Now remove all events that are referenced by the to-be-added events + new_latest_event_ids.difference_update( + e_id + for event in events + for e_id, _ in event.prev_events + if not event.internal_metadata.is_outlier() + ) + + # And finally remove any events that are referenced by previously added + # events. + rows = yield self._simple_select_many_batch( + table="event_edges", + column="prev_event_id", + iterable=list(new_latest_event_ids), + retcols=["prev_event_id"], + keyvalues={ + "room_id": room_id, + "is_state": False, + }, + desc="_calculate_new_extremeties", + ) + + new_latest_event_ids.difference_update( + row["prev_event_id"] for row in rows + ) + + defer.returnValue(new_latest_event_ids) + + @defer.inlineCallbacks + def _calculate_state_delta(self, room_id, events_context, new_latest_event_ids): + """Calculate the new state deltas for a room. + + Assumes that we are only persisting events for one room at a time. + + Returns: + 2-tuple (to_delete, to_insert) where both are state dicts, i.e. + (type, state_key) -> event_id. `to_delete` are the entries to + first be deleted from current_state_events, `to_insert` are entries + to insert. + May return None if there are no changes to be applied. + """ + # Now we need to work out the different state sets for + # each state extremities + state_sets = [] + missing_event_ids = [] + was_updated = False + for event_id in new_latest_event_ids: + # First search in the list of new events we're adding, + # and then use the current state from that + for ev, ctx in events_context: + if event_id == ev.event_id: + if ctx.current_state_ids is None: + raise Exception("Unknown current state") + state_sets.append(ctx.current_state_ids) + if ctx.delta_ids or hasattr(ev, "state_key"): + was_updated = True + break + else: + # If we couldn't find it, then we'll need to pull + # the state from the database + was_updated = True + missing_event_ids.append(event_id) + + if missing_event_ids: + # Now pull out the state for any missing events from DB + event_to_groups = yield self._get_state_group_for_events( + missing_event_ids, + ) + + groups = set(event_to_groups.values()) + group_to_state = yield self._get_state_for_groups(groups) + + state_sets.extend(group_to_state.values()) + + if not new_latest_event_ids: + current_state = {} + elif was_updated: + current_state = yield resolve_events( + state_sets, + state_map_factory=lambda ev_ids: self.get_events( + ev_ids, get_prev_content=False, check_redacted=False, + ), + ) + else: + return + + existing_state_rows = yield self._simple_select_list( + table="current_state_events", + keyvalues={"room_id": room_id}, + retcols=["event_id", "type", "state_key"], + desc="_calculate_state_delta", + ) + + existing_events = set(row["event_id"] for row in existing_state_rows) + new_events = set(ev_id for ev_id in current_state.itervalues()) + changed_events = existing_events ^ new_events + + if not changed_events: + return + + to_delete = { + (row["type"], row["state_key"]): row["event_id"] + for row in existing_state_rows + if row["event_id"] in changed_events + } + events_to_insert = (new_events - existing_events) + to_insert = { + key: ev_id for key, ev_id in current_state.iteritems() + if ev_id in events_to_insert + } + + defer.returnValue((to_delete, to_insert)) @defer.inlineCallbacks def get_event(self, event_id, check_redacted=True, @@ -381,52 +514,9 @@ class EventsStore(SQLBaseStore): defer.returnValue({e.event_id: e for e in events}) @log_function - def _persist_event_txn(self, txn, event, context, current_state, backfilled=False, - delete_existing=False): - # We purposefully do this first since if we include a `current_state` - # key, we *want* to update the `current_state_events` table - if current_state: - txn.call_after(self._get_current_state_for_key.invalidate_all) - txn.call_after(self.get_rooms_for_user.invalidate_all) - txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) - - # Add an entry to the current_state_resets table to record the point - # where we clobbered the current state - stream_order = event.internal_metadata.stream_ordering - self._simple_insert_txn( - txn, - table="current_state_resets", - values={"event_stream_ordering": stream_order} - ) - - self._simple_delete_txn( - txn, - table="current_state_events", - keyvalues={"room_id": event.room_id}, - ) - - for s in current_state: - self._simple_insert_txn( - txn, - "current_state_events", - { - "event_id": s.event_id, - "room_id": s.room_id, - "type": s.type, - "state_key": s.state_key, - } - ) - - return self._persist_events_txn( - txn, - [(event, context)], - backfilled=backfilled, - delete_existing=delete_existing, - ) - - @log_function def _persist_events_txn(self, txn, events_and_contexts, backfilled, - delete_existing=False): + delete_existing=False, current_state_for_room={}, + new_forward_extremeties={}): """Insert some number of room events into the necessary database tables. Rejected events are only inserted into the events table, the events_json table, @@ -436,6 +526,93 @@ class EventsStore(SQLBaseStore): If delete_existing is True then existing events will be purged from the database before insertion. This is useful when retrying due to IntegrityError. """ + max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering + for room_id, current_state_tuple in current_state_for_room.iteritems(): + to_delete, to_insert = current_state_tuple + txn.executemany( + "DELETE FROM current_state_events WHERE event_id = ?", + [(ev_id,) for ev_id in to_delete.itervalues()], + ) + + self._simple_insert_many_txn( + txn, + table="current_state_events", + values=[ + { + "event_id": ev_id, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + } + for key, ev_id in to_insert.iteritems() + ], + ) + + # Invalidate the various caches + + # Figure out the changes of membership to invalidate the + # `get_rooms_for_user` cache. + # We find out which membership events we may have deleted + # and which we have added, then we invlidate the caches for all + # those users. + members_changed = set( + state_key for ev_type, state_key in to_delete.iterkeys() + if ev_type == EventTypes.Member + ) + members_changed.update( + state_key for ev_type, state_key in to_insert.iterkeys() + if ev_type == EventTypes.Member + ) + + for member in members_changed: + self._invalidate_cache_and_stream( + txn, self.get_rooms_for_user, (member,) + ) + + self._invalidate_cache_and_stream( + txn, self.get_users_in_room, (room_id,) + ) + + for room_id, new_extrem in new_forward_extremeties.items(): + self._simple_delete_txn( + txn, + table="event_forward_extremities", + keyvalues={"room_id": room_id}, + ) + txn.call_after( + self.get_latest_event_ids_in_room.invalidate, (room_id,) + ) + + self._simple_insert_many_txn( + txn, + table="event_forward_extremities", + values=[ + { + "event_id": ev_id, + "room_id": room_id, + } + for room_id, new_extrem in new_forward_extremeties.items() + for ev_id in new_extrem + ], + ) + # We now insert into stream_ordering_to_exterm a mapping from room_id, + # new stream_ordering to new forward extremeties in the room. + # This allows us to later efficiently look up the forward extremeties + # for a room before a given stream_ordering + self._simple_insert_many_txn( + txn, + table="stream_ordering_to_exterm", + values=[ + { + "room_id": room_id, + "event_id": event_id, + "stream_ordering": max_stream_order, + } + for room_id, new_extrem in new_forward_extremeties.items() + for event_id in new_extrem + ] + ) + # Ensure that we don't have the same event twice. # Pick the earliest non-outlier if there is one, else the earliest one. new_events_and_contexts = OrderedDict() @@ -550,7 +727,7 @@ class EventsStore(SQLBaseStore): # Update the event_backward_extremities table now that this # event isn't an outlier any more. - self._update_extremeties(txn, [event]) + self._update_backward_extremeties(txn, [event]) events_and_contexts = [ ec for ec in events_and_contexts if ec[0] not in to_remove @@ -798,29 +975,6 @@ class EventsStore(SQLBaseStore): # to update the current state table return - for event, _ in state_events_and_contexts: - if event.internal_metadata.is_outlier(): - # Outlier events shouldn't clobber the current state. - continue - - txn.call_after( - self._get_current_state_for_key.invalidate, - (event.room_id, event.type, event.state_key,) - ) - - self._simple_upsert_txn( - txn, - "current_state_events", - keyvalues={ - "room_id": event.room_id, - "type": event.type, - "state_key": event.state_key, - }, - values={ - "event_id": event.event_id, - } - ) - return def _add_to_cache(self, txn, events_and_contexts): @@ -1084,10 +1238,10 @@ class EventsStore(SQLBaseStore): self._do_fetch ) - logger.info("Loading %d events", len(events)) + logger.debug("Loading %d events", len(events)) with PreserveLoggingContext(): rows = yield events_d - logger.info("Loaded %d events (%d rows)", len(events), len(rows)) + logger.debug("Loaded %d events (%d rows)", len(events), len(rows)) if not allow_rejected: rows[:] = [r for r in rows if not r["rejects"]] @@ -1418,6 +1572,7 @@ class EventsStore(SQLBaseStore): """The current minimum token that backfilled events have reached""" return -self._backfill_id_gen.get_current_token() + @cached(num_args=5, max_entries=10) def get_all_new_events(self, last_backfill_id, last_forward_id, current_backfill_id, current_forward_id, limit): """Get all the new events that have arrived at the server either as @@ -1450,15 +1605,6 @@ class EventsStore(SQLBaseStore): upper_bound = current_forward_id sql = ( - "SELECT event_stream_ordering FROM current_state_resets" - " WHERE ? < event_stream_ordering" - " AND event_stream_ordering <= ?" - " ORDER BY event_stream_ordering ASC" - ) - txn.execute(sql, (last_forward_id, upper_bound)) - state_resets = txn.fetchall() - - sql = ( "SELECT event_stream_ordering, event_id, state_group" " FROM ex_outlier_stream" " WHERE ? > event_stream_ordering" @@ -1469,7 +1615,6 @@ class EventsStore(SQLBaseStore): forward_ex_outliers = txn.fetchall() else: new_forward_events = [] - state_resets = [] forward_ex_outliers = [] sql = ( @@ -1509,7 +1654,6 @@ class EventsStore(SQLBaseStore): return AllNewEventsResult( new_forward_events, new_backfill_events, forward_ex_outliers, backward_ex_outliers, - state_resets, ) return self.runInteraction("get_all_new_events", get_all_new_events_txn) @@ -1735,5 +1879,4 @@ class EventsStore(SQLBaseStore): AllNewEventsResult = namedtuple("AllNewEventsResult", [ "new_forward_events", "new_backfill_events", "forward_ex_outliers", "backward_ex_outliers", - "state_resets" ]) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index e46ae6502e..b357f22be7 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 39 +SCHEMA_VERSION = 40 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 983a8ec52b..26be6060c3 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -413,6 +413,17 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): desc="user_delete_threepids", ) + def user_delete_threepid(self, user_id, medium, address): + return self._simple_delete( + "user_threepids", + keyvalues={ + "user_id": user_id, + "medium": medium, + "address": address, + }, + desc="user_delete_threepids", + ) + @defer.inlineCallbacks def count_all_users(self): """Counts all users registered on the homeserver.""" diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 5d18037c7c..545d3d3a99 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -66,8 +66,6 @@ class RoomMemberStore(SQLBaseStore): ) for event in events: - txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,)) - txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) txn.call_after( self._membership_stream_cache.entity_has_changed, event.state_key, event.internal_metadata.stream_ordering @@ -131,7 +129,7 @@ class RoomMemberStore(SQLBaseStore): with self._stream_id_gen.get_next() as stream_ordering: yield self.runInteraction("locally_reject_invite", f, stream_ordering) - @cached(max_entries=5000) + @cached(max_entries=500000, iterable=True) def get_users_in_room(self, room_id): def f(txn): @@ -220,7 +218,7 @@ class RoomMemberStore(SQLBaseStore): " ON e.event_id = c.event_id" " AND m.room_id = c.room_id" " AND m.user_id = c.state_key" - " WHERE %s" + " WHERE c.type = 'm.room.member' AND %s" ) % (where_clause,) txn.execute(sql, args) @@ -266,7 +264,7 @@ class RoomMemberStore(SQLBaseStore): " ON m.event_id = c.event_id " " AND m.room_id = c.room_id " " AND m.user_id = c.state_key" - " WHERE %(where)s" + " WHERE c.type = 'm.room.member' AND %(where)s" ) % { "where": where_clause, } @@ -276,12 +274,29 @@ class RoomMemberStore(SQLBaseStore): return rows - @cached(max_entries=5000) + @cached(max_entries=500000, iterable=True) def get_rooms_for_user(self, user_id): return self.get_rooms_for_user_where_membership_is( user_id, membership_list=[Membership.JOIN], ) + @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True) + def get_users_who_share_room_with_user(self, user_id, cache_context): + """Returns the set of users who share a room with `user_id` + """ + rooms = yield self.get_rooms_for_user( + user_id, on_invalidate=cache_context.invalidate, + ) + + user_who_share_room = set() + for room in rooms: + user_ids = yield self.get_users_in_room( + room.room_id, on_invalidate=cache_context.invalidate, + ) + user_who_share_room.update(user_ids) + + defer.returnValue(user_who_share_room) + def forget(self, user_id, room_id): """Indicate that user_id wishes to discard history for room_id.""" def f(txn): @@ -390,7 +405,8 @@ class RoomMemberStore(SQLBaseStore): room_id, state_group, state_ids, ) - @cachedInlineCallbacks(num_args=2, cache_context=True) + @cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True, + max_entries=100000) def _get_joined_users_from_context(self, room_id, state_group, current_state_ids, cache_context, event=None): # We don't use `state_group`, it's there so that we can cache based diff --git a/synapse/storage/schema/delta/40/current_state_idx.sql b/synapse/storage/schema/delta/40/current_state_idx.sql new file mode 100644 index 0000000000..7ffa189f39 --- /dev/null +++ b/synapse/storage/schema/delta/40/current_state_idx.sql @@ -0,0 +1,17 @@ +/* Copyright 2017 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('current_state_members_idx', '{}'); diff --git a/synapse/storage/schema/delta/40/device_inbox.sql b/synapse/storage/schema/delta/40/device_inbox.sql new file mode 100644 index 0000000000..b9fe1f0480 --- /dev/null +++ b/synapse/storage/schema/delta/40/device_inbox.sql @@ -0,0 +1,21 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- turn the pre-fill startup query into a index-only scan on postgresql. +INSERT into background_updates (update_name, progress_json) + VALUES ('device_inbox_stream_index', '{}'); + +INSERT into background_updates (update_name, progress_json, depends_on) + VALUES ('device_inbox_stream_drop', '{}', 'device_inbox_stream_index'); diff --git a/synapse/storage/schema/delta/40/device_list_streams.sql b/synapse/storage/schema/delta/40/device_list_streams.sql new file mode 100644 index 0000000000..54841b3843 --- /dev/null +++ b/synapse/storage/schema/delta/40/device_list_streams.sql @@ -0,0 +1,59 @@ +/* Copyright 2017 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Cache of remote devices. +CREATE TABLE device_lists_remote_cache ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + content TEXT NOT NULL +); + +CREATE INDEX device_lists_remote_cache_id ON device_lists_remote_cache(user_id, device_id); + + +-- The last update we got for a user. Empty if we're not receiving updates for +-- that user. +CREATE TABLE device_lists_remote_extremeties ( + user_id TEXT NOT NULL, + stream_id TEXT NOT NULL +); + +CREATE INDEX device_lists_remote_extremeties_id ON device_lists_remote_extremeties(user_id, stream_id); + + +-- Stream of device lists updates. Includes both local and remotes +CREATE TABLE device_lists_stream ( + stream_id BIGINT NOT NULL, + user_id TEXT NOT NULL, + device_id TEXT NOT NULL +); + +CREATE INDEX device_lists_stream_id ON device_lists_stream(stream_id, user_id); + + +-- The stream of updates to send to other servers. We keep at least one row +-- per user that was sent so that the prev_id for any new updates can be +-- calculated +CREATE TABLE device_lists_outbound_pokes ( + destination TEXT NOT NULL, + stream_id BIGINT NOT NULL, + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + sent BOOLEAN NOT NULL, + ts BIGINT NOT NULL -- So that in future we can clear out pokes to dead servers +); + +CREATE INDEX device_lists_outbound_pokes_id ON device_lists_outbound_pokes(destination, stream_id); +CREATE INDEX device_lists_outbound_pokes_user ON device_lists_outbound_pokes(destination, user_id); diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 7f466c40ac..1b3800eb6a 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -49,6 +49,7 @@ class StateStore(SQLBaseStore): STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" + CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" def __init__(self, hs): super(StateStore, self).__init__(hs) @@ -60,6 +61,13 @@ class StateStore(SQLBaseStore): self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state, ) + self.register_background_index_update( + self.CURRENT_STATE_INDEX_UPDATE_NAME, + index_name="current_state_events_member_index", + table="current_state_events", + columns=["state_key"], + where_clause="type='m.room.member'", + ) @defer.inlineCallbacks def get_state_groups_ids(self, room_id, event_ids): @@ -232,59 +240,7 @@ class StateStore(SQLBaseStore): return count - @defer.inlineCallbacks - def get_current_state(self, room_id, event_type=None, state_key=""): - if event_type and state_key is not None: - result = yield self.get_current_state_for_key( - room_id, event_type, state_key - ) - defer.returnValue(result) - - def f(txn): - sql = ( - "SELECT event_id FROM current_state_events" - " WHERE room_id = ? " - ) - - if event_type and state_key is not None: - sql += " AND type = ? AND state_key = ? " - args = (room_id, event_type, state_key) - elif event_type: - sql += " AND type = ?" - args = (room_id, event_type) - else: - args = (room_id, ) - - txn.execute(sql, args) - results = txn.fetchall() - - return [r[0] for r in results] - - event_ids = yield self.runInteraction("get_current_state", f) - events = yield self._get_events(event_ids, get_prev_content=False) - defer.returnValue(events) - - @defer.inlineCallbacks - def get_current_state_for_key(self, room_id, event_type, state_key): - event_ids = yield self._get_current_state_for_key(room_id, event_type, state_key) - events = yield self._get_events(event_ids, get_prev_content=False) - defer.returnValue(events) - - @cached(num_args=3) - def _get_current_state_for_key(self, room_id, event_type, state_key): - def f(txn): - sql = ( - "SELECT event_id FROM current_state_events" - " WHERE room_id = ? AND type = ? AND state_key = ?" - ) - - args = (room_id, event_type, state_key) - txn.execute(sql, args) - results = txn.fetchall() - return [r[0] for r in results] - return self.runInteraction("get_current_state_for_key", f) - - @cached(num_args=2, max_entries=1000) + @cached(num_args=2, max_entries=100000, iterable=True) def _get_state_group_from_group(self, group, types): raise NotImplementedError() diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 2dc24951c4..200d124632 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -244,6 +244,20 @@ class StreamStore(SQLBaseStore): defer.returnValue(results) + def get_rooms_that_changed(self, room_ids, from_key): + """Given a list of rooms and a token, return rooms where there may have + been changes. + + Args: + room_ids (list) + from_key (str): The room_key portion of a StreamToken + """ + from_key = RoomStreamToken.parse_stream_token(from_key).stream + return set( + room_id for room_id in room_ids + if self._events_stream_cache.has_entity_changed(room_id, from_key) + ) + @defer.inlineCallbacks def get_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0, order='DESC'): diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 4d44c3d4ca..91a59b0bae 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -44,6 +44,7 @@ class EventSources(object): def get_current_token(self): push_rules_key, _ = self.store.get_push_rules_stream_token() to_device_key = self.store.get_to_device_stream_token() + device_list_key = self.store.get_device_stream_token() token = StreamToken( room_key=( @@ -63,6 +64,7 @@ class EventSources(object): ), push_rules_key=push_rules_key, to_device_key=to_device_key, + device_list_key=device_list_key, ) defer.returnValue(token) @@ -70,6 +72,7 @@ class EventSources(object): def get_current_token_for_room(self, room_id): push_rules_key, _ = self.store.get_push_rules_stream_token() to_device_key = self.store.get_to_device_stream_token() + device_list_key = self.store.get_device_stream_token() token = StreamToken( room_key=( @@ -89,5 +92,6 @@ class EventSources(object): ), push_rules_key=push_rules_key, to_device_key=to_device_key, + device_list_key=device_list_key, ) defer.returnValue(token) diff --git a/synapse/types.py b/synapse/types.py index 3a3ab21d17..9666f9d73f 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -158,6 +158,7 @@ class StreamToken( "account_data_key", "push_rules_key", "to_device_key", + "device_list_key", )) ): _SEPARATOR = "_" @@ -195,6 +196,7 @@ class StreamToken( or (int(other.account_data_key) < int(self.account_data_key)) or (int(other.push_rules_key) < int(self.push_rules_key)) or (int(other.to_device_key) < int(self.to_device_key)) + or (int(other.device_list_key) < int(self.device_list_key)) ) def copy_and_advance(self, key, new_value): diff --git a/synapse/util/async.py b/synapse/util/async.py index 83875edc85..35380bf8ed 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -192,8 +192,11 @@ class Linearizer(object): logger.info( "Waiting to acquire linearizer lock %r for key %r", self.name, key ) - with PreserveLoggingContext(): - yield current_defer + try: + with PreserveLoggingContext(): + yield current_defer + except: + logger.exception("Unexpected exception in Linearizer") logger.info("Acquired linearizer lock %r for key %r", self.name, key) diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index ebd715c5dc..8a7774a88e 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -40,8 +40,8 @@ def register_cache(name, cache): ) -_string_cache = LruCache(int(5000 * CACHE_SIZE_FACTOR)) -caches_by_name["string_cache"] = _string_cache +_string_cache = LruCache(int(100000 * CACHE_SIZE_FACTOR)) +_stirng_cache_metrics = register_cache("string_cache", _string_cache) KNOWN_KEYS = { @@ -69,7 +69,12 @@ KNOWN_KEYS = { def intern_string(string): """Takes a (potentially) unicode string and interns using custom cache """ - return _string_cache.setdefault(string, string) + new_str = _string_cache.setdefault(string, string) + if new_str is string: + _stirng_cache_metrics.inc_hits() + else: + _stirng_cache_metrics.inc_misses() + return new_str def intern_dict(dictionary): diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 8dba61d49f..998de70d29 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -17,7 +17,7 @@ import logging from synapse.util.async import ObservableDeferred from synapse.util import unwrapFirstError from synapse.util.caches.lrucache import LruCache -from synapse.util.caches.treecache import TreeCache +from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry from synapse.util.logcontext import ( PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn ) @@ -42,6 +42,25 @@ _CacheSentinel = object() CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) +class CacheEntry(object): + __slots__ = [ + "deferred", "sequence", "callbacks", "invalidated" + ] + + def __init__(self, deferred, sequence, callbacks): + self.deferred = deferred + self.sequence = sequence + self.callbacks = set(callbacks) + self.invalidated = False + + def invalidate(self): + if not self.invalidated: + self.invalidated = True + for callback in self.callbacks: + callback() + self.callbacks.clear() + + class Cache(object): __slots__ = ( "cache", @@ -51,12 +70,16 @@ class Cache(object): "sequence", "thread", "metrics", + "_pending_deferred_cache", ) - def __init__(self, name, max_entries=1000, keylen=1, tree=False): + def __init__(self, name, max_entries=1000, keylen=1, tree=False, iterable=False): cache_type = TreeCache if tree else dict + self._pending_deferred_cache = cache_type() + self.cache = LruCache( - max_size=max_entries, keylen=keylen, cache_type=cache_type + max_size=max_entries, keylen=keylen, cache_type=cache_type, + size_callback=(lambda d: len(d.result)) if iterable else None, ) self.name = name @@ -76,7 +99,15 @@ class Cache(object): ) def get(self, key, default=_CacheSentinel, callback=None): - val = self.cache.get(key, _CacheSentinel, callback=callback) + callbacks = [callback] if callback else [] + val = self._pending_deferred_cache.get(key, _CacheSentinel) + if val is not _CacheSentinel: + if val.sequence == self.sequence: + val.callbacks.update(callbacks) + self.metrics.inc_hits() + return val.deferred + + val = self.cache.get(key, _CacheSentinel, callbacks=callbacks) if val is not _CacheSentinel: self.metrics.inc_hits() return val @@ -88,15 +119,39 @@ class Cache(object): else: return default - def update(self, sequence, key, value, callback=None): + def set(self, key, value, callback=None): + callbacks = [callback] if callback else [] self.check_thread() - if self.sequence == sequence: - # Only update the cache if the caches sequence number matches the - # number that the cache had before the SELECT was started (SYN-369) - self.prefill(key, value, callback=callback) + entry = CacheEntry( + deferred=value, + sequence=self.sequence, + callbacks=callbacks, + ) + + entry.callbacks.update(callbacks) + + existing_entry = self._pending_deferred_cache.pop(key, None) + if existing_entry: + existing_entry.invalidate() + + self._pending_deferred_cache[key] = entry + + def shuffle(result): + if self.sequence == entry.sequence: + existing_entry = self._pending_deferred_cache.pop(key, None) + if existing_entry is entry: + self.cache.set(key, entry.deferred, entry.callbacks) + else: + entry.invalidate() + else: + entry.invalidate() + return result + + entry.deferred.addCallback(shuffle) def prefill(self, key, value, callback=None): - self.cache.set(key, value, callback=callback) + callbacks = [callback] if callback else [] + self.cache.set(key, value, callbacks=callbacks) def invalidate(self, key): self.check_thread() @@ -108,6 +163,10 @@ class Cache(object): # Increment the sequence number so that any SELECT statements that # raced with the INSERT don't update the cache (SYN-369) self.sequence += 1 + entry = self._pending_deferred_cache.pop(key, None) + if entry: + entry.invalidate() + self.cache.pop(key, None) def invalidate_many(self, key): @@ -119,6 +178,11 @@ class Cache(object): self.sequence += 1 self.cache.del_multi(key) + entry_dict = self._pending_deferred_cache.pop(key, None) + if entry_dict is not None: + for entry in iterate_tree_cache_entry(entry_dict): + entry.invalidate() + def invalidate_all(self): self.check_thread() self.sequence += 1 @@ -155,7 +219,7 @@ class CacheDescriptor(object): """ def __init__(self, orig, max_entries=1000, num_args=1, tree=False, - inlineCallbacks=False, cache_context=False): + inlineCallbacks=False, cache_context=False, iterable=False): max_entries = int(max_entries * CACHE_SIZE_FACTOR) self.orig = orig @@ -169,6 +233,8 @@ class CacheDescriptor(object): self.num_args = num_args self.tree = tree + self.iterable = iterable + all_args = inspect.getargspec(orig) self.arg_names = all_args.args[1:num_args + 1] @@ -203,6 +269,7 @@ class CacheDescriptor(object): max_entries=self.max_entries, keylen=self.num_args, tree=self.tree, + iterable=self.iterable, ) @functools.wraps(self.orig) @@ -243,11 +310,6 @@ class CacheDescriptor(object): return preserve_context_over_deferred(observer) except KeyError: - # Get the sequence number of the cache before reading from the - # database so that we can tell if the cache is invalidated - # while the SELECT is executing (SYN-369) - sequence = cache.sequence - ret = defer.maybeDeferred( preserve_context_over_fn, self.function_to_call, @@ -261,7 +323,7 @@ class CacheDescriptor(object): ret.addErrback(onErr) ret = ObservableDeferred(ret, consumeErrors=True) - cache.update(sequence, cache_key, ret, callback=invalidate_callback) + cache.set(cache_key, ret, callback=invalidate_callback) return preserve_context_over_deferred(ret.observe()) @@ -359,7 +421,6 @@ class CacheListDescriptor(object): missing.append(arg) if missing: - sequence = cache.sequence args_to_call = dict(arg_dict) args_to_call[self.list_name] = missing @@ -382,8 +443,8 @@ class CacheListDescriptor(object): key = list(keyargs) key[self.list_pos] = arg - cache.update( - sequence, tuple(key), observer, + cache.set( + tuple(key), observer, callback=invalidate_callback ) @@ -417,21 +478,29 @@ class CacheListDescriptor(object): class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))): + # We rely on _CacheContext implementing __eq__ and __hash__ sensibly, + # which namedtuple does for us (i.e. two _CacheContext are the same if + # their caches and keys match). This is important in particular to + # dedupe when we add callbacks to lru cache nodes, otherwise the number + # of callbacks would grow. def invalidate(self): self.cache.invalidate(self.key) -def cached(max_entries=1000, num_args=1, tree=False, cache_context=False): +def cached(max_entries=1000, num_args=1, tree=False, cache_context=False, + iterable=False): return lambda orig: CacheDescriptor( orig, max_entries=max_entries, num_args=num_args, tree=tree, cache_context=cache_context, + iterable=iterable, ) -def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False): +def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False, + iterable=False): return lambda orig: CacheDescriptor( orig, max_entries=max_entries, @@ -439,6 +508,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_contex tree=tree, inlineCallbacks=True, cache_context=cache_context, + iterable=iterable, ) diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index b0ca1bb79d..cb6933c61c 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -23,7 +23,9 @@ import logging logger = logging.getLogger(__name__) -DictionaryEntry = namedtuple("DictionaryEntry", ("full", "value")) +class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "value"))): + def __len__(self): + return len(self.value) class DictionaryCache(object): @@ -32,7 +34,7 @@ class DictionaryCache(object): """ def __init__(self, name, max_entries=1000): - self.cache = LruCache(max_size=max_entries) + self.cache = LruCache(max_size=max_entries, size_callback=len) self.name = name self.sequence = 0 diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index 080388958f..2987c38a2d 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -15,6 +15,7 @@ from synapse.util.caches import register_cache +from collections import OrderedDict import logging @@ -23,7 +24,7 @@ logger = logging.getLogger(__name__) class ExpiringCache(object): def __init__(self, cache_name, clock, max_len=0, expiry_ms=0, - reset_expiry_on_get=False): + reset_expiry_on_get=False, iterable=False): """ Args: cache_name (str): Name of this cache, used for logging. @@ -36,6 +37,8 @@ class ExpiringCache(object): evicted based on time. reset_expiry_on_get (bool): If true, will reset the expiry time for an item on access. Defaults to False. + iterable (bool): If true, the size is calculated by summing the + sizes of all entries, rather than the number of entries. """ self._cache_name = cache_name @@ -47,9 +50,13 @@ class ExpiringCache(object): self._reset_expiry_on_get = reset_expiry_on_get - self._cache = {} + self._cache = OrderedDict() - self.metrics = register_cache(cache_name, self._cache) + self.metrics = register_cache(cache_name, self) + + self.iterable = iterable + + self._size_estimate = 0 def start(self): if not self._expiry_ms: @@ -65,15 +72,14 @@ class ExpiringCache(object): now = self._clock.time_msec() self._cache[key] = _CacheEntry(now, value) - # Evict if there are now too many items - if self._max_len and len(self._cache.keys()) > self._max_len: - sorted_entries = sorted( - self._cache.items(), - key=lambda item: item[1].time, - ) + if self.iterable: + self._size_estimate += len(value) - for k, _ in sorted_entries[self._max_len:]: - self._cache.pop(k) + # Evict if there are now too many items + while self._max_len and len(self) > self._max_len: + _key, value = self._cache.popitem(last=False) + if self.iterable: + self._size_estimate -= len(value.value) def __getitem__(self, key): try: @@ -99,7 +105,7 @@ class ExpiringCache(object): # zero expiry time means don't expire. This should never get called # since we have this check in start too. return - begin_length = len(self._cache) + begin_length = len(self) now = self._clock.time_msec() @@ -110,15 +116,20 @@ class ExpiringCache(object): keys_to_delete.add(key) for k in keys_to_delete: - self._cache.pop(k) + value = self._cache.pop(k) + if self.iterable: + self._size_estimate -= len(value.value) logger.debug( "[%s] _prune_cache before: %d, after len: %d", - self._cache_name, begin_length, len(self._cache) + self._cache_name, begin_length, len(self) ) def __len__(self): - return len(self._cache) + if self.iterable: + return self._size_estimate + else: + return len(self._cache) class _CacheEntry(object): diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 9c4c679175..cf5fbb679c 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -49,7 +49,7 @@ class LruCache(object): Can also set callbacks on objects when getting/setting which are fired when that key gets invalidated/evicted. """ - def __init__(self, max_size, keylen=1, cache_type=dict): + def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None): cache = cache_type() self.cache = cache # Used for introspection. list_root = _Node(None, None, None, None) @@ -58,6 +58,12 @@ class LruCache(object): lock = threading.Lock() + def evict(): + while cache_len() > max_size: + todelete = list_root.prev_node + delete_node(todelete) + cache.pop(todelete.key, None) + def synchronized(f): @wraps(f) def inner(*args, **kwargs): @@ -66,6 +72,16 @@ class LruCache(object): return inner + cached_cache_len = [0] + if size_callback is not None: + def cache_len(): + return cached_cache_len[0] + else: + def cache_len(): + return len(cache) + + self.len = synchronized(cache_len) + def add_node(key, value, callbacks=set()): prev_node = list_root next_node = prev_node.next_node @@ -74,6 +90,9 @@ class LruCache(object): next_node.prev_node = node cache[key] = node + if size_callback: + cached_cache_len[0] += size_callback(node.value) + def move_node_to_front(node): prev_node = node.prev_node next_node = node.next_node @@ -92,23 +111,25 @@ class LruCache(object): prev_node.next_node = next_node next_node.prev_node = prev_node + if size_callback: + cached_cache_len[0] -= size_callback(node.value) + for cb in node.callbacks: cb() node.callbacks.clear() @synchronized - def cache_get(key, default=None, callback=None): + def cache_get(key, default=None, callbacks=[]): node = cache.get(key, None) if node is not None: move_node_to_front(node) - if callback: - node.callbacks.add(callback) + node.callbacks.update(callbacks) return node.value else: return default @synchronized - def cache_set(key, value, callback=None): + def cache_set(key, value, callbacks=[]): node = cache.get(key, None) if node is not None: if value != node.value: @@ -116,21 +137,18 @@ class LruCache(object): cb() node.callbacks.clear() - if callback: - node.callbacks.add(callback) + if size_callback: + cached_cache_len[0] -= size_callback(node.value) + cached_cache_len[0] += size_callback(value) + + node.callbacks.update(callbacks) move_node_to_front(node) node.value = value else: - if callback: - callbacks = set([callback]) - else: - callbacks = set() - add_node(key, value, callbacks) - if len(cache) > max_size: - todelete = list_root.prev_node - delete_node(todelete) - cache.pop(todelete.key, None) + add_node(key, value, set(callbacks)) + + evict() @synchronized def cache_set_default(key, value): @@ -139,10 +157,7 @@ class LruCache(object): return node.value else: add_node(key, value) - if len(cache) > max_size: - todelete = list_root.prev_node - delete_node(todelete) - cache.pop(todelete.key, None) + evict() return value @synchronized @@ -174,10 +189,8 @@ class LruCache(object): for cb in node.callbacks: cb() cache.clear() - - @synchronized - def cache_len(): - return len(cache) + if size_callback: + cached_cache_len[0] = 0 @synchronized def cache_contains(key): @@ -190,7 +203,7 @@ class LruCache(object): self.pop = cache_pop if cache_type is TreeCache: self.del_multi = cache_del_multi - self.len = cache_len + self.len = synchronized(cache_len) self.contains = cache_contains self.clear = cache_clear diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py index c31585aea3..fcc341a6b7 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py @@ -65,12 +65,27 @@ class TreeCache(object): return popped def values(self): - return [e.value for e in self.root.values()] + return list(iterate_tree_cache_entry(self.root)) def __len__(self): return self.size +def iterate_tree_cache_entry(d): + """Helper function to iterate over the leaves of a tree, i.e. a dict of that + can contain dicts. + """ + if isinstance(d, dict): + for value_d in d.itervalues(): + for value in iterate_tree_cache_entry(value_d): + yield value + else: + if isinstance(d, _Entry): + yield d.value + else: + yield d + + class _Entry(object): __slots__ = ["value"] diff --git a/synapse/util/debug.py b/synapse/util/debug.py deleted file mode 100644 index dc49162e6a..0000000000 --- a/synapse/util/debug.py +++ /dev/null @@ -1,71 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from twisted.internet import defer, reactor -from functools import wraps -from synapse.util.logcontext import LoggingContext, PreserveLoggingContext - - -def debug_deferreds(): - """Cause all deferreds to wait for a reactor tick before running their - callbacks. This increases the chance of getting a stack trace out of - a defer.inlineCallback since the code waiting on the deferred will get - a chance to add an errback before the deferred runs.""" - - # Helper method for retrieving and restoring the current logging context - # around a callback. - def with_logging_context(fn): - context = LoggingContext.current_context() - - def restore_context_callback(x): - with PreserveLoggingContext(context): - return fn(x) - - return restore_context_callback - - # We are going to modify the __init__ method of defer.Deferred so we - # need to get a copy of the old method so we can still call it. - old__init__ = defer.Deferred.__init__ - - # We need to create a deferred to bounce the callbacks through the reactor - # but we don't want to add a callback when we create that deferred so we - # we create a new type of deferred that uses the old __init__ method. - # This is safe as long as the old __init__ method doesn't invoke an - # __init__ using super. - class Bouncer(defer.Deferred): - __init__ = old__init__ - - # We'll add this as a callback to all Deferreds. Twisted will wait until - # the bouncer deferred resolves before calling the callbacks of the - # original deferred. - def bounce_callback(x): - bouncer = Bouncer() - reactor.callLater(0, with_logging_context(bouncer.callback), x) - return bouncer - - # We'll add this as an errback to all Deferreds. Twisted will wait until - # the bouncer deferred resolves before calling the errbacks of the - # original deferred. - def bounce_errback(x): - bouncer = Bouncer() - reactor.callLater(0, with_logging_context(bouncer.errback), x) - return bouncer - - @wraps(old__init__) - def new__init__(self, *args, **kargs): - old__init__(self, *args, **kargs) - self.addCallbacks(bounce_callback, bounce_errback) - - defer.Deferred.__init__ = new__init__ diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index e2de7fce91..153ef001ad 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -88,7 +88,7 @@ class RetryDestinationLimiter(object): def __init__(self, destination, clock, store, retry_interval, min_retry_interval=10 * 60 * 1000, max_retry_interval=24 * 60 * 60 * 1000, - multiplier_retry_interval=5,): + multiplier_retry_interval=5, backoff_on_404=False): """Marks the destination as "down" if an exception is thrown in the context, except for CodeMessageException with code < 500. @@ -107,6 +107,7 @@ class RetryDestinationLimiter(object): a failed request, in milliseconds. multiplier_retry_interval (int): The multiplier to use to increase the retry interval after a failed request. + backoff_on_404 (bool): Back off if we get a 404 """ self.clock = clock self.store = store @@ -116,6 +117,7 @@ class RetryDestinationLimiter(object): self.min_retry_interval = min_retry_interval self.max_retry_interval = max_retry_interval self.multiplier_retry_interval = multiplier_retry_interval + self.backoff_on_404 = backoff_on_404 def __enter__(self): pass @@ -123,7 +125,22 @@ class RetryDestinationLimiter(object): def __exit__(self, exc_type, exc_val, exc_tb): valid_err_code = False if exc_type is not None and issubclass(exc_type, CodeMessageException): - valid_err_code = exc_val.code != 429 and 0 <= exc_val.code < 500 + # Some error codes are perfectly fine for some APIs, whereas other + # APIs may expect to never received e.g. a 404. It's important to + # handle 404 as some remote servers will return a 404 when the HS + # has been decommissioned. + # If we get a 401, then we should probably back off since they + # won't accept our requests for at least a while. + # 429 is us being aggresively rate limited, so lets rate limit + # ourselves. + if exc_val.code == 404 and self.backoff_on_404: + valid_err_code = False + elif exc_val.code in (401, 429): + valid_err_code = False + elif exc_val.code < 500: + valid_err_code = True + else: + valid_err_code = False if exc_type is None or valid_err_code: # We connected successfully. |