diff options
Diffstat (limited to 'synapse')
276 files changed, 25208 insertions, 8652 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py index bc50bec9db..f32c28be02 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -16,4 +16,4 @@ """ This is a reference implementation of a Matrix home server. """ -__version__ = "0.13.3" +__version__ = "0.18.5-rc2" diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 3038df4ab8..ddab210718 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -13,22 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""This module contains classes for authenticating the user.""" +import logging + +import pymacaroons from canonicaljson import encode_canonical_json from signedjson.key import decode_verify_key_bytes from signedjson.sign import verify_signed_json, SignatureVerifyException - from twisted.internet import defer +from unpaddedbase64 import decode_base64 +import synapse.types from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError -from synapse.types import Requester, RoomID, UserID, EventID -from synapse.util.logutils import log_function +from synapse.types import UserID, get_domain_from_id from synapse.util.logcontext import preserve_context_over_fn -from unpaddedbase64 import decode_base64 - -import logging -import pymacaroons +from synapse.util.logutils import log_function +from synapse.util.metrics import Measure logger = logging.getLogger(__name__) @@ -39,23 +39,34 @@ AuthEventTypes = ( EventTypes.ThirdPartyInvite, ) +# guests always get this device id. +GUEST_DEVICE_ID = "guest_device" -class Auth(object): +class Auth(object): + """ + FIXME: This class contains a mix of functions for authenticating users + of our client-server API and authenticating events added to room graphs. + """ def __init__(self, hs): self.hs = hs + self.clock = hs.get_clock() self.store = hs.get_datastore() self.state = hs.get_state_handler() self.TOKEN_NOT_FOUND_HTTP_STATUS = 401 - self._KNOWN_CAVEAT_PREFIXES = set([ - "gen = ", - "guest = ", - "type = ", - "time < ", - "user_id = ", - ]) - - def check(self, event, auth_events): + + @defer.inlineCallbacks + def check_from_context(self, event, context, do_sig_check=True): + auth_events_ids = yield self.compute_auth_events( + event, context.prev_state_ids, for_verification=True, + ) + auth_events = yield self.store.get_events(auth_events_ids) + auth_events = { + (e.type, e.state_key): e for e in auth_events.values() + } + self.check(event, auth_events=auth_events, do_sig_check=do_sig_check) + + def check(self, event, auth_events, do_sig_check=True): """ Checks if this event is correctly authed. Args: @@ -66,11 +77,35 @@ class Auth(object): Returns: True if the auth checks pass. """ - self.check_size_limits(event) + with Measure(self.clock, "auth.check"): + self.check_size_limits(event) - try: if not hasattr(event, "room_id"): raise AuthError(500, "Event has no room_id: %s" % event) + + if do_sig_check: + sender_domain = get_domain_from_id(event.sender) + event_id_domain = get_domain_from_id(event.event_id) + + is_invite_via_3pid = ( + event.type == EventTypes.Member + and event.membership == Membership.INVITE + and "third_party_invite" in event.content + ) + + # Check the sender's domain has signed the event + if not event.signatures.get(sender_domain): + # We allow invites via 3pid to have a sender from a different + # HS, as the sender must match the sender of the original + # 3pid invite. This is checked further down with the + # other dedicated membership checks. + if not is_invite_via_3pid: + raise AuthError(403, "Event not signed by sender's server") + + # Check the event_id's domain has signed the event + if not event.signatures.get(event_id_domain): + raise AuthError(403, "Event not signed by sending server") + if auth_events is None: # Oh, we don't know what the state of the room was, so we # are trusting that this is allowed (at least for now) @@ -78,6 +113,12 @@ class Auth(object): return True if event.type == EventTypes.Create: + room_id_domain = get_domain_from_id(event.room_id) + if room_id_domain != sender_domain: + raise AuthError( + 403, + "Creation event's room_id domain does not match sender's" + ) # FIXME return True @@ -89,8 +130,8 @@ class Auth(object): "Room %r does not exist" % (event.room_id,) ) - creating_domain = RoomID.from_string(event.room_id).domain - originating_domain = UserID.from_string(event.sender).domain + creating_domain = get_domain_from_id(event.room_id) + originating_domain = get_domain_from_id(event.sender) if creating_domain != originating_domain: if not self.can_federate(event, auth_events): raise AuthError( @@ -100,6 +141,22 @@ class Auth(object): # FIXME: Temp hack if event.type == EventTypes.Aliases: + if not event.is_state(): + raise AuthError( + 403, + "Alias event must be a state event", + ) + if not event.state_key: + raise AuthError( + 403, + "Alias event must have non-empty state_key" + ) + sender_domain = get_domain_from_id(event.sender) + if event.state_key != sender_domain: + raise AuthError( + 403, + "Alias event's state_key does not match sender's domain" + ) return True logger.debug( @@ -118,6 +175,24 @@ class Auth(object): return allowed self.check_event_sender_in_room(event, auth_events) + + # Special case to allow m.room.third_party_invite events wherever + # a user is allowed to issue invites. Fixes + # https://github.com/vector-im/vector-web/issues/1208 hopefully + if event.type == EventTypes.ThirdPartyInvite: + user_level = self._get_user_power_level(event.user_id, auth_events) + invite_level = self._get_named_level(auth_events, "invite", 0) + + if user_level < invite_level: + raise AuthError( + 403, ( + "You cannot issue a third party invite for %s." % + (event.content.display_name,) + ) + ) + else: + return True + self._can_send_event(event, auth_events) if event.type == EventTypes.PowerLevels: @@ -127,13 +202,6 @@ class Auth(object): self.check_redaction(event, auth_events) logger.debug("Allowing! %s", event) - except AuthError as e: - logger.info( - "Event auth check failed on event %s with msg: %s", - event, e.msg - ) - logger.info("Denying! %s", event) - raise def check_size_limits(self, event): def too_big(field): @@ -219,21 +287,17 @@ class Auth(object): @defer.inlineCallbacks def check_host_in_room(self, room_id, host): - curr_state = yield self.state.get_current_state(room_id) + with Measure(self.clock, "check_host_in_room"): + latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - for event in curr_state.values(): - if event.type == EventTypes.Member: - try: - if UserID.from_string(event.state_key).domain != host: - continue - except: - logger.warn("state_key not user_id: %s", event.state_key) - continue - - if event.content["membership"] == Membership.JOIN: - defer.returnValue(True) + entry = yield self.state.resolve_state_groups( + room_id, latest_event_ids + ) - defer.returnValue(False) + ret = yield self.store.is_host_joined( + room_id, host, entry.state_group, entry.state + ) + defer.returnValue(ret) def check_event_sender_in_room(self, event, auth_events): key = (EventTypes.Member, event.user_id, ) @@ -271,8 +335,8 @@ class Auth(object): target_user_id = event.state_key - creating_domain = RoomID.from_string(event.room_id).domain - target_domain = UserID.from_string(target_user_id).domain + creating_domain = get_domain_from_id(event.room_id) + target_domain = get_domain_from_id(target_user_id) if creating_domain != target_domain: if not self.can_federate(event, auth_events): raise AuthError( @@ -328,6 +392,10 @@ class Auth(object): if Membership.INVITE == membership and "third_party_invite" in event.content: if not self._verify_third_party_invite(event, auth_events): raise AuthError(403, "You are not invited to this room.") + if target_banned: + raise AuthError( + 403, "%s is banned from the room" % (target_user_id,) + ) return True if Membership.JOIN != membership: @@ -432,6 +500,9 @@ class Auth(object): if not invite_event: return False + if invite_event.sender != event.sender: + return False + if event.user_id != invite_event.user_id: return False @@ -512,33 +583,38 @@ class Auth(object): return default @defer.inlineCallbacks - def get_user_by_req(self, request, allow_guest=False): + def get_user_by_req(self, request, allow_guest=False, rights="access"): """ Get a registered user's ID. Args: request - An HTTP request with an access_token query parameter. Returns: - tuple of: - UserID (str) - Access token ID (str) + defer.Deferred: resolves to a ``synapse.types.Requester`` object Raises: AuthError if no user by that token exists or the token is invalid. """ # Can optionally look elsewhere in the request (e.g. headers) try: - user_id = yield self._get_appservice_user_id(request.args) + user_id, app_service = yield self._get_appservice_user_id(request) if user_id: request.authenticated_entity = user_id defer.returnValue( - Requester(UserID.from_string(user_id), "", False) + synapse.types.create_requester(user_id, app_service=app_service) ) - access_token = request.args["access_token"][0] - user_info = yield self.get_user_by_access_token(access_token) + access_token = get_access_token_from_request( + request, self.TOKEN_NOT_FOUND_HTTP_STATUS + ) + + user_info = yield self.get_user_by_access_token(access_token, rights) user = user_info["user"] token_id = user_info["token_id"] is_guest = user_info["is_guest"] + # device_id may not be present if get_user_by_access_token has been + # stubbed out. + device_id = user_info.get("device_id") + ip_addr = self.hs.get_ip_from_request(request) user_agent = request.requestHeaders.getRawHeaders( "User-Agent", @@ -550,7 +626,8 @@ class Auth(object): user=user, access_token=access_token, ip=ip_addr, - user_agent=user_agent + user_agent=user_agent, + device_id=device_id, ) if is_guest and not allow_guest: @@ -560,7 +637,9 @@ class Auth(object): request.authenticated_entity = user.to_string() - defer.returnValue(Requester(user, token_id, is_guest)) + defer.returnValue(synapse.types.create_requester( + user, token_id, is_guest, device_id, app_service=app_service) + ) except KeyError: raise AuthError( self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.", @@ -568,19 +647,21 @@ class Auth(object): ) @defer.inlineCallbacks - def _get_appservice_user_id(self, request_args): - app_service = yield self.store.get_app_service_by_token( - request_args["access_token"][0] + def _get_appservice_user_id(self, request): + app_service = self.store.get_app_service_by_token( + get_access_token_from_request( + request, self.TOKEN_NOT_FOUND_HTTP_STATUS + ) ) if app_service is None: - defer.returnValue(None) + defer.returnValue((None, None)) - if "user_id" not in request_args: - defer.returnValue(app_service.sender) + if "user_id" not in request.args: + defer.returnValue((app_service.sender, app_service)) - user_id = request_args["user_id"][0] + user_id = request.args["user_id"][0] if app_service.sender == user_id: - defer.returnValue(app_service.sender) + defer.returnValue((app_service.sender, app_service)) if not app_service.is_interested_in_user(user_id): raise AuthError( @@ -592,10 +673,10 @@ class Auth(object): 403, "Application service has not registered this user" ) - defer.returnValue(user_id) + defer.returnValue((user_id, app_service)) @defer.inlineCallbacks - def get_user_by_access_token(self, token): + def get_user_by_access_token(self, token, rights="access"): """ Get a registered user's ID. Args: @@ -606,46 +687,62 @@ class Auth(object): AuthError if no user by that token exists or the token is invalid. """ try: - ret = yield self.get_user_from_macaroon(token) + ret = yield self.get_user_from_macaroon(token, rights) except AuthError: # TODO(daniel): Remove this fallback when all existing access tokens # have been re-issued as macaroons. + if self.hs.config.expire_access_token: + raise ret = yield self._look_up_user_by_access_token(token) + defer.returnValue(ret) @defer.inlineCallbacks - def get_user_from_macaroon(self, macaroon_str): + def get_user_from_macaroon(self, macaroon_str, rights="access"): try: macaroon = pymacaroons.Macaroon.deserialize(macaroon_str) - self.validate_macaroon(macaroon, "access", False) - user_prefix = "user_id = " - user = None + user_id = self.get_user_id_from_macaroon(macaroon) + user = UserID.from_string(user_id) + + self.validate_macaroon( + macaroon, rights, self.hs.config.expire_access_token, + user_id=user_id, + ) + guest = False for caveat in macaroon.caveats: - if caveat.caveat_id.startswith(user_prefix): - user = UserID.from_string(caveat.caveat_id[len(user_prefix):]) - elif caveat.caveat_id == "guest = true": + if caveat.caveat_id == "guest = true": guest = True - if user is None: - raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon", - errcode=Codes.UNKNOWN_TOKEN - ) - if guest: ret = { "user": user, "is_guest": True, "token_id": None, + # all guests get the same device id + "device_id": GUEST_DEVICE_ID, + } + elif rights == "delete_pusher": + # We don't store these tokens in the database + ret = { + "user": user, + "is_guest": False, + "token_id": None, + "device_id": None, } else: - # This codepath exists so that we can actually return a - # token ID, because we use token IDs in place of device - # identifiers throughout the codebase. - # TODO(daniel): Remove this fallback when device IDs are - # properly implemented. + # This codepath exists for several reasons: + # * so that we can actually return a token ID, which is used + # in some parts of the schema (where we probably ought to + # use device IDs instead) + # * the only way we currently have to invalidate an + # access_token is by removing it from the database, so we + # have to check here that it is still in the db + # * some attributes (notably device_id) aren't stored in the + # macaroon. They probably should be. + # TODO: build the dictionary from the macaroon once the + # above are fixed ret = yield self._look_up_user_by_access_token(macaroon_str) if ret["user"] != user: logger.error( @@ -665,31 +762,67 @@ class Auth(object): errcode=Codes.UNKNOWN_TOKEN ) - def validate_macaroon(self, macaroon, type_string, verify_expiry): + def get_user_id_from_macaroon(self, macaroon): + """Retrieve the user_id given by the caveats on the macaroon. + + Does *not* validate the macaroon. + + Args: + macaroon (pymacaroons.Macaroon): The macaroon to validate + + Returns: + (str) user id + + Raises: + AuthError if there is no user_id caveat in the macaroon + """ + user_prefix = "user_id = " + for caveat in macaroon.caveats: + if caveat.caveat_id.startswith(user_prefix): + return caveat.caveat_id[len(user_prefix):] + raise AuthError( + self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon", + errcode=Codes.UNKNOWN_TOKEN + ) + + def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id): """ validate that a Macaroon is understood by and was signed by this server. Args: macaroon(pymacaroons.Macaroon): The macaroon to validate - type_string(str): The kind of token this is (e.g. "access", "refresh") + type_string(str): The kind of token required (e.g. "access", + "delete_pusher") verify_expiry(bool): Whether to verify whether the macaroon has expired. - This should really always be True, but no clients currently implement - token refresh, so we can't enforce expiry yet. + user_id (str): The user_id required """ v = pymacaroons.Verifier() + + # the verifier runs a test for every caveat on the macaroon, to check + # that it is met for the current request. Each caveat must match at + # least one of the predicates specified by satisfy_exact or + # specify_general. v.satisfy_exact("gen = 1") v.satisfy_exact("type = " + type_string) - v.satisfy_general(lambda c: c.startswith("user_id = ")) + v.satisfy_exact("user_id = %s" % user_id) v.satisfy_exact("guest = true") + + # verify_expiry should really always be True, but there exist access + # tokens in the wild which expire when they should not, so we can't + # enforce expiry yet (so we have to allow any caveat starting with + # 'time < ' in access tokens). + # + # On the other hand, short-term login tokens (as used by CAS login, for + # example) have an expiry time which we do want to enforce. + if verify_expiry: v.satisfy_general(self._verify_expiry) else: v.satisfy_general(lambda c: c.startswith("time < ")) - v.verify(macaroon, self.hs.config.macaroon_secret_key) + # access_tokens include a nonce for uniqueness: any value is acceptable + v.satisfy_general(lambda c: c.startswith("nonce = ")) - v = pymacaroons.Verifier() - v.satisfy_general(self._verify_recognizes_caveats) v.verify(macaroon, self.hs.config.macaroon_secret_key) def _verify_expiry(self, caveat): @@ -700,15 +833,6 @@ class Auth(object): now = self.hs.get_clock().time_msec() return now < expiry - def _verify_recognizes_caveats(self, caveat): - first_space = caveat.find(" ") - if first_space < 0: - return False - second_space = caveat.find(" ", first_space + 1) - if second_space < 0: - return False - return caveat[:second_space + 1] in self._KNOWN_CAVEAT_PREFIXES - @defer.inlineCallbacks def _look_up_user_by_access_token(self, token): ret = yield self.store.get_user_by_access_token(token) @@ -718,18 +842,23 @@ class Auth(object): self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.", errcode=Codes.UNKNOWN_TOKEN ) + # we use ret.get() below because *lots* of unit tests stub out + # get_user_by_access_token in a way where it only returns a couple of + # the fields. user_info = { "user": UserID.from_string(ret.get("name")), "token_id": ret.get("token_id", None), "is_guest": False, + "device_id": ret.get("device_id"), } defer.returnValue(user_info) - @defer.inlineCallbacks def get_appservice_by_req(self, request): try: - token = request.args["access_token"][0] - service = yield self.store.get_app_service_by_token(token) + token = get_access_token_from_request( + request, self.TOKEN_NOT_FOUND_HTTP_STATUS + ) + service = self.store.get_app_service_by_token(token) if not service: logger.warn("Unrecognised appservice access token: %s" % (token,)) raise AuthError( @@ -738,7 +867,7 @@ class Auth(object): errcode=Codes.UNKNOWN_TOKEN ) request.authenticated_entity = service.sender - defer.returnValue(service) + return defer.succeed(service) except KeyError: raise AuthError( self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token." @@ -749,7 +878,7 @@ class Auth(object): @defer.inlineCallbacks def add_auth_events(self, builder, context): - auth_ids = self.compute_auth_events(builder, context.current_state) + auth_ids = yield self.compute_auth_events(builder, context.prev_state_ids) auth_events_entries = yield self.store.add_event_hashes( auth_ids @@ -757,30 +886,32 @@ class Auth(object): builder.auth_events = auth_events_entries - def compute_auth_events(self, event, current_state): + @defer.inlineCallbacks + def compute_auth_events(self, event, current_state_ids, for_verification=False): if event.type == EventTypes.Create: - return [] + defer.returnValue([]) auth_ids = [] key = (EventTypes.PowerLevels, "", ) - power_level_event = current_state.get(key) + power_level_event_id = current_state_ids.get(key) - if power_level_event: - auth_ids.append(power_level_event.event_id) + if power_level_event_id: + auth_ids.append(power_level_event_id) key = (EventTypes.JoinRules, "", ) - join_rule_event = current_state.get(key) + join_rule_event_id = current_state_ids.get(key) key = (EventTypes.Member, event.user_id, ) - member_event = current_state.get(key) + member_event_id = current_state_ids.get(key) key = (EventTypes.Create, "", ) - create_event = current_state.get(key) - if create_event: - auth_ids.append(create_event.event_id) + create_event_id = current_state_ids.get(key) + if create_event_id: + auth_ids.append(create_event_id) - if join_rule_event: + if join_rule_event_id: + join_rule_event = yield self.store.get_event(join_rule_event_id) join_rule = join_rule_event.content.get("join_rule") is_public = join_rule == JoinRules.PUBLIC if join_rule else False else: @@ -789,15 +920,21 @@ class Auth(object): if event.type == EventTypes.Member: e_type = event.content["membership"] if e_type in [Membership.JOIN, Membership.INVITE]: - if join_rule_event: - auth_ids.append(join_rule_event.event_id) + if join_rule_event_id: + auth_ids.append(join_rule_event_id) if e_type == Membership.JOIN: - if member_event and not is_public: - auth_ids.append(member_event.event_id) + if member_event_id and not is_public: + auth_ids.append(member_event_id) else: - if member_event: - auth_ids.append(member_event.event_id) + if member_event_id: + auth_ids.append(member_event_id) + + if for_verification: + key = (EventTypes.Member, event.state_key, ) + existing_event_id = current_state_ids.get(key) + if existing_event_id: + auth_ids.append(existing_event_id) if e_type == Membership.INVITE: if "third_party_invite" in event.content: @@ -805,26 +942,26 @@ class Auth(object): EventTypes.ThirdPartyInvite, event.content["third_party_invite"]["signed"]["token"] ) - third_party_invite = current_state.get(key) - if third_party_invite: - auth_ids.append(third_party_invite.event_id) - elif member_event: + third_party_invite_id = current_state_ids.get(key) + if third_party_invite_id: + auth_ids.append(third_party_invite_id) + elif member_event_id: + member_event = yield self.store.get_event(member_event_id) if member_event.content["membership"] == Membership.JOIN: auth_ids.append(member_event.event_id) - return auth_ids + defer.returnValue(auth_ids) - @log_function - def _can_send_event(self, event, auth_events): + def _get_send_level(self, etype, state_key, auth_events): key = (EventTypes.PowerLevels, "", ) send_level_event = auth_events.get(key) send_level = None if send_level_event: send_level = send_level_event.content.get("events", {}).get( - event.type + etype ) if send_level is None: - if hasattr(event, "state_key"): + if state_key is not None: send_level = send_level_event.content.get( "state_default", 50 ) @@ -838,6 +975,13 @@ class Auth(object): else: send_level = 0 + return send_level + + @log_function + def _can_send_event(self, event, auth_events): + send_level = self._get_send_level( + event.type, event.get("state_key", None), auth_events + ) user_level = self._get_user_power_level(event.user_id, auth_events) if user_level < send_level: @@ -855,16 +999,6 @@ class Auth(object): 403, "You are not allowed to set others state" ) - else: - sender_domain = UserID.from_string( - event.user_id - ).domain - - if sender_domain != event.state_key: - raise AuthError( - 403, - "You are not allowed to set others state" - ) return True @@ -888,8 +1022,8 @@ class Auth(object): if user_level >= redact_level: return False - redacter_domain = EventID.from_string(event.event_id).domain - redactee_domain = EventID.from_string(event.redacts).domain + redacter_domain = get_domain_from_id(event.event_id) + redactee_domain = get_domain_from_id(event.redacts) if redacter_domain == redactee_domain: return True @@ -982,3 +1116,108 @@ class Auth(object): "You don't have permission to add ops level greater " "than your own" ) + + @defer.inlineCallbacks + def check_can_change_room_list(self, room_id, user): + """Check if the user is allowed to edit the room's entry in the + published room list. + + Args: + room_id (str) + user (UserID) + """ + + is_admin = yield self.is_server_admin(user) + if is_admin: + defer.returnValue(True) + + user_id = user.to_string() + yield self.check_joined_room(room_id, user_id) + + # We currently require the user is a "moderator" in the room. We do this + # by checking if they would (theoretically) be able to change the + # m.room.aliases events + power_level_event = yield self.state.get_current_state( + room_id, EventTypes.PowerLevels, "" + ) + + auth_events = {} + if power_level_event: + auth_events[(EventTypes.PowerLevels, "")] = power_level_event + + send_level = self._get_send_level( + EventTypes.Aliases, "", auth_events + ) + user_level = self._get_user_power_level(user_id, auth_events) + + if user_level < send_level: + raise AuthError( + 403, + "This server requires you to be a moderator in the room to" + " edit its room list entry" + ) + + +def has_access_token(request): + """Checks if the request has an access_token. + + Returns: + bool: False if no access_token was given, True otherwise. + """ + query_params = request.args.get("access_token") + auth_headers = request.requestHeaders.getRawHeaders("Authorization") + return bool(query_params) or bool(auth_headers) + + +def get_access_token_from_request(request, token_not_found_http_status=401): + """Extracts the access_token from the request. + + Args: + request: The http request. + token_not_found_http_status(int): The HTTP status code to set in the + AuthError if the token isn't found. This is used in some of the + legacy APIs to change the status code to 403 from the default of + 401 since some of the old clients depended on auth errors returning + 403. + Returns: + str: The access_token + Raises: + AuthError: If there isn't an access_token in the request. + """ + + auth_headers = request.requestHeaders.getRawHeaders("Authorization") + query_params = request.args.get("access_token") + if auth_headers: + # Try the get the access_token from a "Authorization: Bearer" + # header + if query_params is not None: + raise AuthError( + token_not_found_http_status, + "Mixing Authorization headers and access_token query parameters.", + errcode=Codes.MISSING_TOKEN, + ) + if len(auth_headers) > 1: + raise AuthError( + token_not_found_http_status, + "Too many Authorization headers.", + errcode=Codes.MISSING_TOKEN, + ) + parts = auth_headers[0].split(" ") + if parts[0] == "Bearer" and len(parts) == 2: + return parts[1] + else: + raise AuthError( + token_not_found_http_status, + "Invalid Authorization header.", + errcode=Codes.MISSING_TOKEN, + ) + else: + # Try to get the access_token from the query params. + if not query_params: + raise AuthError( + token_not_found_http_status, + "Missing access token.", + errcode=Codes.MISSING_TOKEN + ) + + return query_params[0] diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 8cf4d6169c..a8123cddcb 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -85,3 +85,8 @@ class RoomCreationPreset(object): PRIVATE_CHAT = "private_chat" PUBLIC_CHAT = "public_chat" TRUSTED_PRIVATE_CHAT = "trusted_private_chat" + + +class ThirdPartyEntityKind(object): + USER = "user" + LOCATION = "location" diff --git a/synapse/api/errors.py b/synapse/api/errors.py index b106fbed6d..921c457738 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -39,11 +39,14 @@ class Codes(object): CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED" CAPTCHA_INVALID = "M_CAPTCHA_INVALID" MISSING_PARAM = "M_MISSING_PARAM" + INVALID_PARAM = "M_INVALID_PARAM" TOO_LARGE = "M_TOO_LARGE" EXCLUSIVE = "M_EXCLUSIVE" THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED" - THREEPID_IN_USE = "THREEPID_IN_USE" + THREEPID_IN_USE = "M_THREEPID_IN_USE" + THREEPID_NOT_FOUND = "M_THREEPID_NOT_FOUND" INVALID_USERNAME = "M_INVALID_USERNAME" + SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED" class CodeMessageException(RuntimeError): diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index cd699ef27f..fb291d7fb9 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -15,6 +15,8 @@ from synapse.api.errors import SynapseError from synapse.types import UserID, RoomID +from twisted.internet import defer + import ujson as json @@ -24,10 +26,10 @@ class Filtering(object): super(Filtering, self).__init__() self.store = hs.get_datastore() + @defer.inlineCallbacks def get_user_filter(self, user_localpart, filter_id): - result = self.store.get_user_filter(user_localpart, filter_id) - result.addCallback(FilterCollection) - return result + result = yield self.store.get_user_filter(user_localpart, filter_id) + defer.returnValue(FilterCollection(result)) def add_user_filter(self, user_localpart, user_filter): self.check_valid_filter(user_filter) @@ -69,6 +71,21 @@ class Filtering(object): if key in user_filter_json["room"]: self._check_definition(user_filter_json["room"][key]) + if "event_fields" in user_filter_json: + if type(user_filter_json["event_fields"]) != list: + raise SynapseError(400, "event_fields must be a list of strings") + for field in user_filter_json["event_fields"]: + if not isinstance(field, basestring): + raise SynapseError(400, "Event field must be a string") + # Don't allow '\\' in event field filters. This makes matching + # events a lot easier as we can then use a negative lookbehind + # assertion to split '\.' If we allowed \\ then it would + # incorrectly split '\\.' See synapse.events.utils.serialize_event + if r'\\' in field: + raise SynapseError( + 400, r'The escape character \ cannot itself be escaped' + ) + def _check_definition_room_lists(self, definition): """Check that "rooms" and "not_rooms" are lists of room ids if they are present @@ -150,6 +167,7 @@ class FilterCollection(object): self.include_leave = filter_json.get("room", {}).get( "include_leave", False ) + self.event_fields = filter_json.get("event_fields", []) def __repr__(self): return "<FilterCollection %s>" % (json.dumps(self._filter_json),) @@ -184,11 +202,51 @@ class FilterCollection(object): def filter_room_account_data(self, events): return self._room_account_data.filter(self._room_filter.filter(events)) + def blocks_all_presence(self): + return ( + self._presence_filter.filters_all_types() or + self._presence_filter.filters_all_senders() + ) + + def blocks_all_room_ephemeral(self): + return ( + self._room_ephemeral_filter.filters_all_types() or + self._room_ephemeral_filter.filters_all_senders() or + self._room_ephemeral_filter.filters_all_rooms() + ) + + def blocks_all_room_timeline(self): + return ( + self._room_timeline_filter.filters_all_types() or + self._room_timeline_filter.filters_all_senders() or + self._room_timeline_filter.filters_all_rooms() + ) + class Filter(object): def __init__(self, filter_json): self.filter_json = filter_json + self.types = self.filter_json.get("types", None) + self.not_types = self.filter_json.get("not_types", []) + + self.rooms = self.filter_json.get("rooms", None) + self.not_rooms = self.filter_json.get("not_rooms", []) + + self.senders = self.filter_json.get("senders", None) + self.not_senders = self.filter_json.get("not_senders", []) + + self.contains_url = self.filter_json.get("contains_url", None) + + def filters_all_types(self): + return "*" in self.not_types + + def filters_all_senders(self): + return "*" in self.not_senders + + def filters_all_rooms(self): + return "*" in self.not_rooms + def check(self, event): """Checks whether the filter matches the given event. @@ -207,9 +265,10 @@ class Filter(object): event.get("room_id", None), sender, event.get("type", None), + "url" in event.get("content", {}) ) - def check_fields(self, room_id, sender, event_type): + def check_fields(self, room_id, sender, event_type, contains_url): """Checks whether the filter matches the given event fields. Returns: @@ -223,15 +282,20 @@ class Filter(object): for name, match_func in literal_keys.items(): not_name = "not_%s" % (name,) - disallowed_values = self.filter_json.get(not_name, []) + disallowed_values = getattr(self, not_name) if any(map(match_func, disallowed_values)): return False - allowed_values = self.filter_json.get(name, None) + allowed_values = getattr(self, name) if allowed_values is not None: if not any(map(match_func, allowed_values)): return False + contains_url_filter = self.filter_json.get("contains_url") + if contains_url_filter is not None: + if contains_url_filter != contains_url: + return False + return True def filter_rooms(self, room_ids): diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 660dfb56e5..06cc8d90b8 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -23,7 +23,7 @@ class Ratelimiter(object): def __init__(self): self.message_counts = collections.OrderedDict() - def send_message(self, user_id, time_now_s, msg_rate_hz, burst_count): + def send_message(self, user_id, time_now_s, msg_rate_hz, burst_count, update=True): """Can the user send a message? Args: user_id: The user sending a message. @@ -32,12 +32,15 @@ class Ratelimiter(object): second. burst_count: How many messages the user can send before being limited. + update (bool): Whether to update the message rates or not. This is + useful to check if a message would be allowed to be sent before + its ready to be actually sent. Returns: A pair of a bool indicating if they can send a message now and a time in seconds of when they can next send a message. """ self.prune_message_counts(time_now_s) - message_count, time_start, _ignored = self.message_counts.pop( + message_count, time_start, _ignored = self.message_counts.get( user_id, (0., time_now_s, None), ) time_delta = time_now_s - time_start @@ -52,9 +55,10 @@ class Ratelimiter(object): allowed = True message_count += 1 - self.message_counts[user_id] = ( - message_count, time_start, msg_rate_hz - ) + if update: + self.message_counts[user_id] = ( + message_count, time_start, msg_rate_hz + ) if msg_rate_hz > 0: time_allowed = ( diff --git a/synapse/api/urls.py b/synapse/api/urls.py index 0fd9b7f244..91a33a3402 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -25,4 +25,3 @@ SERVER_KEY_PREFIX = "/_matrix/key/v1" SERVER_KEY_V2_PREFIX = "/_matrix/key/v2" MEDIA_PREFIX = "/_matrix/media/r0" LEGACY_MEDIA_PREFIX = "/_matrix/media/v1" -APP_SERVICE_PREFIX = "/_matrix/appservice/v1" diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py index 1bc4279807..9c2b627590 100644 --- a/synapse/app/__init__.py +++ b/synapse/app/__init__.py @@ -16,13 +16,11 @@ import sys sys.dont_write_bytecode = True -from synapse.python_dependencies import ( - check_requirements, MissingRequirementError -) # NOQA +from synapse import python_dependencies # noqa: E402 try: - check_requirements() -except MissingRequirementError as e: + python_dependencies.check_requirements() +except python_dependencies.MissingRequirementError as e: message = "\n".join([ "Missing Requirement: %s" % (e.message,), "To install run:", diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py new file mode 100644 index 0000000000..dd9ee406a1 --- /dev/null +++ b/synapse/app/appservice.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import synapse + +from synapse.server import HomeServer +from synapse.config._base import ConfigError +from synapse.config.logger import setup_logging +from synapse.config.homeserver import HomeServerConfig +from synapse.http.site import SynapseSite +from synapse.metrics.resource import MetricsResource, METRICS_PREFIX +from synapse.replication.slave.storage.directory import DirectoryStore +from synapse.replication.slave.storage.events import SlavedEventStore +from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore +from synapse.replication.slave.storage.registration import SlavedRegistrationStore +from synapse.storage.engines import create_engine +from synapse.util.async import sleep +from synapse.util.httpresourcetree import create_resource_tree +from synapse.util.logcontext import LoggingContext +from synapse.util.manhole import manhole +from synapse.util.rlimit import change_resource_limit +from synapse.util.versionstring import get_version_string + +from synapse import events + +from twisted.internet import reactor, defer +from twisted.web.resource import Resource + +from daemonize import Daemonize + +import sys +import logging +import gc + +logger = logging.getLogger("synapse.app.appservice") + + +class AppserviceSlaveStore( + DirectoryStore, SlavedEventStore, SlavedApplicationServiceStore, + SlavedRegistrationStore, +): + pass + + +class AppserviceServer(HomeServer): + def get_db_conn(self, run_new_connection=True): + # Any param beginning with cp_ is a parameter for adbapi, and should + # not be passed to the database engine. + db_params = { + k: v for k, v in self.db_config.get("args", {}).items() + if not k.startswith("cp_") + } + db_conn = self.database_engine.module.connect(**db_params) + + if run_new_connection: + self.database_engine.on_new_connection(db_conn) + return db_conn + + def setup(self): + logger.info("Setting up.") + self.datastore = AppserviceSlaveStore(self.get_db_conn(), self) + logger.info("Finished setting up.") + + def _listen_http(self, listener_config): + port = listener_config["port"] + bind_address = listener_config.get("bind_address", "") + site_tag = listener_config.get("tag", port) + resources = {} + for res in listener_config["resources"]: + for name in res["names"]: + if name == "metrics": + resources[METRICS_PREFIX] = MetricsResource(self) + + root_resource = create_resource_tree(resources, Resource()) + reactor.listenTCP( + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + ), + interface=bind_address + ) + logger.info("Synapse appservice now listening on port %d", port) + + def start_listening(self, listeners): + for listener in listeners: + if listener["type"] == "http": + self._listen_http(listener) + elif listener["type"] == "manhole": + reactor.listenTCP( + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, + ), + interface=listener.get("bind_address", '127.0.0.1') + ) + else: + logger.warn("Unrecognized listener type: %s", listener["type"]) + + @defer.inlineCallbacks + def replicate(self): + http_client = self.get_simple_http_client() + store = self.get_datastore() + replication_url = self.config.worker_replication_url + appservice_handler = self.get_application_service_handler() + + @defer.inlineCallbacks + def replicate(results): + stream = results.get("events") + if stream: + max_stream_id = stream["position"] + yield appservice_handler.notify_interested_services(max_stream_id) + + while True: + try: + args = store.stream_positions() + args["timeout"] = 30000 + result = yield http_client.get_json(replication_url, args=args) + yield store.process_replication(result) + replicate(result) + except: + logger.exception("Error replicating from %r", replication_url) + yield sleep(30) + + +def start(config_options): + try: + config = HomeServerConfig.load_config( + "Synapse appservice", config_options + ) + except ConfigError as e: + sys.stderr.write("\n" + e.message + "\n") + sys.exit(1) + + assert config.worker_app == "synapse.app.appservice" + + setup_logging(config.worker_log_config, config.worker_log_file) + + events.USE_FROZEN_DICTS = config.use_frozen_dicts + + database_engine = create_engine(config.database_config) + + if config.notify_appservices: + sys.stderr.write( + "\nThe appservices must be disabled in the main synapse process" + "\nbefore they can be run in a separate worker." + "\nPlease add ``notify_appservices: false`` to the main config" + "\n" + ) + sys.exit(1) + + # Force the pushers to start since they will be disabled in the main config + config.notify_appservices = True + + ps = AppserviceServer( + config.server_name, + db_config=config.database_config, + config=config, + version_string="Synapse/" + get_version_string(synapse), + database_engine=database_engine, + ) + + ps.setup() + ps.start_listening(config.worker_listeners) + + def run(): + with LoggingContext("run"): + logger.info("Running") + change_resource_limit(config.soft_file_limit) + if config.gc_thresholds: + gc.set_threshold(*config.gc_thresholds) + reactor.run() + + def start(): + ps.replicate() + ps.get_datastore().start_profiling() + ps.get_state_handler().start_caching() + + reactor.callWhenRunning(start) + + if config.worker_daemonize: + daemon = Daemonize( + app="synapse-appservice", + pid=config.worker_pid_file, + action=run, + auto_close_fds=False, + verbose=True, + logger=logger, + ) + daemon.start() + else: + run() + + +if __name__ == '__main__': + with LoggingContext("main"): + start(sys.argv[1:]) diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py new file mode 100644 index 0000000000..0086a2977e --- /dev/null +++ b/synapse/app/client_reader.py @@ -0,0 +1,220 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import synapse + +from synapse.config._base import ConfigError +from synapse.config.homeserver import HomeServerConfig +from synapse.config.logger import setup_logging +from synapse.http.site import SynapseSite +from synapse.http.server import JsonResource +from synapse.metrics.resource import MetricsResource, METRICS_PREFIX +from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore +from synapse.replication.slave.storage.events import SlavedEventStore +from synapse.replication.slave.storage.keys import SlavedKeyStore +from synapse.replication.slave.storage.room import RoomStore +from synapse.replication.slave.storage.directory import DirectoryStore +from synapse.replication.slave.storage.registration import SlavedRegistrationStore +from synapse.rest.client.v1.room import PublicRoomListRestServlet +from synapse.server import HomeServer +from synapse.storage.client_ips import ClientIpStore +from synapse.storage.engines import create_engine +from synapse.util.async import sleep +from synapse.util.httpresourcetree import create_resource_tree +from synapse.util.logcontext import LoggingContext +from synapse.util.manhole import manhole +from synapse.util.rlimit import change_resource_limit +from synapse.util.versionstring import get_version_string +from synapse.crypto import context_factory + +from synapse import events + + +from twisted.internet import reactor, defer +from twisted.web.resource import Resource + +from daemonize import Daemonize + +import sys +import logging +import gc + +logger = logging.getLogger("synapse.app.client_reader") + + +class ClientReaderSlavedStore( + SlavedEventStore, + SlavedKeyStore, + RoomStore, + DirectoryStore, + SlavedApplicationServiceStore, + SlavedRegistrationStore, + BaseSlavedStore, + ClientIpStore, # After BaseSlavedStore because the constructor is different +): + pass + + +class ClientReaderServer(HomeServer): + def get_db_conn(self, run_new_connection=True): + # Any param beginning with cp_ is a parameter for adbapi, and should + # not be passed to the database engine. + db_params = { + k: v for k, v in self.db_config.get("args", {}).items() + if not k.startswith("cp_") + } + db_conn = self.database_engine.module.connect(**db_params) + + if run_new_connection: + self.database_engine.on_new_connection(db_conn) + return db_conn + + def setup(self): + logger.info("Setting up.") + self.datastore = ClientReaderSlavedStore(self.get_db_conn(), self) + logger.info("Finished setting up.") + + def _listen_http(self, listener_config): + port = listener_config["port"] + bind_address = listener_config.get("bind_address", "") + site_tag = listener_config.get("tag", port) + resources = {} + for res in listener_config["resources"]: + for name in res["names"]: + if name == "metrics": + resources[METRICS_PREFIX] = MetricsResource(self) + elif name == "client": + resource = JsonResource(self, canonical_json=False) + PublicRoomListRestServlet(self).register(resource) + resources.update({ + "/_matrix/client/r0": resource, + "/_matrix/client/unstable": resource, + "/_matrix/client/v2_alpha": resource, + "/_matrix/client/api/v1": resource, + }) + + root_resource = create_resource_tree(resources, Resource()) + reactor.listenTCP( + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + ), + interface=bind_address + ) + logger.info("Synapse client reader now listening on port %d", port) + + def start_listening(self, listeners): + for listener in listeners: + if listener["type"] == "http": + self._listen_http(listener) + elif listener["type"] == "manhole": + reactor.listenTCP( + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, + ), + interface=listener.get("bind_address", '127.0.0.1') + ) + else: + logger.warn("Unrecognized listener type: %s", listener["type"]) + + @defer.inlineCallbacks + def replicate(self): + http_client = self.get_simple_http_client() + store = self.get_datastore() + replication_url = self.config.worker_replication_url + + while True: + try: + args = store.stream_positions() + args["timeout"] = 30000 + result = yield http_client.get_json(replication_url, args=args) + yield store.process_replication(result) + except: + logger.exception("Error replicating from %r", replication_url) + yield sleep(5) + + +def start(config_options): + try: + config = HomeServerConfig.load_config( + "Synapse client reader", config_options + ) + except ConfigError as e: + sys.stderr.write("\n" + e.message + "\n") + sys.exit(1) + + assert config.worker_app == "synapse.app.client_reader" + + setup_logging(config.worker_log_config, config.worker_log_file) + + events.USE_FROZEN_DICTS = config.use_frozen_dicts + + database_engine = create_engine(config.database_config) + + tls_server_context_factory = context_factory.ServerContextFactory(config) + + ss = ClientReaderServer( + config.server_name, + db_config=config.database_config, + tls_server_context_factory=tls_server_context_factory, + config=config, + version_string="Synapse/" + get_version_string(synapse), + database_engine=database_engine, + ) + + ss.setup() + ss.get_handlers() + ss.start_listening(config.worker_listeners) + + def run(): + with LoggingContext("run"): + logger.info("Running") + change_resource_limit(config.soft_file_limit) + if config.gc_thresholds: + gc.set_threshold(*config.gc_thresholds) + reactor.run() + + def start(): + ss.get_state_handler().start_caching() + ss.get_datastore().start_profiling() + ss.replicate() + + reactor.callWhenRunning(start) + + if config.worker_daemonize: + daemon = Daemonize( + app="synapse-client-reader", + pid=config.worker_pid_file, + action=run, + auto_close_fds=False, + verbose=True, + logger=logger, + ) + daemon.start() + else: + run() + + +if __name__ == '__main__': + with LoggingContext("main"): + start(sys.argv[1:]) diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py new file mode 100644 index 0000000000..b5f59a9931 --- /dev/null +++ b/synapse/app/federation_reader.py @@ -0,0 +1,211 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import synapse + +from synapse.config._base import ConfigError +from synapse.config.homeserver import HomeServerConfig +from synapse.config.logger import setup_logging +from synapse.http.site import SynapseSite +from synapse.metrics.resource import MetricsResource, METRICS_PREFIX +from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.replication.slave.storage.events import SlavedEventStore +from synapse.replication.slave.storage.keys import SlavedKeyStore +from synapse.replication.slave.storage.room import RoomStore +from synapse.replication.slave.storage.transactions import TransactionStore +from synapse.replication.slave.storage.directory import DirectoryStore +from synapse.server import HomeServer +from synapse.storage.engines import create_engine +from synapse.util.async import sleep +from synapse.util.httpresourcetree import create_resource_tree +from synapse.util.logcontext import LoggingContext +from synapse.util.manhole import manhole +from synapse.util.rlimit import change_resource_limit +from synapse.util.versionstring import get_version_string +from synapse.api.urls import FEDERATION_PREFIX +from synapse.federation.transport.server import TransportLayerServer +from synapse.crypto import context_factory + +from synapse import events + + +from twisted.internet import reactor, defer +from twisted.web.resource import Resource + +from daemonize import Daemonize + +import sys +import logging +import gc + +logger = logging.getLogger("synapse.app.federation_reader") + + +class FederationReaderSlavedStore( + SlavedEventStore, + SlavedKeyStore, + RoomStore, + DirectoryStore, + TransactionStore, + BaseSlavedStore, +): + pass + + +class FederationReaderServer(HomeServer): + def get_db_conn(self, run_new_connection=True): + # Any param beginning with cp_ is a parameter for adbapi, and should + # not be passed to the database engine. + db_params = { + k: v for k, v in self.db_config.get("args", {}).items() + if not k.startswith("cp_") + } + db_conn = self.database_engine.module.connect(**db_params) + + if run_new_connection: + self.database_engine.on_new_connection(db_conn) + return db_conn + + def setup(self): + logger.info("Setting up.") + self.datastore = FederationReaderSlavedStore(self.get_db_conn(), self) + logger.info("Finished setting up.") + + def _listen_http(self, listener_config): + port = listener_config["port"] + bind_address = listener_config.get("bind_address", "") + site_tag = listener_config.get("tag", port) + resources = {} + for res in listener_config["resources"]: + for name in res["names"]: + if name == "metrics": + resources[METRICS_PREFIX] = MetricsResource(self) + elif name == "federation": + resources.update({ + FEDERATION_PREFIX: TransportLayerServer(self), + }) + + root_resource = create_resource_tree(resources, Resource()) + reactor.listenTCP( + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + ), + interface=bind_address + ) + logger.info("Synapse federation reader now listening on port %d", port) + + def start_listening(self, listeners): + for listener in listeners: + if listener["type"] == "http": + self._listen_http(listener) + elif listener["type"] == "manhole": + reactor.listenTCP( + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, + ), + interface=listener.get("bind_address", '127.0.0.1') + ) + else: + logger.warn("Unrecognized listener type: %s", listener["type"]) + + @defer.inlineCallbacks + def replicate(self): + http_client = self.get_simple_http_client() + store = self.get_datastore() + replication_url = self.config.worker_replication_url + + while True: + try: + args = store.stream_positions() + args["timeout"] = 30000 + result = yield http_client.get_json(replication_url, args=args) + yield store.process_replication(result) + except: + logger.exception("Error replicating from %r", replication_url) + yield sleep(5) + + +def start(config_options): + try: + config = HomeServerConfig.load_config( + "Synapse federation reader", config_options + ) + except ConfigError as e: + sys.stderr.write("\n" + e.message + "\n") + sys.exit(1) + + assert config.worker_app == "synapse.app.federation_reader" + + setup_logging(config.worker_log_config, config.worker_log_file) + + events.USE_FROZEN_DICTS = config.use_frozen_dicts + + database_engine = create_engine(config.database_config) + + tls_server_context_factory = context_factory.ServerContextFactory(config) + + ss = FederationReaderServer( + config.server_name, + db_config=config.database_config, + tls_server_context_factory=tls_server_context_factory, + config=config, + version_string="Synapse/" + get_version_string(synapse), + database_engine=database_engine, + ) + + ss.setup() + ss.get_handlers() + ss.start_listening(config.worker_listeners) + + def run(): + with LoggingContext("run"): + logger.info("Running") + change_resource_limit(config.soft_file_limit) + if config.gc_thresholds: + gc.set_threshold(*config.gc_thresholds) + reactor.run() + + def start(): + ss.get_state_handler().start_caching() + ss.get_datastore().start_profiling() + ss.replicate() + + reactor.callWhenRunning(start) + + if config.worker_daemonize: + daemon = Daemonize( + app="synapse-federation-reader", + pid=config.worker_pid_file, + action=run, + auto_close_fds=False, + verbose=True, + logger=logger, + ) + daemon.start() + else: + run() + + +if __name__ == '__main__': + with LoggingContext("main"): + start(sys.argv[1:]) diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py new file mode 100644 index 0000000000..80ea4c8062 --- /dev/null +++ b/synapse/app/federation_sender.py @@ -0,0 +1,331 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import synapse + +from synapse.server import HomeServer +from synapse.config._base import ConfigError +from synapse.config.logger import setup_logging +from synapse.config.homeserver import HomeServerConfig +from synapse.crypto import context_factory +from synapse.http.site import SynapseSite +from synapse.federation import send_queue +from synapse.federation.units import Edu +from synapse.metrics.resource import MetricsResource, METRICS_PREFIX +from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore +from synapse.replication.slave.storage.events import SlavedEventStore +from synapse.replication.slave.storage.receipts import SlavedReceiptsStore +from synapse.replication.slave.storage.registration import SlavedRegistrationStore +from synapse.replication.slave.storage.transactions import TransactionStore +from synapse.storage.engines import create_engine +from synapse.storage.presence import UserPresenceState +from synapse.util.async import sleep +from synapse.util.httpresourcetree import create_resource_tree +from synapse.util.logcontext import LoggingContext +from synapse.util.manhole import manhole +from synapse.util.rlimit import change_resource_limit +from synapse.util.versionstring import get_version_string + +from synapse import events + +from twisted.internet import reactor, defer +from twisted.web.resource import Resource + +from daemonize import Daemonize + +import sys +import logging +import gc +import ujson as json + +logger = logging.getLogger("synapse.app.appservice") + + +class FederationSenderSlaveStore( + SlavedDeviceInboxStore, TransactionStore, SlavedReceiptsStore, SlavedEventStore, + SlavedRegistrationStore, +): + pass + + +class FederationSenderServer(HomeServer): + def get_db_conn(self, run_new_connection=True): + # Any param beginning with cp_ is a parameter for adbapi, and should + # not be passed to the database engine. + db_params = { + k: v for k, v in self.db_config.get("args", {}).items() + if not k.startswith("cp_") + } + db_conn = self.database_engine.module.connect(**db_params) + + if run_new_connection: + self.database_engine.on_new_connection(db_conn) + return db_conn + + def setup(self): + logger.info("Setting up.") + self.datastore = FederationSenderSlaveStore(self.get_db_conn(), self) + logger.info("Finished setting up.") + + def _listen_http(self, listener_config): + port = listener_config["port"] + bind_address = listener_config.get("bind_address", "") + site_tag = listener_config.get("tag", port) + resources = {} + for res in listener_config["resources"]: + for name in res["names"]: + if name == "metrics": + resources[METRICS_PREFIX] = MetricsResource(self) + + root_resource = create_resource_tree(resources, Resource()) + reactor.listenTCP( + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + ), + interface=bind_address + ) + logger.info("Synapse federation_sender now listening on port %d", port) + + def start_listening(self, listeners): + for listener in listeners: + if listener["type"] == "http": + self._listen_http(listener) + elif listener["type"] == "manhole": + reactor.listenTCP( + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, + ), + interface=listener.get("bind_address", '127.0.0.1') + ) + else: + logger.warn("Unrecognized listener type: %s", listener["type"]) + + @defer.inlineCallbacks + def replicate(self): + http_client = self.get_simple_http_client() + store = self.get_datastore() + replication_url = self.config.worker_replication_url + send_handler = FederationSenderHandler(self) + + send_handler.on_start() + + while True: + try: + args = store.stream_positions() + args.update((yield send_handler.stream_positions())) + args["timeout"] = 30000 + result = yield http_client.get_json(replication_url, args=args) + yield store.process_replication(result) + yield send_handler.process_replication(result) + except: + logger.exception("Error replicating from %r", replication_url) + yield sleep(30) + + +def start(config_options): + try: + config = HomeServerConfig.load_config( + "Synapse federation sender", config_options + ) + except ConfigError as e: + sys.stderr.write("\n" + e.message + "\n") + sys.exit(1) + + assert config.worker_app == "synapse.app.federation_sender" + + setup_logging(config.worker_log_config, config.worker_log_file) + + events.USE_FROZEN_DICTS = config.use_frozen_dicts + + database_engine = create_engine(config.database_config) + + if config.send_federation: + sys.stderr.write( + "\nThe send_federation must be disabled in the main synapse process" + "\nbefore they can be run in a separate worker." + "\nPlease add ``send_federation: false`` to the main config" + "\n" + ) + sys.exit(1) + + # Force the pushers to start since they will be disabled in the main config + config.send_federation = True + + tls_server_context_factory = context_factory.ServerContextFactory(config) + + ps = FederationSenderServer( + config.server_name, + db_config=config.database_config, + tls_server_context_factory=tls_server_context_factory, + config=config, + version_string="Synapse/" + get_version_string(synapse), + database_engine=database_engine, + ) + + ps.setup() + ps.start_listening(config.worker_listeners) + + def run(): + with LoggingContext("run"): + logger.info("Running") + change_resource_limit(config.soft_file_limit) + if config.gc_thresholds: + gc.set_threshold(*config.gc_thresholds) + reactor.run() + + def start(): + ps.replicate() + ps.get_datastore().start_profiling() + ps.get_state_handler().start_caching() + + reactor.callWhenRunning(start) + + if config.worker_daemonize: + daemon = Daemonize( + app="synapse-federation-sender", + pid=config.worker_pid_file, + action=run, + auto_close_fds=False, + verbose=True, + logger=logger, + ) + daemon.start() + else: + run() + + +class FederationSenderHandler(object): + """Processes the replication stream and forwards the appropriate entries + to the federation sender. + """ + def __init__(self, hs): + self.store = hs.get_datastore() + self.federation_sender = hs.get_federation_sender() + + self._room_serials = {} + self._room_typing = {} + + def on_start(self): + # There may be some events that are persisted but haven't been sent, + # so send them now. + self.federation_sender.notify_new_events( + self.store.get_room_max_stream_ordering() + ) + + @defer.inlineCallbacks + def stream_positions(self): + stream_id = yield self.store.get_federation_out_pos("federation") + defer.returnValue({ + "federation": stream_id, + + # Ack stuff we've "processed", this should only be called from + # one process. + "federation_ack": stream_id, + }) + + @defer.inlineCallbacks + def process_replication(self, result): + # The federation stream contains things that we want to send out, e.g. + # presence, typing, etc. + fed_stream = result.get("federation") + if fed_stream: + latest_id = int(fed_stream["position"]) + + # The federation stream containis a bunch of different types of + # rows that need to be handled differently. We parse the rows, put + # them into the appropriate collection and then send them off. + presence_to_send = {} + keyed_edus = {} + edus = {} + failures = {} + device_destinations = set() + + # Parse the rows in the stream + for row in fed_stream["rows"]: + position, typ, content_js = row + content = json.loads(content_js) + + if typ == send_queue.PRESENCE_TYPE: + destination = content["destination"] + state = UserPresenceState.from_dict(content["state"]) + + presence_to_send.setdefault(destination, []).append(state) + elif typ == send_queue.KEYED_EDU_TYPE: + key = content["key"] + edu = Edu(**content["edu"]) + + keyed_edus.setdefault( + edu.destination, {} + )[(edu.destination, tuple(key))] = edu + elif typ == send_queue.EDU_TYPE: + edu = Edu(**content) + + edus.setdefault(edu.destination, []).append(edu) + elif typ == send_queue.FAILURE_TYPE: + destination = content["destination"] + failure = content["failure"] + + failures.setdefault(destination, []).append(failure) + elif typ == send_queue.DEVICE_MESSAGE_TYPE: + device_destinations.add(content["destination"]) + else: + raise Exception("Unrecognised federation type: %r", typ) + + # We've finished collecting, send everything off + for destination, states in presence_to_send.items(): + self.federation_sender.send_presence(destination, states) + + for destination, edu_map in keyed_edus.items(): + for key, edu in edu_map.items(): + self.federation_sender.send_edu( + edu.destination, edu.edu_type, edu.content, key=key, + ) + + for destination, edu_list in edus.items(): + for edu in edu_list: + self.federation_sender.send_edu( + edu.destination, edu.edu_type, edu.content, key=None, + ) + + for destination, failure_list in failures.items(): + for failure in failure_list: + self.federation_sender.send_failure(destination, failure) + + for destination in device_destinations: + self.federation_sender.send_device_messages(destination) + + # Record where we are in the stream. + yield self.store.update_federation_out_pos( + "federation", latest_id + ) + + # We also need to poke the federation sender when new events happen + event_stream = result.get("events") + if event_stream: + latest_pos = event_stream["position"] + self.federation_sender.notify_new_events(latest_pos) + + +if __name__ == '__main__': + with LoggingContext("main"): + start(sys.argv[1:]) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index fcdc8e6e10..54f35900f8 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -16,14 +16,10 @@ import synapse -import contextlib +import gc import logging import os -import re -import resource -import subprocess import sys -import time from synapse.config._base import ConfigError from synapse.python_dependencies import ( @@ -33,22 +29,15 @@ from synapse.python_dependencies import ( from synapse.rest import ClientRestResource from synapse.storage.engines import create_engine, IncorrectDatabaseSetup from synapse.storage import are_all_users_on_domain -from synapse.storage.prepare_database import UpgradeDatabaseException +from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database from synapse.server import HomeServer - -from twisted.conch.manhole import ColoredManhole -from twisted.conch.insults import insults -from twisted.conch import manhole_ssh -from twisted.cred import checkers, portal - - from twisted.internet import reactor, task, defer from twisted.application import service from twisted.web.resource import Resource, EncodingResourceWrapper from twisted.web.static import File -from twisted.web.server import Site, GzipEncoderFactory, Request +from twisted.web.server import GzipEncoderFactory from synapse.http.server import RootRedirect from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.rest.media.v1.media_repository import MediaRepositoryResource @@ -62,10 +51,18 @@ from synapse.api.urls import ( from synapse.config.homeserver import HomeServerConfig from synapse.crypto import context_factory from synapse.util.logcontext import LoggingContext +from synapse.metrics import register_memory_metrics, get_metrics_for from synapse.metrics.resource import MetricsResource, METRICS_PREFIX from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX from synapse.federation.transport.server import TransportLayerServer +from synapse.util.rlimit import change_resource_limit +from synapse.util.versionstring import get_version_string +from synapse.util.httpresourcetree import create_resource_tree +from synapse.util.manhole import manhole + +from synapse.http.site import SynapseSite + from synapse import events from daemonize import Daemonize @@ -73,9 +70,6 @@ from daemonize import Daemonize logger = logging.getLogger("synapse.app.homeserver") -ACCESS_TOKEN_RE = re.compile(r'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$') - - def gz_wrap(r): return EncodingResourceWrapper(r, [GzipEncoderFactory()]) @@ -154,7 +148,7 @@ class SynapseHomeServer(HomeServer): MEDIA_PREFIX: media_repo, LEGACY_MEDIA_PREFIX: media_repo, CONTENT_REPO_PREFIX: ContentRepoResource( - self, self.config.uploads_path, self.auth, self.content_addr + self, self.config.uploads_path ), }) @@ -173,7 +167,12 @@ class SynapseHomeServer(HomeServer): if name == "replication": resources[REPLICATION_PREFIX] = ReplicationResource(self) - root_resource = create_resource_tree(resources) + if WEB_CLIENT_PREFIX in resources: + root_resource = RootRedirect(WEB_CLIENT_PREFIX) + else: + root_resource = Resource() + + root_resource = create_resource_tree(resources, root_resource) if tls: reactor.listenSSL( port, @@ -206,24 +205,13 @@ class SynapseHomeServer(HomeServer): if listener["type"] == "http": self._listener_http(config, listener) elif listener["type"] == "manhole": - checker = checkers.InMemoryUsernamePasswordDatabaseDontUse( - matrix="rabbithole" - ) - - rlm = manhole_ssh.TerminalRealm() - rlm.chainedProtocolFactory = lambda: insults.ServerProtocol( - ColoredManhole, - { - "__name__": "__console__", - "hs": self, - } - ) - - f = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker])) - reactor.listenTCP( listener["port"], - f, + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, + ), interface=listener.get("bind_address", '127.0.0.1') ) else: @@ -245,7 +233,7 @@ class SynapseHomeServer(HomeServer): except IncorrectDatabaseSetup as e: quit_with_error(e.message) - def get_db_conn(self): + def get_db_conn(self, run_new_connection=True): # Any param beginning with cp_ is a parameter for adbapi, and should # not be passed to the database engine. db_params = { @@ -254,7 +242,8 @@ class SynapseHomeServer(HomeServer): } db_conn = self.database_engine.module.connect(**db_params) - self.database_engine.on_new_connection(db_conn) + if run_new_connection: + self.database_engine.on_new_connection(db_conn) return db_conn @@ -268,86 +257,6 @@ def quit_with_error(error_string): sys.exit(1) -def get_version_string(): - try: - null = open(os.devnull, 'w') - cwd = os.path.dirname(os.path.abspath(__file__)) - try: - git_branch = subprocess.check_output( - ['git', 'rev-parse', '--abbrev-ref', 'HEAD'], - stderr=null, - cwd=cwd, - ).strip() - git_branch = "b=" + git_branch - except subprocess.CalledProcessError: - git_branch = "" - - try: - git_tag = subprocess.check_output( - ['git', 'describe', '--exact-match'], - stderr=null, - cwd=cwd, - ).strip() - git_tag = "t=" + git_tag - except subprocess.CalledProcessError: - git_tag = "" - - try: - git_commit = subprocess.check_output( - ['git', 'rev-parse', '--short', 'HEAD'], - stderr=null, - cwd=cwd, - ).strip() - except subprocess.CalledProcessError: - git_commit = "" - - try: - dirty_string = "-this_is_a_dirty_checkout" - is_dirty = subprocess.check_output( - ['git', 'describe', '--dirty=' + dirty_string], - stderr=null, - cwd=cwd, - ).strip().endswith(dirty_string) - - git_dirty = "dirty" if is_dirty else "" - except subprocess.CalledProcessError: - git_dirty = "" - - if git_branch or git_tag or git_commit or git_dirty: - git_version = ",".join( - s for s in - (git_branch, git_tag, git_commit, git_dirty,) - if s - ) - - return ( - "Synapse/%s (%s)" % ( - synapse.__version__, git_version, - ) - ).encode("ascii") - except Exception as e: - logger.info("Failed to check for git repository: %s", e) - - return ("Synapse/%s" % (synapse.__version__,)).encode("ascii") - - -def change_resource_limit(soft_file_no): - try: - soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) - - if not soft_file_no: - soft_file_no = hard - - resource.setrlimit(resource.RLIMIT_NOFILE, (soft_file_no, hard)) - logger.info("Set file limit to: %d", soft_file_no) - - resource.setrlimit( - resource.RLIMIT_CORE, (resource.RLIM_INFINITY, resource.RLIM_INFINITY) - ) - except (ValueError, resource.error) as e: - logger.warn("Failed to set file or core limit: %s", e) - - def setup(config_options): """ Args: @@ -358,10 +267,9 @@ def setup(config_options): HomeServer """ try: - config = HomeServerConfig.load_config( + config = HomeServerConfig.load_or_generate_config( "Synapse Homeserver", config_options, - generate_section="Homeserver" ) except ConfigError as e: sys.stderr.write("\n" + e.message + "\n") @@ -377,7 +285,7 @@ def setup(config_options): # check any extra requirements we have now we have a config check_requirements(config) - version_string = get_version_string() + version_string = "Synapse/" + get_version_string(synapse) logger.info("Server hostname: %s", config.server_name) logger.info("Server version: %s", version_string) @@ -386,7 +294,7 @@ def setup(config_options): tls_server_context_factory = context_factory.ServerContextFactory(config) - database_engine = create_engine(config) + database_engine = create_engine(config.database_config) config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection hs = SynapseHomeServer( @@ -394,7 +302,6 @@ def setup(config_options): db_config=config.database_config, tls_server_context_factory=tls_server_context_factory, config=config, - content_addr=config.content_addr, version_string=version_string, database_engine=database_engine, ) @@ -402,8 +309,10 @@ def setup(config_options): logger.info("Preparing database: %s...", config.database_config['name']) try: - db_conn = hs.get_db_conn() - database_engine.prepare_database(db_conn) + db_conn = hs.get_db_conn(run_new_connection=False) + prepare_database(db_conn, database_engine, config=config) + database_engine.on_new_connection(db_conn) + hs.run_startup_checks(db_conn, database_engine) db_conn.commit() @@ -427,6 +336,8 @@ def setup(config_options): hs.get_datastore().start_doing_background_updates() hs.get_replication_layer().start_get_pdu_cache() + register_memory_metrics(hs) + reactor.callWhenRunning(start) return hs @@ -442,215 +353,13 @@ class SynapseService(service.Service): def startService(self): hs = setup(self.config) change_resource_limit(hs.config.soft_file_limit) + if hs.config.gc_thresholds: + gc.set_threshold(*hs.config.gc_thresholds) def stopService(self): return self._port.stopListening() -class SynapseRequest(Request): - def __init__(self, site, *args, **kw): - Request.__init__(self, *args, **kw) - self.site = site - self.authenticated_entity = None - self.start_time = 0 - - def __repr__(self): - # We overwrite this so that we don't log ``access_token`` - return '<%s at 0x%x method=%s uri=%s clientproto=%s site=%s>' % ( - self.__class__.__name__, - id(self), - self.method, - self.get_redacted_uri(), - self.clientproto, - self.site.site_tag, - ) - - def get_redacted_uri(self): - return ACCESS_TOKEN_RE.sub( - r'\1<redacted>\3', - self.uri - ) - - def get_user_agent(self): - return self.requestHeaders.getRawHeaders("User-Agent", [None])[-1] - - def started_processing(self): - self.site.access_logger.info( - "%s - %s - Received request: %s %s", - self.getClientIP(), - self.site.site_tag, - self.method, - self.get_redacted_uri() - ) - self.start_time = int(time.time() * 1000) - - def finished_processing(self): - - try: - context = LoggingContext.current_context() - ru_utime, ru_stime = context.get_resource_usage() - db_txn_count = context.db_txn_count - db_txn_duration = context.db_txn_duration - except: - ru_utime, ru_stime = (0, 0) - db_txn_count, db_txn_duration = (0, 0) - - self.site.access_logger.info( - "%s - %s - {%s}" - " Processed request: %dms (%dms, %dms) (%dms/%d)" - " %sB %s \"%s %s %s\" \"%s\"", - self.getClientIP(), - self.site.site_tag, - self.authenticated_entity, - int(time.time() * 1000) - self.start_time, - int(ru_utime * 1000), - int(ru_stime * 1000), - int(db_txn_duration * 1000), - int(db_txn_count), - self.sentLength, - self.code, - self.method, - self.get_redacted_uri(), - self.clientproto, - self.get_user_agent(), - ) - - @contextlib.contextmanager - def processing(self): - self.started_processing() - yield - self.finished_processing() - - -class XForwardedForRequest(SynapseRequest): - def __init__(self, *args, **kw): - SynapseRequest.__init__(self, *args, **kw) - - """ - Add a layer on top of another request that only uses the value of an - X-Forwarded-For header as the result of C{getClientIP}. - """ - def getClientIP(self): - """ - @return: The client address (the first address) in the value of the - I{X-Forwarded-For header}. If the header is not present, return - C{b"-"}. - """ - return self.requestHeaders.getRawHeaders( - b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip() - - -class SynapseRequestFactory(object): - def __init__(self, site, x_forwarded_for): - self.site = site - self.x_forwarded_for = x_forwarded_for - - def __call__(self, *args, **kwargs): - if self.x_forwarded_for: - return XForwardedForRequest(self.site, *args, **kwargs) - else: - return SynapseRequest(self.site, *args, **kwargs) - - -class SynapseSite(Site): - """ - Subclass of a twisted http Site that does access logging with python's - standard logging - """ - def __init__(self, logger_name, site_tag, config, resource, *args, **kwargs): - Site.__init__(self, resource, *args, **kwargs) - - self.site_tag = site_tag - - proxied = config.get("x_forwarded", False) - self.requestFactory = SynapseRequestFactory(self, proxied) - self.access_logger = logging.getLogger(logger_name) - - def log(self, request): - pass - - -def create_resource_tree(desired_tree, redirect_root_to_web_client=True): - """Create the resource tree for this Home Server. - - This in unduly complicated because Twisted does not support putting - child resources more than 1 level deep at a time. - - Args: - web_client (bool): True to enable the web client. - redirect_root_to_web_client (bool): True to redirect '/' to the - location of the web client. This does nothing if web_client is not - True. - """ - if redirect_root_to_web_client and WEB_CLIENT_PREFIX in desired_tree: - root_resource = RootRedirect(WEB_CLIENT_PREFIX) - else: - root_resource = Resource() - - # ideally we'd just use getChild and putChild but getChild doesn't work - # unless you give it a Request object IN ADDITION to the name :/ So - # instead, we'll store a copy of this mapping so we can actually add - # extra resources to existing nodes. See self._resource_id for the key. - resource_mappings = {} - for full_path, res in desired_tree.items(): - logger.info("Attaching %s to path %s", res, full_path) - last_resource = root_resource - for path_seg in full_path.split('/')[1:-1]: - if path_seg not in last_resource.listNames(): - # resource doesn't exist, so make a "dummy resource" - child_resource = Resource() - last_resource.putChild(path_seg, child_resource) - res_id = _resource_id(last_resource, path_seg) - resource_mappings[res_id] = child_resource - last_resource = child_resource - else: - # we have an existing Resource, use that instead. - res_id = _resource_id(last_resource, path_seg) - last_resource = resource_mappings[res_id] - - # =========================== - # now attach the actual desired resource - last_path_seg = full_path.split('/')[-1] - - # if there is already a resource here, thieve its children and - # replace it - res_id = _resource_id(last_resource, last_path_seg) - if res_id in resource_mappings: - # there is a dummy resource at this path already, which needs - # to be replaced with the desired resource. - existing_dummy_resource = resource_mappings[res_id] - for child_name in existing_dummy_resource.listNames(): - child_res_id = _resource_id( - existing_dummy_resource, child_name - ) - child_resource = resource_mappings[child_res_id] - # steal the children - res.putChild(child_name, child_resource) - - # finally, insert the desired resource in the right place - last_resource.putChild(last_path_seg, res) - res_id = _resource_id(last_resource, last_path_seg) - resource_mappings[res_id] = res - - return root_resource - - -def _resource_id(resource, path_seg): - """Construct an arbitrary resource ID so you can retrieve the mapping - later. - - If you want to represent resource A putChild resource B with path C, - the mapping should looks like _resource_id(A,C) = B. - - Args: - resource (Resource): The *parent* Resourceb - path_seg (str): The name of the child Resource to be attached. - Returns: - str: A unique string which can be a key to the child Resource. - """ - return "%s-%s" % (resource, path_seg) - - def run(hs): PROFILE_SYNAPSE = False if PROFILE_SYNAPSE: @@ -676,6 +385,8 @@ def run(hs): start_time = hs.get_clock().time() + stats = {} + @defer.inlineCallbacks def phone_stats_home(): logger.info("Gathering stats for reporting") @@ -684,7 +395,10 @@ def run(hs): if uptime < 0: uptime = 0 - stats = {} + # If the stats directory is empty then this is the first time we've + # reported stats. + first_time = not stats + stats["homeserver"] = hs.config.server_name stats["timestamp"] = now stats["uptime_seconds"] = uptime @@ -697,6 +411,25 @@ def run(hs): daily_messages = yield hs.get_datastore().count_daily_messages() if daily_messages is not None: stats["daily_messages"] = daily_messages + else: + stats.pop("daily_messages", None) + + if first_time: + # Add callbacks to report the synapse stats as metrics whenever + # prometheus requests them, typically every 30s. + # As some of the stats are expensive to calculate we only update + # them when synapse phones home to matrix.org every 24 hours. + metrics = get_metrics_for("synapse.usage") + metrics.add_callback("timestamp", lambda: stats["timestamp"]) + metrics.add_callback("uptime_seconds", lambda: stats["uptime_seconds"]) + metrics.add_callback("total_users", lambda: stats["total_users"]) + metrics.add_callback("total_room_count", lambda: stats["total_room_count"]) + metrics.add_callback( + "daily_active_users", lambda: stats["daily_active_users"] + ) + metrics.add_callback( + "daily_messages", lambda: stats.get("daily_messages", 0) + ) logger.info("Reporting stats to matrix.org: %s" % (stats,)) try: @@ -717,6 +450,8 @@ def run(hs): # sys.settrace(logcontext_tracer) with LoggingContext("run"): change_resource_limit(hs.config.soft_file_limit) + if hs.config.gc_thresholds: + gc.set_threshold(*hs.config.gc_thresholds) reactor.run() if hs.config.daemonize: diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py new file mode 100644 index 0000000000..44c19a1bef --- /dev/null +++ b/synapse/app/media_repository.py @@ -0,0 +1,217 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import synapse + +from synapse.config._base import ConfigError +from synapse.config.homeserver import HomeServerConfig +from synapse.config.logger import setup_logging +from synapse.http.site import SynapseSite +from synapse.metrics.resource import MetricsResource, METRICS_PREFIX +from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore +from synapse.replication.slave.storage.registration import SlavedRegistrationStore +from synapse.rest.media.v0.content_repository import ContentRepoResource +from synapse.rest.media.v1.media_repository import MediaRepositoryResource +from synapse.server import HomeServer +from synapse.storage.client_ips import ClientIpStore +from synapse.storage.engines import create_engine +from synapse.storage.media_repository import MediaRepositoryStore +from synapse.util.async import sleep +from synapse.util.httpresourcetree import create_resource_tree +from synapse.util.logcontext import LoggingContext +from synapse.util.manhole import manhole +from synapse.util.rlimit import change_resource_limit +from synapse.util.versionstring import get_version_string +from synapse.api.urls import ( + CONTENT_REPO_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX +) +from synapse.crypto import context_factory + +from synapse import events + + +from twisted.internet import reactor, defer +from twisted.web.resource import Resource + +from daemonize import Daemonize + +import sys +import logging +import gc + +logger = logging.getLogger("synapse.app.media_repository") + + +class MediaRepositorySlavedStore( + SlavedApplicationServiceStore, + SlavedRegistrationStore, + BaseSlavedStore, + MediaRepositoryStore, + ClientIpStore, +): + pass + + +class MediaRepositoryServer(HomeServer): + def get_db_conn(self, run_new_connection=True): + # Any param beginning with cp_ is a parameter for adbapi, and should + # not be passed to the database engine. + db_params = { + k: v for k, v in self.db_config.get("args", {}).items() + if not k.startswith("cp_") + } + db_conn = self.database_engine.module.connect(**db_params) + + if run_new_connection: + self.database_engine.on_new_connection(db_conn) + return db_conn + + def setup(self): + logger.info("Setting up.") + self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self) + logger.info("Finished setting up.") + + def _listen_http(self, listener_config): + port = listener_config["port"] + bind_address = listener_config.get("bind_address", "") + site_tag = listener_config.get("tag", port) + resources = {} + for res in listener_config["resources"]: + for name in res["names"]: + if name == "metrics": + resources[METRICS_PREFIX] = MetricsResource(self) + elif name == "media": + media_repo = MediaRepositoryResource(self) + resources.update({ + MEDIA_PREFIX: media_repo, + LEGACY_MEDIA_PREFIX: media_repo, + CONTENT_REPO_PREFIX: ContentRepoResource( + self, self.config.uploads_path + ), + }) + + root_resource = create_resource_tree(resources, Resource()) + reactor.listenTCP( + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + ), + interface=bind_address + ) + logger.info("Synapse media repository now listening on port %d", port) + + def start_listening(self, listeners): + for listener in listeners: + if listener["type"] == "http": + self._listen_http(listener) + elif listener["type"] == "manhole": + reactor.listenTCP( + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, + ), + interface=listener.get("bind_address", '127.0.0.1') + ) + else: + logger.warn("Unrecognized listener type: %s", listener["type"]) + + @defer.inlineCallbacks + def replicate(self): + http_client = self.get_simple_http_client() + store = self.get_datastore() + replication_url = self.config.worker_replication_url + + while True: + try: + args = store.stream_positions() + args["timeout"] = 30000 + result = yield http_client.get_json(replication_url, args=args) + yield store.process_replication(result) + except: + logger.exception("Error replicating from %r", replication_url) + yield sleep(5) + + +def start(config_options): + try: + config = HomeServerConfig.load_config( + "Synapse media repository", config_options + ) + except ConfigError as e: + sys.stderr.write("\n" + e.message + "\n") + sys.exit(1) + + assert config.worker_app == "synapse.app.media_repository" + + setup_logging(config.worker_log_config, config.worker_log_file) + + events.USE_FROZEN_DICTS = config.use_frozen_dicts + + database_engine = create_engine(config.database_config) + + tls_server_context_factory = context_factory.ServerContextFactory(config) + + ss = MediaRepositoryServer( + config.server_name, + db_config=config.database_config, + tls_server_context_factory=tls_server_context_factory, + config=config, + version_string="Synapse/" + get_version_string(synapse), + database_engine=database_engine, + ) + + ss.setup() + ss.get_handlers() + ss.start_listening(config.worker_listeners) + + def run(): + with LoggingContext("run"): + logger.info("Running") + change_resource_limit(config.soft_file_limit) + if config.gc_thresholds: + gc.set_threshold(*config.gc_thresholds) + reactor.run() + + def start(): + ss.get_state_handler().start_caching() + ss.get_datastore().start_profiling() + ss.replicate() + + reactor.callWhenRunning(start) + + if config.worker_daemonize: + daemon = Daemonize( + app="synapse-media-repository", + pid=config.worker_pid_file, + action=run, + auto_close_fds=False, + verbose=True, + logger=logger, + ) + daemon.start() + else: + run() + + +if __name__ == '__main__': + with LoggingContext("main"): + start(sys.argv[1:]) diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py new file mode 100644 index 0000000000..a0e765c54f --- /dev/null +++ b/synapse/app/pusher.py @@ -0,0 +1,303 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import synapse + +from synapse.server import HomeServer +from synapse.config._base import ConfigError +from synapse.config.logger import setup_logging +from synapse.config.homeserver import HomeServerConfig +from synapse.http.site import SynapseSite +from synapse.metrics.resource import MetricsResource, METRICS_PREFIX +from synapse.storage.roommember import RoomMemberStore +from synapse.replication.slave.storage.events import SlavedEventStore +from synapse.replication.slave.storage.pushers import SlavedPusherStore +from synapse.replication.slave.storage.receipts import SlavedReceiptsStore +from synapse.replication.slave.storage.account_data import SlavedAccountDataStore +from synapse.storage.engines import create_engine +from synapse.storage import DataStore +from synapse.util.async import sleep +from synapse.util.httpresourcetree import create_resource_tree +from synapse.util.logcontext import LoggingContext, preserve_fn +from synapse.util.manhole import manhole +from synapse.util.rlimit import change_resource_limit +from synapse.util.versionstring import get_version_string + +from synapse import events + +from twisted.internet import reactor, defer +from twisted.web.resource import Resource + +from daemonize import Daemonize + +import sys +import logging +import gc + +logger = logging.getLogger("synapse.app.pusher") + + +class PusherSlaveStore( + SlavedEventStore, SlavedPusherStore, SlavedReceiptsStore, + SlavedAccountDataStore +): + update_pusher_last_stream_ordering_and_success = ( + DataStore.update_pusher_last_stream_ordering_and_success.__func__ + ) + + update_pusher_failing_since = ( + DataStore.update_pusher_failing_since.__func__ + ) + + update_pusher_last_stream_ordering = ( + DataStore.update_pusher_last_stream_ordering.__func__ + ) + + get_throttle_params_by_room = ( + DataStore.get_throttle_params_by_room.__func__ + ) + + set_throttle_params = ( + DataStore.set_throttle_params.__func__ + ) + + get_time_of_last_push_action_before = ( + DataStore.get_time_of_last_push_action_before.__func__ + ) + + get_profile_displayname = ( + DataStore.get_profile_displayname.__func__ + ) + + who_forgot_in_room = ( + RoomMemberStore.__dict__["who_forgot_in_room"] + ) + + +class PusherServer(HomeServer): + + def get_db_conn(self, run_new_connection=True): + # Any param beginning with cp_ is a parameter for adbapi, and should + # not be passed to the database engine. + db_params = { + k: v for k, v in self.db_config.get("args", {}).items() + if not k.startswith("cp_") + } + db_conn = self.database_engine.module.connect(**db_params) + + if run_new_connection: + self.database_engine.on_new_connection(db_conn) + return db_conn + + def setup(self): + logger.info("Setting up.") + self.datastore = PusherSlaveStore(self.get_db_conn(), self) + logger.info("Finished setting up.") + + def remove_pusher(self, app_id, push_key, user_id): + http_client = self.get_simple_http_client() + replication_url = self.config.worker_replication_url + url = replication_url + "/remove_pushers" + return http_client.post_json_get_json(url, { + "remove": [{ + "app_id": app_id, + "push_key": push_key, + "user_id": user_id, + }] + }) + + def _listen_http(self, listener_config): + port = listener_config["port"] + bind_address = listener_config.get("bind_address", "") + site_tag = listener_config.get("tag", port) + resources = {} + for res in listener_config["resources"]: + for name in res["names"]: + if name == "metrics": + resources[METRICS_PREFIX] = MetricsResource(self) + + root_resource = create_resource_tree(resources, Resource()) + reactor.listenTCP( + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + ), + interface=bind_address + ) + logger.info("Synapse pusher now listening on port %d", port) + + def start_listening(self, listeners): + for listener in listeners: + if listener["type"] == "http": + self._listen_http(listener) + elif listener["type"] == "manhole": + reactor.listenTCP( + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, + ), + interface=listener.get("bind_address", '127.0.0.1') + ) + else: + logger.warn("Unrecognized listener type: %s", listener["type"]) + + @defer.inlineCallbacks + def replicate(self): + http_client = self.get_simple_http_client() + store = self.get_datastore() + replication_url = self.config.worker_replication_url + pusher_pool = self.get_pusherpool() + + def stop_pusher(user_id, app_id, pushkey): + key = "%s:%s" % (app_id, pushkey) + pushers_for_user = pusher_pool.pushers.get(user_id, {}) + pusher = pushers_for_user.pop(key, None) + if pusher is None: + return + logger.info("Stopping pusher %r / %r", user_id, key) + pusher.on_stop() + + def start_pusher(user_id, app_id, pushkey): + key = "%s:%s" % (app_id, pushkey) + logger.info("Starting pusher %r / %r", user_id, key) + return pusher_pool._refresh_pusher(app_id, pushkey, user_id) + + @defer.inlineCallbacks + def poke_pushers(results): + pushers_rows = set( + map(tuple, results.get("pushers", {}).get("rows", [])) + ) + deleted_pushers_rows = set( + map(tuple, results.get("deleted_pushers", {}).get("rows", [])) + ) + for row in sorted(pushers_rows | deleted_pushers_rows): + if row in deleted_pushers_rows: + user_id, app_id, pushkey = row[1:4] + stop_pusher(user_id, app_id, pushkey) + elif row in pushers_rows: + user_id = row[1] + app_id = row[5] + pushkey = row[8] + yield start_pusher(user_id, app_id, pushkey) + + stream = results.get("events") + if stream and stream["rows"]: + min_stream_id = stream["rows"][0][0] + max_stream_id = stream["position"] + preserve_fn(pusher_pool.on_new_notifications)( + min_stream_id, max_stream_id + ) + + stream = results.get("receipts") + if stream and stream["rows"]: + rows = stream["rows"] + affected_room_ids = set(row[1] for row in rows) + min_stream_id = rows[0][0] + max_stream_id = stream["position"] + preserve_fn(pusher_pool.on_new_receipts)( + min_stream_id, max_stream_id, affected_room_ids + ) + + while True: + try: + args = store.stream_positions() + args["timeout"] = 30000 + result = yield http_client.get_json(replication_url, args=args) + yield store.process_replication(result) + poke_pushers(result) + except: + logger.exception("Error replicating from %r", replication_url) + yield sleep(30) + + +def start(config_options): + try: + config = HomeServerConfig.load_config( + "Synapse pusher", config_options + ) + except ConfigError as e: + sys.stderr.write("\n" + e.message + "\n") + sys.exit(1) + + assert config.worker_app == "synapse.app.pusher" + + setup_logging(config.worker_log_config, config.worker_log_file) + + events.USE_FROZEN_DICTS = config.use_frozen_dicts + + if config.start_pushers: + sys.stderr.write( + "\nThe pushers must be disabled in the main synapse process" + "\nbefore they can be run in a separate worker." + "\nPlease add ``start_pushers: false`` to the main config" + "\n" + ) + sys.exit(1) + + # Force the pushers to start since they will be disabled in the main config + config.start_pushers = True + + database_engine = create_engine(config.database_config) + + ps = PusherServer( + config.server_name, + db_config=config.database_config, + config=config, + version_string="Synapse/" + get_version_string(synapse), + database_engine=database_engine, + ) + + ps.setup() + ps.start_listening(config.worker_listeners) + + def run(): + with LoggingContext("run"): + logger.info("Running") + change_resource_limit(config.soft_file_limit) + if config.gc_thresholds: + gc.set_threshold(*config.gc_thresholds) + reactor.run() + + def start(): + ps.replicate() + ps.get_pusherpool().start() + ps.get_datastore().start_profiling() + ps.get_state_handler().start_caching() + + reactor.callWhenRunning(start) + + if config.worker_daemonize: + daemon = Daemonize( + app="synapse-pusher", + pid=config.worker_pid_file, + action=run, + auto_close_fds=False, + verbose=True, + logger=logger, + ) + daemon.start() + else: + run() + + +if __name__ == '__main__': + with LoggingContext("main"): + ps = start(sys.argv[1:]) diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py new file mode 100644 index 0000000000..bf1b995dc2 --- /dev/null +++ b/synapse/app/synchrotron.py @@ -0,0 +1,496 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import synapse + +from synapse.api.constants import EventTypes, PresenceState +from synapse.config._base import ConfigError +from synapse.config.homeserver import HomeServerConfig +from synapse.config.logger import setup_logging +from synapse.events import FrozenEvent +from synapse.handlers.presence import PresenceHandler +from synapse.http.site import SynapseSite +from synapse.http.server import JsonResource +from synapse.metrics.resource import MetricsResource, METRICS_PREFIX +from synapse.rest.client.v2_alpha import sync +from synapse.rest.client.v1 import events +from synapse.rest.client.v1.room import RoomInitialSyncRestServlet +from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet +from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.replication.slave.storage.events import SlavedEventStore +from synapse.replication.slave.storage.receipts import SlavedReceiptsStore +from synapse.replication.slave.storage.account_data import SlavedAccountDataStore +from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore +from synapse.replication.slave.storage.registration import SlavedRegistrationStore +from synapse.replication.slave.storage.filtering import SlavedFilteringStore +from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore +from synapse.replication.slave.storage.presence import SlavedPresenceStore +from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore +from synapse.replication.slave.storage.room import RoomStore +from synapse.server import HomeServer +from synapse.storage.client_ips import ClientIpStore +from synapse.storage.engines import create_engine +from synapse.storage.presence import PresenceStore, UserPresenceState +from synapse.storage.roommember import RoomMemberStore +from synapse.util.async import sleep +from synapse.util.httpresourcetree import create_resource_tree +from synapse.util.logcontext import LoggingContext, preserve_fn +from synapse.util.manhole import manhole +from synapse.util.rlimit import change_resource_limit +from synapse.util.stringutils import random_string +from synapse.util.versionstring import get_version_string + +from twisted.internet import reactor, defer +from twisted.web.resource import Resource + +from daemonize import Daemonize + +import sys +import logging +import contextlib +import gc +import ujson as json + +logger = logging.getLogger("synapse.app.synchrotron") + + +class SynchrotronSlavedStore( + SlavedPushRuleStore, + SlavedEventStore, + SlavedReceiptsStore, + SlavedAccountDataStore, + SlavedApplicationServiceStore, + SlavedRegistrationStore, + SlavedFilteringStore, + SlavedPresenceStore, + SlavedDeviceInboxStore, + RoomStore, + BaseSlavedStore, + ClientIpStore, # After BaseSlavedStore because the constructor is different +): + who_forgot_in_room = ( + RoomMemberStore.__dict__["who_forgot_in_room"] + ) + + # XXX: This is a bit broken because we don't persist the accepted list in a + # way that can be replicated. This means that we don't have a way to + # invalidate the cache correctly. + get_presence_list_accepted = PresenceStore.__dict__[ + "get_presence_list_accepted" + ] + get_presence_list_observers_accepted = PresenceStore.__dict__[ + "get_presence_list_observers_accepted" + ] + + +UPDATE_SYNCING_USERS_MS = 10 * 1000 + + +class SynchrotronPresence(object): + def __init__(self, hs): + self.is_mine_id = hs.is_mine_id + self.http_client = hs.get_simple_http_client() + self.store = hs.get_datastore() + self.user_to_num_current_syncs = {} + self.syncing_users_url = hs.config.worker_replication_url + "/syncing_users" + self.clock = hs.get_clock() + self.notifier = hs.get_notifier() + + active_presence = self.store.take_presence_startup_info() + self.user_to_current_state = { + state.user_id: state + for state in active_presence + } + + self.process_id = random_string(16) + logger.info("Presence process_id is %r", self.process_id) + + self._sending_sync = False + self._need_to_send_sync = False + self.clock.looping_call( + self._send_syncing_users_regularly, + UPDATE_SYNCING_USERS_MS, + ) + + reactor.addSystemEventTrigger("before", "shutdown", self._on_shutdown) + + def set_state(self, user, state, ignore_status_msg=False): + # TODO Hows this supposed to work? + pass + + get_states = PresenceHandler.get_states.__func__ + get_state = PresenceHandler.get_state.__func__ + _get_interested_parties = PresenceHandler._get_interested_parties.__func__ + current_state_for_users = PresenceHandler.current_state_for_users.__func__ + + @defer.inlineCallbacks + def user_syncing(self, user_id, affect_presence): + if affect_presence: + curr_sync = self.user_to_num_current_syncs.get(user_id, 0) + self.user_to_num_current_syncs[user_id] = curr_sync + 1 + prev_states = yield self.current_state_for_users([user_id]) + if prev_states[user_id].state == PresenceState.OFFLINE: + # TODO: Don't block the sync request on this HTTP hit. + yield self._send_syncing_users_now() + + def _end(): + # We check that the user_id is in user_to_num_current_syncs because + # user_to_num_current_syncs may have been cleared if we are + # shutting down. + if affect_presence and user_id in self.user_to_num_current_syncs: + self.user_to_num_current_syncs[user_id] -= 1 + + @contextlib.contextmanager + def _user_syncing(): + try: + yield + finally: + _end() + + defer.returnValue(_user_syncing()) + + @defer.inlineCallbacks + def _on_shutdown(self): + # When the synchrotron is shutdown tell the master to clear the in + # progress syncs for this process + self.user_to_num_current_syncs.clear() + yield self._send_syncing_users_now() + + def _send_syncing_users_regularly(self): + # Only send an update if we aren't in the middle of sending one. + if not self._sending_sync: + preserve_fn(self._send_syncing_users_now)() + + @defer.inlineCallbacks + def _send_syncing_users_now(self): + if self._sending_sync: + # We don't want to race with sending another update. + # Instead we wait for that update to finish and send another + # update afterwards. + self._need_to_send_sync = True + return + + # Flag that we are sending an update. + self._sending_sync = True + + yield self.http_client.post_json_get_json(self.syncing_users_url, { + "process_id": self.process_id, + "syncing_users": [ + user_id for user_id, count in self.user_to_num_current_syncs.items() + if count > 0 + ], + }) + + # Unset the flag as we are no longer sending an update. + self._sending_sync = False + if self._need_to_send_sync: + # If something happened while we were sending the update then + # we might need to send another update. + # TODO: Check if the update that was sent matches the current state + # as we only need to send an update if they are different. + self._need_to_send_sync = False + yield self._send_syncing_users_now() + + @defer.inlineCallbacks + def notify_from_replication(self, states, stream_id): + parties = yield self._get_interested_parties( + states, calculate_remote_hosts=False + ) + room_ids_to_states, users_to_states, _ = parties + + self.notifier.on_new_event( + "presence_key", stream_id, rooms=room_ids_to_states.keys(), + users=users_to_states.keys() + ) + + @defer.inlineCallbacks + def process_replication(self, result): + stream = result.get("presence", {"rows": []}) + states = [] + for row in stream["rows"]: + ( + position, user_id, state, last_active_ts, + last_federation_update_ts, last_user_sync_ts, status_msg, + currently_active + ) = row + state = UserPresenceState( + user_id, state, last_active_ts, + last_federation_update_ts, last_user_sync_ts, status_msg, + currently_active + ) + self.user_to_current_state[user_id] = state + states.append(state) + + if states and "position" in stream: + stream_id = int(stream["position"]) + yield self.notify_from_replication(states, stream_id) + + +class SynchrotronTyping(object): + def __init__(self, hs): + self._latest_room_serial = 0 + self._room_serials = {} + self._room_typing = {} + + def stream_positions(self): + # We must update this typing token from the response of the previous + # sync. In particular, the stream id may "reset" back to zero/a low + # value which we *must* use for the next replication request. + return {"typing": self._latest_room_serial} + + def process_replication(self, result): + stream = result.get("typing") + if stream: + self._latest_room_serial = int(stream["position"]) + + for row in stream["rows"]: + position, room_id, typing_json = row + typing = json.loads(typing_json) + self._room_serials[room_id] = position + self._room_typing[room_id] = typing + + +class SynchrotronApplicationService(object): + def notify_interested_services(self, event): + pass + + +class SynchrotronServer(HomeServer): + def get_db_conn(self, run_new_connection=True): + # Any param beginning with cp_ is a parameter for adbapi, and should + # not be passed to the database engine. + db_params = { + k: v for k, v in self.db_config.get("args", {}).items() + if not k.startswith("cp_") + } + db_conn = self.database_engine.module.connect(**db_params) + + if run_new_connection: + self.database_engine.on_new_connection(db_conn) + return db_conn + + def setup(self): + logger.info("Setting up.") + self.datastore = SynchrotronSlavedStore(self.get_db_conn(), self) + logger.info("Finished setting up.") + + def _listen_http(self, listener_config): + port = listener_config["port"] + bind_address = listener_config.get("bind_address", "") + site_tag = listener_config.get("tag", port) + resources = {} + for res in listener_config["resources"]: + for name in res["names"]: + if name == "metrics": + resources[METRICS_PREFIX] = MetricsResource(self) + elif name == "client": + resource = JsonResource(self, canonical_json=False) + sync.register_servlets(self, resource) + events.register_servlets(self, resource) + InitialSyncRestServlet(self).register(resource) + RoomInitialSyncRestServlet(self).register(resource) + resources.update({ + "/_matrix/client/r0": resource, + "/_matrix/client/unstable": resource, + "/_matrix/client/v2_alpha": resource, + "/_matrix/client/api/v1": resource, + }) + + root_resource = create_resource_tree(resources, Resource()) + reactor.listenTCP( + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + ), + interface=bind_address + ) + logger.info("Synapse synchrotron now listening on port %d", port) + + def start_listening(self, listeners): + for listener in listeners: + if listener["type"] == "http": + self._listen_http(listener) + elif listener["type"] == "manhole": + reactor.listenTCP( + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, + ), + interface=listener.get("bind_address", '127.0.0.1') + ) + else: + logger.warn("Unrecognized listener type: %s", listener["type"]) + + @defer.inlineCallbacks + def replicate(self): + http_client = self.get_simple_http_client() + store = self.get_datastore() + replication_url = self.config.worker_replication_url + notifier = self.get_notifier() + presence_handler = self.get_presence_handler() + typing_handler = self.get_typing_handler() + + def notify_from_stream( + result, stream_name, stream_key, room=None, user=None + ): + stream = result.get(stream_name) + if stream: + position_index = stream["field_names"].index("position") + if room: + room_index = stream["field_names"].index(room) + if user: + user_index = stream["field_names"].index(user) + + users = () + rooms = () + for row in stream["rows"]: + position = row[position_index] + + if user: + users = (row[user_index],) + + if room: + rooms = (row[room_index],) + + notifier.on_new_event( + stream_key, position, users=users, rooms=rooms + ) + + def notify(result): + stream = result.get("events") + if stream: + max_position = stream["position"] + for row in stream["rows"]: + position = row[0] + internal = json.loads(row[1]) + event_json = json.loads(row[2]) + event = FrozenEvent(event_json, internal_metadata_dict=internal) + extra_users = () + if event.type == EventTypes.Member: + extra_users = (event.state_key,) + notifier.on_new_room_event( + event, position, max_position, extra_users + ) + + notify_from_stream( + result, "push_rules", "push_rules_key", user="user_id" + ) + notify_from_stream( + result, "user_account_data", "account_data_key", user="user_id" + ) + notify_from_stream( + result, "room_account_data", "account_data_key", user="user_id" + ) + notify_from_stream( + result, "tag_account_data", "account_data_key", user="user_id" + ) + notify_from_stream( + result, "receipts", "receipt_key", room="room_id" + ) + notify_from_stream( + result, "typing", "typing_key", room="room_id" + ) + notify_from_stream( + result, "to_device", "to_device_key", user="user_id" + ) + + while True: + try: + args = store.stream_positions() + args.update(typing_handler.stream_positions()) + args["timeout"] = 30000 + result = yield http_client.get_json(replication_url, args=args) + yield store.process_replication(result) + typing_handler.process_replication(result) + yield presence_handler.process_replication(result) + notify(result) + except: + logger.exception("Error replicating from %r", replication_url) + yield sleep(5) + + def build_presence_handler(self): + return SynchrotronPresence(self) + + def build_typing_handler(self): + return SynchrotronTyping(self) + + +def start(config_options): + try: + config = HomeServerConfig.load_config( + "Synapse synchrotron", config_options + ) + except ConfigError as e: + sys.stderr.write("\n" + e.message + "\n") + sys.exit(1) + + assert config.worker_app == "synapse.app.synchrotron" + + setup_logging(config.worker_log_config, config.worker_log_file) + + synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts + + database_engine = create_engine(config.database_config) + + ss = SynchrotronServer( + config.server_name, + db_config=config.database_config, + config=config, + version_string="Synapse/" + get_version_string(synapse), + database_engine=database_engine, + application_service_handler=SynchrotronApplicationService(), + ) + + ss.setup() + ss.start_listening(config.worker_listeners) + + def run(): + with LoggingContext("run"): + logger.info("Running") + change_resource_limit(config.soft_file_limit) + if config.gc_thresholds: + gc.set_threshold(*config.gc_thresholds) + reactor.run() + + def start(): + ss.get_datastore().start_profiling() + ss.replicate() + ss.get_state_handler().start_caching() + + reactor.callWhenRunning(start) + + if config.worker_daemonize: + daemon = Daemonize( + app="synapse-synchrotron", + pid=config.worker_pid_file, + action=run, + auto_close_fds=False, + verbose=True, + logger=logger, + ) + daemon.start() + else: + run() + + +if __name__ == '__main__': + with LoggingContext("main"): + start(sys.argv[1:]) diff --git a/synapse/app/synctl.py b/synapse/app/synctl.py index ab3a31d7b7..c045588866 100755 --- a/synapse/app/synctl.py +++ b/synapse/app/synctl.py @@ -14,70 +14,198 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys +import argparse +import collections +import glob import os import os.path -import subprocess import signal +import subprocess +import sys import yaml -SYNAPSE = ["python", "-B", "-m", "synapse.app.homeserver"] +SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"] GREEN = "\x1b[1;32m" RED = "\x1b[1;31m" NORMAL = "\x1b[m" +def write(message, colour=NORMAL, stream=sys.stdout): + if colour == NORMAL: + stream.write(message + "\n") + else: + stream.write(colour + message + NORMAL + "\n") + + def start(configfile): - print ("Starting ...") + write("Starting ...") args = SYNAPSE args.extend(["--daemonize", "-c", configfile]) try: subprocess.check_call(args) - print (GREEN + "started" + NORMAL) + write("started synapse.app.homeserver(%r)" % (configfile,), colour=GREEN) + except subprocess.CalledProcessError as e: + write( + "error starting (exit code: %d); see above for logs" % e.returncode, + colour=RED, + ) + + +def start_worker(app, configfile, worker_configfile): + args = [ + "python", "-B", + "-m", app, + "-c", configfile, + "-c", worker_configfile + ] + + try: + subprocess.check_call(args) + write("started %s(%r)" % (app, worker_configfile), colour=GREEN) except subprocess.CalledProcessError as e: - print ( - RED + - "error starting (exit code: %d); see above for logs" % e.returncode + - NORMAL + write( + "error starting %s(%r) (exit code: %d); see above for logs" % ( + app, worker_configfile, e.returncode, + ), + colour=RED, ) -def stop(pidfile): +def stop(pidfile, app): if os.path.exists(pidfile): pid = int(open(pidfile).read()) os.kill(pid, signal.SIGTERM) - print (GREEN + "stopped" + NORMAL) + write("stopped %s" % (app,), colour=GREEN) + + +Worker = collections.namedtuple("Worker", [ + "app", "configfile", "pidfile", "cache_factor" +]) def main(): - configfile = sys.argv[2] if len(sys.argv) == 3 else "homeserver.yaml" + + parser = argparse.ArgumentParser() + + parser.add_argument( + "action", + choices=["start", "stop", "restart"], + help="whether to start, stop or restart the synapse", + ) + parser.add_argument( + "configfile", + nargs="?", + default="homeserver.yaml", + help="the homeserver config file, defaults to homserver.yaml", + ) + parser.add_argument( + "-w", "--worker", + metavar="WORKERCONFIG", + help="start or stop a single worker", + ) + parser.add_argument( + "-a", "--all-processes", + metavar="WORKERCONFIGDIR", + help="start or stop all the workers in the given directory" + " and the main synapse process", + ) + + options = parser.parse_args() + + if options.worker and options.all_processes: + write( + 'Cannot use "--worker" with "--all-processes"', + stream=sys.stderr + ) + sys.exit(1) + + configfile = options.configfile if not os.path.exists(configfile): - sys.stderr.write( + write( "No config file found\n" "To generate a config file, run '%s -c %s --generate-config" " --server-name=<server name>'\n" % ( - " ".join(SYNAPSE), configfile - ) + " ".join(SYNAPSE), options.configfile + ), + stream=sys.stderr, ) sys.exit(1) - config = yaml.load(open(configfile)) + with open(configfile) as stream: + config = yaml.load(stream) + pidfile = config["pid_file"] + cache_factor = config.get("synctl_cache_factor") + start_stop_synapse = True - action = sys.argv[1] if sys.argv[1:] else "usage" - if action == "start": - start(configfile) - elif action == "stop": - stop(pidfile) - elif action == "restart": - stop(pidfile) - start(configfile) - else: - sys.stderr.write("Usage: %s [start|stop|restart] [configfile]\n" % (sys.argv[0],)) - sys.exit(1) + if cache_factor: + os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor) + + worker_configfiles = [] + if options.worker: + start_stop_synapse = False + worker_configfile = options.worker + if not os.path.exists(worker_configfile): + write( + "No worker config found at %r" % (worker_configfile,), + stream=sys.stderr, + ) + sys.exit(1) + worker_configfiles.append(worker_configfile) + + if options.all_processes: + worker_configdir = options.all_processes + if not os.path.isdir(worker_configdir): + write( + "No worker config directory found at %r" % (worker_configdir,), + stream=sys.stderr, + ) + sys.exit(1) + worker_configfiles.extend(sorted(glob.glob( + os.path.join(worker_configdir, "*.yaml") + ))) + + workers = [] + for worker_configfile in worker_configfiles: + with open(worker_configfile) as stream: + worker_config = yaml.load(stream) + worker_app = worker_config["worker_app"] + worker_pidfile = worker_config["worker_pid_file"] + worker_daemonize = worker_config["worker_daemonize"] + assert worker_daemonize # TODO print something more user friendly + worker_cache_factor = worker_config.get("synctl_cache_factor") + workers.append(Worker( + worker_app, worker_configfile, worker_pidfile, worker_cache_factor, + )) + + action = options.action + + if action == "stop" or action == "restart": + for worker in workers: + stop(worker.pidfile, worker.app) + + if start_stop_synapse: + stop(pidfile, "synapse.app.homeserver") + + # TODO: Wait for synapse to actually shutdown before starting it again + + if action == "start" or action == "restart": + if start_stop_synapse: + start(configfile) + + for worker in workers: + if worker.cache_factor: + os.environ["SYNAPSE_CACHE_FACTOR"] = str(worker.cache_factor) + + start_worker(worker.app, configfile, worker.configfile) + + if cache_factor: + os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor) + else: + os.environ.pop("SYNAPSE_CACHE_FACTOR", None) if __name__ == "__main__": diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index f7178ea0d3..91471f7e89 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -14,6 +14,8 @@ # limitations under the License. from synapse.api.constants import EventTypes +from twisted.internet import defer + import logging import re @@ -79,7 +81,7 @@ class ApplicationService(object): NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS] def __init__(self, token, url=None, namespaces=None, hs_token=None, - sender=None, id=None): + sender=None, id=None, protocols=None, rate_limited=True): self.token = token self.url = url self.hs_token = hs_token @@ -87,6 +89,14 @@ class ApplicationService(object): self.namespaces = self._check_namespaces(namespaces) self.id = id + # .protocols is a publicly visible field + if protocols: + self.protocols = set(protocols) + else: + self.protocols = set() + + self.rate_limited = rate_limited + def _check_namespaces(self, namespaces): # Sanity check that it is of the form: # { @@ -138,65 +148,66 @@ class ApplicationService(object): return regex_obj["exclusive"] return False - def _matches_user(self, event, member_list): - if (hasattr(event, "sender") and - self.is_interested_in_user(event.sender)): - return True + @defer.inlineCallbacks + def _matches_user(self, event, store): + if not event: + defer.returnValue(False) + + if self.is_interested_in_user(event.sender): + defer.returnValue(True) # also check m.room.member state key - if (hasattr(event, "type") and event.type == EventTypes.Member - and hasattr(event, "state_key") - and self.is_interested_in_user(event.state_key)): - return True + if (event.type == EventTypes.Member and + self.is_interested_in_user(event.state_key)): + defer.returnValue(True) + + if not store: + defer.returnValue(False) + + member_list = yield store.get_users_in_room(event.room_id) + # check joined member events for user_id in member_list: if self.is_interested_in_user(user_id): - return True - return False + defer.returnValue(True) + defer.returnValue(False) def _matches_room_id(self, event): if hasattr(event, "room_id"): return self.is_interested_in_room(event.room_id) return False - def _matches_aliases(self, event, alias_list): + @defer.inlineCallbacks + def _matches_aliases(self, event, store): + if not store or not event: + defer.returnValue(False) + + alias_list = yield store.get_aliases_for_room(event.room_id) for alias in alias_list: if self.is_interested_in_alias(alias): - return True - return False + defer.returnValue(True) + defer.returnValue(False) - def is_interested(self, event, restrict_to=None, aliases_for_event=None, - member_list=None): + @defer.inlineCallbacks + def is_interested(self, event, store=None): """Check if this service is interested in this event. Args: event(Event): The event to check. - restrict_to(str): The namespace to restrict regex tests to. - aliases_for_event(list): A list of all the known room aliases for - this event. - member_list(list): A list of all joined user_ids in this room. + store(DataStore) Returns: bool: True if this service would like to know about this event. """ - if aliases_for_event is None: - aliases_for_event = [] - if member_list is None: - member_list = [] - - if restrict_to and restrict_to not in ApplicationService.NS_LIST: - # this is a programming error, so fail early and raise a general - # exception - raise Exception("Unexpected restrict_to value: %s". restrict_to) - - if not restrict_to: - return (self._matches_user(event, member_list) - or self._matches_aliases(event, aliases_for_event) - or self._matches_room_id(event)) - elif restrict_to == ApplicationService.NS_ALIASES: - return self._matches_aliases(event, aliases_for_event) - elif restrict_to == ApplicationService.NS_ROOMS: - return self._matches_room_id(event) - elif restrict_to == ApplicationService.NS_USERS: - return self._matches_user(event, member_list) + # Do cheap checks first + if self._matches_room_id(event): + defer.returnValue(True) + + if (yield self._matches_aliases(event, store)): + defer.returnValue(True) + + if (yield self._matches_user(event, store)): + defer.returnValue(True) + + defer.returnValue(False) def is_interested_in_user(self, user_id): return ( @@ -216,11 +227,17 @@ class ApplicationService(object): or user_id == self.sender ) + def is_interested_in_protocol(self, protocol): + return protocol in self.protocols + def is_exclusive_alias(self, alias): return self._is_exclusive(ApplicationService.NS_ALIASES, alias) def is_exclusive_room(self, room_id): return self._is_exclusive(ApplicationService.NS_ROOMS, room_id) + def is_rate_limited(self): + return self.rate_limited + def __str__(self): return "ApplicationService: %s" % (self.__dict__,) diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index bc90605324..b0eb0c6d9d 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -14,9 +14,11 @@ # limitations under the License. from twisted.internet import defer +from synapse.api.constants import ThirdPartyEntityKind from synapse.api.errors import CodeMessageException from synapse.http.client import SimpleHttpClient from synapse.events.utils import serialize_event +from synapse.util.caches.response_cache import ResponseCache import logging import urllib @@ -24,6 +26,42 @@ import urllib logger = logging.getLogger(__name__) +HOUR_IN_MS = 60 * 60 * 1000 + + +APP_SERVICE_PREFIX = "/_matrix/app/unstable" + + +def _is_valid_3pe_metadata(info): + if "instances" not in info: + return False + if not isinstance(info["instances"], list): + return False + return True + + +def _is_valid_3pe_result(r, field): + if not isinstance(r, dict): + return False + + for k in (field, "protocol"): + if k not in r: + return False + if not isinstance(r[k], str): + return False + + if "fields" not in r: + return False + fields = r["fields"] + if not isinstance(fields, dict): + return False + for k in fields.keys(): + if not isinstance(fields[k], str): + return False + + return True + + class ApplicationServiceApi(SimpleHttpClient): """This class manages HS -> AS communications, including querying and pushing. @@ -33,8 +71,12 @@ class ApplicationServiceApi(SimpleHttpClient): super(ApplicationServiceApi, self).__init__(hs) self.clock = hs.get_clock() + self.protocol_meta_cache = ResponseCache(hs, timeout_ms=HOUR_IN_MS) + @defer.inlineCallbacks def query_user(self, service, user_id): + if service.url is None: + defer.returnValue(False) uri = service.url + ("/users/%s" % urllib.quote(user_id)) response = None try: @@ -54,6 +96,8 @@ class ApplicationServiceApi(SimpleHttpClient): @defer.inlineCallbacks def query_alias(self, service, alias): + if service.url is None: + defer.returnValue(False) uri = service.url + ("/rooms/%s" % urllib.quote(alias)) response = None try: @@ -72,7 +116,83 @@ class ApplicationServiceApi(SimpleHttpClient): defer.returnValue(False) @defer.inlineCallbacks + def query_3pe(self, service, kind, protocol, fields): + if kind == ThirdPartyEntityKind.USER: + required_field = "userid" + elif kind == ThirdPartyEntityKind.LOCATION: + required_field = "alias" + else: + raise ValueError( + "Unrecognised 'kind' argument %r to query_3pe()", kind + ) + if service.url is None: + defer.returnValue([]) + + uri = "%s%s/thirdparty/%s/%s" % ( + service.url, + APP_SERVICE_PREFIX, + kind, + urllib.quote(protocol) + ) + try: + response = yield self.get_json(uri, fields) + if not isinstance(response, list): + logger.warning( + "query_3pe to %s returned an invalid response %r", + uri, response + ) + defer.returnValue([]) + + ret = [] + for r in response: + if _is_valid_3pe_result(r, field=required_field): + ret.append(r) + else: + logger.warning( + "query_3pe to %s returned an invalid result %r", + uri, r + ) + + defer.returnValue(ret) + except Exception as ex: + logger.warning("query_3pe to %s threw exception %s", uri, ex) + defer.returnValue([]) + + def get_3pe_protocol(self, service, protocol): + if service.url is None: + defer.returnValue({}) + + @defer.inlineCallbacks + def _get(): + uri = "%s%s/thirdparty/protocol/%s" % ( + service.url, + APP_SERVICE_PREFIX, + urllib.quote(protocol) + ) + try: + info = yield self.get_json(uri, {}) + + if not _is_valid_3pe_metadata(info): + logger.warning("query_3pe_protocol to %s did not return a" + " valid result", uri) + defer.returnValue(None) + + defer.returnValue(info) + except Exception as ex: + logger.warning("query_3pe_protocol to %s threw exception %s", + uri, ex) + defer.returnValue(None) + + key = (service.id, protocol) + return self.protocol_meta_cache.get(key) or ( + self.protocol_meta_cache.set(key, _get()) + ) + + @defer.inlineCallbacks def push_bulk(self, service, events, txn_id=None): + if service.url is None: + defer.returnValue(True) + events = self._serialize(events) if txn_id is None: @@ -100,11 +220,6 @@ class ApplicationServiceApi(SimpleHttpClient): logger.warning("push_bulk to %s threw exception %s", uri, ex) defer.returnValue(False) - @defer.inlineCallbacks - def push(self, service, event, txn_id=None): - response = yield self.push_bulk(service, [event], txn_id) - defer.returnValue(response) - def _serialize(self, events): time_now = self.clock.time_msec() return [ diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 47a4e9f864..68a9de17b8 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -48,32 +48,35 @@ UP & quit +---------- YES SUCCESS This is all tied together by the AppServiceScheduler which DIs the required components. """ +from twisted.internet import defer from synapse.appservice import ApplicationServiceState -from twisted.internet import defer +from synapse.util.logcontext import preserve_fn +from synapse.util.metrics import Measure + import logging logger = logging.getLogger(__name__) -class AppServiceScheduler(object): +class ApplicationServiceScheduler(object): """ Public facing API for this module. Does the required DI to tie the components together. This also serves as the "event_pool", which in this case is a simple array. """ - def __init__(self, clock, store, as_api): - self.clock = clock - self.store = store - self.as_api = as_api + def __init__(self, hs): + self.clock = hs.get_clock() + self.store = hs.get_datastore() + self.as_api = hs.get_application_service_api() def create_recoverer(service, callback): - return _Recoverer(clock, store, as_api, service, callback) + return _Recoverer(self.clock, self.store, self.as_api, service, callback) self.txn_ctrl = _TransactionController( - clock, store, as_api, create_recoverer + self.clock, self.store, self.as_api, create_recoverer ) - self.queuer = _ServiceQueuer(self.txn_ctrl) + self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock) @defer.inlineCallbacks def start(self): @@ -94,38 +97,36 @@ class _ServiceQueuer(object): this schedules any other events in the queue to run. """ - def __init__(self, txn_ctrl): + def __init__(self, txn_ctrl, clock): self.queued_events = {} # dict of {service_id: [events]} - self.pending_requests = {} # dict of {service_id: Deferred} + self.requests_in_flight = set() self.txn_ctrl = txn_ctrl + self.clock = clock def enqueue(self, service, event): # if this service isn't being sent something - if not self.pending_requests.get(service.id): - self._send_request(service, [event]) - else: - # add to queue for this service - if service.id not in self.queued_events: - self.queued_events[service.id] = [] - self.queued_events[service.id].append(event) - - def _send_request(self, service, events): - # send request and add callbacks - d = self.txn_ctrl.send(service, events) - d.addBoth(self._on_request_finish) - d.addErrback(self._on_request_fail) - self.pending_requests[service.id] = d - - def _on_request_finish(self, service): - self.pending_requests[service.id] = None - # if there are queued events, then send them. - if (service.id in self.queued_events - and len(self.queued_events[service.id]) > 0): - self._send_request(service, self.queued_events[service.id]) - self.queued_events[service.id] = [] - - def _on_request_fail(self, err): - logger.error("AS request failed: %s", err) + self.queued_events.setdefault(service.id, []).append(event) + preserve_fn(self._send_request)(service) + + @defer.inlineCallbacks + def _send_request(self, service): + if service.id in self.requests_in_flight: + return + + self.requests_in_flight.add(service.id) + try: + while True: + events = self.queued_events.pop(service.id, []) + if not events: + return + + with Measure(self.clock, "servicequeuer.send"): + try: + yield self.txn_ctrl.send(service, events) + except: + logger.exception("AS request failed") + finally: + self.requests_in_flight.discard(service.id) class _TransactionController(object): @@ -149,14 +150,12 @@ class _TransactionController(object): if service_is_up: sent = yield txn.send(self.as_api) if sent: - txn.complete(self.store) + yield txn.complete(self.store) else: - self._start_recoverer(service) + preserve_fn(self._start_recoverer)(service) except Exception as e: logger.exception(e) - self._start_recoverer(service) - # request has finished - defer.returnValue(service) + preserve_fn(self._start_recoverer)(service) @defer.inlineCallbacks def on_recovered(self, recoverer): diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 7449f36491..1ab5593c6e 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -64,11 +64,12 @@ class Config(object): if isinstance(value, int) or isinstance(value, long): return value second = 1000 - hour = 60 * 60 * second + minute = 60 * second + hour = 60 * minute day = 24 * hour week = 7 * day year = 365 * day - sizes = {"s": second, "h": hour, "d": day, "w": week, "y": year} + sizes = {"s": second, "m": minute, "h": hour, "d": day, "w": week, "y": year} size = 1 suffix = value[-1] if suffix in sizes: @@ -157,9 +158,40 @@ class Config(object): return default_config, config @classmethod - def load_config(cls, description, argv, generate_section=None): + def load_config(cls, description, argv): + config_parser = argparse.ArgumentParser( + description=description, + ) + config_parser.add_argument( + "-c", "--config-path", + action="append", + metavar="CONFIG_FILE", + help="Specify config file. Can be given multiple times and" + " may specify directories containing *.yaml files." + ) + + config_parser.add_argument( + "--keys-directory", + metavar="DIRECTORY", + help="Where files such as certs and signing keys are stored when" + " their location is given explicitly in the config." + " Defaults to the directory containing the last config file", + ) + + config_args = config_parser.parse_args(argv) + + config_files = find_config_files(search_paths=config_args.config_path) + obj = cls() + obj.read_config_files( + config_files, + keys_directory=config_args.keys_directory, + generate_keys=False, + ) + return obj + @classmethod + def load_or_generate_config(cls, description, argv): config_parser = argparse.ArgumentParser(add_help=False) config_parser.add_argument( "-c", "--config-path", @@ -176,7 +208,7 @@ class Config(object): config_parser.add_argument( "--report-stats", action="store", - help="Stuff", + help="Whether the generated config reports anonymized usage statistics", choices=["yes", "no"] ) config_parser.add_argument( @@ -197,36 +229,11 @@ class Config(object): ) config_args, remaining_args = config_parser.parse_known_args(argv) + config_files = find_config_files(search_paths=config_args.config_path) + generate_keys = config_args.generate_keys - config_files = [] - if config_args.config_path: - for config_path in config_args.config_path: - if os.path.isdir(config_path): - # We accept specifying directories as config paths, we search - # inside that directory for all files matching *.yaml, and then - # we apply them in *sorted* order. - files = [] - for entry in os.listdir(config_path): - entry_path = os.path.join(config_path, entry) - if not os.path.isfile(entry_path): - print ( - "Found subdirectory in config directory: %r. IGNORING." - ) % (entry_path, ) - continue - - if not entry.endswith(".yaml"): - print ( - "Found file in config directory that does not" - " end in '.yaml': %r. IGNORING." - ) % (entry_path, ) - continue - - files.append(entry_path) - - config_files.extend(sorted(files)) - else: - config_files.append(config_path) + obj = cls() if config_args.generate_config: if config_args.report_stats is None: @@ -299,28 +306,43 @@ class Config(object): " -c CONFIG-FILE\"" ) - if config_args.keys_directory: - config_dir_path = config_args.keys_directory - else: - config_dir_path = os.path.dirname(config_args.config_path[-1]) - config_dir_path = os.path.abspath(config_dir_path) + obj.read_config_files( + config_files, + keys_directory=config_args.keys_directory, + generate_keys=generate_keys, + ) + + if generate_keys: + return None + + obj.invoke_all("read_arguments", args) + + return obj + + def read_config_files(self, config_files, keys_directory=None, + generate_keys=False): + if not keys_directory: + keys_directory = os.path.dirname(config_files[-1]) + + config_dir_path = os.path.abspath(keys_directory) specified_config = {} for config_file in config_files: - yaml_config = cls.read_config_file(config_file) + yaml_config = self.read_config_file(config_file) specified_config.update(yaml_config) if "server_name" not in specified_config: raise ConfigError(MISSING_SERVER_NAME) server_name = specified_config["server_name"] - _, config = obj.generate_config( + _, config = self.generate_config( config_dir_path=config_dir_path, server_name=server_name, is_generating_file=False, ) config.pop("log_config") config.update(specified_config) + if "report_stats" not in config: raise ConfigError( MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" + @@ -328,11 +350,51 @@ class Config(object): ) if generate_keys: - obj.invoke_all("generate_files", config) + self.invoke_all("generate_files", config) return - obj.invoke_all("read_config", config) - - obj.invoke_all("read_arguments", args) - - return obj + self.invoke_all("read_config", config) + + +def find_config_files(search_paths): + """Finds config files using a list of search paths. If a path is a file + then that file path is added to the list. If a search path is a directory + then all the "*.yaml" files in that directory are added to the list in + sorted order. + + Args: + search_paths(list(str)): A list of paths to search. + + Returns: + list(str): A list of file paths. + """ + + config_files = [] + if search_paths: + for config_path in search_paths: + if os.path.isdir(config_path): + # We accept specifying directories as config paths, we search + # inside that directory for all files matching *.yaml, and then + # we apply them in *sorted* order. + files = [] + for entry in os.listdir(config_path): + entry_path = os.path.join(config_path, entry) + if not os.path.isfile(entry_path): + print ( + "Found subdirectory in config directory: %r. IGNORING." + ) % (entry_path, ) + continue + + if not entry.endswith(".yaml"): + print ( + "Found file in config directory that does not" + " end in '.yaml': %r. IGNORING." + ) % (entry_path, ) + continue + + files.append(entry_path) + + config_files.extend(sorted(files)) + else: + config_files.append(config_path) + return config_files diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index 3bed542c4f..82c50b8240 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -12,16 +12,153 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import Config +from ._base import Config, ConfigError + +from synapse.appservice import ApplicationService +from synapse.types import UserID + +import urllib +import yaml +import logging + +logger = logging.getLogger(__name__) class AppServiceConfig(Config): def read_config(self, config): self.app_service_config_files = config.get("app_service_config_files", []) + self.notify_appservices = config.get("notify_appservices", True) def default_config(cls, **kwargs): return """\ # A list of application service config file to use app_service_config_files: [] """ + + +def load_appservices(hostname, config_files): + """Returns a list of Application Services from the config files.""" + if not isinstance(config_files, list): + logger.warning( + "Expected %s to be a list of AS config files.", config_files + ) + return [] + + # Dicts of value -> filename + seen_as_tokens = {} + seen_ids = {} + + appservices = [] + + for config_file in config_files: + try: + with open(config_file, 'r') as f: + appservice = _load_appservice( + hostname, yaml.load(f), config_file + ) + if appservice.id in seen_ids: + raise ConfigError( + "Cannot reuse ID across application services: " + "%s (files: %s, %s)" % ( + appservice.id, config_file, seen_ids[appservice.id], + ) + ) + seen_ids[appservice.id] = config_file + if appservice.token in seen_as_tokens: + raise ConfigError( + "Cannot reuse as_token across application services: " + "%s (files: %s, %s)" % ( + appservice.token, + config_file, + seen_as_tokens[appservice.token], + ) + ) + seen_as_tokens[appservice.token] = config_file + logger.info("Loaded application service: %s", appservice) + appservices.append(appservice) + except Exception as e: + logger.error("Failed to load appservice from '%s'", config_file) + logger.exception(e) + raise + return appservices + + +def _load_appservice(hostname, as_info, config_filename): + required_string_fields = [ + "id", "as_token", "hs_token", "sender_localpart" + ] + for field in required_string_fields: + if not isinstance(as_info.get(field), basestring): + raise KeyError("Required string field: '%s' (%s)" % ( + field, config_filename, + )) + + # 'url' must either be a string or explicitly null, not missing + # to avoid accidentally turning off push for ASes. + if (not isinstance(as_info.get("url"), basestring) and + as_info.get("url", "") is not None): + raise KeyError( + "Required string field or explicit null: 'url' (%s)" % (config_filename,) + ) + + localpart = as_info["sender_localpart"] + if urllib.quote(localpart) != localpart: + raise ValueError( + "sender_localpart needs characters which are not URL encoded." + ) + user = UserID(localpart, hostname) + user_id = user.to_string() + + # Rate limiting for users of this AS is on by default (excludes sender) + rate_limited = True + if isinstance(as_info.get("rate_limited"), bool): + rate_limited = as_info.get("rate_limited") + + # namespace checks + if not isinstance(as_info.get("namespaces"), dict): + raise KeyError("Requires 'namespaces' object.") + for ns in ApplicationService.NS_LIST: + # specific namespaces are optional + if ns in as_info["namespaces"]: + # expect a list of dicts with exclusive and regex keys + for regex_obj in as_info["namespaces"][ns]: + if not isinstance(regex_obj, dict): + raise ValueError( + "Expected namespace entry in %s to be an object," + " but got %s", ns, regex_obj + ) + if not isinstance(regex_obj.get("regex"), basestring): + raise ValueError( + "Missing/bad type 'regex' key in %s", regex_obj + ) + if not isinstance(regex_obj.get("exclusive"), bool): + raise ValueError( + "Missing/bad type 'exclusive' key in %s", regex_obj + ) + # protocols check + protocols = as_info.get("protocols") + if protocols: + # Because strings are lists in python + if isinstance(protocols, str) or not isinstance(protocols, list): + raise KeyError("Optional 'protocols' must be a list if present.") + for p in protocols: + if not isinstance(p, str): + raise KeyError("Bad value for 'protocols' item") + + if as_info["url"] is None: + logger.info( + "(%s) Explicitly empty 'url' provided. This application service" + " will not receive events or queries.", + config_filename, + ) + return ApplicationService( + token=as_info["as_token"], + url=as_info["url"], + namespaces=as_info["namespaces"], + hs_token=as_info["hs_token"], + sender=user_id, + id=as_info["id"], + protocols=protocols, + rate_limited=rate_limited + ) diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py index b54dbabbee..7ba0c2de6a 100644 --- a/synapse/config/captcha.py +++ b/synapse/config/captcha.py @@ -27,6 +27,7 @@ class CaptchaConfig(Config): def default_config(self, **kwargs): return """\ ## Captcha ## + # See docs/CAPTCHA_SETUP for full details of configuring this. # This Home Server's ReCAPTCHA public key. recaptcha_public_key: "YOUR_PUBLIC_KEY" diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py new file mode 100644 index 0000000000..a187161272 --- /dev/null +++ b/synapse/config/emailconfig.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file can't be called email.py because if it is, we cannot: +import email.utils + +from ._base import Config + + +class EmailConfig(Config): + def read_config(self, config): + self.email_enable_notifs = False + + email_config = config.get("email", {}) + self.email_enable_notifs = email_config.get("enable_notifs", False) + + if self.email_enable_notifs: + # make sure we can import the required deps + import jinja2 + import bleach + # prevent unused warnings + jinja2 + bleach + + required = [ + "smtp_host", + "smtp_port", + "notif_from", + "template_dir", + "notif_template_html", + "notif_template_text", + ] + + missing = [] + for k in required: + if k not in email_config: + missing.append(k) + + if (len(missing) > 0): + raise RuntimeError( + "email.enable_notifs is True but required keys are missing: %s" % + (", ".join(["email." + k for k in missing]),) + ) + + if config.get("public_baseurl") is None: + raise RuntimeError( + "email.enable_notifs is True but no public_baseurl is set" + ) + + self.email_smtp_host = email_config["smtp_host"] + self.email_smtp_port = email_config["smtp_port"] + self.email_notif_from = email_config["notif_from"] + self.email_template_dir = email_config["template_dir"] + self.email_notif_template_html = email_config["notif_template_html"] + self.email_notif_template_text = email_config["notif_template_text"] + self.email_notif_for_new_users = email_config.get( + "notif_for_new_users", True + ) + if "app_name" in email_config: + self.email_app_name = email_config["app_name"] + else: + self.email_app_name = "Matrix" + + # make sure it's valid + parsed = email.utils.parseaddr(self.email_notif_from) + if parsed[1] == '': + raise RuntimeError("Invalid notif_from address") + else: + self.email_enable_notifs = False + # Not much point setting defaults for the rest: it would be an + # error for them to be used. + + def default_config(self, config_dir_path, server_name, **kwargs): + return """ + # Enable sending emails for notification events + #email: + # enable_notifs: false + # smtp_host: "localhost" + # smtp_port: 25 + # notif_from: "Your Friendly %(app)s Home Server <noreply@example.com>" + # app_name: Matrix + # template_dir: res/templates + # notif_template_html: notif_mail.html + # notif_template_text: notif_mail.txt + # notif_for_new_users: True + """ diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index a08c170f1d..0f890fc04a 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -29,13 +29,18 @@ from .key import KeyConfig from .saml2 import SAML2Config from .cas import CasConfig from .password import PasswordConfig +from .jwt import JWTConfig +from .password_auth_providers import PasswordAuthProviderConfig +from .emailconfig import EmailConfig +from .workers import WorkerConfig class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, RatelimitConfig, ContentRepositoryConfig, CaptchaConfig, VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig, AppServiceConfig, KeyConfig, SAML2Config, CasConfig, - PasswordConfig,): + JWTConfig, PasswordConfig, EmailConfig, + WorkerConfig, PasswordAuthProviderConfig,): pass diff --git a/synapse/config/jwt.py b/synapse/config/jwt.py new file mode 100644 index 0000000000..47f145c589 --- /dev/null +++ b/synapse/config/jwt.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 Niklas Riekenbrauck +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import Config, ConfigError + + +MISSING_JWT = ( + """Missing jwt library. This is required for jwt login. + + Install by running: + pip install pyjwt + """ +) + + +class JWTConfig(Config): + def read_config(self, config): + jwt_config = config.get("jwt_config", None) + if jwt_config: + self.jwt_enabled = jwt_config.get("enabled", False) + self.jwt_secret = jwt_config["secret"] + self.jwt_algorithm = jwt_config["algorithm"] + + try: + import jwt + jwt # To stop unused lint. + except ImportError: + raise ConfigError(MISSING_JWT) + else: + self.jwt_enabled = False + self.jwt_secret = None + self.jwt_algorithm = None + + def default_config(self, **kwargs): + return """\ + # The JWT needs to contain a globally unique "sub" (subject) claim. + # + # jwt_config: + # enabled: true + # secret: "a secret" + # algorithm: "HS256" + """ diff --git a/synapse/config/key.py b/synapse/config/key.py index a072aec714..6ee643793e 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -57,6 +57,8 @@ class KeyConfig(Config): seed = self.signing_key[0].seed self.macaroon_secret_key = hashlib.sha256(seed) + self.expire_access_token = config.get("expire_access_token", False) + def default_config(self, config_dir_path, server_name, is_generating_file=False, **kwargs): base_key_name = os.path.join(config_dir_path, server_name) @@ -69,6 +71,9 @@ class KeyConfig(Config): return """\ macaroon_secret_key: "%(macaroon_secret_key)s" + # Used to enable access token expiration. + expire_access_token: False + ## Signing Keys ## # Path to the signing key to sign messages with diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 5047db898f..ec72c95436 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -50,6 +50,7 @@ handlers: console: class: logging.StreamHandler formatter: precise + filters: [context] loggers: synapse: @@ -126,54 +127,58 @@ class LoggingConfig(Config): ) def setup_logging(self): - log_format = ( - "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s" - " - %(message)s" - ) - if self.log_config is None: - - level = logging.INFO - level_for_storage = logging.INFO - if self.verbosity: - level = logging.DEBUG - if self.verbosity > 1: - level_for_storage = logging.DEBUG - - # FIXME: we need a logging.WARN for a -q quiet option - logger = logging.getLogger('') - logger.setLevel(level) - - logging.getLogger('synapse.storage').setLevel(level_for_storage) - - formatter = logging.Formatter(log_format) - if self.log_file: - # TODO: Customisable file size / backup count - handler = logging.handlers.RotatingFileHandler( - self.log_file, maxBytes=(1000 * 1000 * 100), backupCount=3 - ) - - def sighup(signum, stack): - logger.info("Closing log file due to SIGHUP") - handler.doRollover() - logger.info("Opened new log file due to SIGHUP") - - # TODO(paul): obviously this is a terrible mechanism for - # stealing SIGHUP, because it means no other part of synapse - # can use it instead. If we want to catch SIGHUP anywhere - # else as well, I'd suggest we find a nicer way to broadcast - # it around. - if getattr(signal, "SIGHUP"): - signal.signal(signal.SIGHUP, sighup) - else: - handler = logging.StreamHandler() - handler.setFormatter(formatter) - - handler.addFilter(LoggingContextFilter(request="")) - - logger.addHandler(handler) + setup_logging(self.log_config, self.log_file, self.verbosity) + + +def setup_logging(log_config=None, log_file=None, verbosity=None): + log_format = ( + "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s" + " - %(message)s" + ) + if log_config is None: + + level = logging.INFO + level_for_storage = logging.INFO + if verbosity: + level = logging.DEBUG + if verbosity > 1: + level_for_storage = logging.DEBUG + + # FIXME: we need a logging.WARN for a -q quiet option + logger = logging.getLogger('') + logger.setLevel(level) + + logging.getLogger('synapse.storage').setLevel(level_for_storage) + + formatter = logging.Formatter(log_format) + if log_file: + # TODO: Customisable file size / backup count + handler = logging.handlers.RotatingFileHandler( + log_file, maxBytes=(1000 * 1000 * 100), backupCount=3 + ) + + def sighup(signum, stack): + logger.info("Closing log file due to SIGHUP") + handler.doRollover() + logger.info("Opened new log file due to SIGHUP") + + # TODO(paul): obviously this is a terrible mechanism for + # stealing SIGHUP, because it means no other part of synapse + # can use it instead. If we want to catch SIGHUP anywhere + # else as well, I'd suggest we find a nicer way to broadcast + # it around. + if getattr(signal, "SIGHUP"): + signal.signal(signal.SIGHUP, sighup) else: - with open(self.log_config, 'r') as f: - logging.config.dictConfig(yaml.load(f)) + handler = logging.StreamHandler() + handler.setFormatter(formatter) + + handler.addFilter(LoggingContextFilter(request="")) + + logger.addHandler(handler) + else: + with open(log_config, 'r') as f: + logging.config.dictConfig(yaml.load(f)) - observer = PythonLoggingObserver() - observer.start() + observer = PythonLoggingObserver() + observer.start() diff --git a/synapse/config/password.py b/synapse/config/password.py index dec801ef41..a4bd171399 100644 --- a/synapse/config/password.py +++ b/synapse/config/password.py @@ -23,10 +23,14 @@ class PasswordConfig(Config): def read_config(self, config): password_config = config.get("password_config", {}) self.password_enabled = password_config.get("enabled", True) + self.password_pepper = password_config.get("pepper", "") def default_config(self, config_dir_path, server_name, **kwargs): return """ # Enable password for login. password_config: enabled: true + # Uncomment and change to a secret random string for extra security. + # DO NOT CHANGE THIS AFTER INITIAL SETUP! + #pepper: "" """ diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py new file mode 100644 index 0000000000..83762d089a --- /dev/null +++ b/synapse/config/password_auth_providers.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 Openmarket +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import Config, ConfigError + +import importlib + + +class PasswordAuthProviderConfig(Config): + def read_config(self, config): + self.password_providers = [] + + # We want to be backwards compatible with the old `ldap_config` + # param. + ldap_config = config.get("ldap_config", {}) + self.ldap_enabled = ldap_config.get("enabled", False) + if self.ldap_enabled: + from ldap_auth_provider import LdapAuthProvider + parsed_config = LdapAuthProvider.parse_config(ldap_config) + self.password_providers.append((LdapAuthProvider, parsed_config)) + + providers = config.get("password_providers", []) + for provider in providers: + # This is for backwards compat when the ldap auth provider resided + # in this package. + if provider['module'] == "synapse.util.ldap_auth_provider.LdapAuthProvider": + from ldap_auth_provider import LdapAuthProvider + provider_class = LdapAuthProvider + else: + # We need to import the module, and then pick the class out of + # that, so we split based on the last dot. + module, clz = provider['module'].rsplit(".", 1) + module = importlib.import_module(module) + provider_class = getattr(module, clz) + + try: + provider_config = provider_class.parse_config(provider["config"]) + except Exception as e: + raise ConfigError( + "Failed to parse config for %r: %r" % (provider['module'], e) + ) + self.password_providers.append((provider_class, provider_config)) + + def default_config(self, **kwargs): + return """\ + # password_providers: + # - module: "ldap_auth_provider.LdapAuthProvider" + # config: + # enabled: true + # uri: "ldap://ldap.example.com:389" + # start_tls: true + # base: "ou=users,dc=example,dc=com" + # attributes: + # uid: "cn" + # mail: "email" + # name: "givenName" + # #bind_dn: + # #bind_password: + # #filter: "(objectClass=posixAccount)" + """ diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 2e96c09013..2c6f57168e 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -13,9 +13,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import Config +from ._base import Config, ConfigError from collections import namedtuple + +MISSING_NETADDR = ( + "Missing netaddr library. This is required for URL preview API." +) + +MISSING_LXML = ( + """Missing lxml library. This is required for URL preview API. + + Install by running: + pip install lxml + + Requires libxslt1-dev system package. + """ +) + + ThumbnailRequirement = namedtuple( "ThumbnailRequirement", ["width", "height", "method", "media_type"] ) @@ -23,7 +39,7 @@ ThumbnailRequirement = namedtuple( def parse_thumbnail_requirements(thumbnail_sizes): """ Takes a list of dictionaries with "width", "height", and "method" keys - and creates a map from image media types to the thumbnail size, thumnailing + and creates a map from image media types to the thumbnail size, thumbnailing method, and thumbnail media type to precalculate Args: @@ -53,12 +69,44 @@ class ContentRepositoryConfig(Config): def read_config(self, config): self.max_upload_size = self.parse_size(config["max_upload_size"]) self.max_image_pixels = self.parse_size(config["max_image_pixels"]) + self.max_spider_size = self.parse_size(config["max_spider_size"]) self.media_store_path = self.ensure_directory(config["media_store_path"]) self.uploads_path = self.ensure_directory(config["uploads_path"]) self.dynamic_thumbnails = config["dynamic_thumbnails"] self.thumbnail_requirements = parse_thumbnail_requirements( config["thumbnail_sizes"] ) + self.url_preview_enabled = config.get("url_preview_enabled", False) + if self.url_preview_enabled: + try: + import lxml + lxml # To stop unused lint. + except ImportError: + raise ConfigError(MISSING_LXML) + + try: + from netaddr import IPSet + except ImportError: + raise ConfigError(MISSING_NETADDR) + + if "url_preview_ip_range_blacklist" in config: + self.url_preview_ip_range_blacklist = IPSet( + config["url_preview_ip_range_blacklist"] + ) + else: + raise ConfigError( + "For security, you must specify an explicit target IP address " + "blacklist in url_preview_ip_range_blacklist for url previewing " + "to work" + ) + + self.url_preview_ip_range_whitelist = IPSet( + config.get("url_preview_ip_range_whitelist", ()) + ) + + self.url_preview_url_blacklist = config.get( + "url_preview_url_blacklist", () + ) def default_config(self, **kwargs): media_store = self.default_path("media_store") @@ -80,7 +128,7 @@ class ContentRepositoryConfig(Config): # the resolution requested by the client. If true then whenever # a new resolution is requested by the client the server will # generate a new thumbnail. If false the server will pick a thumbnail - # from a precalcualted list. + # from a precalculated list. dynamic_thumbnails: false # List of thumbnail to precalculate when an image is uploaded. @@ -100,4 +148,73 @@ class ContentRepositoryConfig(Config): - width: 800 height: 600 method: scale + + # Is the preview URL API enabled? If enabled, you *must* specify + # an explicit url_preview_ip_range_blacklist of IPs that the spider is + # denied from accessing. + url_preview_enabled: False + + # List of IP address CIDR ranges that the URL preview spider is denied + # from accessing. There are no defaults: you must explicitly + # specify a list for URL previewing to work. You should specify any + # internal services in your network that you do not want synapse to try + # to connect to, otherwise anyone in any Matrix room could cause your + # synapse to issue arbitrary GET requests to your internal services, + # causing serious security issues. + # + # url_preview_ip_range_blacklist: + # - '127.0.0.0/8' + # - '10.0.0.0/8' + # - '172.16.0.0/12' + # - '192.168.0.0/16' + # - '100.64.0.0/10' + # - '169.254.0.0/16' + # + # List of IP address CIDR ranges that the URL preview spider is allowed + # to access even if they are specified in url_preview_ip_range_blacklist. + # This is useful for specifying exceptions to wide-ranging blacklisted + # target IP ranges - e.g. for enabling URL previews for a specific private + # website only visible in your network. + # + # url_preview_ip_range_whitelist: + # - '192.168.1.1' + + # Optional list of URL matches that the URL preview spider is + # denied from accessing. You should use url_preview_ip_range_blacklist + # in preference to this, otherwise someone could define a public DNS + # entry that points to a private IP address and circumvent the blacklist. + # This is more useful if you know there is an entire shape of URL that + # you know that will never want synapse to try to spider. + # + # Each list entry is a dictionary of url component attributes as returned + # by urlparse.urlsplit as applied to the absolute form of the URL. See + # https://docs.python.org/2/library/urlparse.html#urlparse.urlsplit + # The values of the dictionary are treated as an filename match pattern + # applied to that component of URLs, unless they start with a ^ in which + # case they are treated as a regular expression match. If all the + # specified component matches for a given list item succeed, the URL is + # blacklisted. + # + # url_preview_url_blacklist: + # # blacklist any URL with a username in its URI + # - username: '*' + # + # # blacklist all *.google.com URLs + # - netloc: 'google.com' + # - netloc: '*.google.com' + # + # # blacklist all plain HTTP URLs + # - scheme: 'http' + # + # # blacklist http(s)://www.acme.com/foo + # - netloc: 'www.acme.com' + # path: '/foo' + # + # # blacklist any URL with a literal IPv4 address + # - netloc: '^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+$' + + # The largest allowed URL preview spidering size in bytes + max_spider_size: "10M" + + """ % locals() diff --git a/synapse/config/server.py b/synapse/config/server.py index df4707e1d1..634d8e6fe5 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import Config +from ._base import Config, ConfigError class ServerConfig(Config): @@ -27,10 +27,23 @@ class ServerConfig(Config): self.daemonize = config.get("daemonize") self.print_pidfile = config.get("print_pidfile") self.user_agent_suffix = config.get("user_agent_suffix") - self.use_frozen_dicts = config.get("use_frozen_dicts", True) + self.use_frozen_dicts = config.get("use_frozen_dicts", False) + self.public_baseurl = config.get("public_baseurl") + + # Whether to send federation traffic out in this process. This only + # applies to some federation traffic, and so shouldn't be used to + # "disable" federation + self.send_federation = config.get("send_federation", True) + + if self.public_baseurl is not None: + if self.public_baseurl[-1] != '/': + self.public_baseurl += '/' + self.start_pushers = config.get("start_pushers", True) self.listeners = config.get("listeners", []) + self.gc_thresholds = read_gc_thresholds(config.get("gc_thresholds", None)) + bind_port = config.get("bind_port") if bind_port: self.listeners = [] @@ -98,26 +111,6 @@ class ServerConfig(Config): ] }) - # Attempt to guess the content_addr for the v0 content repostitory - content_addr = config.get("content_addr") - if not content_addr: - for listener in self.listeners: - if listener["type"] == "http" and not listener.get("tls", False): - unsecure_port = listener["port"] - break - else: - raise RuntimeError("Could not determine 'content_addr'") - - host = self.server_name - if ':' not in host: - host = "%s:%d" % (host, unsecure_port) - else: - host = host.split(':')[0] - host = "%s:%d" % (host, unsecure_port) - content_addr = "http://%s" % (host,) - - self.content_addr = content_addr - def default_config(self, server_name, **kwargs): if ":" in server_name: bind_port = int(server_name.split(":")[1]) @@ -142,11 +135,17 @@ class ServerConfig(Config): # Whether to serve a web client from the HTTP/HTTPS root resource. web_client: True + # The public-facing base URL for the client API (not including _matrix/...) + # public_baseurl: https://example.com:8448/ + # Set the soft limit on the number of file descriptors synapse can use # Zero is used to indicate synapse should set the soft limit to the # hard limit. soft_file_limit: 0 + # The GC threshold parameters to pass to `gc.set_threshold`, if defined + # gc_thresholds: [700, 10, 10] + # List of ports that Synapse should listen on, their purpose and their # configuration. listeners: @@ -228,3 +227,20 @@ class ServerConfig(Config): type=int, help="Turn on the twisted telnet manhole" " service on the given port.") + + +def read_gc_thresholds(thresholds): + """Reads the three integer thresholds for garbage collection. Ensures that + the thresholds are integers if thresholds are supplied. + """ + if thresholds is None: + return None + try: + assert len(thresholds) == 3 + return ( + int(thresholds[0]), int(thresholds[1]), int(thresholds[2]), + ) + except: + raise ConfigError( + "Value of `gc_threshold` must be a list of three integers if set" + ) diff --git a/synapse/config/tls.py b/synapse/config/tls.py index fac8550823..3c58d2de17 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -19,6 +19,9 @@ from OpenSSL import crypto import subprocess import os +from hashlib import sha256 +from unpaddedbase64 import encode_base64 + GENERATE_DH_PARAMS = False @@ -42,6 +45,19 @@ class TlsConfig(Config): config.get("tls_dh_params_path"), "tls_dh_params" ) + self.tls_fingerprints = config["tls_fingerprints"] + + # Check that our own certificate is included in the list of fingerprints + # and include it if it is not. + x509_certificate_bytes = crypto.dump_certificate( + crypto.FILETYPE_ASN1, + self.tls_certificate + ) + sha256_fingerprint = encode_base64(sha256(x509_certificate_bytes).digest()) + sha256_fingerprints = set(f["sha256"] for f in self.tls_fingerprints) + if sha256_fingerprint not in sha256_fingerprints: + self.tls_fingerprints.append({u"sha256": sha256_fingerprint}) + # This config option applies to non-federation HTTP clients # (e.g. for talking to recaptcha, identity servers, and such) # It should never be used in production, and is intended for @@ -73,6 +89,28 @@ class TlsConfig(Config): # Don't bind to the https port no_tls: False + + # List of allowed TLS fingerprints for this server to publish along + # with the signing keys for this server. Other matrix servers that + # make HTTPS requests to this server will check that the TLS + # certificates returned by this server match one of the fingerprints. + # + # Synapse automatically adds its the fingerprint of its own certificate + # to the list. So if federation traffic is handle directly by synapse + # then no modification to the list is required. + # + # If synapse is run behind a load balancer that handles the TLS then it + # will be necessary to add the fingerprints of the certificates used by + # the loadbalancers to this list if they are different to the one + # synapse is using. + # + # Homeservers are permitted to cache the list of TLS fingerprints + # returned in the key responses up to the "valid_until_ts" returned in + # key. It may be necessary to publish the fingerprints of a new + # certificate and wait until the "valid_until_ts" of the previous key + # responses have passed before deploying it. + tls_fingerprints: [] + # tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}] """ % locals() def read_tls_certificate(self, cert_path): diff --git a/synapse/config/workers.py b/synapse/config/workers.py new file mode 100644 index 0000000000..904789d155 --- /dev/null +++ b/synapse/config/workers.py @@ -0,0 +1,31 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import Config + + +class WorkerConfig(Config): + """The workers are processes run separately to the main synapse process. + They have their own pid_file and listener configuration. They use the + replication_url to talk to the main synapse process.""" + + def read_config(self, config): + self.worker_app = config.get("worker_app") + self.worker_listeners = config.get("worker_listeners") + self.worker_daemonize = config.get("worker_daemonize") + self.worker_pid_file = config.get("worker_pid_file") + self.worker_log_file = config.get("worker_log_file") + self.worker_log_config = config.get("worker_log_config") + self.worker_replication_url = config.get("worker_replication_url") diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py index 54b83da9d8..c2bd64d6c2 100644 --- a/synapse/crypto/keyclient.py +++ b/synapse/crypto/keyclient.py @@ -77,10 +77,12 @@ class SynapseKeyClientProtocol(HTTPClient): def __init__(self): self.remote_key = defer.Deferred() self.host = None + self._peer = None def connectionMade(self): - self.host = self.transport.getHost() - logger.debug("Connected to %s", self.host) + self._peer = self.transport.getPeer() + logger.debug("Connected to %s", self._peer) + self.sendCommand(b"GET", self.path) if self.host: self.sendHeader(b"Host", self.host) @@ -124,7 +126,10 @@ class SynapseKeyClientProtocol(HTTPClient): self.timer.cancel() def on_timeout(self): - logger.debug("Timeout waiting for response from %s", self.host) + logger.debug( + "Timeout waiting for response from %s: %s", + self.host, self._peer, + ) self.errback(IOError("Timeout waiting for response")) self.transport.abortConnection() @@ -133,4 +138,5 @@ class SynapseKeyClientFactory(Factory): def protocol(self): protocol = SynapseKeyClientProtocol() protocol.path = self.path + protocol.host = self.host return protocol diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index d08ee0aa91..d7211ee9b3 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -22,6 +22,7 @@ from synapse.util.logcontext import ( preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext, preserve_fn ) +from synapse.util.metrics import Measure from twisted.internet import defer @@ -44,7 +45,25 @@ import logging logger = logging.getLogger(__name__) -KeyGroup = namedtuple("KeyGroup", ("server_name", "group_id", "key_ids")) +VerifyKeyRequest = namedtuple("VerifyRequest", ( + "server_name", "key_ids", "json_object", "deferred" +)) +""" +A request for a verify key to verify a JSON object. + +Attributes: + server_name(str): The name of the server to verify against. + key_ids(set(str)): The set of key_ids to that could be used to verify the + JSON object + json_object(dict): The JSON object to verify. + deferred(twisted.internet.defer.Deferred): + A deferred (server_name, key_id, verify_key) tuple that resolves when + a verify key has been fetched +""" + + +class KeyLookupError(ValueError): + pass class Keyring(object): @@ -74,39 +93,32 @@ class Keyring(object): list of deferreds indicating success or failure to verify each json object's signature for the given server_name. """ - group_id_to_json = {} - group_id_to_group = {} - group_ids = [] - - next_group_id = 0 - deferreds = {} + verify_requests = [] for server_name, json_object in server_and_json: logger.debug("Verifying for %s", server_name) - group_id = next_group_id - next_group_id += 1 - group_ids.append(group_id) key_ids = signature_ids(json_object, server_name) if not key_ids: - deferreds[group_id] = defer.fail(SynapseError( + deferred = defer.fail(SynapseError( 400, "Not signed with a supported algorithm", Codes.UNAUTHORIZED, )) else: - deferreds[group_id] = defer.Deferred() + deferred = defer.Deferred() - group = KeyGroup(server_name, group_id, key_ids) + verify_request = VerifyKeyRequest( + server_name, key_ids, json_object, deferred + ) - group_id_to_group[group_id] = group - group_id_to_json[group_id] = json_object + verify_requests.append(verify_request) @defer.inlineCallbacks - def handle_key_deferred(group, deferred): - server_name = group.server_name + def handle_key_deferred(verify_request): + server_name = verify_request.server_name try: - _, _, key_id, verify_key = yield deferred + _, key_id, verify_key = yield verify_request.deferred except IOError as e: logger.warn( "Got IOError when downloading keys for %s: %s %s", @@ -128,7 +140,7 @@ class Keyring(object): Codes.UNAUTHORIZED, ) - json_object = group_id_to_json[group.group_id] + json_object = verify_request.json_object try: verify_signed_json(json_object, server_name, verify_key) @@ -157,36 +169,34 @@ class Keyring(object): # Actually start fetching keys. wait_on_deferred.addBoth( - lambda _: self.get_server_verify_keys(group_id_to_group, deferreds) + lambda _: self.get_server_verify_keys(verify_requests) ) # When we've finished fetching all the keys for a given server_name, # resolve the deferred passed to `wait_for_previous_lookups` so that # any lookups waiting will proceed. - server_to_gids = {} + server_to_request_ids = {} - def remove_deferreds(res, server_name, group_id): - server_to_gids[server_name].discard(group_id) - if not server_to_gids[server_name]: + def remove_deferreds(res, server_name, verify_request): + request_id = id(verify_request) + server_to_request_ids[server_name].discard(request_id) + if not server_to_request_ids[server_name]: d = server_to_deferred.pop(server_name, None) if d: d.callback(None) return res - for g_id, deferred in deferreds.items(): - server_name = group_id_to_group[g_id].server_name - server_to_gids.setdefault(server_name, set()).add(g_id) - deferred.addBoth(remove_deferreds, server_name, g_id) + for verify_request in verify_requests: + server_name = verify_request.server_name + request_id = id(verify_request) + server_to_request_ids.setdefault(server_name, set()).add(request_id) + deferred.addBoth(remove_deferreds, server_name, verify_request) # Pass those keys to handle_key_deferred so that the json object # signatures can be verified return [ - preserve_context_over_fn( - handle_key_deferred, - group_id_to_group[g_id], - deferreds[g_id], - ) - for g_id in group_ids + preserve_context_over_fn(handle_key_deferred, verify_request) + for verify_request in verify_requests ] @defer.inlineCallbacks @@ -220,7 +230,7 @@ class Keyring(object): d.addBoth(rm, server_name) - def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred): + def get_server_verify_keys(self, verify_requests): """Takes a dict of KeyGroups and tries to find at least one key for each group. """ @@ -234,76 +244,79 @@ class Keyring(object): @defer.inlineCallbacks def do_iterations(): - merged_results = {} - - missing_keys = {} - for group in group_id_to_group.values(): - missing_keys.setdefault(group.server_name, set()).update( - group.key_ids - ) + with Measure(self.clock, "get_server_verify_keys"): + merged_results = {} - for fn in key_fetch_fns: - results = yield fn(missing_keys.items()) - merged_results.update(results) - - # We now need to figure out which groups we have keys for - # and which we don't - missing_groups = {} - for group in group_id_to_group.values(): - for key_id in group.key_ids: - if key_id in merged_results[group.server_name]: - with PreserveLoggingContext(): - group_id_to_deferred[group.group_id].callback(( - group.group_id, - group.server_name, - key_id, - merged_results[group.server_name][key_id], - )) - break - else: - missing_groups.setdefault( - group.server_name, [] - ).append(group) - - if not missing_groups: - break - - missing_keys = { - server_name: set( - key_id for group in groups for key_id in group.key_ids + missing_keys = {} + for verify_request in verify_requests: + missing_keys.setdefault(verify_request.server_name, set()).update( + verify_request.key_ids ) - for server_name, groups in missing_groups.items() - } - for group in missing_groups.values(): - group_id_to_deferred[group.group_id].errback(SynapseError( - 401, - "No key for %s with id %s" % ( - group.server_name, group.key_ids, - ), - Codes.UNAUTHORIZED, - )) + for fn in key_fetch_fns: + results = yield fn(missing_keys.items()) + merged_results.update(results) + + # We now need to figure out which verify requests we have keys + # for and which we don't + missing_keys = {} + requests_missing_keys = [] + for verify_request in verify_requests: + server_name = verify_request.server_name + result_keys = merged_results[server_name] + + if verify_request.deferred.called: + # We've already called this deferred, which probably + # means that we've already found a key for it. + continue + + for key_id in verify_request.key_ids: + if key_id in result_keys: + with PreserveLoggingContext(): + verify_request.deferred.callback(( + server_name, + key_id, + result_keys[key_id], + )) + break + else: + # The else block is only reached if the loop above + # doesn't break. + missing_keys.setdefault(server_name, set()).update( + verify_request.key_ids + ) + requests_missing_keys.append(verify_request) + + if not missing_keys: + break + + for verify_request in requests_missing_keys.values(): + verify_request.deferred.errback(SynapseError( + 401, + "No key for %s with id %s" % ( + verify_request.server_name, verify_request.key_ids, + ), + Codes.UNAUTHORIZED, + )) def on_err(err): - for deferred in group_id_to_deferred.values(): - if not deferred.called: - deferred.errback(err) + for verify_request in verify_requests: + if not verify_request.deferred.called: + verify_request.deferred.errback(err) do_iterations().addErrback(on_err) - return group_id_to_deferred - @defer.inlineCallbacks def get_keys_from_store(self, server_name_and_key_ids): - res = yield defer.gatherResults( + res = yield preserve_context_over_deferred(defer.gatherResults( [ - self.store.get_server_verify_keys( + preserve_fn(self.store.get_server_verify_keys)( server_name, key_ids ).addCallback(lambda ks, server: (server, ks), server_name) for server_name, key_ids in server_name_and_key_ids ], consumeErrors=True, - ).addErrback(unwrapFirstError) + )).addErrback(unwrapFirstError) defer.returnValue(dict(res)) @@ -324,13 +337,13 @@ class Keyring(object): ) defer.returnValue({}) - results = yield defer.gatherResults( + results = yield preserve_context_over_deferred(defer.gatherResults( [ - get_key(p_name, p_keys) + preserve_fn(get_key)(p_name, p_keys) for p_name, p_keys in self.perspective_servers.items() ], consumeErrors=True, - ).addErrback(unwrapFirstError) + )).addErrback(unwrapFirstError) union_of_keys = {} for result in results: @@ -356,7 +369,7 @@ class Keyring(object): ) except Exception as e: logger.info( - "Unable to getting key %r for %r directly: %s %s", + "Unable to get key %r for %r directly: %s %s", key_ids, server_name, type(e).__name__, str(e.message), ) @@ -370,13 +383,13 @@ class Keyring(object): defer.returnValue(keys) - results = yield defer.gatherResults( + results = yield preserve_context_over_deferred(defer.gatherResults( [ - get_key(server_name, key_ids) + preserve_fn(get_key)(server_name, key_ids) for server_name, key_ids in server_name_and_key_ids ], consumeErrors=True, - ).addErrback(unwrapFirstError) + )).addErrback(unwrapFirstError) merged = {} for result in results: @@ -418,7 +431,7 @@ class Keyring(object): for response in responses: if (u"signatures" not in response or perspective_name not in response[u"signatures"]): - raise ValueError( + raise KeyLookupError( "Key response not signed by perspective server" " %r" % (perspective_name,) ) @@ -441,21 +454,21 @@ class Keyring(object): list(response[u"signatures"][perspective_name]), list(perspective_keys) ) - raise ValueError( + raise KeyLookupError( "Response not signed with a known key for perspective" " server %r" % (perspective_name,) ) processed_response = yield self.process_v2_response( - perspective_name, response + perspective_name, response, only_from_server=False ) for server_name, response_keys in processed_response.items(): keys.setdefault(server_name, {}).update(response_keys) - yield defer.gatherResults( + yield preserve_context_over_deferred(defer.gatherResults( [ - self.store_keys( + preserve_fn(self.store_keys)( server_name=server_name, from_server=perspective_name, verify_keys=response_keys, @@ -463,7 +476,7 @@ class Keyring(object): for server_name, response_keys in keys.items() ], consumeErrors=True - ).addErrback(unwrapFirstError) + )).addErrback(unwrapFirstError) defer.returnValue(keys) @@ -484,10 +497,10 @@ class Keyring(object): if (u"signatures" not in response or server_name not in response[u"signatures"]): - raise ValueError("Key response not signed by remote server") + raise KeyLookupError("Key response not signed by remote server") if "tls_fingerprints" not in response: - raise ValueError("Key response missing TLS fingerprints") + raise KeyLookupError("Key response missing TLS fingerprints") certificate_bytes = crypto.dump_certificate( crypto.FILETYPE_ASN1, tls_certificate @@ -501,7 +514,7 @@ class Keyring(object): response_sha256_fingerprints.add(fingerprint[u"sha256"]) if sha256_fingerprint_b64 not in response_sha256_fingerprints: - raise ValueError("TLS certificate not allowed by fingerprints") + raise KeyLookupError("TLS certificate not allowed by fingerprints") response_keys = yield self.process_v2_response( from_server=server_name, @@ -511,7 +524,7 @@ class Keyring(object): keys.update(response_keys) - yield defer.gatherResults( + yield preserve_context_over_deferred(defer.gatherResults( [ preserve_fn(self.store_keys)( server_name=key_server_name, @@ -521,13 +534,13 @@ class Keyring(object): for key_server_name, verify_keys in keys.items() ], consumeErrors=True - ).addErrback(unwrapFirstError) + )).addErrback(unwrapFirstError) defer.returnValue(keys) @defer.inlineCallbacks def process_v2_response(self, from_server, response_json, - requested_ids=[]): + requested_ids=[], only_from_server=True): time_now_ms = self.clock.time_msec() response_keys = {} verify_keys = {} @@ -551,9 +564,16 @@ class Keyring(object): results = {} server_name = response_json["server_name"] + if only_from_server: + if server_name != from_server: + raise KeyLookupError( + "Expected a response for server %r not %r" % ( + from_server, server_name + ) + ) for key_id in response_json["signatures"].get(server_name, {}): if key_id not in response_json["verify_keys"]: - raise ValueError( + raise KeyLookupError( "Key response must include verification keys for all" " signatures" ) @@ -580,7 +600,7 @@ class Keyring(object): response_keys.update(verify_keys) response_keys.update(old_verify_keys) - yield defer.gatherResults( + yield preserve_context_over_deferred(defer.gatherResults( [ preserve_fn(self.store.store_server_keys_json)( server_name=server_name, @@ -593,7 +613,7 @@ class Keyring(object): for key_id in updated_key_ids ], consumeErrors=True, - ).addErrback(unwrapFirstError) + )).addErrback(unwrapFirstError) results[server_name] = response_keys @@ -621,15 +641,15 @@ class Keyring(object): if ("signatures" not in response or server_name not in response["signatures"]): - raise ValueError("Key response not signed by remote server") + raise KeyLookupError("Key response not signed by remote server") if "tls_certificate" not in response: - raise ValueError("Key response missing TLS certificate") + raise KeyLookupError("Key response missing TLS certificate") tls_certificate_b64 = response["tls_certificate"] if encode_base64(x509_certificate_bytes) != tls_certificate_b64: - raise ValueError("TLS certificate doesn't match") + raise KeyLookupError("TLS certificate doesn't match") # Cache the result in the datastore. @@ -645,7 +665,7 @@ class Keyring(object): for key_id in response["signatures"][server_name]: if key_id not in response["verify_keys"]: - raise ValueError( + raise KeyLookupError( "Key response must include verification keys for all" " signatures" ) @@ -682,7 +702,7 @@ class Keyring(object): A deferred that completes when the keys are stored. """ # TODO(markjh): Store whether the keys have expired. - yield defer.gatherResults( + yield preserve_context_over_deferred(defer.gatherResults( [ preserve_fn(self.store.store_server_verify_key)( server_name, server_name, key.time_added, key @@ -690,4 +710,4 @@ class Keyring(object): for key_id, key in verify_keys.items() ], consumeErrors=True, - ).addErrback(unwrapFirstError) + )).addErrback(unwrapFirstError) diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index bbfa5a7265..bcb8f33a58 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -14,6 +14,7 @@ # limitations under the License. from synapse.util.frozenutils import freeze +from synapse.util.caches import intern_dict # Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents @@ -30,7 +31,10 @@ class _EventInternalMetadata(object): return dict(self.__dict__) def is_outlier(self): - return hasattr(self, "outlier") and self.outlier + return getattr(self, "outlier", False) + + def is_invite_from_remote(self): + return getattr(self, "invite_from_remote", False) def _event_dict_property(key): @@ -95,7 +99,7 @@ class EventBase(object): return d - def get(self, key, default): + def get(self, key, default=None): return self._event_dict.get(key, default) def get_internal_metadata_dict(self): @@ -140,6 +144,10 @@ class FrozenEvent(EventBase): unsigned = dict(event_dict.pop("unsigned", {})) + # We intern these strings because they turn up a lot (especially when + # caching). + event_dict = intern_dict(event_dict) + if USE_FROZEN_DICTS: frozen_dict = freeze(event_dict) else: @@ -168,5 +176,7 @@ class FrozenEvent(EventBase): def __repr__(self): return "<FrozenEvent event_id='%s', type='%s', state_key='%s'>" % ( - self.event_id, self.type, self.get("state_key", None), + self.get("event_id", None), + self.get("type", None), + self.get("state_key", None), ) diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 8a475417a6..11605b34a3 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -15,9 +15,30 @@ class EventContext(object): + __slots__ = [ + "current_state_ids", + "prev_state_ids", + "state_group", + "rejected", + "push_actions", + "prev_group", + "delta_ids", + "prev_state_events", + ] - def __init__(self, current_state=None): - self.current_state = current_state + def __init__(self): + # The current state including the current event + self.current_state_ids = None + # The current state excluding the current event + self.prev_state_ids = None self.state_group = None + self.rejected = False self.push_actions = [] + + # A previously persisted state group and a delta between that + # and this state. + self.prev_group = None + self.delta_ids = None + + self.prev_state_events = None diff --git a/synapse/events/utils.py b/synapse/events/utils.py index aab18d7f71..5bbaef8187 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -16,6 +16,17 @@ from synapse.api.constants import EventTypes from . import EventBase +from frozendict import frozendict + +import re + +# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\' +# (?<!stuff) matches if the current position in the string is not preceded +# by a match for 'stuff'. +# TODO: This is fast, but fails to handle "foo\\.bar" which should be treated as +# the literal fields "foo\" and "bar" but will instead be treated as "foo\\.bar" +SPLIT_FIELD_REGEX = re.compile(r'(?<!\\)\.') + def prune_event(event): """ Returns a pruned version of the given event, which removes all keys we @@ -88,6 +99,8 @@ def prune_event(event): if "age_ts" in event.unsigned: allowed_fields["unsigned"]["age_ts"] = event.unsigned["age_ts"] + if "replaces_state" in event.unsigned: + allowed_fields["unsigned"]["replaces_state"] = event.unsigned["replaces_state"] return type(event)( allowed_fields, @@ -95,6 +108,83 @@ def prune_event(event): ) +def _copy_field(src, dst, field): + """Copy the field in 'src' to 'dst'. + + For example, if src={"foo":{"bar":5}} and dst={}, and field=["foo","bar"] + then dst={"foo":{"bar":5}}. + + Args: + src(dict): The dict to read from. + dst(dict): The dict to modify. + field(list<str>): List of keys to drill down to in 'src'. + """ + if len(field) == 0: # this should be impossible + return + if len(field) == 1: # common case e.g. 'origin_server_ts' + if field[0] in src: + dst[field[0]] = src[field[0]] + return + + # Else is a nested field e.g. 'content.body' + # Pop the last field as that's the key to move across and we need the + # parent dict in order to access the data. Drill down to the right dict. + key_to_move = field.pop(-1) + sub_dict = src + for sub_field in field: # e.g. sub_field => "content" + if sub_field in sub_dict and type(sub_dict[sub_field]) in [dict, frozendict]: + sub_dict = sub_dict[sub_field] + else: + return + + if key_to_move not in sub_dict: + return + + # Insert the key into the output dictionary, creating nested objects + # as required. We couldn't do this any earlier or else we'd need to delete + # the empty objects if the key didn't exist. + sub_out_dict = dst + for sub_field in field: + sub_out_dict = sub_out_dict.setdefault(sub_field, {}) + sub_out_dict[key_to_move] = sub_dict[key_to_move] + + +def only_fields(dictionary, fields): + """Return a new dict with only the fields in 'dictionary' which are present + in 'fields'. + + If there are no event fields specified then all fields are included. + The entries may include '.' charaters to indicate sub-fields. + So ['content.body'] will include the 'body' field of the 'content' object. + A literal '.' character in a field name may be escaped using a '\'. + + Args: + dictionary(dict): The dictionary to read from. + fields(list<str>): A list of fields to copy over. Only shallow refs are + taken. + Returns: + dict: A new dictionary with only the given fields. If fields was empty, + the same dictionary is returned. + """ + if len(fields) == 0: + return dictionary + + # for each field, convert it: + # ["content.body.thing\.with\.dots"] => [["content", "body", "thing\.with\.dots"]] + split_fields = [SPLIT_FIELD_REGEX.split(f) for f in fields] + + # for each element of the output array of arrays: + # remove escaping so we can use the right key names. + split_fields[:] = [ + [f.replace(r'\.', r'.') for f in field_array] for field_array in split_fields + ] + + output = {} + for field_array in split_fields: + _copy_field(dictionary, output, field_array) + return output + + def format_event_raw(d): return d @@ -135,7 +225,7 @@ def format_event_for_client_v2_without_room_id(d): def serialize_event(e, time_now_ms, as_client_event=True, event_format=format_event_for_client_v1, - token_id=None): + token_id=None, only_event_fields=None): # FIXME(erikj): To handle the case of presence events and the like if not isinstance(e, EventBase): return e @@ -162,6 +252,12 @@ def serialize_event(e, time_now_ms, as_client_event=True, d["unsigned"]["transaction_id"] = txn_id if as_client_event: - return event_format(d) - else: - return d + d = event_format(d) + + if only_event_fields: + if (not isinstance(only_event_fields, list) or + not all(isinstance(f, basestring) for f in only_event_fields)): + raise TypeError("only_event_fields must be a list of strings") + d = only_fields(d, only_event_fields) + + return d diff --git a/synapse/federation/__init__.py b/synapse/federation/__init__.py index 979fdf2431..2e32d245ba 100644 --- a/synapse/federation/__init__.py +++ b/synapse/federation/__init__.py @@ -17,10 +17,9 @@ """ from .replication import ReplicationLayer -from .transport.client import TransportLayerClient -def initialize_http_replication(homeserver): - transport = TransportLayerClient(homeserver) +def initialize_http_replication(hs): + transport = hs.get_federation_transport_client() - return ReplicationLayer(homeserver, transport) + return ReplicationLayer(hs, transport) diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index a0b7cb7963..2339cc9034 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -23,6 +23,7 @@ from synapse.crypto.event_signing import check_event_content_hash from synapse.api.errors import SynapseError from synapse.util import unwrapFirstError +from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred import logging @@ -31,6 +32,9 @@ logger = logging.getLogger(__name__) class FederationBase(object): + def __init__(self, hs): + pass + @defer.inlineCallbacks def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False, include_none=False): @@ -99,10 +103,10 @@ class FederationBase(object): warn, pdu ) - valid_pdus = yield defer.gatherResults( + valid_pdus = yield preserve_context_over_deferred(defer.gatherResults( deferreds, consumeErrors=True - ).addErrback(unwrapFirstError) + )).addErrback(unwrapFirstError) if include_none: defer.returnValue(valid_pdus) @@ -126,7 +130,7 @@ class FederationBase(object): for pdu in pdus ] - deferreds = self.keyring.verify_json_objects_for_server([ + deferreds = preserve_fn(self.keyring.verify_json_objects_for_server)([ (p.origin, p.get_pdu_json()) for p in redacted_pdus ]) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 83c1f46586..b255709165 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -18,7 +18,6 @@ from twisted.internet import defer from .federation_base import FederationBase from synapse.api.constants import Membership -from .units import Edu from synapse.api.errors import ( CodeMessageException, HttpResponseException, SynapseError, @@ -26,7 +25,9 @@ from synapse.api.errors import ( from synapse.util import unwrapFirstError from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.logutils import log_function +from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.events import FrozenEvent +from synapse.types import get_domain_from_id import synapse.metrics from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination @@ -43,14 +44,37 @@ logger = logging.getLogger(__name__) # synapse.federation.federation_client is a silly name metrics = synapse.metrics.get_metrics_for("synapse.federation.client") -sent_pdus_destination_dist = metrics.register_distribution("sent_pdu_destinations") +sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"]) -sent_edus_counter = metrics.register_counter("sent_edus") -sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"]) +PDU_RETRY_TIME_MS = 1 * 60 * 1000 class FederationClient(FederationBase): + def __init__(self, hs): + super(FederationClient, self).__init__(hs) + + self.pdu_destination_tried = {} + self._clock.looping_call( + self._clear_tried_cache, 60 * 1000, + ) + self.state = hs.get_state_handler() + + def _clear_tried_cache(self): + """Clear pdu_destination_tried cache""" + now = self._clock.time_msec() + + old_dict = self.pdu_destination_tried + self.pdu_destination_tried = {} + + for event_id, destination_dict in old_dict.items(): + destination_dict = { + dest: time + for dest, time in destination_dict.items() + if time + PDU_RETRY_TIME_MS > now + } + if destination_dict: + self.pdu_destination_tried[event_id] = destination_dict def start_get_pdu_cache(self): self._get_pdu_cache = ExpiringCache( @@ -64,55 +88,6 @@ class FederationClient(FederationBase): self._get_pdu_cache.start() @log_function - def send_pdu(self, pdu, destinations): - """Informs the replication layer about a new PDU generated within the - home server that should be transmitted to others. - - TODO: Figure out when we should actually resolve the deferred. - - Args: - pdu (Pdu): The new Pdu. - - Returns: - Deferred: Completes when we have successfully processed the PDU - and replicated it to any interested remote home servers. - """ - order = self._order - self._order += 1 - - sent_pdus_destination_dist.inc_by(len(destinations)) - - logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id) - - # TODO, add errback, etc. - self._transaction_queue.enqueue_pdu(pdu, destinations, order) - - logger.debug( - "[%s] transaction_layer.enqueue_pdu... done", - pdu.event_id - ) - - @log_function - def send_edu(self, destination, edu_type, content): - edu = Edu( - origin=self.server_name, - destination=destination, - edu_type=edu_type, - content=content, - ) - - sent_edus_counter.inc() - - # TODO, add errback, etc. - self._transaction_queue.enqueue_edu(edu) - return defer.succeed(None) - - @log_function - def send_failure(self, failure, destination): - self._transaction_queue.enqueue_failure(failure, destination) - return defer.succeed(None) - - @log_function def make_query(self, destination, query_type, args, retry_on_dns_fail=False): """Sends a federation Query to a remote homeserver of the given type @@ -136,7 +111,7 @@ class FederationClient(FederationBase): ) @log_function - def query_client_keys(self, destination, content): + def query_client_keys(self, destination, content, timeout): """Query device keys for a device hosted on a remote server. Args: @@ -148,10 +123,12 @@ class FederationClient(FederationBase): response """ sent_queries_counter.inc("client_device_keys") - return self.transport_layer.query_client_keys(destination, content) + return self.transport_layer.query_client_keys( + destination, content, timeout + ) @log_function - def claim_client_keys(self, destination, content): + def claim_client_keys(self, destination, content, timeout): """Claims one-time keys for a device hosted on a remote server. Args: @@ -163,7 +140,9 @@ class FederationClient(FederationBase): response """ sent_queries_counter.inc("client_one_time_keys") - return self.transport_layer.claim_client_keys(destination, content) + return self.transport_layer.claim_client_keys( + destination, content, timeout + ) @defer.inlineCallbacks @log_function @@ -198,10 +177,10 @@ class FederationClient(FederationBase): ] # FIXME: We should handle signature failures more gracefully. - pdus[:] = yield defer.gatherResults( + pdus[:] = yield preserve_context_over_deferred(defer.gatherResults( self._check_sigs_and_hashes(pdus), consumeErrors=True, - ).addErrback(unwrapFirstError) + )).addErrback(unwrapFirstError) defer.returnValue(pdus) @@ -233,12 +212,19 @@ class FederationClient(FederationBase): # TODO: Rate limit the number of times we try and get the same event. if self._get_pdu_cache: - e = self._get_pdu_cache.get(event_id) - if e: - defer.returnValue(e) + ev = self._get_pdu_cache.get(event_id) + if ev: + defer.returnValue(ev) - pdu = None + pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {}) + + signed_pdu = None for destination in destinations: + now = self._clock.time_msec() + last_attempt = pdu_attempts.get(destination, 0) + if last_attempt + PDU_RETRY_TIME_MS > now: + continue + try: limiter = yield get_retry_limiter( destination, @@ -262,39 +248,33 @@ class FederationClient(FederationBase): pdu = pdu_list[0] # Check signatures are correct. - pdu = yield self._check_sigs_and_hashes([pdu])[0] + signed_pdu = yield self._check_sigs_and_hashes([pdu])[0] break - except SynapseError: - logger.info( - "Failed to get PDU %s from %s because %s", - event_id, destination, e, - ) - continue - except CodeMessageException as e: - if 400 <= e.code < 500: - raise + pdu_attempts[destination] = now + except SynapseError as e: logger.info( "Failed to get PDU %s from %s because %s", event_id, destination, e, ) - continue except NotRetryingDestination as e: logger.info(e.message) continue except Exception as e: + pdu_attempts[destination] = now + logger.info( "Failed to get PDU %s from %s because %s", event_id, destination, e, ) continue - if self._get_pdu_cache is not None and pdu: - self._get_pdu_cache[event_id] = pdu + if self._get_pdu_cache is not None and signed_pdu: + self._get_pdu_cache[event_id] = signed_pdu - defer.returnValue(pdu) + defer.returnValue(signed_pdu) @defer.inlineCallbacks @log_function @@ -311,6 +291,42 @@ class FederationClient(FederationBase): Deferred: Results in a list of PDUs. """ + try: + # First we try and ask for just the IDs, as thats far quicker if + # we have most of the state and auth_chain already. + # However, this may 404 if the other side has an old synapse. + result = yield self.transport_layer.get_room_state_ids( + destination, room_id, event_id=event_id, + ) + + state_event_ids = result["pdu_ids"] + auth_event_ids = result.get("auth_chain_ids", []) + + fetched_events, failed_to_fetch = yield self.get_events( + [destination], room_id, set(state_event_ids + auth_event_ids) + ) + + if failed_to_fetch: + logger.warn("Failed to get %r", failed_to_fetch) + + event_map = { + ev.event_id: ev for ev in fetched_events + } + + pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map] + auth_chain = [ + event_map[e_id] for e_id in auth_event_ids if e_id in event_map + ] + + auth_chain.sort(key=lambda e: e.depth) + + defer.returnValue((pdus, auth_chain)) + except HttpResponseException as e: + if e.code == 400 or e.code == 404: + logger.info("Failed to use get_room_state_ids API, falling back") + else: + raise e + result = yield self.transport_layer.get_room_state( destination, room_id, event_id=event_id, ) @@ -324,12 +340,26 @@ class FederationClient(FederationBase): for p in result.get("auth_chain", []) ] + seen_events = yield self.store.get_events([ + ev.event_id for ev in itertools.chain(pdus, auth_chain) + ]) + signed_pdus = yield self._check_sigs_and_hash_and_fetch( - destination, pdus, outlier=True + destination, + [p for p in pdus if p.event_id not in seen_events], + outlier=True + ) + signed_pdus.extend( + seen_events[p.event_id] for p in pdus if p.event_id in seen_events ) signed_auth = yield self._check_sigs_and_hash_and_fetch( - destination, auth_chain, outlier=True + destination, + [p for p in auth_chain if p.event_id not in seen_events], + outlier=True + ) + signed_auth.extend( + seen_events[p.event_id] for p in auth_chain if p.event_id in seen_events ) signed_auth.sort(key=lambda e: e.depth) @@ -337,6 +367,69 @@ class FederationClient(FederationBase): defer.returnValue((signed_pdus, signed_auth)) @defer.inlineCallbacks + def get_events(self, destinations, room_id, event_ids, return_local=True): + """Fetch events from some remote destinations, checking if we already + have them. + + Args: + destinations (list) + room_id (str) + event_ids (list) + return_local (bool): Whether to include events we already have in + the DB in the returned list of events + + Returns: + Deferred: A deferred resolving to a 2-tuple where the first is a list of + events and the second is a list of event ids that we failed to fetch. + """ + if return_local: + seen_events = yield self.store.get_events(event_ids, allow_rejected=True) + signed_events = seen_events.values() + else: + seen_events = yield self.store.have_events(event_ids) + signed_events = [] + + failed_to_fetch = set() + + missing_events = set(event_ids) + for k in seen_events: + missing_events.discard(k) + + if not missing_events: + defer.returnValue((signed_events, failed_to_fetch)) + + def random_server_list(): + srvs = list(destinations) + random.shuffle(srvs) + return srvs + + batch_size = 20 + missing_events = list(missing_events) + for i in xrange(0, len(missing_events), batch_size): + batch = set(missing_events[i:i + batch_size]) + + deferreds = [ + preserve_fn(self.get_pdu)( + destinations=random_server_list(), + event_id=e_id, + ) + for e_id in batch + ] + + res = yield preserve_context_over_deferred( + defer.DeferredList(deferreds, consumeErrors=True) + ) + for success, result in res: + if success and result: + signed_events.append(result) + batch.discard(result.event_id) + + # We removed all events we successfully fetched from `batch` + failed_to_fetch.update(batch) + + defer.returnValue((signed_events, failed_to_fetch)) + + @defer.inlineCallbacks @log_function def get_event_auth(self, destination, room_id, event_id): res = yield self.transport_layer.get_event_auth( @@ -411,8 +504,14 @@ class FederationClient(FederationBase): (destination, self.event_from_pdu_json(pdu_dict)) ) break - except CodeMessageException: - raise + except CodeMessageException as e: + if not 500 <= e.code < 600: + raise + else: + logger.warn( + "Failed to make_%s via %s: %s", + membership, destination, e.message + ) except Exception as e: logger.warn( "Failed to make_%s via %s: %s", @@ -489,8 +588,14 @@ class FederationClient(FederationBase): "auth_chain": signed_auth, "origin": destination, }) - except CodeMessageException: - raise + except CodeMessageException as e: + if not 500 <= e.code < 600: + raise + else: + logger.exception( + "Failed to send_join via %s: %s", + destination, e.message + ) except Exception as e: logger.exception( "Failed to send_join via %s: %s", @@ -549,6 +654,15 @@ class FederationClient(FederationBase): raise RuntimeError("Failed to send to any server.") + def get_public_rooms(self, destination, limit=None, since_token=None, + search_filter=None): + if destination == self.server_name: + return + + return self.transport_layer.get_public_rooms( + destination, limit, since_token, search_filter + ) + @defer.inlineCallbacks def query_auth(self, destination, room_id, event_id, local_auth): """ @@ -638,7 +752,8 @@ class FederationClient(FederationBase): if len(signed_events) >= limit: defer.returnValue(signed_events) - servers = yield self.store.get_joined_hosts_for_room(room_id) + users = yield self.state.get_current_user_in_room(room_id) + servers = set(get_domain_from_id(u) for u in users) servers = set(servers) servers.discard(self.server_name) @@ -683,14 +798,16 @@ class FederationClient(FederationBase): return srvs deferreds = [ - self.get_pdu( + preserve_fn(self.get_pdu)( destinations=random_server_list(), event_id=e_id, ) for e_id, depth in ordered_missing[:limit - len(signed_events)] ] - res = yield defer.DeferredList(deferreds, consumeErrors=True) + res = yield preserve_context_over_deferred( + defer.DeferredList(deferreds, consumeErrors=True) + ) for (result, val), (e_id, _) in zip(res, ordered_missing): if result and val: signed_events.append(val) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index e8bfbe7cb5..3fa7b2315c 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -19,11 +19,13 @@ from twisted.internet import defer from .federation_base import FederationBase from .units import Transaction, Edu +from synapse.util.async import Linearizer from synapse.util.logutils import log_function +from synapse.util.caches.response_cache import ResponseCache from synapse.events import FrozenEvent import synapse.metrics -from synapse.api.errors import FederationError, SynapseError +from synapse.api.errors import AuthError, FederationError, SynapseError from synapse.crypto.event_signing import compute_event_signature @@ -44,6 +46,18 @@ received_queries_counter = metrics.register_counter("received_queries", labels=[ class FederationServer(FederationBase): + def __init__(self, hs): + super(FederationServer, self).__init__(hs) + + self.auth = hs.get_auth() + + self._room_pdu_linearizer = Linearizer() + self._server_linearizer = Linearizer() + + # We cache responses to state queries, as they take a while and often + # come in waves. + self._state_resp_cache = ResponseCache(hs, timeout_ms=30000) + def set_handler(self, handler): """Sets the handler that the replication layer will use to communicate receipt of new PDUs from other home servers. The required methods are @@ -83,11 +97,14 @@ class FederationServer(FederationBase): @defer.inlineCallbacks @log_function def on_backfill_request(self, origin, room_id, versions, limit): - pdus = yield self.handler.on_backfill_request( - origin, room_id, versions, limit - ) + with (yield self._server_linearizer.queue((origin, room_id))): + pdus = yield self.handler.on_backfill_request( + origin, room_id, versions, limit + ) + + res = self._transaction_from_pdus(pdus).get_dict() - defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict())) + defer.returnValue((200, res)) @defer.inlineCallbacks @log_function @@ -137,8 +154,8 @@ class FederationServer(FederationBase): logger.exception("Failed to handle PDU") if hasattr(transaction, "edus"): - for edu in [Edu(**x) for x in transaction.edus]: - self.received_edu( + for edu in (Edu(**x) for x in transaction.edus): + yield self.received_edu( transaction.origin, edu.edu_type, edu.content @@ -161,26 +178,74 @@ class FederationServer(FederationBase): ) defer.returnValue((200, response)) + @defer.inlineCallbacks def received_edu(self, origin, edu_type, content): received_edus_counter.inc() if edu_type in self.edu_handlers: - self.edu_handlers[edu_type](origin, content) + try: + yield self.edu_handlers[edu_type](origin, content) + except SynapseError as e: + logger.info("Failed to handle edu %r: %r", edu_type, e) + except Exception as e: + logger.exception("Failed to handle edu %r", edu_type) else: logger.warn("Received EDU of type %s with no handler", edu_type) @defer.inlineCallbacks @log_function def on_context_state_request(self, origin, room_id, event_id): - if event_id: - pdus = yield self.handler.get_state_for_pdu( - origin, room_id, event_id, - ) - auth_chain = yield self.store.get_auth_chain( - [pdu.event_id for pdu in pdus] - ) + if not event_id: + raise NotImplementedError("Specify an event") + + in_room = yield self.auth.check_host_in_room(room_id, origin) + if not in_room: + raise AuthError(403, "Host not in room.") + + result = self._state_resp_cache.get((room_id, event_id)) + if not result: + with (yield self._server_linearizer.queue((origin, room_id))): + resp = yield self._state_resp_cache.set( + (room_id, event_id), + self._on_context_state_request_compute(room_id, event_id) + ) + else: + resp = yield result + + defer.returnValue((200, resp)) - for event in auth_chain: + @defer.inlineCallbacks + def on_state_ids_request(self, origin, room_id, event_id): + if not event_id: + raise NotImplementedError("Specify an event") + + in_room = yield self.auth.check_host_in_room(room_id, origin) + if not in_room: + raise AuthError(403, "Host not in room.") + + state_ids = yield self.handler.get_state_ids_for_pdu( + room_id, event_id, + ) + auth_chain_ids = yield self.store.get_auth_chain_ids(state_ids) + + defer.returnValue((200, { + "pdu_ids": state_ids, + "auth_chain_ids": auth_chain_ids, + })) + + @defer.inlineCallbacks + def _on_context_state_request_compute(self, room_id, event_id): + pdus = yield self.handler.get_state_for_pdu( + room_id, event_id, + ) + auth_chain = yield self.store.get_auth_chain( + [pdu.event_id for pdu in pdus] + ) + + for event in auth_chain: + # We sign these again because there was a bug where we + # incorrectly signed things the first time round + if self.hs.is_mine_id(event.event_id): event.signatures.update( compute_event_signature( event, @@ -188,13 +253,11 @@ class FederationServer(FederationBase): self.hs.config.signing_key[0] ) ) - else: - raise NotImplementedError("Specify an event") - defer.returnValue((200, { + defer.returnValue({ "pdus": [pdu.get_pdu_json() for pdu in pdus], "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain], - })) + }) @defer.inlineCallbacks @log_function @@ -268,14 +331,16 @@ class FederationServer(FederationBase): @defer.inlineCallbacks def on_event_auth(self, origin, room_id, event_id): - time_now = self._clock.time_msec() - auth_pdus = yield self.handler.on_event_auth(event_id) - defer.returnValue((200, { - "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus], - })) + with (yield self._server_linearizer.queue((origin, room_id))): + time_now = self._clock.time_msec() + auth_pdus = yield self.handler.on_event_auth(event_id) + res = { + "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus], + } + defer.returnValue((200, res)) @defer.inlineCallbacks - def on_query_auth_request(self, origin, content, event_id): + def on_query_auth_request(self, origin, content, room_id, event_id): """ Content is a dict with keys:: auth_chain (list): A list of events that give the auth chain. @@ -294,58 +359,41 @@ class FederationServer(FederationBase): Returns: Deferred: Results in `dict` with the same format as `content` """ - auth_chain = [ - self.event_from_pdu_json(e) - for e in content["auth_chain"] - ] + with (yield self._server_linearizer.queue((origin, room_id))): + auth_chain = [ + self.event_from_pdu_json(e) + for e in content["auth_chain"] + ] + + signed_auth = yield self._check_sigs_and_hash_and_fetch( + origin, auth_chain, outlier=True + ) - signed_auth = yield self._check_sigs_and_hash_and_fetch( - origin, auth_chain, outlier=True - ) + ret = yield self.handler.on_query_auth( + origin, + event_id, + signed_auth, + content.get("rejects", []), + content.get("missing", []), + ) - ret = yield self.handler.on_query_auth( - origin, - event_id, - signed_auth, - content.get("rejects", []), - content.get("missing", []), - ) - - time_now = self._clock.time_msec() - send_content = { - "auth_chain": [ - e.get_pdu_json(time_now) - for e in ret["auth_chain"] - ], - "rejects": ret.get("rejects", []), - "missing": ret.get("missing", []), - } + time_now = self._clock.time_msec() + send_content = { + "auth_chain": [ + e.get_pdu_json(time_now) + for e in ret["auth_chain"] + ], + "rejects": ret.get("rejects", []), + "missing": ret.get("missing", []), + } defer.returnValue( (200, send_content) ) - @defer.inlineCallbacks @log_function def on_query_client_keys(self, origin, content): - query = [] - for user_id, device_ids in content.get("device_keys", {}).items(): - if not device_ids: - query.append((user_id, None)) - else: - for device_id in device_ids: - query.append((user_id, device_id)) - - results = yield self.store.get_e2e_device_keys(query) - - json_result = {} - for user_id, device_keys in results.items(): - for device_id, json_bytes in device_keys.items(): - json_result.setdefault(user_id, {})[device_id] = json.loads( - json_bytes - ) - - defer.returnValue({"device_keys": json_result}) + return self.on_query_request("client_keys", content) @defer.inlineCallbacks @log_function @@ -371,17 +419,35 @@ class FederationServer(FederationBase): @log_function def on_get_missing_events(self, origin, room_id, earliest_events, latest_events, limit, min_depth): - missing_events = yield self.handler.on_get_missing_events( - origin, room_id, earliest_events, latest_events, limit, min_depth - ) + with (yield self._server_linearizer.queue((origin, room_id))): + logger.info( + "on_get_missing_events: earliest_events: %r, latest_events: %r," + " limit: %d, min_depth: %d", + earliest_events, latest_events, limit, min_depth + ) + missing_events = yield self.handler.on_get_missing_events( + origin, room_id, earliest_events, latest_events, limit, min_depth + ) - time_now = self._clock.time_msec() + if len(missing_events) < 5: + logger.info( + "Returning %d events: %r", len(missing_events), missing_events + ) + else: + logger.info("Returning %d events", len(missing_events)) + + time_now = self._clock.time_msec() defer.returnValue({ "events": [ev.get_pdu_json(time_now) for ev in missing_events], }) @log_function + def on_openid_userinfo(self, token): + ts_now_ms = self._clock.time_msec() + return self.store.get_user_id_for_open_id_token(token, ts_now_ms) + + @log_function def _get_persisted_pdu(self, origin, event_id, do_auth=True): """ Get a PDU from the database with given origin and id. @@ -470,42 +536,59 @@ class FederationServer(FederationBase): pdu.internal_metadata.outlier = True elif min_depth and pdu.depth > min_depth: if get_missing and prevs - seen: - latest = yield self.store.get_latest_event_ids_in_room( - pdu.room_id - ) - - # We add the prev events that we have seen to the latest - # list to ensure the remote server doesn't give them to us - latest = set(latest) - latest |= seen - - missing_events = yield self.get_missing_events( - origin, - pdu.room_id, - earliest_events_ids=list(latest), - latest_events=[pdu], - limit=10, - min_depth=min_depth, - ) - - # We want to sort these by depth so we process them and - # tell clients about them in order. - missing_events.sort(key=lambda x: x.depth) - - for e in missing_events: - yield self._handle_new_pdu( - origin, - e, - get_missing=False - ) - - have_seen = yield self.store.have_events( - [ev for ev, _ in pdu.prev_events] - ) + # If we're missing stuff, ensure we only fetch stuff one + # at a time. + with (yield self._room_pdu_linearizer.queue(pdu.room_id)): + # We recalculate seen, since it may have changed. + have_seen = yield self.store.have_events(prevs) + seen = set(have_seen.keys()) + + if prevs - seen: + latest = yield self.store.get_latest_event_ids_in_room( + pdu.room_id + ) + + # We add the prev events that we have seen to the latest + # list to ensure the remote server doesn't give them to us + latest = set(latest) + latest |= seen + + logger.info( + "Missing %d events for room %r: %r...", + len(prevs - seen), pdu.room_id, list(prevs - seen)[:5] + ) + + missing_events = yield self.get_missing_events( + origin, + pdu.room_id, + earliest_events_ids=list(latest), + latest_events=[pdu], + limit=10, + min_depth=min_depth, + ) + + # We want to sort these by depth so we process them and + # tell clients about them in order. + missing_events.sort(key=lambda x: x.depth) + + for e in missing_events: + yield self._handle_new_pdu( + origin, + e, + get_missing=False + ) + + have_seen = yield self.store.have_events( + [ev for ev, _ in pdu.prev_events] + ) prevs = {e_id for e_id, _ in pdu.prev_events} seen = set(have_seen.keys()) if prevs - seen: + logger.info( + "Still missing %d events for room %r: %r...", + len(prevs - seen), pdu.room_id, list(prevs - seen)[:5] + ) fetch_state = True if fetch_state: @@ -520,12 +603,11 @@ class FederationServer(FederationBase): origin, pdu.room_id, pdu.event_id, ) except: - logger.warn("Failed to get state for event: %s", pdu.event_id) + logger.exception("Failed to get state for event: %s", pdu.event_id) yield self.handler.on_receive_pdu( origin, pdu, - backfilled=False, state=state, auth_chain=auth_chain, ) diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py index 3e062a5eab..62d865ec4b 100644 --- a/synapse/federation/replication.py +++ b/synapse/federation/replication.py @@ -20,8 +20,6 @@ a given transport. from .federation_client import FederationClient from .federation_server import FederationServer -from .transaction_queue import TransactionQueue - from .persistence import TransactionActions import logging @@ -66,11 +64,10 @@ class ReplicationLayer(FederationClient, FederationServer): self._clock = hs.get_clock() self.transaction_actions = TransactionActions(self.store) - self._transaction_queue = TransactionQueue(hs, transport_layer) - - self._order = 0 self.hs = hs + super(ReplicationLayer, self).__init__(hs) + def __str__(self): return "<ReplicationLayer(%s)>" % self.server_name diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py new file mode 100644 index 0000000000..5c9f7a86f0 --- /dev/null +++ b/synapse/federation/send_queue.py @@ -0,0 +1,298 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A federation sender that forwards things to be sent across replication to +a worker process. + +It assumes there is a single worker process feeding off of it. + +Each row in the replication stream consists of a type and some json, where the +types indicate whether they are presence, or edus, etc. + +Ephemeral or non-event data are queued up in-memory. When the worker requests +updates since a particular point, all in-memory data since before that point is +dropped. We also expire things in the queue after 5 minutes, to ensure that a +dead worker doesn't cause the queues to grow limitlessly. + +Events are replicated via a separate events stream. +""" + +from .units import Edu + +from synapse.util.metrics import Measure +import synapse.metrics + +from blist import sorteddict +import ujson + + +metrics = synapse.metrics.get_metrics_for(__name__) + + +PRESENCE_TYPE = "p" +KEYED_EDU_TYPE = "k" +EDU_TYPE = "e" +FAILURE_TYPE = "f" +DEVICE_MESSAGE_TYPE = "d" + + +class FederationRemoteSendQueue(object): + """A drop in replacement for TransactionQueue""" + + def __init__(self, hs): + self.server_name = hs.hostname + self.clock = hs.get_clock() + + self.presence_map = {} + self.presence_changed = sorteddict() + + self.keyed_edu = {} + self.keyed_edu_changed = sorteddict() + + self.edus = sorteddict() + + self.failures = sorteddict() + + self.device_messages = sorteddict() + + self.pos = 1 + self.pos_time = sorteddict() + + # EVERYTHING IS SAD. In particular, python only makes new scopes when + # we make a new function, so we need to make a new function so the inner + # lambda binds to the queue rather than to the name of the queue which + # changes. ARGH. + def register(name, queue): + metrics.register_callback( + queue_name + "_size", + lambda: len(queue), + ) + + for queue_name in [ + "presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed", + "edus", "failures", "device_messages", "pos_time", + ]: + register(queue_name, getattr(self, queue_name)) + + self.clock.looping_call(self._clear_queue, 30 * 1000) + + def _next_pos(self): + pos = self.pos + self.pos += 1 + self.pos_time[self.clock.time_msec()] = pos + return pos + + def _clear_queue(self): + """Clear the queues for anything older than N minutes""" + + FIVE_MINUTES_AGO = 5 * 60 * 1000 + now = self.clock.time_msec() + + keys = self.pos_time.keys() + time = keys.bisect_left(now - FIVE_MINUTES_AGO) + if not keys[:time]: + return + + position_to_delete = max(keys[:time]) + for key in keys[:time]: + del self.pos_time[key] + + self._clear_queue_before_pos(position_to_delete) + + def _clear_queue_before_pos(self, position_to_delete): + """Clear all the queues from before a given position""" + with Measure(self.clock, "send_queue._clear"): + # Delete things out of presence maps + keys = self.presence_changed.keys() + i = keys.bisect_left(position_to_delete) + for key in keys[:i]: + del self.presence_changed[key] + + user_ids = set( + user_id for uids in self.presence_changed.values() for _, user_id in uids + ) + + to_del = [ + user_id for user_id in self.presence_map if user_id not in user_ids + ] + for user_id in to_del: + del self.presence_map[user_id] + + # Delete things out of keyed edus + keys = self.keyed_edu_changed.keys() + i = keys.bisect_left(position_to_delete) + for key in keys[:i]: + del self.keyed_edu_changed[key] + + live_keys = set() + for edu_key in self.keyed_edu_changed.values(): + live_keys.add(edu_key) + + to_del = [edu_key for edu_key in self.keyed_edu if edu_key not in live_keys] + for edu_key in to_del: + del self.keyed_edu[edu_key] + + # Delete things out of edu map + keys = self.edus.keys() + i = keys.bisect_left(position_to_delete) + for key in keys[:i]: + del self.edus[key] + + # Delete things out of failure map + keys = self.failures.keys() + i = keys.bisect_left(position_to_delete) + for key in keys[:i]: + del self.failures[key] + + # Delete things out of device map + keys = self.device_messages.keys() + i = keys.bisect_left(position_to_delete) + for key in keys[:i]: + del self.device_messages[key] + + def notify_new_events(self, current_id): + """As per TransactionQueue""" + # We don't need to replicate this as it gets sent down a different + # stream. + pass + + def send_edu(self, destination, edu_type, content, key=None): + """As per TransactionQueue""" + pos = self._next_pos() + + edu = Edu( + origin=self.server_name, + destination=destination, + edu_type=edu_type, + content=content, + ) + + if key: + assert isinstance(key, tuple) + self.keyed_edu[(destination, key)] = edu + self.keyed_edu_changed[pos] = (destination, key) + else: + self.edus[pos] = edu + + def send_presence(self, destination, states): + """As per TransactionQueue""" + pos = self._next_pos() + + self.presence_map.update({ + state.user_id: state + for state in states + }) + + self.presence_changed[pos] = [ + (destination, state.user_id) for state in states + ] + + def send_failure(self, failure, destination): + """As per TransactionQueue""" + pos = self._next_pos() + + self.failures[pos] = (destination, str(failure)) + + def send_device_messages(self, destination): + """As per TransactionQueue""" + pos = self._next_pos() + self.device_messages[pos] = destination + + def get_current_token(self): + return self.pos - 1 + + def get_replication_rows(self, token, limit, federation_ack=None): + """ + Args: + token (int) + limit (int) + federation_ack (int): Optional. The position where the worker is + explicitly acknowledged it has handled. Allows us to drop + data from before that point + """ + # TODO: Handle limit. + + # To handle restarts where we wrap around + if token > self.pos: + token = -1 + + rows = [] + + # There should be only one reader, so lets delete everything its + # acknowledged its seen. + if federation_ack: + self._clear_queue_before_pos(federation_ack) + + # Fetch changed presence + keys = self.presence_changed.keys() + i = keys.bisect_right(token) + dest_user_ids = set( + (pos, dest_user_id) + for pos in keys[i:] + for dest_user_id in self.presence_changed[pos] + ) + + for (key, (dest, user_id)) in dest_user_ids: + rows.append((key, PRESENCE_TYPE, ujson.dumps({ + "destination": dest, + "state": self.presence_map[user_id].as_dict(), + }))) + + # Fetch changes keyed edus + keys = self.keyed_edu_changed.keys() + i = keys.bisect_right(token) + keyed_edus = set((k, self.keyed_edu_changed[k]) for k in keys[i:]) + + for (pos, (destination, edu_key)) in keyed_edus: + rows.append( + (pos, KEYED_EDU_TYPE, ujson.dumps({ + "key": edu_key, + "edu": self.keyed_edu[(destination, edu_key)].get_internal_dict(), + })) + ) + + # Fetch changed edus + keys = self.edus.keys() + i = keys.bisect_right(token) + edus = set((k, self.edus[k]) for k in keys[i:]) + + for (pos, edu) in edus: + rows.append((pos, EDU_TYPE, ujson.dumps(edu.get_internal_dict()))) + + # Fetch changed failures + keys = self.failures.keys() + i = keys.bisect_right(token) + failures = set((k, self.failures[k]) for k in keys[i:]) + + for (pos, (destination, failure)) in failures: + rows.append((pos, FAILURE_TYPE, ujson.dumps({ + "destination": destination, + "failure": failure, + }))) + + # Fetch changed device messages + keys = self.device_messages.keys() + i = keys.bisect_right(token) + device_messages = set((k, self.device_messages[k]) for k in keys[i:]) + + for (pos, destination) in device_messages: + rows.append((pos, DEVICE_MESSAGE_TYPE, ujson.dumps({ + "destination": destination, + }))) + + # Sort rows based on pos + rows.sort() + + return rows diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index 1928da03b3..51b656d74a 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -17,14 +17,18 @@ from twisted.internet import defer from .persistence import TransactionActions -from .units import Transaction +from .units import Transaction, Edu +from synapse.api.constants import EventTypes, Membership from synapse.api.errors import HttpResponseException -from synapse.util.logutils import log_function -from synapse.util.logcontext import PreserveLoggingContext +from synapse.util.async import run_on_reactor +from synapse.util.logcontext import preserve_context_over_fn from synapse.util.retryutils import ( get_retry_limiter, NotRetryingDestination, ) +from synapse.util.metrics import measure_func +from synapse.types import get_domain_from_id +from synapse.handlers.presence import format_user_presence_state import synapse.metrics import logging @@ -34,6 +38,12 @@ logger = logging.getLogger(__name__) metrics = synapse.metrics.get_metrics_for(__name__) +client_metrics = synapse.metrics.get_metrics_for("synapse.federation.client") +sent_pdus_destination_dist = client_metrics.register_distribution( + "sent_pdu_destinations" +) +sent_edus_counter = client_metrics.register_counter("sent_edus") + class TransactionQueue(object): """This class makes sure we only have one transaction in flight at @@ -42,15 +52,16 @@ class TransactionQueue(object): It batches pending PDUs into single transactions. """ - def __init__(self, hs, transport_layer): + def __init__(self, hs): self.server_name = hs.hostname self.store = hs.get_datastore() + self.state = hs.get_state_handler() self.transaction_actions = TransactionActions(self.store) - self.transport_layer = transport_layer + self.transport_layer = hs.get_federation_transport_client() - self._clock = hs.get_clock() + self.clock = hs.get_clock() # Is a mapping from destinations -> deferreds. Used to keep track # of which destinations have transactions in flight and when they are @@ -68,20 +79,35 @@ class TransactionQueue(object): # destination -> list of tuple(edu, deferred) self.pending_edus_by_dest = edus = {} + # Presence needs to be separate as we send single aggragate EDUs + self.pending_presence_by_dest = presence = {} + self.pending_edus_keyed_by_dest = edus_keyed = {} + metrics.register_callback( "pending_pdus", lambda: sum(map(len, pdus.values())), ) metrics.register_callback( "pending_edus", - lambda: sum(map(len, edus.values())), + lambda: ( + sum(map(len, edus.values())) + + sum(map(len, presence.values())) + + sum(map(len, edus_keyed.values())) + ), ) # destination -> list of tuple(failure, deferred) self.pending_failures_by_dest = {} + self.last_device_stream_id_by_dest = {} + # HACK to get unique tx id - self._next_txn_id = int(self._clock.time_msec()) + self._next_txn_id = int(self.clock.time_msec()) + + self._order = 1 + + self._is_processing = False + self._last_poked_id = -1 def can_send_to(self, destination): """Can we send messages to the given server? @@ -103,11 +129,61 @@ class TransactionQueue(object): else: return not destination.startswith("localhost") - def enqueue_pdu(self, pdu, destinations, order): + @defer.inlineCallbacks + def notify_new_events(self, current_id): + """This gets called when we have some new events we might want to + send out to other servers. + """ + self._last_poked_id = max(current_id, self._last_poked_id) + + if self._is_processing: + return + + try: + self._is_processing = True + while True: + last_token = yield self.store.get_federation_out_pos("events") + next_token, events = yield self.store.get_all_new_events_stream( + last_token, self._last_poked_id, limit=20, + ) + + logger.debug("Handling %s -> %s", last_token, next_token) + + if not events and next_token >= self._last_poked_id: + break + + for event in events: + users_in_room = yield self.state.get_current_user_in_room( + event.room_id, latest_event_ids=[event.event_id], + ) + + destinations = set( + get_domain_from_id(user_id) for user_id in users_in_room + ) + + if event.type == EventTypes.Member: + if event.content["membership"] == Membership.JOIN: + destinations.add(get_domain_from_id(event.state_key)) + + logger.debug("Sending %s to %r", event, destinations) + + self._send_pdu(event, destinations) + + yield self.store.update_federation_out_pos( + "events", next_token + ) + + finally: + self._is_processing = False + + def _send_pdu(self, pdu, destinations): # We loop through all destinations to see whether we already have # a transaction in progress. If we do, stick it in the pending_pdus # table and we'll get back to it later. + order = self._order + self._order += 1 + destinations = set(destinations) destinations = set( dest for dest in destinations if self.can_send_to(dest) @@ -118,86 +194,83 @@ class TransactionQueue(object): if not destinations: return - deferreds = [] + sent_pdus_destination_dist.inc_by(len(destinations)) for destination in destinations: - deferred = defer.Deferred() self.pending_pdus_by_dest.setdefault(destination, []).append( - (pdu, deferred, order) + (pdu, order) ) - def chain(failure): - if not deferred.called: - deferred.errback(failure) - - def log_failure(f): - logger.warn("Failed to send pdu to %s: %s", destination, f.value) - - deferred.addErrback(log_failure) - - with PreserveLoggingContext(): - self._attempt_new_transaction(destination).addErrback(chain) - - deferreds.append(deferred) - - # NO inlineCallbacks - def enqueue_edu(self, edu): - destination = edu.destination + preserve_context_over_fn( + self._attempt_new_transaction, destination + ) + def send_presence(self, destination, states): if not self.can_send_to(destination): return - deferred = defer.Deferred() - self.pending_edus_by_dest.setdefault(destination, []).append( - (edu, deferred) + self.pending_presence_by_dest.setdefault(destination, {}).update({ + state.user_id: state for state in states + }) + + preserve_context_over_fn( + self._attempt_new_transaction, destination ) - def chain(failure): - if not deferred.called: - deferred.errback(failure) + def send_edu(self, destination, edu_type, content, key=None): + edu = Edu( + origin=self.server_name, + destination=destination, + edu_type=edu_type, + content=content, + ) - def log_failure(f): - logger.warn("Failed to send edu to %s: %s", destination, f.value) + if not self.can_send_to(destination): + return - deferred.addErrback(log_failure) + sent_edus_counter.inc() - with PreserveLoggingContext(): - self._attempt_new_transaction(destination).addErrback(chain) + if key: + self.pending_edus_keyed_by_dest.setdefault( + destination, {} + )[(edu.edu_type, key)] = edu + else: + self.pending_edus_by_dest.setdefault(destination, []).append(edu) - return deferred + preserve_context_over_fn( + self._attempt_new_transaction, destination + ) - @defer.inlineCallbacks - def enqueue_failure(self, failure, destination): + def send_failure(self, failure, destination): if destination == self.server_name or destination == "localhost": return - deferred = defer.Deferred() - if not self.can_send_to(destination): return self.pending_failures_by_dest.setdefault( destination, [] - ).append( - (failure, deferred) - ) + ).append(failure) - def chain(f): - if not deferred.called: - deferred.errback(f) + preserve_context_over_fn( + self._attempt_new_transaction, destination + ) - def log_failure(f): - logger.warn("Failed to send failure to %s: %s", destination, f.value) + def send_device_messages(self, destination): + if destination == self.server_name or destination == "localhost": + return - deferred.addErrback(log_failure) + if not self.can_send_to(destination): + return - with PreserveLoggingContext(): - self._attempt_new_transaction(destination).addErrback(chain) + preserve_context_over_fn( + self._attempt_new_transaction, destination + ) - yield deferred + def get_current_token(self): + return 0 @defer.inlineCallbacks - @log_function def _attempt_new_transaction(self, destination): # list of (pending_pdu, deferred, order) if destination in self.pending_transactions: @@ -211,55 +284,128 @@ class TransactionQueue(object): ) return - pending_pdus = self.pending_pdus_by_dest.pop(destination, []) - pending_edus = self.pending_edus_by_dest.pop(destination, []) - pending_failures = self.pending_failures_by_dest.pop(destination, []) - - if pending_pdus: - logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d", - destination, len(pending_pdus)) - - if not pending_pdus and not pending_edus and not pending_failures: - logger.debug("TX [%s] Nothing to send", destination) - return - try: self.pending_transactions[destination] = 1 - logger.debug("TX [%s] _attempt_new_transaction", destination) + yield run_on_reactor() - # Sort based on the order field - pending_pdus.sort(key=lambda t: t[2]) + while True: + pending_pdus = self.pending_pdus_by_dest.pop(destination, []) + pending_edus = self.pending_edus_by_dest.pop(destination, []) + pending_presence = self.pending_presence_by_dest.pop(destination, {}) + pending_failures = self.pending_failures_by_dest.pop(destination, []) - pdus = [x[0] for x in pending_pdus] - edus = [x[0] for x in pending_edus] - failures = [x[0].get_dict() for x in pending_failures] - deferreds = [ - x[1] - for x in pending_pdus + pending_edus + pending_failures - ] + pending_edus.extend( + self.pending_edus_keyed_by_dest.pop(destination, {}).values() + ) - txn_id = str(self._next_txn_id) + limiter = yield get_retry_limiter( + destination, + self.clock, + self.store, + ) - limiter = yield get_retry_limiter( + device_message_edus, device_stream_id = ( + yield self._get_new_device_messages(destination) + ) + + pending_edus.extend(device_message_edus) + if pending_presence: + pending_edus.append( + Edu( + origin=self.server_name, + destination=destination, + edu_type="m.presence", + content={ + "push": [ + format_user_presence_state( + presence, self.clock.time_msec() + ) + for presence in pending_presence.values() + ] + }, + ) + ) + + if pending_pdus: + logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d", + destination, len(pending_pdus)) + + if not pending_pdus and not pending_edus and not pending_failures: + logger.debug("TX [%s] Nothing to send", destination) + self.last_device_stream_id_by_dest[destination] = ( + device_stream_id + ) + return + + success = yield self._send_new_transaction( + destination, pending_pdus, pending_edus, pending_failures, + device_stream_id, + should_delete_from_device_stream=bool(device_message_edus), + limiter=limiter, + ) + if not success: + break + except NotRetryingDestination: + logger.info( + "TX [%s] not ready for retry yet - " + "dropping transaction for now", destination, - self._clock, - self.store, ) + finally: + # We want to be *very* sure we delete this after we stop processing + self.pending_transactions.pop(destination, None) + + @defer.inlineCallbacks + def _get_new_device_messages(self, destination): + last_device_stream_id = self.last_device_stream_id_by_dest.get(destination, 0) + to_device_stream_id = self.store.get_to_device_stream_token() + contents, stream_id = yield self.store.get_new_device_msgs_for_remote( + destination, last_device_stream_id, to_device_stream_id + ) + edus = [ + Edu( + origin=self.server_name, + destination=destination, + edu_type="m.direct_to_device", + content=content, + ) + for content in contents + ] + defer.returnValue((edus, stream_id)) + + @measure_func("_send_new_transaction") + @defer.inlineCallbacks + def _send_new_transaction(self, destination, pending_pdus, pending_edus, + pending_failures, device_stream_id, + should_delete_from_device_stream, limiter): + + # Sort based on the order field + pending_pdus.sort(key=lambda t: t[1]) + pdus = [x[0] for x in pending_pdus] + edus = pending_edus + failures = [x.get_dict() for x in pending_failures] + + success = True + + try: + logger.debug("TX [%s] _attempt_new_transaction", destination) + + txn_id = str(self._next_txn_id) logger.debug( "TX [%s] {%s} Attempting new transaction" " (pdus: %d, edus: %d, failures: %d)", destination, txn_id, - len(pending_pdus), - len(pending_edus), - len(pending_failures) + len(pdus), + len(edus), + len(failures) ) logger.debug("TX [%s] Persisting transaction...", destination) transaction = Transaction.create_new( - origin_server_ts=int(self._clock.time_msec()), + origin_server_ts=int(self.clock.time_msec()), transaction_id=txn_id, origin=self.server_name, destination=destination, @@ -278,9 +424,9 @@ class TransactionQueue(object): " (PDUs: %d, EDUs: %d, failures: %d)", destination, txn_id, transaction.transaction_id, - len(pending_pdus), - len(pending_edus), - len(pending_failures), + len(pdus), + len(edus), + len(failures), ) with limiter: @@ -290,7 +436,7 @@ class TransactionQueue(object): # keys work def json_data_cb(): data = transaction.get_dict() - now = int(self._clock.time_msec()) + now = int(self.clock.time_msec()) if "pdus" in data: for p in data["pdus"]: if "age_ts" in p: @@ -316,6 +462,13 @@ class TransactionQueue(object): code = e.code response = e.response + if e.code == 429 or 500 <= e.code: + logger.info( + "TX [%s] {%s} got %d response", + destination, txn_id, code + ) + raise e + logger.info( "TX [%s] {%s} got %d response", destination, txn_id, code @@ -330,28 +483,19 @@ class TransactionQueue(object): logger.debug("TX [%s] Marked as delivered", destination) - logger.debug("TX [%s] Yielding to callbacks...", destination) - - for deferred in deferreds: - if code == 200: - deferred.callback(None) - else: - deferred.errback(RuntimeError("Got status %d" % code)) - - # Ensures we don't continue until all callbacks on that - # deferred have fired - try: - yield deferred - except: - pass - - logger.debug("TX [%s] Yielded to callbacks", destination) - except NotRetryingDestination: - logger.info( - "TX [%s] not ready for retry yet - " - "dropping transaction for now", - destination, - ) + if code != 200: + for p in pdus: + logger.info( + "Failed to send event %s to %s", p.event_id, destination + ) + success = False + else: + # Remove the acknowledged device messages from the database + if should_delete_from_device_stream: + yield self.store.delete_device_msgs_for_remote( + destination, device_stream_id + ) + self.last_device_stream_id_by_dest[destination] = device_stream_id except RuntimeError as e: # We capture this here as there as nothing actually listens # for this finishing functions deferred. @@ -360,6 +504,11 @@ class TransactionQueue(object): destination, e, ) + + success = False + + for p in pdus: + logger.info("Failed to send event %s to %s", p.event_id, destination) except Exception as e: # We capture this here as there as nothing actually listens # for this finishing functions deferred. @@ -369,13 +518,9 @@ class TransactionQueue(object): e, ) - for deferred in deferreds: - if not deferred.called: - deferred.errback(e) + success = False - finally: - # We want to be *very* sure we delete this after we stop processing - self.pending_transactions.pop(destination, None) + for p in pdus: + logger.info("Failed to send event %s to %s", p.event_id, destination) - # Check to see if there is anything else to send. - self._attempt_new_transaction(destination) + defer.returnValue(success) diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 2237e3413c..db45c7826c 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -55,6 +55,28 @@ class TransportLayerClient(object): ) @log_function + def get_room_state_ids(self, destination, room_id, event_id): + """ Requests all state for a given room from the given server at the + given event. Returns the state's event_id's + + Args: + destination (str): The host name of the remote home server we want + to get the state from. + context (str): The name of the context we want the state of + event_id (str): The event we want the context at. + + Returns: + Deferred: Results in a dict received from the remote homeserver. + """ + logger.debug("get_room_state_ids dest=%s, room=%s", + destination, room_id) + + path = PREFIX + "/state_ids/%s/" % room_id + return self.client.get_json( + destination, path=path, args={"event_id": event_id}, + ) + + @log_function def get_event(self, destination, event_id, timeout=None): """ Requests the pdu with give id and origin from the given server. @@ -179,7 +201,8 @@ class TransportLayerClient(object): content = yield self.client.get_json( destination=destination, path=path, - retry_on_dns_fail=True, + retry_on_dns_fail=False, + timeout=20000, ) defer.returnValue(content) @@ -225,6 +248,28 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function + def get_public_rooms(self, remote_server, limit, since_token, + search_filter=None): + path = PREFIX + "/publicRooms" + + args = {} + if limit: + args["limit"] = [str(limit)] + if since_token: + args["since"] = [since_token] + + # TODO(erikj): Actually send the search_filter across federation. + + response = yield self.client.get_json( + destination=remote_server, + path=path, + args=args, + ) + + defer.returnValue(response) + + @defer.inlineCallbacks + @log_function def exchange_third_party_invite(self, destination, room_id, event_dict): path = PREFIX + "/exchange_third_party_invite/%s" % (room_id,) @@ -263,7 +308,7 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function - def query_client_keys(self, destination, query_content): + def query_client_keys(self, destination, query_content, timeout): """Query the device keys for a list of user ids hosted on a remote server. @@ -292,12 +337,13 @@ class TransportLayerClient(object): destination=destination, path=path, data=query_content, + timeout=timeout, ) defer.returnValue(content) @defer.inlineCallbacks @log_function - def claim_client_keys(self, destination, query_content): + def claim_client_keys(self, destination, query_content, timeout): """Claim one-time keys for a list of devices hosted on a remote server. Request: @@ -328,6 +374,7 @@ class TransportLayerClient(object): destination=destination, path=path, data=query_content, + timeout=timeout, ) defer.returnValue(content) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 208bff8d4f..fec337be64 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -18,13 +18,16 @@ from twisted.internet import defer from synapse.api.urls import FEDERATION_PREFIX as PREFIX from synapse.api.errors import Codes, SynapseError from synapse.http.server import JsonResource -from synapse.http.servlet import parse_json_object_from_request +from synapse.http.servlet import ( + parse_json_object_from_request, parse_integer_from_args, parse_string_from_args, +) from synapse.util.ratelimitutils import FederationRateLimiter +from synapse.util.versionstring import get_version_string import functools import logging -import simplejson as json import re +import synapse logger = logging.getLogger(__name__) @@ -37,7 +40,7 @@ class TransportLayerServer(JsonResource): self.hs = hs self.clock = hs.get_clock() - super(TransportLayerServer, self).__init__(hs) + super(TransportLayerServer, self).__init__(hs, canonical_json=False) self.authenticator = Authenticator(hs) self.ratelimiter = FederationRateLimiter( @@ -60,6 +63,16 @@ class TransportLayerServer(JsonResource): ) +class AuthenticationError(SynapseError): + """There was a problem authenticating the request""" + pass + + +class NoAuthenticationError(AuthenticationError): + """The request had no authentication information""" + pass + + class Authenticator(object): def __init__(self, hs): self.keyring = hs.get_keyring() @@ -67,7 +80,7 @@ class Authenticator(object): # A method just so we can pass 'self' as the authenticator to the Servlets @defer.inlineCallbacks - def authenticate_request(self, request): + def authenticate_request(self, request, content): json_request = { "method": request.method, "uri": request.uri, @@ -75,17 +88,10 @@ class Authenticator(object): "signatures": {}, } - content = None - origin = None + if content is not None: + json_request["content"] = content - if request.method in ["PUT", "POST"]: - # TODO: Handle other method types? other content types? - try: - content_bytes = request.content.read() - content = json.loads(content_bytes) - json_request["content"] = content - except: - raise SynapseError(400, "Unable to parse JSON", Codes.BAD_JSON) + origin = None def parse_auth_header(header_str): try: @@ -103,14 +109,14 @@ class Authenticator(object): sig = strip_quotes(param_dict["sig"]) return (origin, key, sig) except: - raise SynapseError( + raise AuthenticationError( 400, "Malformed Authorization header", Codes.UNAUTHORIZED ) auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") if not auth_headers: - raise SynapseError( + raise NoAuthenticationError( 401, "Missing Authorization headers", Codes.UNAUTHORIZED, ) @@ -121,7 +127,7 @@ class Authenticator(object): json_request["signatures"].setdefault(origin, {})[key] = sig if not json_request["signatures"]: - raise SynapseError( + raise NoAuthenticationError( 401, "Missing Authorization headers", Codes.UNAUTHORIZED, ) @@ -130,38 +136,59 @@ class Authenticator(object): logger.info("Request from %s", origin) request.authenticated_entity = origin - defer.returnValue((origin, content)) + defer.returnValue(origin) class BaseFederationServlet(object): - def __init__(self, handler, authenticator, ratelimiter, server_name): + REQUIRE_AUTH = True + + def __init__(self, handler, authenticator, ratelimiter, server_name, + room_list_handler): self.handler = handler self.authenticator = authenticator self.ratelimiter = ratelimiter + self.room_list_handler = room_list_handler - def _wrap(self, code): + def _wrap(self, func): authenticator = self.authenticator ratelimiter = self.ratelimiter @defer.inlineCallbacks - @functools.wraps(code) - def new_code(request, *args, **kwargs): + @functools.wraps(func) + def new_func(request, *args, **kwargs): + content = None + if request.method in ["PUT", "POST"]: + # TODO: Handle other method types? other content types? + content = parse_json_object_from_request(request) + try: - (origin, content) = yield authenticator.authenticate_request(request) + origin = yield authenticator.authenticate_request(request, content) + except NoAuthenticationError: + origin = None + if self.REQUIRE_AUTH: + logger.exception("authenticate_request failed") + raise + except: + logger.exception("authenticate_request failed") + raise + + if origin: with ratelimiter.ratelimit(origin) as d: yield d - response = yield code( + response = yield func( origin, content, request.args, *args, **kwargs ) - except: - logger.exception("authenticate_request failed") - raise + else: + response = yield func( + origin, content, request.args, *args, **kwargs + ) + defer.returnValue(response) # Extra logic that functools.wraps() doesn't finish - new_code.__self__ = code.__self__ + new_func.__self__ = func.__self__ - return new_code + return new_func def register(self, server): pattern = re.compile("^" + PREFIX + self.PATH + "$") @@ -175,7 +202,7 @@ class BaseFederationServlet(object): class FederationSendServlet(BaseFederationServlet): - PATH = "/send/([^/]*)/" + PATH = "/send/(?P<transaction_id>[^/]*)/" def __init__(self, handler, server_name, **kwargs): super(FederationSendServlet, self).__init__( @@ -250,7 +277,7 @@ class FederationPullServlet(BaseFederationServlet): class FederationEventServlet(BaseFederationServlet): - PATH = "/event/([^/]*)/" + PATH = "/event/(?P<event_id>[^/]*)/" # This is when someone asks for a data item for a given server data_id pair. def on_GET(self, origin, content, query, event_id): @@ -258,7 +285,7 @@ class FederationEventServlet(BaseFederationServlet): class FederationStateServlet(BaseFederationServlet): - PATH = "/state/([^/]*)/" + PATH = "/state/(?P<context>[^/]*)/" # This is when someone asks for all data for a given context. def on_GET(self, origin, content, query, context): @@ -269,8 +296,19 @@ class FederationStateServlet(BaseFederationServlet): ) +class FederationStateIdsServlet(BaseFederationServlet): + PATH = "/state_ids/(?P<room_id>[^/]*)/" + + def on_GET(self, origin, content, query, room_id): + return self.handler.on_state_ids_request( + origin, + room_id, + query.get("event_id", [None])[0], + ) + + class FederationBackfillServlet(BaseFederationServlet): - PATH = "/backfill/([^/]*)/" + PATH = "/backfill/(?P<context>[^/]*)/" def on_GET(self, origin, content, query, context): versions = query["v"] @@ -285,7 +323,7 @@ class FederationBackfillServlet(BaseFederationServlet): class FederationQueryServlet(BaseFederationServlet): - PATH = "/query/([^/]*)" + PATH = "/query/(?P<query_type>[^/]*)" # This is when we receive a server-server Query def on_GET(self, origin, content, query, query_type): @@ -296,7 +334,7 @@ class FederationQueryServlet(BaseFederationServlet): class FederationMakeJoinServlet(BaseFederationServlet): - PATH = "/make_join/([^/]*)/([^/]*)" + PATH = "/make_join/(?P<context>[^/]*)/(?P<user_id>[^/]*)" @defer.inlineCallbacks def on_GET(self, origin, content, query, context, user_id): @@ -305,7 +343,7 @@ class FederationMakeJoinServlet(BaseFederationServlet): class FederationMakeLeaveServlet(BaseFederationServlet): - PATH = "/make_leave/([^/]*)/([^/]*)" + PATH = "/make_leave/(?P<context>[^/]*)/(?P<user_id>[^/]*)" @defer.inlineCallbacks def on_GET(self, origin, content, query, context, user_id): @@ -314,7 +352,7 @@ class FederationMakeLeaveServlet(BaseFederationServlet): class FederationSendLeaveServlet(BaseFederationServlet): - PATH = "/send_leave/([^/]*)/([^/]*)" + PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<txid>[^/]*)" @defer.inlineCallbacks def on_PUT(self, origin, content, query, room_id, txid): @@ -323,14 +361,14 @@ class FederationSendLeaveServlet(BaseFederationServlet): class FederationEventAuthServlet(BaseFederationServlet): - PATH = "/event_auth/([^/]*)/([^/]*)" + PATH = "/event_auth/(?P<context>[^/]*)/(?P<event_id>[^/]*)" def on_GET(self, origin, content, query, context, event_id): return self.handler.on_event_auth(origin, context, event_id) class FederationSendJoinServlet(BaseFederationServlet): - PATH = "/send_join/([^/]*)/([^/]*)" + PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)" @defer.inlineCallbacks def on_PUT(self, origin, content, query, context, event_id): @@ -341,7 +379,7 @@ class FederationSendJoinServlet(BaseFederationServlet): class FederationInviteServlet(BaseFederationServlet): - PATH = "/invite/([^/]*)/([^/]*)" + PATH = "/invite/(?P<context>[^/]*)/(?P<event_id>[^/]*)" @defer.inlineCallbacks def on_PUT(self, origin, content, query, context, event_id): @@ -352,7 +390,7 @@ class FederationInviteServlet(BaseFederationServlet): class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet): - PATH = "/exchange_third_party_invite/([^/]*)" + PATH = "/exchange_third_party_invite/(?P<room_id>[^/]*)" @defer.inlineCallbacks def on_PUT(self, origin, content, query, room_id): @@ -365,10 +403,8 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet): class FederationClientKeysQueryServlet(BaseFederationServlet): PATH = "/user/keys/query" - @defer.inlineCallbacks def on_POST(self, origin, content, query): - response = yield self.handler.on_query_client_keys(origin, content) - defer.returnValue((200, response)) + return self.handler.on_query_client_keys(origin, content) class FederationClientKeysClaimServlet(BaseFederationServlet): @@ -381,12 +417,12 @@ class FederationClientKeysClaimServlet(BaseFederationServlet): class FederationQueryAuthServlet(BaseFederationServlet): - PATH = "/query_auth/([^/]*)/([^/]*)" + PATH = "/query_auth/(?P<context>[^/]*)/(?P<event_id>[^/]*)" @defer.inlineCallbacks def on_POST(self, origin, content, query, context, event_id): new_content = yield self.handler.on_query_auth_request( - origin, content, event_id + origin, content, context, event_id ) defer.returnValue((200, new_content)) @@ -394,7 +430,7 @@ class FederationQueryAuthServlet(BaseFederationServlet): class FederationGetMissingEventsServlet(BaseFederationServlet): # TODO(paul): Why does this path alone end with "/?" optional? - PATH = "/get_missing_events/([^/]*)/?" + PATH = "/get_missing_events/(?P<room_id>[^/]*)/?" @defer.inlineCallbacks def on_POST(self, origin, content, query, room_id): @@ -418,9 +454,10 @@ class FederationGetMissingEventsServlet(BaseFederationServlet): class On3pidBindServlet(BaseFederationServlet): PATH = "/3pid/onbind" + REQUIRE_AUTH = False + @defer.inlineCallbacks - def on_POST(self, request): - content = parse_json_object_from_request(request) + def on_POST(self, origin, content, query): if "invites" in content: last_exception = None for invite in content["invites"]: @@ -442,10 +479,103 @@ class On3pidBindServlet(BaseFederationServlet): raise last_exception defer.returnValue((200, {})) - # Avoid doing remote HS authorization checks which are done by default by - # BaseFederationServlet. - def _wrap(self, code): - return code + +class OpenIdUserInfo(BaseFederationServlet): + """ + Exchange a bearer token for information about a user. + + The response format should be compatible with: + http://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse + + GET /openid/userinfo?access_token=ABDEFGH HTTP/1.1 + + HTTP/1.1 200 OK + Content-Type: application/json + + { + "sub": "@userpart:example.org", + } + """ + + PATH = "/openid/userinfo" + + REQUIRE_AUTH = False + + @defer.inlineCallbacks + def on_GET(self, origin, content, query): + token = query.get("access_token", [None])[0] + if token is None: + defer.returnValue((401, { + "errcode": "M_MISSING_TOKEN", "error": "Access Token required" + })) + return + + user_id = yield self.handler.on_openid_userinfo(token) + + if user_id is None: + defer.returnValue((401, { + "errcode": "M_UNKNOWN_TOKEN", + "error": "Access Token unknown or expired" + })) + + defer.returnValue((200, {"sub": user_id})) + + +class PublicRoomList(BaseFederationServlet): + """ + Fetch the public room list for this server. + + This API returns information in the same format as /publicRooms on the + client API, but will only ever include local public rooms and hence is + intended for consumption by other home servers. + + GET /publicRooms HTTP/1.1 + + HTTP/1.1 200 OK + Content-Type: application/json + + { + "chunk": [ + { + "aliases": [ + "#test:localhost" + ], + "guest_can_join": false, + "name": "test room", + "num_joined_members": 3, + "room_id": "!whkydVegtvatLfXmPN:localhost", + "world_readable": false + } + ], + "end": "END", + "start": "START" + } + """ + + PATH = "/publicRooms" + + @defer.inlineCallbacks + def on_GET(self, origin, content, query): + limit = parse_integer_from_args(query, "limit", 0) + since_token = parse_string_from_args(query, "since", None) + data = yield self.room_list_handler.get_local_public_room_list( + limit, since_token + ) + defer.returnValue((200, data)) + + +class FederationVersionServlet(BaseFederationServlet): + PATH = "/version" + + REQUIRE_AUTH = False + + def on_GET(self, origin, content, query): + return defer.succeed((200, { + "server": { + "name": "Synapse", + "version": get_version_string(synapse) + }, + })) SERVLET_CLASSES = ( @@ -453,6 +583,7 @@ SERVLET_CLASSES = ( FederationPullServlet, FederationEventServlet, FederationStateServlet, + FederationStateIdsServlet, FederationBackfillServlet, FederationQueryServlet, FederationMakeJoinServlet, @@ -468,6 +599,9 @@ SERVLET_CLASSES = ( FederationClientKeysClaimServlet, FederationThirdPartyInviteExchangeServlet, On3pidBindServlet, + OpenIdUserInfo, + PublicRoomList, + FederationVersionServlet, ) @@ -478,4 +612,5 @@ def register_servlets(hs, resource, authenticator, ratelimiter): authenticator=authenticator, ratelimiter=ratelimiter, server_name=hs.hostname, + room_list_handler=hs.get_room_list_handler(), ).register(resource) diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py index 66d2c01123..5ad408f549 100644 --- a/synapse/handlers/__init__.py +++ b/synapse/handlers/__init__.py @@ -13,34 +13,37 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.appservice.scheduler import AppServiceScheduler -from synapse.appservice.api import ApplicationServiceApi from .register import RegistrationHandler from .room import ( - RoomCreationHandler, RoomMemberHandler, RoomListHandler, RoomContextHandler, + RoomCreationHandler, RoomContextHandler, ) +from .room_member import RoomMemberHandler from .message import MessageHandler -from .events import EventStreamHandler, EventHandler from .federation import FederationHandler from .profile import ProfileHandler -from .presence import PresenceHandler from .directory import DirectoryHandler -from .typing import TypingNotificationHandler from .admin import AdminHandler -from .appservice import ApplicationServicesHandler -from .sync import SyncHandler -from .auth import AuthHandler from .identity import IdentityHandler -from .receipts import ReceiptsHandler from .search import SearchHandler class Handlers(object): - """ A collection of all the event handlers. + """ Deprecated. A collection of handlers. - There's no need to lazily create these; we'll just make them all eagerly - at construction time. + At some point most of the classes whose name ended "Handler" were + accessed through this class. + + However this makes it painful to unit test the handlers and to run cut + down versions of synapse that only use specific handlers because using a + single handler required creating all of the handlers. So some of the + handlers have been lifted out of the Handlers object and are now accessed + directly through the homeserver object itself. + + Any new handlers should follow the new pattern of being accessed through + the homeserver object and should not be added to the Handlers object. + + The remaining handlers should be moved out of the handlers object. """ def __init__(self, hs): @@ -48,26 +51,10 @@ class Handlers(object): self.message_handler = MessageHandler(hs) self.room_creation_handler = RoomCreationHandler(hs) self.room_member_handler = RoomMemberHandler(hs) - self.event_stream_handler = EventStreamHandler(hs) - self.event_handler = EventHandler(hs) self.federation_handler = FederationHandler(hs) self.profile_handler = ProfileHandler(hs) - self.presence_handler = PresenceHandler(hs) - self.room_list_handler = RoomListHandler(hs) self.directory_handler = DirectoryHandler(hs) - self.typing_notification_handler = TypingNotificationHandler(hs) self.admin_handler = AdminHandler(hs) - self.receipts_handler = ReceiptsHandler(hs) - asapi = ApplicationServiceApi(hs) - self.appservice_handler = ApplicationServicesHandler( - hs, asapi, AppServiceScheduler( - clock=hs.get_clock(), - store=hs.get_datastore(), - as_api=asapi - ) - ) - self.sync_handler = SyncHandler(hs) - self.auth_handler = AuthHandler(hs) self.identity_handler = IdentityHandler(hs) self.search_handler = SearchHandler(hs) self.room_context_handler = RoomContextHandler(hs) diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 90eabb6eb7..90f96209f8 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -13,39 +13,33 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + from twisted.internet import defer -from synapse.api.errors import LimitExceededError, SynapseError, AuthError -from synapse.crypto.event_signing import add_hashes_and_signatures +import synapse.types from synapse.api.constants import Membership, EventTypes -from synapse.types import UserID, RoomAlias, Requester -from synapse.push.action_generator import ActionGenerator - -from synapse.util.logcontext import PreserveLoggingContext - -import logging +from synapse.api.errors import LimitExceededError +from synapse.types import UserID logger = logging.getLogger(__name__) -VISIBILITY_PRIORITY = ( - "world_readable", - "shared", - "invited", - "joined", -) - - class BaseHandler(object): """ Common base class for the event handlers. - :type store: synapse.storage.events.StateStore - :type state_handler: synapse.state.StateHandler + Attributes: + store (synapse.storage.DataStore): + state_handler (synapse.state.StateHandler): """ def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): + """ self.store = hs.get_datastore() self.auth = hs.get_auth() self.notifier = hs.get_notifier() @@ -55,141 +49,26 @@ class BaseHandler(object): self.clock = hs.get_clock() self.hs = hs - self.signing_key = hs.config.signing_key[0] self.server_name = hs.hostname self.event_builder_factory = hs.get_event_builder_factory() - @defer.inlineCallbacks - def filter_events_for_clients(self, user_tuples, events, event_id_to_state): - """ Returns dict of user_id -> list of events that user is allowed to - see. - - :param (str, bool) user_tuples: (user id, is_peeking) for each - user to be checked. is_peeking should be true if: - * the user is not currently a member of the room, and: - * the user has not been a member of the room since the given - events - """ - forgotten = yield defer.gatherResults([ - self.store.who_forgot_in_room( - room_id, - ) - for room_id in frozenset(e.room_id for e in events) - ], consumeErrors=True) - - # Set of membership event_ids that have been forgotten - event_id_forgotten = frozenset( - row["event_id"] for rows in forgotten for row in rows - ) - - def allowed(event, user_id, is_peeking): - state = event_id_to_state[event.event_id] - - # get the room_visibility at the time of the event. - visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None) - if visibility_event: - visibility = visibility_event.content.get("history_visibility", "shared") - else: - visibility = "shared" - - if visibility not in VISIBILITY_PRIORITY: - visibility = "shared" - - # if it was world_readable, it's easy: everyone can read it - if visibility == "world_readable": - return True - - # Always allow history visibility events on boundaries. This is done - # by setting the effective visibility to the least restrictive - # of the old vs new. - if event.type == EventTypes.RoomHistoryVisibility: - prev_content = event.unsigned.get("prev_content", {}) - prev_visibility = prev_content.get("history_visibility", None) - - if prev_visibility not in VISIBILITY_PRIORITY: - prev_visibility = "shared" - - new_priority = VISIBILITY_PRIORITY.index(visibility) - old_priority = VISIBILITY_PRIORITY.index(prev_visibility) - if old_priority < new_priority: - visibility = prev_visibility - - # get the user's membership at the time of the event. (or rather, - # just *after* the event. Which means that people can see their - # own join events, but not (currently) their own leave events.) - membership_event = state.get((EventTypes.Member, user_id), None) - if membership_event: - if membership_event.event_id in event_id_forgotten: - membership = None - else: - membership = membership_event.membership - else: - membership = None - - # if the user was a member of the room at the time of the event, - # they can see it. - if membership == Membership.JOIN: - return True - - if visibility == "joined": - # we weren't a member at the time of the event, so we can't - # see this event. - return False - - elif visibility == "invited": - # user can also see the event if they were *invited* at the time - # of the event. - return membership == Membership.INVITE + def ratelimit(self, requester): + time_now = self.clock.time() + user_id = requester.user.to_string() - else: - # visibility is shared: user can also see the event if they have - # become a member since the event - # - # XXX: if the user has subsequently joined and then left again, - # ideally we would share history up to the point they left. But - # we don't know when they left. - return not is_peeking + # The AS user itself is never rate limited. + app_service = self.store.get_app_service_by_user_id(user_id) + if app_service is not None: + return # do not ratelimit app service senders - defer.returnValue({ - user_id: [ - event - for event in events - if allowed(event, user_id, is_peeking) - ] - for user_id, is_peeking in user_tuples - }) + # Disable rate limiting of users belonging to any AS that is configured + # not to be rate limited in its registration file (rate_limited: true|false). + if requester.app_service and not requester.app_service.is_rate_limited(): + return - @defer.inlineCallbacks - def _filter_events_for_client(self, user_id, events, is_peeking=False): - """ - Check which events a user is allowed to see - - :param str user_id: user id to be checked - :param [synapse.events.EventBase] events: list of events to be checked - :param bool is_peeking should be True if: - * the user is not currently a member of the room, and: - * the user has not been a member of the room since the given - events - :rtype [synapse.events.EventBase] - """ - types = ( - (EventTypes.RoomHistoryVisibility, ""), - (EventTypes.Member, user_id), - ) - event_id_to_state = yield self.store.get_state_for_events( - frozenset(e.event_id for e in events), - types=types - ) - res = yield self.filter_events_for_clients( - [(user_id, is_peeking)], events, event_id_to_state - ) - defer.returnValue(res.get(user_id, [])) - - def ratelimit(self, requester): - time_now = self.clock.time() allowed, time_allowed = self.ratelimiter.send_message( - requester.user.to_string(), time_now, + user_id, time_now, msg_rate_hz=self.hs.config.rc_messages_per_second, burst_count=self.hs.config.rc_message_burst_count, ) @@ -199,252 +78,20 @@ class BaseHandler(object): ) @defer.inlineCallbacks - def _create_new_client_event(self, builder): - latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room( - builder.room_id, - ) - - if latest_ret: - depth = max([d for _, _, d in latest_ret]) + 1 - else: - depth = 1 - - prev_events = [ - (event_id, prev_hashes) - for event_id, prev_hashes, _ in latest_ret - ] - - builder.prev_events = prev_events - builder.depth = depth - - state_handler = self.state_handler - - context = yield state_handler.compute_event_context(builder) - - # If we've received an invite over federation, there are no latest - # events in the room, because we don't know enough about the graph - # fragment we received to treat it like a graph, so the above returned - # no relevant events. It may have returned some events (if we have - # joined and left the room), but not useful ones, like the invite. - if ( - not self.is_host_in_room(context.current_state) and - builder.type == EventTypes.Member - ): - prev_member_event = yield self.store.get_room_member( - builder.sender, builder.room_id - ) - - # The prev_member_event may already be in context.current_state, - # despite us not being present in the room; in particular, if - # inviting user, and all other local users, have already left. - # - # In that case, we have all the information we need, and we don't - # want to drop "context" - not least because we may need to handle - # the invite locally, which will require us to have the whole - # context (not just prev_member_event) to auth it. - # - context_event_ids = ( - e.event_id for e in context.current_state.values() - ) - - if ( - prev_member_event and - prev_member_event.event_id not in context_event_ids - ): - # The prev_member_event is missing from context, so it must - # have arrived over federation and is an outlier. We forcibly - # set our context to the invite we received over federation - builder.prev_events = ( - prev_member_event.event_id, - prev_member_event.prev_events - ) - - context = yield state_handler.compute_event_context( - builder, - old_state=(prev_member_event,), - outlier=True - ) - - if builder.is_state(): - builder.prev_state = yield self.store.add_event_hashes( - context.prev_state_events - ) - - yield self.auth.add_auth_events(builder, context) - - add_hashes_and_signatures( - builder, self.server_name, self.signing_key - ) - - event = builder.build() - - logger.debug( - "Created event %s with current state: %s", - event.event_id, context.current_state, - ) - - defer.returnValue( - (event, context,) - ) - - def is_host_in_room(self, current_state): - room_members = [ - (state_key, event.membership) - for ((event_type, state_key), event) in current_state.items() - if event_type == EventTypes.Member - ] - if len(room_members) == 0: - # Have we just created the room, and is this about to be the very - # first member event? - create_event = current_state.get(("m.room.create", "")) - if create_event: - return True - for (state_key, membership) in room_members: - if ( - UserID.from_string(state_key).domain == self.hs.hostname - and membership == Membership.JOIN - ): - return True - return False - - @defer.inlineCallbacks - def handle_new_client_event( - self, - requester, - event, - context, - ratelimit=True, - extra_users=[] - ): - # We now need to go and hit out to wherever we need to hit out to. - - if ratelimit: - self.ratelimit(requester) - - self.auth.check(event, auth_events=context.current_state) - - yield self.maybe_kick_guest_users(event, context.current_state.values()) - - if event.type == EventTypes.CanonicalAlias: - # Check the alias is acually valid (at this time at least) - room_alias_str = event.content.get("alias", None) - if room_alias_str: - room_alias = RoomAlias.from_string(room_alias_str) - directory_handler = self.hs.get_handlers().directory_handler - mapping = yield directory_handler.get_association(room_alias) - - if mapping["room_id"] != event.room_id: - raise SynapseError( - 400, - "Room alias %s does not point to the room" % ( - room_alias_str, - ) - ) - - federation_handler = self.hs.get_handlers().federation_handler - - if event.type == EventTypes.Member: - if event.content["membership"] == Membership.INVITE: - def is_inviter_member_event(e): - return ( - e.type == EventTypes.Member and - e.sender == event.sender - ) - - event.unsigned["invite_room_state"] = [ - { - "type": e.type, - "state_key": e.state_key, - "content": e.content, - "sender": e.sender, - } - for k, e in context.current_state.items() - if e.type in self.hs.config.room_invite_state_types - or is_inviter_member_event(e) - ] - - invitee = UserID.from_string(event.state_key) - if not self.hs.is_mine(invitee): - # TODO: Can we add signature from remote server in a nicer - # way? If we have been invited by a remote server, we need - # to get them to sign the event. - - returned_invite = yield federation_handler.send_invite( - invitee.domain, - event, - ) - - event.unsigned.pop("room_state", None) - - # TODO: Make sure the signatures actually are correct. - event.signatures.update( - returned_invite.signatures - ) - - if event.type == EventTypes.Redaction: - if self.auth.check_redaction(event, auth_events=context.current_state): - original_event = yield self.store.get_event( - event.redacts, - check_redacted=False, - get_prev_content=False, - allow_rejected=False, - allow_none=False - ) - if event.user_id != original_event.user_id: - raise AuthError( - 403, - "You don't have permission to redact events" - ) - - if event.type == EventTypes.Create and context.current_state: - raise AuthError( - 403, - "Changing the room create event is forbidden", - ) - - action_generator = ActionGenerator(self.hs) - yield action_generator.handle_push_actions_for_event( - event, context, self - ) - - (event_stream_id, max_stream_id) = yield self.store.persist_event( - event, context=context - ) - - destinations = set() - for k, s in context.current_state.items(): - try: - if k[0] == EventTypes.Member: - if s.content["membership"] == Membership.JOIN: - destinations.add( - UserID.from_string(s.state_key).domain - ) - except SynapseError: - logger.warn( - "Failed to get destination from event %s", s.event_id - ) - - with PreserveLoggingContext(): - # Don't block waiting on waking up all the listeners. - self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, - extra_users=extra_users - ) - - # If invite, remove room_state from unsigned before sending. - event.unsigned.pop("invite_room_state", None) - - federation_handler.handle_new_event( - event, destinations=destinations, - ) - - @defer.inlineCallbacks - def maybe_kick_guest_users(self, event, current_state): + def maybe_kick_guest_users(self, event, context=None): # Technically this function invalidates current_state by changing it. # Hopefully this isn't that important to the caller. if event.type == EventTypes.GuestAccess: guest_access = event.content.get("guest_access", "forbidden") if guest_access != "can_join": + if context: + current_state = yield self.store.get_events( + context.current_state_ids.values() + ) + current_state = current_state.values() + else: + current_state = yield self.store.get_current_state(event.room_id) + logger.info("maybe_kick_guest_users %r", current_state) yield self.kick_guest_users(current_state) @defer.inlineCallbacks @@ -477,7 +124,8 @@ class BaseHandler(object): # and having homeservers have their own users leave keeps more # of that decision-making and control local to the guest-having # homeserver. - requester = Requester(target_user, "", True) + requester = synapse.types.create_requester( + target_user, is_guest=True) handler = self.hs.get_handlers().room_member_handler yield handler.update_membership( requester, diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 75fc74c797..05af54d31b 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -16,8 +16,8 @@ from twisted.internet import defer from synapse.api.constants import EventTypes -from synapse.appservice import ApplicationService -from synapse.types import UserID +from synapse.util.metrics import Measure +from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred import logging @@ -35,47 +35,81 @@ def log_failure(failure): ) -# NB: Purposefully not inheriting BaseHandler since that contains way too much -# setup code which this handler does not need or use. This makes testing a lot -# easier. class ApplicationServicesHandler(object): - def __init__(self, hs, appservice_api, appservice_scheduler): + def __init__(self, hs): self.store = hs.get_datastore() - self.hs = hs - self.appservice_api = appservice_api - self.scheduler = appservice_scheduler + self.is_mine_id = hs.is_mine_id + self.appservice_api = hs.get_application_service_api() + self.scheduler = hs.get_application_service_scheduler() self.started_scheduler = False + self.clock = hs.get_clock() + self.notify_appservices = hs.config.notify_appservices + + self.current_max = 0 + self.is_processing = False @defer.inlineCallbacks - def notify_interested_services(self, event): + def notify_interested_services(self, current_id): """Notifies (pushes) all application services interested in this event. Pushing is done asynchronously, so this method won't block for any prolonged length of time. Args: - event(Event): The event to push out to interested services. + current_id(int): The current maximum ID. """ - # Gather interested services - services = yield self._get_services_for_event(event) - if len(services) == 0: - return # no services need notifying - - # Do we know this user exists? If not, poke the user query API for - # all services which match that user regex. This needs to block as these - # user queries need to be made BEFORE pushing the event. - yield self._check_user_exists(event.sender) - if event.type == EventTypes.Member: - yield self._check_user_exists(event.state_key) - - if not self.started_scheduler: - self.scheduler.start().addErrback(log_failure) - self.started_scheduler = True - - # Fork off pushes to these services - for service in services: - self.scheduler.submit_event_for_as(service, event) + services = self.store.get_app_services() + if not services or not self.notify_appservices: + return + + self.current_max = max(self.current_max, current_id) + if self.is_processing: + return + + with Measure(self.clock, "notify_interested_services"): + self.is_processing = True + try: + upper_bound = self.current_max + limit = 100 + while True: + upper_bound, events = yield self.store.get_new_events_for_appservice( + upper_bound, limit + ) + + if not events: + break + + for event in events: + # Gather interested services + services = yield self._get_services_for_event(event) + if len(services) == 0: + continue # no services need notifying + + # Do we know this user exists? If not, poke the user + # query API for all services which match that user regex. + # This needs to block as these user queries need to be + # made BEFORE pushing the event. + yield self._check_user_exists(event.sender) + if event.type == EventTypes.Member: + yield self._check_user_exists(event.state_key) + + if not self.started_scheduler: + self.scheduler.start().addErrback(log_failure) + self.started_scheduler = True + + # Fork off pushes to these services + for service in services: + preserve_fn(self.scheduler.submit_event_for_as)( + service, event + ) + + yield self.store.set_appservice_last_pos(upper_bound) + + if len(events) < limit: + break + finally: + self.is_processing = False @defer.inlineCallbacks def query_user_exists(self, user_id): @@ -108,11 +142,12 @@ class ApplicationServicesHandler(object): association can be found. """ room_alias_str = room_alias.to_string() - alias_query_services = yield self._get_services_for_event( - event=None, - restrict_to=ApplicationService.NS_ALIASES, - alias_list=[room_alias_str] - ) + services = self.store.get_app_services() + alias_query_services = [ + s for s in services if ( + s.is_interested_in_alias(room_alias_str) + ) + ] for alias_service in alias_query_services: is_known_alias = yield self.appservice_api.query_alias( alias_service, room_alias_str @@ -125,52 +160,97 @@ class ApplicationServicesHandler(object): defer.returnValue(result) @defer.inlineCallbacks - def _get_services_for_event(self, event, restrict_to="", alias_list=None): + def query_3pe(self, kind, protocol, fields): + services = yield self._get_services_for_3pn(protocol) + + results = yield preserve_context_over_deferred(defer.DeferredList([ + preserve_fn(self.appservice_api.query_3pe)(service, kind, protocol, fields) + for service in services + ], consumeErrors=True)) + + ret = [] + for (success, result) in results: + if success: + ret.extend(result) + + defer.returnValue(ret) + + @defer.inlineCallbacks + def get_3pe_protocols(self, only_protocol=None): + services = self.store.get_app_services() + protocols = {} + + # Collect up all the individual protocol responses out of the ASes + for s in services: + for p in s.protocols: + if only_protocol is not None and p != only_protocol: + continue + + if p not in protocols: + protocols[p] = [] + + info = yield self.appservice_api.get_3pe_protocol(s, p) + + if info is not None: + protocols[p].append(info) + + def _merge_instances(infos): + if not infos: + return {} + + # Merge the 'instances' lists of multiple results, but just take + # the other fields from the first as they ought to be identical + # copy the result so as not to corrupt the cached one + combined = dict(infos[0]) + combined["instances"] = list(combined["instances"]) + + for info in infos[1:]: + combined["instances"].extend(info["instances"]) + + return combined + + for p in protocols.keys(): + protocols[p] = _merge_instances(protocols[p]) + + defer.returnValue(protocols) + + @defer.inlineCallbacks + def _get_services_for_event(self, event): """Retrieve a list of application services interested in this event. Args: event(Event): The event to check. Can be None if alias_list is not. - restrict_to(str): The namespace to restrict regex tests to. - alias_list: A list of aliases to get services for. If None, this - list is obtained from the database. Returns: list<ApplicationService>: A list of services interested in this event based on the service regex. """ - member_list = None - if hasattr(event, "room_id"): - # We need to know the aliases associated with this event.room_id, - # if any. - if not alias_list: - alias_list = yield self.store.get_aliases_for_room( - event.room_id - ) - # We need to know the members associated with this event.room_id, - # if any. - member_list = yield self.store.get_users_in_room(event.room_id) - - services = yield self.store.get_app_services() + services = self.store.get_app_services() interested_list = [ s for s in services if ( - s.is_interested(event, restrict_to, alias_list, member_list) + yield s.is_interested(event, self.store) ) ] defer.returnValue(interested_list) - @defer.inlineCallbacks def _get_services_for_user(self, user_id): - services = yield self.store.get_app_services() + services = self.store.get_app_services() interested_list = [ s for s in services if ( s.is_interested_in_user(user_id) ) ] - defer.returnValue(interested_list) + return defer.succeed(interested_list) + + def _get_services_for_3pn(self, protocol): + services = self.store.get_app_services() + interested_list = [ + s for s in services if s.is_interested_in_protocol(protocol) + ] + return defer.succeed(interested_list) @defer.inlineCallbacks def _is_unknown_user(self, user_id): - user = UserID.from_string(user_id) - if not self.hs.is_mine(user): + if not self.is_mine_id(user_id): # we don't know if they are unknown or not since it isn't one of our # users. We can't poke ASes. defer.returnValue(False) @@ -182,7 +262,7 @@ class ApplicationServicesHandler(object): return # user not found; could be the AS though, so check. - services = yield self.store.get_app_services() + services = self.store.get_app_services() service_list = [s for s in services if s.sender == user_id] defer.returnValue(len(service_list) == 0) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 82d458b424..3b146f09d6 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -18,7 +18,7 @@ from twisted.internet import defer from ._base import BaseHandler from synapse.api.constants import LoginType from synapse.types import UserID -from synapse.api.errors import AuthError, LoginError, Codes +from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError from synapse.util.async import run_on_reactor from twisted.web.client import PartialDownloadError @@ -38,6 +38,10 @@ class AuthHandler(BaseHandler): SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000 def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): + """ super(AuthHandler, self).__init__(hs) self.checkers = { LoginType.PASSWORD: self._check_password_auth, @@ -47,7 +51,20 @@ class AuthHandler(BaseHandler): } self.bcrypt_rounds = hs.config.bcrypt_rounds self.sessions = {} - self.INVALID_TOKEN_HTTP_STATUS = 401 + + account_handler = _AccountHandler( + hs, check_user_exists=self.check_user_exists + ) + + self.password_providers = [ + module(config=config, account_handler=account_handler) + for module, config in hs.config.password_providers + ] + + logger.info("Extra password_providers: %r", self.password_providers) + + self.hs = hs # FIXME better possibility to access registrationHandler later? + self.device_handler = hs.get_device_handler() @defer.inlineCallbacks def check_auth(self, flows, clientdict, clientip): @@ -118,21 +135,47 @@ class AuthHandler(BaseHandler): creds = session['creds'] # check auth type currently being presented + errordict = {} if 'type' in authdict: - if authdict['type'] not in self.checkers: + login_type = authdict['type'] + if login_type not in self.checkers: raise LoginError(400, "", Codes.UNRECOGNIZED) - result = yield self.checkers[authdict['type']](authdict, clientip) - if result: - creds[authdict['type']] = result - self._save_session(session) + try: + result = yield self.checkers[login_type](authdict, clientip) + if result: + creds[login_type] = result + self._save_session(session) + except LoginError, e: + if login_type == LoginType.EMAIL_IDENTITY: + # riot used to have a bug where it would request a new + # validation token (thus sending a new email) each time it + # got a 401 with a 'flows' field. + # (https://github.com/vector-im/vector-web/issues/2447). + # + # Grandfather in the old behaviour for now to avoid + # breaking old riot deployments. + raise e + + # this step failed. Merge the error dict into the response + # so that the client can have another go. + errordict = e.error_dict() for f in flows: if len(set(f) - set(creds.keys())) == 0: - logger.info("Auth completed with creds: %r", creds) + # it's very useful to know what args are stored, but this can + # include the password in the case of registering, so only log + # the keys (confusingly, clientdict may contain a password + # param, creds is just what the user authed as for UI auth + # and is not sensitive). + logger.info( + "Auth completed with creds: %r. Client dict has keys: %r", + creds, clientdict.keys() + ) defer.returnValue((True, creds, clientdict, session['id'])) ret = self._auth_dict_for_flows(flows, session) ret['completed'] = creds.keys() + ret.update(errordict) defer.returnValue((False, ret, clientdict, session['id'])) @defer.inlineCallbacks @@ -163,9 +206,13 @@ class AuthHandler(BaseHandler): def get_session_id(self, clientdict): """ Gets the session ID for a client given the client dictionary - :param clientdict: The dictionary sent by the client in the request - :return: The string session ID the client sent. If the client did not - send a session ID, returns None. + + Args: + clientdict: The dictionary sent by the client in the request + + Returns: + str|None: The string session ID the client sent. If the client did + not send a session ID, returns None. """ sid = None if clientdict and 'auth' in clientdict: @@ -179,9 +226,11 @@ class AuthHandler(BaseHandler): Store a key-value pair into the sessions data associated with this request. This data is stored server-side and cannot be modified by the client. - :param session_id: (string) The ID of this session as returned from check_auth - :param key: (string) The key to store the data under - :param value: (any) The data to store + + Args: + session_id (string): The ID of this session as returned from check_auth + key (string): The key to store the data under + value (any): The data to store """ sess = self._get_session_info(session_id) sess.setdefault('serverdict', {})[key] = value @@ -190,14 +239,15 @@ class AuthHandler(BaseHandler): def get_session_data(self, session_id, key, default=None): """ Retrieve data stored with set_session_data - :param session_id: (string) The ID of this session as returned from check_auth - :param key: (string) The key to store the data under - :param default: (any) Value to return if the key has not been set + + Args: + session_id (string): The ID of this session as returned from check_auth + key (string): The key to store the data under + default (any): Value to return if the key has not been set """ sess = self._get_session_info(session_id) return sess.setdefault('serverdict', {}).get(key, default) - @defer.inlineCallbacks def _check_password_auth(self, authdict, _): if "user" not in authdict or "password" not in authdict: raise LoginError(400, "", Codes.MISSING_PARAM) @@ -207,9 +257,7 @@ class AuthHandler(BaseHandler): if not user_id.startswith('@'): user_id = UserID.create(user_id, self.hs.hostname).to_string() - user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) - self._check_password(user_id, password, password_hash) - defer.returnValue(user_id) + return self._check_password(user_id, password) @defer.inlineCallbacks def _check_recaptcha(self, authdict, clientip): @@ -245,8 +293,17 @@ class AuthHandler(BaseHandler): data = pde.response resp_body = simplejson.loads(data) - if 'success' in resp_body and resp_body['success']: - defer.returnValue(True) + if 'success' in resp_body: + # Note that we do NOT check the hostname here: we explicitly + # intend the CAPTCHA to be presented by whatever client the + # user is using, we just care that they have completed a CAPTCHA. + logger.info( + "%s reCAPTCHA from hostname %s", + "Successful" if resp_body['success'] else "Failed", + resp_body.get('hostname') + ) + if resp_body['success']: + defer.returnValue(True) raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) @defer.inlineCallbacks @@ -313,147 +370,205 @@ class AuthHandler(BaseHandler): return self.sessions[session_id] - @defer.inlineCallbacks - def login_with_password(self, user_id, password): + def validate_password_login(self, user_id, password): """ Authenticates the user with their username and password. Used only by the v1 login API. Args: - user_id (str): User ID + user_id (str): complete @user:id password (str): Password Returns: - A tuple of: - The user's ID. - The access token for the user's session. - The refresh token for the user's session. + defer.Deferred: (str) canonical user id Raises: - StoreError if there was a problem storing the token. + StoreError if there was a problem accessing the database LoginError if there was an authentication problem. """ - user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) - self._check_password(user_id, password, password_hash) - - logger.info("Logging in user %s", user_id) - access_token = yield self.issue_access_token(user_id) - refresh_token = yield self.issue_refresh_token(user_id) - defer.returnValue((user_id, access_token, refresh_token)) + return self._check_password(user_id, password) @defer.inlineCallbacks - def get_login_tuple_for_user_id(self, user_id): + def get_access_token_for_user_id(self, user_id, device_id=None, + initial_display_name=None): """ - Gets login tuple for the user with the given user ID. + Creates a new access token for the user with the given user ID. + The user is assumed to have been authenticated by some other - machanism (e.g. CAS) + machanism (e.g. CAS), and the user_id converted to the canonical case. + + The device will be recorded in the table if it is not there already. Args: - user_id (str): User ID + user_id (str): canonical User ID + device_id (str|None): the device ID to associate with the tokens. + None to leave the tokens unassociated with a device (deprecated: + we should always have a device ID) + initial_display_name (str): display name to associate with the + device if it needs re-registering Returns: - A tuple of: - The user's ID. The access token for the user's session. - The refresh token for the user's session. Raises: StoreError if there was a problem storing the token. LoginError if there was an authentication problem. """ - user_id, ignored = yield self._find_user_id_and_pwd_hash(user_id) + logger.info("Logging in user %s on device %s", user_id, device_id) + access_token = yield self.issue_access_token(user_id, device_id) + + # the device *should* have been registered before we got here; however, + # it's possible we raced against a DELETE operation. The thing we + # really don't want is active access_tokens without a record of the + # device, so we double-check it here. + if device_id is not None: + yield self.device_handler.check_device_registered( + user_id, device_id, initial_display_name + ) - logger.info("Logging in user %s", user_id) - access_token = yield self.issue_access_token(user_id) - refresh_token = yield self.issue_refresh_token(user_id) - defer.returnValue((user_id, access_token, refresh_token)) + defer.returnValue(access_token) @defer.inlineCallbacks - def does_user_exist(self, user_id): - try: - yield self._find_user_id_and_pwd_hash(user_id) - defer.returnValue(True) - except LoginError: - defer.returnValue(False) + def check_user_exists(self, user_id): + """ + Checks to see if a user with the given id exists. Will check case + insensitively, but return None if there are multiple inexact matches. + + Args: + (str) user_id: complete @user:id + + Returns: + defer.Deferred: (str) canonical_user_id, or None if zero or + multiple matches + """ + res = yield self._find_user_id_and_pwd_hash(user_id) + if res is not None: + defer.returnValue(res[0]) + defer.returnValue(None) @defer.inlineCallbacks def _find_user_id_and_pwd_hash(self, user_id): """Checks to see if a user with the given id exists. Will check case - insensitively, but will throw if there are multiple inexact matches. + insensitively, but will return None if there are multiple inexact + matches. Returns: tuple: A 2-tuple of `(canonical_user_id, password_hash)` + None: if there is not exactly one match """ user_infos = yield self.store.get_users_by_id_case_insensitive(user_id) + + result = None if not user_infos: logger.warn("Attempted to login as %s but they do not exist", user_id) - raise LoginError(403, "", errcode=Codes.FORBIDDEN) - - if len(user_infos) > 1: - if user_id not in user_infos: - logger.warn( - "Attempted to login as %s but it matches more than one user " - "inexactly: %r", - user_id, user_infos.keys() - ) - raise LoginError(403, "", errcode=Codes.FORBIDDEN) - - defer.returnValue((user_id, user_infos[user_id])) + elif len(user_infos) == 1: + # a single match (possibly not exact) + result = user_infos.popitem() + elif user_id in user_infos: + # multiple matches, but one is exact + result = (user_id, user_infos[user_id]) else: - defer.returnValue(user_infos.popitem()) + # multiple matches, none of them exact + logger.warn( + "Attempted to login as %s but it matches more than one user " + "inexactly: %r", + user_id, user_infos.keys() + ) + defer.returnValue(result) + + @defer.inlineCallbacks + def _check_password(self, user_id, password): + """Authenticate a user against the LDAP and local databases. - def _check_password(self, user_id, password, stored_hash): - """Checks that user_id has passed password, raises LoginError if not.""" - if not self.validate_hash(password, stored_hash): + user_id is checked case insensitively against the local database, but + will throw if there are multiple inexact matches. + + Args: + user_id (str): complete @user:id + Returns: + (str) the canonical_user_id + Raises: + LoginError if login fails + """ + for provider in self.password_providers: + is_valid = yield provider.check_password(user_id, password) + if is_valid: + defer.returnValue(user_id) + + canonical_user_id = yield self._check_local_password(user_id, password) + + if canonical_user_id: + defer.returnValue(canonical_user_id) + + # unknown username or invalid password. We raise a 403 here, but note + # that if we're doing user-interactive login, it turns all LoginErrors + # into a 401 anyway. + raise LoginError( + 403, "Invalid password", + errcode=Codes.FORBIDDEN + ) + + @defer.inlineCallbacks + def _check_local_password(self, user_id, password): + """Authenticate a user against the local password database. + + user_id is checked case insensitively, but will return None if there are + multiple inexact matches. + + Args: + user_id (str): complete @user:id + Returns: + (str) the canonical_user_id, or None if unknown user / bad password + """ + lookupres = yield self._find_user_id_and_pwd_hash(user_id) + if not lookupres: + defer.returnValue(None) + (user_id, password_hash) = lookupres + result = self.validate_hash(password, password_hash) + if not result: logger.warn("Failed password login for user %s", user_id) - raise LoginError(403, "", errcode=Codes.FORBIDDEN) + defer.returnValue(None) + defer.returnValue(user_id) @defer.inlineCallbacks - def issue_access_token(self, user_id): + def issue_access_token(self, user_id, device_id=None): access_token = self.generate_access_token(user_id) - yield self.store.add_access_token_to_user(user_id, access_token) + yield self.store.add_access_token_to_user(user_id, access_token, + device_id) defer.returnValue(access_token) - @defer.inlineCallbacks - def issue_refresh_token(self, user_id): - refresh_token = self.generate_refresh_token(user_id) - yield self.store.add_refresh_token_to_user(user_id, refresh_token) - defer.returnValue(refresh_token) - def generate_access_token(self, user_id, extra_caveats=None): extra_caveats = extra_caveats or [] macaroon = self._generate_base_macaroon(user_id) macaroon.add_first_party_caveat("type = access") - now = self.hs.get_clock().time_msec() - expiry = now + (60 * 60 * 1000) - macaroon.add_first_party_caveat("time < %d" % (expiry,)) + # Include a nonce, to make sure that each login gets a different + # access token. + macaroon.add_first_party_caveat("nonce = %s" % ( + stringutils.random_string_with_symbols(16), + )) for caveat in extra_caveats: macaroon.add_first_party_caveat(caveat) return macaroon.serialize() - def generate_refresh_token(self, user_id): - m = self._generate_base_macaroon(user_id) - m.add_first_party_caveat("type = refresh") - # Important to add a nonce, because otherwise every refresh token for a - # user will be the same. - m.add_first_party_caveat("nonce = %s" % ( - stringutils.random_string_with_symbols(16), - )) - return m.serialize() - - def generate_short_term_login_token(self, user_id): + def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)): macaroon = self._generate_base_macaroon(user_id) macaroon.add_first_party_caveat("type = login") now = self.hs.get_clock().time_msec() - expiry = now + (2 * 60 * 1000) + expiry = now + duration_in_ms macaroon.add_first_party_caveat("time < %d" % (expiry,)) return macaroon.serialize() + def generate_delete_pusher_token(self, user_id): + macaroon = self._generate_base_macaroon(user_id) + macaroon.add_first_party_caveat("type = delete_pusher") + return macaroon.serialize() + def validate_short_term_login_token_and_get_user_id(self, login_token): + auth_api = self.hs.get_auth() try: macaroon = pymacaroons.Macaroon.deserialize(login_token) - auth_api = self.hs.get_auth() - auth_api.validate_macaroon(macaroon, "login", True) - return self.get_user_from_macaroon(macaroon) - except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError): - raise AuthError(401, "Invalid token", errcode=Codes.UNKNOWN_TOKEN) + user_id = auth_api.get_user_id_from_macaroon(macaroon) + auth_api.validate_macaroon(macaroon, "login", True, user_id) + return user_id + except Exception: + raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) def _generate_base_macaroon(self, user_id): macaroon = pymacaroons.Macaroon( @@ -464,32 +579,39 @@ class AuthHandler(BaseHandler): macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) return macaroon - def get_user_from_macaroon(self, macaroon): - user_prefix = "user_id = " - for caveat in macaroon.caveats: - if caveat.caveat_id.startswith(user_prefix): - return caveat.caveat_id[len(user_prefix):] - raise AuthError( - self.INVALID_TOKEN_HTTP_STATUS, "No user_id found in token", - errcode=Codes.UNKNOWN_TOKEN - ) - @defer.inlineCallbacks def set_password(self, user_id, newpassword, requester=None): password_hash = self.hash(newpassword) - except_access_token_ids = [requester.access_token_id] if requester else [] + except_access_token_id = requester.access_token_id if requester else None - yield self.store.user_set_password_hash(user_id, password_hash) + try: + yield self.store.user_set_password_hash(user_id, password_hash) + except StoreError as e: + if e.code == 404: + raise SynapseError(404, "Unknown user", Codes.NOT_FOUND) + raise e yield self.store.user_delete_access_tokens( - user_id, except_access_token_ids + user_id, except_access_token_id ) yield self.hs.get_pusherpool().remove_pushers_by_user( - user_id, except_access_token_ids + user_id, except_access_token_id ) @defer.inlineCallbacks def add_threepid(self, user_id, medium, address, validated_at): + # 'Canonicalise' email addresses down to lower case. + # We've now moving towards the Home Server being the entity that + # is responsible for validating threepids used for resetting passwords + # on accounts, so in future Synapse will gain knowledge of specific + # types (mediums) of threepid. For now, we still use the existing + # infrastructure, but this is the start of synapse gaining knowledge + # of specific types of threepid (and fixes the fact that checking + # for the presenc eof an email address during password reset was + # case sensitive). + if medium == 'email': + address = address.lower() + yield self.store.user_add_threepid( user_id, medium, address, validated_at, self.hs.get_clock().time_msec() @@ -520,7 +642,8 @@ class AuthHandler(BaseHandler): Returns: Hashed password (str). """ - return bcrypt.hashpw(password, bcrypt.gensalt(self.bcrypt_rounds)) + return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper, + bcrypt.gensalt(self.bcrypt_rounds)) def validate_hash(self, password, stored_hash): """Validates that self.hash(password) == stored_hash. @@ -532,4 +655,35 @@ class AuthHandler(BaseHandler): Returns: Whether self.hash(password) == stored_hash (bool). """ - return bcrypt.hashpw(password, stored_hash) == stored_hash + if stored_hash: + return bcrypt.hashpw(password + self.hs.config.password_pepper, + stored_hash.encode('utf-8')) == stored_hash + else: + return False + + +class _AccountHandler(object): + """A proxy object that gets passed to password auth providers so they + can register new users etc if necessary. + """ + def __init__(self, hs, check_user_exists): + self.hs = hs + + self._check_user_exists = check_user_exists + + def check_user_exists(self, user_id): + """Check if user exissts. + + Returns: + Deferred(bool) + """ + return self._check_user_exists(user_id) + + def register(self, localpart): + """Registers a new user with given localpart + + Returns: + Deferred: a 2-tuple of (user_id, access_token) + """ + reg = self.hs.get_handlers().registration_handler + return reg.register(localpart=localpart) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py new file mode 100644 index 0000000000..aa68755936 --- /dev/null +++ b/synapse/handlers/device.py @@ -0,0 +1,181 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.api import errors +from synapse.util import stringutils +from twisted.internet import defer +from ._base import BaseHandler + +import logging + +logger = logging.getLogger(__name__) + + +class DeviceHandler(BaseHandler): + def __init__(self, hs): + super(DeviceHandler, self).__init__(hs) + + @defer.inlineCallbacks + def check_device_registered(self, user_id, device_id, + initial_device_display_name=None): + """ + If the given device has not been registered, register it with the + supplied display name. + + If no device_id is supplied, we make one up. + + Args: + user_id (str): @user:id + device_id (str | None): device id supplied by client + initial_device_display_name (str | None): device display name from + client + Returns: + str: device id (generated if none was supplied) + """ + if device_id is not None: + yield self.store.store_device( + user_id=user_id, + device_id=device_id, + initial_device_display_name=initial_device_display_name, + ignore_if_known=True, + ) + defer.returnValue(device_id) + + # if the device id is not specified, we'll autogen one, but loop a few + # times in case of a clash. + attempts = 0 + while attempts < 5: + try: + device_id = stringutils.random_string(10).upper() + yield self.store.store_device( + user_id=user_id, + device_id=device_id, + initial_device_display_name=initial_device_display_name, + ignore_if_known=False, + ) + defer.returnValue(device_id) + except errors.StoreError: + attempts += 1 + + raise errors.StoreError(500, "Couldn't generate a device ID.") + + @defer.inlineCallbacks + def get_devices_by_user(self, user_id): + """ + Retrieve the given user's devices + + Args: + user_id (str): + Returns: + defer.Deferred: list[dict[str, X]]: info on each device + """ + + device_map = yield self.store.get_devices_by_user(user_id) + + ips = yield self.store.get_last_client_ip_by_device( + devices=((user_id, device_id) for device_id in device_map.keys()) + ) + + devices = device_map.values() + for device in devices: + _update_device_from_client_ips(device, ips) + + defer.returnValue(devices) + + @defer.inlineCallbacks + def get_device(self, user_id, device_id): + """ Retrieve the given device + + Args: + user_id (str): + device_id (str): + + Returns: + defer.Deferred: dict[str, X]: info on the device + Raises: + errors.NotFoundError: if the device was not found + """ + try: + device = yield self.store.get_device(user_id, device_id) + except errors.StoreError: + raise errors.NotFoundError + ips = yield self.store.get_last_client_ip_by_device( + devices=((user_id, device_id),) + ) + _update_device_from_client_ips(device, ips) + defer.returnValue(device) + + @defer.inlineCallbacks + def delete_device(self, user_id, device_id): + """ Delete the given device + + Args: + user_id (str): + device_id (str): + + Returns: + defer.Deferred: + """ + + try: + yield self.store.delete_device(user_id, device_id) + except errors.StoreError, e: + if e.code == 404: + # no match + pass + else: + raise + + yield self.store.user_delete_access_tokens( + user_id, device_id=device_id, + delete_refresh_tokens=True, + ) + + yield self.store.delete_e2e_keys_by_device( + user_id=user_id, device_id=device_id + ) + + @defer.inlineCallbacks + def update_device(self, user_id, device_id, content): + """ Update the given device + + Args: + user_id (str): + device_id (str): + content (dict): body of update request + + Returns: + defer.Deferred: + """ + + try: + yield self.store.update_device( + user_id, + device_id, + new_display_name=content.get("display_name") + ) + except errors.StoreError, e: + if e.code == 404: + raise errors.NotFoundError() + else: + raise + + +def _update_device_from_client_ips(device, client_ips): + ip = client_ips.get((device["user_id"], device["device_id"]), {}) + device.update({ + "last_seen_ts": ip.get("last_seen"), + "last_seen_ip": ip.get("ip"), + }) diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py new file mode 100644 index 0000000000..f7fad15c62 --- /dev/null +++ b/synapse/handlers/devicemessage.py @@ -0,0 +1,117 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.internet import defer + +from synapse.types import get_domain_from_id +from synapse.util.stringutils import random_string + + +logger = logging.getLogger(__name__) + + +class DeviceMessageHandler(object): + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + self.store = hs.get_datastore() + self.notifier = hs.get_notifier() + self.is_mine_id = hs.is_mine_id + self.federation = hs.get_federation_sender() + + hs.get_replication_layer().register_edu_handler( + "m.direct_to_device", self.on_direct_to_device_edu + ) + + @defer.inlineCallbacks + def on_direct_to_device_edu(self, origin, content): + local_messages = {} + sender_user_id = content["sender"] + if origin != get_domain_from_id(sender_user_id): + logger.warn( + "Dropping device message from %r with spoofed sender %r", + origin, sender_user_id + ) + message_type = content["type"] + message_id = content["message_id"] + for user_id, by_device in content["messages"].items(): + messages_by_device = { + device_id: { + "content": message_content, + "type": message_type, + "sender": sender_user_id, + } + for device_id, message_content in by_device.items() + } + if messages_by_device: + local_messages[user_id] = messages_by_device + + stream_id = yield self.store.add_messages_from_remote_to_device_inbox( + origin, message_id, local_messages + ) + + self.notifier.on_new_event( + "to_device_key", stream_id, users=local_messages.keys() + ) + + @defer.inlineCallbacks + def send_device_message(self, sender_user_id, message_type, messages): + + local_messages = {} + remote_messages = {} + for user_id, by_device in messages.items(): + if self.is_mine_id(user_id): + messages_by_device = { + device_id: { + "content": message_content, + "type": message_type, + "sender": sender_user_id, + } + for device_id, message_content in by_device.items() + } + if messages_by_device: + local_messages[user_id] = messages_by_device + else: + destination = get_domain_from_id(user_id) + remote_messages.setdefault(destination, {})[user_id] = by_device + + message_id = random_string(16) + + remote_edu_contents = {} + for destination, messages in remote_messages.items(): + remote_edu_contents[destination] = { + "messages": messages, + "sender": sender_user_id, + "type": message_type, + "message_id": message_id, + } + + stream_id = yield self.store.add_messages_to_device_inbox( + local_messages, remote_edu_contents + ) + + self.notifier.on_new_event( + "to_device_key", stream_id, users=local_messages.keys() + ) + + for destination in remote_messages.keys(): + # Enqueue a new federation transaction to send the new + # device messages to each remote destination. + self.federation.send_device_messages(destination) diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index c4aaa11918..c00274afc3 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -19,7 +19,7 @@ from ._base import BaseHandler from synapse.api.errors import SynapseError, Codes, CodeMessageException, AuthError from synapse.api.constants import EventTypes -from synapse.types import RoomAlias, UserID +from synapse.types import RoomAlias, UserID, get_domain_from_id import logging import string @@ -32,6 +32,9 @@ class DirectoryHandler(BaseHandler): def __init__(self, hs): super(DirectoryHandler, self).__init__(hs) + self.state = hs.get_state_handler() + self.appservice_handler = hs.get_application_service_handler() + self.federation = hs.get_replication_layer() self.federation.register_query_handler( "directory", self.on_directory_query @@ -52,7 +55,8 @@ class DirectoryHandler(BaseHandler): # TODO(erikj): Add transactions. # TODO(erikj): Check if there is a current association. if not servers: - servers = yield self.store.get_joined_hosts_for_room(room_id) + users = yield self.state.get_current_user_in_room(room_id) + servers = set(get_domain_from_id(u) for u in users) if not servers: raise SynapseError(400, "Failed to get server list") @@ -93,7 +97,7 @@ class DirectoryHandler(BaseHandler): yield self._create_association(room_alias, room_id, servers) @defer.inlineCallbacks - def delete_association(self, user_id, room_alias): + def delete_association(self, requester, user_id, room_alias): # association deletion for human users can_delete = yield self._user_can_delete_alias(room_alias, user_id) @@ -112,7 +116,25 @@ class DirectoryHandler(BaseHandler): errcode=Codes.EXCLUSIVE ) - yield self._delete_association(room_alias) + room_id = yield self._delete_association(room_alias) + + try: + yield self.send_room_alias_update_event( + requester, + requester.user.to_string(), + room_id + ) + + yield self._update_canonical_alias( + requester, + requester.user.to_string(), + room_id, + room_alias, + ) + except AuthError as e: + logger.info("Failed to update alias events: %s", e) + + defer.returnValue(room_id) @defer.inlineCallbacks def delete_appservice_association(self, service, room_alias): @@ -129,11 +151,9 @@ class DirectoryHandler(BaseHandler): if not self.hs.is_mine(room_alias): raise SynapseError(400, "Room alias must be local") - yield self.store.delete_room_alias(room_alias) + room_id = yield self.store.delete_room_alias(room_alias) - # TODO - Looks like _update_room_alias_event has never been implemented - # if room_id: - # yield self._update_room_alias_events(user_id, room_id) + defer.returnValue(room_id) @defer.inlineCallbacks def get_association(self, room_alias): @@ -174,7 +194,8 @@ class DirectoryHandler(BaseHandler): Codes.NOT_FOUND ) - extra_servers = yield self.store.get_joined_hosts_for_room(room_id) + users = yield self.state.get_current_user_in_room(room_id) + extra_servers = set(get_domain_from_id(u) for u in users) servers = set(extra_servers) | set(servers) # If this server is in the list of servers, return it first. @@ -234,23 +255,45 @@ class DirectoryHandler(BaseHandler): ) @defer.inlineCallbacks + def _update_canonical_alias(self, requester, user_id, room_id, room_alias): + alias_event = yield self.state.get_current_state( + room_id, EventTypes.CanonicalAlias, "" + ) + + alias_str = room_alias.to_string() + if not alias_event or alias_event.content.get("alias", "") != alias_str: + return + + msg_handler = self.hs.get_handlers().message_handler + yield msg_handler.create_and_send_nonmember_event( + requester, + { + "type": EventTypes.CanonicalAlias, + "state_key": "", + "room_id": room_id, + "sender": user_id, + "content": {}, + }, + ratelimit=False + ) + + @defer.inlineCallbacks def get_association_from_room_alias(self, room_alias): result = yield self.store.get_association_from_room_alias( room_alias ) if not result: # Query AS to see if it exists - as_handler = self.hs.get_handlers().appservice_handler + as_handler = self.appservice_handler result = yield as_handler.query_room_alias_exists(room_alias) defer.returnValue(result) - @defer.inlineCallbacks def can_modify_alias(self, alias, user_id=None): # Any application service "interested" in an alias they are regexing on # can modify the alias. # Users can only modify the alias if ALL the interested services have # non-exclusive locks on the alias (or there are no interested services) - services = yield self.store.get_app_services() + services = self.store.get_app_services() interested_services = [ s for s in services if s.is_interested_in_alias(alias.to_string()) ] @@ -258,14 +301,12 @@ class DirectoryHandler(BaseHandler): for service in interested_services: if user_id == service.sender: # this user IS the app service so they can do whatever they like - defer.returnValue(True) - return + return defer.succeed(True) elif service.is_exclusive_alias(alias.to_string()): # another service has an exclusive lock on this alias. - defer.returnValue(False) - return + return defer.succeed(False) # either no interested services, or no service with an exclusive lock - defer.returnValue(True) + return defer.succeed(True) @defer.inlineCallbacks def _user_can_delete_alias(self, alias, user_id): @@ -276,3 +317,25 @@ class DirectoryHandler(BaseHandler): is_admin = yield self.auth.is_server_admin(UserID.from_string(user_id)) defer.returnValue(is_admin) + + @defer.inlineCallbacks + def edit_published_room_list(self, requester, room_id, visibility): + """Edit the entry of the room in the published room list. + + requester + room_id (str) + visibility (str): "public" or "private" + """ + if requester.is_guest: + raise AuthError(403, "Guests cannot edit the published room list") + + if visibility not in ["public", "private"]: + raise SynapseError(400, "Invalid visibility setting") + + room = yield self.store.get_room(room_id) + if room is None: + raise SynapseError(400, "Unknown room") + + yield self.auth.check_can_change_room_list(room_id, requester.user) + + yield self.store.set_room_is_public(room_id, visibility == "public") diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py new file mode 100644 index 0000000000..fd11935b40 --- /dev/null +++ b/synapse/handlers/e2e_keys.py @@ -0,0 +1,279 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ujson as json +import logging + +from canonicaljson import encode_canonical_json +from twisted.internet import defer + +from synapse.api.errors import SynapseError, CodeMessageException +from synapse.types import get_domain_from_id +from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred +from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination + +logger = logging.getLogger(__name__) + + +class E2eKeysHandler(object): + def __init__(self, hs): + self.store = hs.get_datastore() + self.federation = hs.get_replication_layer() + self.device_handler = hs.get_device_handler() + self.is_mine_id = hs.is_mine_id + self.clock = hs.get_clock() + + # doesn't really work as part of the generic query API, because the + # query request requires an object POST, but we abuse the + # "query handler" interface. + self.federation.register_query_handler( + "client_keys", self.on_federation_query_client_keys + ) + + @defer.inlineCallbacks + def query_devices(self, query_body, timeout): + """ Handle a device key query from a client + + { + "device_keys": { + "<user_id>": ["<device_id>"] + } + } + -> + { + "device_keys": { + "<user_id>": { + "<device_id>": { + ... + } + } + } + } + """ + device_keys_query = query_body.get("device_keys", {}) + + # separate users by domain. + # make a map from domain to user_id to device_ids + local_query = {} + remote_queries = {} + + for user_id, device_ids in device_keys_query.items(): + if self.is_mine_id(user_id): + local_query[user_id] = device_ids + else: + domain = get_domain_from_id(user_id) + remote_queries.setdefault(domain, {})[user_id] = device_ids + + # do the queries + failures = {} + results = {} + if local_query: + local_result = yield self.query_local_devices(local_query) + for user_id, keys in local_result.items(): + if user_id in local_query: + results[user_id] = keys + + @defer.inlineCallbacks + def do_remote_query(destination): + destination_query = remote_queries[destination] + try: + limiter = yield get_retry_limiter( + destination, self.clock, self.store + ) + with limiter: + remote_result = yield self.federation.query_client_keys( + destination, + {"device_keys": destination_query}, + timeout=timeout + ) + + for user_id, keys in remote_result["device_keys"].items(): + if user_id in destination_query: + results[user_id] = keys + + except CodeMessageException as e: + failures[destination] = { + "status": e.code, "message": e.message + } + except NotRetryingDestination as e: + failures[destination] = { + "status": 503, "message": "Not ready for retry", + } + + yield preserve_context_over_deferred(defer.gatherResults([ + preserve_fn(do_remote_query)(destination) + for destination in remote_queries + ])) + + defer.returnValue({ + "device_keys": results, "failures": failures, + }) + + @defer.inlineCallbacks + def query_local_devices(self, query): + """Get E2E device keys for local users + + Args: + query (dict[string, list[string]|None): map from user_id to a list + of devices to query (None for all devices) + + Returns: + defer.Deferred: (resolves to dict[string, dict[string, dict]]): + map from user_id -> device_id -> device details + """ + local_query = [] + + result_dict = {} + for user_id, device_ids in query.items(): + if not self.is_mine_id(user_id): + logger.warning("Request for keys for non-local user %s", + user_id) + raise SynapseError(400, "Not a user here") + + if not device_ids: + local_query.append((user_id, None)) + else: + for device_id in device_ids: + local_query.append((user_id, device_id)) + + # make sure that each queried user appears in the result dict + result_dict[user_id] = {} + + results = yield self.store.get_e2e_device_keys(local_query) + + # Build the result structure, un-jsonify the results, and add the + # "unsigned" section + for user_id, device_keys in results.items(): + for device_id, device_info in device_keys.items(): + r = json.loads(device_info["key_json"]) + r["unsigned"] = {} + display_name = device_info["device_display_name"] + if display_name is not None: + r["unsigned"]["device_display_name"] = display_name + result_dict[user_id][device_id] = r + + defer.returnValue(result_dict) + + @defer.inlineCallbacks + def on_federation_query_client_keys(self, query_body): + """ Handle a device key query from a federated server + """ + device_keys_query = query_body.get("device_keys", {}) + res = yield self.query_local_devices(device_keys_query) + defer.returnValue({"device_keys": res}) + + @defer.inlineCallbacks + def claim_one_time_keys(self, query, timeout): + local_query = [] + remote_queries = {} + + for user_id, device_keys in query.get("one_time_keys", {}).items(): + if self.is_mine_id(user_id): + for device_id, algorithm in device_keys.items(): + local_query.append((user_id, device_id, algorithm)) + else: + domain = get_domain_from_id(user_id) + remote_queries.setdefault(domain, {})[user_id] = device_keys + + results = yield self.store.claim_e2e_one_time_keys(local_query) + + json_result = {} + failures = {} + for user_id, device_keys in results.items(): + for device_id, keys in device_keys.items(): + for key_id, json_bytes in keys.items(): + json_result.setdefault(user_id, {})[device_id] = { + key_id: json.loads(json_bytes) + } + + @defer.inlineCallbacks + def claim_client_keys(destination): + device_keys = remote_queries[destination] + try: + limiter = yield get_retry_limiter( + destination, self.clock, self.store + ) + with limiter: + remote_result = yield self.federation.claim_client_keys( + destination, + {"one_time_keys": device_keys}, + timeout=timeout + ) + for user_id, keys in remote_result["one_time_keys"].items(): + if user_id in device_keys: + json_result[user_id] = keys + except CodeMessageException as e: + failures[destination] = { + "status": e.code, "message": e.message + } + except NotRetryingDestination as e: + failures[destination] = { + "status": 503, "message": "Not ready for retry", + } + + yield preserve_context_over_deferred(defer.gatherResults([ + preserve_fn(claim_client_keys)(destination) + for destination in remote_queries + ])) + + defer.returnValue({ + "one_time_keys": json_result, + "failures": failures + }) + + @defer.inlineCallbacks + def upload_keys_for_user(self, user_id, device_id, keys): + time_now = self.clock.time_msec() + + # TODO: Validate the JSON to make sure it has the right keys. + device_keys = keys.get("device_keys", None) + if device_keys: + logger.info( + "Updating device_keys for device %r for user %s at %d", + device_id, user_id, time_now + ) + # TODO: Sign the JSON with the server key + yield self.store.set_e2e_device_keys( + user_id, device_id, time_now, + encode_canonical_json(device_keys) + ) + + one_time_keys = keys.get("one_time_keys", None) + if one_time_keys: + logger.info( + "Adding %d one_time_keys for device %r for user %r at %d", + len(one_time_keys), device_id, user_id, time_now + ) + key_list = [] + for key_id, key_json in one_time_keys.items(): + algorithm, key_id = key_id.split(":") + key_list.append(( + algorithm, key_id, encode_canonical_json(key_json) + )) + + yield self.store.add_e2e_one_time_keys( + user_id, device_id, time_now, key_list + ) + + # the device should have been registered already, but it may have been + # deleted due to a race with a DELETE request. Or we may be using an + # old access_token without an associated device_id. Either way, we + # need to double-check the device is registered to avoid ending up with + # keys without a corresponding device. + self.device_handler.check_device_registered(user_id, device_id) + + result = yield self.store.count_e2e_one_time_keys(user_id, device_id) + + defer.returnValue({"one_time_key_counts": result}) diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index f25a252523..d3685fb12a 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -47,6 +47,7 @@ class EventStreamHandler(BaseHandler): self.clock = hs.get_clock() self.notifier = hs.get_notifier() + self.state = hs.get_state_handler() @defer.inlineCallbacks @log_function @@ -58,7 +59,7 @@ class EventStreamHandler(BaseHandler): If `only_keys` is not None, events from keys will be sent down. """ auth_user = UserID.from_string(auth_user_id) - presence_handler = self.hs.get_handlers().presence_handler + presence_handler = self.hs.get_presence_handler() context = yield presence_handler.user_syncing( auth_user_id, affect_presence=affect_presence, @@ -90,7 +91,7 @@ class EventStreamHandler(BaseHandler): # Send down presence. if event.state_key == auth_user_id: # Send down presence for everyone in the room. - users = yield self.store.get_users_in_room(event.room_id) + users = yield self.state.get_current_user_in_room(event.room_id) states = yield presence_handler.get_states( users, as_event=True, diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index f599e817aa..771ab3bc43 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -26,20 +26,24 @@ from synapse.api.errors import ( from synapse.api.constants import EventTypes, Membership, RejectedReason from synapse.events.validator import EventValidator from synapse.util import unwrapFirstError -from synapse.util.logcontext import PreserveLoggingContext +from synapse.util.logcontext import ( + PreserveLoggingContext, preserve_fn, preserve_context_over_deferred +) +from synapse.util.metrics import measure_func from synapse.util.logutils import log_function from synapse.util.async import run_on_reactor from synapse.util.frozenutils import unfreeze from synapse.crypto.event_signing import ( compute_event_signature, add_hashes_and_signatures, ) -from synapse.types import UserID +from synapse.types import UserID, get_domain_from_id from synapse.events.utils import prune_event from synapse.util.retryutils import NotRetryingDestination from synapse.push.action_generator import ActionGenerator +from synapse.util.distributor import user_joined_room from twisted.internet import defer @@ -49,10 +53,6 @@ import logging logger = logging.getLogger(__name__) -def user_joined_room(distributor, user, room_id): - return distributor.fire("user_joined_room", user, room_id) - - class FederationHandler(BaseHandler): """Handles events that originated from federation. Responsible for: @@ -69,10 +69,6 @@ class FederationHandler(BaseHandler): self.hs = hs - self.distributor.observe("user_joined_room", self.user_joined_room) - - self.waiting_for_join_list = {} - self.store = hs.get_datastore() self.replication_layer = hs.get_replication_layer() self.state_handler = hs.get_state_handler() @@ -84,28 +80,14 @@ class FederationHandler(BaseHandler): # When joining a room we need to queue any events for that room up self.room_queues = {} - def handle_new_event(self, event, destinations): - """ Takes in an event from the client to server side, that has already - been authed and handled by the state module, and sends it to any - remote home servers that may be interested. - - Args: - event: The event to send - destinations: A list of destinations to send it to - - Returns: - Deferred: Resolved when it has successfully been queued for - processing. - """ - - return self.replication_layer.send_pdu(event, destinations) - @log_function @defer.inlineCallbacks - def on_receive_pdu(self, origin, pdu, backfilled, state=None, - auth_chain=None): + def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None): """ Called by the ReplicationLayer when we have a new pdu. We need to do auth checks and put it through the StateHandler. + + auth_chain and state are None if we already have the necessary state + and prev_events in the db """ event = pdu @@ -123,17 +105,25 @@ class FederationHandler(BaseHandler): # FIXME (erikj): Awful hack to make the case where we are not currently # in the room work - current_state = None - is_in_room = yield self.auth.check_host_in_room( - event.room_id, - self.server_name - ) - if not is_in_room and not event.internal_metadata.is_outlier(): - logger.debug("Got event for room we're not in.") + # If state and auth_chain are None, then we don't need to do this check + # as we already know we have enough state in the DB to handle this + # event. + if state and auth_chain and not event.internal_metadata.is_outlier(): + is_in_room = yield self.auth.check_host_in_room( + event.room_id, + self.server_name + ) + else: + is_in_room = True + if not is_in_room: + logger.info( + "Got event for room we're not in: %r %r", + event.room_id, event.event_id + ) try: event_stream_id, max_stream_id = yield self._persist_auth_tree( - auth_chain, state, event + origin, auth_chain, state, event ) except AuthError as e: raise FederationError( @@ -175,19 +165,13 @@ class FederationHandler(BaseHandler): }) seen_ids.add(e.event_id) - yield self._handle_new_events( - origin, - event_infos, - outliers=True - ) + yield self._handle_new_events(origin, event_infos) try: context, event_stream_id, max_stream_id = yield self._handle_new_event( origin, event, state=state, - backfilled=backfilled, - current_state=current_state, ) except AuthError as e: raise FederationError( @@ -216,32 +200,42 @@ class FederationHandler(BaseHandler): except StoreError: logger.exception("Failed to store room.") - if not backfilled: - extra_users = [] - if event.type == EventTypes.Member: - target_user_id = event.state_key - target_user = UserID.from_string(target_user_id) - extra_users.append(target_user) + extra_users = [] + if event.type == EventTypes.Member: + target_user_id = event.state_key + target_user = UserID.from_string(target_user_id) + extra_users.append(target_user) - with PreserveLoggingContext(): - self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, - extra_users=extra_users - ) + with PreserveLoggingContext(): + self.notifier.on_new_room_event( + event, event_stream_id, max_stream_id, + extra_users=extra_users + ) if event.type == EventTypes.Member: if event.membership == Membership.JOIN: - prev_state = context.current_state.get((event.type, event.state_key)) - if not prev_state or prev_state.membership != Membership.JOIN: - # Only fire user_joined_room if the user has acutally - # joined the room. Don't bother if the user is just - # changing their profile info. + # Only fire user_joined_room if the user has acutally + # joined the room. Don't bother if the user is just + # changing their profile info. + newly_joined = True + prev_state_id = context.prev_state_ids.get( + (event.type, event.state_key) + ) + if prev_state_id: + prev_state = yield self.store.get_event( + prev_state_id, allow_none=True, + ) + if prev_state and prev_state.membership == Membership.JOIN: + newly_joined = False + + if newly_joined: user = UserID.from_string(event.state_key) yield user_joined_room(self.distributor, user, event.room_id) + @measure_func("_filter_events_for_server") @defer.inlineCallbacks def _filter_events_for_server(self, server_name, room_id, events): - event_to_state = yield self.store.get_state_for_events( + event_to_state_ids = yield self.store.get_state_ids_for_events( frozenset(e.event_id for e in events), types=( (EventTypes.RoomHistoryVisibility, ""), @@ -249,6 +243,30 @@ class FederationHandler(BaseHandler): ) ) + # We only want to pull out member events that correspond to the + # server's domain. + + def check_match(id): + try: + return server_name == get_domain_from_id(id) + except: + return False + + event_map = yield self.store.get_events([ + e_id for key_to_eid in event_to_state_ids.values() + for key, e_id in key_to_eid + if key[0] != EventTypes.Member or check_match(key[1]) + ]) + + event_to_state = { + e_id: { + key: event_map[inner_e_id] + for key, inner_e_id in key_to_eid.items() + if inner_e_id in event_map + } + for e_id, key_to_eid in event_to_state_ids.items() + } + def redact_disallowed(event, state): if not state: return event @@ -265,7 +283,7 @@ class FederationHandler(BaseHandler): if ev.type != EventTypes.Member: continue try: - domain = UserID.from_string(ev.state_key).domain + domain = get_domain_from_id(ev.state_key) except: continue @@ -290,11 +308,15 @@ class FederationHandler(BaseHandler): @log_function @defer.inlineCallbacks - def backfill(self, dest, room_id, limit, extremities=[]): + def backfill(self, dest, room_id, limit, extremities): """ Trigger a backfill request to `dest` for the given `room_id` + + This will attempt to get more events from the remote. This may return + be successfull and still return no events if the other side has no new + events to offer. """ - if not extremities: - extremities = yield self.store.get_oldest_events_in_room(room_id) + if dest == self.server_name: + raise SynapseError(400, "Can't backfill from self.") events = yield self.replication_layer.backfill( dest, @@ -303,6 +325,16 @@ class FederationHandler(BaseHandler): extremities=extremities, ) + # Don't bother processing events we already have. + seen_events = yield self.store.have_events_in_timeline( + set(e.event_id for e in events) + ) + + events = [e for e in events if e.event_id not in seen_events] + + if not events: + defer.returnValue([]) + event_map = {e.event_id: e for e in events} event_ids = set(e.event_id for e in events) @@ -334,40 +366,73 @@ class FederationHandler(BaseHandler): state_events.update({s.event_id: s for s in state}) events_to_state[e_id] = state - seen_events = yield self.store.have_events( - set(auth_events.keys()) | set(state_events.keys()) - ) - - all_events = events + state_events.values() + auth_events.values() required_auth = set( - a_id for event in all_events for a_id, _ in event.auth_events + a_id + for event in events + state_events.values() + auth_events.values() + for a_id, _ in event.auth_events ) - + auth_events.update({ + e_id: event_map[e_id] for e_id in required_auth if e_id in event_map + }) missing_auth = required_auth - set(auth_events) - results = yield defer.gatherResults( - [ - self.replication_layer.get_pdu( - [dest], - event_id, - outlier=True, - timeout=10000, + failed_to_fetch = set() + + # Try and fetch any missing auth events from both DB and remote servers. + # We repeatedly do this until we stop finding new auth events. + while missing_auth - failed_to_fetch: + logger.info("Missing auth for backfill: %r", missing_auth) + ret_events = yield self.store.get_events(missing_auth - failed_to_fetch) + auth_events.update(ret_events) + + required_auth.update( + a_id for event in ret_events.values() for a_id, _ in event.auth_events + ) + missing_auth = required_auth - set(auth_events) + + if missing_auth - failed_to_fetch: + logger.info( + "Fetching missing auth for backfill: %r", + missing_auth - failed_to_fetch ) - for event_id in missing_auth - ], - consumeErrors=True - ).addErrback(unwrapFirstError) - auth_events.update({a.event_id: a for a in results}) + + results = yield preserve_context_over_deferred(defer.gatherResults( + [ + preserve_fn(self.replication_layer.get_pdu)( + [dest], + event_id, + outlier=True, + timeout=10000, + ) + for event_id in missing_auth - failed_to_fetch + ], + consumeErrors=True + )).addErrback(unwrapFirstError) + auth_events.update({a.event_id: a for a in results if a}) + required_auth.update( + a_id + for event in results if event + for a_id, _ in event.auth_events + ) + missing_auth = required_auth - set(auth_events) + + failed_to_fetch = missing_auth - set(auth_events) + + seen_events = yield self.store.have_events( + set(auth_events.keys()) | set(state_events.keys()) + ) ev_infos = [] for a in auth_events.values(): if a.event_id in seen_events: continue + a.internal_metadata.outlier = True ev_infos.append({ "event": a, "auth_events": { (auth_events[a_id].type, auth_events[a_id].state_key): auth_events[a_id] for a_id, _ in a.auth_events + if a_id in auth_events } }) @@ -379,23 +444,27 @@ class FederationHandler(BaseHandler): (auth_events[a_id].type, auth_events[a_id].state_key): auth_events[a_id] for a_id, _ in event_map[e_id].auth_events + if a_id in auth_events } }) + yield self._handle_new_events( + dest, ev_infos, + backfilled=True, + ) + events.sort(key=lambda e: e.depth) for event in events: if event in events_to_state: continue - ev_infos.append({ - "event": event, - }) - - yield self._handle_new_events( - dest, ev_infos, - backfilled=True, - ) + # We store these one at a time since each event depends on the + # previous to work out the state. + # TODO: We can probably do something more clever here. + yield self._handle_new_event( + dest, event, backfilled=True, + ) defer.returnValue(events) @@ -419,6 +488,10 @@ class FederationHandler(BaseHandler): ) max_depth = sorted_extremeties_tuple[0][1] + # We don't want to specify too many extremities as it causes the backfill + # request URI to be too long. + extremities = dict(sorted_extremeties_tuple[:5]) + if current_depth > max_depth: logger.debug( "Not backfilling as we don't need to. %d < %d", @@ -444,7 +517,7 @@ class FederationHandler(BaseHandler): joined_domains = {} for u, d in joined_users: try: - dom = UserID.from_string(u).domain + dom = get_domain_from_id(u) old_d = joined_domains.get(dom) if old_d: joined_domains[dom] = min(d, old_d) @@ -459,7 +532,7 @@ class FederationHandler(BaseHandler): likely_domains = [ domain for domain, depth in curr_domains - if domain is not self.server_name + if domain != self.server_name ] @defer.inlineCallbacks @@ -467,11 +540,15 @@ class FederationHandler(BaseHandler): # TODO: Should we try multiple of these at a time? for dom in domains: try: - events = yield self.backfill( + yield self.backfill( dom, room_id, limit=100, extremities=[e for e in extremities.keys()] ) + # If this succeeded then we probably already have the + # appropriate stuff. + # TODO: We can probably do something more intelligent here. + defer.returnValue(True) except SynapseError as e: logger.info( "Failed to backfill from %s because %s", @@ -497,8 +574,6 @@ class FederationHandler(BaseHandler): ) continue - if events: - defer.returnValue(True) defer.returnValue(False) success = yield try_backfill(likely_domains) @@ -513,12 +588,24 @@ class FederationHandler(BaseHandler): event_ids = list(extremities.keys()) - states = yield defer.gatherResults([ - self.state_handler.resolve_state_groups(room_id, [e]) + states = yield preserve_context_over_deferred(defer.gatherResults([ + preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e]) for e in event_ids - ]) + ])) states = dict(zip(event_ids, [s[1] for s in states])) + state_map = yield self.store.get_events( + [e_id for ids in states.values() for e_id in ids], + get_prev_content=False + ) + states = { + key: { + k: state_map[e_id] + for k, e_id in state_dict.items() + if e_id in state_map + } for key, state_dict in states.items() + } + for e_id, _ in sorted_extremeties_tuple: likely_domains = get_domains_from_state(states[e_id]) @@ -628,7 +715,7 @@ class FederationHandler(BaseHandler): pass event_stream_id, max_stream_id = yield self._persist_auth_tree( - auth_chain, state, event + origin, auth_chain, state, event ) with PreserveLoggingContext(): @@ -647,7 +734,7 @@ class FederationHandler(BaseHandler): continue try: - self.on_receive_pdu(origin, p, backfilled=False) + self.on_receive_pdu(origin, p) except: logger.exception("Couldn't handle pdu") @@ -670,11 +757,18 @@ class FederationHandler(BaseHandler): "state_key": user_id, }) - event, context = yield self._create_new_client_event( - builder=builder, - ) + try: + message_handler = self.hs.get_handlers().message_handler + event, context = yield message_handler._create_new_client_event( + builder=builder, + ) + except AuthError as e: + logger.warn("Failed to create join %r because %s", event, e) + raise e - self.auth.check(event, auth_events=context.current_state) + # The remote hasn't signed it yet, obviously. We'll do the full checks + # when we get the event back in `on_send_join_request` + yield self.auth.check_from_context(event, context, do_sig_check=False) defer.returnValue(event) @@ -720,39 +814,15 @@ class FederationHandler(BaseHandler): user = UserID.from_string(event.state_key) yield user_joined_room(self.distributor, user, event.room_id) - new_pdu = event - - destinations = set() - - for k, s in context.current_state.items(): - try: - if k[0] == EventTypes.Member: - if s.content["membership"] == Membership.JOIN: - destinations.add( - UserID.from_string(s.state_key).domain - ) - except: - logger.warn( - "Failed to get destination from event %s", s.event_id - ) - - destinations.discard(origin) - - logger.debug( - "on_send_join_request: Sending event: %s, signatures: %s", - event.event_id, - event.signatures, - ) - - self.replication_layer.send_pdu(new_pdu, destinations) - - state_ids = [e.event_id for e in context.current_state.values()] + state_ids = context.prev_state_ids.values() auth_chain = yield self.store.get_auth_chain(set( [event.event_id] + state_ids )) + state = yield self.store.get_events(context.prev_state_ids.values()) + defer.returnValue({ - "state": context.current_state.values(), + "state": state.values(), "auth_chain": auth_chain, }) @@ -765,6 +835,7 @@ class FederationHandler(BaseHandler): event = pdu event.internal_metadata.outlier = True + event.internal_metadata.invite_from_remote = True event.signatures.update( compute_event_signature( @@ -779,7 +850,6 @@ class FederationHandler(BaseHandler): event_stream_id, max_stream_id = yield self.store.persist_event( event, context=context, - backfilled=False, ) target_user = UserID.from_string(event.state_key) @@ -793,13 +863,19 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks def do_remotely_reject_invite(self, target_hosts, room_id, user_id): - origin, event = yield self._make_and_verify_event( - target_hosts, - room_id, - user_id, - "leave" - ) - signed_event = self._sign_event(event) + try: + origin, event = yield self._make_and_verify_event( + target_hosts, + room_id, + user_id, + "leave" + ) + signed_event = self._sign_event(event) + except SynapseError: + raise + except CodeMessageException as e: + logger.warn("Failed to reject invite: %s", e) + raise SynapseError(500, "Failed to reject invite") # Try the host we successfully got a response to /make_join/ # request first. @@ -809,17 +885,22 @@ class FederationHandler(BaseHandler): except ValueError: pass - yield self.replication_layer.send_leave( - target_hosts, - signed_event - ) + try: + yield self.replication_layer.send_leave( + target_hosts, + signed_event + ) + except SynapseError: + raise + except CodeMessageException as e: + logger.warn("Failed to reject invite: %s", e) + raise SynapseError(500, "Failed to reject invite") context = yield self.state_handler.compute_event_context(event) event_stream_id, max_stream_id = yield self.store.persist_event( event, context=context, - backfilled=False, ) target_user = UserID.from_string(event.state_key) @@ -889,11 +970,18 @@ class FederationHandler(BaseHandler): "state_key": user_id, }) - event, context = yield self._create_new_client_event( + message_handler = self.hs.get_handlers().message_handler + event, context = yield message_handler._create_new_client_event( builder=builder, ) - self.auth.check(event, auth_events=context.current_state) + try: + # The remote hasn't signed it yet, obviously. We'll do the full checks + # when we get the event back in `on_send_leave_request` + yield self.auth.check_from_context(event, context, do_sig_check=False) + except AuthError as e: + logger.warn("Failed to create new leave %r because %s", event, e) + raise e defer.returnValue(event) @@ -932,43 +1020,14 @@ class FederationHandler(BaseHandler): event, event_stream_id, max_stream_id, extra_users=extra_users ) - new_pdu = event - - destinations = set() - - for k, s in context.current_state.items(): - try: - if k[0] == EventTypes.Member: - if s.content["membership"] == Membership.LEAVE: - destinations.add( - UserID.from_string(s.state_key).domain - ) - except: - logger.warn( - "Failed to get destination from event %s", s.event_id - ) - - destinations.discard(origin) - - logger.debug( - "on_send_leave_request: Sending event: %s, signatures: %s", - event.event_id, - event.signatures, - ) - - self.replication_layer.send_pdu(new_pdu, destinations) - defer.returnValue(None) @defer.inlineCallbacks - def get_state_for_pdu(self, origin, room_id, event_id, do_auth=True): + def get_state_for_pdu(self, room_id, event_id): + """Returns the state at the event. i.e. not including said event. + """ yield run_on_reactor() - if do_auth: - in_room = yield self.auth.check_host_in_room(room_id, origin) - if not in_room: - raise AuthError(403, "Host not in room.") - state_groups = yield self.store.get_state_groups( room_id, [event_id] ) @@ -992,19 +1051,50 @@ class FederationHandler(BaseHandler): res = results.values() for event in res: - event.signatures.update( - compute_event_signature( - event, - self.hs.hostname, - self.hs.config.signing_key[0] + # We sign these again because there was a bug where we + # incorrectly signed things the first time round + if self.hs.is_mine_id(event.event_id): + event.signatures.update( + compute_event_signature( + event, + self.hs.hostname, + self.hs.config.signing_key[0] + ) ) - ) defer.returnValue(res) else: defer.returnValue([]) @defer.inlineCallbacks + def get_state_ids_for_pdu(self, room_id, event_id): + """Returns the state at the event. i.e. not including said event. + """ + yield run_on_reactor() + + state_groups = yield self.store.get_state_groups_ids( + room_id, [event_id] + ) + + if state_groups: + _, state = state_groups.items().pop() + results = state + + event = yield self.store.get_event(event_id) + if event and event.is_state(): + # Get previous state + if "replaces_state" in event.unsigned: + prev_id = event.unsigned["replaces_state"] + if prev_id != event.event_id: + results[(event.type, event.state_key)] = prev_id + else: + del results[(event.type, event.state_key)] + + defer.returnValue(results.values()) + else: + defer.returnValue([]) + + @defer.inlineCallbacks @log_function def on_backfill_request(self, origin, room_id, pdu_list, limit): in_room = yield self.auth.check_host_in_room(room_id, origin) @@ -1036,16 +1126,17 @@ class FederationHandler(BaseHandler): ) if event: - # FIXME: This is a temporary work around where we occasionally - # return events slightly differently than when they were - # originally signed - event.signatures.update( - compute_event_signature( - event, - self.hs.hostname, - self.hs.config.signing_key[0] + if self.hs.is_mine_id(event.event_id): + # FIXME: This is a temporary work around where we occasionally + # return events slightly differently than when they were + # originally signed + event.signatures.update( + compute_event_signature( + event, + self.hs.hostname, + self.hs.config.signing_key[0] + ) ) - ) if do_auth: in_room = yield self.auth.check_host_in_room( @@ -1055,6 +1146,12 @@ class FederationHandler(BaseHandler): if not in_room: raise AuthError(403, "Host not in room.") + events = yield self._filter_events_for_server( + origin, event.room_id, [event] + ) + + event = events[0] + defer.returnValue(event) else: defer.returnValue(None) @@ -1063,50 +1160,47 @@ class FederationHandler(BaseHandler): def get_min_depth_for_context(self, context): return self.store.get_min_depth(context) - @log_function - def user_joined_room(self, user, room_id): - waiters = self.waiting_for_join_list.get( - (user.to_string(), room_id), - [] - ) - while waiters: - waiters.pop().callback(None) - @defer.inlineCallbacks @log_function - def _handle_new_event(self, origin, event, state=None, backfilled=False, - current_state=None, auth_events=None): - - outlier = event.internal_metadata.is_outlier() - + def _handle_new_event(self, origin, event, state=None, auth_events=None, + backfilled=False): context = yield self._prep_event( origin, event, state=state, auth_events=auth_events, ) - if not backfilled and not event.internal_metadata.is_outlier(): + if not event.internal_metadata.is_outlier(): action_generator = ActionGenerator(self.hs) yield action_generator.handle_push_actions_for_event( - event, context, self + event, context ) event_stream_id, max_stream_id = yield self.store.persist_event( event, context=context, backfilled=backfilled, - is_new_state=(not outlier and not backfilled), - current_state=current_state, ) + if not backfilled: + # this intentionally does not yield: we don't care about the result + # and don't need to wait for it. + preserve_fn(self.hs.get_pusherpool().on_new_notifications)( + event_stream_id, max_stream_id + ) + defer.returnValue((context, event_stream_id, max_stream_id)) @defer.inlineCallbacks - def _handle_new_events(self, origin, event_infos, backfilled=False, - outliers=False): - contexts = yield defer.gatherResults( + def _handle_new_events(self, origin, event_infos, backfilled=False): + """Creates the appropriate contexts and persists events. The events + should not depend on one another, e.g. this should be used to persist + a bunch of outliers, but not a chunk of individual events that depend + on each other for state calculations. + """ + contexts = yield preserve_context_over_deferred(defer.gatherResults( [ - self._prep_event( + preserve_fn(self._prep_event)( origin, ev_info["event"], state=ev_info.get("state"), @@ -1114,7 +1208,7 @@ class FederationHandler(BaseHandler): ) for ev_info in event_infos ] - ) + )) yield self.store.persist_events( [ @@ -1122,30 +1216,35 @@ class FederationHandler(BaseHandler): for ev_info, context in itertools.izip(event_infos, contexts) ], backfilled=backfilled, - is_new_state=(not outliers and not backfilled), ) @defer.inlineCallbacks - def _persist_auth_tree(self, auth_events, state, event): + def _persist_auth_tree(self, origin, auth_events, state, event): """Checks the auth chain is valid (and passes auth checks) for the state and event. Then persists the auth chain and state atomically. Persists the event seperately. + Will attempt to fetch missing auth events. + + Args: + origin (str): Where the events came from + auth_events (list) + state (list) + event (Event) + Returns: 2-tuple of (event_stream_id, max_stream_id) from the persist_event call for `event` """ events_to_context = {} for e in itertools.chain(auth_events, state): - ctx = yield self.state_handler.compute_event_context( - e, outlier=True, - ) - events_to_context[e.event_id] = ctx e.internal_metadata.outlier = True + ctx = yield self.state_handler.compute_event_context(e) + events_to_context[e.event_id] = ctx event_map = { e.event_id: e - for e in auth_events + for e in itertools.chain(auth_events, state, [event]) } create_event = None @@ -1154,10 +1253,29 @@ class FederationHandler(BaseHandler): create_event = e break + missing_auth_events = set() + for e in itertools.chain(auth_events, state, [event]): + for e_id, _ in e.auth_events: + if e_id not in event_map: + missing_auth_events.add(e_id) + + for e_id in missing_auth_events: + m_ev = yield self.replication_layer.get_pdu( + [origin], + e_id, + outlier=True, + timeout=10000, + ) + if m_ev and m_ev.event_id == e_id: + event_map[e_id] = m_ev + else: + logger.info("Failed to find auth event %r", e_id) + for e in itertools.chain(auth_events, state, [event]): auth_for_e = { (event_map[e_id].type, event_map[e_id].state_key): event_map[e_id] for e_id, _ in e.auth_events + if e_id in event_map } if create_event: auth_for_e[(EventTypes.Create, "")] = create_event @@ -1185,17 +1303,14 @@ class FederationHandler(BaseHandler): (e, events_to_context[e.event_id]) for e in itertools.chain(auth_events, state) ], - is_new_state=False, ) new_event_context = yield self.state_handler.compute_event_context( - event, old_state=state, outlier=False, + event, old_state=state ) event_stream_id, max_stream_id = yield self.store.persist_event( event, new_event_context, - backfilled=False, - is_new_state=True, current_state=state, ) @@ -1203,14 +1318,19 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks def _prep_event(self, origin, event, state=None, auth_events=None): - outlier = event.internal_metadata.is_outlier() context = yield self.state_handler.compute_event_context( - event, old_state=state, outlier=outlier, + event, old_state=state, ) if not auth_events: - auth_events = context.current_state + auth_events_ids = yield self.auth.compute_auth_events( + event, context.prev_state_ids, for_verification=True, + ) + auth_events = yield self.store.get_events(auth_events_ids) + auth_events = { + (e.type, e.state_key): e for e in auth_events.values() + } # This is a hack to fix some old rooms where the initial join event # didn't reference the create event in its auth events. @@ -1236,8 +1356,7 @@ class FederationHandler(BaseHandler): context.rejected = RejectedReason.AUTH_ERROR if event.type == EventTypes.GuestAccess: - full_context = yield self.store.get_current_state(room_id=event.room_id) - yield self.maybe_kick_guest_users(event, full_context) + yield self.maybe_kick_guest_users(event) defer.returnValue(context) @@ -1305,6 +1424,11 @@ class FederationHandler(BaseHandler): current_state = set(e.event_id for e in auth_events.values()) event_auth_events = set(e_id for e_id, _ in event.auth_events) + if event.is_state(): + event_key = (event.type, event.state_key) + else: + event_key = None + if event_auth_events - current_state: have_events = yield self.store.have_events( event_auth_events - current_state @@ -1378,9 +1502,9 @@ class FederationHandler(BaseHandler): # Do auth conflict res. logger.info("Different auth: %s", different_auth) - different_events = yield defer.gatherResults( + different_events = yield preserve_context_over_deferred(defer.gatherResults( [ - self.store.get_event( + preserve_fn(self.store.get_event)( d, allow_none=True, allow_rejected=False, @@ -1389,13 +1513,13 @@ class FederationHandler(BaseHandler): if d in have_events and not have_events[d] ], consumeErrors=True - ).addErrback(unwrapFirstError) + )).addErrback(unwrapFirstError) if different_events: local_view = dict(auth_events) remote_view = dict(auth_events) remote_view.update({ - (d.type, d.state_key): d for d in different_events + (d.type, d.state_key): d for d in different_events if d }) new_state, prev_state = self.state_handler.resolve_events( @@ -1408,8 +1532,16 @@ class FederationHandler(BaseHandler): current_state = set(e.event_id for e in auth_events.values()) different_auth = event_auth_events - current_state - context.current_state.update(auth_events) - context.state_group = None + context.current_state_ids = dict(context.current_state_ids) + context.current_state_ids.update({ + k: a.event_id for k, a in auth_events.items() + if k != event_key + }) + context.prev_state_ids = dict(context.prev_state_ids) + context.prev_state_ids.update({ + k: a.event_id for k, a in auth_events.items() + }) + context.state_group = self.store.get_next_state_group() if different_auth and not event.internal_metadata.is_outlier(): logger.info("Different auth after resolution: %s", different_auth) @@ -1430,8 +1562,8 @@ class FederationHandler(BaseHandler): if do_resolution: # 1. Get what we think is the auth chain. - auth_ids = self.auth.compute_auth_events( - event, context.current_state + auth_ids = yield self.auth.compute_auth_events( + event, context.prev_state_ids ) local_auth_chain = yield self.store.get_auth_chain(auth_ids) @@ -1487,13 +1619,22 @@ class FederationHandler(BaseHandler): # 4. Look at rejects and their proofs. # TODO. - context.current_state.update(auth_events) - context.state_group = None + context.current_state_ids = dict(context.current_state_ids) + context.current_state_ids.update({ + k: a.event_id for k, a in auth_events.items() + if k != event_key + }) + context.prev_state_ids = dict(context.prev_state_ids) + context.prev_state_ids.update({ + k: a.event_id for k, a in auth_events.items() + }) + context.state_group = self.store.get_next_state_group() try: self.auth.check(event, auth_events=auth_events) - except AuthError: - raise + except AuthError as e: + logger.warn("Failed auth resolution for %r because %s", event, e) + raise e @defer.inlineCallbacks def construct_auth_difference(self, local_auth, remote_auth): @@ -1663,14 +1804,22 @@ class FederationHandler(BaseHandler): if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)): builder = self.event_builder_factory.new(event_dict) EventValidator().validate_new(builder) - event, context = yield self._create_new_client_event(builder=builder) + message_handler = self.hs.get_handlers().message_handler + event, context = yield message_handler._create_new_client_event( + builder=builder + ) event, context = yield self.add_display_name_to_third_party_invite( event_dict, event, context ) - self.auth.check(event, context.current_state) - yield self._check_signature(event, auth_events=context.current_state) + try: + yield self.auth.check_from_context(event, context) + except AuthError as e: + logger.warn("Denying new third party invite %r because %s", event, e) + raise e + + yield self._check_signature(event, context) member_handler = self.hs.get_handlers().room_member_handler yield member_handler.send_membership_event(None, event, context) else: @@ -1686,7 +1835,8 @@ class FederationHandler(BaseHandler): def on_exchange_third_party_invite_request(self, origin, room_id, event_dict): builder = self.event_builder_factory.new(event_dict) - event, context = yield self._create_new_client_event( + message_handler = self.hs.get_handlers().message_handler + event, context = yield message_handler._create_new_client_event( builder=builder, ) @@ -1694,8 +1844,12 @@ class FederationHandler(BaseHandler): event_dict, event, context ) - self.auth.check(event, auth_events=context.current_state) - yield self._check_signature(event, auth_events=context.current_state) + try: + self.auth.check_from_context(event, context) + except AuthError as e: + logger.warn("Denying third party invite %r because %s", event, e) + raise e + yield self._check_signature(event, context) returned_invite = yield self.send_invite(origin, event) # TODO: Make sure the signatures actually are correct. @@ -1709,41 +1863,56 @@ class FederationHandler(BaseHandler): EventTypes.ThirdPartyInvite, event.content["third_party_invite"]["signed"]["token"] ) - original_invite = context.current_state.get(key) - if not original_invite: + original_invite = None + original_invite_id = context.prev_state_ids.get(key) + if original_invite_id: + original_invite = yield self.store.get_event( + original_invite_id, allow_none=True + ) + if original_invite: + display_name = original_invite.content["display_name"] + event_dict["content"]["third_party_invite"]["display_name"] = display_name + else: logger.info( - "Could not find invite event for third_party_invite - " - "discarding: %s" % (event_dict,) + "Could not find invite event for third_party_invite: %r", + event_dict ) - return + # We don't discard here as this is not the appropriate place to do + # auth checks. If we need the invite and don't have it then the + # auth check code will explode appropriately. - display_name = original_invite.content["display_name"] - event_dict["content"]["third_party_invite"]["display_name"] = display_name builder = self.event_builder_factory.new(event_dict) EventValidator().validate_new(builder) - event, context = yield self._create_new_client_event(builder=builder) + message_handler = self.hs.get_handlers().message_handler + event, context = yield message_handler._create_new_client_event(builder=builder) defer.returnValue((event, context)) @defer.inlineCallbacks - def _check_signature(self, event, auth_events): + def _check_signature(self, event, context): """ Checks that the signature in the event is consistent with its invite. - :param event (Event): The m.room.member event to check - :param auth_events (dict<(event type, state_key), event>) - :raises - AuthError if signature didn't match any keys, or key has been + Args: + event (Event): The m.room.member event to check + context (EventContext): + + Raises: + AuthError: if signature didn't match any keys, or key has been revoked, - SynapseError if a transient error meant a key couldn't be checked + SynapseError: if a transient error meant a key couldn't be checked for revocation. """ signed = event.content["third_party_invite"]["signed"] token = signed["token"] - invite_event = auth_events.get( + invite_event_id = context.prev_state_ids.get( (EventTypes.ThirdPartyInvite, token,) ) + invite_event = None + if invite_event_id: + invite_event = yield self.store.get_event(invite_event_id, allow_none=True) + if not invite_event: raise AuthError(403, "Could not find invite") @@ -1776,12 +1945,13 @@ class FederationHandler(BaseHandler): """ Checks whether public_key has been revoked. - :param public_key (str): base-64 encoded public key. - :param url (str): Key revocation URL. + Args: + public_key (str): base-64 encoded public key. + url (str): Key revocation URL. - :raises - AuthError if they key has been revoked. - SynapseError if a transient error meant a key couldn't be checked + Raises: + AuthError: if they key has been revoked. + SynapseError: if a transient error meant a key couldn't be checked for revocation. """ try: diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 656ce124f9..559e5d5a71 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -21,7 +21,7 @@ from synapse.api.errors import ( ) from ._base import BaseHandler from synapse.util.async import run_on_reactor -from synapse.api.errors import SynapseError +from synapse.api.errors import SynapseError, Codes import json import logging @@ -41,6 +41,20 @@ class IdentityHandler(BaseHandler): hs.config.use_insecure_ssl_client_just_for_testing_do_not_use ) + def _should_trust_id_server(self, id_server): + if id_server not in self.trusted_id_servers: + if self.trust_any_id_server_just_for_testing_do_not_use: + logger.warn( + "Trusting untrustworthy ID server %r even though it isn't" + " in the trusted id list for testing because" + " 'use_insecure_ssl_client_just_for_testing_do_not_use'" + " is set in the config", + id_server, + ) + else: + return False + return True + @defer.inlineCallbacks def threepid_from_creds(self, creds): yield run_on_reactor() @@ -59,19 +73,12 @@ class IdentityHandler(BaseHandler): else: raise SynapseError(400, "No client_secret in creds") - if id_server not in self.trusted_id_servers: - if self.trust_any_id_server_just_for_testing_do_not_use: - logger.warn( - "Trusting untrustworthy ID server %r even though it isn't" - " in the trusted id list for testing because" - " 'use_insecure_ssl_client_just_for_testing_do_not_use'" - " is set in the config", - id_server, - ) - else: - logger.warn('%s is not a trusted ID server: rejecting 3pid ' + - 'credentials', id_server) - defer.returnValue(None) + if not self._should_trust_id_server(id_server): + logger.warn( + '%s is not a trusted ID server: rejecting 3pid ' + + 'credentials', id_server + ) + defer.returnValue(None) data = {} try: @@ -129,6 +136,12 @@ class IdentityHandler(BaseHandler): def requestEmailToken(self, id_server, email, client_secret, send_attempt, **kwargs): yield run_on_reactor() + if not self._should_trust_id_server(id_server): + raise SynapseError( + 400, "Untrusted ID server '%s'" % id_server, + Codes.SERVER_NOT_TRUSTED + ) + params = { 'email': email, 'client_secret': client_secret, diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py new file mode 100644 index 0000000000..e0ade4c164 --- /dev/null +++ b/synapse/handlers/initial_sync.py @@ -0,0 +1,444 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.api.constants import EventTypes, Membership +from synapse.api.errors import AuthError, Codes +from synapse.events.utils import serialize_event +from synapse.events.validator import EventValidator +from synapse.streams.config import PaginationConfig +from synapse.types import ( + UserID, StreamToken, +) +from synapse.util import unwrapFirstError +from synapse.util.async import concurrently_execute +from synapse.util.caches.snapshot_cache import SnapshotCache +from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred +from synapse.visibility import filter_events_for_client + +from ._base import BaseHandler + +import logging + + +logger = logging.getLogger(__name__) + + +class InitialSyncHandler(BaseHandler): + def __init__(self, hs): + super(InitialSyncHandler, self).__init__(hs) + self.hs = hs + self.state = hs.get_state_handler() + self.clock = hs.get_clock() + self.validator = EventValidator() + self.snapshot_cache = SnapshotCache() + + def snapshot_all_rooms(self, user_id=None, pagin_config=None, + as_client_event=True, include_archived=False): + """Retrieve a snapshot of all rooms the user is invited or has joined. + + This snapshot may include messages for all rooms where the user is + joined, depending on the pagination config. + + Args: + user_id (str): The ID of the user making the request. + pagin_config (synapse.api.streams.PaginationConfig): The pagination + config used to determine how many messages *PER ROOM* to return. + as_client_event (bool): True to get events in client-server format. + include_archived (bool): True to get rooms that the user has left + Returns: + A list of dicts with "room_id" and "membership" keys for all rooms + the user is currently invited or joined in on. Rooms where the user + is joined on, may return a "messages" key with messages, depending + on the specified PaginationConfig. + """ + key = ( + user_id, + pagin_config.from_token, + pagin_config.to_token, + pagin_config.direction, + pagin_config.limit, + as_client_event, + include_archived, + ) + now_ms = self.clock.time_msec() + result = self.snapshot_cache.get(now_ms, key) + if result is not None: + return result + + return self.snapshot_cache.set(now_ms, key, self._snapshot_all_rooms( + user_id, pagin_config, as_client_event, include_archived + )) + + @defer.inlineCallbacks + def _snapshot_all_rooms(self, user_id=None, pagin_config=None, + as_client_event=True, include_archived=False): + + memberships = [Membership.INVITE, Membership.JOIN] + if include_archived: + memberships.append(Membership.LEAVE) + + room_list = yield self.store.get_rooms_for_user_where_membership_is( + user_id=user_id, membership_list=memberships + ) + + user = UserID.from_string(user_id) + + rooms_ret = [] + + now_token = yield self.hs.get_event_sources().get_current_token() + + presence_stream = self.hs.get_event_sources().sources["presence"] + pagination_config = PaginationConfig(from_token=now_token) + presence, _ = yield presence_stream.get_pagination_rows( + user, pagination_config.get_source_config("presence"), None + ) + + receipt_stream = self.hs.get_event_sources().sources["receipt"] + receipt, _ = yield receipt_stream.get_pagination_rows( + user, pagination_config.get_source_config("receipt"), None + ) + + tags_by_room = yield self.store.get_tags_for_user(user_id) + + account_data, account_data_by_room = ( + yield self.store.get_account_data_for_user(user_id) + ) + + public_room_ids = yield self.store.get_public_room_ids() + + limit = pagin_config.limit + if limit is None: + limit = 10 + + @defer.inlineCallbacks + def handle_room(event): + d = { + "room_id": event.room_id, + "membership": event.membership, + "visibility": ( + "public" if event.room_id in public_room_ids + else "private" + ), + } + + if event.membership == Membership.INVITE: + time_now = self.clock.time_msec() + d["inviter"] = event.sender + + invite_event = yield self.store.get_event(event.event_id) + d["invite"] = serialize_event(invite_event, time_now, as_client_event) + + rooms_ret.append(d) + + if event.membership not in (Membership.JOIN, Membership.LEAVE): + return + + try: + if event.membership == Membership.JOIN: + room_end_token = now_token.room_key + deferred_room_state = self.state_handler.get_current_state( + event.room_id + ) + elif event.membership == Membership.LEAVE: + room_end_token = "s%d" % (event.stream_ordering,) + deferred_room_state = self.store.get_state_for_events( + [event.event_id], None + ) + deferred_room_state.addCallback( + lambda states: states[event.event_id] + ) + + (messages, token), current_state = yield preserve_context_over_deferred( + defer.gatherResults( + [ + preserve_fn(self.store.get_recent_events_for_room)( + event.room_id, + limit=limit, + end_token=room_end_token, + ), + deferred_room_state, + ] + ) + ).addErrback(unwrapFirstError) + + messages = yield filter_events_for_client( + self.store, user_id, messages + ) + + start_token = now_token.copy_and_replace("room_key", token[0]) + end_token = now_token.copy_and_replace("room_key", token[1]) + time_now = self.clock.time_msec() + + d["messages"] = { + "chunk": [ + serialize_event(m, time_now, as_client_event) + for m in messages + ], + "start": start_token.to_string(), + "end": end_token.to_string(), + } + + d["state"] = [ + serialize_event(c, time_now, as_client_event) + for c in current_state.values() + ] + + account_data_events = [] + tags = tags_by_room.get(event.room_id) + if tags: + account_data_events.append({ + "type": "m.tag", + "content": {"tags": tags}, + }) + + account_data = account_data_by_room.get(event.room_id, {}) + for account_data_type, content in account_data.items(): + account_data_events.append({ + "type": account_data_type, + "content": content, + }) + + d["account_data"] = account_data_events + except: + logger.exception("Failed to get snapshot") + + yield concurrently_execute(handle_room, room_list, 10) + + account_data_events = [] + for account_data_type, content in account_data.items(): + account_data_events.append({ + "type": account_data_type, + "content": content, + }) + + ret = { + "rooms": rooms_ret, + "presence": presence, + "account_data": account_data_events, + "receipts": receipt, + "end": now_token.to_string(), + } + + defer.returnValue(ret) + + @defer.inlineCallbacks + def room_initial_sync(self, requester, room_id, pagin_config=None): + """Capture the a snapshot of a room. If user is currently a member of + the room this will be what is currently in the room. If the user left + the room this will be what was in the room when they left. + + Args: + requester(Requester): The user to get a snapshot for. + room_id(str): The room to get a snapshot of. + pagin_config(synapse.streams.config.PaginationConfig): + The pagination config used to determine how many messages to + return. + Raises: + AuthError if the user wasn't in the room. + Returns: + A JSON serialisable dict with the snapshot of the room. + """ + + user_id = requester.user.to_string() + + membership, member_event_id = yield self._check_in_room_or_world_readable( + room_id, user_id, + ) + is_peeking = member_event_id is None + + if membership == Membership.JOIN: + result = yield self._room_initial_sync_joined( + user_id, room_id, pagin_config, membership, is_peeking + ) + elif membership == Membership.LEAVE: + result = yield self._room_initial_sync_parted( + user_id, room_id, pagin_config, membership, member_event_id, is_peeking + ) + + account_data_events = [] + tags = yield self.store.get_tags_for_room(user_id, room_id) + if tags: + account_data_events.append({ + "type": "m.tag", + "content": {"tags": tags}, + }) + + account_data = yield self.store.get_account_data_for_room(user_id, room_id) + for account_data_type, content in account_data.items(): + account_data_events.append({ + "type": account_data_type, + "content": content, + }) + + result["account_data"] = account_data_events + + defer.returnValue(result) + + @defer.inlineCallbacks + def _room_initial_sync_parted(self, user_id, room_id, pagin_config, + membership, member_event_id, is_peeking): + room_state = yield self.store.get_state_for_events( + [member_event_id], None + ) + + room_state = room_state[member_event_id] + + limit = pagin_config.limit if pagin_config else None + if limit is None: + limit = 10 + + stream_token = yield self.store.get_stream_token_for_event( + member_event_id + ) + + messages, token = yield self.store.get_recent_events_for_room( + room_id, + limit=limit, + end_token=stream_token + ) + + messages = yield filter_events_for_client( + self.store, user_id, messages, is_peeking=is_peeking + ) + + start_token = StreamToken.START.copy_and_replace("room_key", token[0]) + end_token = StreamToken.START.copy_and_replace("room_key", token[1]) + + time_now = self.clock.time_msec() + + defer.returnValue({ + "membership": membership, + "room_id": room_id, + "messages": { + "chunk": [serialize_event(m, time_now) for m in messages], + "start": start_token.to_string(), + "end": end_token.to_string(), + }, + "state": [serialize_event(s, time_now) for s in room_state.values()], + "presence": [], + "receipts": [], + }) + + @defer.inlineCallbacks + def _room_initial_sync_joined(self, user_id, room_id, pagin_config, + membership, is_peeking): + current_state = yield self.state.get_current_state( + room_id=room_id, + ) + + # TODO: These concurrently + time_now = self.clock.time_msec() + state = [ + serialize_event(x, time_now) + for x in current_state.values() + ] + + now_token = yield self.hs.get_event_sources().get_current_token() + + limit = pagin_config.limit if pagin_config else None + if limit is None: + limit = 10 + + room_members = [ + m for m in current_state.values() + if m.type == EventTypes.Member + and m.content["membership"] == Membership.JOIN + ] + + presence_handler = self.hs.get_presence_handler() + + @defer.inlineCallbacks + def get_presence(): + states = yield presence_handler.get_states( + [m.user_id for m in room_members], + as_event=True, + ) + + defer.returnValue(states) + + @defer.inlineCallbacks + def get_receipts(): + receipts = yield self.store.get_linearized_receipts_for_room( + room_id, + to_key=now_token.receipt_key, + ) + if not receipts: + receipts = [] + defer.returnValue(receipts) + + presence, receipts, (messages, token) = yield defer.gatherResults( + [ + preserve_fn(get_presence)(), + preserve_fn(get_receipts)(), + preserve_fn(self.store.get_recent_events_for_room)( + room_id, + limit=limit, + end_token=now_token.room_key, + ) + ], + consumeErrors=True, + ).addErrback(unwrapFirstError) + + messages = yield filter_events_for_client( + self.store, user_id, messages, is_peeking=is_peeking, + ) + + start_token = now_token.copy_and_replace("room_key", token[0]) + end_token = now_token.copy_and_replace("room_key", token[1]) + + time_now = self.clock.time_msec() + + ret = { + "room_id": room_id, + "messages": { + "chunk": [serialize_event(m, time_now) for m in messages], + "start": start_token.to_string(), + "end": end_token.to_string(), + }, + "state": state, + "presence": presence, + "receipts": receipts, + } + if not is_peeking: + ret["membership"] = membership + + defer.returnValue(ret) + + @defer.inlineCallbacks + def _check_in_room_or_world_readable(self, room_id, user_id): + try: + # check_user_was_in_room will return the most recent membership + # event for the user if: + # * The user is a non-guest user, and was ever in the room + # * The user is a guest user, and has joined the room + # else it will throw. + member_event = yield self.auth.check_user_was_in_room(room_id, user_id) + defer.returnValue((member_event.membership, member_event.event_id)) + return + except AuthError: + visibility = yield self.state_handler.get_current_state( + room_id, EventTypes.RoomHistoryVisibility, "" + ) + if ( + visibility and + visibility.content["history_visibility"] == "world_readable" + ): + defer.returnValue((Membership.JOIN, None)) + return + raise AuthError( + 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN + ) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 5c50c611ba..fd09397226 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -16,27 +16,29 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import AuthError, Codes, SynapseError -from synapse.streams.config import PaginationConfig +from synapse.api.errors import AuthError, Codes, SynapseError, LimitExceededError +from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator -from synapse.util import unwrapFirstError -from synapse.util.caches.snapshot_cache import SnapshotCache -from synapse.types import UserID, RoomStreamToken, StreamToken +from synapse.push.action_generator import ActionGenerator +from synapse.types import ( + UserID, RoomAlias, RoomStreamToken, +) +from synapse.util.async import run_on_reactor, ReadWriteLock +from synapse.util.logcontext import preserve_fn +from synapse.util.metrics import measure_func +from synapse.visibility import filter_events_for_client from ._base import BaseHandler from canonicaljson import encode_canonical_json import logging +import random logger = logging.getLogger(__name__) -def collect_presencelike_data(distributor, user, content): - return distributor.fire("collect_presencelike_data", user, content) - - class MessageHandler(BaseHandler): def __init__(self, hs): @@ -45,40 +47,24 @@ class MessageHandler(BaseHandler): self.state = hs.get_state_handler() self.clock = hs.get_clock() self.validator = EventValidator() - self.snapshot_cache = SnapshotCache() - @defer.inlineCallbacks - def get_message(self, msg_id=None, room_id=None, sender_id=None, - user_id=None): - """ Retrieve a message. + self.pagination_lock = ReadWriteLock() - Args: - msg_id (str): The message ID to obtain. - room_id (str): The room where the message resides. - sender_id (str): The user ID of the user who sent the message. - user_id (str): The user ID of the user making this request. - Returns: - The message, or None if no message exists. - Raises: - SynapseError if something went wrong. - """ - yield self.auth.check_joined_room(room_id, user_id) + @defer.inlineCallbacks + def purge_history(self, room_id, event_id): + event = yield self.store.get_event(event_id) - # Pull out the message from the db -# msg = yield self.store.get_message( -# room_id=room_id, -# msg_id=msg_id, -# user_id=sender_id -# ) + if event.room_id != room_id: + raise SynapseError(400, "Event is for wrong room.") - # TODO (erikj): Once we work out the correct c-s api we need to think - # on how to do this. + depth = event.depth - defer.returnValue(None) + with (yield self.pagination_lock.write(room_id)): + yield self.store.delete_old_state(room_id, depth) @defer.inlineCallbacks def get_messages(self, requester, room_id=None, pagin_config=None, - as_client_event=True): + as_client_event=True, event_filter=None): """Get messages in a room. Args: @@ -87,18 +73,18 @@ class MessageHandler(BaseHandler): pagin_config (synapse.api.streams.PaginationConfig): The pagination config rules to apply, if any. as_client_event (bool): True to get events in client-server format. + event_filter (Filter): Filter to apply to results or None Returns: dict: Pagination API results """ user_id = requester.user.to_string() - data_source = self.hs.get_event_sources().sources["room"] if pagin_config.from_token: room_token = pagin_config.from_token.room_key else: pagin_config.from_token = ( - yield self.hs.get_event_sources().get_current_token( - direction='b' + yield self.hs.get_event_sources().get_current_token_for_room( + room_id=room_id ) ) room_token = pagin_config.from_token.room_key @@ -111,42 +97,48 @@ class MessageHandler(BaseHandler): source_config = pagin_config.get_source_config("room") - membership, member_event_id = yield self._check_in_room_or_world_readable( - room_id, user_id - ) + with (yield self.pagination_lock.read(room_id)): + membership, member_event_id = yield self._check_in_room_or_world_readable( + room_id, user_id + ) - if source_config.direction == 'b': - # if we're going backwards, we might need to backfill. This - # requires that we have a topo token. - if room_token.topological: - max_topo = room_token.topological - else: - max_topo = yield self.store.get_max_topological_token_for_stream_and_room( - room_id, room_token.stream - ) + if source_config.direction == 'b': + # if we're going backwards, we might need to backfill. This + # requires that we have a topo token. + if room_token.topological: + max_topo = room_token.topological + else: + max_topo = yield self.store.get_max_topological_token( + room_id, room_token.stream + ) - if membership == Membership.LEAVE: - # If they have left the room then clamp the token to be before - # they left the room, to save the effort of loading from the - # database. - leave_token = yield self.store.get_topological_token_for_event( - member_event_id + if membership == Membership.LEAVE: + # If they have left the room then clamp the token to be before + # they left the room, to save the effort of loading from the + # database. + leave_token = yield self.store.get_topological_token_for_event( + member_event_id + ) + leave_token = RoomStreamToken.parse(leave_token) + if leave_token.topological < max_topo: + source_config.from_key = str(leave_token) + + yield self.hs.get_handlers().federation_handler.maybe_backfill( + room_id, max_topo ) - leave_token = RoomStreamToken.parse(leave_token) - if leave_token.topological < max_topo: - source_config.from_key = str(leave_token) - yield self.hs.get_handlers().federation_handler.maybe_backfill( - room_id, max_topo + events, next_key = yield self.store.paginate_room_events( + room_id=room_id, + from_key=source_config.from_key, + to_key=source_config.to_key, + direction=source_config.direction, + limit=source_config.limit, + event_filter=event_filter, ) - events, next_key = yield data_source.get_pagination_rows( - requester.user, source_config, room_id - ) - - next_token = pagin_config.from_token.copy_and_replace( - "room_key", next_key - ) + next_token = pagin_config.from_token.copy_and_replace( + "room_key", next_key + ) if not events: defer.returnValue({ @@ -155,7 +147,11 @@ class MessageHandler(BaseHandler): "end": next_token.to_string(), }) - events = yield self._filter_events_for_client( + if event_filter: + events = event_filter.filter(events) + + events = yield filter_events_for_client( + self.store, user_id, events, is_peeking=(member_event_id is None), @@ -175,7 +171,7 @@ class MessageHandler(BaseHandler): defer.returnValue(chunk) @defer.inlineCallbacks - def create_event(self, event_dict, token_id=None, txn_id=None): + def create_event(self, event_dict, token_id=None, txn_id=None, prev_event_ids=None): """ Given a dict from a client, create a new event. @@ -186,6 +182,9 @@ class MessageHandler(BaseHandler): Args: event_dict (dict): An entire event + token_id (str) + txn_id (str) + prev_event_ids (list): The prev event ids to use when creating the event Returns: Tuple of created event (FrozenEvent), Context @@ -198,12 +197,8 @@ class MessageHandler(BaseHandler): membership = builder.content.get("membership", None) target = UserID.from_string(builder.state_key) - if membership == Membership.JOIN: + if membership in {Membership.JOIN, Membership.INVITE}: # If event doesn't include a display name, add one. - yield collect_presencelike_data( - self.distributor, target, builder.content - ) - elif membership == Membership.INVITE: profile = self.hs.get_handlers().profile_handler content = builder.content @@ -224,6 +219,7 @@ class MessageHandler(BaseHandler): event, context = yield self._create_new_client_event( builder=builder, + prev_event_ids=prev_event_ids, ) defer.returnValue((event, context)) @@ -244,12 +240,27 @@ class MessageHandler(BaseHandler): "Tried to send member event through non-member codepath" ) + # We check here if we are currently being rate limited, so that we + # don't do unnecessary work. We check again just before we actually + # send the event. + time_now = self.clock.time() + allowed, time_allowed = self.ratelimiter.send_message( + event.sender, time_now, + msg_rate_hz=self.hs.config.rc_messages_per_second, + burst_count=self.hs.config.rc_message_burst_count, + update=False, + ) + if not allowed: + raise LimitExceededError( + retry_after_ms=int(1000 * (time_allowed - time_now)), + ) + user = UserID.from_string(event.sender) assert self.hs.is_mine(user), "User must be our own: %s" % (user,) if event.is_state(): - prev_state = self.deduplicate_state_event(event, context) + prev_state = yield self.deduplicate_state_event(event, context) if prev_state is not None: defer.returnValue(prev_state) @@ -261,9 +272,10 @@ class MessageHandler(BaseHandler): ) if event.type == EventTypes.Message: - presence = self.hs.get_handlers().presence_handler + presence = self.hs.get_presence_handler() yield presence.bump_presence_active_time(user) + @defer.inlineCallbacks def deduplicate_state_event(self, event, context): """ Checks whether event is in the latest resolved state in context. @@ -271,13 +283,17 @@ class MessageHandler(BaseHandler): If so, returns the version of the event in context. Otherwise, returns None. """ - prev_event = context.current_state.get((event.type, event.state_key)) + prev_event_id = context.prev_state_ids.get((event.type, event.state_key)) + prev_event = yield self.store.get_event(prev_event_id, allow_none=True) + if not prev_event: + return + if prev_event and event.user_id == prev_event.user_id: prev_content = encode_canonical_json(prev_event.content) next_content = encode_canonical_json(event.content) if prev_content == next_content: - return prev_event - return None + defer.returnValue(prev_event) + return @defer.inlineCallbacks def create_and_send_nonmember_event( @@ -388,378 +404,210 @@ class MessageHandler(BaseHandler): [serialize_event(c, now) for c in room_state.values()] ) - def snapshot_all_rooms(self, user_id=None, pagin_config=None, - as_client_event=True, include_archived=False): - """Retrieve a snapshot of all rooms the user is invited or has joined. - - This snapshot may include messages for all rooms where the user is - joined, depending on the pagination config. - - Args: - user_id (str): The ID of the user making the request. - pagin_config (synapse.api.streams.PaginationConfig): The pagination - config used to determine how many messages *PER ROOM* to return. - as_client_event (bool): True to get events in client-server format. - include_archived (bool): True to get rooms that the user has left - Returns: - A list of dicts with "room_id" and "membership" keys for all rooms - the user is currently invited or joined in on. Rooms where the user - is joined on, may return a "messages" key with messages, depending - on the specified PaginationConfig. - """ - key = ( - user_id, - pagin_config.from_token, - pagin_config.to_token, - pagin_config.direction, - pagin_config.limit, - as_client_event, - include_archived, - ) - now_ms = self.clock.time_msec() - result = self.snapshot_cache.get(now_ms, key) - if result is not None: - return result - - return self.snapshot_cache.set(now_ms, key, self._snapshot_all_rooms( - user_id, pagin_config, as_client_event, include_archived - )) - + @measure_func("_create_new_client_event") @defer.inlineCallbacks - def _snapshot_all_rooms(self, user_id=None, pagin_config=None, - as_client_event=True, include_archived=False): - - memberships = [Membership.INVITE, Membership.JOIN] - if include_archived: - memberships.append(Membership.LEAVE) + def _create_new_client_event(self, builder, prev_event_ids=None): + if prev_event_ids: + prev_events = yield self.store.add_event_hashes(prev_event_ids) + prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids) + depth = prev_max_depth + 1 + else: + latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room( + builder.room_id, + ) - room_list = yield self.store.get_rooms_for_user_where_membership_is( - user_id=user_id, membership_list=memberships - ) + # We want to limit the max number of prev events we point to in our + # new event + if len(latest_ret) > 10: + # Sort by reverse depth, so we point to the most recent. + latest_ret.sort(key=lambda a: -a[2]) + new_latest_ret = latest_ret[:5] + + # We also randomly point to some of the older events, to make + # sure that we don't completely ignore the older events. + if latest_ret[5:]: + sample_size = min(5, len(latest_ret[5:])) + new_latest_ret.extend(random.sample(latest_ret[5:], sample_size)) + latest_ret = new_latest_ret + + if latest_ret: + depth = max([d for _, _, d in latest_ret]) + 1 + else: + depth = 1 - user = UserID.from_string(user_id) + prev_events = [ + (event_id, prev_hashes) + for event_id, prev_hashes, _ in latest_ret + ] - rooms_ret = [] + builder.prev_events = prev_events + builder.depth = depth - now_token = yield self.hs.get_event_sources().get_current_token() + state_handler = self.state_handler - presence_stream = self.hs.get_event_sources().sources["presence"] - pagination_config = PaginationConfig(from_token=now_token) - presence, _ = yield presence_stream.get_pagination_rows( - user, pagination_config.get_source_config("presence"), None - ) + context = yield state_handler.compute_event_context(builder) - receipt_stream = self.hs.get_event_sources().sources["receipt"] - receipt, _ = yield receipt_stream.get_pagination_rows( - user, pagination_config.get_source_config("receipt"), None - ) + if builder.is_state(): + builder.prev_state = yield self.store.add_event_hashes( + context.prev_state_events + ) - tags_by_room = yield self.store.get_tags_for_user(user_id) + yield self.auth.add_auth_events(builder, context) - account_data, account_data_by_room = ( - yield self.store.get_account_data_for_user(user_id) + signing_key = self.hs.config.signing_key[0] + add_hashes_and_signatures( + builder, self.server_name, signing_key ) - public_room_ids = yield self.store.get_public_room_ids() + event = builder.build() - limit = pagin_config.limit - if limit is None: - limit = 10 + logger.debug( + "Created event %s with state: %s", + event.event_id, context.prev_state_ids, + ) - @defer.inlineCallbacks - def handle_room(event): - d = { - "room_id": event.room_id, - "membership": event.membership, - "visibility": ( - "public" if event.room_id in public_room_ids - else "private" - ), - } + defer.returnValue( + (event, context,) + ) - if event.membership == Membership.INVITE: - time_now = self.clock.time_msec() - d["inviter"] = event.sender + @measure_func("handle_new_client_event") + @defer.inlineCallbacks + def handle_new_client_event( + self, + requester, + event, + context, + ratelimit=True, + extra_users=[] + ): + # We now need to go and hit out to wherever we need to hit out to. - invite_event = yield self.store.get_event(event.event_id) - d["invite"] = serialize_event(invite_event, time_now, as_client_event) + if ratelimit: + self.ratelimit(requester) - rooms_ret.append(d) + try: + yield self.auth.check_from_context(event, context) + except AuthError as err: + logger.warn("Denying new event %r because %s", event, err) + raise err + + yield self.maybe_kick_guest_users(event, context) + + if event.type == EventTypes.CanonicalAlias: + # Check the alias is acually valid (at this time at least) + room_alias_str = event.content.get("alias", None) + if room_alias_str: + room_alias = RoomAlias.from_string(room_alias_str) + directory_handler = self.hs.get_handlers().directory_handler + mapping = yield directory_handler.get_association(room_alias) + + if mapping["room_id"] != event.room_id: + raise SynapseError( + 400, + "Room alias %s does not point to the room" % ( + room_alias_str, + ) + ) - if event.membership not in (Membership.JOIN, Membership.LEAVE): - return + federation_handler = self.hs.get_handlers().federation_handler - try: - if event.membership == Membership.JOIN: - room_end_token = now_token.room_key - deferred_room_state = self.state_handler.get_current_state( - event.room_id - ) - elif event.membership == Membership.LEAVE: - room_end_token = "s%d" % (event.stream_ordering,) - deferred_room_state = self.store.get_state_for_events( - [event.event_id], None - ) - deferred_room_state.addCallback( - lambda states: states[event.event_id] + if event.type == EventTypes.Member: + if event.content["membership"] == Membership.INVITE: + def is_inviter_member_event(e): + return ( + e.type == EventTypes.Member and + e.sender == event.sender ) - (messages, token), current_state = yield defer.gatherResults( - [ - self.store.get_recent_events_for_room( - event.room_id, - limit=limit, - end_token=room_end_token, - ), - deferred_room_state, - ] - ).addErrback(unwrapFirstError) - - messages = yield self._filter_events_for_client( - user_id, messages - ) - - start_token = now_token.copy_and_replace("room_key", token[0]) - end_token = now_token.copy_and_replace("room_key", token[1]) - time_now = self.clock.time_msec() - - d["messages"] = { - "chunk": [ - serialize_event(m, time_now, as_client_event) - for m in messages - ], - "start": start_token.to_string(), - "end": end_token.to_string(), - } - - d["state"] = [ - serialize_event(c, time_now, as_client_event) - for c in current_state.values() + state_to_include_ids = [ + e_id + for k, e_id in context.current_state_ids.items() + if k[0] in self.hs.config.room_invite_state_types + or k[0] == EventTypes.Member and k[1] == event.sender ] - account_data_events = [] - tags = tags_by_room.get(event.room_id) - if tags: - account_data_events.append({ - "type": "m.tag", - "content": {"tags": tags}, - }) - - account_data = account_data_by_room.get(event.room_id, {}) - for account_data_type, content in account_data.items(): - account_data_events.append({ - "type": account_data_type, - "content": content, - }) - - d["account_data"] = account_data_events - except: - logger.exception("Failed to get snapshot") - - # Only do N rooms at once - n = 5 - d_list = [handle_room(e) for e in room_list] - for i in range(0, len(d_list), n): - yield defer.gatherResults( - d_list[i:i + n], - consumeErrors=True - ).addErrback(unwrapFirstError) - - account_data_events = [] - for account_data_type, content in account_data.items(): - account_data_events.append({ - "type": account_data_type, - "content": content, - }) - - ret = { - "rooms": rooms_ret, - "presence": presence, - "account_data": account_data_events, - "receipts": receipt, - "end": now_token.to_string(), - } + state_to_include = yield self.store.get_events(state_to_include_ids) - defer.returnValue(ret) + event.unsigned["invite_room_state"] = [ + { + "type": e.type, + "state_key": e.state_key, + "content": e.content, + "sender": e.sender, + } + for e in state_to_include.values() + ] - @defer.inlineCallbacks - def room_initial_sync(self, requester, room_id, pagin_config=None): - """Capture the a snapshot of a room. If user is currently a member of - the room this will be what is currently in the room. If the user left - the room this will be what was in the room when they left. + invitee = UserID.from_string(event.state_key) + if not self.hs.is_mine(invitee): + # TODO: Can we add signature from remote server in a nicer + # way? If we have been invited by a remote server, we need + # to get them to sign the event. - Args: - requester(Requester): The user to get a snapshot for. - room_id(str): The room to get a snapshot of. - pagin_config(synapse.streams.config.PaginationConfig): - The pagination config used to determine how many messages to - return. - Raises: - AuthError if the user wasn't in the room. - Returns: - A JSON serialisable dict with the snapshot of the room. - """ + returned_invite = yield federation_handler.send_invite( + invitee.domain, + event, + ) - user_id = requester.user.to_string() + event.unsigned.pop("room_state", None) - membership, member_event_id = yield self._check_in_room_or_world_readable( - room_id, user_id, - ) - is_peeking = member_event_id is None + # TODO: Make sure the signatures actually are correct. + event.signatures.update( + returned_invite.signatures + ) - if membership == Membership.JOIN: - result = yield self._room_initial_sync_joined( - user_id, room_id, pagin_config, membership, is_peeking - ) - elif membership == Membership.LEAVE: - result = yield self._room_initial_sync_parted( - user_id, room_id, pagin_config, membership, member_event_id, is_peeking + if event.type == EventTypes.Redaction: + auth_events_ids = yield self.auth.compute_auth_events( + event, context.prev_state_ids, for_verification=True, ) + auth_events = yield self.store.get_events(auth_events_ids) + auth_events = { + (e.type, e.state_key): e for e in auth_events.values() + } + if self.auth.check_redaction(event, auth_events=auth_events): + original_event = yield self.store.get_event( + event.redacts, + check_redacted=False, + get_prev_content=False, + allow_rejected=False, + allow_none=False + ) + if event.user_id != original_event.user_id: + raise AuthError( + 403, + "You don't have permission to redact events" + ) - account_data_events = [] - tags = yield self.store.get_tags_for_room(user_id, room_id) - if tags: - account_data_events.append({ - "type": "m.tag", - "content": {"tags": tags}, - }) - - account_data = yield self.store.get_account_data_for_room(user_id, room_id) - for account_data_type, content in account_data.items(): - account_data_events.append({ - "type": account_data_type, - "content": content, - }) - - result["account_data"] = account_data_events - - defer.returnValue(result) - - @defer.inlineCallbacks - def _room_initial_sync_parted(self, user_id, room_id, pagin_config, - membership, member_event_id, is_peeking): - room_state = yield self.store.get_state_for_events( - [member_event_id], None - ) - - room_state = room_state[member_event_id] - - limit = pagin_config.limit if pagin_config else None - if limit is None: - limit = 10 + if event.type == EventTypes.Create and context.prev_state_ids: + raise AuthError( + 403, + "Changing the room create event is forbidden", + ) - stream_token = yield self.store.get_stream_token_for_event( - member_event_id + action_generator = ActionGenerator(self.hs) + yield action_generator.handle_push_actions_for_event( + event, context ) - messages, token = yield self.store.get_recent_events_for_room( - room_id, - limit=limit, - end_token=stream_token + (event_stream_id, max_stream_id) = yield self.store.persist_event( + event, context=context ) - messages = yield self._filter_events_for_client( - user_id, messages, is_peeking=is_peeking + # this intentionally does not yield: we don't care about the result + # and don't need to wait for it. + preserve_fn(self.hs.get_pusherpool().on_new_notifications)( + event_stream_id, max_stream_id ) - start_token = StreamToken.START.copy_and_replace("room_key", token[0]) - end_token = StreamToken.START.copy_and_replace("room_key", token[1]) - - time_now = self.clock.time_msec() - - defer.returnValue({ - "membership": membership, - "room_id": room_id, - "messages": { - "chunk": [serialize_event(m, time_now) for m in messages], - "start": start_token.to_string(), - "end": end_token.to_string(), - }, - "state": [serialize_event(s, time_now) for s in room_state.values()], - "presence": [], - "receipts": [], - }) - - @defer.inlineCallbacks - def _room_initial_sync_joined(self, user_id, room_id, pagin_config, - membership, is_peeking): - current_state = yield self.state.get_current_state( - room_id=room_id, - ) - - # TODO: These concurrently - time_now = self.clock.time_msec() - state = [ - serialize_event(x, time_now) - for x in current_state.values() - ] - - now_token = yield self.hs.get_event_sources().get_current_token() - - limit = pagin_config.limit if pagin_config else None - if limit is None: - limit = 10 - - room_members = [ - m for m in current_state.values() - if m.type == EventTypes.Member - and m.content["membership"] == Membership.JOIN - ] - - presence_handler = self.hs.get_handlers().presence_handler - @defer.inlineCallbacks - def get_presence(): - states = yield presence_handler.get_states( - [m.user_id for m in room_members], - as_event=True, + def _notify(): + yield run_on_reactor() + yield self.notifier.on_new_room_event( + event, event_stream_id, max_stream_id, + extra_users=extra_users ) - defer.returnValue(states) - - @defer.inlineCallbacks - def get_receipts(): - receipts_handler = self.hs.get_handlers().receipts_handler - receipts = yield receipts_handler.get_receipts_for_room( - room_id, - now_token.receipt_key - ) - defer.returnValue(receipts) - - presence, receipts, (messages, token) = yield defer.gatherResults( - [ - get_presence(), - get_receipts(), - self.store.get_recent_events_for_room( - room_id, - limit=limit, - end_token=now_token.room_key, - ) - ], - consumeErrors=True, - ).addErrback(unwrapFirstError) - - messages = yield self._filter_events_for_client( - user_id, messages, is_peeking=is_peeking, - ) - - start_token = now_token.copy_and_replace("room_key", token[0]) - end_token = now_token.copy_and_replace("room_key", token[1]) - - time_now = self.clock.time_msec() - - ret = { - "room_id": room_id, - "messages": { - "chunk": [serialize_event(m, time_now) for m in messages], - "start": start_token.to_string(), - "end": end_token.to_string(), - }, - "state": state, - "presence": presence, - "receipts": receipts, - } - if not is_peeking: - ret["membership"] = membership + preserve_fn(_notify)() - defer.returnValue(ret) + # If invite, remove room_state from unsigned before sending. + event.unsigned.pop("invite_room_state", None) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index d0c8f1328b..1b89dc6274 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -33,11 +33,9 @@ from synapse.util.logcontext import preserve_fn from synapse.util.logutils import log_function from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer -from synapse.types import UserID +from synapse.types import UserID, get_domain_from_id import synapse.metrics -from ._base import BaseHandler - import logging @@ -52,6 +50,13 @@ timers_fired_counter = metrics.register_counter("timers_fired") federation_presence_counter = metrics.register_counter("federation_presence") bump_active_time_counter = metrics.register_counter("bump_active_time") +get_updates_counter = metrics.register_counter("get_updates", labels=["type"]) + +notify_reason_counter = metrics.register_counter("notify_reason", labels=["reason"]) +state_transition_counter = metrics.register_counter( + "state_transition", labels=["from", "to"] +) + # If a user was last active in the last LAST_ACTIVE_GRANULARITY, consider them # "currently_active" @@ -70,38 +75,45 @@ FEDERATION_TIMEOUT = 30 * 60 * 1000 # How often to resend presence to remote servers FEDERATION_PING_INTERVAL = 25 * 60 * 1000 +# How long we will wait before assuming that the syncs from an external process +# are dead. +EXTERNAL_PROCESS_EXPIRY = 5 * 60 * 1000 + assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER -class PresenceHandler(BaseHandler): +class PresenceHandler(object): def __init__(self, hs): - super(PresenceHandler, self).__init__(hs) - self.hs = hs + self.is_mine = hs.is_mine + self.is_mine_id = hs.is_mine_id self.clock = hs.get_clock() self.store = hs.get_datastore() self.wheel_timer = WheelTimer() self.notifier = hs.get_notifier() - self.federation = hs.get_replication_layer() + self.replication = hs.get_replication_layer() + self.federation = hs.get_federation_sender() + + self.state = hs.get_state_handler() - self.federation.register_edu_handler( + self.replication.register_edu_handler( "m.presence", self.incoming_presence ) - self.federation.register_edu_handler( + self.replication.register_edu_handler( "m.presence_invite", lambda origin, content: self.invite_presence( observed_user=UserID.from_string(content["observed_user"]), observer_user=UserID.from_string(content["observer_user"]), ) ) - self.federation.register_edu_handler( + self.replication.register_edu_handler( "m.presence_accept", lambda origin, content: self.accept_presence( observed_user=UserID.from_string(content["observed_user"]), observer_user=UserID.from_string(content["observer_user"]), ) ) - self.federation.register_edu_handler( + self.replication.register_edu_handler( "m.presence_deny", lambda origin, content: self.deny_presence( observed_user=UserID.from_string(content["observed_user"]), @@ -138,7 +150,7 @@ class PresenceHandler(BaseHandler): obj=state.user_id, then=state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT, ) - if self.hs.is_mine_id(state.user_id): + if self.is_mine_id(state.user_id): self.wheel_timer.insert( now=now, obj=state.user_id, @@ -160,20 +172,38 @@ class PresenceHandler(BaseHandler): self.serial_to_user = {} self._next_serial = 1 - # Keeps track of the number of *ongoing* syncs. While this is non zero - # a user will never go offline. + # Keeps track of the number of *ongoing* syncs on this process. While + # this is non zero a user will never go offline. self.user_to_num_current_syncs = {} + # Keeps track of the number of *ongoing* syncs on other processes. + # While any sync is ongoing on another process the user will never + # go offline. + # Each process has a unique identifier and an update frequency. If + # no update is received from that process within the update period then + # we assume that all the sync requests on that process have stopped. + # Stored as a dict from process_id to set of user_id, and a dict of + # process_id to millisecond timestamp last updated. + self.external_process_to_current_syncs = {} + self.external_process_last_updated_ms = {} + # Start a LoopingCall in 30s that fires every 5s. # The initial delay is to allow disconnected clients a chance to # reconnect before we treat them as offline. self.clock.call_later( - 0 * 1000, + 30, self.clock.looping_call, self._handle_timeouts, 5000, ) + self.clock.call_later( + 60, + self.clock.looping_call, + self._persist_unpersisted_changes, + 60 * 1000, + ) + metrics.register_callback("wheel_timer_size", lambda: len(self.wheel_timer)) @defer.inlineCallbacks @@ -188,7 +218,7 @@ class PresenceHandler(BaseHandler): is some spurious presence changes that will self-correct. """ logger.info( - "Performing _on_shutdown. Persiting %d unpersisted changes", + "Performing _on_shutdown. Persisting %d unpersisted changes", len(self.user_to_current_state) ) @@ -200,6 +230,27 @@ class PresenceHandler(BaseHandler): logger.info("Finished _on_shutdown") @defer.inlineCallbacks + def _persist_unpersisted_changes(self): + """We periodically persist the unpersisted changes, as otherwise they + may stack up and slow down shutdown times. + """ + logger.info( + "Performing _persist_unpersisted_changes. Persisting %d unpersisted changes", + len(self.unpersisted_users_changes) + ) + + unpersisted = self.unpersisted_users_changes + self.unpersisted_users_changes = set() + + if unpersisted: + yield self.store.update_presence([ + self.user_to_current_state[user_id] + for user_id in unpersisted + ]) + + logger.info("Finished _persist_unpersisted_changes") + + @defer.inlineCallbacks def _update_states(self, new_states): """Updates presence of users. Sets the appropriate timeouts. Pokes the notifier and federation if and only if the changed presence state @@ -215,6 +266,12 @@ class PresenceHandler(BaseHandler): to_notify = {} # Changes we want to notify everyone about to_federation_ping = {} # These need sending keep-alives + # Only bother handling the last presence change for each user + new_states_dict = {} + for new_state in new_states: + new_states_dict[new_state.user_id] = new_state + new_state = new_states_dict.values() + for new_state in new_states: user_id = new_state.user_id @@ -228,7 +285,7 @@ class PresenceHandler(BaseHandler): new_state, should_notify, should_ping = handle_update( prev_state, new_state, - is_mine=self.hs.is_mine_id(user_id), + is_mine=self.is_mine_id(user_id), wheel_timer=self.wheel_timer, now=now ) @@ -268,31 +325,48 @@ class PresenceHandler(BaseHandler): """Checks the presence of users that have timed out and updates as appropriate. """ + logger.info("Handling presence timeouts") now = self.clock.time_msec() - with Measure(self.clock, "presence_handle_timeouts"): - # Fetch the list of users that *may* have timed out. Things may have - # changed since the timeout was set, so we won't necessarily have to - # take any action. - users_to_check = self.wheel_timer.fetch(now) + try: + with Measure(self.clock, "presence_handle_timeouts"): + # Fetch the list of users that *may* have timed out. Things may have + # changed since the timeout was set, so we won't necessarily have to + # take any action. + users_to_check = set(self.wheel_timer.fetch(now)) + + # Check whether the lists of syncing processes from an external + # process have expired. + expired_process_ids = [ + process_id for process_id, last_update + in self.external_process_last_updated_ms.items() + if now - last_update > EXTERNAL_PROCESS_EXPIRY + ] + for process_id in expired_process_ids: + users_to_check.update( + self.external_process_last_updated_ms.pop(process_id, ()) + ) + self.external_process_last_update.pop(process_id) - states = [ - self.user_to_current_state.get( - user_id, UserPresenceState.default(user_id) - ) - for user_id in set(users_to_check) - ] + states = [ + self.user_to_current_state.get( + user_id, UserPresenceState.default(user_id) + ) + for user_id in users_to_check + ] - timers_fired_counter.inc_by(len(states)) + timers_fired_counter.inc_by(len(states)) - changes = handle_timeouts( - states, - is_mine_fn=self.hs.is_mine_id, - user_to_num_current_syncs=self.user_to_num_current_syncs, - now=now, - ) + changes = handle_timeouts( + states, + is_mine_fn=self.is_mine_id, + syncing_user_ids=self.get_currently_syncing_users(), + now=now, + ) - preserve_fn(self._update_states)(changes) + preserve_fn(self._update_states)(changes) + except: + logger.exception("Exception in _handle_timeouts loop") @defer.inlineCallbacks def bump_presence_active_time(self, user): @@ -365,6 +439,74 @@ class PresenceHandler(BaseHandler): defer.returnValue(_user_syncing()) + def get_currently_syncing_users(self): + """Get the set of user ids that are currently syncing on this HS. + Returns: + set(str): A set of user_id strings. + """ + syncing_user_ids = { + user_id for user_id, count in self.user_to_num_current_syncs.items() + if count + } + for user_ids in self.external_process_to_current_syncs.values(): + syncing_user_ids.update(user_ids) + return syncing_user_ids + + @defer.inlineCallbacks + def update_external_syncs(self, process_id, syncing_user_ids): + """Update the syncing users for an external process + + Args: + process_id(str): An identifier for the process the users are + syncing against. This allows synapse to process updates + as user start and stop syncing against a given process. + syncing_user_ids(set(str)): The set of user_ids that are + currently syncing on that server. + """ + + # Grab the previous list of user_ids that were syncing on that process + prev_syncing_user_ids = ( + self.external_process_to_current_syncs.get(process_id, set()) + ) + # Grab the current presence state for both the users that are syncing + # now and the users that were syncing before this update. + prev_states = yield self.current_state_for_users( + syncing_user_ids | prev_syncing_user_ids + ) + updates = [] + time_now_ms = self.clock.time_msec() + + # For each new user that is syncing check if we need to mark them as + # being online. + for new_user_id in syncing_user_ids - prev_syncing_user_ids: + prev_state = prev_states[new_user_id] + if prev_state.state == PresenceState.OFFLINE: + updates.append(prev_state.copy_and_replace( + state=PresenceState.ONLINE, + last_active_ts=time_now_ms, + last_user_sync_ts=time_now_ms, + )) + else: + updates.append(prev_state.copy_and_replace( + last_user_sync_ts=time_now_ms, + )) + + # For each user that is still syncing or stopped syncing update the + # last sync time so that we will correctly apply the grace period when + # they stop syncing. + for old_user_id in prev_syncing_user_ids: + prev_state = prev_states[old_user_id] + updates.append(prev_state.copy_and_replace( + last_user_sync_ts=time_now_ms, + )) + + yield self._update_states(updates) + + # Update the last updated time for the process. We expire the entries + # if we don't receive an update in the given timeframe. + self.external_process_last_updated_ms[process_id] = self.clock.time_msec() + self.external_process_to_current_syncs[process_id] = syncing_user_ids + @defer.inlineCallbacks def current_state_for_user(self, user_id): """Get the current presence state for a user. @@ -403,7 +545,7 @@ class PresenceHandler(BaseHandler): defer.returnValue(states) @defer.inlineCallbacks - def _get_interested_parties(self, states): + def _get_interested_parties(self, states, calculate_remote_hosts=True): """Given a list of states return which entities (rooms, users, servers) are interested in the given states. @@ -426,21 +568,24 @@ class PresenceHandler(BaseHandler): users_to_states.setdefault(state.user_id, []).append(state) hosts_to_states = {} - for room_id, states in room_ids_to_states.items(): - local_states = filter(lambda s: self.hs.is_mine_id(s.user_id), states) - if not local_states: - continue + if calculate_remote_hosts: + for room_id, states in room_ids_to_states.items(): + local_states = filter(lambda s: self.is_mine_id(s.user_id), states) + if not local_states: + continue + + users = yield self.state.get_current_user_in_room(room_id) + hosts = set(get_domain_from_id(u) for u in users) - hosts = yield self.store.get_joined_hosts_for_room(room_id) - for host in hosts: - hosts_to_states.setdefault(host, []).extend(local_states) + for host in hosts: + hosts_to_states.setdefault(host, []).extend(local_states) for user_id, states in users_to_states.items(): - local_states = filter(lambda s: self.hs.is_mine_id(s.user_id), states) + local_states = filter(lambda s: self.is_mine_id(s.user_id), states) if not local_states: continue - host = UserID.from_string(user_id).domain + host = get_domain_from_id(user_id) hosts_to_states.setdefault(host, []).extend(local_states) # TODO: de-dup hosts_to_states, as a single host might have multiple @@ -465,24 +610,24 @@ class PresenceHandler(BaseHandler): self._push_to_remotes(hosts_to_states) + @defer.inlineCallbacks + def notify_for_states(self, state, stream_id): + parties = yield self._get_interested_parties([state]) + room_ids_to_states, users_to_states, hosts_to_states = parties + + self.notifier.on_new_event( + "presence_key", stream_id, rooms=room_ids_to_states.keys(), + users=[UserID.from_string(u) for u in users_to_states.keys()] + ) + def _push_to_remotes(self, hosts_to_states): """Sends state updates to remote servers. Args: hosts_to_states (dict): Mapping `server_name` -> `[UserPresenceState]` """ - now = self.clock.time_msec() for host, states in hosts_to_states.items(): - self.federation.send_edu( - destination=host, - edu_type="m.presence", - content={ - "push": [ - _format_user_presence_state(state, now) - for state in states - ] - } - ) + self.federation.send_presence(host, states) @defer.inlineCallbacks def incoming_presence(self, origin, content): @@ -503,6 +648,13 @@ class PresenceHandler(BaseHandler): ) continue + if get_domain_from_id(user_id) != origin: + logger.info( + "Got presence update from %r with bad 'user_id': %r", + origin, user_id, + ) + continue + presence_state = push.get("presence", None) if not presence_state: logger.info( @@ -562,17 +714,17 @@ class PresenceHandler(BaseHandler): defer.returnValue([ { "type": "m.presence", - "content": _format_user_presence_state(state, now), + "content": format_user_presence_state(state, now), } for state in updates ]) else: defer.returnValue([ - _format_user_presence_state(state, now) for state in updates + format_user_presence_state(state, now) for state in updates ]) @defer.inlineCallbacks - def set_state(self, target_user, state): + def set_state(self, target_user, state, ignore_status_msg=False): """Set the presence state of the user. """ status_msg = state.get("status_msg", None) @@ -589,10 +741,13 @@ class PresenceHandler(BaseHandler): prev_state = yield self.current_state_for_user(user_id) new_fields = { - "state": presence, - "status_msg": status_msg if presence != PresenceState.OFFLINE else None + "state": presence } + if not ignore_status_msg: + msg = status_msg if presence != PresenceState.OFFLINE else None + new_fields["status_msg"] = msg + if presence == PresenceState.ONLINE: new_fields["last_active_ts"] = self.clock.time_msec() @@ -611,14 +766,14 @@ class PresenceHandler(BaseHandler): # don't need to send to local clients here, as that is done as part # of the event stream/sync. # TODO: Only send to servers not already in the room. - if self.hs.is_mine(user): + user_ids = yield self.state.get_current_user_in_room(room_id) + if self.is_mine(user): state = yield self.current_state_for_user(user.to_string()) - hosts = yield self.store.get_joined_hosts_for_room(room_id) + hosts = set(get_domain_from_id(u) for u in user_ids) self._push_to_remotes({host: (state,) for host in hosts}) else: - user_ids = yield self.store.get_users_in_room(room_id) - user_ids = filter(self.hs.is_mine_id, user_ids) + user_ids = filter(self.is_mine_id, user_ids) states = yield self.current_state_for_users(user_ids) @@ -628,7 +783,7 @@ class PresenceHandler(BaseHandler): def get_presence_list(self, observer_user, accepted=None): """Returns the presence for all users in their presence list. """ - if not self.hs.is_mine(observer_user): + if not self.is_mine(observer_user): raise SynapseError(400, "User is not hosted on this Home Server") presence_list = yield self.store.get_presence_list( @@ -659,7 +814,7 @@ class PresenceHandler(BaseHandler): observer_user.localpart, observed_user.to_string() ) - if self.hs.is_mine(observed_user): + if self.is_mine(observed_user): yield self.invite_presence(observed_user, observer_user) else: yield self.federation.send_edu( @@ -675,11 +830,11 @@ class PresenceHandler(BaseHandler): def invite_presence(self, observed_user, observer_user): """Handles new presence invites. """ - if not self.hs.is_mine(observed_user): + if not self.is_mine(observed_user): raise SynapseError(400, "User is not hosted on this Home Server") # TODO: Don't auto accept - if self.hs.is_mine(observer_user): + if self.is_mine(observer_user): yield self.accept_presence(observed_user, observer_user) else: self.federation.send_edu( @@ -742,7 +897,7 @@ class PresenceHandler(BaseHandler): Returns: A Deferred. """ - if not self.hs.is_mine(observer_user): + if not self.is_mine(observer_user): raise SynapseError(400, "User is not hosted on this Home Server") yield self.store.del_presence_list( @@ -793,28 +948,38 @@ class PresenceHandler(BaseHandler): def should_notify(old_state, new_state): """Decides if a presence state change should be sent to interested parties. """ + if old_state == new_state: + return False + if old_state.status_msg != new_state.status_msg: + notify_reason_counter.inc("status_msg_change") return True - if old_state.state == PresenceState.ONLINE: - if new_state.state != PresenceState.ONLINE: - # Always notify for online -> anything - return True + if old_state.state != new_state.state: + notify_reason_counter.inc("state_change") + state_transition_counter.inc(old_state.state, new_state.state) + return True + if old_state.state == PresenceState.ONLINE: if new_state.currently_active != old_state.currently_active: + notify_reason_counter.inc("current_active_change") return True - if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY: - # Always notify for a transition where last active gets bumped. - return True + if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY: + # Only notify about last active bumps if we're not currently acive + if not new_state.currently_active: + notify_reason_counter.inc("last_active_change_online") + return True - if old_state.state != new_state.state: + elif new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY: + # Always notify for a transition where last active gets bumped. + notify_reason_counter.inc("last_active_change_not_online") return True return False -def _format_user_presence_state(state, now): +def format_user_presence_state(state, now): """Convert UserPresenceState to a format that can be sent down to clients and to other servers. """ @@ -834,9 +999,14 @@ def _format_user_presence_state(state, now): class PresenceEventSource(object): def __init__(self, hs): - self.hs = hs + # We can't call get_presence_handler here because there's a cycle: + # + # Presence -> Notifier -> PresenceEventSource -> Presence + # + self.get_presence_handler = hs.get_presence_handler self.clock = hs.get_clock() self.store = hs.get_datastore() + self.state = hs.get_state_handler() @defer.inlineCallbacks @log_function @@ -860,7 +1030,7 @@ class PresenceEventSource(object): from_key = int(from_key) room_ids = room_ids or [] - presence = self.hs.get_handlers().presence_handler + presence = self.get_presence_handler() stream_change_cache = self.store.presence_stream_cache if not room_ids: @@ -877,13 +1047,13 @@ class PresenceEventSource(object): user_ids_changed = set() changed = None - if from_key and max_token - from_key < 100: - # For small deltas, its quicker to get all changes and then - # work out if we share a room or they're in our presence list + if from_key: changed = stream_change_cache.get_all_entities_changed(from_key) - # get_all_entities_changed can return None - if changed is not None: + if changed is not None and len(changed) < 500: + # For small deltas, its quicker to get all changes and then + # work out if we share a room or they're in our presence list + get_updates_counter.inc("stream") for other_user_id in changed: if other_user_id in friends: user_ids_changed.add(other_user_id) @@ -895,9 +1065,11 @@ class PresenceEventSource(object): else: # Too many possible updates. Find all users we can see and check # if any of them have changed. + get_updates_counter.inc("full") + user_ids_to_check = set() for room_id in room_ids: - users = yield self.store.get_users_in_room(room_id) + users = yield self.state.get_current_user_in_room(room_id) user_ids_to_check.update(users) user_ids_to_check.update(friends) @@ -920,7 +1092,7 @@ class PresenceEventSource(object): defer.returnValue(([ { "type": "m.presence", - "content": _format_user_presence_state(s, now), + "content": format_user_presence_state(s, now), } for s in updates.values() if include_offline or s.state != PresenceState.OFFLINE @@ -933,15 +1105,14 @@ class PresenceEventSource(object): return self.get_new_events(user, from_key=None, include_offline=False) -def handle_timeouts(user_states, is_mine_fn, user_to_num_current_syncs, now): +def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now): """Checks the presence of users that have timed out and updates as appropriate. Args: user_states(list): List of UserPresenceState's to check. is_mine_fn (fn): Function that returns if a user_id is ours - user_to_num_current_syncs (dict): Mapping of user_id to number of currently - active syncs. + syncing_user_ids (set): Set of user_ids with active syncs. now (int): Current time in ms. Returns: @@ -952,21 +1123,20 @@ def handle_timeouts(user_states, is_mine_fn, user_to_num_current_syncs, now): for state in user_states: is_mine = is_mine_fn(state.user_id) - new_state = handle_timeout(state, is_mine, user_to_num_current_syncs, now) + new_state = handle_timeout(state, is_mine, syncing_user_ids, now) if new_state: changes[state.user_id] = new_state return changes.values() -def handle_timeout(state, is_mine, user_to_num_current_syncs, now): +def handle_timeout(state, is_mine, syncing_user_ids, now): """Checks the presence of the user to see if any of the timers have elapsed Args: state (UserPresenceState) is_mine (bool): Whether the user is ours - user_to_num_current_syncs (dict): Mapping of user_id to number of currently - active syncs. + syncing_user_ids (set): Set of user_ids with active syncs. now (int): Current time in ms. Returns: @@ -1000,7 +1170,7 @@ def handle_timeout(state, is_mine, user_to_num_current_syncs, now): # If there are have been no sync for a while (and none ongoing), # set presence to offline - if not user_to_num_current_syncs.get(user_id, 0): + if user_id not in syncing_user_ids: if now - state.last_user_sync_ts > SYNC_ONLINE_TIMEOUT: state = state.copy_and_replace( state=PresenceState.OFFLINE, diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index b45eafbb49..87f74dfb8e 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -13,28 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + from twisted.internet import defer +import synapse.types from synapse.api.errors import SynapseError, AuthError, CodeMessageException -from synapse.types import UserID, Requester -from synapse.util import unwrapFirstError - +from synapse.types import UserID from ._base import BaseHandler -import logging - logger = logging.getLogger(__name__) -def changed_presencelike_data(distributor, user, state): - return distributor.fire("changed_presencelike_data", user, state) - - -def collect_presencelike_data(distributor, user, content): - return distributor.fire("collect_presencelike_data", user, content) - - class ProfileHandler(BaseHandler): def __init__(self, hs): @@ -45,21 +36,6 @@ class ProfileHandler(BaseHandler): "profile", self.on_profile_query ) - distributor = hs.get_distributor() - self.distributor = distributor - - distributor.declare("collect_presencelike_data") - distributor.declare("changed_presencelike_data") - - distributor.observe("registered_user", self.registered_user) - - distributor.observe( - "collect_presencelike_data", self.collect_presencelike_data - ) - - def registered_user(self, user): - return self.store.create_profile(user.localpart) - @defer.inlineCallbacks def get_displayname(self, target_user): if self.hs.is_mine(target_user): @@ -89,13 +65,13 @@ class ProfileHandler(BaseHandler): defer.returnValue(result["displayname"]) @defer.inlineCallbacks - def set_displayname(self, target_user, requester, new_displayname): + def set_displayname(self, target_user, requester, new_displayname, by_admin=False): """target_user is the user whose displayname is to be changed; auth_user is the user attempting to make this change.""" if not self.hs.is_mine(target_user): raise SynapseError(400, "User is not hosted on this Home Server") - if target_user != requester.user: + if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's displayname") if new_displayname == '': @@ -105,10 +81,6 @@ class ProfileHandler(BaseHandler): target_user.localpart, new_displayname ) - yield changed_presencelike_data(self.distributor, target_user, { - "displayname": new_displayname, - }) - yield self._update_join_states(requester) @defer.inlineCallbacks @@ -139,44 +111,22 @@ class ProfileHandler(BaseHandler): defer.returnValue(result["avatar_url"]) @defer.inlineCallbacks - def set_avatar_url(self, target_user, requester, new_avatar_url): + def set_avatar_url(self, target_user, requester, new_avatar_url, by_admin=False): """target_user is the user whose avatar_url is to be changed; auth_user is the user attempting to make this change.""" if not self.hs.is_mine(target_user): raise SynapseError(400, "User is not hosted on this Home Server") - if target_user != requester.user: + if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's avatar_url") yield self.store.set_profile_avatar_url( target_user.localpart, new_avatar_url ) - yield changed_presencelike_data(self.distributor, target_user, { - "avatar_url": new_avatar_url, - }) - yield self._update_join_states(requester) @defer.inlineCallbacks - def collect_presencelike_data(self, user, state): - if not self.hs.is_mine(user): - defer.returnValue(None) - - (displayname, avatar_url) = yield defer.gatherResults( - [ - self.store.get_profile_displayname(user.localpart), - self.store.get_profile_avatar_url(user.localpart), - ], - consumeErrors=True - ).addErrback(unwrapFirstError) - - state["displayname"] = displayname - state["avatar_url"] = avatar_url - - defer.returnValue(None) - - @defer.inlineCallbacks def on_profile_query(self, args): user = UserID.from_string(args["user_id"]) if not self.hs.is_mine(user): @@ -215,7 +165,9 @@ class ProfileHandler(BaseHandler): try: # Assume the user isn't a guest because we don't let guests set # profile or avatar data. - requester = Requester(user, "", False) + # XXX why are we recreating `requester` here for each room? + # what was wrong with the `requester` we were passed? + requester = synapse.types.create_requester(user) yield handler.update_membership( requester, user, diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 935c339707..916e80a48e 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -18,6 +18,7 @@ from ._base import BaseHandler from twisted.internet import defer from synapse.util.logcontext import PreserveLoggingContext +from synapse.types import get_domain_from_id import logging @@ -29,12 +30,15 @@ class ReceiptsHandler(BaseHandler): def __init__(self, hs): super(ReceiptsHandler, self).__init__(hs) + self.server_name = hs.config.server_name + self.store = hs.get_datastore() self.hs = hs - self.federation = hs.get_replication_layer() - self.federation.register_edu_handler( + self.federation = hs.get_federation_sender() + hs.get_replication_layer().register_edu_handler( "m.receipt", self._received_remote_receipt ) self.clock = self.hs.get_clock() + self.state = hs.get_state_handler() @defer.inlineCallbacks def received_client_receipt(self, room_id, receipt_type, user_id, @@ -80,6 +84,9 @@ class ReceiptsHandler(BaseHandler): def _handle_new_receipts(self, receipts): """Takes a list of receipts, stores them and informs the notifier. """ + min_batch_id = None + max_batch_id = None + for receipt in receipts: room_id = receipt["room_id"] receipt_type = receipt["receipt_type"] @@ -97,10 +104,21 @@ class ReceiptsHandler(BaseHandler): stream_id, max_persisted_id = res - with PreserveLoggingContext(): - self.notifier.on_new_event( - "receipt_key", max_persisted_id, rooms=[room_id] - ) + if min_batch_id is None or stream_id < min_batch_id: + min_batch_id = stream_id + if max_batch_id is None or max_persisted_id > max_batch_id: + max_batch_id = max_persisted_id + + affected_room_ids = list(set([r["room_id"] for r in receipts])) + + with PreserveLoggingContext(): + self.notifier.on_new_event( + "receipt_key", max_batch_id, rooms=affected_room_ids + ) + # Note that the min here shouldn't be relied upon to be accurate. + self.hs.get_pusherpool().on_new_receipts( + min_batch_id, max_batch_id, affected_room_ids + ) defer.returnValue(True) @@ -117,12 +135,10 @@ class ReceiptsHandler(BaseHandler): event_ids = receipt["event_ids"] data = receipt["data"] - remotedomains = set() - - rm_handler = self.hs.get_handlers().room_member_handler - yield rm_handler.fetch_room_distributions_into( - room_id, localusers=None, remotedomains=remotedomains - ) + users = yield self.state.get_current_user_in_room(room_id) + remotedomains = set(get_domain_from_id(u) for u in users) + remotedomains = remotedomains.copy() + remotedomains.discard(self.server_name) logger.debug("Sending receipt to: %r", remotedomains) @@ -140,6 +156,7 @@ class ReceiptsHandler(BaseHandler): } }, }, + key=(room_id, receipt_type, user_id), ) @defer.inlineCallbacks diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index f287ee247b..886fec8701 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -14,34 +14,28 @@ # limitations under the License. """Contains functions for registering clients.""" +import logging +import urllib + from twisted.internet import defer -from synapse.types import UserID from synapse.api.errors import ( AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError ) -from ._base import BaseHandler -from synapse.util.async import run_on_reactor from synapse.http.client import CaptchaServerHttpClient - -import logging -import urllib +from synapse.types import UserID +from synapse.util.async import run_on_reactor +from ._base import BaseHandler logger = logging.getLogger(__name__) -def registered_user(distributor, user): - return distributor.fire("registered_user", user) - - class RegistrationHandler(BaseHandler): def __init__(self, hs): super(RegistrationHandler, self).__init__(hs) self.auth = hs.get_auth() - self.distributor = hs.get_distributor() - self.distributor.declare("registered_user") self.captcha_client = CaptchaServerHttpClient(hs) self._next_generated_user_id = None @@ -58,6 +52,13 @@ class RegistrationHandler(BaseHandler): Codes.INVALID_USERNAME ) + if localpart[0] == '_': + raise SynapseError( + 400, + "User ID may not begin with _", + Codes.INVALID_USERNAME + ) + user = UserID(localpart, self.hs.hostname) user_id = user.to_string() @@ -96,7 +97,8 @@ class RegistrationHandler(BaseHandler): password=None, generate_token=True, guest_access_token=None, - make_guest=False + make_guest=False, + admin=False, ): """Registers a new client on the server. @@ -104,8 +106,13 @@ class RegistrationHandler(BaseHandler): localpart : The local part of the user ID to register. If None, one will be generated. password (str) : The password to assign to this user so they can - login again. This can be None which means they cannot login again - via a password (e.g. the user is an application service user). + login again. This can be None which means they cannot login again + via a password (e.g. the user is an application service user). + generate_token (bool): Whether a new access token should be + generated. Having this be True should be considered deprecated, + since it offers no means of associating a device_id with the + access_token. Instead you should call auth_handler.issue_access_token + after registration. Returns: A tuple of (user_id, access_token). Raises: @@ -143,9 +150,12 @@ class RegistrationHandler(BaseHandler): password_hash=password_hash, was_guest=was_guest, make_guest=make_guest, + create_profile_with_localpart=( + # If the user was a guest then they already have a profile + None if was_guest else user.localpart + ), + admin=admin, ) - - yield registered_user(self.distributor, user) else: # autogen a sequential user ID attempts = 0 @@ -163,7 +173,8 @@ class RegistrationHandler(BaseHandler): user_id=user_id, token=token, password_hash=password_hash, - make_guest=make_guest + make_guest=make_guest, + create_profile_with_localpart=user.localpart, ) except SynapseError: # if user id is taken, just generate another @@ -171,7 +182,6 @@ class RegistrationHandler(BaseHandler): user_id = None token = None attempts += 1 - yield registered_user(self.distributor, user) # We used to generate default identicons here, but nowadays # we want clients to generate their own as part of their branding @@ -183,7 +193,7 @@ class RegistrationHandler(BaseHandler): def appservice_register(self, user_localpart, as_token): user = UserID(user_localpart, self.hs.hostname) user_id = user.to_string() - service = yield self.store.get_app_service_by_token(as_token) + service = self.store.get_app_service_by_token(as_token) if not service: raise AuthError(403, "Invalid application service token.") if not service.is_interested_in_user(user_id): @@ -198,15 +208,13 @@ class RegistrationHandler(BaseHandler): user_id, allowed_appservice=service ) - token = self.auth_handler().generate_access_token(user_id) yield self.store.register( user_id=user_id, - token=token, password_hash="", appservice_id=service_id, + create_profile_with_localpart=user.localpart, ) - yield registered_user(self.distributor, user) - defer.returnValue((user_id, token)) + defer.returnValue(user_id) @defer.inlineCallbacks def check_recaptcha(self, ip, private_key, challenge, response): @@ -251,9 +259,9 @@ class RegistrationHandler(BaseHandler): yield self.store.register( user_id=user_id, token=token, - password_hash=None + password_hash=None, + create_profile_with_localpart=user.localpart, ) - yield registered_user(self.distributor, user) except Exception as e: yield self.store.add_access_token_to_user(user_id, token) # Ignore Registration errors @@ -296,11 +304,10 @@ class RegistrationHandler(BaseHandler): # XXX: This should be a deferred list, shouldn't it? yield identity_handler.bind_threepid(c, user_id) - @defer.inlineCallbacks def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None): # valid user IDs must not clash with any user ID namespaces claimed by # application services. - services = yield self.store.get_app_services() + services = self.store.get_app_services() interested_services = [ s for s in services if s.is_interested_in_user(user_id) @@ -361,8 +368,61 @@ class RegistrationHandler(BaseHandler): ) defer.returnValue(data) + @defer.inlineCallbacks + def get_or_create_user(self, requester, localpart, displayname, + password_hash=None): + """Creates a new user if the user does not exist, + else revokes all previous access tokens and generates a new one. + + Args: + localpart : The local part of the user ID to register. If None, + one will be randomly generated. + Returns: + A tuple of (user_id, access_token). + Raises: + RegistrationError if there was a problem registering. + """ + yield run_on_reactor() + + if localpart is None: + raise SynapseError(400, "Request must include user id") + + need_register = True + + try: + yield self.check_username(localpart) + except SynapseError as e: + if e.errcode == Codes.USER_IN_USE: + need_register = False + else: + raise + + user = UserID(localpart, self.hs.hostname) + user_id = user.to_string() + token = self.auth_handler().generate_access_token(user_id) + + if need_register: + yield self.store.register( + user_id=user_id, + token=token, + password_hash=password_hash, + create_profile_with_localpart=user.localpart, + ) + else: + yield self.store.user_delete_access_tokens(user_id=user_id) + yield self.store.add_access_token_to_user(user_id=user_id, token=token) + + if displayname is not None: + logger.info("setting user display name: %s -> %s", user_id, displayname) + profile_handler = self.hs.get_handlers().profile_handler + yield profile_handler.set_displayname( + user, requester, displayname, by_admin=True, + ) + + defer.returnValue((user_id, token)) + def auth_handler(self): - return self.hs.get_handlers().auth_handler + return self.hs.get_auth_handler() @defer.inlineCallbacks def guest_access_token_for(self, medium, address, inviter_user_id): diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index f7163470a9..5f18007e90 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -18,19 +18,15 @@ from twisted.internet import defer from ._base import BaseHandler -from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken, Requester +from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken from synapse.api.constants import ( - EventTypes, Membership, JoinRules, RoomCreationPreset, + EventTypes, JoinRules, RoomCreationPreset ) -from synapse.api.errors import AuthError, StoreError, SynapseError, Codes -from synapse.util import stringutils, unwrapFirstError -from synapse.util.logcontext import preserve_context_over_fn - -from signedjson.sign import verify_signed_json -from signedjson.key import decode_verify_key_bytes +from synapse.api.errors import AuthError, StoreError, SynapseError +from synapse.util import stringutils +from synapse.visibility import filter_events_for_client from collections import OrderedDict -from unpaddedbase64 import decode_base64 import logging import math @@ -41,20 +37,6 @@ logger = logging.getLogger(__name__) id_server_scheme = "https://" -def user_left_room(distributor, user, room_id): - return preserve_context_over_fn( - distributor.fire, - "user_left_room", user=user, room_id=room_id - ) - - -def user_joined_room(distributor, user, room_id): - return preserve_context_over_fn( - distributor.fire, - "user_joined_room", user=user, room_id=room_id - ) - - class RoomCreationHandler(BaseHandler): PRESETS_DICT = { @@ -122,7 +104,8 @@ class RoomCreationHandler(BaseHandler): invite_3pid_list = config.get("invite_3pid", []) - is_public = config.get("visibility", None) == "public" + visibility = config.get("visibility", None) + is_public = visibility == "public" # autogen room IDs and try to create it. We may clash, so just # try a few times till one goes through, giving up eventually. @@ -158,9 +141,9 @@ class RoomCreationHandler(BaseHandler): preset_config = config.get( "preset", - RoomCreationPreset.PUBLIC_CHAT - if is_public - else RoomCreationPreset.PRIVATE_CHAT + RoomCreationPreset.PRIVATE_CHAT + if visibility == "private" + else RoomCreationPreset.PUBLIC_CHAT ) raw_initial_state = config.get("initial_state", []) @@ -212,6 +195,11 @@ class RoomCreationHandler(BaseHandler): }, ratelimit=False) + content = {} + is_direct = config.get("is_direct", None) + if is_direct: + content["is_direct"] = is_direct + for invitee in invite_list: yield room_member_handler.update_membership( requester, @@ -219,6 +207,7 @@ class RoomCreationHandler(BaseHandler): room_id, "invite", ratelimit=False, + content=content, ) for invite_3pid in invite_3pid_list: @@ -365,659 +354,6 @@ class RoomCreationHandler(BaseHandler): ) -class RoomMemberHandler(BaseHandler): - # TODO(paul): This handler currently contains a messy conflation of - # low-level API that works on UserID objects and so on, and REST-level - # API that takes ID strings and returns pagination chunks. These concerns - # ought to be separated out a lot better. - - def __init__(self, hs): - super(RoomMemberHandler, self).__init__(hs) - - self.clock = hs.get_clock() - - self.distributor = hs.get_distributor() - self.distributor.declare("user_joined_room") - self.distributor.declare("user_left_room") - - @defer.inlineCallbacks - def get_room_members(self, room_id): - users = yield self.store.get_users_in_room(room_id) - - defer.returnValue([UserID.from_string(u) for u in users]) - - @defer.inlineCallbacks - def fetch_room_distributions_into(self, room_id, localusers=None, - remotedomains=None, ignore_user=None): - """Fetch the distribution of a room, adding elements to either - 'localusers' or 'remotedomains', which should be a set() if supplied. - If ignore_user is set, ignore that user. - - This function returns nothing; its result is performed by the - side-effect on the two passed sets. This allows easy accumulation of - member lists of multiple rooms at once if required. - """ - members = yield self.get_room_members(room_id) - for member in members: - if ignore_user is not None and member == ignore_user: - continue - - if self.hs.is_mine(member): - if localusers is not None: - localusers.add(member) - else: - if remotedomains is not None: - remotedomains.add(member.domain) - - @defer.inlineCallbacks - def update_membership( - self, - requester, - target, - room_id, - action, - txn_id=None, - remote_room_hosts=None, - third_party_signed=None, - ratelimit=True, - ): - effective_membership_state = action - if action in ["kick", "unban"]: - effective_membership_state = "leave" - elif action == "forget": - effective_membership_state = "leave" - - if third_party_signed is not None: - replication = self.hs.get_replication_layer() - yield replication.exchange_third_party_invite( - third_party_signed["sender"], - target.to_string(), - room_id, - third_party_signed, - ) - - msg_handler = self.hs.get_handlers().message_handler - - content = {"membership": effective_membership_state} - if requester.is_guest: - content["kind"] = "guest" - - event, context = yield msg_handler.create_event( - { - "type": EventTypes.Member, - "content": content, - "room_id": room_id, - "sender": requester.user.to_string(), - "state_key": target.to_string(), - - # For backwards compatibility: - "membership": effective_membership_state, - }, - token_id=requester.access_token_id, - txn_id=txn_id, - ) - - old_state = context.current_state.get((EventTypes.Member, event.state_key)) - old_membership = old_state.content.get("membership") if old_state else None - if action == "unban" and old_membership != "ban": - raise SynapseError( - 403, - "Cannot unban user who was not banned (membership=%s)" % old_membership, - errcode=Codes.BAD_STATE - ) - if old_membership == "ban" and action != "unban": - raise SynapseError( - 403, - "Cannot %s user who was is banned" % (action,), - errcode=Codes.BAD_STATE - ) - - member_handler = self.hs.get_handlers().room_member_handler - yield member_handler.send_membership_event( - requester, - event, - context, - ratelimit=ratelimit, - remote_room_hosts=remote_room_hosts, - ) - - if action == "forget": - yield self.forget(requester.user, room_id) - - @defer.inlineCallbacks - def send_membership_event( - self, - requester, - event, - context, - remote_room_hosts=None, - ratelimit=True, - ): - """ - Change the membership status of a user in a room. - - Args: - requester (Requester): The local user who requested the membership - event. If None, certain checks, like whether this homeserver can - act as the sender, will be skipped. - event (SynapseEvent): The membership event. - context: The context of the event. - is_guest (bool): Whether the sender is a guest. - room_hosts ([str]): Homeservers which are likely to already be in - the room, and could be danced with in order to join this - homeserver for the first time. - ratelimit (bool): Whether to rate limit this request. - Raises: - SynapseError if there was a problem changing the membership. - """ - remote_room_hosts = remote_room_hosts or [] - - target_user = UserID.from_string(event.state_key) - room_id = event.room_id - - if requester is not None: - sender = UserID.from_string(event.sender) - assert sender == requester.user, ( - "Sender (%s) must be same as requester (%s)" % - (sender, requester.user) - ) - assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,) - else: - requester = Requester(target_user, None, False) - - message_handler = self.hs.get_handlers().message_handler - prev_event = message_handler.deduplicate_state_event(event, context) - if prev_event is not None: - return - - action = "send" - - if event.membership == Membership.JOIN: - if requester.is_guest and not self._can_guest_join(context.current_state): - # This should be an auth check, but guests are a local concept, - # so don't really fit into the general auth process. - raise AuthError(403, "Guest access not allowed") - do_remote_join_dance, remote_room_hosts = self._should_do_dance( - context, - (self.get_inviter(event.state_key, context.current_state)), - remote_room_hosts, - ) - if do_remote_join_dance: - action = "remote_join" - elif event.membership == Membership.LEAVE: - is_host_in_room = self.is_host_in_room(context.current_state) - - if not is_host_in_room: - # perhaps we've been invited - inviter = self.get_inviter(target_user.to_string(), context.current_state) - if not inviter: - raise SynapseError(404, "Not a known room") - - if self.hs.is_mine(inviter): - # the inviter was on our server, but has now left. Carry on - # with the normal rejection codepath. - # - # This is a bit of a hack, because the room might still be - # active on other servers. - pass - else: - # send the rejection to the inviter's HS. - remote_room_hosts = remote_room_hosts + [inviter.domain] - action = "remote_reject" - - federation_handler = self.hs.get_handlers().federation_handler - - if action == "remote_join": - if len(remote_room_hosts) == 0: - raise SynapseError(404, "No known servers") - - # We don't do an auth check if we are doing an invite - # join dance for now, since we're kinda implicitly checking - # that we are allowed to join when we decide whether or not we - # need to do the invite/join dance. - yield federation_handler.do_invite_join( - remote_room_hosts, - event.room_id, - event.user_id, - event.content, - ) - elif action == "remote_reject": - yield federation_handler.do_remotely_reject_invite( - remote_room_hosts, - room_id, - event.user_id - ) - else: - yield self.handle_new_client_event( - requester, - event, - context, - extra_users=[target_user], - ratelimit=ratelimit, - ) - - prev_member_event = context.current_state.get( - (EventTypes.Member, target_user.to_string()), - None - ) - - if event.membership == Membership.JOIN: - if not prev_member_event or prev_member_event.membership != Membership.JOIN: - # Only fire user_joined_room if the user has acutally joined the - # room. Don't bother if the user is just changing their profile - # info. - yield user_joined_room(self.distributor, target_user, room_id) - elif event.membership == Membership.LEAVE: - if prev_member_event and prev_member_event.membership == Membership.JOIN: - user_left_room(self.distributor, target_user, room_id) - - def _can_guest_join(self, current_state): - """ - Returns whether a guest can join a room based on its current state. - """ - guest_access = current_state.get((EventTypes.GuestAccess, ""), None) - return ( - guest_access - and guest_access.content - and "guest_access" in guest_access.content - and guest_access.content["guest_access"] == "can_join" - ) - - def _should_do_dance(self, context, inviter, room_hosts=None): - # TODO: Shouldn't this be remote_room_host? - room_hosts = room_hosts or [] - - is_host_in_room = self.is_host_in_room(context.current_state) - if is_host_in_room: - return False, room_hosts - - if inviter and not self.hs.is_mine(inviter): - room_hosts.append(inviter.domain) - - return True, room_hosts - - @defer.inlineCallbacks - def lookup_room_alias(self, room_alias): - """ - Get the room ID associated with a room alias. - - Args: - room_alias (RoomAlias): The alias to look up. - Returns: - A tuple of: - The room ID as a RoomID object. - Hosts likely to be participating in the room ([str]). - Raises: - SynapseError if room alias could not be found. - """ - directory_handler = self.hs.get_handlers().directory_handler - mapping = yield directory_handler.get_association(room_alias) - - if not mapping: - raise SynapseError(404, "No such room alias") - - room_id = mapping["room_id"] - servers = mapping["servers"] - - defer.returnValue((RoomID.from_string(room_id), servers)) - - def get_inviter(self, user_id, current_state): - prev_state = current_state.get((EventTypes.Member, user_id)) - if prev_state and prev_state.membership == Membership.INVITE: - return UserID.from_string(prev_state.user_id) - return None - - @defer.inlineCallbacks - def get_joined_rooms_for_user(self, user): - """Returns a list of roomids that the user has any of the given - membership states in.""" - - rooms = yield self.store.get_rooms_for_user( - user.to_string(), - ) - - # For some reason the list of events contains duplicates - # TODO(paul): work out why because I really don't think it should - room_ids = set(r.room_id for r in rooms) - - defer.returnValue(room_ids) - - @defer.inlineCallbacks - def do_3pid_invite( - self, - room_id, - inviter, - medium, - address, - id_server, - requester, - txn_id - ): - invitee = yield self._lookup_3pid( - id_server, medium, address - ) - - if invitee: - handler = self.hs.get_handlers().room_member_handler - yield handler.update_membership( - requester, - UserID.from_string(invitee), - room_id, - "invite", - txn_id=txn_id, - ) - else: - yield self._make_and_store_3pid_invite( - requester, - id_server, - medium, - address, - room_id, - inviter, - txn_id=txn_id - ) - - @defer.inlineCallbacks - def _lookup_3pid(self, id_server, medium, address): - """Looks up a 3pid in the passed identity server. - - Args: - id_server (str): The server name (including port, if required) - of the identity server to use. - medium (str): The type of the third party identifier (e.g. "email"). - address (str): The third party identifier (e.g. "foo@example.com"). - - Returns: - (str) the matrix ID of the 3pid, or None if it is not recognized. - """ - try: - data = yield self.hs.get_simple_http_client().get_json( - "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,), - { - "medium": medium, - "address": address, - } - ) - - if "mxid" in data: - if "signatures" not in data: - raise AuthError(401, "No signatures on 3pid binding") - self.verify_any_signature(data, id_server) - defer.returnValue(data["mxid"]) - - except IOError as e: - logger.warn("Error from identity server lookup: %s" % (e,)) - defer.returnValue(None) - - @defer.inlineCallbacks - def verify_any_signature(self, data, server_hostname): - if server_hostname not in data["signatures"]: - raise AuthError(401, "No signature from server %s" % (server_hostname,)) - for key_name, signature in data["signatures"][server_hostname].items(): - key_data = yield self.hs.get_simple_http_client().get_json( - "%s%s/_matrix/identity/api/v1/pubkey/%s" % - (id_server_scheme, server_hostname, key_name,), - ) - if "public_key" not in key_data: - raise AuthError(401, "No public key named %s from %s" % - (key_name, server_hostname,)) - verify_signed_json( - data, - server_hostname, - decode_verify_key_bytes(key_name, decode_base64(key_data["public_key"])) - ) - return - - @defer.inlineCallbacks - def _make_and_store_3pid_invite( - self, - requester, - id_server, - medium, - address, - room_id, - user, - txn_id - ): - room_state = yield self.hs.get_state_handler().get_current_state(room_id) - - inviter_display_name = "" - inviter_avatar_url = "" - member_event = room_state.get((EventTypes.Member, user.to_string())) - if member_event: - inviter_display_name = member_event.content.get("displayname", "") - inviter_avatar_url = member_event.content.get("avatar_url", "") - - canonical_room_alias = "" - canonical_alias_event = room_state.get((EventTypes.CanonicalAlias, "")) - if canonical_alias_event: - canonical_room_alias = canonical_alias_event.content.get("alias", "") - - room_name = "" - room_name_event = room_state.get((EventTypes.Name, "")) - if room_name_event: - room_name = room_name_event.content.get("name", "") - - room_join_rules = "" - join_rules_event = room_state.get((EventTypes.JoinRules, "")) - if join_rules_event: - room_join_rules = join_rules_event.content.get("join_rule", "") - - room_avatar_url = "" - room_avatar_event = room_state.get((EventTypes.RoomAvatar, "")) - if room_avatar_event: - room_avatar_url = room_avatar_event.content.get("url", "") - - token, public_keys, fallback_public_key, display_name = ( - yield self._ask_id_server_for_third_party_invite( - id_server=id_server, - medium=medium, - address=address, - room_id=room_id, - inviter_user_id=user.to_string(), - room_alias=canonical_room_alias, - room_avatar_url=room_avatar_url, - room_join_rules=room_join_rules, - room_name=room_name, - inviter_display_name=inviter_display_name, - inviter_avatar_url=inviter_avatar_url - ) - ) - - msg_handler = self.hs.get_handlers().message_handler - yield msg_handler.create_and_send_nonmember_event( - requester, - { - "type": EventTypes.ThirdPartyInvite, - "content": { - "display_name": display_name, - "public_keys": public_keys, - - # For backwards compatibility: - "key_validity_url": fallback_public_key["key_validity_url"], - "public_key": fallback_public_key["public_key"], - }, - "room_id": room_id, - "sender": user.to_string(), - "state_key": token, - }, - txn_id=txn_id, - ) - - @defer.inlineCallbacks - def _ask_id_server_for_third_party_invite( - self, - id_server, - medium, - address, - room_id, - inviter_user_id, - room_alias, - room_avatar_url, - room_join_rules, - room_name, - inviter_display_name, - inviter_avatar_url - ): - """ - Asks an identity server for a third party invite. - - :param id_server (str): hostname + optional port for the identity server. - :param medium (str): The literal string "email". - :param address (str): The third party address being invited. - :param room_id (str): The ID of the room to which the user is invited. - :param inviter_user_id (str): The user ID of the inviter. - :param room_alias (str): An alias for the room, for cosmetic - notifications. - :param room_avatar_url (str): The URL of the room's avatar, for cosmetic - notifications. - :param room_join_rules (str): The join rules of the email - (e.g. "public"). - :param room_name (str): The m.room.name of the room. - :param inviter_display_name (str): The current display name of the - inviter. - :param inviter_avatar_url (str): The URL of the inviter's avatar. - - :return: A deferred tuple containing: - token (str): The token which must be signed to prove authenticity. - public_keys ([{"public_key": str, "key_validity_url": str}]): - public_key is a base64-encoded ed25519 public key. - fallback_public_key: One element from public_keys. - display_name (str): A user-friendly name to represent the invited - user. - """ - - is_url = "%s%s/_matrix/identity/api/v1/store-invite" % ( - id_server_scheme, id_server, - ) - - invite_config = { - "medium": medium, - "address": address, - "room_id": room_id, - "room_alias": room_alias, - "room_avatar_url": room_avatar_url, - "room_join_rules": room_join_rules, - "room_name": room_name, - "sender": inviter_user_id, - "sender_display_name": inviter_display_name, - "sender_avatar_url": inviter_avatar_url, - } - - if self.hs.config.invite_3pid_guest: - registration_handler = self.hs.get_handlers().registration_handler - guest_access_token = yield registration_handler.guest_access_token_for( - medium=medium, - address=address, - inviter_user_id=inviter_user_id, - ) - - guest_user_info = yield self.hs.get_auth().get_user_by_access_token( - guest_access_token - ) - - invite_config.update({ - "guest_access_token": guest_access_token, - "guest_user_id": guest_user_info["user"].to_string(), - }) - - data = yield self.hs.get_simple_http_client().post_urlencoded_get_json( - is_url, - invite_config - ) - # TODO: Check for success - token = data["token"] - public_keys = data.get("public_keys", []) - if "public_key" in data: - fallback_public_key = { - "public_key": data["public_key"], - "key_validity_url": "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % ( - id_server_scheme, id_server, - ), - } - else: - fallback_public_key = public_keys[0] - - if not public_keys: - public_keys.append(fallback_public_key) - display_name = data["display_name"] - defer.returnValue((token, public_keys, fallback_public_key, display_name)) - - def forget(self, user, room_id): - return self.store.forget(user.to_string(), room_id) - - -class RoomListHandler(BaseHandler): - - @defer.inlineCallbacks - def get_public_room_list(self): - room_ids = yield self.store.get_public_room_ids() - - @defer.inlineCallbacks - def handle_room(room_id): - aliases = yield self.store.get_aliases_for_room(room_id) - if not aliases: - defer.returnValue(None) - - state = yield self.state_handler.get_current_state(room_id) - - result = {"aliases": aliases, "room_id": room_id} - - name_event = state.get((EventTypes.Name, ""), None) - if name_event: - name = name_event.content.get("name", None) - if name: - result["name"] = name - - topic_event = state.get((EventTypes.Topic, ""), None) - if topic_event: - topic = topic_event.content.get("topic", None) - if topic: - result["topic"] = topic - - canonical_event = state.get((EventTypes.CanonicalAlias, ""), None) - if canonical_event: - canonical_alias = canonical_event.content.get("alias", None) - if canonical_alias: - result["canonical_alias"] = canonical_alias - - visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None) - visibility = None - if visibility_event: - visibility = visibility_event.content.get("history_visibility", None) - result["world_readable"] = visibility == "world_readable" - - guest_event = state.get((EventTypes.GuestAccess, ""), None) - guest = None - if guest_event: - guest = guest_event.content.get("guest_access", None) - result["guest_can_join"] = guest == "can_join" - - avatar_event = state.get(("m.room.avatar", ""), None) - if avatar_event: - avatar_url = avatar_event.content.get("url", None) - if avatar_url: - result["avatar_url"] = avatar_url - - result["num_joined_members"] = sum( - 1 for (event_type, _), ev in state.items() - if event_type == EventTypes.Member and ev.membership == Membership.JOIN - ) - - defer.returnValue(result) - - result = [] - for chunk in (room_ids[i:i + 10] for i in xrange(0, len(room_ids), 10)): - chunk_result = yield defer.gatherResults([ - handle_room(room_id) - for room_id in chunk - ], consumeErrors=True).addErrback(unwrapFirstError) - result.extend(v for v in chunk_result if v) - - # FIXME (erikj): START is no longer a valid value - defer.returnValue({"start": "START", "end": "END", "chunk": result}) - - class RoomContextHandler(BaseHandler): @defer.inlineCallbacks def get_event_context(self, user, room_id, event_id, limit, is_guest): @@ -1040,10 +376,12 @@ class RoomContextHandler(BaseHandler): now_token = yield self.hs.get_event_sources().get_current_token() def filter_evts(events): - return self._filter_events_for_client( + return filter_events_for_client( + self.store, user.to_string(), events, - is_peeking=is_guest) + is_peeking=is_guest + ) event = yield self.store.get_event(event_id, get_prev_content=True, allow_none=True) @@ -1109,7 +447,7 @@ class RoomEventSource(object): logger.warn("Stream has topological part!!!! %r", from_key) from_key = "s%s" % (from_token.stream,) - app_service = yield self.store.get_app_service_by_user_id( + app_service = self.store.get_app_service_by_user_id( user.to_string() ) if app_service: @@ -1147,8 +485,11 @@ class RoomEventSource(object): defer.returnValue((events, end_key)) - def get_current_key(self, direction='f'): - return self.store.get_room_events_max_id(direction) + def get_current_key(self): + return self.store.get_room_events_max_id() + + def get_current_key_for_room(self, room_id): + return self.store.get_room_events_max_id(room_id) @defer.inlineCallbacks def get_pagination_rows(self, user, config, key): diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py new file mode 100644 index 0000000000..b04aea0110 --- /dev/null +++ b/synapse/handlers/room_list.py @@ -0,0 +1,403 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 - 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from ._base import BaseHandler + +from synapse.api.constants import ( + EventTypes, JoinRules, +) +from synapse.util.async import concurrently_execute +from synapse.util.caches.response_cache import ResponseCache + +from collections import namedtuple +from unpaddedbase64 import encode_base64, decode_base64 + +import logging +import msgpack + +logger = logging.getLogger(__name__) + +REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000 + + +class RoomListHandler(BaseHandler): + def __init__(self, hs): + super(RoomListHandler, self).__init__(hs) + self.response_cache = ResponseCache(hs) + self.remote_response_cache = ResponseCache(hs, timeout_ms=30 * 1000) + + def get_local_public_room_list(self, limit=None, since_token=None, + search_filter=None): + if search_filter: + # We explicitly don't bother caching searches. + return self._get_public_room_list(limit, since_token, search_filter) + + result = self.response_cache.get((limit, since_token)) + if not result: + result = self.response_cache.set( + (limit, since_token), + self._get_public_room_list(limit, since_token) + ) + return result + + @defer.inlineCallbacks + def _get_public_room_list(self, limit=None, since_token=None, + search_filter=None): + if since_token and since_token != "END": + since_token = RoomListNextBatch.from_token(since_token) + else: + since_token = None + + rooms_to_order_value = {} + rooms_to_num_joined = {} + rooms_to_latest_event_ids = {} + + newly_visible = [] + newly_unpublished = [] + if since_token: + stream_token = since_token.stream_ordering + current_public_id = yield self.store.get_current_public_room_stream_id() + public_room_stream_id = since_token.public_room_stream_id + newly_visible, newly_unpublished = yield self.store.get_public_room_changes( + public_room_stream_id, current_public_id + ) + else: + stream_token = yield self.store.get_room_max_stream_ordering() + public_room_stream_id = yield self.store.get_current_public_room_stream_id() + + room_ids = yield self.store.get_public_room_ids_at_stream_id( + public_room_stream_id + ) + + # We want to return rooms in a particular order: the number of joined + # users. We then arbitrarily use the room_id as a tie breaker. + + @defer.inlineCallbacks + def get_order_for_room(room_id): + latest_event_ids = rooms_to_latest_event_ids.get(room_id, None) + if not latest_event_ids: + latest_event_ids = yield self.store.get_forward_extremeties_for_room( + room_id, stream_token + ) + rooms_to_latest_event_ids[room_id] = latest_event_ids + + if not latest_event_ids: + return + + joined_users = yield self.state_handler.get_current_user_in_room( + room_id, latest_event_ids, + ) + num_joined_users = len(joined_users) + rooms_to_num_joined[room_id] = num_joined_users + + if num_joined_users == 0: + return + + # We want larger rooms to be first, hence negating num_joined_users + rooms_to_order_value[room_id] = (-num_joined_users, room_id) + + yield concurrently_execute(get_order_for_room, room_ids, 10) + + sorted_entries = sorted(rooms_to_order_value.items(), key=lambda e: e[1]) + sorted_rooms = [room_id for room_id, _ in sorted_entries] + + # `sorted_rooms` should now be a list of all public room ids that is + # stable across pagination. Therefore, we can use indices into this + # list as our pagination tokens. + + # Filter out rooms that we don't want to return + rooms_to_scan = [ + r for r in sorted_rooms + if r not in newly_unpublished and rooms_to_num_joined[room_id] > 0 + ] + + total_room_count = len(rooms_to_scan) + + if since_token: + # Filter out rooms we've already returned previously + # `since_token.current_limit` is the index of the last room we + # sent down, so we exclude it and everything before/after it. + if since_token.direction_is_forward: + rooms_to_scan = rooms_to_scan[since_token.current_limit + 1:] + else: + rooms_to_scan = rooms_to_scan[:since_token.current_limit] + rooms_to_scan.reverse() + + # Actually generate the entries. _generate_room_entry will append to + # chunk but will stop if len(chunk) > limit + chunk = [] + if limit and not search_filter: + step = limit + 1 + for i in xrange(0, len(rooms_to_scan), step): + # We iterate here because the vast majority of cases we'll stop + # at first iteration, but occaisonally _generate_room_entry + # won't append to the chunk and so we need to loop again. + # We don't want to scan over the entire range either as that + # would potentially waste a lot of work. + yield concurrently_execute( + lambda r: self._generate_room_entry( + r, rooms_to_num_joined[r], + chunk, limit, search_filter + ), + rooms_to_scan[i:i + step], 10 + ) + if len(chunk) >= limit + 1: + break + else: + yield concurrently_execute( + lambda r: self._generate_room_entry( + r, rooms_to_num_joined[r], + chunk, limit, search_filter + ), + rooms_to_scan, 5 + ) + + chunk.sort(key=lambda e: (-e["num_joined_members"], e["room_id"])) + + # Work out the new limit of the batch for pagination, or None if we + # know there are no more results that would be returned. + # i.e., [since_token.current_limit..new_limit] is the batch of rooms + # we've returned (or the reverse if we paginated backwards) + # We tried to pull out limit + 1 rooms above, so if we have <= limit + # then we know there are no more results to return + new_limit = None + if chunk and (not limit or len(chunk) > limit): + + if not since_token or since_token.direction_is_forward: + if limit: + chunk = chunk[:limit] + last_room_id = chunk[-1]["room_id"] + else: + if limit: + chunk = chunk[-limit:] + last_room_id = chunk[0]["room_id"] + + new_limit = sorted_rooms.index(last_room_id) + + results = { + "chunk": chunk, + "total_room_count_estimate": total_room_count, + } + + if since_token: + results["new_rooms"] = bool(newly_visible) + + if not since_token or since_token.direction_is_forward: + if new_limit is not None: + results["next_batch"] = RoomListNextBatch( + stream_ordering=stream_token, + public_room_stream_id=public_room_stream_id, + current_limit=new_limit, + direction_is_forward=True, + ).to_token() + + if since_token: + results["prev_batch"] = since_token.copy_and_replace( + direction_is_forward=False, + current_limit=since_token.current_limit + 1, + ).to_token() + else: + if new_limit is not None: + results["prev_batch"] = RoomListNextBatch( + stream_ordering=stream_token, + public_room_stream_id=public_room_stream_id, + current_limit=new_limit, + direction_is_forward=False, + ).to_token() + + if since_token: + results["next_batch"] = since_token.copy_and_replace( + direction_is_forward=True, + current_limit=since_token.current_limit - 1, + ).to_token() + + defer.returnValue(results) + + @defer.inlineCallbacks + def _generate_room_entry(self, room_id, num_joined_users, chunk, limit, + search_filter): + if limit and len(chunk) > limit + 1: + # We've already got enough, so lets just drop it. + return + + result = { + "room_id": room_id, + "num_joined_members": num_joined_users, + } + + current_state_ids = yield self.state_handler.get_current_state_ids(room_id) + + event_map = yield self.store.get_events([ + event_id for key, event_id in current_state_ids.items() + if key[0] in ( + EventTypes.JoinRules, + EventTypes.Name, + EventTypes.Topic, + EventTypes.CanonicalAlias, + EventTypes.RoomHistoryVisibility, + EventTypes.GuestAccess, + "m.room.avatar", + ) + ]) + + current_state = { + (ev.type, ev.state_key): ev + for ev in event_map.values() + } + + # Double check that this is actually a public room. + join_rules_event = current_state.get((EventTypes.JoinRules, "")) + if join_rules_event: + join_rule = join_rules_event.content.get("join_rule", None) + if join_rule and join_rule != JoinRules.PUBLIC: + defer.returnValue(None) + + aliases = yield self.store.get_aliases_for_room(room_id) + if aliases: + result["aliases"] = aliases + + name_event = yield current_state.get((EventTypes.Name, "")) + if name_event: + name = name_event.content.get("name", None) + if name: + result["name"] = name + + topic_event = current_state.get((EventTypes.Topic, "")) + if topic_event: + topic = topic_event.content.get("topic", None) + if topic: + result["topic"] = topic + + canonical_event = current_state.get((EventTypes.CanonicalAlias, "")) + if canonical_event: + canonical_alias = canonical_event.content.get("alias", None) + if canonical_alias: + result["canonical_alias"] = canonical_alias + + visibility_event = current_state.get((EventTypes.RoomHistoryVisibility, "")) + visibility = None + if visibility_event: + visibility = visibility_event.content.get("history_visibility", None) + result["world_readable"] = visibility == "world_readable" + + guest_event = current_state.get((EventTypes.GuestAccess, "")) + guest = None + if guest_event: + guest = guest_event.content.get("guest_access", None) + result["guest_can_join"] = guest == "can_join" + + avatar_event = current_state.get(("m.room.avatar", "")) + if avatar_event: + avatar_url = avatar_event.content.get("url", None) + if avatar_url: + result["avatar_url"] = avatar_url + + if _matches_room_entry(result, search_filter): + chunk.append(result) + + @defer.inlineCallbacks + def get_remote_public_room_list(self, server_name, limit=None, since_token=None, + search_filter=None): + if search_filter: + # We currently don't support searching across federation, so we have + # to do it manually without pagination + limit = None + since_token = None + + res = yield self._get_remote_list_cached( + server_name, limit=limit, since_token=since_token, + ) + + if search_filter: + res = {"chunk": [ + entry + for entry in list(res.get("chunk", [])) + if _matches_room_entry(entry, search_filter) + ]} + + defer.returnValue(res) + + def _get_remote_list_cached(self, server_name, limit=None, since_token=None, + search_filter=None): + repl_layer = self.hs.get_replication_layer() + if search_filter: + # We can't cache when asking for search + return repl_layer.get_public_rooms( + server_name, limit=limit, since_token=since_token, + search_filter=search_filter, + ) + + result = self.remote_response_cache.get((server_name, limit, since_token)) + if not result: + result = self.remote_response_cache.set( + (server_name, limit, since_token), + repl_layer.get_public_rooms( + server_name, limit=limit, since_token=since_token, + search_filter=search_filter, + ) + ) + return result + + +class RoomListNextBatch(namedtuple("RoomListNextBatch", ( + "stream_ordering", # stream_ordering of the first public room list + "public_room_stream_id", # public room stream id for first public room list + "current_limit", # The number of previous rooms returned + "direction_is_forward", # Bool if this is a next_batch, false if prev_batch +))): + + KEY_DICT = { + "stream_ordering": "s", + "public_room_stream_id": "p", + "current_limit": "n", + "direction_is_forward": "d", + } + + REVERSE_KEY_DICT = {v: k for k, v in KEY_DICT.items()} + + @classmethod + def from_token(cls, token): + return RoomListNextBatch(**{ + cls.REVERSE_KEY_DICT[key]: val + for key, val in msgpack.loads(decode_base64(token)).items() + }) + + def to_token(self): + return encode_base64(msgpack.dumps({ + self.KEY_DICT[key]: val + for key, val in self._asdict().items() + })) + + def copy_and_replace(self, **kwds): + return self._replace( + **kwds + ) + + +def _matches_room_entry(room_entry, search_filter): + if search_filter and search_filter.get("generic_search_term", None): + generic_search_term = search_filter["generic_search_term"].upper() + if generic_search_term in room_entry.get("name", "").upper(): + return True + elif generic_search_term in room_entry.get("topic", "").upper(): + return True + elif generic_search_term in room_entry.get("canonical_alias", "").upper(): + return True + else: + return True + + return False diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py new file mode 100644 index 0000000000..ba49075a20 --- /dev/null +++ b/synapse/handlers/room_member.py @@ -0,0 +1,735 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging + +from signedjson.key import decode_verify_key_bytes +from signedjson.sign import verify_signed_json +from twisted.internet import defer +from unpaddedbase64 import decode_base64 + +import synapse.types +from synapse.api.constants import ( + EventTypes, Membership, +) +from synapse.api.errors import AuthError, SynapseError, Codes +from synapse.types import UserID, RoomID +from synapse.util.async import Linearizer +from synapse.util.distributor import user_left_room, user_joined_room +from ._base import BaseHandler + +logger = logging.getLogger(__name__) + +id_server_scheme = "https://" + + +class RoomMemberHandler(BaseHandler): + # TODO(paul): This handler currently contains a messy conflation of + # low-level API that works on UserID objects and so on, and REST-level + # API that takes ID strings and returns pagination chunks. These concerns + # ought to be separated out a lot better. + + def __init__(self, hs): + super(RoomMemberHandler, self).__init__(hs) + + self.member_linearizer = Linearizer() + + self.clock = hs.get_clock() + + self.distributor = hs.get_distributor() + self.distributor.declare("user_joined_room") + self.distributor.declare("user_left_room") + + @defer.inlineCallbacks + def _local_membership_update( + self, requester, target, room_id, membership, + prev_event_ids, + txn_id=None, + ratelimit=True, + content=None, + ): + if content is None: + content = {} + msg_handler = self.hs.get_handlers().message_handler + + content["membership"] = membership + if requester.is_guest: + content["kind"] = "guest" + + event, context = yield msg_handler.create_event( + { + "type": EventTypes.Member, + "content": content, + "room_id": room_id, + "sender": requester.user.to_string(), + "state_key": target.to_string(), + + # For backwards compatibility: + "membership": membership, + }, + token_id=requester.access_token_id, + txn_id=txn_id, + prev_event_ids=prev_event_ids, + ) + + # Check if this event matches the previous membership event for the user. + duplicate = yield msg_handler.deduplicate_state_event(event, context) + if duplicate is not None: + # Discard the new event since this membership change is a no-op. + return + + yield msg_handler.handle_new_client_event( + requester, + event, + context, + extra_users=[target], + ratelimit=ratelimit, + ) + + prev_member_event_id = context.prev_state_ids.get( + (EventTypes.Member, target.to_string()), + None + ) + + if event.membership == Membership.JOIN: + # Only fire user_joined_room if the user has acutally joined the + # room. Don't bother if the user is just changing their profile + # info. + newly_joined = True + if prev_member_event_id: + prev_member_event = yield self.store.get_event(prev_member_event_id) + newly_joined = prev_member_event.membership != Membership.JOIN + if newly_joined: + yield user_joined_room(self.distributor, target, room_id) + elif event.membership == Membership.LEAVE: + if prev_member_event_id: + prev_member_event = yield self.store.get_event(prev_member_event_id) + if prev_member_event.membership == Membership.JOIN: + user_left_room(self.distributor, target, room_id) + + @defer.inlineCallbacks + def remote_join(self, remote_room_hosts, room_id, user, content): + if len(remote_room_hosts) == 0: + raise SynapseError(404, "No known servers") + + # We don't do an auth check if we are doing an invite + # join dance for now, since we're kinda implicitly checking + # that we are allowed to join when we decide whether or not we + # need to do the invite/join dance. + yield self.hs.get_handlers().federation_handler.do_invite_join( + remote_room_hosts, + room_id, + user.to_string(), + content, + ) + yield user_joined_room(self.distributor, user, room_id) + + def reject_remote_invite(self, user_id, room_id, remote_room_hosts): + return self.hs.get_handlers().federation_handler.do_remotely_reject_invite( + remote_room_hosts, + room_id, + user_id + ) + + @defer.inlineCallbacks + def update_membership( + self, + requester, + target, + room_id, + action, + txn_id=None, + remote_room_hosts=None, + third_party_signed=None, + ratelimit=True, + content=None, + ): + key = (room_id,) + + with (yield self.member_linearizer.queue(key)): + result = yield self._update_membership( + requester, + target, + room_id, + action, + txn_id=txn_id, + remote_room_hosts=remote_room_hosts, + third_party_signed=third_party_signed, + ratelimit=ratelimit, + content=content, + ) + + defer.returnValue(result) + + @defer.inlineCallbacks + def _update_membership( + self, + requester, + target, + room_id, + action, + txn_id=None, + remote_room_hosts=None, + third_party_signed=None, + ratelimit=True, + content=None, + ): + if content is None: + content = {} + + effective_membership_state = action + if action in ["kick", "unban"]: + effective_membership_state = "leave" + + if third_party_signed is not None: + replication = self.hs.get_replication_layer() + yield replication.exchange_third_party_invite( + third_party_signed["sender"], + target.to_string(), + room_id, + third_party_signed, + ) + + if not remote_room_hosts: + remote_room_hosts = [] + + latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) + current_state_ids = yield self.state_handler.get_current_state_ids( + room_id, latest_event_ids=latest_event_ids, + ) + + old_state_id = current_state_ids.get((EventTypes.Member, target.to_string())) + if old_state_id: + old_state = yield self.store.get_event(old_state_id, allow_none=True) + old_membership = old_state.content.get("membership") if old_state else None + if action == "unban" and old_membership != "ban": + raise SynapseError( + 403, + "Cannot unban user who was not banned" + " (membership=%s)" % old_membership, + errcode=Codes.BAD_STATE + ) + if old_membership == "ban" and action != "unban": + raise SynapseError( + 403, + "Cannot %s user who was banned" % (action,), + errcode=Codes.BAD_STATE + ) + + is_host_in_room = yield self._is_host_in_room(current_state_ids) + + if effective_membership_state == Membership.JOIN: + if requester.is_guest and not self._can_guest_join(current_state_ids): + # This should be an auth check, but guests are a local concept, + # so don't really fit into the general auth process. + raise AuthError(403, "Guest access not allowed") + + if not is_host_in_room: + inviter = yield self.get_inviter(target.to_string(), room_id) + if inviter and not self.hs.is_mine(inviter): + remote_room_hosts.append(inviter.domain) + + content["membership"] = Membership.JOIN + + profile = self.hs.get_handlers().profile_handler + content["displayname"] = yield profile.get_displayname(target) + content["avatar_url"] = yield profile.get_avatar_url(target) + + if requester.is_guest: + content["kind"] = "guest" + + ret = yield self.remote_join( + remote_room_hosts, room_id, target, content + ) + defer.returnValue(ret) + + elif effective_membership_state == Membership.LEAVE: + if not is_host_in_room: + # perhaps we've been invited + inviter = yield self.get_inviter(target.to_string(), room_id) + if not inviter: + raise SynapseError(404, "Not a known room") + + if self.hs.is_mine(inviter): + # the inviter was on our server, but has now left. Carry on + # with the normal rejection codepath. + # + # This is a bit of a hack, because the room might still be + # active on other servers. + pass + else: + # send the rejection to the inviter's HS. + remote_room_hosts = remote_room_hosts + [inviter.domain] + + try: + ret = yield self.reject_remote_invite( + target.to_string(), room_id, remote_room_hosts + ) + defer.returnValue(ret) + except SynapseError as e: + logger.warn("Failed to reject invite: %s", e) + + yield self.store.locally_reject_invite( + target.to_string(), room_id + ) + + defer.returnValue({}) + + yield self._local_membership_update( + requester=requester, + target=target, + room_id=room_id, + membership=effective_membership_state, + txn_id=txn_id, + ratelimit=ratelimit, + prev_event_ids=latest_event_ids, + content=content, + ) + + @defer.inlineCallbacks + def send_membership_event( + self, + requester, + event, + context, + remote_room_hosts=None, + ratelimit=True, + ): + """ + Change the membership status of a user in a room. + + Args: + requester (Requester): The local user who requested the membership + event. If None, certain checks, like whether this homeserver can + act as the sender, will be skipped. + event (SynapseEvent): The membership event. + context: The context of the event. + is_guest (bool): Whether the sender is a guest. + room_hosts ([str]): Homeservers which are likely to already be in + the room, and could be danced with in order to join this + homeserver for the first time. + ratelimit (bool): Whether to rate limit this request. + Raises: + SynapseError if there was a problem changing the membership. + """ + remote_room_hosts = remote_room_hosts or [] + + target_user = UserID.from_string(event.state_key) + room_id = event.room_id + + if requester is not None: + sender = UserID.from_string(event.sender) + assert sender == requester.user, ( + "Sender (%s) must be same as requester (%s)" % + (sender, requester.user) + ) + assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,) + else: + requester = synapse.types.create_requester(target_user) + + message_handler = self.hs.get_handlers().message_handler + prev_event = yield message_handler.deduplicate_state_event(event, context) + if prev_event is not None: + return + + if event.membership == Membership.JOIN: + if requester.is_guest: + guest_can_join = yield self._can_guest_join(context.prev_state_ids) + if not guest_can_join: + # This should be an auth check, but guests are a local concept, + # so don't really fit into the general auth process. + raise AuthError(403, "Guest access not allowed") + + yield message_handler.handle_new_client_event( + requester, + event, + context, + extra_users=[target_user], + ratelimit=ratelimit, + ) + + prev_member_event_id = context.prev_state_ids.get( + (EventTypes.Member, event.state_key), + None + ) + + if event.membership == Membership.JOIN: + # Only fire user_joined_room if the user has acutally joined the + # room. Don't bother if the user is just changing their profile + # info. + newly_joined = True + if prev_member_event_id: + prev_member_event = yield self.store.get_event(prev_member_event_id) + newly_joined = prev_member_event.membership != Membership.JOIN + if newly_joined: + yield user_joined_room(self.distributor, target_user, room_id) + elif event.membership == Membership.LEAVE: + if prev_member_event_id: + prev_member_event = yield self.store.get_event(prev_member_event_id) + if prev_member_event.membership == Membership.JOIN: + user_left_room(self.distributor, target_user, room_id) + + @defer.inlineCallbacks + def _can_guest_join(self, current_state_ids): + """ + Returns whether a guest can join a room based on its current state. + """ + guest_access_id = current_state_ids.get((EventTypes.GuestAccess, ""), None) + if not guest_access_id: + defer.returnValue(False) + + guest_access = yield self.store.get_event(guest_access_id) + + defer.returnValue( + guest_access + and guest_access.content + and "guest_access" in guest_access.content + and guest_access.content["guest_access"] == "can_join" + ) + + @defer.inlineCallbacks + def lookup_room_alias(self, room_alias): + """ + Get the room ID associated with a room alias. + + Args: + room_alias (RoomAlias): The alias to look up. + Returns: + A tuple of: + The room ID as a RoomID object. + Hosts likely to be participating in the room ([str]). + Raises: + SynapseError if room alias could not be found. + """ + directory_handler = self.hs.get_handlers().directory_handler + mapping = yield directory_handler.get_association(room_alias) + + if not mapping: + raise SynapseError(404, "No such room alias") + + room_id = mapping["room_id"] + servers = mapping["servers"] + + defer.returnValue((RoomID.from_string(room_id), servers)) + + @defer.inlineCallbacks + def get_inviter(self, user_id, room_id): + invite = yield self.store.get_invite_for_user_in_room( + user_id=user_id, + room_id=room_id, + ) + if invite: + defer.returnValue(UserID.from_string(invite.sender)) + + @defer.inlineCallbacks + def do_3pid_invite( + self, + room_id, + inviter, + medium, + address, + id_server, + requester, + txn_id + ): + invitee = yield self._lookup_3pid( + id_server, medium, address + ) + + if invitee: + yield self.update_membership( + requester, + UserID.from_string(invitee), + room_id, + "invite", + txn_id=txn_id, + ) + else: + yield self._make_and_store_3pid_invite( + requester, + id_server, + medium, + address, + room_id, + inviter, + txn_id=txn_id + ) + + @defer.inlineCallbacks + def _lookup_3pid(self, id_server, medium, address): + """Looks up a 3pid in the passed identity server. + + Args: + id_server (str): The server name (including port, if required) + of the identity server to use. + medium (str): The type of the third party identifier (e.g. "email"). + address (str): The third party identifier (e.g. "foo@example.com"). + + Returns: + str: the matrix ID of the 3pid, or None if it is not recognized. + """ + try: + data = yield self.hs.get_simple_http_client().get_json( + "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,), + { + "medium": medium, + "address": address, + } + ) + + if "mxid" in data: + if "signatures" not in data: + raise AuthError(401, "No signatures on 3pid binding") + self.verify_any_signature(data, id_server) + defer.returnValue(data["mxid"]) + + except IOError as e: + logger.warn("Error from identity server lookup: %s" % (e,)) + defer.returnValue(None) + + @defer.inlineCallbacks + def verify_any_signature(self, data, server_hostname): + if server_hostname not in data["signatures"]: + raise AuthError(401, "No signature from server %s" % (server_hostname,)) + for key_name, signature in data["signatures"][server_hostname].items(): + key_data = yield self.hs.get_simple_http_client().get_json( + "%s%s/_matrix/identity/api/v1/pubkey/%s" % + (id_server_scheme, server_hostname, key_name,), + ) + if "public_key" not in key_data: + raise AuthError(401, "No public key named %s from %s" % + (key_name, server_hostname,)) + verify_signed_json( + data, + server_hostname, + decode_verify_key_bytes(key_name, decode_base64(key_data["public_key"])) + ) + return + + @defer.inlineCallbacks + def _make_and_store_3pid_invite( + self, + requester, + id_server, + medium, + address, + room_id, + user, + txn_id + ): + room_state = yield self.hs.get_state_handler().get_current_state(room_id) + + inviter_display_name = "" + inviter_avatar_url = "" + member_event = room_state.get((EventTypes.Member, user.to_string())) + if member_event: + inviter_display_name = member_event.content.get("displayname", "") + inviter_avatar_url = member_event.content.get("avatar_url", "") + + canonical_room_alias = "" + canonical_alias_event = room_state.get((EventTypes.CanonicalAlias, "")) + if canonical_alias_event: + canonical_room_alias = canonical_alias_event.content.get("alias", "") + + room_name = "" + room_name_event = room_state.get((EventTypes.Name, "")) + if room_name_event: + room_name = room_name_event.content.get("name", "") + + room_join_rules = "" + join_rules_event = room_state.get((EventTypes.JoinRules, "")) + if join_rules_event: + room_join_rules = join_rules_event.content.get("join_rule", "") + + room_avatar_url = "" + room_avatar_event = room_state.get((EventTypes.RoomAvatar, "")) + if room_avatar_event: + room_avatar_url = room_avatar_event.content.get("url", "") + + token, public_keys, fallback_public_key, display_name = ( + yield self._ask_id_server_for_third_party_invite( + id_server=id_server, + medium=medium, + address=address, + room_id=room_id, + inviter_user_id=user.to_string(), + room_alias=canonical_room_alias, + room_avatar_url=room_avatar_url, + room_join_rules=room_join_rules, + room_name=room_name, + inviter_display_name=inviter_display_name, + inviter_avatar_url=inviter_avatar_url + ) + ) + + msg_handler = self.hs.get_handlers().message_handler + yield msg_handler.create_and_send_nonmember_event( + requester, + { + "type": EventTypes.ThirdPartyInvite, + "content": { + "display_name": display_name, + "public_keys": public_keys, + + # For backwards compatibility: + "key_validity_url": fallback_public_key["key_validity_url"], + "public_key": fallback_public_key["public_key"], + }, + "room_id": room_id, + "sender": user.to_string(), + "state_key": token, + }, + txn_id=txn_id, + ) + + @defer.inlineCallbacks + def _ask_id_server_for_third_party_invite( + self, + id_server, + medium, + address, + room_id, + inviter_user_id, + room_alias, + room_avatar_url, + room_join_rules, + room_name, + inviter_display_name, + inviter_avatar_url + ): + """ + Asks an identity server for a third party invite. + + Args: + id_server (str): hostname + optional port for the identity server. + medium (str): The literal string "email". + address (str): The third party address being invited. + room_id (str): The ID of the room to which the user is invited. + inviter_user_id (str): The user ID of the inviter. + room_alias (str): An alias for the room, for cosmetic notifications. + room_avatar_url (str): The URL of the room's avatar, for cosmetic + notifications. + room_join_rules (str): The join rules of the email (e.g. "public"). + room_name (str): The m.room.name of the room. + inviter_display_name (str): The current display name of the + inviter. + inviter_avatar_url (str): The URL of the inviter's avatar. + + Returns: + A deferred tuple containing: + token (str): The token which must be signed to prove authenticity. + public_keys ([{"public_key": str, "key_validity_url": str}]): + public_key is a base64-encoded ed25519 public key. + fallback_public_key: One element from public_keys. + display_name (str): A user-friendly name to represent the invited + user. + """ + + is_url = "%s%s/_matrix/identity/api/v1/store-invite" % ( + id_server_scheme, id_server, + ) + + invite_config = { + "medium": medium, + "address": address, + "room_id": room_id, + "room_alias": room_alias, + "room_avatar_url": room_avatar_url, + "room_join_rules": room_join_rules, + "room_name": room_name, + "sender": inviter_user_id, + "sender_display_name": inviter_display_name, + "sender_avatar_url": inviter_avatar_url, + } + + if self.hs.config.invite_3pid_guest: + registration_handler = self.hs.get_handlers().registration_handler + guest_access_token = yield registration_handler.guest_access_token_for( + medium=medium, + address=address, + inviter_user_id=inviter_user_id, + ) + + guest_user_info = yield self.hs.get_auth().get_user_by_access_token( + guest_access_token + ) + + invite_config.update({ + "guest_access_token": guest_access_token, + "guest_user_id": guest_user_info["user"].to_string(), + }) + + data = yield self.hs.get_simple_http_client().post_urlencoded_get_json( + is_url, + invite_config + ) + # TODO: Check for success + token = data["token"] + public_keys = data.get("public_keys", []) + if "public_key" in data: + fallback_public_key = { + "public_key": data["public_key"], + "key_validity_url": "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % ( + id_server_scheme, id_server, + ), + } + else: + fallback_public_key = public_keys[0] + + if not public_keys: + public_keys.append(fallback_public_key) + display_name = data["display_name"] + defer.returnValue((token, public_keys, fallback_public_key, display_name)) + + @defer.inlineCallbacks + def forget(self, user, room_id): + user_id = user.to_string() + + member = yield self.state_handler.get_current_state( + room_id=room_id, + event_type=EventTypes.Member, + state_key=user_id + ) + membership = member.membership if member else None + + if membership is not None and membership != Membership.LEAVE: + raise SynapseError(400, "User %s in room %s" % ( + user_id, room_id + )) + + if membership: + yield self.store.forget(user_id, room_id) + + @defer.inlineCallbacks + def _is_host_in_room(self, current_state_ids): + # Have we just created the room, and is this about to be the very + # first member event? + create_event_id = current_state_ids.get(("m.room.create", "")) + if len(current_state_ids) == 1 and create_event_id: + defer.returnValue(self.hs.is_mine_id(create_event_id)) + + for (etype, state_key), event_id in current_state_ids.items(): + if etype != EventTypes.Member or not self.hs.is_mine_id(state_key): + continue + + event = yield self.store.get_event(event_id, allow_none=True) + if not event: + continue + + if event.membership == Membership.JOIN: + defer.returnValue(True) + + defer.returnValue(False) diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 9937d8dd7f..df75d70fac 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -21,6 +21,7 @@ from synapse.api.constants import Membership, EventTypes from synapse.api.filtering import Filter from synapse.api.errors import SynapseError from synapse.events.utils import serialize_event +from synapse.visibility import filter_events_for_client from unpaddedbase64 import decode_base64, encode_base64 @@ -172,8 +173,8 @@ class SearchHandler(BaseHandler): filtered_events = search_filter.filter([r["event"] for r in results]) - events = yield self._filter_events_for_client( - user.to_string(), filtered_events + events = yield filter_events_for_client( + self.store, user.to_string(), filtered_events ) events.sort(key=lambda e: -rank_map[e.event_id]) @@ -223,8 +224,8 @@ class SearchHandler(BaseHandler): r["event"] for r in results ]) - events = yield self._filter_events_for_client( - user.to_string(), filtered_events + events = yield filter_events_for_client( + self.store, user.to_string(), filtered_events ) room_events.extend(events) @@ -281,12 +282,12 @@ class SearchHandler(BaseHandler): event.room_id, event.event_id, before_limit, after_limit ) - res["events_before"] = yield self._filter_events_for_client( - user.to_string(), res["events_before"] + res["events_before"] = yield filter_events_for_client( + self.store, user.to_string(), res["events_before"] ) - res["events_after"] = yield self._filter_events_for_client( - user.to_string(), res["events_after"] + res["events_after"] = yield filter_events_for_client( + self.store, user.to_string(), res["events_after"] ) res["start"] = now_token.copy_and_replace( diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 1f6fde8e8a..a86996689c 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -13,14 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import BaseHandler - -from synapse.streams.config import PaginationConfig from synapse.api.constants import Membership, EventTypes -from synapse.util import unwrapFirstError -from synapse.util.logcontext import LoggingContext, preserve_fn +from synapse.util.async import concurrently_execute +from synapse.util.logcontext import LoggingContext from synapse.util.metrics import Measure +from synapse.util.caches.response_cache import ResponseCache from synapse.push.clientformat import format_push_rules_for_user +from synapse.visibility import filter_events_for_client from twisted.internet import defer @@ -35,6 +34,8 @@ SyncConfig = collections.namedtuple("SyncConfig", [ "user", "filter_collection", "is_guest", + "request_key", + "device_id", ]) @@ -113,6 +114,7 @@ class SyncResult(collections.namedtuple("SyncResult", [ "joined", # JoinedSyncResult for each joined room. "invited", # InvitedSyncResult for each invited room. "archived", # ArchivedSyncResult for each archived room. + "to_device", # List of direct messages for the device. ])): __slots__ = [] @@ -126,18 +128,22 @@ class SyncResult(collections.namedtuple("SyncResult", [ self.joined or self.invited or self.archived or - self.account_data + self.account_data or + self.to_device ) -class SyncHandler(BaseHandler): +class SyncHandler(object): def __init__(self, hs): - super(SyncHandler, self).__init__(hs) + self.store = hs.get_datastore() + self.notifier = hs.get_notifier() + self.presence_handler = hs.get_presence_handler() self.event_sources = hs.get_event_sources() self.clock = hs.get_clock() + self.response_cache = ResponseCache(hs) + self.state = hs.get_state_handler() - @defer.inlineCallbacks def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0, full_state=False): """Get the sync for a client if we have new data for it now. Otherwise @@ -146,7 +152,19 @@ class SyncHandler(BaseHandler): Returns: A Deferred SyncResult. """ + result = self.response_cache.get(sync_config.request_key) + if not result: + result = self.response_cache.set( + sync_config.request_key, + self._wait_for_sync_for_user( + sync_config, since_token, timeout, full_state + ) + ) + return result + @defer.inlineCallbacks + def _wait_for_sync_for_user(self, sync_config, since_token, timeout, + full_state): context = LoggingContext.current_context() if context: if since_token is None: @@ -179,197 +197,15 @@ class SyncHandler(BaseHandler): Returns: A Deferred SyncResult. """ - if since_token is None or full_state: - return self.full_state_sync(sync_config, since_token) - else: - return self.incremental_sync_with_gap(sync_config, since_token) - - @defer.inlineCallbacks - def full_state_sync(self, sync_config, timeline_since_token): - """Get a sync for a client which is starting without any state. - - If a 'message_since_token' is given, only timeline events which have - happened since that token will be returned. - - Returns: - A Deferred SyncResult. - """ - now_token = yield self.event_sources.get_current_token() - - now_token, ephemeral_by_room = yield self.ephemeral_by_room( - sync_config, now_token - ) - - presence_stream = self.event_sources.sources["presence"] - # TODO (mjark): This looks wrong, shouldn't we be getting the presence - # UP to the present rather than after the present? - pagination_config = PaginationConfig(from_token=now_token) - presence, _ = yield presence_stream.get_pagination_rows( - user=sync_config.user, - pagination_config=pagination_config.get_source_config("presence"), - key=None - ) - - membership_list = ( - Membership.INVITE, Membership.JOIN, Membership.LEAVE, Membership.BAN - ) - - room_list = yield self.store.get_rooms_for_user_where_membership_is( - user_id=sync_config.user.to_string(), - membership_list=membership_list - ) - - account_data, account_data_by_room = ( - yield self.store.get_account_data_for_user( - sync_config.user.to_string() - ) - ) - - account_data['m.push_rules'] = yield self.push_rules_for_user( - sync_config.user - ) - - tags_by_room = yield self.store.get_tags_for_user( - sync_config.user.to_string() - ) - - joined = [] - invited = [] - archived = [] - deferreds = [] - - room_list_chunks = [room_list[i:i + 10] for i in xrange(0, len(room_list), 10)] - for room_list_chunk in room_list_chunks: - for event in room_list_chunk: - if event.membership == Membership.JOIN: - room_sync_deferred = preserve_fn( - self.full_state_sync_for_joined_room - )( - room_id=event.room_id, - sync_config=sync_config, - now_token=now_token, - timeline_since_token=timeline_since_token, - ephemeral_by_room=ephemeral_by_room, - tags_by_room=tags_by_room, - account_data_by_room=account_data_by_room, - ) - room_sync_deferred.addCallback(joined.append) - deferreds.append(room_sync_deferred) - elif event.membership == Membership.INVITE: - invite = yield self.store.get_event(event.event_id) - invited.append(InvitedSyncResult( - room_id=event.room_id, - invite=invite, - )) - elif event.membership in (Membership.LEAVE, Membership.BAN): - # Always send down rooms we were banned or kicked from. - if not sync_config.filter_collection.include_leave: - if event.membership == Membership.LEAVE: - if sync_config.user.to_string() == event.sender: - continue - - leave_token = now_token.copy_and_replace( - "room_key", "s%d" % (event.stream_ordering,) - ) - room_sync_deferred = preserve_fn( - self.full_state_sync_for_archived_room - )( - sync_config=sync_config, - room_id=event.room_id, - leave_event_id=event.event_id, - leave_token=leave_token, - timeline_since_token=timeline_since_token, - tags_by_room=tags_by_room, - account_data_by_room=account_data_by_room, - ) - room_sync_deferred.addCallback(archived.append) - deferreds.append(room_sync_deferred) - - yield defer.gatherResults( - deferreds, consumeErrors=True - ).addErrback(unwrapFirstError) - - account_data_for_user = sync_config.filter_collection.filter_account_data( - self.account_data_for_user(account_data) - ) - - presence = sync_config.filter_collection.filter_presence( - presence - ) - - defer.returnValue(SyncResult( - presence=presence, - account_data=account_data_for_user, - joined=joined, - invited=invited, - archived=archived, - next_batch=now_token, - )) - - @defer.inlineCallbacks - def full_state_sync_for_joined_room(self, room_id, sync_config, - now_token, timeline_since_token, - ephemeral_by_room, tags_by_room, - account_data_by_room): - """Sync a room for a client which is starting without any state - Returns: - A Deferred JoinedSyncResult. - """ - - batch = yield self.load_filtered_recents( - room_id, sync_config, now_token, since_token=timeline_since_token - ) - - room_sync = yield self.incremental_sync_with_gap_for_room( - room_id, sync_config, - now_token=now_token, - since_token=timeline_since_token, - ephemeral_by_room=ephemeral_by_room, - tags_by_room=tags_by_room, - account_data_by_room=account_data_by_room, - batch=batch, - full_state=True, - ) - - defer.returnValue(room_sync) + return self.generate_sync_result(sync_config, since_token, full_state) @defer.inlineCallbacks def push_rules_for_user(self, user): user_id = user.to_string() - rawrules = yield self.store.get_push_rules_for_user(user_id) - enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id) - rules = format_push_rules_for_user(user, rawrules, enabled_map) + rules = yield self.store.get_push_rules_for_user(user_id) + rules = format_push_rules_for_user(user, rules) defer.returnValue(rules) - def account_data_for_user(self, account_data): - account_data_events = [] - - for account_data_type, content in account_data.items(): - account_data_events.append({ - "type": account_data_type, - "content": content, - }) - - return account_data_events - - def account_data_for_room(self, room_id, tags_by_room, account_data_by_room): - account_data_events = [] - tags = tags_by_room.get(room_id) - if tags is not None: - account_data_events.append({ - "type": "m.tag", - "content": {"tags": tags}, - }) - - account_data = account_data_by_room.get(room_id, {}) - for account_data_type, content in account_data.items(): - account_data_events.append({ - "type": account_data_type, - "content": content, - }) - - return account_data_events - @defer.inlineCallbacks def ephemeral_by_room(self, sync_config, now_token, since_token=None): """Get the ephemeral events for each room the user is in @@ -432,255 +268,45 @@ class SyncHandler(BaseHandler): defer.returnValue((now_token, ephemeral_by_room)) - def full_state_sync_for_archived_room(self, room_id, sync_config, - leave_event_id, leave_token, - timeline_since_token, tags_by_room, - account_data_by_room): - """Sync a room for a client which is starting without any state - Returns: - A Deferred ArchivedSyncResult. - """ - - return self.incremental_sync_for_archived_room( - sync_config, room_id, leave_event_id, timeline_since_token, tags_by_room, - account_data_by_room, full_state=True, leave_token=leave_token, - ) - @defer.inlineCallbacks - def incremental_sync_with_gap(self, sync_config, since_token): - """ Get the incremental delta needed to bring the client up to - date with the server. - Returns: - A Deferred SyncResult. + def _load_filtered_recents(self, room_id, sync_config, now_token, + since_token=None, recents=None, newly_joined_room=False): """ - now_token = yield self.event_sources.get_current_token() - - rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string()) - room_ids = [room.room_id for room in rooms] - - presence_source = self.event_sources.sources["presence"] - presence, presence_key = yield presence_source.get_new_events( - user=sync_config.user, - from_key=since_token.presence_key, - limit=sync_config.filter_collection.presence_limit(), - room_ids=room_ids, - is_guest=sync_config.is_guest, - ) - now_token = now_token.copy_and_replace("presence_key", presence_key) - - now_token, ephemeral_by_room = yield self.ephemeral_by_room( - sync_config, now_token, since_token - ) - - rm_handler = self.hs.get_handlers().room_member_handler - app_service = yield self.store.get_app_service_by_user_id( - sync_config.user.to_string() - ) - if app_service: - rooms = yield self.store.get_app_service_rooms(app_service) - joined_room_ids = set(r.room_id for r in rooms) - else: - joined_room_ids = yield rm_handler.get_joined_rooms_for_user( - sync_config.user - ) - - user_id = sync_config.user.to_string() - - timeline_limit = sync_config.filter_collection.timeline_limit() - - tags_by_room = yield self.store.get_updated_tags( - user_id, - since_token.account_data_key, - ) - - account_data, account_data_by_room = ( - yield self.store.get_updated_account_data_for_user( - user_id, - since_token.account_data_key, - ) - ) - - push_rules_changed = yield self.store.have_push_rules_changed_for_user( - user_id, int(since_token.push_rules_key) - ) - - if push_rules_changed: - account_data["m.push_rules"] = yield self.push_rules_for_user( - sync_config.user - ) - - # Get a list of membership change events that have happened. - rooms_changed = yield self.store.get_membership_changes_for_user( - user_id, since_token.room_key, now_token.room_key - ) - - mem_change_events_by_room_id = {} - for event in rooms_changed: - mem_change_events_by_room_id.setdefault(event.room_id, []).append(event) - - newly_joined_rooms = [] - archived = [] - invited = [] - for room_id, events in mem_change_events_by_room_id.items(): - non_joins = [e for e in events if e.membership != Membership.JOIN] - has_join = len(non_joins) != len(events) - - # We want to figure out if we joined the room at some point since - # the last sync (even if we have since left). This is to make sure - # we do send down the room, and with full state, where necessary - if room_id in joined_room_ids or has_join: - old_state = yield self.get_state_at(room_id, since_token) - old_mem_ev = old_state.get((EventTypes.Member, user_id), None) - if not old_mem_ev or old_mem_ev.membership != Membership.JOIN: - newly_joined_rooms.append(room_id) - - if room_id in joined_room_ids: - continue - - if not non_joins: - continue - - # Only bother if we're still currently invited - should_invite = non_joins[-1].membership == Membership.INVITE - if should_invite: - room_sync = InvitedSyncResult(room_id, invite=non_joins[-1]) - if room_sync: - invited.append(room_sync) - - # Always include leave/ban events. Just take the last one. - # TODO: How do we handle ban -> leave in same batch? - leave_events = [ - e for e in non_joins - if e.membership in (Membership.LEAVE, Membership.BAN) - ] - - if leave_events: - leave_event = leave_events[-1] - room_sync = yield self.incremental_sync_for_archived_room( - sync_config, room_id, leave_event.event_id, since_token, - tags_by_room, account_data_by_room, - full_state=room_id in newly_joined_rooms - ) - if room_sync: - archived.append(room_sync) - - # Get all events for rooms we're currently joined to. - room_to_events = yield self.store.get_room_events_stream_for_rooms( - room_ids=joined_room_ids, - from_key=since_token.room_key, - to_key=now_token.room_key, - limit=timeline_limit + 1, - ) - - joined = [] - # We loop through all room ids, even if there are no new events, in case - # there are non room events taht we need to notify about. - for room_id in joined_room_ids: - room_entry = room_to_events.get(room_id, None) - - if room_entry: - events, start_key = room_entry - - prev_batch_token = now_token.copy_and_replace("room_key", start_key) - - newly_joined_room = room_id in newly_joined_rooms - full_state = newly_joined_room - - batch = yield self.load_filtered_recents( - room_id, sync_config, prev_batch_token, - since_token=since_token, - recents=events, - newly_joined_room=newly_joined_room, - ) - else: - batch = TimelineBatch( - events=[], - prev_batch=since_token, - limited=False, - ) - full_state = False - - room_sync = yield self.incremental_sync_with_gap_for_room( - room_id=room_id, - sync_config=sync_config, - since_token=since_token, - now_token=now_token, - ephemeral_by_room=ephemeral_by_room, - tags_by_room=tags_by_room, - account_data_by_room=account_data_by_room, - batch=batch, - full_state=full_state, - ) - if room_sync: - joined.append(room_sync) - - # For each newly joined room, we want to send down presence of - # existing users. - presence_handler = self.hs.get_handlers().presence_handler - extra_presence_users = set() - for room_id in newly_joined_rooms: - users = yield self.store.get_users_in_room(event.room_id) - extra_presence_users.update(users) - - # For each new member, send down presence. - for joined_sync in joined: - it = itertools.chain(joined_sync.timeline.events, joined_sync.state.values()) - for event in it: - if event.type == EventTypes.Member: - if event.membership == Membership.JOIN: - extra_presence_users.add(event.state_key) - - states = yield presence_handler.get_states( - [u for u in extra_presence_users if u != user_id], - as_event=True, - ) - presence.extend(states) - - account_data_for_user = sync_config.filter_collection.filter_account_data( - self.account_data_for_user(account_data) - ) - - presence = sync_config.filter_collection.filter_presence( - presence - ) - - defer.returnValue(SyncResult( - presence=presence, - account_data=account_data_for_user, - joined=joined, - invited=invited, - archived=archived, - next_batch=now_token, - )) - - @defer.inlineCallbacks - def load_filtered_recents(self, room_id, sync_config, now_token, - since_token=None, recents=None, newly_joined_room=False): - """ - :returns a Deferred TimelineBatch + Returns: + a Deferred TimelineBatch """ with Measure(self.clock, "load_filtered_recents"): - filtering_factor = 2 timeline_limit = sync_config.filter_collection.timeline_limit() - load_limit = max(timeline_limit * filtering_factor, 10) - max_repeat = 5 # Only try a few times per room, otherwise - room_key = now_token.room_key - end_key = room_key + block_all_timeline = sync_config.filter_collection.blocks_all_room_timeline() if recents is None or newly_joined_room or timeline_limit < len(recents): limited = True else: limited = False - if recents is not None: + if recents: recents = sync_config.filter_collection.filter_room_timeline(recents) - recents = yield self._filter_events_for_client( + recents = yield filter_events_for_client( + self.store, sync_config.user.to_string(), recents, ) else: recents = [] + if not limited or block_all_timeline: + defer.returnValue(TimelineBatch( + events=recents, + prev_batch=now_token, + limited=False + )) + + filtering_factor = 2 + load_limit = max(timeline_limit * filtering_factor, 10) + max_repeat = 5 # Only try a few times per room, otherwise + room_key = now_token.room_key + end_key = room_key + since_key = None if since_token and not newly_joined_room: since_key = since_token.room_key @@ -695,7 +321,8 @@ class SyncHandler(BaseHandler): loaded_recents = sync_config.filter_collection.filter_room_timeline( events ) - loaded_recents = yield self._filter_events_for_client( + loaded_recents = yield filter_events_for_client( + self.store, sync_config.user.to_string(), loaded_recents, ) @@ -723,122 +350,32 @@ class SyncHandler(BaseHandler): )) @defer.inlineCallbacks - def incremental_sync_with_gap_for_room(self, room_id, sync_config, - since_token, now_token, - ephemeral_by_room, tags_by_room, - account_data_by_room, - batch, full_state=False): - state = yield self.compute_state_delta( - room_id, batch, sync_config, since_token, now_token, - full_state=full_state - ) - - account_data = self.account_data_for_room( - room_id, tags_by_room, account_data_by_room - ) - - account_data = sync_config.filter_collection.filter_room_account_data( - account_data - ) - - ephemeral = sync_config.filter_collection.filter_room_ephemeral( - ephemeral_by_room.get(room_id, []) - ) - - unread_notifications = {} - room_sync = JoinedSyncResult( - room_id=room_id, - timeline=batch, - state=state, - ephemeral=ephemeral, - account_data=account_data, - unread_notifications=unread_notifications, - ) - - if room_sync: - notifs = yield self.unread_notifs_for_room_id( - room_id, sync_config - ) - - if notifs is not None: - unread_notifications["notification_count"] = notifs["notify_count"] - unread_notifications["highlight_count"] = notifs["highlight_count"] - - logger.debug("Room sync: %r", room_sync) - - defer.returnValue(room_sync) - - @defer.inlineCallbacks - def incremental_sync_for_archived_room(self, sync_config, room_id, leave_event_id, - since_token, tags_by_room, - account_data_by_room, full_state, - leave_token=None): - """ Get the incremental delta needed to bring the client up to date for - the archived room. - Returns: - A Deferred ArchivedSyncResult - """ - - if not leave_token: - stream_token = yield self.store.get_stream_token_for_event( - leave_event_id - ) - - leave_token = since_token.copy_and_replace("room_key", stream_token) - - if since_token and since_token.is_after(leave_token): - defer.returnValue(None) - - batch = yield self.load_filtered_recents( - room_id, sync_config, leave_token, since_token, - ) - - logger.debug("Recents %r", batch) - - state_events_delta = yield self.compute_state_delta( - room_id, batch, sync_config, since_token, leave_token, - full_state=full_state - ) - - account_data = self.account_data_for_room( - room_id, tags_by_room, account_data_by_room - ) - - account_data = sync_config.filter_collection.filter_room_account_data( - account_data - ) - - room_sync = ArchivedSyncResult( - room_id=room_id, - timeline=batch, - state=state_events_delta, - account_data=account_data, - ) - - logger.debug("Room sync: %r", room_sync) - - defer.returnValue(room_sync) - - @defer.inlineCallbacks def get_state_after_event(self, event): """ Get the room state after the given event - :param synapse.events.EventBase event: event of interest - :return: A Deferred map from ((type, state_key)->Event) + Args: + event(synapse.events.EventBase): event of interest + + Returns: + A Deferred map from ((type, state_key)->Event) """ - state = yield self.store.get_state_for_event(event.event_id) + state_ids = yield self.store.get_state_ids_for_event(event.event_id) if event.is_state(): - state = state.copy() - state[(event.type, event.state_key)] = event - defer.returnValue(state) + state_ids = state_ids.copy() + state_ids[(event.type, event.state_key)] = event.event_id + defer.returnValue(state_ids) @defer.inlineCallbacks def get_state_at(self, room_id, stream_position): """ Get the room state at a particular stream position - :param str room_id: room for which to get state - :param StreamToken stream_position: point at which to get state - :returns: A Deferred map from ((type, state_key)->Event) + + Args: + room_id(str): room for which to get state + stream_position(StreamToken): point at which to get state + + Returns: + A Deferred map from ((type, state_key)->Event) """ last_events, token = yield self.store.get_recent_events_for_room( room_id, end_token=stream_position.room_key, limit=1, @@ -859,15 +396,18 @@ class SyncHandler(BaseHandler): """ Works out the differnce in state between the start of the timeline and the previous sync. - :param str room_id - :param TimelineBatch batch: The timeline batch for the room that will - be sent to the user. - :param sync_config - :param str since_token: Token of the end of the previous batch. May be None. - :param str now_token: Token of the end of the current batch. - :param bool full_state: Whether to force returning the full state. + Args: + room_id(str): + batch(synapse.handlers.sync.TimelineBatch): The timeline batch for + the room that will be sent to the user. + sync_config(synapse.handlers.sync.SyncConfig): + since_token(str|None): Token of the end of the previous batch. May + be None. + now_token(str): Token of the end of the current batch. + full_state(bool): Whether to force returning the full state. - :returns A new event dictionary + Returns: + A deferred new event dictionary """ # TODO(mjark) Check if the state events were received by the server # after the previous sync, since we need to include those state @@ -877,80 +417,66 @@ class SyncHandler(BaseHandler): with Measure(self.clock, "compute_state_delta"): if full_state: if batch: - current_state = yield self.store.get_state_for_event( + current_state_ids = yield self.store.get_state_ids_for_event( batch.events[-1].event_id ) - state = yield self.store.get_state_for_event( + state_ids = yield self.store.get_state_ids_for_event( batch.events[0].event_id ) else: - current_state = yield self.get_state_at( + current_state_ids = yield self.get_state_at( room_id, stream_position=now_token ) - state = current_state + state_ids = current_state_ids timeline_state = { - (event.type, event.state_key): event + (event.type, event.state_key): event.event_id for event in batch.events if event.is_state() } - state = _calculate_state( + state_ids = _calculate_state( timeline_contains=timeline_state, - timeline_start=state, + timeline_start=state_ids, previous={}, - current=current_state, + current=current_state_ids, ) elif batch.limited: state_at_previous_sync = yield self.get_state_at( room_id, stream_position=since_token ) - current_state = yield self.store.get_state_for_event( + current_state_ids = yield self.store.get_state_ids_for_event( batch.events[-1].event_id ) - state_at_timeline_start = yield self.store.get_state_for_event( + state_at_timeline_start = yield self.store.get_state_ids_for_event( batch.events[0].event_id ) timeline_state = { - (event.type, event.state_key): event + (event.type, event.state_key): event.event_id for event in batch.events if event.is_state() } - state = _calculate_state( + state_ids = _calculate_state( timeline_contains=timeline_state, timeline_start=state_at_timeline_start, previous=state_at_previous_sync, - current=current_state, + current=current_state_ids, ) else: - state = {} + state_ids = {} - defer.returnValue({ - (e.type, e.state_key): e - for e in sync_config.filter_collection.filter_room_state(state.values()) - }) - - def check_joined_room(self, sync_config, state_delta): - """ - Check if the user has just joined the given room (so should - be given the full state) + state = {} + if state_ids: + state = yield self.store.get_events(state_ids.values()) - :param sync_config: - :param dict[(str,str), synapse.events.FrozenEvent] state_delta: the - difference in state since the last sync - - :returns A deferred Tuple (state_delta, limited) - """ - join_event = state_delta.get(( - EventTypes.Member, sync_config.user.to_string()), None) - if join_event is not None: - if join_event.content["membership"] == Membership.JOIN: - return True - return False + defer.returnValue({ + (e.type, e.state_key): e + for e in sync_config.filter_collection.filter_room_state(state.values()) + }) @defer.inlineCallbacks def unread_notifs_for_room_id(self, room_id, sync_config): @@ -968,9 +494,613 @@ class SyncHandler(BaseHandler): ) defer.returnValue(notifs) - # There is no new information in this period, so your notification - # count is whatever it was last time. - defer.returnValue(None) + # There is no new information in this period, so your notification + # count is whatever it was last time. + defer.returnValue(None) + + @defer.inlineCallbacks + def generate_sync_result(self, sync_config, since_token=None, full_state=False): + """Generates a sync result. + + Args: + sync_config (SyncConfig) + since_token (StreamToken) + full_state (bool) + + Returns: + Deferred(SyncResult) + """ + + # NB: The now_token gets changed by some of the generate_sync_* methods, + # this is due to some of the underlying streams not supporting the ability + # to query up to a given point. + # Always use the `now_token` in `SyncResultBuilder` + now_token = yield self.event_sources.get_current_token() + + sync_result_builder = SyncResultBuilder( + sync_config, full_state, + since_token=since_token, + now_token=now_token, + ) + + account_data_by_room = yield self._generate_sync_entry_for_account_data( + sync_result_builder + ) + + res = yield self._generate_sync_entry_for_rooms( + sync_result_builder, account_data_by_room + ) + newly_joined_rooms, newly_joined_users = res + + block_all_presence_data = ( + since_token is None and + sync_config.filter_collection.blocks_all_presence() + ) + if not block_all_presence_data: + yield self._generate_sync_entry_for_presence( + sync_result_builder, newly_joined_rooms, newly_joined_users + ) + + yield self._generate_sync_entry_for_to_device(sync_result_builder) + + defer.returnValue(SyncResult( + presence=sync_result_builder.presence, + account_data=sync_result_builder.account_data, + joined=sync_result_builder.joined, + invited=sync_result_builder.invited, + archived=sync_result_builder.archived, + to_device=sync_result_builder.to_device, + next_batch=sync_result_builder.now_token, + )) + + @defer.inlineCallbacks + def _generate_sync_entry_for_to_device(self, sync_result_builder): + """Generates the portion of the sync response. Populates + `sync_result_builder` with the result. + + Args: + sync_result_builder(SyncResultBuilder) + + Returns: + Deferred(dict): A dictionary containing the per room account data. + """ + user_id = sync_result_builder.sync_config.user.to_string() + device_id = sync_result_builder.sync_config.device_id + now_token = sync_result_builder.now_token + since_stream_id = 0 + if sync_result_builder.since_token is not None: + since_stream_id = int(sync_result_builder.since_token.to_device_key) + + if since_stream_id != int(now_token.to_device_key): + # We only delete messages when a new message comes in, but that's + # fine so long as we delete them at some point. + + logger.debug("Deleting messages up to %d", since_stream_id) + yield self.store.delete_messages_for_device( + user_id, device_id, since_stream_id + ) + + logger.debug("Getting messages up to %d", now_token.to_device_key) + messages, stream_id = yield self.store.get_new_messages_for_device( + user_id, device_id, since_stream_id, now_token.to_device_key + ) + logger.debug("Got messages up to %d: %r", stream_id, messages) + sync_result_builder.now_token = now_token.copy_and_replace( + "to_device_key", stream_id + ) + sync_result_builder.to_device = messages + else: + sync_result_builder.to_device = [] + + @defer.inlineCallbacks + def _generate_sync_entry_for_account_data(self, sync_result_builder): + """Generates the account data portion of the sync response. Populates + `sync_result_builder` with the result. + + Args: + sync_result_builder(SyncResultBuilder) + + Returns: + Deferred(dict): A dictionary containing the per room account data. + """ + sync_config = sync_result_builder.sync_config + user_id = sync_result_builder.sync_config.user.to_string() + since_token = sync_result_builder.since_token + + if since_token and not sync_result_builder.full_state: + account_data, account_data_by_room = ( + yield self.store.get_updated_account_data_for_user( + user_id, + since_token.account_data_key, + ) + ) + + push_rules_changed = yield self.store.have_push_rules_changed_for_user( + user_id, int(since_token.push_rules_key) + ) + + if push_rules_changed: + account_data["m.push_rules"] = yield self.push_rules_for_user( + sync_config.user + ) + else: + account_data, account_data_by_room = ( + yield self.store.get_account_data_for_user( + sync_config.user.to_string() + ) + ) + + account_data['m.push_rules'] = yield self.push_rules_for_user( + sync_config.user + ) + + account_data_for_user = sync_config.filter_collection.filter_account_data([ + {"type": account_data_type, "content": content} + for account_data_type, content in account_data.items() + ]) + + sync_result_builder.account_data = account_data_for_user + + defer.returnValue(account_data_by_room) + + @defer.inlineCallbacks + def _generate_sync_entry_for_presence(self, sync_result_builder, newly_joined_rooms, + newly_joined_users): + """Generates the presence portion of the sync response. Populates the + `sync_result_builder` with the result. + + Args: + sync_result_builder(SyncResultBuilder) + newly_joined_rooms(list): List of rooms that the user has joined + since the last sync (or empty if an initial sync) + newly_joined_users(list): List of users that have joined rooms + since the last sync (or empty if an initial sync) + """ + now_token = sync_result_builder.now_token + sync_config = sync_result_builder.sync_config + user = sync_result_builder.sync_config.user + + presence_source = self.event_sources.sources["presence"] + + since_token = sync_result_builder.since_token + if since_token and not sync_result_builder.full_state: + presence_key = since_token.presence_key + include_offline = True + else: + presence_key = None + include_offline = False + + presence, presence_key = yield presence_source.get_new_events( + user=user, + from_key=presence_key, + is_guest=sync_config.is_guest, + include_offline=include_offline, + ) + sync_result_builder.now_token = now_token.copy_and_replace( + "presence_key", presence_key + ) + + extra_users_ids = set(newly_joined_users) + for room_id in newly_joined_rooms: + users = yield self.state.get_current_user_in_room(room_id) + extra_users_ids.update(users) + extra_users_ids.discard(user.to_string()) + + states = yield self.presence_handler.get_states( + extra_users_ids, + as_event=True, + ) + presence.extend(states) + + # Deduplicate the presence entries so that there's at most one per user + presence = {p["content"]["user_id"]: p for p in presence}.values() + + presence = sync_config.filter_collection.filter_presence( + presence + ) + + sync_result_builder.presence = presence + + @defer.inlineCallbacks + def _generate_sync_entry_for_rooms(self, sync_result_builder, account_data_by_room): + """Generates the rooms portion of the sync response. Populates the + `sync_result_builder` with the result. + + Args: + sync_result_builder(SyncResultBuilder) + account_data_by_room(dict): Dictionary of per room account data + + Returns: + Deferred(tuple): Returns a 2-tuple of + `(newly_joined_rooms, newly_joined_users)` + """ + user_id = sync_result_builder.sync_config.user.to_string() + block_all_room_ephemeral = ( + sync_result_builder.since_token is None and + sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral() + ) + + if block_all_room_ephemeral: + ephemeral_by_room = {} + else: + now_token, ephemeral_by_room = yield self.ephemeral_by_room( + sync_result_builder.sync_config, + now_token=sync_result_builder.now_token, + since_token=sync_result_builder.since_token, + ) + sync_result_builder.now_token = now_token + + ignored_account_data = yield self.store.get_global_account_data_by_type_for_user( + "m.ignored_user_list", user_id=user_id, + ) + + if ignored_account_data: + ignored_users = ignored_account_data.get("ignored_users", {}).keys() + else: + ignored_users = frozenset() + + if sync_result_builder.since_token: + res = yield self._get_rooms_changed(sync_result_builder, ignored_users) + room_entries, invited, newly_joined_rooms = res + + tags_by_room = yield self.store.get_updated_tags( + user_id, + sync_result_builder.since_token.account_data_key, + ) + else: + res = yield self._get_all_rooms(sync_result_builder, ignored_users) + room_entries, invited, newly_joined_rooms = res + + tags_by_room = yield self.store.get_tags_for_user(user_id) + + def handle_room_entries(room_entry): + return self._generate_room_entry( + sync_result_builder, + ignored_users, + room_entry, + ephemeral=ephemeral_by_room.get(room_entry.room_id, []), + tags=tags_by_room.get(room_entry.room_id), + account_data=account_data_by_room.get(room_entry.room_id, {}), + always_include=sync_result_builder.full_state, + ) + + yield concurrently_execute(handle_room_entries, room_entries, 10) + + sync_result_builder.invited.extend(invited) + + # Now we want to get any newly joined users + newly_joined_users = set() + if sync_result_builder.since_token: + for joined_sync in sync_result_builder.joined: + it = itertools.chain( + joined_sync.timeline.events, joined_sync.state.values() + ) + for event in it: + if event.type == EventTypes.Member: + if event.membership == Membership.JOIN: + newly_joined_users.add(event.state_key) + + defer.returnValue((newly_joined_rooms, newly_joined_users)) + + @defer.inlineCallbacks + def _get_rooms_changed(self, sync_result_builder, ignored_users): + """Gets the the changes that have happened since the last sync. + + Args: + sync_result_builder(SyncResultBuilder) + ignored_users(set(str)): Set of users ignored by user. + + Returns: + Deferred(tuple): Returns a tuple of the form: + `([RoomSyncResultBuilder], [InvitedSyncResult], newly_joined_rooms)` + """ + user_id = sync_result_builder.sync_config.user.to_string() + since_token = sync_result_builder.since_token + now_token = sync_result_builder.now_token + sync_config = sync_result_builder.sync_config + + assert since_token + + app_service = self.store.get_app_service_by_user_id(user_id) + if app_service: + rooms = yield self.store.get_app_service_rooms(app_service) + joined_room_ids = set(r.room_id for r in rooms) + else: + rooms = yield self.store.get_rooms_for_user(user_id) + joined_room_ids = set(r.room_id for r in rooms) + + # Get a list of membership change events that have happened. + rooms_changed = yield self.store.get_membership_changes_for_user( + user_id, since_token.room_key, now_token.room_key + ) + + mem_change_events_by_room_id = {} + for event in rooms_changed: + mem_change_events_by_room_id.setdefault(event.room_id, []).append(event) + + newly_joined_rooms = [] + room_entries = [] + invited = [] + for room_id, events in mem_change_events_by_room_id.items(): + non_joins = [e for e in events if e.membership != Membership.JOIN] + has_join = len(non_joins) != len(events) + + # We want to figure out if we joined the room at some point since + # the last sync (even if we have since left). This is to make sure + # we do send down the room, and with full state, where necessary + if room_id in joined_room_ids or has_join: + old_state_ids = yield self.get_state_at(room_id, since_token) + old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None) + old_mem_ev = None + if old_mem_ev_id: + old_mem_ev = yield self.store.get_event( + old_mem_ev_id, allow_none=True + ) + if not old_mem_ev or old_mem_ev.membership != Membership.JOIN: + newly_joined_rooms.append(room_id) + + if room_id in joined_room_ids: + continue + + if not non_joins: + continue + + # Only bother if we're still currently invited + should_invite = non_joins[-1].membership == Membership.INVITE + if should_invite: + if event.sender not in ignored_users: + room_sync = InvitedSyncResult(room_id, invite=non_joins[-1]) + if room_sync: + invited.append(room_sync) + + # Always include leave/ban events. Just take the last one. + # TODO: How do we handle ban -> leave in same batch? + leave_events = [ + e for e in non_joins + if e.membership in (Membership.LEAVE, Membership.BAN) + ] + + if leave_events: + leave_event = leave_events[-1] + leave_stream_token = yield self.store.get_stream_token_for_event( + leave_event.event_id + ) + leave_token = since_token.copy_and_replace( + "room_key", leave_stream_token + ) + + if since_token and since_token.is_after(leave_token): + continue + + room_entries.append(RoomSyncResultBuilder( + room_id=room_id, + rtype="archived", + events=None, + newly_joined=room_id in newly_joined_rooms, + full_state=False, + since_token=since_token, + upto_token=leave_token, + )) + + timeline_limit = sync_config.filter_collection.timeline_limit() + + # Get all events for rooms we're currently joined to. + room_to_events = yield self.store.get_room_events_stream_for_rooms( + room_ids=joined_room_ids, + from_key=since_token.room_key, + to_key=now_token.room_key, + limit=timeline_limit + 1, + ) + + # We loop through all room ids, even if there are no new events, in case + # there are non room events taht we need to notify about. + for room_id in joined_room_ids: + room_entry = room_to_events.get(room_id, None) + + if room_entry: + events, start_key = room_entry + + prev_batch_token = now_token.copy_and_replace("room_key", start_key) + + room_entries.append(RoomSyncResultBuilder( + room_id=room_id, + rtype="joined", + events=events, + newly_joined=room_id in newly_joined_rooms, + full_state=False, + since_token=None if room_id in newly_joined_rooms else since_token, + upto_token=prev_batch_token, + )) + else: + room_entries.append(RoomSyncResultBuilder( + room_id=room_id, + rtype="joined", + events=[], + newly_joined=room_id in newly_joined_rooms, + full_state=False, + since_token=since_token, + upto_token=since_token, + )) + + defer.returnValue((room_entries, invited, newly_joined_rooms)) + + @defer.inlineCallbacks + def _get_all_rooms(self, sync_result_builder, ignored_users): + """Returns entries for all rooms for the user. + + Args: + sync_result_builder(SyncResultBuilder) + ignored_users(set(str)): Set of users ignored by user. + + Returns: + Deferred(tuple): Returns a tuple of the form: + `([RoomSyncResultBuilder], [InvitedSyncResult], [])` + """ + + user_id = sync_result_builder.sync_config.user.to_string() + since_token = sync_result_builder.since_token + now_token = sync_result_builder.now_token + sync_config = sync_result_builder.sync_config + + membership_list = ( + Membership.INVITE, Membership.JOIN, Membership.LEAVE, Membership.BAN + ) + + room_list = yield self.store.get_rooms_for_user_where_membership_is( + user_id=user_id, + membership_list=membership_list + ) + + room_entries = [] + invited = [] + + for event in room_list: + if event.membership == Membership.JOIN: + room_entries.append(RoomSyncResultBuilder( + room_id=event.room_id, + rtype="joined", + events=None, + newly_joined=False, + full_state=True, + since_token=since_token, + upto_token=now_token, + )) + elif event.membership == Membership.INVITE: + if event.sender in ignored_users: + continue + invite = yield self.store.get_event(event.event_id) + invited.append(InvitedSyncResult( + room_id=event.room_id, + invite=invite, + )) + elif event.membership in (Membership.LEAVE, Membership.BAN): + # Always send down rooms we were banned or kicked from. + if not sync_config.filter_collection.include_leave: + if event.membership == Membership.LEAVE: + if user_id == event.sender: + continue + + leave_token = now_token.copy_and_replace( + "room_key", "s%d" % (event.stream_ordering,) + ) + room_entries.append(RoomSyncResultBuilder( + room_id=event.room_id, + rtype="archived", + events=None, + newly_joined=False, + full_state=True, + since_token=since_token, + upto_token=leave_token, + )) + + defer.returnValue((room_entries, invited, [])) + + @defer.inlineCallbacks + def _generate_room_entry(self, sync_result_builder, ignored_users, + room_builder, ephemeral, tags, account_data, + always_include=False): + """Populates the `joined` and `archived` section of `sync_result_builder` + based on the `room_builder`. + + Args: + sync_result_builder(SyncResultBuilder) + ignored_users(set(str)): Set of users ignored by user. + room_builder(RoomSyncResultBuilder) + ephemeral(list): List of new ephemeral events for room + tags(list): List of *all* tags for room, or None if there has been + no change. + account_data(list): List of new account data for room + always_include(bool): Always include this room in the sync response, + even if empty. + """ + newly_joined = room_builder.newly_joined + full_state = ( + room_builder.full_state + or newly_joined + or sync_result_builder.full_state + ) + events = room_builder.events + + # We want to shortcut out as early as possible. + if not (always_include or account_data or ephemeral or full_state): + if events == [] and tags is None: + return + + since_token = sync_result_builder.since_token + now_token = sync_result_builder.now_token + sync_config = sync_result_builder.sync_config + + room_id = room_builder.room_id + since_token = room_builder.since_token + upto_token = room_builder.upto_token + + batch = yield self._load_filtered_recents( + room_id, sync_config, + now_token=upto_token, + since_token=since_token, + recents=events, + newly_joined_room=newly_joined, + ) + + account_data_events = [] + if tags is not None: + account_data_events.append({ + "type": "m.tag", + "content": {"tags": tags}, + }) + + for account_data_type, content in account_data.items(): + account_data_events.append({ + "type": account_data_type, + "content": content, + }) + + account_data = sync_config.filter_collection.filter_room_account_data( + account_data_events + ) + + ephemeral = sync_config.filter_collection.filter_room_ephemeral(ephemeral) + + if not (always_include or batch or account_data or ephemeral or full_state): + return + + state = yield self.compute_state_delta( + room_id, batch, sync_config, since_token, now_token, + full_state=full_state + ) + + if room_builder.rtype == "joined": + unread_notifications = {} + room_sync = JoinedSyncResult( + room_id=room_id, + timeline=batch, + state=state, + ephemeral=ephemeral, + account_data=account_data_events, + unread_notifications=unread_notifications, + ) + + if room_sync or always_include: + notifs = yield self.unread_notifs_for_room_id( + room_id, sync_config + ) + + if notifs is not None: + unread_notifications["notification_count"] = notifs["notify_count"] + unread_notifications["highlight_count"] = notifs["highlight_count"] + + sync_result_builder.joined.append(room_sync) + elif room_builder.rtype == "archived": + room_sync = ArchivedSyncResult( + room_id=room_id, + timeline=batch, + state=state, + account_data=account_data, + ) + if room_sync or always_include: + sync_result_builder.archived.append(room_sync) + else: + raise Exception("Unrecognized rtype: %r", room_builder.rtype) def _action_has_highlight(actions): @@ -997,25 +1127,72 @@ def _calculate_state(timeline_contains, timeline_start, previous, current): Returns: dict """ - event_id_to_state = { - e.event_id: e - for e in itertools.chain( - timeline_contains.values(), - previous.values(), - timeline_start.values(), - current.values(), + event_id_to_key = { + e: key + for key, e in itertools.chain( + timeline_contains.items(), + previous.items(), + timeline_start.items(), + current.items(), ) } - c_ids = set(e.event_id for e in current.values()) - tc_ids = set(e.event_id for e in timeline_contains.values()) - p_ids = set(e.event_id for e in previous.values()) - ts_ids = set(e.event_id for e in timeline_start.values()) + c_ids = set(e for e in current.values()) + tc_ids = set(e for e in timeline_contains.values()) + p_ids = set(e for e in previous.values()) + ts_ids = set(e for e in timeline_start.values()) state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids - evs = (event_id_to_state[e] for e in state_ids) return { - (e.type, e.state_key): e - for e in evs + event_id_to_key[e]: e for e in state_ids } + + +class SyncResultBuilder(object): + "Used to help build up a new SyncResult for a user" + def __init__(self, sync_config, full_state, since_token, now_token): + """ + Args: + sync_config(SyncConfig) + full_state(bool): The full_state flag as specified by user + since_token(StreamToken): The token supplied by user, or None. + now_token(StreamToken): The token to sync up to. + """ + self.sync_config = sync_config + self.full_state = full_state + self.since_token = since_token + self.now_token = now_token + + self.presence = [] + self.account_data = [] + self.joined = [] + self.invited = [] + self.archived = [] + self.device = [] + + +class RoomSyncResultBuilder(object): + """Stores information needed to create either a `JoinedSyncResult` or + `ArchivedSyncResult`. + """ + def __init__(self, room_id, rtype, events, newly_joined, full_state, + since_token, upto_token): + """ + Args: + room_id(str) + rtype(str): One of `"joined"` or `"archived"` + events(list): List of events to include in the room, (more events + may be added when generating result). + newly_joined(bool): If the user has newly joined the room + full_state(bool): Whether the full state should be sent in result + since_token(StreamToken): Earliest point to return events from, or None + upto_token(StreamToken): Latest point to return events from. + """ + self.room_id = room_id + self.rtype = rtype + self.events = events + self.newly_joined = newly_joined + self.full_state = full_state + self.since_token = since_token + self.upto_token = upto_token diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 8ce27f49ec..0eea7f8f9c 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -15,12 +15,11 @@ from twisted.internet import defer -from ._base import BaseHandler - from synapse.api.errors import SynapseError, AuthError -from synapse.util.logcontext import PreserveLoggingContext +from synapse.util.logcontext import preserve_fn from synapse.util.metrics import Measure -from synapse.types import UserID +from synapse.util.wheel_timer import WheelTimer +from synapse.types import UserID, get_domain_from_id import logging @@ -32,25 +31,38 @@ logger = logging.getLogger(__name__) # A tiny object useful for storing a user's membership in a room, as a mapping # key -RoomMember = namedtuple("RoomMember", ("room_id", "user")) +RoomMember = namedtuple("RoomMember", ("room_id", "user_id")) + + +# How often we expect remote servers to resend us presence. +FEDERATION_TIMEOUT = 60 * 1000 + +# How often to resend typing across federation. +FEDERATION_PING_INTERVAL = 40 * 1000 -class TypingNotificationHandler(BaseHandler): +class TypingHandler(object): def __init__(self, hs): - super(TypingNotificationHandler, self).__init__(hs) + self.store = hs.get_datastore() + self.server_name = hs.config.server_name + self.auth = hs.get_auth() + self.is_mine_id = hs.is_mine_id + self.notifier = hs.get_notifier() + self.state = hs.get_state_handler() - self.homeserver = hs + self.hs = hs self.clock = hs.get_clock() + self.wheel_timer = WheelTimer(bucket_size=5000) - self.federation = hs.get_replication_layer() + self.federation = hs.get_federation_sender() - self.federation.register_edu_handler("m.typing", self._recv_edu) + hs.get_replication_layer().register_edu_handler("m.typing", self._recv_edu) hs.get_distributor().observe("user_left_room", self.user_left_room) self._member_typing_until = {} # clock time we expect to stop - self._member_typing_timer = {} # deferreds to manage theabove + self._member_last_federation_poke = {} # map room IDs to serial numbers self._room_serials = {} @@ -58,44 +70,78 @@ class TypingNotificationHandler(BaseHandler): # map room IDs to sets of users currently typing self._room_typing = {} - def tearDown(self): - """Cancels all the pending timers. - Normally this shouldn't be needed, but it's required from unit tests - to avoid a "Reactor was unclean" warning.""" - for t in self._member_typing_timer.values(): - self.clock.cancel_call_later(t) + self.clock.looping_call( + self._handle_timeouts, + 5000, + ) + + def _handle_timeouts(self): + logger.info("Checking for typing timeouts") + + now = self.clock.time_msec() + + members = set(self.wheel_timer.fetch(now)) + + for member in members: + if not self.is_typing(member): + # Nothing to do if they're no longer typing + continue + + until = self._member_typing_until.get(member, None) + if not until or until <= now: + logger.info("Timing out typing for: %s", member.user_id) + preserve_fn(self._stopped_typing)(member) + continue + + # Check if we need to resend a keep alive over federation for this + # user. + if self.hs.is_mine_id(member.user_id): + last_fed_poke = self._member_last_federation_poke.get(member, None) + if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now: + preserve_fn(self._push_remote)( + member=member, + typing=True + ) + + # Add a paranoia timer to ensure that we always have a timer for + # each person typing. + self.wheel_timer.insert( + now=now, + obj=member, + then=now + 60 * 1000, + ) + + def is_typing(self, member): + return member.user_id in self._room_typing.get(member.room_id, []) @defer.inlineCallbacks def started_typing(self, target_user, auth_user, room_id, timeout): - if not self.hs.is_mine(target_user): + target_user_id = target_user.to_string() + auth_user_id = auth_user.to_string() + + if not self.is_mine_id(target_user_id): raise SynapseError(400, "User is not hosted on this Home Server") - if target_user != auth_user: + if target_user_id != auth_user_id: raise AuthError(400, "Cannot set another user's typing state") - yield self.auth.check_joined_room(room_id, target_user.to_string()) + yield self.auth.check_joined_room(room_id, target_user_id) logger.debug( - "%s has started typing in %s", target_user.to_string(), room_id + "%s has started typing in %s", target_user_id, room_id ) - until = self.clock.time_msec() + timeout - member = RoomMember(room_id=room_id, user=target_user) + member = RoomMember(room_id=room_id, user_id=target_user_id) - was_present = member in self._member_typing_until + was_present = member.user_id in self._room_typing.get(room_id, set()) - if member in self._member_typing_timer: - self.clock.cancel_call_later(self._member_typing_timer[member]) - - def _cb(): - logger.debug( - "%s has timed out in %s", target_user.to_string(), room_id - ) - self._stopped_typing(member) + now = self.clock.time_msec() + self._member_typing_until[member] = now + timeout - self._member_typing_until[member] = until - self._member_typing_timer[member] = self.clock.call_later( - timeout / 1000.0, _cb + self.wheel_timer.insert( + now=now, + obj=member, + then=now + timeout, ) if was_present: @@ -103,132 +149,146 @@ class TypingNotificationHandler(BaseHandler): defer.returnValue(None) yield self._push_update( - room_id=room_id, - user=target_user, + member=member, typing=True, ) @defer.inlineCallbacks def stopped_typing(self, target_user, auth_user, room_id): - if not self.hs.is_mine(target_user): + target_user_id = target_user.to_string() + auth_user_id = auth_user.to_string() + + if not self.is_mine_id(target_user_id): raise SynapseError(400, "User is not hosted on this Home Server") - if target_user != auth_user: + if target_user_id != auth_user_id: raise AuthError(400, "Cannot set another user's typing state") - yield self.auth.check_joined_room(room_id, target_user.to_string()) + yield self.auth.check_joined_room(room_id, target_user_id) logger.debug( - "%s has stopped typing in %s", target_user.to_string(), room_id + "%s has stopped typing in %s", target_user_id, room_id ) - member = RoomMember(room_id=room_id, user=target_user) - - if member in self._member_typing_timer: - self.clock.cancel_call_later(self._member_typing_timer[member]) - del self._member_typing_timer[member] + member = RoomMember(room_id=room_id, user_id=target_user_id) yield self._stopped_typing(member) @defer.inlineCallbacks def user_left_room(self, user, room_id): - if self.hs.is_mine(user): - member = RoomMember(room_id=room_id, user=user) + user_id = user.to_string() + if self.is_mine_id(user_id): + member = RoomMember(room_id=room_id, user_id=user_id) yield self._stopped_typing(member) @defer.inlineCallbacks def _stopped_typing(self, member): - if member not in self._member_typing_until: + if member.user_id not in self._room_typing.get(member.room_id, set()): # No point defer.returnValue(None) + self._member_typing_until.pop(member, None) + self._member_last_federation_poke.pop(member, None) + yield self._push_update( - room_id=member.room_id, - user=member.user, + member=member, typing=False, ) - del self._member_typing_until[member] - - if member in self._member_typing_timer: - # Don't cancel it - either it already expired, or the real - # stopped_typing() will cancel it - del self._member_typing_timer[member] - @defer.inlineCallbacks - def _push_update(self, room_id, user, typing): - localusers = set() - remotedomains = set() - - rm_handler = self.homeserver.get_handlers().room_member_handler - yield rm_handler.fetch_room_distributions_into( - room_id, localusers=localusers, remotedomains=remotedomains + def _push_update(self, member, typing): + if self.hs.is_mine_id(member.user_id): + # Only send updates for changes to our own users. + yield self._push_remote(member, typing) + + self._push_update_local( + member=member, + typing=typing ) - if localusers: - self._push_update_local( - room_id=room_id, - user=user, - typing=typing - ) - - deferreds = [] - for domain in remotedomains: - deferreds.append(self.federation.send_edu( - destination=domain, - edu_type="m.typing", - content={ - "room_id": room_id, - "user_id": user.to_string(), - "typing": typing, - }, - )) + @defer.inlineCallbacks + def _push_remote(self, member, typing): + users = yield self.state.get_current_user_in_room(member.room_id) + self._member_last_federation_poke[member] = self.clock.time_msec() + + now = self.clock.time_msec() + self.wheel_timer.insert( + now=now, + obj=member, + then=now + FEDERATION_PING_INTERVAL, + ) - yield defer.DeferredList(deferreds, consumeErrors=True) + for domain in set(get_domain_from_id(u) for u in users): + if domain != self.server_name: + self.federation.send_edu( + destination=domain, + edu_type="m.typing", + content={ + "room_id": member.room_id, + "user_id": member.user_id, + "typing": typing, + }, + key=member, + ) @defer.inlineCallbacks def _recv_edu(self, origin, content): room_id = content["room_id"] - user = UserID.from_string(content["user_id"]) + user_id = content["user_id"] - localusers = set() + member = RoomMember(user_id=user_id, room_id=room_id) - rm_handler = self.homeserver.get_handlers().room_member_handler - yield rm_handler.fetch_room_distributions_into( - room_id, localusers=localusers - ) + # Check that the string is a valid user id + user = UserID.from_string(user_id) - if localusers: + if user.domain != origin: + logger.info( + "Got typing update from %r with bad 'user_id': %r", + origin, user_id, + ) + return + + users = yield self.state.get_current_user_in_room(room_id) + domains = set(get_domain_from_id(u) for u in users) + + if self.server_name in domains: + logger.info("Got typing update from %s: %r", user_id, content) + now = self.clock.time_msec() + self._member_typing_until[member] = now + FEDERATION_TIMEOUT + self.wheel_timer.insert( + now=now, + obj=member, + then=now + FEDERATION_TIMEOUT, + ) self._push_update_local( - room_id=room_id, - user=user, + member=member, typing=content["typing"] ) - def _push_update_local(self, room_id, user, typing): - room_set = self._room_typing.setdefault(room_id, set()) + def _push_update_local(self, member, typing): + room_set = self._room_typing.setdefault(member.room_id, set()) if typing: - room_set.add(user) + room_set.add(member.user_id) else: - room_set.discard(user) + room_set.discard(member.user_id) self._latest_room_serial += 1 - self._room_serials[room_id] = self._latest_room_serial + self._room_serials[member.room_id] = self._latest_room_serial - with PreserveLoggingContext(): - self.notifier.on_new_event( - "typing_key", self._latest_room_serial, rooms=[room_id] - ) + self.notifier.on_new_event( + "typing_key", self._latest_room_serial, rooms=[member.room_id] + ) def get_all_typing_updates(self, last_id, current_id): # TODO: Work out a way to do this without scanning the entire state. + if last_id == current_id: + return [] + rows = [] for room_id, serial in self._room_serials.items(): if last_id < serial and serial <= current_id: typing = self._room_typing[room_id] - typing_bytes = json.dumps([ - u.to_string() for u in typing - ], ensure_ascii=False) + typing_bytes = json.dumps(list(typing), ensure_ascii=False) rows.append((serial, room_id, typing_bytes)) rows.sort() return rows @@ -238,34 +298,26 @@ class TypingNotificationEventSource(object): def __init__(self, hs): self.hs = hs self.clock = hs.get_clock() - self._handler = None - self._room_member_handler = None - - def handler(self): - # Avoid cyclic dependency in handler setup - if not self._handler: - self._handler = self.hs.get_handlers().typing_notification_handler - return self._handler - - def room_member_handler(self): - if not self._room_member_handler: - self._room_member_handler = self.hs.get_handlers().room_member_handler - return self._room_member_handler + # We can't call get_typing_handler here because there's a cycle: + # + # Typing -> Notifier -> TypingNotificationEventSource -> Typing + # + self.get_typing_handler = hs.get_typing_handler def _make_event_for(self, room_id): - typing = self.handler()._room_typing[room_id] + typing = self.get_typing_handler()._room_typing[room_id] return { "type": "m.typing", "room_id": room_id, "content": { - "user_ids": [u.to_string() for u in typing], + "user_ids": list(typing), }, } def get_new_events(self, from_key, room_ids, **kwargs): with Measure(self.clock, "typing.get_new_events"): from_key = int(from_key) - handler = self.handler() + handler = self.get_typing_handler() events = [] for room_id in room_ids: @@ -279,7 +331,7 @@ class TypingNotificationEventSource(object): return events, handler._latest_room_serial def get_current_key(self): - return self.handler()._latest_room_serial + return self.get_typing_handler()._latest_room_serial def get_pagination_rows(self, user, pagination_config, key): return ([], pagination_config.from_key) diff --git a/synapse/http/client.py b/synapse/http/client.py index cbd45b2bbe..3ec9bc7faf 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -15,17 +15,25 @@ from OpenSSL import SSL from OpenSSL.SSL import VERIFY_NONE -from synapse.api.errors import CodeMessageException +from synapse.api.errors import ( + CodeMessageException, SynapseError, Codes, +) from synapse.util.logcontext import preserve_context_over_fn import synapse.metrics +from synapse.http.endpoint import SpiderEndpoint from canonicaljson import encode_canonical_json -from twisted.internet import defer, reactor, ssl +from twisted.internet import defer, reactor, ssl, protocol, task +from twisted.internet.endpoints import SSL4ClientEndpoint, TCP4ClientEndpoint from twisted.web.client import ( - Agent, readBody, FileBodyProducer, PartialDownloadError, + BrowserLikeRedirectAgent, ContentDecoderAgent, GzipDecoder, Agent, + readBody, PartialDownloadError, ) +from twisted.web.client import FileBodyProducer as TwistedFileBodyProducer +from twisted.web.http import PotentialDataLoss from twisted.web.http_headers import Headers +from twisted.web._newclient import ResponseDone from StringIO import StringIO @@ -238,6 +246,107 @@ class SimpleHttpClient(object): else: raise CodeMessageException(response.code, body) + # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. + # The two should be factored out. + + @defer.inlineCallbacks + def get_file(self, url, output_stream, max_size=None): + """GETs a file from a given URL + Args: + url (str): The URL to GET + output_stream (file): File to write the response body to. + Returns: + A (int,dict,string,int) tuple of the file length, dict of the response + headers, absolute URI of the response and HTTP response code. + """ + + response = yield self.request( + "GET", + url.encode("ascii"), + headers=Headers({ + b"User-Agent": [self.user_agent], + }) + ) + + headers = dict(response.headers.getAllRawHeaders()) + + if 'Content-Length' in headers and headers['Content-Length'] > max_size: + logger.warn("Requested URL is too large > %r bytes" % (self.max_size,)) + raise SynapseError( + 502, + "Requested file is too large > %r bytes" % (self.max_size,), + Codes.TOO_LARGE, + ) + + if response.code > 299: + logger.warn("Got %d when downloading %s" % (response.code, url)) + raise SynapseError( + 502, + "Got error %d" % (response.code,), + Codes.UNKNOWN, + ) + + # TODO: if our Content-Type is HTML or something, just read the first + # N bytes into RAM rather than saving it all to disk only to read it + # straight back in again + + try: + length = yield preserve_context_over_fn( + _readBodyToFile, + response, output_stream, max_size + ) + except Exception as e: + logger.exception("Failed to download body") + raise SynapseError( + 502, + ("Failed to download remote body: %s" % e), + Codes.UNKNOWN, + ) + + defer.returnValue((length, headers, response.request.absoluteURI, response.code)) + + +# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. +# The two should be factored out. + +class _ReadBodyToFileProtocol(protocol.Protocol): + def __init__(self, stream, deferred, max_size): + self.stream = stream + self.deferred = deferred + self.length = 0 + self.max_size = max_size + + def dataReceived(self, data): + self.stream.write(data) + self.length += len(data) + if self.max_size is not None and self.length >= self.max_size: + self.deferred.errback(SynapseError( + 502, + "Requested file is too large > %r bytes" % (self.max_size,), + Codes.TOO_LARGE, + )) + self.deferred = defer.Deferred() + self.transport.loseConnection() + + def connectionLost(self, reason): + if reason.check(ResponseDone): + self.deferred.callback(self.length) + elif reason.check(PotentialDataLoss): + # stolen from https://github.com/twisted/treq/pull/49/files + # http://twistedmatrix.com/trac/ticket/4840 + self.deferred.callback(self.length) + else: + self.deferred.errback(reason) + + +# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. +# The two should be factored out. + +def _readBodyToFile(response, stream, max_size): + d = defer.Deferred() + response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size)) + return d + class CaptchaServerHttpClient(SimpleHttpClient): """ @@ -269,6 +378,60 @@ class CaptchaServerHttpClient(SimpleHttpClient): defer.returnValue(e.response) +class SpiderEndpointFactory(object): + def __init__(self, hs): + self.blacklist = hs.config.url_preview_ip_range_blacklist + self.whitelist = hs.config.url_preview_ip_range_whitelist + self.policyForHTTPS = hs.get_http_client_context_factory() + + def endpointForURI(self, uri): + logger.info("Getting endpoint for %s", uri.toBytes()) + if uri.scheme == "http": + return SpiderEndpoint( + reactor, uri.host, uri.port, self.blacklist, self.whitelist, + endpoint=TCP4ClientEndpoint, + endpoint_kw_args={ + 'timeout': 15 + }, + ) + elif uri.scheme == "https": + tlsPolicy = self.policyForHTTPS.creatorForNetloc(uri.host, uri.port) + return SpiderEndpoint( + reactor, uri.host, uri.port, self.blacklist, self.whitelist, + endpoint=SSL4ClientEndpoint, + endpoint_kw_args={ + 'sslContextFactory': tlsPolicy, + 'timeout': 15 + }, + ) + else: + logger.warn("Can't get endpoint for unrecognised scheme %s", uri.scheme) + + +class SpiderHttpClient(SimpleHttpClient): + """ + Separate HTTP client for spidering arbitrary URLs. + Special in that it follows retries and has a UA that looks + like a browser. + + used by the preview_url endpoint in the content repo. + """ + def __init__(self, hs): + SimpleHttpClient.__init__(self, hs) + # clobber the base class's agent and UA: + self.agent = ContentDecoderAgent( + BrowserLikeRedirectAgent( + Agent.usingEndpointFactory( + reactor, + SpiderEndpointFactory(hs) + ) + ), [('gzip', GzipDecoder)] + ) + # We could look like Chrome: + # self.user_agent = ("Mozilla/5.0 (%s) (KHTML, like Gecko) + # Chrome Safari" % hs.version_string) + + def encode_urlencode_args(args): return {k: encode_urlencode_arg(v) for k, v in args.items()} @@ -301,5 +464,31 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory): self._context = SSL.Context(SSL.SSLv23_METHOD) self._context.set_verify(VERIFY_NONE, lambda *_: None) - def getContext(self, hostname, port): + def getContext(self, hostname=None, port=None): return self._context + + def creatorForNetloc(self, hostname, port): + return self + + +class FileBodyProducer(TwistedFileBodyProducer): + """Workaround for https://twistedmatrix.com/trac/ticket/8473 + + We override the pauseProducing and resumeProducing methods in twisted's + FileBodyProducer so that they do not raise exceptions if the task has + already completed. + """ + + def pauseProducing(self): + try: + super(FileBodyProducer, self).pauseProducing() + except task.TaskDone: + # task has already completed + pass + + def resumeProducing(self): + try: + super(FileBodyProducer, self).resumeProducing() + except task.NotPaused: + # task was not paused (probably because it had already completed) + pass diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py index 4775f6707d..442696d393 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py @@ -22,6 +22,7 @@ from twisted.names.error import DNSNameError, DomainError import collections import logging import random +import time logger = logging.getLogger(__name__) @@ -31,7 +32,7 @@ SERVER_CACHE = {} _Server = collections.namedtuple( - "_Server", "priority weight host port" + "_Server", "priority weight host port expires" ) @@ -74,6 +75,41 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None, return transport_endpoint(reactor, domain, port, **endpoint_kw_args) +class SpiderEndpoint(object): + """An endpoint which refuses to connect to blacklisted IP addresses + Implements twisted.internet.interfaces.IStreamClientEndpoint. + """ + def __init__(self, reactor, host, port, blacklist, whitelist, + endpoint=TCP4ClientEndpoint, endpoint_kw_args={}): + self.reactor = reactor + self.host = host + self.port = port + self.blacklist = blacklist + self.whitelist = whitelist + self.endpoint = endpoint + self.endpoint_kw_args = endpoint_kw_args + + @defer.inlineCallbacks + def connect(self, protocolFactory): + address = yield self.reactor.resolve(self.host) + + from netaddr import IPAddress + ip_address = IPAddress(address) + + if ip_address in self.blacklist: + if self.whitelist is None or ip_address not in self.whitelist: + raise ConnectError( + "Refusing to spider blacklisted IP address %s" % address + ) + + logger.info("Connecting to %s:%s", address, self.port) + endpoint = self.endpoint( + self.reactor, address, self.port, **self.endpoint_kw_args + ) + connection = yield endpoint.connect(protocolFactory) + defer.returnValue(connection) + + class SRVClientEndpoint(object): """An endpoint which looks up SRV records for a service. Cycles through the list of servers starting with each call to connect @@ -92,7 +128,8 @@ class SRVClientEndpoint(object): host=domain, port=default_port, priority=0, - weight=0 + weight=0, + expires=0, ) else: self.default_server = None @@ -118,7 +155,7 @@ class SRVClientEndpoint(object): return self.default_server else: raise ConnectError( - "Not server available for %s", self.service_name + "Not server available for %s" % self.service_name ) min_priority = self.servers[0].priority @@ -153,7 +190,13 @@ class SRVClientEndpoint(object): @defer.inlineCallbacks -def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE): +def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time): + cache_entry = cache.get(service_name, None) + if cache_entry: + if all(s.expires > int(clock.time()) for s in cache_entry): + servers = list(cache_entry) + defer.returnValue(servers) + servers = [] try: @@ -166,34 +209,33 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE): and answers[0].type == dns.SRV and answers[0].payload and answers[0].payload.target == dns.Name('.')): - raise ConnectError("Service %s unavailable", service_name) + raise ConnectError("Service %s unavailable" % service_name) for answer in answers: if answer.type != dns.SRV or not answer.payload: continue payload = answer.payload - host = str(payload.target) + srv_ttl = answer.ttl try: answers, _, _ = yield dns_client.lookupAddress(host) except DNSNameError: continue - ips = [ - answer.payload.dottedQuad() - for answer in answers - if answer.type == dns.A and answer.payload - ] - - for ip in ips: - servers.append(_Server( - host=ip, - port=int(payload.port), - priority=int(payload.priority), - weight=int(payload.weight) - )) + for answer in answers: + if answer.type == dns.A and answer.payload: + ip = answer.payload.dottedQuad() + host_ttl = min(srv_ttl, answer.ttl) + + servers.append(_Server( + host=ip, + port=int(payload.port), + priority=int(payload.priority), + weight=int(payload.weight), + expires=int(clock.time()) + host_ttl, + )) servers.sort() cache[service_name] = list(servers) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index c3589534f8..d5970c05a8 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -33,6 +33,7 @@ from synapse.api.errors import ( from signedjson.sign import sign_json +import cgi import simplejson as json import logging import random @@ -155,9 +156,7 @@ class MatrixFederationHttpClient(object): time_out=timeout / 1000. if timeout else 60, ) - response = yield preserve_context_over_fn( - send_request, - ) + response = yield preserve_context_over_fn(send_request) log_result = "%d %s" % (response.code, response.phrase,) break @@ -248,7 +247,7 @@ class MatrixFederationHttpClient(object): @defer.inlineCallbacks def put_json(self, destination, path, data={}, json_data_callback=None, - long_retries=False): + long_retries=False, timeout=None): """ Sends the specifed json data using PUT Args: @@ -261,6 +260,8 @@ class MatrixFederationHttpClient(object): use as the request body. long_retries (bool): A boolean that indicates whether we should retry for a short or long time. + timeout(int): How long to try (in ms) the destination for before + giving up. None indicates no timeout. Returns: Deferred: Succeeds when we get a 2xx HTTP response. The result @@ -287,22 +288,19 @@ class MatrixFederationHttpClient(object): body_callback=body_callback, headers_dict={"Content-Type": ["application/json"]}, long_retries=long_retries, + timeout=timeout, ) if 200 <= response.code < 300: # We need to update the transactions table to say it was sent? - c_type = response.headers.getRawHeaders("Content-Type") - - if "application/json" not in c_type: - raise RuntimeError( - "Content-Type not application/json" - ) + check_content_type_is_json(response.headers) body = yield preserve_context_over_fn(readBody, response) defer.returnValue(json.loads(body)) @defer.inlineCallbacks - def post_json(self, destination, path, data={}, long_retries=True): + def post_json(self, destination, path, data={}, long_retries=True, + timeout=None): """ Sends the specifed json data using POST Args: @@ -313,6 +311,8 @@ class MatrixFederationHttpClient(object): the request body. This will be encoded as JSON. long_retries (bool): A boolean that indicates whether we should retry for a short or long time. + timeout(int): How long to try (in ms) the destination for before + giving up. None indicates no timeout. Returns: Deferred: Succeeds when we get a 2xx HTTP response. The result @@ -333,16 +333,12 @@ class MatrixFederationHttpClient(object): body_callback=body_callback, headers_dict={"Content-Type": ["application/json"]}, long_retries=True, + timeout=timeout, ) if 200 <= response.code < 300: # We need to update the transactions table to say it was sent? - c_type = response.headers.getRawHeaders("Content-Type") - - if "application/json" not in c_type: - raise RuntimeError( - "Content-Type not application/json" - ) + check_content_type_is_json(response.headers) body = yield preserve_context_over_fn(readBody, response) @@ -395,12 +391,7 @@ class MatrixFederationHttpClient(object): if 200 <= response.code < 300: # We need to update the transactions table to say it was sent? - c_type = response.headers.getRawHeaders("Content-Type") - - if "application/json" not in c_type: - raise RuntimeError( - "Content-Type not application/json" - ) + check_content_type_is_json(response.headers) body = yield preserve_context_over_fn(readBody, response) @@ -520,3 +511,29 @@ def _flatten_response_never_received(e): ) else: return "%s: %s" % (type(e).__name__, e.message,) + + +def check_content_type_is_json(headers): + """ + Check that a set of HTTP headers have a Content-Type header, and that it + is application/json. + + Args: + headers (twisted.web.http_headers.Headers): headers to check + + Raises: + RuntimeError if the + + """ + c_type = headers.getRawHeaders("Content-Type") + if c_type is None: + raise RuntimeError( + "No Content-Type header" + ) + + c_type = c_type[0] # only the first header + val, options = cgi.parse_header(c_type) + if val != "application/json": + raise RuntimeError( + "Content-Type not application/json: was '%s'" % c_type + ) diff --git a/synapse/http/server.py b/synapse/http/server.py index b17b190ee5..14715878c5 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -18,6 +18,8 @@ from synapse.api.errors import ( cs_exception, SynapseError, CodeMessageException, UnrecognizedRequestError, Codes ) from synapse.util.logcontext import LoggingContext, PreserveLoggingContext +from synapse.util.caches import intern_dict +from synapse.util.metrics import Measure import synapse.metrics import synapse.events @@ -73,7 +75,12 @@ response_db_txn_duration = metrics.register_distribution( _next_request_id = 0 -def request_handler(request_handler): +def request_handler(include_metrics=False): + """Decorator for ``wrap_request_handler``""" + return lambda request_handler: wrap_request_handler(request_handler, include_metrics) + + +def wrap_request_handler(request_handler, include_metrics=False): """Wraps a method that acts as a request handler with the necessary logging and exception handling. @@ -95,43 +102,58 @@ def request_handler(request_handler): global _next_request_id request_id = "%s-%s" % (request.method, _next_request_id) _next_request_id += 1 + with LoggingContext(request_id) as request_context: - request_context.request = request_id - with request.processing(): - try: - with PreserveLoggingContext(request_context): - yield request_handler(self, request) - except CodeMessageException as e: - code = e.code - if isinstance(e, SynapseError): - logger.info( - "%s SynapseError: %s - %s", request, code, e.msg + with Measure(self.clock, "wrapped_request_handler"): + request_metrics = RequestMetrics() + request_metrics.start(self.clock, name=self.__class__.__name__) + + request_context.request = request_id + with request.processing(): + try: + with PreserveLoggingContext(request_context): + if include_metrics: + yield request_handler(self, request, request_metrics) + else: + yield request_handler(self, request) + except CodeMessageException as e: + code = e.code + if isinstance(e, SynapseError): + logger.info( + "%s SynapseError: %s - %s", request, code, e.msg + ) + else: + logger.exception(e) + outgoing_responses_counter.inc(request.method, str(code)) + respond_with_json( + request, code, cs_exception(e), send_cors=True, + pretty_print=_request_user_agent_is_curl(request), + version_string=self.version_string, + ) + except: + logger.exception( + "Failed handle request %s.%s on %r: %r", + request_handler.__module__, + request_handler.__name__, + self, + request ) - else: - logger.exception(e) - outgoing_responses_counter.inc(request.method, str(code)) - respond_with_json( - request, code, cs_exception(e), send_cors=True, - pretty_print=_request_user_agent_is_curl(request), - version_string=self.version_string, - ) - except: - logger.exception( - "Failed handle request %s.%s on %r: %r", - request_handler.__module__, - request_handler.__name__, - self, - request - ) - respond_with_json( - request, - 500, - { - "error": "Internal server error", - "errcode": Codes.UNKNOWN, - }, - send_cors=True - ) + respond_with_json( + request, + 500, + { + "error": "Internal server error", + "errcode": Codes.UNKNOWN, + }, + send_cors=True + ) + finally: + try: + request_metrics.stop( + self.clock, request + ) + except Exception as e: + logger.warn("Failed to stop metrics: %r", e) return wrapped_request_handler @@ -186,6 +208,7 @@ class JsonResource(HttpServer, resource.Resource): def register_paths(self, method, path_patterns, callback): for path_pattern in path_patterns: + logger.debug("Registering for %s %s", method, path_pattern.pattern) self.path_regexs.setdefault(method, []).append( self._PathEntry(path_pattern, callback) ) @@ -196,20 +219,21 @@ class JsonResource(HttpServer, resource.Resource): self._async_render(request) return server.NOT_DONE_YET - @request_handler + # Disable metric reporting because _async_render does its own metrics. + # It does its own metric reporting because _async_render dispatches to + # a callback and it's the class name of that callback we want to report + # against rather than the JsonResource itself. + @request_handler(include_metrics=True) @defer.inlineCallbacks - def _async_render(self, request): + def _async_render(self, request, request_metrics): """ This gets called from render() every time someone sends us a request. This checks if anyone has registered a callback for that method and path. """ - start = self.clock.time_msec() if request.method == "OPTIONS": self._send_response(request, 200, {}) return - start_context = LoggingContext.current_context() - # Loop through all the registered callbacks to check if the method # and path regex match for path_entry in self.path_regexs.get(request.method, []): @@ -223,58 +247,23 @@ class JsonResource(HttpServer, resource.Resource): callback = path_entry.callback - servlet_instance = getattr(callback, "__self__", None) - if servlet_instance is not None: - servlet_classname = servlet_instance.__class__.__name__ - else: - servlet_classname = "%r" % callback - - args = [ - urllib.unquote(u).decode("UTF-8") if u else u for u in m.groups() - ] + kwargs = intern_dict({ + name: urllib.unquote(value).decode("UTF-8") if value else value + for name, value in m.groupdict().items() + }) - callback_return = yield callback(request, *args) + callback_return = yield callback(request, **kwargs) if callback_return is not None: code, response = callback_return self._send_response(request, code, response) - try: - context = LoggingContext.current_context() - - tag = "" - if context: - tag = context.tag - - if context != start_context: - logger.warn( - "Context have unexpectedly changed %r, %r", - context, self.start_context - ) - return - - incoming_requests_counter.inc(request.method, servlet_classname, tag) - - response_timer.inc_by( - self.clock.time_msec() - start, request.method, - servlet_classname, tag - ) - - ru_utime, ru_stime = context.get_resource_usage() + servlet_instance = getattr(callback, "__self__", None) + if servlet_instance is not None: + servlet_classname = servlet_instance.__class__.__name__ + else: + servlet_classname = "%r" % callback - response_ru_utime.inc_by( - ru_utime, request.method, servlet_classname, tag - ) - response_ru_stime.inc_by( - ru_stime, request.method, servlet_classname, tag - ) - response_db_txn_count.inc_by( - context.db_txn_count, request.method, servlet_classname, tag - ) - response_db_txn_duration.inc_by( - context.db_txn_duration, request.method, servlet_classname, tag - ) - except: - pass + request_metrics.name = servlet_classname return @@ -305,6 +294,49 @@ class JsonResource(HttpServer, resource.Resource): ) +class RequestMetrics(object): + def start(self, clock, name): + self.start = clock.time_msec() + self.start_context = LoggingContext.current_context() + self.name = name + + def stop(self, clock, request): + context = LoggingContext.current_context() + + tag = "" + if context: + tag = context.tag + + if context != self.start_context: + logger.warn( + "Context have unexpectedly changed %r, %r", + context, self.start_context + ) + return + + incoming_requests_counter.inc(request.method, self.name, tag) + + response_timer.inc_by( + clock.time_msec() - self.start, request.method, + self.name, tag + ) + + ru_utime, ru_stime = context.get_resource_usage() + + response_ru_utime.inc_by( + ru_utime, request.method, self.name, tag + ) + response_ru_stime.inc_by( + ru_stime, request.method, self.name, tag + ) + response_db_txn_count.inc_by( + context.db_txn_count, request.method, self.name, tag + ) + response_db_txn_duration.inc_by( + context.db_txn_duration, request.method, self.name, tag + ) + + class RootRedirect(resource.Resource): """Redirects the root '/' path to another path.""" @@ -360,17 +392,30 @@ def respond_with_json_bytes(request, code, json_bytes, send_cors=False, request.setHeader(b"Content-Length", b"%d" % (len(json_bytes),)) if send_cors: - request.setHeader("Access-Control-Allow-Origin", "*") - request.setHeader("Access-Control-Allow-Methods", - "GET, POST, PUT, DELETE, OPTIONS") - request.setHeader("Access-Control-Allow-Headers", - "Origin, X-Requested-With, Content-Type, Accept") + set_cors_headers(request) request.write(json_bytes) finish_request(request) return NOT_DONE_YET +def set_cors_headers(request): + """Set the CORs headers so that javascript running in a web browsers can + use this API + + Args: + request (twisted.web.http.Request): The http request to add CORs to. + """ + request.setHeader("Access-Control-Allow-Origin", "*") + request.setHeader( + "Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS" + ) + request.setHeader( + "Access-Control-Allow-Headers", + "Origin, X-Requested-With, Content-Type, Accept" + ) + + def finish_request(request): """ Finish writing the response to the request. diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 1c8bd8666f..9346386238 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -26,19 +26,28 @@ logger = logging.getLogger(__name__) def parse_integer(request, name, default=None, required=False): """Parse an integer parameter from the request string - :param request: the twisted HTTP request. - :param name (str): the name of the query parameter. - :param default: value to use if the parameter is absent, defaults to None. - :param required (bool): whether to raise a 400 SynapseError if the - parameter is absent, defaults to False. - :return: An int value or the default. - :raises - SynapseError if the parameter is absent and required, or if the + Args: + request: the twisted HTTP request. + name (str): the name of the query parameter. + default (int|None): value to use if the parameter is absent, defaults + to None. + required (bool): whether to raise a 400 SynapseError if the + parameter is absent, defaults to False. + + Returns: + int|None: An int value or the default. + + Raises: + SynapseError: if the parameter is absent and required, or if the parameter is present and not an integer. """ - if name in request.args: + return parse_integer_from_args(request.args, name, default, required) + + +def parse_integer_from_args(args, name, default=None, required=False): + if name in args: try: - return int(request.args[name][0]) + return int(args[name][0]) except: message = "Query parameter %r must be an integer" % (name,) raise SynapseError(400, message) @@ -53,14 +62,19 @@ def parse_integer(request, name, default=None, required=False): def parse_boolean(request, name, default=None, required=False): """Parse a boolean parameter from the request query string - :param request: the twisted HTTP request. - :param name (str): the name of the query parameter. - :param default: value to use if the parameter is absent, defaults to None. - :param required (bool): whether to raise a 400 SynapseError if the - parameter is absent, defaults to False. - :return: A bool value or the default. - :raises - SynapseError if the parameter is absent and required, or if the + Args: + request: the twisted HTTP request. + name (str): the name of the query parameter. + default (bool|None): value to use if the parameter is absent, defaults + to None. + required (bool): whether to raise a 400 SynapseError if the + parameter is absent, defaults to False. + + Returns: + bool|None: A bool value or the default. + + Raises: + SynapseError: if the parameter is absent and required, or if the parameter is present and not one of "true" or "false". """ @@ -88,22 +102,33 @@ def parse_string(request, name, default=None, required=False, allowed_values=None, param_type="string"): """Parse a string parameter from the request query string. - :param request: the twisted HTTP request. - :param name (str): the name of the query parameter. - :param default: value to use if the parameter is absent, defaults to None. - :param required (bool): whether to raise a 400 SynapseError if the - parameter is absent, defaults to False. - :param allowed_values (list): List of allowed values for the string, - or None if any value is allowed, defaults to None - :return: A string value or the default. - :raises + Args: + request: the twisted HTTP request. + name (str): the name of the query parameter. + default (str|None): value to use if the parameter is absent, defaults + to None. + required (bool): whether to raise a 400 SynapseError if the + parameter is absent, defaults to False. + allowed_values (list[str]): List of allowed values for the string, + or None if any value is allowed, defaults to None + + Returns: + str|None: A string value or the default. + + Raises: SynapseError if the parameter is absent and required, or if the parameter is present, must be one of a list of allowed values and is not one of those allowed values. """ + return parse_string_from_args( + request.args, name, default, required, allowed_values, param_type, + ) - if name in request.args: - value = request.args[name][0] + +def parse_string_from_args(args, name, default=None, required=False, + allowed_values=None, param_type="string"): + if name in args: + value = args[name][0] if allowed_values is not None and value not in allowed_values: message = "Query parameter %r must be one of [%s]" % ( name, ", ".join(repr(v) for v in allowed_values) @@ -122,9 +147,13 @@ def parse_string(request, name, default=None, required=False, def parse_json_value_from_request(request): """Parse a JSON value from the body of a twisted HTTP request. - :param request: the twisted HTTP request. - :returns: The JSON value. - :raises + Args: + request: the twisted HTTP request. + + Returns: + The JSON value. + + Raises: SynapseError if the request body couldn't be decoded as JSON. """ try: @@ -143,8 +172,10 @@ def parse_json_value_from_request(request): def parse_json_object_from_request(request): """Parse a JSON object from the body of a twisted HTTP request. - :param request: the twisted HTTP request. - :raises + Args: + request: the twisted HTTP request. + + Raises: SynapseError if the request body couldn't be decoded as JSON or if it wasn't a JSON object. """ diff --git a/synapse/http/site.py b/synapse/http/site.py new file mode 100644 index 0000000000..4b09d7ee66 --- /dev/null +++ b/synapse/http/site.py @@ -0,0 +1,146 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.util.logcontext import LoggingContext +from twisted.web.server import Site, Request + +import contextlib +import logging +import re +import time + +ACCESS_TOKEN_RE = re.compile(r'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$') + + +class SynapseRequest(Request): + def __init__(self, site, *args, **kw): + Request.__init__(self, *args, **kw) + self.site = site + self.authenticated_entity = None + self.start_time = 0 + + def __repr__(self): + # We overwrite this so that we don't log ``access_token`` + return '<%s at 0x%x method=%s uri=%s clientproto=%s site=%s>' % ( + self.__class__.__name__, + id(self), + self.method, + self.get_redacted_uri(), + self.clientproto, + self.site.site_tag, + ) + + def get_redacted_uri(self): + return ACCESS_TOKEN_RE.sub( + r'\1<redacted>\3', + self.uri + ) + + def get_user_agent(self): + return self.requestHeaders.getRawHeaders("User-Agent", [None])[-1] + + def started_processing(self): + self.site.access_logger.info( + "%s - %s - Received request: %s %s", + self.getClientIP(), + self.site.site_tag, + self.method, + self.get_redacted_uri() + ) + self.start_time = int(time.time() * 1000) + + def finished_processing(self): + + try: + context = LoggingContext.current_context() + ru_utime, ru_stime = context.get_resource_usage() + db_txn_count = context.db_txn_count + db_txn_duration = context.db_txn_duration + except: + ru_utime, ru_stime = (0, 0) + db_txn_count, db_txn_duration = (0, 0) + + self.site.access_logger.info( + "%s - %s - {%s}" + " Processed request: %dms (%dms, %dms) (%dms/%d)" + " %sB %s \"%s %s %s\" \"%s\"", + self.getClientIP(), + self.site.site_tag, + self.authenticated_entity, + int(time.time() * 1000) - self.start_time, + int(ru_utime * 1000), + int(ru_stime * 1000), + int(db_txn_duration * 1000), + int(db_txn_count), + self.sentLength, + self.code, + self.method, + self.get_redacted_uri(), + self.clientproto, + self.get_user_agent(), + ) + + @contextlib.contextmanager + def processing(self): + self.started_processing() + yield + self.finished_processing() + + +class XForwardedForRequest(SynapseRequest): + def __init__(self, *args, **kw): + SynapseRequest.__init__(self, *args, **kw) + + """ + Add a layer on top of another request that only uses the value of an + X-Forwarded-For header as the result of C{getClientIP}. + """ + def getClientIP(self): + """ + @return: The client address (the first address) in the value of the + I{X-Forwarded-For header}. If the header is not present, return + C{b"-"}. + """ + return self.requestHeaders.getRawHeaders( + b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip() + + +class SynapseRequestFactory(object): + def __init__(self, site, x_forwarded_for): + self.site = site + self.x_forwarded_for = x_forwarded_for + + def __call__(self, *args, **kwargs): + if self.x_forwarded_for: + return XForwardedForRequest(self.site, *args, **kwargs) + else: + return SynapseRequest(self.site, *args, **kwargs) + + +class SynapseSite(Site): + """ + Subclass of a twisted http Site that does access logging with python's + standard logging + """ + def __init__(self, logger_name, site_tag, config, resource, *args, **kwargs): + Site.__init__(self, resource, *args, **kwargs) + + self.site_tag = site_tag + + proxied = config.get("x_forwarded", False) + self.requestFactory = SynapseRequestFactory(self, proxied) + self.access_logger = logging.getLogger(logger_name) + + def log(self, request): + pass diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index 5664d5a381..2265e6e8d6 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -13,31 +13,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Because otherwise 'resource' collides with synapse.metrics.resource -from __future__ import absolute_import - import logging -from resource import getrusage, RUSAGE_SELF import functools -import os -import stat import time +import gc from twisted.internet import reactor from .metric import ( - CounterMetric, CallbackMetric, DistributionMetric, CacheMetric + CounterMetric, CallbackMetric, DistributionMetric, CacheMetric, + MemoryUsageMetric, ) +from .process_collector import register_process_collector logger = logging.getLogger(__name__) -# We'll keep all the available metrics in a single toplevel dict, one shared -# for the entire process. We don't currently support per-HomeServer instances -# of metrics, because in practice any one python VM will host only one -# HomeServer anyway. This makes a lot of implementation neater -all_metrics = {} +all_metrics = [] +all_collectors = [] class Metrics(object): @@ -48,12 +42,18 @@ class Metrics(object): def __init__(self, name): self.name_prefix = name + def make_subspace(self, name): + return Metrics("%s_%s" % (self.name_prefix, name)) + + def register_collector(self, func): + all_collectors.append(func) + def _register(self, metric_class, name, *args, **kwargs): full_name = "%s_%s" % (self.name_prefix, name) metric = metric_class(full_name, *args, **kwargs) - all_metrics[full_name] = metric + all_metrics.append(metric) return metric def register_counter(self, *args, **kwargs): @@ -69,6 +69,21 @@ class Metrics(object): return self._register(CacheMetric, *args, **kwargs) +def register_memory_metrics(hs): + try: + import psutil + process = psutil.Process() + process.memory_info().rss + except (ImportError, AttributeError): + logger.warn( + "psutil is not installed or incorrect version." + " Disabling memory metrics." + ) + return + metric = MemoryUsageMetric(hs, psutil) + all_metrics.append(metric) + + def get_metrics_for(pkg_name): """ Returns a Metrics instance for conveniently creating metrics namespaced with the given name prefix. """ @@ -81,78 +96,33 @@ def get_metrics_for(pkg_name): def render_all(): strs = [] - # TODO(paul): Internal hack - update_resource_metrics() + for collector in all_collectors: + collector() - for name in sorted(all_metrics.keys()): + for metric in all_metrics: try: - strs += all_metrics[name].render() + strs += metric.render() except Exception: - strs += ["# FAILED to render %s" % name] - logger.exception("Failed to render %s metric", name) + strs += ["# FAILED to render"] + logger.exception("Failed to render metric") strs.append("") # to generate a final CRLF return "\n".join(strs) -# Now register some standard process-wide state metrics, to give indications of -# process resource usage - -rusage = None - - -def update_resource_metrics(): - global rusage - rusage = getrusage(RUSAGE_SELF) - -resource_metrics = get_metrics_for("process.resource") - -# msecs -resource_metrics.register_callback("utime", lambda: rusage.ru_utime * 1000) -resource_metrics.register_callback("stime", lambda: rusage.ru_stime * 1000) - -# kilobytes -resource_metrics.register_callback("maxrss", lambda: rusage.ru_maxrss * 1024) - -TYPES = { - stat.S_IFSOCK: "SOCK", - stat.S_IFLNK: "LNK", - stat.S_IFREG: "REG", - stat.S_IFBLK: "BLK", - stat.S_IFDIR: "DIR", - stat.S_IFCHR: "CHR", - stat.S_IFIFO: "FIFO", -} +register_process_collector(get_metrics_for("process")) -def _process_fds(): - counts = {(k,): 0 for k in TYPES.values()} - counts[("other",)] = 0 +python_metrics = get_metrics_for("python") - # Not every OS will have a /proc/self/fd directory - if not os.path.exists("/proc/self/fd"): - return counts - - for fd in os.listdir("/proc/self/fd"): - try: - s = os.stat("/proc/self/fd/%s" % (fd)) - fmt = stat.S_IFMT(s.st_mode) - if fmt in TYPES: - t = TYPES[fmt] - else: - t = "other" - - counts[(t,)] += 1 - except OSError: - # the dirh itself used by listdir() is usually missing by now - pass - - return counts - -get_metrics_for("process").register_callback("fds", _process_fds, labels=["type"]) +gc_time = python_metrics.register_distribution("gc_time", labels=["gen"]) +gc_unreachable = python_metrics.register_counter("gc_unreachable_total", labels=["gen"]) +python_metrics.register_callback( + "gc_counts", lambda: {(i,): v for i, v in enumerate(gc.get_count())}, labels=["gen"] +) -reactor_metrics = get_metrics_for("reactor") +reactor_metrics = get_metrics_for("python.twisted.reactor") tick_time = reactor_metrics.register_distribution("tick_time") pending_calls_metric = reactor_metrics.register_distribution("pending_calls") @@ -182,6 +152,22 @@ def runUntilCurrentTimer(func): end = time.time() * 1000 tick_time.inc_by(end - start) pending_calls_metric.inc_by(num_pending) + + # Check if we need to do a manual GC (since its been disabled), and do + # one if necessary. + threshold = gc.get_threshold() + counts = gc.get_count() + for i in (2, 1, 0): + if threshold[i] < counts[i]: + logger.info("Collecting gc %d", i) + + start = time.time() * 1000 + unreachable = gc.collect(i) + end = time.time() * 1000 + + gc_time.inc_by(end - start, i) + gc_unreachable.inc_by(unreachable, i) + return ret return f @@ -196,5 +182,9 @@ try: # runUntilCurrent is called when we have pending calls. It is called once # per iteratation after fd polling. reactor.runUntilCurrent = runUntilCurrentTimer(reactor.runUntilCurrent) + + # We manually run the GC each reactor tick so that we can get some metrics + # about time spent doing GC, + gc.disable() except AttributeError: pass diff --git a/synapse/metrics/metric.py b/synapse/metrics/metric.py index 368fc24984..e87b2b80a7 100644 --- a/synapse/metrics/metric.py +++ b/synapse/metrics/metric.py @@ -47,9 +47,6 @@ class BaseMetric(object): for k, v in zip(self.labels, values)]) ) - def render(self): - return map_concat(self.render_item, sorted(self.counts.keys())) - class CounterMetric(BaseMetric): """The simplest kind of metric; one that stores a monotonically-increasing @@ -83,6 +80,9 @@ class CounterMetric(BaseMetric): def render_item(self, k): return ["%s%s %d" % (self.name, self._render_key(k), self.counts[k])] + def render(self): + return map_concat(self.render_item, sorted(self.counts.keys())) + class CallbackMetric(BaseMetric): """A metric that returns the numeric value returned by a callback whenever @@ -98,9 +98,9 @@ class CallbackMetric(BaseMetric): value = self.callback() if self.is_scalar(): - return ["%s %d" % (self.name, value)] + return ["%s %.12g" % (self.name, value)] - return ["%s%s %d" % (self.name, self._render_key(k), value[k]) + return ["%s%s %.12g" % (self.name, self._render_key(k), value[k]) for k in sorted(value.keys())] @@ -126,30 +126,70 @@ class DistributionMetric(object): class CacheMetric(object): - """A combination of two CounterMetrics, one to count cache hits and one to - count a total, and a callback metric to yield the current size. - - This metric generates standard metric name pairs, so that monitoring rules - can easily be applied to measure hit ratio.""" + __slots__ = ("name", "cache_name", "hits", "misses", "size_callback") - def __init__(self, name, size_callback, labels=[]): + def __init__(self, name, size_callback, cache_name): self.name = name + self.cache_name = cache_name - self.hits = CounterMetric(name + ":hits", labels=labels) - self.total = CounterMetric(name + ":total", labels=labels) + self.hits = 0 + self.misses = 0 - self.size = CallbackMetric( - name + ":size", - callback=size_callback, - labels=labels, - ) + self.size_callback = size_callback + + def inc_hits(self): + self.hits += 1 + + def inc_misses(self): + self.misses += 1 + + def render(self): + size = self.size_callback() + hits = self.hits + total = self.misses + self.hits + + return [ + """%s:hits{name="%s"} %d""" % (self.name, self.cache_name, hits), + """%s:total{name="%s"} %d""" % (self.name, self.cache_name, total), + """%s:size{name="%s"} %d""" % (self.name, self.cache_name, size), + ] + + +class MemoryUsageMetric(object): + """Keeps track of the current memory usage, using psutil. + + The class will keep the current min/max/sum/counts of rss over the last + WINDOW_SIZE_SEC, by polling UPDATE_HZ times per second + """ + + UPDATE_HZ = 2 # number of times to get memory per second + WINDOW_SIZE_SEC = 30 # the size of the window in seconds + + def __init__(self, hs, psutil): + clock = hs.get_clock() + self.memory_snapshots = [] + + self.process = psutil.Process() - def inc_hits(self, *values): - self.hits.inc(*values) - self.total.inc(*values) + clock.looping_call(self._update_curr_values, 1000 / self.UPDATE_HZ) - def inc_misses(self, *values): - self.total.inc(*values) + def _update_curr_values(self): + max_size = self.UPDATE_HZ * self.WINDOW_SIZE_SEC + self.memory_snapshots.append(self.process.memory_info().rss) + self.memory_snapshots[:] = self.memory_snapshots[-max_size:] def render(self): - return self.hits.render() + self.total.render() + self.size.render() + if not self.memory_snapshots: + return [] + + max_rss = max(self.memory_snapshots) + min_rss = min(self.memory_snapshots) + sum_rss = sum(self.memory_snapshots) + len_rss = len(self.memory_snapshots) + + return [ + "process_psutil_rss:max %d" % max_rss, + "process_psutil_rss:min %d" % min_rss, + "process_psutil_rss:total %d" % sum_rss, + "process_psutil_rss:count %d" % len_rss, + ] diff --git a/synapse/metrics/process_collector.py b/synapse/metrics/process_collector.py new file mode 100644 index 0000000000..6fec3de399 --- /dev/null +++ b/synapse/metrics/process_collector.py @@ -0,0 +1,122 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + + +TICKS_PER_SEC = 100 +BYTES_PER_PAGE = 4096 + +HAVE_PROC_STAT = os.path.exists("/proc/stat") +HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat") +HAVE_PROC_SELF_LIMITS = os.path.exists("/proc/self/limits") +HAVE_PROC_SELF_FD = os.path.exists("/proc/self/fd") + +# Field indexes from /proc/self/stat, taken from the proc(5) manpage +STAT_FIELDS = { + "utime": 14, + "stime": 15, + "starttime": 22, + "vsize": 23, + "rss": 24, +} + + +stats = {} + +# In order to report process_start_time_seconds we need to know the +# machine's boot time, because the value in /proc/self/stat is relative to +# this +boot_time = None +if HAVE_PROC_STAT: + with open("/proc/stat") as _procstat: + for line in _procstat: + if line.startswith("btime "): + boot_time = int(line.split()[1]) + + +def update_resource_metrics(): + if HAVE_PROC_SELF_STAT: + global stats + with open("/proc/self/stat") as s: + line = s.read() + # line is PID (command) more stats go here ... + raw_stats = line.split(") ", 1)[1].split(" ") + + for (name, index) in STAT_FIELDS.iteritems(): + # subtract 3 from the index, because proc(5) is 1-based, and + # we've lost the first two fields in PID and COMMAND above + stats[name] = int(raw_stats[index - 3]) + + +def _count_fds(): + # Not every OS will have a /proc/self/fd directory + if not HAVE_PROC_SELF_FD: + return 0 + + return len(os.listdir("/proc/self/fd")) + + +def register_process_collector(process_metrics): + process_metrics.register_collector(update_resource_metrics) + + if HAVE_PROC_SELF_STAT: + process_metrics.register_callback( + "cpu_user_seconds_total", + lambda: float(stats["utime"]) / TICKS_PER_SEC + ) + process_metrics.register_callback( + "cpu_system_seconds_total", + lambda: float(stats["stime"]) / TICKS_PER_SEC + ) + process_metrics.register_callback( + "cpu_seconds_total", + lambda: (float(stats["utime"] + stats["stime"])) / TICKS_PER_SEC + ) + + process_metrics.register_callback( + "virtual_memory_bytes", + lambda: int(stats["vsize"]) + ) + process_metrics.register_callback( + "resident_memory_bytes", + lambda: int(stats["rss"]) * BYTES_PER_PAGE + ) + + process_metrics.register_callback( + "start_time_seconds", + lambda: boot_time + int(stats["starttime"]) / TICKS_PER_SEC + ) + + if HAVE_PROC_SELF_FD: + process_metrics.register_callback( + "open_fds", + lambda: _count_fds() + ) + + if HAVE_PROC_SELF_LIMITS: + def _get_max_fds(): + with open("/proc/self/limits") as limits: + for line in limits: + if not line.startswith("Max open files "): + continue + # Line is Max open files $SOFT $HARD + return int(line.split()[3]) + return None + + process_metrics.register_callback( + "max_fds", + lambda: _get_max_fds() + ) diff --git a/synapse/notifier.py b/synapse/notifier.py index f00cd8c588..054ca59ad2 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -14,13 +14,15 @@ # limitations under the License. from twisted.internet import defer -from synapse.api.constants import EventTypes +from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError from synapse.util.logutils import log_function from synapse.util.async import ObservableDeferred -from synapse.util.logcontext import PreserveLoggingContext +from synapse.util.logcontext import PreserveLoggingContext, preserve_fn +from synapse.util.metrics import Measure from synapse.types import StreamToken +from synapse.visibility import filter_events_for_client import synapse.metrics from collections import namedtuple @@ -66,10 +68,8 @@ class _NotifierUserStream(object): so that it can remove itself from the indexes in the Notifier class. """ - def __init__(self, user_id, rooms, current_token, time_now_ms, - appservice=None): + def __init__(self, user_id, rooms, current_token, time_now_ms): self.user_id = user_id - self.appservice = appservice self.rooms = set(rooms) self.current_token = current_token self.last_notified_ms = time_now_ms @@ -106,11 +106,6 @@ class _NotifierUserStream(object): notifier.user_to_user_stream.pop(self.user_id) - if self.appservice: - notifier.appservice_to_user_streams.get( - self.appservice, set() - ).discard(self) - def count_listeners(self): return len(self.notify_deferred.observers()) @@ -139,21 +134,22 @@ class Notifier(object): UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000 def __init__(self, hs): - self.hs = hs - self.user_to_user_stream = {} self.room_to_user_streams = {} - self.appservice_to_user_streams = {} self.event_sources = hs.get_event_sources() self.store = hs.get_datastore() self.pending_new_room_events = [] self.clock = hs.get_clock() + self.appservice_handler = hs.get_application_service_handler() - hs.get_distributor().observe( - "user_joined_room", self._user_joined_room - ) + if hs.should_send_federation(): + self.federation_sender = hs.get_federation_sender() + else: + self.federation_sender = None + + self.state_handler = hs.get_state_handler() self.clock.looping_call( self.remove_expired_streams, self.UNUSED_STREAM_EXPIRY_MS @@ -171,8 +167,6 @@ class Notifier(object): all_user_streams |= x for x in self.user_to_user_stream.values(): all_user_streams.add(x) - for x in self.appservice_to_user_streams.values(): - all_user_streams |= x return sum(stream.count_listeners() for stream in all_user_streams) metrics.register_callback("listeners", count_listeners) @@ -185,11 +179,8 @@ class Notifier(object): "users", lambda: len(self.user_to_user_stream), ) - metrics.register_callback( - "appservices", - lambda: count(bool, self.appservice_to_user_streams.values()), - ) + @preserve_fn def on_new_room_event(self, event, room_stream_id, max_room_stream_id, extra_users=[]): """ Used by handlers to inform the notifier something has happened @@ -211,6 +202,7 @@ class Notifier(object): self.notify_replication() + @preserve_fn def _notify_pending_new_room_events(self, max_room_stream_id): """Notify for the room events that were queued waiting for a previous event to be persisted. @@ -228,60 +220,52 @@ class Notifier(object): else: self._on_new_room_event(event, room_stream_id, extra_users) + @preserve_fn def _on_new_room_event(self, event, room_stream_id, extra_users=[]): """Notify any user streams that are interested in this room event""" # poke any interested application service. - self.hs.get_handlers().appservice_handler.notify_interested_services( - event - ) + self.appservice_handler.notify_interested_services(room_stream_id) - app_streams = set() - - for appservice in self.appservice_to_user_streams: - # TODO (kegan): Redundant appservice listener checks? - # App services will already be in the room_to_user_streams set, but - # that isn't enough. They need to be checked here in order to - # receive *invites* for users they are interested in. Does this - # make the room_to_user_streams check somewhat obselete? - if appservice.is_interested(event): - app_user_streams = self.appservice_to_user_streams.get( - appservice, set() - ) - app_streams |= app_user_streams + if self.federation_sender: + self.federation_sender.notify_new_events(room_stream_id) + + if event.type == EventTypes.Member and event.membership == Membership.JOIN: + self._user_joined_room(event.state_key, event.room_id) self.on_new_event( "room_key", room_stream_id, users=extra_users, rooms=[event.room_id], - extra_streams=app_streams, ) - def on_new_event(self, stream_key, new_token, users=[], rooms=[], - extra_streams=set()): + @preserve_fn + def on_new_event(self, stream_key, new_token, users=[], rooms=[]): """ Used to inform listeners that something has happend event wise. Will wake up all listeners for the given users and rooms. """ with PreserveLoggingContext(): - user_streams = set() + with Measure(self.clock, "on_new_event"): + user_streams = set() - for user in users: - user_stream = self.user_to_user_stream.get(str(user)) - if user_stream is not None: - user_streams.add(user_stream) + for user in users: + user_stream = self.user_to_user_stream.get(str(user)) + if user_stream is not None: + user_streams.add(user_stream) - for room in rooms: - user_streams |= self.room_to_user_streams.get(room, set()) + for room in rooms: + user_streams |= self.room_to_user_streams.get(room, set()) - time_now_ms = self.clock.time_msec() - for user_stream in user_streams: - try: - user_stream.notify(stream_key, new_token, time_now_ms) - except: - logger.exception("Failed to notify listener") + time_now_ms = self.clock.time_msec() + for user_stream in user_streams: + try: + user_stream.notify(stream_key, new_token, time_now_ms) + except: + logger.exception("Failed to notify listener") - self.notify_replication() + self.notify_replication() + @preserve_fn def on_new_replication_data(self): """Used to inform replication listeners that something has happend without waking up any of the normal user event streams""" @@ -296,7 +280,6 @@ class Notifier(object): """ user_stream = self.user_to_user_stream.get(user_id) if user_stream is None: - appservice = yield self.store.get_app_service_by_user_id(user_id) current_token = yield self.event_sources.get_current_token() if room_ids is None: rooms = yield self.store.get_rooms_for_user(user_id) @@ -304,7 +287,6 @@ class Notifier(object): user_stream = _NotifierUserStream( user_id=user_id, rooms=room_ids, - appservice=appservice, current_token=current_token, time_now_ms=self.clock.time_msec(), ) @@ -398,8 +380,8 @@ class Notifier(object): ) if name == "room": - room_member_handler = self.hs.get_handlers().room_member_handler - new_events = yield room_member_handler._filter_events_for_client( + new_events = yield filter_events_for_client( + self.store, user.to_string(), new_events, is_peeking=is_peeking, @@ -448,9 +430,10 @@ class Notifier(object): @defer.inlineCallbacks def _is_world_readable(self, room_id): - state = yield self.hs.get_state_handler().get_current_state( + state = yield self.state_handler.get_current_state( room_id, - EventTypes.RoomHistoryVisibility + EventTypes.RoomHistoryVisibility, + "", ) if state and "history_visibility" in state.content: defer.returnValue(state.content["history_visibility"] == "world_readable") @@ -479,14 +462,8 @@ class Notifier(object): s = self.room_to_user_streams.setdefault(room, set()) s.add(user_stream) - if user_stream.appservice: - self.appservice_to_user_stream.setdefault( - user_stream.appservice, set() - ).add(user_stream) - - def _user_joined_room(self, user, room_id): - user = str(user) - new_user_stream = self.user_to_user_stream.get(user) + def _user_joined_room(self, user_id, room_id): + new_user_stream = self.user_to_user_stream.get(user_id) if new_user_stream is not None: room_streams = self.room_to_user_streams.setdefault(room_id, set()) room_streams.add(new_user_stream) @@ -503,13 +480,14 @@ class Notifier(object): def wait_for_replication(self, callback, timeout): """Wait for an event to happen. - :param callback: - Gets called whenever an event happens. If this returns a truthy - value then ``wait_for_replication`` returns, otherwise it waits - for another event. - :param int timeout: - How many milliseconds to wait for callback return a truthy value. - :returns: + Args: + callback: Gets called whenever an event happens. If this returns a + truthy value then ``wait_for_replication`` returns, otherwise + it waits for another event. + timeout: How many milliseconds to wait for callback return a truthy + value. + + Returns: A deferred that resolves with the value returned by the callback. """ listener = _NotificationListener(None) diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py index 65ef1b68a3..edf45dc599 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py @@ -13,333 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - -from synapse.streams.config import PaginationConfig -from synapse.types import StreamToken -from synapse.util.logcontext import LoggingContext -from synapse.util.metrics import Measure - -import synapse.util.async -from .push_rule_evaluator import evaluator_for_user_id - -import logging -import random - -logger = logging.getLogger(__name__) - - -_NEXT_ID = 1 - - -def _get_next_id(): - global _NEXT_ID - _id = _NEXT_ID - _NEXT_ID += 1 - return _id - - -# Pushers could now be moved to pull out of the event_push_actions table instead -# of listening on the event stream: this would avoid them having to run the -# rules again. -class Pusher(object): - INITIAL_BACKOFF = 1000 - MAX_BACKOFF = 60 * 60 * 1000 - GIVE_UP_AFTER = 24 * 60 * 60 * 1000 - - def __init__(self, _hs, user_id, app_id, - app_display_name, device_display_name, pushkey, pushkey_ts, - data, last_token, last_success, failing_since): - self.hs = _hs - self.evStreamHandler = self.hs.get_handlers().event_stream_handler - self.store = self.hs.get_datastore() - self.clock = self.hs.get_clock() - self.user_id = user_id - self.app_id = app_id - self.app_display_name = app_display_name - self.device_display_name = device_display_name - self.pushkey = pushkey - self.pushkey_ts = pushkey_ts - self.data = data - self.last_token = last_token - self.last_success = last_success # not actually used - self.backoff_delay = Pusher.INITIAL_BACKOFF - self.failing_since = failing_since - self.alive = True - self.badge = None - - self.name = "Pusher-%d" % (_get_next_id(),) - - # The last value of last_active_time that we saw - self.last_last_active_time = 0 - self.has_unread = True - - @defer.inlineCallbacks - def get_context_for_event(self, ev): - name_aliases = yield self.store.get_room_name_and_aliases( - ev['room_id'] - ) - - ctx = {'aliases': name_aliases[1]} - if name_aliases[0] is not None: - ctx['name'] = name_aliases[0] - - their_member_events_for_room = yield self.store.get_current_state( - room_id=ev['room_id'], - event_type='m.room.member', - state_key=ev['user_id'] - ) - for mev in their_member_events_for_room: - if mev.content['membership'] == 'join' and 'displayname' in mev.content: - dn = mev.content['displayname'] - if dn is not None: - ctx['sender_display_name'] = dn - - defer.returnValue(ctx) - - @defer.inlineCallbacks - def start(self): - with LoggingContext(self.name): - if not self.last_token: - # First-time setup: get a token to start from (we can't - # just start from no token, ie. 'now' - # because we need the result to be reproduceable in case - # we fail to dispatch the push) - config = PaginationConfig(from_token=None, limit='1') - chunk = yield self.evStreamHandler.get_stream( - self.user_id, config, timeout=0, affect_presence=False - ) - self.last_token = chunk['end'] - yield self.store.update_pusher_last_token( - self.app_id, self.pushkey, self.user_id, self.last_token - ) - logger.info("New pusher %s for user %s starting from token %s", - self.pushkey, self.user_id, self.last_token) - - else: - logger.info( - "Old pusher %s for user %s starting", - self.pushkey, self.user_id, - ) - - wait = 0 - while self.alive: - try: - if wait > 0: - yield synapse.util.async.sleep(wait) - with Measure(self.clock, "push"): - yield self.get_and_dispatch() - wait = 0 - except: - if wait == 0: - wait = 1 - else: - wait = min(wait * 2, 1800) - logger.exception( - "Exception in pusher loop for pushkey %s. Pausing for %ds", - self.pushkey, wait - ) - - @defer.inlineCallbacks - def get_and_dispatch(self): - from_tok = StreamToken.from_string(self.last_token) - config = PaginationConfig(from_token=from_tok, limit='1') - timeout = (300 + random.randint(-60, 60)) * 1000 - chunk = yield self.evStreamHandler.get_stream( - self.user_id, config, timeout=timeout, affect_presence=False, - only_keys=("room", "receipt",), - ) - - # limiting to 1 may get 1 event plus 1 presence event, so - # pick out the actual event - single_event = None - read_receipt = None - for c in chunk['chunk']: - if 'event_id' in c: # Hmmm... - single_event = c - elif c['type'] == 'm.receipt': - read_receipt = c - - have_updated_badge = False - if read_receipt: - for receipt_part in read_receipt['content'].values(): - if 'm.read' in receipt_part: - if self.user_id in receipt_part['m.read'].keys(): - have_updated_badge = True - - if not single_event: - if have_updated_badge: - yield self.update_badge() - self.last_token = chunk['end'] - yield self.store.update_pusher_last_token( - self.app_id, - self.pushkey, - self.user_id, - self.last_token - ) - return - - if not self.alive: - return - - processed = False - - rule_evaluator = yield \ - evaluator_for_user_id( - self.user_id, single_event['room_id'], self.store - ) - - actions = yield rule_evaluator.actions_for_event(single_event) - tweaks = rule_evaluator.tweaks_for_actions(actions) - - if 'notify' in actions: - self.badge = yield self._get_badge_count() - rejected = yield self.dispatch_push(single_event, tweaks, self.badge) - self.has_unread = True - if isinstance(rejected, list) or isinstance(rejected, tuple): - processed = True - for pk in rejected: - if pk != self.pushkey: - # for sanity, we only remove the pushkey if it - # was the one we actually sent... - logger.warn( - ("Ignoring rejected pushkey %s because we" - " didn't send it"), pk - ) - else: - logger.info( - "Pushkey %s was rejected: removing", - pk - ) - yield self.hs.get_pusherpool().remove_pusher( - self.app_id, pk, self.user_id - ) - else: - if have_updated_badge: - yield self.update_badge() - processed = True - - if not self.alive: - return - - if processed: - self.backoff_delay = Pusher.INITIAL_BACKOFF - self.last_token = chunk['end'] - yield self.store.update_pusher_last_token_and_success( - self.app_id, - self.pushkey, - self.user_id, - self.last_token, - self.clock.time_msec() - ) - if self.failing_since: - self.failing_since = None - yield self.store.update_pusher_failing_since( - self.app_id, - self.pushkey, - self.user_id, - self.failing_since) - else: - if not self.failing_since: - self.failing_since = self.clock.time_msec() - yield self.store.update_pusher_failing_since( - self.app_id, - self.pushkey, - self.user_id, - self.failing_since - ) - - if (self.failing_since and - self.failing_since < - self.clock.time_msec() - Pusher.GIVE_UP_AFTER): - # we really only give up so that if the URL gets - # fixed, we don't suddenly deliver a load - # of old notifications. - logger.warn("Giving up on a notification to user %s, " - "pushkey %s", - self.user_id, self.pushkey) - self.backoff_delay = Pusher.INITIAL_BACKOFF - self.last_token = chunk['end'] - yield self.store.update_pusher_last_token( - self.app_id, - self.pushkey, - self.user_id, - self.last_token - ) - - self.failing_since = None - yield self.store.update_pusher_failing_since( - self.app_id, - self.pushkey, - self.user_id, - self.failing_since - ) - else: - logger.warn("Failed to dispatch push for user %s " - "(failing for %dms)." - "Trying again in %dms", - self.user_id, - self.clock.time_msec() - self.failing_since, - self.backoff_delay) - yield synapse.util.async.sleep(self.backoff_delay / 1000.0) - self.backoff_delay *= 2 - if self.backoff_delay > Pusher.MAX_BACKOFF: - self.backoff_delay = Pusher.MAX_BACKOFF - - def stop(self): - self.alive = False - - def dispatch_push(self, p, tweaks, badge): - """ - Overridden by implementing classes to actually deliver the notification - Args: - p: The event to notify for as a single event from the event stream - Returns: If the notification was delivered, an array containing any - pushkeys that were rejected by the push gateway. - False if the notification could not be delivered (ie. - should be retried). - """ - pass - - @defer.inlineCallbacks - def update_badge(self): - new_badge = yield self._get_badge_count() - if self.badge != new_badge: - self.badge = new_badge - yield self.send_badge(self.badge) - - def send_badge(self, badge): - """ - Overridden by implementing classes to send an updated badge count - """ - pass - - @defer.inlineCallbacks - def _get_badge_count(self): - invites, joins = yield defer.gatherResults([ - self.store.get_invites_for_user(self.user_id), - self.store.get_rooms_for_user(self.user_id), - ], consumeErrors=True) - - my_receipts_by_room = yield self.store.get_receipts_for_user( - self.user_id, - "m.read", - ) - - badge = len(invites) - - for r in joins: - if r.room_id in my_receipts_by_room: - last_unread_event_id = my_receipts_by_room[r.room_id] - - notifs = yield ( - self.store.get_unread_event_push_actions_by_room_for_user( - r.room_id, self.user_id, last_unread_event_id - ) - ) - badge += notifs["notify_count"] - defer.returnValue(badge) - class PusherConfigException(Exception): def __init__(self, msg): diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py index 84efcdd184..3f75d3f921 100644 --- a/synapse/push/action_generator.py +++ b/synapse/push/action_generator.py @@ -15,7 +15,9 @@ from twisted.internet import defer -from .bulk_push_rule_evaluator import evaluator_for_room_id +from .bulk_push_rule_evaluator import evaluator_for_event + +from synapse.util.metrics import Measure import logging @@ -25,6 +27,7 @@ logger = logging.getLogger(__name__) class ActionGenerator: def __init__(self, hs): self.hs = hs + self.clock = hs.get_clock() self.store = hs.get_datastore() # really we want to get all user ids and all profile tags too, # since we want the actions for each profile tag for every user and @@ -34,14 +37,16 @@ class ActionGenerator: # tag (ie. we just need all the users). @defer.inlineCallbacks - def handle_push_actions_for_event(self, event, context, handler): - bulk_evaluator = yield evaluator_for_room_id( - event.room_id, self.hs, self.store - ) - - actions_by_user = yield bulk_evaluator.action_for_event_by_user( - event, handler, context.current_state - ) + def handle_push_actions_for_event(self, event, context): + with Measure(self.clock, "evaluator_for_event"): + bulk_evaluator = yield evaluator_for_event( + event, self.hs, self.store, context + ) + + with Measure(self.clock, "action_for_event_by_user"): + actions_by_user = yield bulk_evaluator.action_for_event_by_user( + event, context + ) context.push_actions = [ (uid, actions) for uid, actions in actions_by_user.items() diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py index 86a2998bcc..85effdfa46 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py @@ -19,9 +19,11 @@ import copy def list_with_base_rules(rawrules): """Combine the list of rules set by the user with the default push rules - :param list rawrules: The rules the user has modified or set. - :returns: A new list with the rules set by the user combined with the - defaults. + Args: + rawrules(list): The rules the user has modified or set. + + Returns: + A new list with the rules set by the user combined with the defaults. """ ruleslist = [] @@ -77,7 +79,7 @@ def make_base_append_rules(kind, modified_base_rules): rules = [] if kind == 'override': - rules = BASE_APPEND_OVRRIDE_RULES + rules = BASE_APPEND_OVERRIDE_RULES elif kind == 'underride': rules = BASE_APPEND_UNDERRIDE_RULES elif kind == 'content': @@ -146,7 +148,7 @@ BASE_PREPEND_OVERRIDE_RULES = [ ] -BASE_APPEND_OVRRIDE_RULES = [ +BASE_APPEND_OVERRIDE_RULES = [ { 'rule_id': 'global/override/.m.rule.suppress_notices', 'conditions': [ @@ -160,34 +162,67 @@ BASE_APPEND_OVRRIDE_RULES = [ 'actions': [ 'dont_notify', ] - } -] - - -BASE_APPEND_UNDERRIDE_RULES = [ + }, + # NB. .m.rule.invite_for_me must be higher prio than .m.rule.member_event + # otherwise invites will be matched by .m.rule.member_event { - 'rule_id': 'global/underride/.m.rule.call', + 'rule_id': 'global/override/.m.rule.invite_for_me', 'conditions': [ { 'kind': 'event_match', 'key': 'type', - 'pattern': 'm.call.invite', - '_id': '_call', - } + 'pattern': 'm.room.member', + '_id': '_member', + }, + { + 'kind': 'event_match', + 'key': 'content.membership', + 'pattern': 'invite', + '_id': '_invite_member', + }, + { + 'kind': 'event_match', + 'key': 'state_key', + 'pattern_type': 'user_id' + }, ], 'actions': [ 'notify', { 'set_tweak': 'sound', - 'value': 'ring' + 'value': 'default' }, { 'set_tweak': 'highlight', 'value': False } ] }, + # Will we sometimes want to know about people joining and leaving? + # Perhaps: if so, this could be expanded upon. Seems the most usual case + # is that we don't though. We add this override rule so that even if + # the room rule is set to notify, we don't get notifications about + # join/leave/avatar/displayname events. + # See also: https://matrix.org/jira/browse/SYN-607 + { + 'rule_id': 'global/override/.m.rule.member_event', + 'conditions': [ + { + 'kind': 'event_match', + 'key': 'type', + 'pattern': 'm.room.member', + '_id': '_member', + } + ], + 'actions': [ + 'dont_notify' + ] + }, + # This was changed from underride to override so it's closer in priority + # to the content rules where the user name highlight rule lives. This + # way a room rule is lower priority than both but a custom override rule + # is higher priority than both. { - 'rule_id': 'global/underride/.m.rule.contains_display_name', + 'rule_id': 'global/override/.m.rule.contains_display_name', 'conditions': [ { 'kind': 'contains_display_name' @@ -203,6 +238,33 @@ BASE_APPEND_UNDERRIDE_RULES = [ } ] }, +] + + +BASE_APPEND_UNDERRIDE_RULES = [ + { + 'rule_id': 'global/underride/.m.rule.call', + 'conditions': [ + { + 'kind': 'event_match', + 'key': 'type', + 'pattern': 'm.call.invite', + '_id': '_call', + } + ], + 'actions': [ + 'notify', + { + 'set_tweak': 'sound', + 'value': 'ring' + }, { + 'set_tweak': 'highlight', + 'value': False + } + ] + }, + # XXX: once m.direct is standardised everywhere, we should use it to detect + # a DM from the user's perspective rather than this heuristic. { 'rule_id': 'global/underride/.m.rule.room_one_to_one', 'conditions': [ @@ -229,26 +291,22 @@ BASE_APPEND_UNDERRIDE_RULES = [ } ] }, + # XXX: this is going to fire for events which aren't m.room.messages + # but are encrypted (e.g. m.call.*)... { - 'rule_id': 'global/underride/.m.rule.invite_for_me', + 'rule_id': 'global/underride/.m.rule.encrypted_room_one_to_one', 'conditions': [ { - 'kind': 'event_match', - 'key': 'type', - 'pattern': 'm.room.member', - '_id': '_member', - }, - { - 'kind': 'event_match', - 'key': 'content.membership', - 'pattern': 'invite', - '_id': '_invite_member', + 'kind': 'room_member_count', + 'is': '2', + '_id': 'member_count', }, { 'kind': 'event_match', - 'key': 'state_key', - 'pattern_type': 'user_id' - }, + 'key': 'type', + 'pattern': 'm.room.encrypted', + '_id': '_encrypted', + } ], 'actions': [ 'notify', @@ -261,25 +319,6 @@ BASE_APPEND_UNDERRIDE_RULES = [ } ] }, - # This is too simple: https://matrix.org/jira/browse/SYN-607 - # Removing for now - # { - # 'rule_id': 'global/underride/.m.rule.member_event', - # 'conditions': [ - # { - # 'kind': 'event_match', - # 'key': 'type', - # 'pattern': 'm.room.member', - # '_id': '_member', - # } - # ], - # 'actions': [ - # 'notify', { - # 'set_tweak': 'highlight', - # 'value': False - # } - # ] - # }, { 'rule_id': 'global/underride/.m.rule.message', 'conditions': [ @@ -296,6 +335,25 @@ BASE_APPEND_UNDERRIDE_RULES = [ 'value': False } ] + }, + # XXX: this is going to fire for events which aren't m.room.messages + # but are encrypted (e.g. m.call.*)... + { + 'rule_id': 'global/underride/.m.rule.encrypted', + 'conditions': [ + { + 'kind': 'event_match', + 'key': 'type', + 'pattern': 'm.room.encrypted', + '_id': '_encrypted', + } + ], + 'actions': [ + 'notify', { + 'set_tweak': 'highlight', + 'value': False + } + ] } ] @@ -312,7 +370,7 @@ for r in BASE_PREPEND_OVERRIDE_RULES: r['default'] = True BASE_RULE_IDS.add(r['rule_id']) -for r in BASE_APPEND_OVRRIDE_RULES: +for r in BASE_APPEND_OVERRIDE_RULES: r['priority_class'] = PRIORITY_CLASS_MAP['override'] r['default'] = True BASE_RULE_IDS.add(r['rule_id']) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 87d5061fb0..be55598c43 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -14,71 +14,38 @@ # limitations under the License. import logging -import ujson as json from twisted.internet import defer -from .baserules import list_with_base_rules from .push_rule_evaluator import PushRuleEvaluatorForEvent from synapse.api.constants import EventTypes +from synapse.visibility import filter_events_for_clients_context logger = logging.getLogger(__name__) -def decode_rule_json(rule): - rule['conditions'] = json.loads(rule['conditions']) - rule['actions'] = json.loads(rule['actions']) - return rule - - @defer.inlineCallbacks -def _get_rules(room_id, user_ids, store): - rules_by_user = yield store.bulk_get_push_rules(user_ids) - rules_enabled_by_user = yield store.bulk_get_push_rules_enabled(user_ids) - - rules_by_user = { - uid: list_with_base_rules([ - decode_rule_json(rule_list) - for rule_list in rules_by_user.get(uid, []) - ]) - for uid in user_ids - } - - # We apply the rules-enabled map here: bulk_get_push_rules doesn't - # fetch disabled rules, but this won't account for any server default - # rules the user has disabled, so we need to do this too. - for uid in user_ids: - if uid not in rules_enabled_by_user: - continue - - user_enabled_map = rules_enabled_by_user[uid] - - for i, rule in enumerate(rules_by_user[uid]): - rule_id = rule['rule_id'] - - if rule_id in user_enabled_map: - if rule.get('enabled', True) != bool(user_enabled_map[rule_id]): - # Rules are cached across users. - rule = dict(rule) - rule['enabled'] = bool(user_enabled_map[rule_id]) - rules_by_user[uid][i] = rule - - defer.returnValue(rules_by_user) - - -@defer.inlineCallbacks -def evaluator_for_room_id(room_id, hs, store): - results = yield store.get_receipts_for_room(room_id, "m.read") - user_ids = [ - row["user_id"] for row in results - if hs.is_mine_id(row["user_id"]) - ] - rules_by_user = yield _get_rules(room_id, user_ids, store) +def evaluator_for_event(event, hs, store, context): + rules_by_user = yield store.bulk_get_push_rules_for_room( + event, context + ) + + # if this event is an invite event, we may need to run rules for the user + # who's been invited, otherwise they won't get told they've been invited + if event.type == 'm.room.member' and event.content['membership'] == 'invite': + invited_user = event.state_key + if invited_user and hs.is_mine_id(invited_user): + has_pusher = yield store.user_has_pusher(invited_user) + if has_pusher: + rules_by_user = dict(rules_by_user) + rules_by_user[invited_user] = yield store.get_push_rules_for_user( + invited_user + ) defer.returnValue(BulkPushRuleEvaluator( - room_id, rules_by_user, user_ids, store + event.room_id, rules_by_user, store )) @@ -91,34 +58,41 @@ class BulkPushRuleEvaluator: the same logic to run the actual rules, but could be optimised further (see https://matrix.org/jira/browse/SYN-562) """ - def __init__(self, room_id, rules_by_user, users_in_room, store): + def __init__(self, room_id, rules_by_user, store): self.room_id = room_id self.rules_by_user = rules_by_user - self.users_in_room = users_in_room self.store = store @defer.inlineCallbacks - def action_for_event_by_user(self, event, handler, current_state): + def action_for_event_by_user(self, event, context): actions_by_user = {} - users_dict = yield self.store.are_guests(self.rules_by_user.keys()) + # None of these users can be peeking since this list of users comes + # from the set of users in the room, so we know for sure they're all + # actually in the room. + user_tuples = [ + (u, False) for u in self.rules_by_user.keys() + ] - filtered_by_user = yield handler.filter_events_for_clients( - users_dict.items(), [event], {event.event_id: current_state} + filtered_by_user = yield filter_events_for_clients_context( + self.store, user_tuples, [event], {event.event_id: context} ) - evaluator = PushRuleEvaluatorForEvent(event, len(self.users_in_room)) + room_members = yield self.store.get_joined_users_from_context( + event, context + ) - condition_cache = {} + evaluator = PushRuleEvaluatorForEvent(event, len(room_members)) - display_names = {} - for ev in current_state.values(): - nm = ev.content.get("displayname", None) - if nm and ev.type == EventTypes.Member: - display_names[ev.state_key] = nm + condition_cache = {} for uid, rules in self.rules_by_user.items(): - display_name = display_names.get(uid, None) + display_name = None + member_ev_id = context.current_state_ids.get((EventTypes.Member, uid)) + if member_ev_id: + member_ev = yield self.store.get_event(member_ev_id, allow_none=True) + if member_ev: + display_name = member_ev.content.get("displayname", None) filtered = filtered_by_user[uid] if len(filtered) == 0: diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py index ae9db9ec2f..e0331b2d2d 100644 --- a/synapse/push/clientformat.py +++ b/synapse/push/clientformat.py @@ -13,29 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.push.baserules import list_with_base_rules - from synapse.push.rulekinds import ( PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP ) import copy -import simplejson as json -def format_push_rules_for_user(user, rawrules, enabled_map): +def format_push_rules_for_user(user, ruleslist): """Converts a list of rawrules and a enabled map into nested dictionaries to match the Matrix client-server format for push rules""" - ruleslist = [] - for rawrule in rawrules: - rule = dict(rawrule) - rule["conditions"] = json.loads(rawrule["conditions"]) - rule["actions"] = json.loads(rawrule["actions"]) - ruleslist.append(rule) - # We're going to be mutating this a lot, so do a deep copy - ruleslist = copy.deepcopy(list_with_base_rules(ruleslist)) + ruleslist = copy.deepcopy(ruleslist) rules = {'global': {}, 'device': {}} @@ -60,9 +50,7 @@ def format_push_rules_for_user(user, rawrules, enabled_map): template_rule = _rule_to_template(r) if template_rule: - if r['rule_id'] in enabled_map: - template_rule['enabled'] = enabled_map[r['rule_id']] - elif 'enabled' in r: + if 'enabled' in r: template_rule['enabled'] = r['enabled'] else: template_rule['enabled'] = True diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py new file mode 100644 index 0000000000..2eb325c7c7 --- /dev/null +++ b/synapse/push/emailpusher.py @@ -0,0 +1,294 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer, reactor +from twisted.internet.error import AlreadyCalled, AlreadyCancelled + +import logging + +from synapse.util.metrics import Measure +from synapse.util.logcontext import LoggingContext + +from mailer import Mailer + +logger = logging.getLogger(__name__) + +# The amount of time we always wait before ever emailing about a notification +# (to give the user a chance to respond to other push or notice the window) +DELAY_BEFORE_MAIL_MS = 10 * 60 * 1000 + +# THROTTLE is the minimum time between mail notifications sent for a given room. +# Each room maintains its own throttle counter, but each new mail notification +# sends the pending notifications for all rooms. +THROTTLE_START_MS = 10 * 60 * 1000 +THROTTLE_MAX_MS = 24 * 60 * 60 * 1000 # 24h +# THROTTLE_MULTIPLIER = 6 # 10 mins, 1 hour, 6 hours, 24 hours +THROTTLE_MULTIPLIER = 144 # 10 mins, 24 hours - i.e. jump straight to 1 day + +# If no event triggers a notification for this long after the previous, +# the throttle is released. +# 12 hours - a gap of 12 hours in conversation is surely enough to merit a new +# notification when things get going again... +THROTTLE_RESET_AFTER_MS = (12 * 60 * 60 * 1000) + +# does each email include all unread notifs, or just the ones which have happened +# since the last mail? +# XXX: this is currently broken as it includes ones from parted rooms(!) +INCLUDE_ALL_UNREAD_NOTIFS = False + + +class EmailPusher(object): + """ + A pusher that sends email notifications about events (approximately) + when they happen. + This shares quite a bit of code with httpusher: it would be good to + factor out the common parts + """ + def __init__(self, hs, pusherdict): + self.hs = hs + self.store = self.hs.get_datastore() + self.clock = self.hs.get_clock() + self.pusher_id = pusherdict['id'] + self.user_id = pusherdict['user_name'] + self.app_id = pusherdict['app_id'] + self.email = pusherdict['pushkey'] + self.last_stream_ordering = pusherdict['last_stream_ordering'] + self.timed_call = None + self.throttle_params = None + + # See httppusher + self.max_stream_ordering = None + + self.processing = False + + if self.hs.config.email_enable_notifs: + if 'data' in pusherdict and 'brand' in pusherdict['data']: + app_name = pusherdict['data']['brand'] + else: + app_name = self.hs.config.email_app_name + + self.mailer = Mailer(self.hs, app_name) + else: + self.mailer = None + + @defer.inlineCallbacks + def on_started(self): + if self.mailer is not None: + self.throttle_params = yield self.store.get_throttle_params_by_room( + self.pusher_id + ) + yield self._process() + + def on_stop(self): + if self.timed_call: + try: + self.timed_call.cancel() + except (AlreadyCalled, AlreadyCancelled): + pass + self.timed_call = None + + @defer.inlineCallbacks + def on_new_notifications(self, min_stream_ordering, max_stream_ordering): + self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering) + yield self._process() + + def on_new_receipts(self, min_stream_id, max_stream_id): + # We could wake up and cancel the timer but there tend to be quite a + # lot of read receipts so it's probably less work to just let the + # timer fire + return defer.succeed(None) + + @defer.inlineCallbacks + def on_timer(self): + self.timed_call = None + yield self._process() + + @defer.inlineCallbacks + def _process(self): + if self.processing: + return + + with LoggingContext("emailpush._process"): + with Measure(self.clock, "emailpush._process"): + try: + self.processing = True + # if the max ordering changes while we're running _unsafe_process, + # call it again, and so on until we've caught up. + while True: + starting_max_ordering = self.max_stream_ordering + try: + yield self._unsafe_process() + except: + logger.exception("Exception processing notifs") + if self.max_stream_ordering == starting_max_ordering: + break + finally: + self.processing = False + + @defer.inlineCallbacks + def _unsafe_process(self): + """ + Main logic of the push loop without the wrapper function that sets + up logging, measures and guards against multiple instances of it + being run. + """ + start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering + fn = self.store.get_unread_push_actions_for_user_in_range_for_email + unprocessed = yield fn(self.user_id, start, self.max_stream_ordering) + + soonest_due_at = None + + if not unprocessed: + yield self.save_last_stream_ordering_and_success(self.max_stream_ordering) + return + + for push_action in unprocessed: + received_at = push_action['received_ts'] + if received_at is None: + received_at = 0 + notif_ready_at = received_at + DELAY_BEFORE_MAIL_MS + + room_ready_at = self.room_ready_to_notify_at( + push_action['room_id'] + ) + + should_notify_at = max(notif_ready_at, room_ready_at) + + if should_notify_at < self.clock.time_msec(): + # one of our notifications is ready for sending, so we send + # *one* email updating the user on their notifications, + # we then consider all previously outstanding notifications + # to be delivered. + + reason = { + 'room_id': push_action['room_id'], + 'now': self.clock.time_msec(), + 'received_at': received_at, + 'delay_before_mail_ms': DELAY_BEFORE_MAIL_MS, + 'last_sent_ts': self.get_room_last_sent_ts(push_action['room_id']), + 'throttle_ms': self.get_room_throttle_ms(push_action['room_id']), + } + + yield self.send_notification(unprocessed, reason) + + yield self.save_last_stream_ordering_and_success(max([ + ea['stream_ordering'] for ea in unprocessed + ])) + + # we update the throttle on all the possible unprocessed push actions + for ea in unprocessed: + yield self.sent_notif_update_throttle( + ea['room_id'], ea + ) + break + else: + if soonest_due_at is None or should_notify_at < soonest_due_at: + soonest_due_at = should_notify_at + + if self.timed_call is not None: + try: + self.timed_call.cancel() + except (AlreadyCalled, AlreadyCancelled): + pass + self.timed_call = None + + if soonest_due_at is not None: + self.timed_call = reactor.callLater( + self.seconds_until(soonest_due_at), self.on_timer + ) + + @defer.inlineCallbacks + def save_last_stream_ordering_and_success(self, last_stream_ordering): + self.last_stream_ordering = last_stream_ordering + yield self.store.update_pusher_last_stream_ordering_and_success( + self.app_id, self.email, self.user_id, + last_stream_ordering, self.clock.time_msec() + ) + + def seconds_until(self, ts_msec): + return (ts_msec - self.clock.time_msec()) / 1000 + + def get_room_throttle_ms(self, room_id): + if room_id in self.throttle_params: + return self.throttle_params[room_id]["throttle_ms"] + else: + return 0 + + def get_room_last_sent_ts(self, room_id): + if room_id in self.throttle_params: + return self.throttle_params[room_id]["last_sent_ts"] + else: + return 0 + + def room_ready_to_notify_at(self, room_id): + """ + Determines whether throttling should prevent us from sending an email + for the given room + Returns: The timestamp when we are next allowed to send an email notif + for this room + """ + last_sent_ts = self.get_room_last_sent_ts(room_id) + throttle_ms = self.get_room_throttle_ms(room_id) + + may_send_at = last_sent_ts + throttle_ms + return may_send_at + + @defer.inlineCallbacks + def sent_notif_update_throttle(self, room_id, notified_push_action): + # We have sent a notification, so update the throttle accordingly. + # If the event that triggered the notif happened more than + # THROTTLE_RESET_AFTER_MS after the previous one that triggered a + # notif, we release the throttle. Otherwise, the throttle is increased. + time_of_previous_notifs = yield self.store.get_time_of_last_push_action_before( + notified_push_action['stream_ordering'] + ) + + time_of_this_notifs = notified_push_action['received_ts'] + + if time_of_previous_notifs is not None and time_of_this_notifs is not None: + gap = time_of_this_notifs - time_of_previous_notifs + else: + # if we don't know the arrival time of one of the notifs (it was not + # stored prior to email notification code) then assume a gap of + # zero which will just not reset the throttle + gap = 0 + + current_throttle_ms = self.get_room_throttle_ms(room_id) + + if gap > THROTTLE_RESET_AFTER_MS: + new_throttle_ms = THROTTLE_START_MS + else: + if current_throttle_ms == 0: + new_throttle_ms = THROTTLE_START_MS + else: + new_throttle_ms = min( + current_throttle_ms * THROTTLE_MULTIPLIER, + THROTTLE_MAX_MS + ) + self.throttle_params[room_id] = { + "last_sent_ts": self.clock.time_msec(), + "throttle_ms": new_throttle_ms + } + yield self.store.set_throttle_params( + self.pusher_id, room_id, self.throttle_params[room_id] + ) + + @defer.inlineCallbacks + def send_notification(self, push_actions, reason): + logger.info("Sending notif email for user %r", self.user_id) + + yield self.mailer.send_notification_mail( + self.app_id, self.user_id, self.email, push_actions, reason + ) diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 9be4869360..c0f8176e3d 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -13,60 +13,248 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.push import Pusher, PusherConfigException +from synapse.push import PusherConfigException -from twisted.internet import defer +from twisted.internet import defer, reactor +from twisted.internet.error import AlreadyCalled, AlreadyCancelled import logging +import push_rule_evaluator +import push_tools + +from synapse.util.logcontext import LoggingContext +from synapse.util.metrics import Measure logger = logging.getLogger(__name__) -class HttpPusher(Pusher): - def __init__(self, _hs, user_id, app_id, - app_display_name, device_display_name, pushkey, pushkey_ts, - data, last_token, last_success, failing_since): - super(HttpPusher, self).__init__( - _hs, - user_id, - app_id, - app_display_name, - device_display_name, - pushkey, - pushkey_ts, - data, - last_token, - last_success, - failing_since +class HttpPusher(object): + INITIAL_BACKOFF_SEC = 1 # in seconds because that's what Twisted takes + MAX_BACKOFF_SEC = 60 * 60 + + # This one's in ms because we compare it against the clock + GIVE_UP_AFTER_MS = 24 * 60 * 60 * 1000 + + def __init__(self, hs, pusherdict): + self.hs = hs + self.store = self.hs.get_datastore() + self.clock = self.hs.get_clock() + self.state_handler = self.hs.get_state_handler() + self.user_id = pusherdict['user_name'] + self.app_id = pusherdict['app_id'] + self.app_display_name = pusherdict['app_display_name'] + self.device_display_name = pusherdict['device_display_name'] + self.pushkey = pusherdict['pushkey'] + self.pushkey_ts = pusherdict['ts'] + self.data = pusherdict['data'] + self.last_stream_ordering = pusherdict['last_stream_ordering'] + self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC + self.failing_since = pusherdict['failing_since'] + self.timed_call = None + self.processing = False + + # This is the highest stream ordering we know it's safe to process. + # When new events arrive, we'll be given a window of new events: we + # should honour this rather than just looking for anything higher + # because of potential out-of-order event serialisation. This starts + # off as None though as we don't know any better. + self.max_stream_ordering = None + + if 'data' not in pusherdict: + raise PusherConfigException( + "No 'data' key for HTTP pusher" + ) + self.data = pusherdict['data'] + + self.name = "%s/%s/%s" % ( + pusherdict['user_name'], + pusherdict['app_id'], + pusherdict['pushkey'], ) - if 'url' not in data: + + if 'url' not in self.data: raise PusherConfigException( "'url' required in data for HTTP pusher" ) - self.url = data['url'] - self.http_client = _hs.get_simple_http_client() + self.url = self.data['url'] + self.http_client = hs.get_simple_http_client() self.data_minus_url = {} self.data_minus_url.update(self.data) del self.data_minus_url['url'] @defer.inlineCallbacks - def _build_notification_dict(self, event, tweaks, badge): - # we probably do not want to push for every presence update - # (we may want to be able to set up notifications when specific - # people sign in, but we'd want to only deliver the pertinent ones) - # Actually, presence events will not get this far now because we - # need to filter them out in the main Pusher code. - if 'event_id' not in event: - defer.returnValue(None) + def on_started(self): + yield self._process() - ctx = yield self.get_context_for_event(event) + @defer.inlineCallbacks + def on_new_notifications(self, min_stream_ordering, max_stream_ordering): + self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering) + yield self._process() + + @defer.inlineCallbacks + def on_new_receipts(self, min_stream_id, max_stream_id): + # Note that the min here shouldn't be relied upon to be accurate. + + # We could check the receipts are actually m.read receipts here, + # but currently that's the only type of receipt anyway... + with LoggingContext("push.on_new_receipts"): + with Measure(self.clock, "push.on_new_receipts"): + badge = yield push_tools.get_badge_count( + self.hs.get_datastore(), self.user_id + ) + yield self._send_badge(badge) + + @defer.inlineCallbacks + def on_timer(self): + yield self._process() + + def on_stop(self): + if self.timed_call: + try: + self.timed_call.cancel() + except (AlreadyCalled, AlreadyCancelled): + pass + self.timed_call = None + + @defer.inlineCallbacks + def _process(self): + if self.processing: + return + + with LoggingContext("push._process"): + with Measure(self.clock, "push._process"): + try: + self.processing = True + # if the max ordering changes while we're running _unsafe_process, + # call it again, and so on until we've caught up. + while True: + starting_max_ordering = self.max_stream_ordering + try: + yield self._unsafe_process() + except: + logger.exception("Exception processing notifs") + if self.max_stream_ordering == starting_max_ordering: + break + finally: + self.processing = False + + @defer.inlineCallbacks + def _unsafe_process(self): + """ + Looks for unset notifications and dispatch them, in order + Never call this directly: use _process which will only allow this to + run once per pusher. + """ + + fn = self.store.get_unread_push_actions_for_user_in_range_for_http + unprocessed = yield fn( + self.user_id, self.last_stream_ordering, self.max_stream_ordering + ) + + for push_action in unprocessed: + processed = yield self._process_one(push_action) + if processed: + self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC + self.last_stream_ordering = push_action['stream_ordering'] + yield self.store.update_pusher_last_stream_ordering_and_success( + self.app_id, self.pushkey, self.user_id, + self.last_stream_ordering, + self.clock.time_msec() + ) + if self.failing_since: + self.failing_since = None + yield self.store.update_pusher_failing_since( + self.app_id, self.pushkey, self.user_id, + self.failing_since + ) + else: + if not self.failing_since: + self.failing_since = self.clock.time_msec() + yield self.store.update_pusher_failing_since( + self.app_id, self.pushkey, self.user_id, + self.failing_since + ) + + if ( + self.failing_since and + self.failing_since < + self.clock.time_msec() - HttpPusher.GIVE_UP_AFTER_MS + ): + # we really only give up so that if the URL gets + # fixed, we don't suddenly deliver a load + # of old notifications. + logger.warn("Giving up on a notification to user %s, " + "pushkey %s", + self.user_id, self.pushkey) + self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC + self.last_stream_ordering = push_action['stream_ordering'] + yield self.store.update_pusher_last_stream_ordering( + self.app_id, + self.pushkey, + self.user_id, + self.last_stream_ordering + ) + + self.failing_since = None + yield self.store.update_pusher_failing_since( + self.app_id, + self.pushkey, + self.user_id, + self.failing_since + ) + else: + logger.info("Push failed: delaying for %ds", self.backoff_delay) + self.timed_call = reactor.callLater(self.backoff_delay, self.on_timer) + self.backoff_delay = min(self.backoff_delay * 2, self.MAX_BACKOFF_SEC) + break + + @defer.inlineCallbacks + def _process_one(self, push_action): + if 'notify' not in push_action['actions']: + defer.returnValue(True) + + tweaks = push_rule_evaluator.tweaks_for_actions(push_action['actions']) + badge = yield push_tools.get_badge_count(self.hs.get_datastore(), self.user_id) + + event = yield self.store.get_event(push_action['event_id'], allow_none=True) + if event is None: + defer.returnValue(True) # It's been redacted + rejected = yield self.dispatch_push(event, tweaks, badge) + if rejected is False: + defer.returnValue(False) + + if isinstance(rejected, list) or isinstance(rejected, tuple): + for pk in rejected: + if pk != self.pushkey: + # for sanity, we only remove the pushkey if it + # was the one we actually sent... + logger.warn( + ("Ignoring rejected pushkey %s because we" + " didn't send it"), pk + ) + else: + logger.info( + "Pushkey %s was rejected: removing", + pk + ) + yield self.hs.remove_pusher( + self.app_id, pk, self.user_id + ) + defer.returnValue(True) + + @defer.inlineCallbacks + def _build_notification_dict(self, event, tweaks, badge): + ctx = yield push_tools.get_context_for_event( + self.store, self.state_handler, event, self.user_id + ) d = { 'notification': { - 'id': event['event_id'], - 'room_id': event['room_id'], - 'type': event['type'], - 'sender': event['user_id'], + 'id': event.event_id, # deprecated: remove soon + 'event_id': event.event_id, + 'room_id': event.room_id, + 'type': event.type, + 'sender': event.user_id, 'counts': { # -- we don't mark messages as read yet so # we have no way of knowing # Just set the badge to 1 until we have read receipts @@ -84,14 +272,14 @@ class HttpPusher(Pusher): ] } } - if event['type'] == 'm.room.member': - d['notification']['membership'] = event['content']['membership'] - d['notification']['user_is_target'] = event['state_key'] == self.user_id + if event.type == 'm.room.member': + d['notification']['membership'] = event.content['membership'] + d['notification']['user_is_target'] = event.state_key == self.user_id if 'content' in event: - d['notification']['content'] = event['content'] + d['notification']['content'] = event.content - if len(ctx['aliases']): - d['notification']['room_alias'] = ctx['aliases'][0] + # We no longer send aliases separately, instead, we send the human + # readable name of the room, which may be an alias. if 'sender_display_name' in ctx and len(ctx['sender_display_name']) > 0: d['notification']['sender_display_name'] = ctx['sender_display_name'] if 'name' in ctx and len(ctx['name']) > 0: @@ -115,7 +303,7 @@ class HttpPusher(Pusher): defer.returnValue(rejected) @defer.inlineCallbacks - def send_badge(self, badge): + def _send_badge(self, badge): logger.info("Sending updated badge count %d to %r", badge, self.user_id) d = { 'notification': { diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py new file mode 100644 index 0000000000..53551632b6 --- /dev/null +++ b/synapse/push/mailer.py @@ -0,0 +1,533 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer +from twisted.mail.smtp import sendmail + +import email.utils +import email.mime.multipart +from email.mime.text import MIMEText +from email.mime.multipart import MIMEMultipart + +from synapse.util.async import concurrently_execute +from synapse.push.presentable_names import ( + calculate_room_name, name_from_member_event, descriptor_from_member_events +) +from synapse.types import UserID +from synapse.api.errors import StoreError +from synapse.api.constants import EventTypes +from synapse.visibility import filter_events_for_client + +import jinja2 +import bleach + +import time +import urllib + +import logging +logger = logging.getLogger(__name__) + + +MESSAGE_FROM_PERSON_IN_ROOM = "You have a message on %(app)s from %(person)s " \ + "in the %(room)s room..." +MESSAGE_FROM_PERSON = "You have a message on %(app)s from %(person)s..." +MESSAGES_FROM_PERSON = "You have messages on %(app)s from %(person)s..." +MESSAGES_IN_ROOM = "You have messages on %(app)s in the %(room)s room..." +MESSAGES_IN_ROOM_AND_OTHERS = \ + "You have messages on %(app)s in the %(room)s room and others..." +MESSAGES_FROM_PERSON_AND_OTHERS = \ + "You have messages on %(app)s from %(person)s and others..." +INVITE_FROM_PERSON_TO_ROOM = "%(person)s has invited you to join the " \ + "%(room)s room on %(app)s..." +INVITE_FROM_PERSON = "%(person)s has invited you to chat on %(app)s..." + +CONTEXT_BEFORE = 1 +CONTEXT_AFTER = 1 + +# From https://github.com/matrix-org/matrix-react-sdk/blob/master/src/HtmlUtils.js +ALLOWED_TAGS = [ + 'font', # custom to matrix for IRC-style font coloring + 'del', # for markdown + # deliberately no h1/h2 to stop people shouting. + 'h3', 'h4', 'h5', 'h6', 'blockquote', 'p', 'a', 'ul', 'ol', + 'nl', 'li', 'b', 'i', 'u', 'strong', 'em', 'strike', 'code', 'hr', 'br', 'div', + 'table', 'thead', 'caption', 'tbody', 'tr', 'th', 'td', 'pre' +] +ALLOWED_ATTRS = { + # custom ones first: + "font": ["color"], # custom to matrix + "a": ["href", "name", "target"], # remote target: custom to matrix + # We don't currently allow img itself by default, but this + # would make sense if we did + "img": ["src"], +} +# When bleach release a version with this option, we can specify schemes +# ALLOWED_SCHEMES = ["http", "https", "ftp", "mailto"] + + +class Mailer(object): + def __init__(self, hs, app_name): + self.hs = hs + self.store = self.hs.get_datastore() + self.auth_handler = self.hs.get_auth_handler() + self.state_handler = self.hs.get_state_handler() + loader = jinja2.FileSystemLoader(self.hs.config.email_template_dir) + self.app_name = app_name + logger.info("Created Mailer for app_name %s" % app_name) + env = jinja2.Environment(loader=loader) + env.filters["format_ts"] = format_ts_filter + env.filters["mxc_to_http"] = self.mxc_to_http_filter + self.notif_template_html = env.get_template( + self.hs.config.email_notif_template_html + ) + self.notif_template_text = env.get_template( + self.hs.config.email_notif_template_text + ) + + @defer.inlineCallbacks + def send_notification_mail(self, app_id, user_id, email_address, + push_actions, reason): + try: + from_string = self.hs.config.email_notif_from % { + "app": self.app_name + } + except TypeError: + from_string = self.hs.config.email_notif_from + + raw_from = email.utils.parseaddr(from_string)[1] + raw_to = email.utils.parseaddr(email_address)[1] + + if raw_to == '': + raise RuntimeError("Invalid 'to' address") + + rooms_in_order = deduped_ordered_list( + [pa['room_id'] for pa in push_actions] + ) + + notif_events = yield self.store.get_events( + [pa['event_id'] for pa in push_actions] + ) + + notifs_by_room = {} + for pa in push_actions: + notifs_by_room.setdefault(pa["room_id"], []).append(pa) + + # collect the current state for all the rooms in which we have + # notifications + state_by_room = {} + + try: + user_display_name = yield self.store.get_profile_displayname( + UserID.from_string(user_id).localpart + ) + if user_display_name is None: + user_display_name = user_id + except StoreError: + user_display_name = user_id + + @defer.inlineCallbacks + def _fetch_room_state(room_id): + room_state = yield self.state_handler.get_current_state_ids(room_id) + state_by_room[room_id] = room_state + + # Run at most 3 of these at once: sync does 10 at a time but email + # notifs are much less realtime than sync so we can afford to wait a bit. + yield concurrently_execute(_fetch_room_state, rooms_in_order, 3) + + # actually sort our so-called rooms_in_order list, most recent room first + rooms_in_order.sort( + key=lambda r: -(notifs_by_room[r][-1]['received_ts'] or 0) + ) + + rooms = [] + + for r in rooms_in_order: + roomvars = yield self.get_room_vars( + r, user_id, notifs_by_room[r], notif_events, state_by_room[r] + ) + rooms.append(roomvars) + + reason['room_name'] = yield calculate_room_name( + self.store, state_by_room[reason['room_id']], user_id, + fallback_to_members=True + ) + + summary_text = yield self.make_summary_text( + notifs_by_room, state_by_room, notif_events, user_id, reason + ) + + template_vars = { + "user_display_name": user_display_name, + "unsubscribe_link": self.make_unsubscribe_link( + user_id, app_id, email_address + ), + "summary_text": summary_text, + "app_name": self.app_name, + "rooms": rooms, + "reason": reason, + } + + html_text = self.notif_template_html.render(**template_vars) + html_part = MIMEText(html_text, "html", "utf8") + + plain_text = self.notif_template_text.render(**template_vars) + text_part = MIMEText(plain_text, "plain", "utf8") + + multipart_msg = MIMEMultipart('alternative') + multipart_msg['Subject'] = "[%s] %s" % (self.app_name, summary_text) + multipart_msg['From'] = from_string + multipart_msg['To'] = email_address + multipart_msg['Date'] = email.utils.formatdate() + multipart_msg['Message-ID'] = email.utils.make_msgid() + multipart_msg.attach(text_part) + multipart_msg.attach(html_part) + + logger.info("Sending email push notification to %s" % email_address) + # logger.debug(html_text) + + yield sendmail( + self.hs.config.email_smtp_host, + raw_from, raw_to, multipart_msg.as_string(), + port=self.hs.config.email_smtp_port + ) + + @defer.inlineCallbacks + def get_room_vars(self, room_id, user_id, notifs, notif_events, room_state_ids): + my_member_event_id = room_state_ids[("m.room.member", user_id)] + my_member_event = yield self.store.get_event(my_member_event_id) + is_invite = my_member_event.content["membership"] == "invite" + + room_name = yield calculate_room_name(self.store, room_state_ids, user_id) + + room_vars = { + "title": room_name, + "hash": string_ordinal_total(room_id), # See sender avatar hash + "notifs": [], + "invite": is_invite, + "link": self.make_room_link(room_id), + } + + if not is_invite: + for n in notifs: + notifvars = yield self.get_notif_vars( + n, user_id, notif_events[n['event_id']], room_state_ids + ) + + # merge overlapping notifs together. + # relies on the notifs being in chronological order. + merge = False + if room_vars['notifs'] and 'messages' in room_vars['notifs'][-1]: + prev_messages = room_vars['notifs'][-1]['messages'] + for message in notifvars['messages']: + pm = filter(lambda pm: pm['id'] == message['id'], prev_messages) + if pm: + if not message["is_historical"]: + pm[0]["is_historical"] = False + merge = True + elif merge: + # we're merging, so append any remaining messages + # in this notif to the previous one + prev_messages.append(message) + + if not merge: + room_vars['notifs'].append(notifvars) + + defer.returnValue(room_vars) + + @defer.inlineCallbacks + def get_notif_vars(self, notif, user_id, notif_event, room_state_ids): + results = yield self.store.get_events_around( + notif['room_id'], notif['event_id'], + before_limit=CONTEXT_BEFORE, after_limit=CONTEXT_AFTER + ) + + ret = { + "link": self.make_notif_link(notif), + "ts": notif['received_ts'], + "messages": [], + } + + the_events = yield filter_events_for_client( + self.store, user_id, results["events_before"] + ) + the_events.append(notif_event) + + for event in the_events: + messagevars = yield self.get_message_vars(notif, event, room_state_ids) + if messagevars is not None: + ret['messages'].append(messagevars) + + defer.returnValue(ret) + + @defer.inlineCallbacks + def get_message_vars(self, notif, event, room_state_ids): + if event.type != EventTypes.Message: + return + + sender_state_event_id = room_state_ids[("m.room.member", event.sender)] + sender_state_event = yield self.store.get_event(sender_state_event_id) + sender_name = name_from_member_event(sender_state_event) + sender_avatar_url = sender_state_event.content.get("avatar_url") + + # 'hash' for deterministically picking default images: use + # sender_hash % the number of default images to choose from + sender_hash = string_ordinal_total(event.sender) + + msgtype = event.content.get("msgtype") + + ret = { + "msgtype": msgtype, + "is_historical": event.event_id != notif['event_id'], + "id": event.event_id, + "ts": event.origin_server_ts, + "sender_name": sender_name, + "sender_avatar_url": sender_avatar_url, + "sender_hash": sender_hash, + } + + if msgtype == "m.text": + self.add_text_message_vars(ret, event) + elif msgtype == "m.image": + self.add_image_message_vars(ret, event) + + if "body" in event.content: + ret["body_text_plain"] = event.content["body"] + + defer.returnValue(ret) + + def add_text_message_vars(self, messagevars, event): + msgformat = event.content.get("format") + + messagevars["format"] = msgformat + + formatted_body = event.content.get("formatted_body") + body = event.content.get("body") + + if msgformat == "org.matrix.custom.html" and formatted_body: + messagevars["body_text_html"] = safe_markup(formatted_body) + elif body: + messagevars["body_text_html"] = safe_text(body) + + return messagevars + + def add_image_message_vars(self, messagevars, event): + messagevars["image_url"] = event.content["url"] + + return messagevars + + @defer.inlineCallbacks + def make_summary_text(self, notifs_by_room, room_state_ids, + notif_events, user_id, reason): + if len(notifs_by_room) == 1: + # Only one room has new stuff + room_id = notifs_by_room.keys()[0] + + # If the room has some kind of name, use it, but we don't + # want the generated-from-names one here otherwise we'll + # end up with, "new message from Bob in the Bob room" + room_name = yield calculate_room_name( + self.store, room_state_ids[room_id], user_id, fallback_to_members=False + ) + + my_member_event_id = room_state_ids[room_id][("m.room.member", user_id)] + my_member_event = yield self.store.get_event(my_member_event_id) + if my_member_event.content["membership"] == "invite": + inviter_member_event_id = room_state_ids[room_id][ + ("m.room.member", my_member_event.sender) + ] + inviter_member_event = yield self.store.get_event( + inviter_member_event_id + ) + inviter_name = name_from_member_event(inviter_member_event) + + if room_name is None: + defer.returnValue(INVITE_FROM_PERSON % { + "person": inviter_name, + "app": self.app_name + }) + else: + defer.returnValue(INVITE_FROM_PERSON_TO_ROOM % { + "person": inviter_name, + "room": room_name, + "app": self.app_name, + }) + + sender_name = None + if len(notifs_by_room[room_id]) == 1: + # There is just the one notification, so give some detail + event = notif_events[notifs_by_room[room_id][0]["event_id"]] + if ("m.room.member", event.sender) in room_state_ids[room_id]: + state_event_id = room_state_ids[room_id][ + ("m.room.member", event.sender) + ] + state_event = yield self.store.get_event(state_event_id) + sender_name = name_from_member_event(state_event) + + if sender_name is not None and room_name is not None: + defer.returnValue(MESSAGE_FROM_PERSON_IN_ROOM % { + "person": sender_name, + "room": room_name, + "app": self.app_name, + }) + elif sender_name is not None: + defer.returnValue(MESSAGE_FROM_PERSON % { + "person": sender_name, + "app": self.app_name, + }) + else: + # There's more than one notification for this room, so just + # say there are several + if room_name is not None: + defer.returnValue(MESSAGES_IN_ROOM % { + "room": room_name, + "app": self.app_name, + }) + else: + # If the room doesn't have a name, say who the messages + # are from explicitly to avoid, "messages in the Bob room" + sender_ids = list(set([ + notif_events[n['event_id']].sender + for n in notifs_by_room[room_id] + ])) + + member_events = yield self.store.get_events([ + room_state_ids[room_id][("m.room.member", s)] + for s in sender_ids + ]) + + defer.returnValue(MESSAGES_FROM_PERSON % { + "person": descriptor_from_member_events(member_events.values()), + "app": self.app_name, + }) + else: + # Stuff's happened in multiple different rooms + + # ...but we still refer to the 'reason' room which triggered the mail + if reason['room_name'] is not None: + defer.returnValue(MESSAGES_IN_ROOM_AND_OTHERS % { + "room": reason['room_name'], + "app": self.app_name, + }) + else: + # If the reason room doesn't have a name, say who the messages + # are from explicitly to avoid, "messages in the Bob room" + sender_ids = list(set([ + notif_events[n['event_id']].sender + for n in notifs_by_room[reason['room_id']] + ])) + + member_events = yield self.store.get_events([ + room_state_ids[room_id][("m.room.member", s)] + for s in sender_ids + ]) + + defer.returnValue(MESSAGES_FROM_PERSON_AND_OTHERS % { + "person": descriptor_from_member_events(member_events.values()), + "app": self.app_name, + }) + + def make_room_link(self, room_id): + # need /beta for Universal Links to work on iOS + if self.app_name == "Vector": + return "https://vector.im/beta/#/room/%s" % (room_id,) + else: + return "https://matrix.to/#/%s" % (room_id,) + + def make_notif_link(self, notif): + # need /beta for Universal Links to work on iOS + if self.app_name == "Vector": + return "https://vector.im/beta/#/room/%s/%s" % ( + notif['room_id'], notif['event_id'] + ) + else: + return "https://matrix.to/#/%s/%s" % ( + notif['room_id'], notif['event_id'] + ) + + def make_unsubscribe_link(self, user_id, app_id, email_address): + params = { + "access_token": self.auth_handler.generate_delete_pusher_token(user_id), + "app_id": app_id, + "pushkey": email_address, + } + + # XXX: make r0 once API is stable + return "%s_matrix/client/unstable/pushers/remove?%s" % ( + self.hs.config.public_baseurl, + urllib.urlencode(params), + ) + + def mxc_to_http_filter(self, value, width, height, resize_method="crop"): + if value[0:6] != "mxc://": + return "" + + serverAndMediaId = value[6:] + fragment = None + if '#' in serverAndMediaId: + (serverAndMediaId, fragment) = serverAndMediaId.split('#', 1) + fragment = "#" + fragment + + params = { + "width": width, + "height": height, + "method": resize_method, + } + return "%s_matrix/media/v1/thumbnail/%s?%s%s" % ( + self.hs.config.public_baseurl, + serverAndMediaId, + urllib.urlencode(params), + fragment or "", + ) + + +def safe_markup(raw_html): + return jinja2.Markup(bleach.linkify(bleach.clean( + raw_html, tags=ALLOWED_TAGS, attributes=ALLOWED_ATTRS, + # bleach master has this, but it isn't released yet + # protocols=ALLOWED_SCHEMES, + strip=True + ))) + + +def safe_text(raw_text): + """ + Process text: treat it as HTML but escape any tags (ie. just escape the + HTML) then linkify it. + """ + return jinja2.Markup(bleach.linkify(bleach.clean( + raw_text, tags=[], attributes={}, + strip=False + ))) + + +def deduped_ordered_list(l): + seen = set() + ret = [] + for item in l: + if item not in seen: + seen.add(item) + ret.append(item) + return ret + + +def string_ordinal_total(s): + tot = 0 + for c in s: + tot += ord(c) + return tot + + +def format_ts_filter(value, format): + return time.strftime(format, time.localtime(value / 1000)) diff --git a/synapse/push/presentable_names.py b/synapse/push/presentable_names.py new file mode 100644 index 0000000000..277da3cd35 --- /dev/null +++ b/synapse/push/presentable_names.py @@ -0,0 +1,198 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +import re +import logging + +logger = logging.getLogger(__name__) + +# intentionally looser than what aliases we allow to be registered since +# other HSes may allow aliases that we would not +ALIAS_RE = re.compile(r"^#.*:.+$") + +ALL_ALONE = "Empty Room" + + +@defer.inlineCallbacks +def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True, + fallback_to_single_member=True): + """ + Works out a user-facing name for the given room as per Matrix + spec recommendations. + Does not yet support internationalisation. + Args: + room_state: Dictionary of the room's state + user_id: The ID of the user to whom the room name is being presented + fallback_to_members: If False, return None instead of generating a name + based on the room's members if the room has no + title or aliases. + + Returns: + (string or None) A human readable name for the room. + """ + # does it have a name? + if ("m.room.name", "") in room_state_ids: + m_room_name = yield store.get_event( + room_state_ids[("m.room.name", "")], allow_none=True + ) + if m_room_name and m_room_name.content and m_room_name.content["name"]: + defer.returnValue(m_room_name.content["name"]) + + # does it have a canonical alias? + if ("m.room.canonical_alias", "") in room_state_ids: + canon_alias = yield store.get_event( + room_state_ids[("m.room.canonical_alias", "")], allow_none=True + ) + if ( + canon_alias and canon_alias.content and canon_alias.content["alias"] and + _looks_like_an_alias(canon_alias.content["alias"]) + ): + defer.returnValue(canon_alias.content["alias"]) + + # at this point we're going to need to search the state by all state keys + # for an event type, so rearrange the data structure + room_state_bytype_ids = _state_as_two_level_dict(room_state_ids) + + # right then, any aliases at all? + if "m.room.aliases" in room_state_bytype_ids: + m_room_aliases = room_state_bytype_ids["m.room.aliases"] + for alias_id in m_room_aliases.values(): + alias_event = yield store.get_event( + alias_id, allow_none=True + ) + if alias_event and alias_event.content.get("aliases"): + the_aliases = alias_event.content["aliases"] + if len(the_aliases) > 0 and _looks_like_an_alias(the_aliases[0]): + defer.returnValue(the_aliases[0]) + + if not fallback_to_members: + defer.returnValue(None) + + my_member_event = None + if ("m.room.member", user_id) in room_state_ids: + my_member_event = yield store.get_event( + room_state_ids[("m.room.member", user_id)], allow_none=True + ) + + if ( + my_member_event is not None and + my_member_event.content['membership'] == "invite" + ): + if ("m.room.member", my_member_event.sender) in room_state_ids: + inviter_member_event = yield store.get_event( + room_state_ids[("m.room.member", my_member_event.sender)], + allow_none=True, + ) + if inviter_member_event: + if fallback_to_single_member: + defer.returnValue( + "Invite from %s" % ( + name_from_member_event(inviter_member_event), + ) + ) + else: + return + else: + defer.returnValue("Room Invite") + + # we're going to have to generate a name based on who's in the room, + # so find out who is in the room that isn't the user. + if "m.room.member" in room_state_bytype_ids: + member_events = yield store.get_events( + room_state_bytype_ids["m.room.member"].values() + ) + all_members = [ + ev for ev in member_events.values() + if ev.content['membership'] == "join" or ev.content['membership'] == "invite" + ] + # Sort the member events oldest-first so the we name people in the + # order the joined (it should at least be deterministic rather than + # dictionary iteration order) + all_members.sort(key=lambda e: e.origin_server_ts) + other_members = [m for m in all_members if m.state_key != user_id] + else: + other_members = [] + all_members = [] + + if len(other_members) == 0: + if len(all_members) == 1: + # self-chat, peeked room with 1 participant, + # or inbound invite, or outbound 3PID invite. + if all_members[0].sender == user_id: + if "m.room.third_party_invite" in room_state_bytype_ids: + third_party_invites = ( + room_state_bytype_ids["m.room.third_party_invite"].values() + ) + + if len(third_party_invites) > 0: + # technically third party invite events are not member + # events, but they are close enough + + # FIXME: no they're not - they look nothing like a member; + # they have a great big encrypted thing as their name to + # prevent leaking the 3PID name... + # return "Inviting %s" % ( + # descriptor_from_member_events(third_party_invites) + # ) + defer.returnValue("Inviting email address") + else: + defer.returnValue(ALL_ALONE) + else: + defer.returnValue(name_from_member_event(all_members[0])) + else: + defer.returnValue(ALL_ALONE) + elif len(other_members) == 1 and not fallback_to_single_member: + return + else: + defer.returnValue(descriptor_from_member_events(other_members)) + + +def descriptor_from_member_events(member_events): + if len(member_events) == 0: + return "nobody" + elif len(member_events) == 1: + return name_from_member_event(member_events[0]) + elif len(member_events) == 2: + return "%s and %s" % ( + name_from_member_event(member_events[0]), + name_from_member_event(member_events[1]), + ) + else: + return "%s and %d others" % ( + name_from_member_event(member_events[0]), + len(member_events) - 1, + ) + + +def name_from_member_event(member_event): + if ( + member_event.content and "displayname" in member_event.content and + member_event.content["displayname"] + ): + return member_event.content["displayname"] + return member_event.state_key + + +def _state_as_two_level_dict(state): + ret = {} + for k, v in state.items(): + ret.setdefault(k[0], {})[k[1]] = v + return ret + + +def _looks_like_an_alias(string): + return ALIAS_RE.match(string) is not None diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 51f73a5b78..4db76f18bd 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -13,12 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - -from .baserules import list_with_base_rules - import logging -import simplejson as json import re from synapse.types import UserID @@ -32,22 +27,6 @@ IS_GLOB = re.compile(r'[\?\*\[\]]') INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$") -@defer.inlineCallbacks -def evaluator_for_user_id(user_id, room_id, store): - rawrules = yield store.get_push_rules_for_user(user_id) - enabled_map = yield store.get_push_rules_enabled_for_user(user_id) - our_member_event = yield store.get_current_state( - room_id=room_id, - event_type='m.room.member', - state_key=user_id, - ) - - defer.returnValue(PushRuleEvaluator( - user_id, rawrules, enabled_map, - room_id, our_member_event, store - )) - - def _room_member_count(ev, condition, room_member_count): if 'is' not in condition: return False @@ -74,110 +53,14 @@ def _room_member_count(ev, condition, room_member_count): return False -class PushRuleEvaluator: - DEFAULT_ACTIONS = [] - - def __init__(self, user_id, raw_rules, enabled_map, room_id, - our_member_event, store): - self.user_id = user_id - self.room_id = room_id - self.our_member_event = our_member_event - self.store = store - - rules = [] - for raw_rule in raw_rules: - rule = dict(raw_rule) - rule['conditions'] = json.loads(raw_rule['conditions']) - rule['actions'] = json.loads(raw_rule['actions']) - rules.append(rule) - - self.rules = list_with_base_rules(rules) - - self.enabled_map = enabled_map - - @staticmethod - def tweaks_for_actions(actions): - tweaks = {} - for a in actions: - if not isinstance(a, dict): - continue - if 'set_tweak' in a and 'value' in a: - tweaks[a['set_tweak']] = a['value'] - return tweaks - - @defer.inlineCallbacks - def actions_for_event(self, ev): - """ - This should take into account notification settings that the user - has configured both globally and per-room when we have the ability - to do such things. - """ - if ev['user_id'] == self.user_id: - # let's assume you probably know about messages you sent yourself - defer.returnValue([]) - - room_id = ev['room_id'] - - # get *our* member event for display name matching - my_display_name = None - - if self.our_member_event: - my_display_name = self.our_member_event[0].content.get("displayname") - - room_members = yield self.store.get_users_in_room(room_id) - room_member_count = len(room_members) - - evaluator = PushRuleEvaluatorForEvent(ev, room_member_count) - - for r in self.rules: - enabled = self.enabled_map.get(r['rule_id'], None) - if enabled is not None and not enabled: - continue - - if not r.get("enabled", True): - continue - - conditions = r['conditions'] - actions = r['actions'] - - # ignore rules with no actions (we have an explict 'dont_notify') - if len(actions) == 0: - logger.warn( - "Ignoring rule id %s with no actions for user %s", - r['rule_id'], self.user_id - ) - continue - - matches = True - for c in conditions: - matches = evaluator.matches( - c, self.user_id, my_display_name - ) - if not matches: - break - - logger.debug( - "Rule %s %s", - r['rule_id'], "matches" if matches else "doesn't match" - ) - - if matches: - logger.debug( - "%s matches for user %s, event %s", - r['rule_id'], self.user_id, ev['event_id'] - ) - - # filter out dont_notify as we treat an empty actions list - # as dont_notify, and this doesn't take up a row in our database - actions = [x for x in actions if x != 'dont_notify'] - - defer.returnValue(actions) - - logger.debug( - "No rules match for user %s, event %s", - self.user_id, ev['event_id'] - ) - defer.returnValue(PushRuleEvaluator.DEFAULT_ACTIONS) +def tweaks_for_actions(actions): + tweaks = {} + for a in actions: + if not isinstance(a, dict): + continue + if 'set_tweak' in a and 'value' in a: + tweaks[a['set_tweak']] = a['value'] + return tweaks class PushRuleEvaluatorForEvent(object): diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py new file mode 100644 index 0000000000..b47bf1f92b --- /dev/null +++ b/synapse/push/push_tools.py @@ -0,0 +1,70 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer +from synapse.push.presentable_names import ( + calculate_room_name, name_from_member_event +) +from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred + + +@defer.inlineCallbacks +def get_badge_count(store, user_id): + invites, joins = yield preserve_context_over_deferred(defer.gatherResults([ + preserve_fn(store.get_invited_rooms_for_user)(user_id), + preserve_fn(store.get_rooms_for_user)(user_id), + ], consumeErrors=True)) + + my_receipts_by_room = yield store.get_receipts_for_user( + user_id, "m.read", + ) + + badge = len(invites) + + for r in joins: + if r.room_id in my_receipts_by_room: + last_unread_event_id = my_receipts_by_room[r.room_id] + + notifs = yield ( + store.get_unread_event_push_actions_by_room_for_user( + r.room_id, user_id, last_unread_event_id + ) + ) + # return one badge count per conversation, as count per + # message is so noisy as to be almost useless + badge += 1 if notifs["notify_count"] else 0 + defer.returnValue(badge) + + +@defer.inlineCallbacks +def get_context_for_event(store, state_handler, ev, user_id): + ctx = {} + + room_state_ids = yield state_handler.get_current_state_ids(ev.room_id) + + # we no longer bother setting room_alias, and make room_name the + # human-readable name instead, be that m.room.name, an alias or + # a list of people in the room + name = yield calculate_room_name( + store, room_state_ids, user_id, fallback_to_single_member=False + ) + if name: + ctx['name'] = name + + sender_state_event_id = room_state_ids[("m.room.member", ev.sender)] + sender_state_event = yield store.get_event(sender_state_event_id) + ctx['sender_display_name'] = name_from_member_event(sender_state_event) + + defer.returnValue(ctx) diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py new file mode 100644 index 0000000000..de9c33b936 --- /dev/null +++ b/synapse/push/pusher.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from httppusher import HttpPusher + +import logging +logger = logging.getLogger(__name__) + +# We try importing this if we can (it will fail if we don't +# have the optional email dependencies installed). We don't +# yet have the config to know if we need the email pusher, +# but importing this after daemonizing seems to fail +# (even though a simple test of importing from a daemonized +# process works fine) +try: + from synapse.push.emailpusher import EmailPusher +except: + pass + + +def create_pusher(hs, pusherdict): + logger.info("trying to create_pusher for %r", pusherdict) + + PUSHER_TYPES = { + "http": HttpPusher, + } + + logger.info("email enable notifs: %r", hs.config.email_enable_notifs) + if hs.config.email_enable_notifs: + PUSHER_TYPES["email"] = EmailPusher + logger.info("defined email pusher type") + + if pusherdict['kind'] in PUSHER_TYPES: + logger.info("found pusher") + return PUSHER_TYPES[pusherdict['kind']](hs, pusherdict) diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 0b463c6fdb..3837be523d 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -16,9 +16,9 @@ from twisted.internet import defer -from .httppusher import HttpPusher -from synapse.push import PusherConfigException -from synapse.util.logcontext import preserve_fn +import pusher +from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred +from synapse.util.async import run_on_reactor import logging @@ -28,10 +28,10 @@ logger = logging.getLogger(__name__) class PusherPool: def __init__(self, _hs): self.hs = _hs + self.start_pushers = _hs.config.start_pushers self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() self.pushers = {} - self.last_pusher_started = -1 @defer.inlineCallbacks def start(self): @@ -48,7 +48,8 @@ class PusherPool: # will then get pulled out of the database, # recreated, added and started: this means we have only one # code path adding pushers. - self._create_pusher({ + pusher.create_pusher(self.hs, { + "id": None, "user_name": user_id, "kind": kind, "app_id": app_id, @@ -58,10 +59,18 @@ class PusherPool: "ts": time_now_msec, "lang": lang, "data": data, - "last_token": None, + "last_stream_ordering": None, "last_success": None, "failing_since": None }) + + # create the pusher setting last_stream_ordering to the current maximum + # stream ordering in event_push_actions, so it will process + # pushes from this point onwards. + last_stream_ordering = ( + yield self.store.get_latest_push_action_stream_ordering() + ) + yield self.store.add_pusher( user_id=user_id, access_token=access_token, @@ -73,6 +82,7 @@ class PusherPool: pushkey_ts=time_now_msec, lang=lang, data=data, + last_stream_ordering=last_stream_ordering, profile_tag=profile_tag, ) yield self._refresh_pusher(app_id, pushkey, user_id) @@ -92,40 +102,67 @@ class PusherPool: yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) @defer.inlineCallbacks - def remove_pushers_by_user(self, user_id, except_token_ids=[]): + def remove_pushers_by_user(self, user_id, except_access_token_id=None): all = yield self.store.get_all_pushers() logger.info( - "Removing all pushers for user %s except access tokens ids %r", - user_id, except_token_ids + "Removing all pushers for user %s except access tokens id %r", + user_id, except_access_token_id ) for p in all: - if p['user_name'] == user_id and p['access_token'] not in except_token_ids: + if p['user_name'] == user_id and p['access_token'] != except_access_token_id: logger.info( "Removing pusher for app id %s, pushkey %s, user %s", p['app_id'], p['pushkey'], p['user_name'] ) yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) - def _create_pusher(self, pusherdict): - if pusherdict['kind'] == 'http': - return HttpPusher( - self.hs, - user_id=pusherdict['user_name'], - app_id=pusherdict['app_id'], - app_display_name=pusherdict['app_display_name'], - device_display_name=pusherdict['device_display_name'], - pushkey=pusherdict['pushkey'], - pushkey_ts=pusherdict['ts'], - data=pusherdict['data'], - last_token=pusherdict['last_token'], - last_success=pusherdict['last_success'], - failing_since=pusherdict['failing_since'] + @defer.inlineCallbacks + def on_new_notifications(self, min_stream_id, max_stream_id): + yield run_on_reactor() + try: + users_affected = yield self.store.get_push_action_users_in_range( + min_stream_id, max_stream_id ) - else: - raise PusherConfigException( - "Unknown pusher type '%s' for user %s" % - (pusherdict['kind'], pusherdict['user_name']) + + deferreds = [] + + for u in users_affected: + if u in self.pushers: + for p in self.pushers[u].values(): + deferreds.append( + preserve_fn(p.on_new_notifications)( + min_stream_id, max_stream_id + ) + ) + + yield preserve_context_over_deferred(defer.gatherResults(deferreds)) + except: + logger.exception("Exception in pusher on_new_notifications") + + @defer.inlineCallbacks + def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids): + yield run_on_reactor() + try: + # Need to subtract 1 from the minimum because the lower bound here + # is not inclusive + updated_receipts = yield self.store.get_all_updated_receipts( + min_stream_id - 1, max_stream_id ) + # This returns a tuple, user_id is at index 3 + users_affected = set([r[3] for r in updated_receipts]) + + deferreds = [] + + for u in users_affected: + if u in self.pushers: + for p in self.pushers[u].values(): + deferreds.append( + preserve_fn(p.on_new_receipts)(min_stream_id, max_stream_id) + ) + + yield preserve_context_over_deferred(defer.gatherResults(deferreds)) + except: + logger.exception("Exception in pusher on_new_receipts") @defer.inlineCallbacks def _refresh_pusher(self, app_id, pushkey, user_id): @@ -143,33 +180,40 @@ class PusherPool: self._start_pushers([p]) def _start_pushers(self, pushers): + if not self.start_pushers: + logger.info("Not starting pushers because they are disabled in the config") + return logger.info("Starting %d pushers", len(pushers)) for pusherdict in pushers: try: - p = self._create_pusher(pusherdict) - except PusherConfigException: - logger.exception("Couldn't start a pusher: caught PusherConfigException") + p = pusher.create_pusher(self.hs, pusherdict) + except: + logger.exception("Couldn't start a pusher: caught Exception") continue if p: - fullid = "%s:%s:%s" % ( + appid_pushkey = "%s:%s" % ( pusherdict['app_id'], pusherdict['pushkey'], - pusherdict['user_name'] ) - if fullid in self.pushers: - self.pushers[fullid].stop() - self.pushers[fullid] = p - preserve_fn(p.start)() + byuser = self.pushers.setdefault(pusherdict['user_name'], {}) + + if appid_pushkey in byuser: + byuser[appid_pushkey].on_stop() + byuser[appid_pushkey] = p + preserve_fn(p.on_started)() logger.info("Started pushers") @defer.inlineCallbacks def remove_pusher(self, app_id, pushkey, user_id): - fullid = "%s:%s:%s" % (app_id, pushkey, user_id) - if fullid in self.pushers: - logger.info("Stopping pusher %s", fullid) - self.pushers[fullid].stop() - del self.pushers[fullid] + appid_pushkey = "%s:%s" % (app_id, pushkey) + + byuser = self.pushers.get(user_id, {}) + + if appid_pushkey in byuser: + logger.info("Stopping pusher %s / %s", user_id, appid_pushkey) + byuser[appid_pushkey].on_stop() + del byuser[appid_pushkey] yield self.store.delete_pusher_by_app_id_pushkey_user_id( app_id, pushkey, user_id ) diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 0a6043ae8d..3742a25b37 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -36,11 +36,25 @@ REQUIREMENTS = { "blist": ["blist"], "pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"], "pymacaroons-pynacl": ["pymacaroons"], + "msgpack-python>=0.3.0": ["msgpack"], } CONDITIONAL_REQUIREMENTS = { "web_client": { "matrix_angular_sdk>=0.6.8": ["syweb>=0.6.8"], - } + }, + "preview_url": { + "netaddr>=0.7.18": ["netaddr"], + }, + "email.enable_notifs": { + "Jinja2>=2.8": ["Jinja2>=2.8"], + "bleach>=1.4.2": ["bleach>=1.4.2"], + }, + "matrix-synapse-ldap3": { + "matrix-synapse-ldap3>=0.1": ["ldap_auth_provider"], + }, + "psutil": { + "psutil>=2.0.0": ["psutil>=2.0.0"], + }, } @@ -55,6 +69,7 @@ def requirements(config=None, include_conditional=False): def github_link(project, version, egg): return "https://github.com/%s/tarball/%s/#egg=%s" % (project, version, egg) + DEPENDENCY_LINKS = { } @@ -142,6 +157,7 @@ def list_requirements(): result.append(requirement) return result + if __name__ == "__main__": import sys sys.stdout.writelines(req + "\n" for req in list_requirements()) diff --git a/synapse/replication/expire_cache.py b/synapse/replication/expire_cache.py new file mode 100644 index 0000000000..c05a50d7a6 --- /dev/null +++ b/synapse/replication/expire_cache.py @@ -0,0 +1,60 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.http.server import respond_with_json_bytes, request_handler +from synapse.http.servlet import parse_json_object_from_request + +from twisted.web.resource import Resource +from twisted.web.server import NOT_DONE_YET + + +class ExpireCacheResource(Resource): + """ + HTTP endpoint for expiring storage caches. + + POST /_synapse/replication/expire_cache HTTP/1.1 + Content-Type: application/json + + { + "invalidate": [ + { + "name": "func_name", + "keys": ["key1", "key2"] + } + ] + } + """ + + def __init__(self, hs): + Resource.__init__(self) # Resource is old-style, so no super() + + self.store = hs.get_datastore() + self.version_string = hs.version_string + self.clock = hs.get_clock() + + def render_POST(self, request): + self._async_render_POST(request) + return NOT_DONE_YET + + @request_handler() + def _async_render_POST(self, request): + content = parse_json_object_from_request(request) + + for row in content["invalidate"]: + name = row["name"] + keys = tuple(row["keys"]) + + getattr(self.store, name).invalidate(keys) + + respond_with_json_bytes(request, 200, "{}") diff --git a/synapse/replication/presence_resource.py b/synapse/replication/presence_resource.py new file mode 100644 index 0000000000..fc18130ab4 --- /dev/null +++ b/synapse/replication/presence_resource.py @@ -0,0 +1,59 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.http.server import respond_with_json_bytes, request_handler +from synapse.http.servlet import parse_json_object_from_request + +from twisted.web.resource import Resource +from twisted.web.server import NOT_DONE_YET +from twisted.internet import defer + + +class PresenceResource(Resource): + """ + HTTP endpoint for marking users as syncing. + + POST /_synapse/replication/presence HTTP/1.1 + Content-Type: application/json + + { + "process_id": "<process_id>", + "syncing_users": ["<user_id>"] + } + """ + + def __init__(self, hs): + Resource.__init__(self) # Resource is old-style, so no super() + + self.version_string = hs.version_string + self.clock = hs.get_clock() + self.presence_handler = hs.get_presence_handler() + + def render_POST(self, request): + self._async_render_POST(request) + return NOT_DONE_YET + + @request_handler() + @defer.inlineCallbacks + def _async_render_POST(self, request): + content = parse_json_object_from_request(request) + + process_id = content["process_id"] + syncing_user_ids = content["syncing_users"] + + yield self.presence_handler.update_external_syncs( + process_id, set(syncing_user_ids) + ) + + respond_with_json_bytes(request, 200, "{}") diff --git a/synapse/replication/pusher_resource.py b/synapse/replication/pusher_resource.py new file mode 100644 index 0000000000..9b01ab3c13 --- /dev/null +++ b/synapse/replication/pusher_resource.py @@ -0,0 +1,54 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.http.server import respond_with_json_bytes, request_handler +from synapse.http.servlet import parse_json_object_from_request + +from twisted.web.resource import Resource +from twisted.web.server import NOT_DONE_YET +from twisted.internet import defer + + +class PusherResource(Resource): + """ + HTTP endpoint for deleting rejected pushers + """ + + def __init__(self, hs): + Resource.__init__(self) # Resource is old-style, so no super() + + self.version_string = hs.version_string + self.store = hs.get_datastore() + self.notifier = hs.get_notifier() + self.clock = hs.get_clock() + + def render_POST(self, request): + self._async_render_POST(request) + return NOT_DONE_YET + + @request_handler() + @defer.inlineCallbacks + def _async_render_POST(self, request): + content = parse_json_object_from_request(request) + + for remove in content["remove"]: + yield self.store.delete_pusher_by_app_id_pushkey_user_id( + remove["app_id"], + remove["push_key"], + remove["user_id"], + ) + + self.notifier.on_new_replication_data() + + respond_with_json_bytes(request, 200, "{}") diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py index 8c1ae0fbc7..d79b421cba 100644 --- a/synapse/replication/resource.py +++ b/synapse/replication/resource.py @@ -15,6 +15,10 @@ from synapse.http.servlet import parse_integer, parse_string from synapse.http.server import request_handler, finish_request +from synapse.replication.pusher_resource import PusherResource +from synapse.replication.presence_resource import PresenceResource +from synapse.replication.expire_cache import ExpireCacheResource +from synapse.api.errors import SynapseError from twisted.web.resource import Resource from twisted.web.server import NOT_DONE_YET @@ -38,6 +42,10 @@ STREAM_NAMES = ( ("backfill",), ("push_rules",), ("pushers",), + ("caches",), + ("to_device",), + ("public_rooms",), + ("federation",), ) @@ -67,6 +75,7 @@ class ReplicationResource(Resource): * "backfill": Old events that have been backfilled from other servers. * "push_rules": Per user changes to push rules. * "pushers": Per user changes to their pushers. + * "caches": Cache invalidations. The API takes two additional query parameters: @@ -76,7 +85,7 @@ class ReplicationResource(Resource): The response is a JSON object with keys for each stream with updates. Under each key is a JSON object with: - * "postion": The current position of the stream. + * "position": The current position of the stream. * "field_names": The names of the fields in each row. * "rows": The updates as an array of arrays. @@ -101,17 +110,22 @@ class ReplicationResource(Resource): long-polling this replication API for new data on those streams. """ - isLeaf = True - def __init__(self, hs): Resource.__init__(self) # Resource is old-style, so no super() self.version_string = hs.version_string self.store = hs.get_datastore() self.sources = hs.get_event_sources() - self.presence_handler = hs.get_handlers().presence_handler - self.typing_handler = hs.get_handlers().typing_notification_handler + self.presence_handler = hs.get_presence_handler() + self.typing_handler = hs.get_typing_handler() + self.federation_sender = hs.get_federation_sender() self.notifier = hs.notifier + self.clock = hs.get_clock() + self.config = hs.get_config() + + self.putChild("remove_pushers", PusherResource(hs)) + self.putChild("syncing_users", PresenceResource(hs)) + self.putChild("expire_cache", ExpireCacheResource(hs)) def render_GET(self, request): self._async_render_GET(request) @@ -123,6 +137,9 @@ class ReplicationResource(Resource): backfill_token = yield self.store.get_current_backfill_token() push_rules_token, room_stream_token = self.store.get_push_rules_stream_token() pushers_token = self.store.get_pushers_stream_token() + caches_token = self.store.get_cache_stream_token() + public_rooms_token = self.store.get_current_public_room_stream_id() + federation_token = self.federation_sender.get_current_token() defer.returnValue(_ReplicationToken( room_stream_token, @@ -133,40 +150,78 @@ class ReplicationResource(Resource): backfill_token, push_rules_token, pushers_token, + 0, # State stream is no longer a thing + caches_token, + int(stream_token.to_device_key), + int(public_rooms_token), + int(federation_token), )) - @request_handler + @request_handler() @defer.inlineCallbacks def _async_render_GET(self, request): limit = parse_integer(request, "limit", 100) timeout = parse_integer(request, "timeout", 10 * 1000) request.setHeader(b"Content-Type", b"application/json") - writer = _Writer(request) - @defer.inlineCallbacks - def replicate(): - current_token = yield self.current_replication_token() - logger.info("Replicating up to %r", current_token) + request_streams = { + name: parse_integer(request, name) + for names in STREAM_NAMES for name in names + } + request_streams["streams"] = parse_string(request, "streams") - yield self.account_data(writer, current_token, limit) - yield self.events(writer, current_token, limit) - yield self.presence(writer, current_token) # TODO: implement limit - yield self.typing(writer, current_token) # TODO: implement limit - yield self.receipts(writer, current_token, limit) - yield self.push_rules(writer, current_token, limit) - yield self.pushers(writer, current_token, limit) - self.streams(writer, current_token) + federation_ack = parse_integer(request, "federation_ack", None) - logger.info("Replicated %d rows", writer.total) - defer.returnValue(writer.total) + def replicate(): + return self.replicate( + request_streams, limit, + federation_ack=federation_ack + ) - yield self.notifier.wait_for_replication(replicate, timeout) + writer = yield self.notifier.wait_for_replication(replicate, timeout) + result = writer.finish() - writer.finish() + for stream_name, stream_content in result.items(): + logger.info( + "Replicating %d rows of %s from %s -> %s", + len(stream_content["rows"]), + stream_name, + request_streams.get(stream_name), + stream_content["position"], + ) - def streams(self, writer, current_token): - request_token = parse_string(writer.request, "streams") + request.write(json.dumps(result, ensure_ascii=False)) + finish_request(request) + + @defer.inlineCallbacks + def replicate(self, request_streams, limit, federation_ack=None): + writer = _Writer() + current_token = yield self.current_replication_token() + logger.debug("Replicating up to %r", current_token) + + if limit == 0: + raise SynapseError(400, "Limit cannot be 0") + + yield self.account_data(writer, current_token, limit, request_streams) + yield self.events(writer, current_token, limit, request_streams) + # TODO: implement limit + yield self.presence(writer, current_token, request_streams) + yield self.typing(writer, current_token, request_streams) + yield self.receipts(writer, current_token, limit, request_streams) + yield self.push_rules(writer, current_token, limit, request_streams) + yield self.pushers(writer, current_token, limit, request_streams) + yield self.caches(writer, current_token, limit, request_streams) + yield self.to_device(writer, current_token, limit, request_streams) + yield self.public_rooms(writer, current_token, limit, request_streams) + self.federation(writer, current_token, limit, request_streams, federation_ack) + self.streams(writer, current_token, request_streams) + + logger.debug("Replicated %d rows", writer.total) + defer.returnValue(writer) + + def streams(self, writer, current_token, request_streams): + request_token = request_streams.get("streams") streams = [] @@ -191,166 +246,288 @@ class ReplicationResource(Resource): ) @defer.inlineCallbacks - def events(self, writer, current_token, limit): - request_events = parse_integer(writer.request, "events") - request_backfill = parse_integer(writer.request, "backfill") + def events(self, writer, current_token, limit, request_streams): + request_events = request_streams.get("events") + request_backfill = request_streams.get("backfill") if request_events is not None or request_backfill is not None: if request_events is None: request_events = current_token.events if request_backfill is None: request_backfill = current_token.backfill - events_rows, backfill_rows = yield self.store.get_all_new_events( + + no_new_tokens = ( + request_events == current_token.events + and request_backfill == current_token.backfill + ) + if no_new_tokens: + return + + res = yield self.store.get_all_new_events( request_backfill, request_events, current_token.backfill, current_token.events, limit ) + + upto_events_token = _position_from_rows( + res.new_forward_events, current_token.events + ) + + upto_backfill_token = _position_from_rows( + res.new_backfill_events, current_token.backfill + ) + + if request_events != upto_events_token: + writer.write_header_and_rows("events", res.new_forward_events, ( + "position", "internal", "json", "state_group" + ), position=upto_events_token) + + if request_backfill != upto_backfill_token: + writer.write_header_and_rows("backfill", res.new_backfill_events, ( + "position", "internal", "json", "state_group", + ), position=upto_backfill_token) + writer.write_header_and_rows( - "events", events_rows, ("position", "internal", "json") + "forward_ex_outliers", res.forward_ex_outliers, + ("position", "event_id", "state_group"), ) writer.write_header_and_rows( - "backfill", backfill_rows, ("position", "internal", "json") + "backward_ex_outliers", res.backward_ex_outliers, + ("position", "event_id", "state_group"), + ) + writer.write_header_and_rows( + "state_resets", res.state_resets, ("position",), ) @defer.inlineCallbacks - def presence(self, writer, current_token): + def presence(self, writer, current_token, request_streams): current_position = current_token.presence - request_presence = parse_integer(writer.request, "presence") + request_presence = request_streams.get("presence") - if request_presence is not None: + if request_presence is not None and request_presence != current_position: presence_rows = yield self.presence_handler.get_all_presence_updates( request_presence, current_position ) + upto_token = _position_from_rows(presence_rows, current_position) writer.write_header_and_rows("presence", presence_rows, ( "position", "user_id", "state", "last_active_ts", "last_federation_update_ts", "last_user_sync_ts", "status_msg", "currently_active", - )) + ), position=upto_token) @defer.inlineCallbacks - def typing(self, writer, current_token): - current_position = current_token.presence + def typing(self, writer, current_token, request_streams): + current_position = current_token.typing - request_typing = parse_integer(writer.request, "typing") + request_typing = request_streams.get("typing") + + if request_typing is not None and request_typing != current_position: + # If they have a higher token than current max, we can assume that + # they had been talking to a previous instance of the master. Since + # we reset the token on restart, the best (but hacky) thing we can + # do is to simply resend down all the typing notifications. + if request_typing > current_position: + request_typing = 0 - if request_typing is not None: typing_rows = yield self.typing_handler.get_all_typing_updates( request_typing, current_position ) + upto_token = _position_from_rows(typing_rows, current_position) writer.write_header_and_rows("typing", typing_rows, ( "position", "room_id", "typing" - )) + ), position=upto_token) @defer.inlineCallbacks - def receipts(self, writer, current_token, limit): + def receipts(self, writer, current_token, limit, request_streams): current_position = current_token.receipts - request_receipts = parse_integer(writer.request, "receipts") + request_receipts = request_streams.get("receipts") - if request_receipts is not None: + if request_receipts is not None and request_receipts != current_position: receipts_rows = yield self.store.get_all_updated_receipts( request_receipts, current_position, limit ) + upto_token = _position_from_rows(receipts_rows, current_position) writer.write_header_and_rows("receipts", receipts_rows, ( "position", "room_id", "receipt_type", "user_id", "event_id", "data" - )) + ), position=upto_token) @defer.inlineCallbacks - def account_data(self, writer, current_token, limit): + def account_data(self, writer, current_token, limit, request_streams): current_position = current_token.account_data - user_account_data = parse_integer(writer.request, "user_account_data") - room_account_data = parse_integer(writer.request, "room_account_data") - tag_account_data = parse_integer(writer.request, "tag_account_data") + user_account_data = request_streams.get("user_account_data") + room_account_data = request_streams.get("room_account_data") + tag_account_data = request_streams.get("tag_account_data") if user_account_data is not None or room_account_data is not None: if user_account_data is None: user_account_data = current_position if room_account_data is None: room_account_data = current_position + + no_new_tokens = ( + user_account_data == current_position + and room_account_data == current_position + ) + if no_new_tokens: + return + user_rows, room_rows = yield self.store.get_all_updated_account_data( user_account_data, room_account_data, current_position, limit ) + + upto_users_token = _position_from_rows(user_rows, current_position) + upto_rooms_token = _position_from_rows(room_rows, current_position) + writer.write_header_and_rows("user_account_data", user_rows, ( "position", "user_id", "type", "content" - )) + ), position=upto_users_token) writer.write_header_and_rows("room_account_data", room_rows, ( "position", "user_id", "room_id", "type", "content" - )) + ), position=upto_rooms_token) if tag_account_data is not None: tag_rows = yield self.store.get_all_updated_tags( tag_account_data, current_position, limit ) + upto_tag_token = _position_from_rows(tag_rows, current_position) writer.write_header_and_rows("tag_account_data", tag_rows, ( "position", "user_id", "room_id", "tags" - )) + ), position=upto_tag_token) @defer.inlineCallbacks - def push_rules(self, writer, current_token, limit): + def push_rules(self, writer, current_token, limit, request_streams): current_position = current_token.push_rules - push_rules = parse_integer(writer.request, "push_rules") + push_rules = request_streams.get("push_rules") - if push_rules is not None: + if push_rules is not None and push_rules != current_position: rows = yield self.store.get_all_push_rule_updates( push_rules, current_position, limit ) + upto_token = _position_from_rows(rows, current_position) writer.write_header_and_rows("push_rules", rows, ( "position", "event_stream_ordering", "user_id", "rule_id", "op", "priority_class", "priority", "conditions", "actions" - )) + ), position=upto_token) @defer.inlineCallbacks - def pushers(self, writer, current_token, limit): + def pushers(self, writer, current_token, limit, request_streams): current_position = current_token.pushers - pushers = parse_integer(writer.request, "pushers") - if pushers is not None: + pushers = request_streams.get("pushers") + + if pushers is not None and pushers != current_position: updated, deleted = yield self.store.get_all_updated_pushers( pushers, current_position, limit ) + upto_token = _position_from_rows(updated, current_position) writer.write_header_and_rows("pushers", updated, ( "position", "user_id", "access_token", "profile_tag", "kind", "app_id", "app_display_name", "device_display_name", "pushkey", "ts", "lang", "data" - )) - writer.write_header_and_rows("deleted", deleted, ( + ), position=upto_token) + writer.write_header_and_rows("deleted_pushers", deleted, ( "position", "user_id", "app_id", "pushkey" - )) + ), position=upto_token) + + @defer.inlineCallbacks + def caches(self, writer, current_token, limit, request_streams): + current_position = current_token.caches + + caches = request_streams.get("caches") + + if caches is not None and caches != current_position: + updated_caches = yield self.store.get_all_updated_caches( + caches, current_position, limit + ) + upto_token = _position_from_rows(updated_caches, current_position) + writer.write_header_and_rows("caches", updated_caches, ( + "position", "cache_func", "keys", "invalidation_ts" + ), position=upto_token) + + @defer.inlineCallbacks + def to_device(self, writer, current_token, limit, request_streams): + current_position = current_token.to_device + + to_device = request_streams.get("to_device") + + if to_device is not None and to_device != current_position: + to_device_rows = yield self.store.get_all_new_device_messages( + to_device, current_position, limit + ) + upto_token = _position_from_rows(to_device_rows, current_position) + writer.write_header_and_rows("to_device", to_device_rows, ( + "position", "user_id", "device_id", "message_json" + ), position=upto_token) + + @defer.inlineCallbacks + def public_rooms(self, writer, current_token, limit, request_streams): + current_position = current_token.public_rooms + + public_rooms = request_streams.get("public_rooms") + + if public_rooms is not None and public_rooms != current_position: + public_rooms_rows = yield self.store.get_all_new_public_rooms( + public_rooms, current_position, limit + ) + upto_token = _position_from_rows(public_rooms_rows, current_position) + writer.write_header_and_rows("public_rooms", public_rooms_rows, ( + "position", "room_id", "visibility" + ), position=upto_token) + + def federation(self, writer, current_token, limit, request_streams, federation_ack): + if self.config.send_federation: + return + + current_position = current_token.federation + + federation = request_streams.get("federation") + + if federation is not None and federation != current_position: + federation_rows = self.federation_sender.get_replication_rows( + federation, limit, federation_ack=federation_ack, + ) + upto_token = _position_from_rows(federation_rows, current_position) + writer.write_header_and_rows("federation", federation_rows, ( + "position", "type", "content", + ), position=upto_token) class _Writer(object): """Writes the streams as a JSON object as the response to the request""" - def __init__(self, request): + def __init__(self): self.streams = {} - self.request = request self.total = 0 def write_header_and_rows(self, name, rows, fields, position=None): - if not rows: - return - if position is None: - position = rows[-1][0] + if rows: + position = rows[-1][0] + else: + return self.streams[name] = { - "position": str(position), + "position": position if type(position) is int else str(position), "field_names": fields, "rows": rows, } self.total += len(rows) + def __nonzero__(self): + return bool(self.total) + def finish(self): - self.request.write(json.dumps(self.streams, ensure_ascii=False)) - finish_request(self.request) + return self.streams class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( "events", "presence", "typing", "receipts", "account_data", "backfill", - "push_rules", "pushers" + "push_rules", "pushers", "state", "caches", "to_device", "public_rooms", + "federation", ))): __slots__ = [] @@ -365,3 +542,20 @@ class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( def __str__(self): return "_".join(str(value) for value in self) + + +def _position_from_rows(rows, current_position): + """Calculates a position to return for a stream. Ideally we want to return the + position of the last row, as that will be the most correct. However, if there + are no rows we fall back to using the current position to stop us from + repeatedly hitting the storage layer unncessarily thinking there are updates. + (Not all advances of the token correspond to an actual update) + + We can't just always return the current position, as we often limit the + number of rows we replicate, and so the stream may lag. The assumption is + that if the storage layer returns no new rows then we are not lagging and + we are at the `current_position`. + """ + if rows: + return rows[-1][0] + return current_position diff --git a/synapse/replication/slave/__init__.py b/synapse/replication/slave/__init__.py new file mode 100644 index 0000000000..b7df13c9ee --- /dev/null +++ b/synapse/replication/slave/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/synapse/replication/slave/storage/__init__.py b/synapse/replication/slave/storage/__init__.py new file mode 100644 index 0000000000..b7df13c9ee --- /dev/null +++ b/synapse/replication/slave/storage/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py new file mode 100644 index 0000000000..18076e0f3b --- /dev/null +++ b/synapse/replication/slave/storage/_base.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage._base import SQLBaseStore +from synapse.storage.engines import PostgresEngine +from twisted.internet import defer + +from ._slaved_id_tracker import SlavedIdTracker + +import logging + +logger = logging.getLogger(__name__) + + +class BaseSlavedStore(SQLBaseStore): + def __init__(self, db_conn, hs): + super(BaseSlavedStore, self).__init__(hs) + if isinstance(self.database_engine, PostgresEngine): + self._cache_id_gen = SlavedIdTracker( + db_conn, "cache_invalidation_stream", "stream_id", + ) + else: + self._cache_id_gen = None + + self.expire_cache_url = hs.config.worker_replication_url + "/expire_cache" + self.http_client = hs.get_simple_http_client() + + def stream_positions(self): + pos = {} + if self._cache_id_gen: + pos["caches"] = self._cache_id_gen.get_current_token() + return pos + + def process_replication(self, result): + stream = result.get("caches") + if stream: + for row in stream["rows"]: + ( + position, cache_func, keys, invalidation_ts, + ) = row + + try: + getattr(self, cache_func).invalidate(tuple(keys)) + except AttributeError: + logger.info("Got unexpected cache_func: %r", cache_func) + self._cache_id_gen.advance(int(stream["position"])) + return defer.succeed(None) + + def _invalidate_cache_and_stream(self, txn, cache_func, keys): + txn.call_after(cache_func.invalidate, keys) + txn.call_after(self._send_invalidation_poke, cache_func, keys) + + @defer.inlineCallbacks + def _send_invalidation_poke(self, cache_func, keys): + try: + yield self.http_client.post_json_get_json(self.expire_cache_url, { + "invalidate": [{ + "name": cache_func.__name__, + "keys": list(keys), + }] + }) + except: + logger.exception("Failed to poke on expire_cache") diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py new file mode 100644 index 0000000000..24b5c79d4a --- /dev/null +++ b/synapse/replication/slave/storage/_slaved_id_tracker.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.util.id_generators import _load_current_id + + +class SlavedIdTracker(object): + def __init__(self, db_conn, table, column, extra_tables=[], step=1): + self.step = step + self._current = _load_current_id(db_conn, table, column, step) + for table, column in extra_tables: + self.advance(_load_current_id(db_conn, table, column)) + + def advance(self, new_id): + self._current = (max if self.step > 0 else min)(self._current, new_id) + + def get_current_token(self): + return self._current diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py new file mode 100644 index 0000000000..735c03c7eb --- /dev/null +++ b/synapse/replication/slave/storage/account_data.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import BaseSlavedStore +from ._slaved_id_tracker import SlavedIdTracker +from synapse.storage import DataStore +from synapse.storage.account_data import AccountDataStore +from synapse.storage.tags import TagsStore +from synapse.util.caches.stream_change_cache import StreamChangeCache + + +class SlavedAccountDataStore(BaseSlavedStore): + + def __init__(self, db_conn, hs): + super(SlavedAccountDataStore, self).__init__(db_conn, hs) + self._account_data_id_gen = SlavedIdTracker( + db_conn, "account_data_max_stream_id", "stream_id", + ) + self._account_data_stream_cache = StreamChangeCache( + "AccountDataAndTagsChangeCache", + self._account_data_id_gen.get_current_token(), + ) + + get_account_data_for_user = ( + AccountDataStore.__dict__["get_account_data_for_user"] + ) + + get_global_account_data_by_type_for_users = ( + AccountDataStore.__dict__["get_global_account_data_by_type_for_users"] + ) + + get_global_account_data_by_type_for_user = ( + AccountDataStore.__dict__["get_global_account_data_by_type_for_user"] + ) + + get_tags_for_user = TagsStore.__dict__["get_tags_for_user"] + + get_updated_tags = DataStore.get_updated_tags.__func__ + get_updated_account_data_for_user = ( + DataStore.get_updated_account_data_for_user.__func__ + ) + + def get_max_account_data_stream_id(self): + return self._account_data_id_gen.get_current_token() + + def stream_positions(self): + result = super(SlavedAccountDataStore, self).stream_positions() + position = self._account_data_id_gen.get_current_token() + result["user_account_data"] = position + result["room_account_data"] = position + result["tag_account_data"] = position + return result + + def process_replication(self, result): + stream = result.get("user_account_data") + if stream: + self._account_data_id_gen.advance(int(stream["position"])) + for row in stream["rows"]: + position, user_id, data_type = row[:3] + self.get_global_account_data_by_type_for_user.invalidate( + (data_type, user_id,) + ) + self.get_account_data_for_user.invalidate((user_id,)) + self._account_data_stream_cache.entity_has_changed( + user_id, position + ) + + stream = result.get("room_account_data") + if stream: + self._account_data_id_gen.advance(int(stream["position"])) + for row in stream["rows"]: + position, user_id = row[:2] + self.get_account_data_for_user.invalidate((user_id,)) + self._account_data_stream_cache.entity_has_changed( + user_id, position + ) + + stream = result.get("tag_account_data") + if stream: + self._account_data_id_gen.advance(int(stream["position"])) + for row in stream["rows"]: + position, user_id = row[:2] + self.get_tags_for_user.invalidate((user_id,)) + self._account_data_stream_cache.entity_has_changed( + user_id, position + ) + + return super(SlavedAccountDataStore, self).process_replication(result) diff --git a/synapse/replication/slave/storage/appservice.py b/synapse/replication/slave/storage/appservice.py new file mode 100644 index 0000000000..a374f2f1a2 --- /dev/null +++ b/synapse/replication/slave/storage/appservice.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import BaseSlavedStore +from synapse.storage import DataStore +from synapse.config.appservice import load_appservices + + +class SlavedApplicationServiceStore(BaseSlavedStore): + def __init__(self, db_conn, hs): + super(SlavedApplicationServiceStore, self).__init__(db_conn, hs) + self.services_cache = load_appservices( + hs.config.server_name, + hs.config.app_service_config_files + ) + + get_app_service_by_token = DataStore.get_app_service_by_token.__func__ + get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__ + get_app_services = DataStore.get_app_services.__func__ + get_new_events_for_appservice = DataStore.get_new_events_for_appservice.__func__ + create_appservice_txn = DataStore.create_appservice_txn.__func__ + get_appservices_by_state = DataStore.get_appservices_by_state.__func__ + get_oldest_unsent_txn = DataStore.get_oldest_unsent_txn.__func__ + _get_last_txn = DataStore._get_last_txn.__func__ + complete_appservice_txn = DataStore.complete_appservice_txn.__func__ + get_appservice_state = DataStore.get_appservice_state.__func__ + set_appservice_last_pos = DataStore.set_appservice_last_pos.__func__ + set_appservice_state = DataStore.set_appservice_state.__func__ diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py new file mode 100644 index 0000000000..cc860f9f9b --- /dev/null +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import BaseSlavedStore +from ._slaved_id_tracker import SlavedIdTracker +from synapse.storage import DataStore +from synapse.util.caches.stream_change_cache import StreamChangeCache + + +class SlavedDeviceInboxStore(BaseSlavedStore): + def __init__(self, db_conn, hs): + super(SlavedDeviceInboxStore, self).__init__(db_conn, hs) + self._device_inbox_id_gen = SlavedIdTracker( + db_conn, "device_max_stream_id", "stream_id", + ) + self._device_inbox_stream_cache = StreamChangeCache( + "DeviceInboxStreamChangeCache", + self._device_inbox_id_gen.get_current_token() + ) + self._device_federation_outbox_stream_cache = StreamChangeCache( + "DeviceFederationOutboxStreamChangeCache", + self._device_inbox_id_gen.get_current_token() + ) + + get_to_device_stream_token = DataStore.get_to_device_stream_token.__func__ + get_new_messages_for_device = DataStore.get_new_messages_for_device.__func__ + get_new_device_msgs_for_remote = DataStore.get_new_device_msgs_for_remote.__func__ + delete_messages_for_device = DataStore.delete_messages_for_device.__func__ + delete_device_msgs_for_remote = DataStore.delete_device_msgs_for_remote.__func__ + + def stream_positions(self): + result = super(SlavedDeviceInboxStore, self).stream_positions() + result["to_device"] = self._device_inbox_id_gen.get_current_token() + return result + + def process_replication(self, result): + stream = result.get("to_device") + if stream: + self._device_inbox_id_gen.advance(int(stream["position"])) + for row in stream["rows"]: + stream_id = row[0] + entity = row[1] + + if entity.startswith("@"): + self._device_inbox_stream_cache.entity_has_changed( + entity, stream_id + ) + else: + self._device_federation_outbox_stream_cache.entity_has_changed( + entity, stream_id + ) + + return super(SlavedDeviceInboxStore, self).process_replication(result) diff --git a/synapse/replication/slave/storage/directory.py b/synapse/replication/slave/storage/directory.py new file mode 100644 index 0000000000..7301d885f2 --- /dev/null +++ b/synapse/replication/slave/storage/directory.py @@ -0,0 +1,23 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import BaseSlavedStore +from synapse.storage.directory import DirectoryStore + + +class DirectoryStore(BaseSlavedStore): + get_aliases_for_room = DirectoryStore.__dict__[ + "get_aliases_for_room" + ] diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py new file mode 100644 index 0000000000..64f18bbb3e --- /dev/null +++ b/synapse/replication/slave/storage/events.py @@ -0,0 +1,295 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ._base import BaseSlavedStore +from ._slaved_id_tracker import SlavedIdTracker + +from synapse.api.constants import EventTypes +from synapse.events import FrozenEvent +from synapse.storage import DataStore +from synapse.storage.roommember import RoomMemberStore +from synapse.storage.event_federation import EventFederationStore +from synapse.storage.event_push_actions import EventPushActionsStore +from synapse.storage.state import StateStore +from synapse.storage.stream import StreamStore +from synapse.util.caches.stream_change_cache import StreamChangeCache + +import ujson as json +import logging + + +logger = logging.getLogger(__name__) + + +# So, um, we want to borrow a load of functions intended for reading from +# a DataStore, but we don't want to take functions that either write to the +# DataStore or are cached and don't have cache invalidation logic. +# +# Rather than write duplicate versions of those functions, or lift them to +# a common base class, we going to grab the underlying __func__ object from +# the method descriptor on the DataStore and chuck them into our class. + + +class SlavedEventStore(BaseSlavedStore): + + def __init__(self, db_conn, hs): + super(SlavedEventStore, self).__init__(db_conn, hs) + self._stream_id_gen = SlavedIdTracker( + db_conn, "events", "stream_ordering", + ) + self._backfill_id_gen = SlavedIdTracker( + db_conn, "events", "stream_ordering", step=-1 + ) + events_max = self._stream_id_gen.get_current_token() + event_cache_prefill, min_event_val = self._get_cache_dict( + db_conn, "events", + entity_column="room_id", + stream_column="stream_ordering", + max_value=events_max, + ) + self._events_stream_cache = StreamChangeCache( + "EventsRoomStreamChangeCache", min_event_val, + prefilled_cache=event_cache_prefill, + ) + self._membership_stream_cache = StreamChangeCache( + "MembershipStreamChangeCache", events_max, + ) + + self.stream_ordering_month_ago = 0 + self._stream_order_on_start = self.get_room_max_stream_ordering() + + # Cached functions can't be accessed through a class instance so we need + # to reach inside the __dict__ to extract them. + get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"] + get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"] + get_latest_event_ids_in_room = EventFederationStore.__dict__[ + "get_latest_event_ids_in_room" + ] + _get_current_state_for_key = StateStore.__dict__[ + "_get_current_state_for_key" + ] + get_invited_rooms_for_user = RoomMemberStore.__dict__[ + "get_invited_rooms_for_user" + ] + get_unread_event_push_actions_by_room_for_user = ( + EventPushActionsStore.__dict__["get_unread_event_push_actions_by_room_for_user"] + ) + _get_state_group_for_events = ( + StateStore.__dict__["_get_state_group_for_events"] + ) + _get_state_group_for_event = ( + StateStore.__dict__["_get_state_group_for_event"] + ) + _get_state_groups_from_groups = ( + StateStore.__dict__["_get_state_groups_from_groups"] + ) + _get_state_groups_from_groups_txn = ( + DataStore._get_state_groups_from_groups_txn.__func__ + ) + _get_state_group_from_group = ( + StateStore.__dict__["_get_state_group_from_group"] + ) + get_recent_event_ids_for_room = ( + StreamStore.__dict__["get_recent_event_ids_for_room"] + ) + + get_unread_push_actions_for_user_in_range_for_http = ( + DataStore.get_unread_push_actions_for_user_in_range_for_http.__func__ + ) + get_unread_push_actions_for_user_in_range_for_email = ( + DataStore.get_unread_push_actions_for_user_in_range_for_email.__func__ + ) + get_push_action_users_in_range = ( + DataStore.get_push_action_users_in_range.__func__ + ) + get_event = DataStore.get_event.__func__ + get_events = DataStore.get_events.__func__ + get_current_state = DataStore.get_current_state.__func__ + get_current_state_for_key = DataStore.get_current_state_for_key.__func__ + get_rooms_for_user_where_membership_is = ( + DataStore.get_rooms_for_user_where_membership_is.__func__ + ) + get_membership_changes_for_user = ( + DataStore.get_membership_changes_for_user.__func__ + ) + get_room_events_max_id = DataStore.get_room_events_max_id.__func__ + get_room_events_stream_for_room = ( + DataStore.get_room_events_stream_for_room.__func__ + ) + get_events_around = DataStore.get_events_around.__func__ + get_state_for_event = DataStore.get_state_for_event.__func__ + get_state_for_events = DataStore.get_state_for_events.__func__ + get_state_groups = DataStore.get_state_groups.__func__ + get_state_groups_ids = DataStore.get_state_groups_ids.__func__ + get_state_ids_for_event = DataStore.get_state_ids_for_event.__func__ + get_state_ids_for_events = DataStore.get_state_ids_for_events.__func__ + get_joined_users_from_state = DataStore.get_joined_users_from_state.__func__ + get_joined_users_from_context = DataStore.get_joined_users_from_context.__func__ + _get_joined_users_from_context = ( + RoomMemberStore.__dict__["_get_joined_users_from_context"] + ) + + get_recent_events_for_room = DataStore.get_recent_events_for_room.__func__ + get_room_events_stream_for_rooms = ( + DataStore.get_room_events_stream_for_rooms.__func__ + ) + is_host_joined = DataStore.is_host_joined.__func__ + _is_host_joined = RoomMemberStore.__dict__["_is_host_joined"] + get_stream_token_for_event = DataStore.get_stream_token_for_event.__func__ + + _set_before_and_after = staticmethod(DataStore._set_before_and_after) + + _get_events = DataStore._get_events.__func__ + _get_events_from_cache = DataStore._get_events_from_cache.__func__ + + _invalidate_get_event_cache = DataStore._invalidate_get_event_cache.__func__ + _enqueue_events = DataStore._enqueue_events.__func__ + _do_fetch = DataStore._do_fetch.__func__ + _fetch_event_rows = DataStore._fetch_event_rows.__func__ + _get_event_from_row = DataStore._get_event_from_row.__func__ + _get_rooms_for_user_where_membership_is_txn = ( + DataStore._get_rooms_for_user_where_membership_is_txn.__func__ + ) + _get_members_rows_txn = DataStore._get_members_rows_txn.__func__ + _get_state_for_groups = DataStore._get_state_for_groups.__func__ + _get_all_state_from_cache = DataStore._get_all_state_from_cache.__func__ + _get_events_around_txn = DataStore._get_events_around_txn.__func__ + _get_some_state_from_cache = DataStore._get_some_state_from_cache.__func__ + + get_backfill_events = DataStore.get_backfill_events.__func__ + _get_backfill_events = DataStore._get_backfill_events.__func__ + get_missing_events = DataStore.get_missing_events.__func__ + _get_missing_events = DataStore._get_missing_events.__func__ + + get_auth_chain = DataStore.get_auth_chain.__func__ + get_auth_chain_ids = DataStore.get_auth_chain_ids.__func__ + _get_auth_chain_ids_txn = DataStore._get_auth_chain_ids_txn.__func__ + + get_room_max_stream_ordering = DataStore.get_room_max_stream_ordering.__func__ + + get_forward_extremeties_for_room = ( + DataStore.get_forward_extremeties_for_room.__func__ + ) + _get_forward_extremeties_for_room = ( + EventFederationStore.__dict__["_get_forward_extremeties_for_room"] + ) + + get_all_new_events_stream = DataStore.get_all_new_events_stream.__func__ + + get_federation_out_pos = DataStore.get_federation_out_pos.__func__ + update_federation_out_pos = DataStore.update_federation_out_pos.__func__ + + def stream_positions(self): + result = super(SlavedEventStore, self).stream_positions() + result["events"] = self._stream_id_gen.get_current_token() + result["backfill"] = -self._backfill_id_gen.get_current_token() + return result + + def process_replication(self, result): + state_resets = set( + r[0] for r in result.get("state_resets", {"rows": []})["rows"] + ) + + stream = result.get("events") + if stream: + self._stream_id_gen.advance(int(stream["position"])) + + if stream["rows"]: + logger.info("Got %d event rows", len(stream["rows"])) + + for row in stream["rows"]: + self._process_replication_row( + row, backfilled=False, state_resets=state_resets + ) + + stream = result.get("backfill") + if stream: + self._backfill_id_gen.advance(-int(stream["position"])) + for row in stream["rows"]: + self._process_replication_row( + row, backfilled=True, state_resets=state_resets + ) + + stream = result.get("forward_ex_outliers") + if stream: + self._stream_id_gen.advance(int(stream["position"])) + for row in stream["rows"]: + event_id = row[1] + self._invalidate_get_event_cache(event_id) + + stream = result.get("backward_ex_outliers") + if stream: + self._backfill_id_gen.advance(-int(stream["position"])) + for row in stream["rows"]: + event_id = row[1] + self._invalidate_get_event_cache(event_id) + + return super(SlavedEventStore, self).process_replication(result) + + def _process_replication_row(self, row, backfilled, state_resets): + position = row[0] + internal = json.loads(row[1]) + event_json = json.loads(row[2]) + event = FrozenEvent(event_json, internal_metadata_dict=internal) + self.invalidate_caches_for_event( + event, backfilled, reset_state=position in state_resets + ) + + def invalidate_caches_for_event(self, event, backfilled, reset_state): + if reset_state: + self._get_current_state_for_key.invalidate_all() + self.get_rooms_for_user.invalidate_all() + self.get_users_in_room.invalidate((event.room_id,)) + + self._invalidate_get_event_cache(event.event_id) + + self.get_latest_event_ids_in_room.invalidate((event.room_id,)) + + self.get_unread_event_push_actions_by_room_for_user.invalidate_many( + (event.room_id,) + ) + + if not backfilled: + self._events_stream_cache.entity_has_changed( + event.room_id, event.internal_metadata.stream_ordering + ) + + # self.get_unread_event_push_actions_by_room_for_user.invalidate_many( + # (event.room_id,) + # ) + + if event.type == EventTypes.Redaction: + self._invalidate_get_event_cache(event.redacts) + + if event.type == EventTypes.Member: + self.get_rooms_for_user.invalidate((event.state_key,)) + self.get_users_in_room.invalidate((event.room_id,)) + self._membership_stream_cache.entity_has_changed( + event.state_key, event.internal_metadata.stream_ordering + ) + self.get_invited_rooms_for_user.invalidate((event.state_key,)) + + if not event.is_state(): + return + + if backfilled: + return + + if (not event.internal_metadata.is_invite_from_remote() + and event.internal_metadata.is_outlier()): + return + + self._get_current_state_for_key.invalidate(( + event.room_id, event.type, event.state_key + )) diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py new file mode 100644 index 0000000000..819ed62881 --- /dev/null +++ b/synapse/replication/slave/storage/filtering.py @@ -0,0 +1,25 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import BaseSlavedStore +from synapse.storage.filtering import FilteringStore + + +class SlavedFilteringStore(BaseSlavedStore): + def __init__(self, db_conn, hs): + super(SlavedFilteringStore, self).__init__(db_conn, hs) + + # Filters are immutable so this cache doesn't need to be expired + get_user_filter = FilteringStore.__dict__["get_user_filter"] diff --git a/synapse/replication/slave/storage/keys.py b/synapse/replication/slave/storage/keys.py new file mode 100644 index 0000000000..dd2ae49e48 --- /dev/null +++ b/synapse/replication/slave/storage/keys.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import BaseSlavedStore +from synapse.storage import DataStore +from synapse.storage.keys import KeyStore + + +class SlavedKeyStore(BaseSlavedStore): + _get_server_verify_key = KeyStore.__dict__[ + "_get_server_verify_key" + ] + + get_server_verify_keys = DataStore.get_server_verify_keys.__func__ + store_server_verify_key = DataStore.store_server_verify_key.__func__ + + get_server_certificate = DataStore.get_server_certificate.__func__ + store_server_certificate = DataStore.store_server_certificate.__func__ + + get_server_keys_json = DataStore.get_server_keys_json.__func__ + store_server_keys_json = DataStore.store_server_keys_json.__func__ diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py new file mode 100644 index 0000000000..703f4a49bf --- /dev/null +++ b/synapse/replication/slave/storage/presence.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import BaseSlavedStore +from ._slaved_id_tracker import SlavedIdTracker + +from synapse.util.caches.stream_change_cache import StreamChangeCache +from synapse.storage import DataStore + + +class SlavedPresenceStore(BaseSlavedStore): + def __init__(self, db_conn, hs): + super(SlavedPresenceStore, self).__init__(db_conn, hs) + self._presence_id_gen = SlavedIdTracker( + db_conn, "presence_stream", "stream_id", + ) + + self._presence_on_startup = self._get_active_presence(db_conn) + + self.presence_stream_cache = self.presence_stream_cache = StreamChangeCache( + "PresenceStreamChangeCache", self._presence_id_gen.get_current_token() + ) + + _get_active_presence = DataStore._get_active_presence.__func__ + take_presence_startup_info = DataStore.take_presence_startup_info.__func__ + get_presence_for_users = DataStore.get_presence_for_users.__func__ + + def get_current_presence_token(self): + return self._presence_id_gen.get_current_token() + + def stream_positions(self): + result = super(SlavedPresenceStore, self).stream_positions() + position = self._presence_id_gen.get_current_token() + result["presence"] = position + return result + + def process_replication(self, result): + stream = result.get("presence") + if stream: + self._presence_id_gen.advance(int(stream["position"])) + for row in stream["rows"]: + position, user_id = row[:2] + self.presence_stream_cache.entity_has_changed( + user_id, position + ) + + return super(SlavedPresenceStore, self).process_replication(result) diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py new file mode 100644 index 0000000000..21ceb0213a --- /dev/null +++ b/synapse/replication/slave/storage/push_rule.py @@ -0,0 +1,67 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .events import SlavedEventStore +from ._slaved_id_tracker import SlavedIdTracker +from synapse.storage import DataStore +from synapse.storage.push_rule import PushRuleStore +from synapse.util.caches.stream_change_cache import StreamChangeCache + + +class SlavedPushRuleStore(SlavedEventStore): + def __init__(self, db_conn, hs): + super(SlavedPushRuleStore, self).__init__(db_conn, hs) + self._push_rules_stream_id_gen = SlavedIdTracker( + db_conn, "push_rules_stream", "stream_id", + ) + self.push_rules_stream_cache = StreamChangeCache( + "PushRulesStreamChangeCache", + self._push_rules_stream_id_gen.get_current_token(), + ) + + get_push_rules_for_user = PushRuleStore.__dict__["get_push_rules_for_user"] + get_push_rules_enabled_for_user = ( + PushRuleStore.__dict__["get_push_rules_enabled_for_user"] + ) + have_push_rules_changed_for_user = ( + DataStore.have_push_rules_changed_for_user.__func__ + ) + + def get_push_rules_stream_token(self): + return ( + self._push_rules_stream_id_gen.get_current_token(), + self._stream_id_gen.get_current_token(), + ) + + def stream_positions(self): + result = super(SlavedPushRuleStore, self).stream_positions() + result["push_rules"] = self._push_rules_stream_id_gen.get_current_token() + return result + + def process_replication(self, result): + stream = result.get("push_rules") + if stream: + for row in stream["rows"]: + position = row[0] + user_id = row[2] + self.get_push_rules_for_user.invalidate((user_id,)) + self.get_push_rules_enabled_for_user.invalidate((user_id,)) + self.push_rules_stream_cache.entity_has_changed( + user_id, position + ) + + self._push_rules_stream_id_gen.advance(int(stream["position"])) + + return super(SlavedPushRuleStore, self).process_replication(result) diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py new file mode 100644 index 0000000000..d88206b3bb --- /dev/null +++ b/synapse/replication/slave/storage/pushers.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import BaseSlavedStore +from ._slaved_id_tracker import SlavedIdTracker + +from synapse.storage import DataStore + + +class SlavedPusherStore(BaseSlavedStore): + + def __init__(self, db_conn, hs): + super(SlavedPusherStore, self).__init__(db_conn, hs) + self._pushers_id_gen = SlavedIdTracker( + db_conn, "pushers", "id", + extra_tables=[("deleted_pushers", "stream_id")], + ) + + get_all_pushers = DataStore.get_all_pushers.__func__ + get_pushers_by = DataStore.get_pushers_by.__func__ + get_pushers_by_app_id_and_pushkey = ( + DataStore.get_pushers_by_app_id_and_pushkey.__func__ + ) + _decode_pushers_rows = DataStore._decode_pushers_rows.__func__ + + def stream_positions(self): + result = super(SlavedPusherStore, self).stream_positions() + result["pushers"] = self._pushers_id_gen.get_current_token() + return result + + def process_replication(self, result): + stream = result.get("pushers") + if stream: + self._pushers_id_gen.advance(int(stream["position"])) + + stream = result.get("deleted_pushers") + if stream: + self._pushers_id_gen.advance(int(stream["position"])) + + return super(SlavedPusherStore, self).process_replication(result) diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py new file mode 100644 index 0000000000..ac9662d399 --- /dev/null +++ b/synapse/replication/slave/storage/receipts.py @@ -0,0 +1,84 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import BaseSlavedStore +from ._slaved_id_tracker import SlavedIdTracker + +from synapse.storage import DataStore +from synapse.storage.receipts import ReceiptsStore +from synapse.util.caches.stream_change_cache import StreamChangeCache + +# So, um, we want to borrow a load of functions intended for reading from +# a DataStore, but we don't want to take functions that either write to the +# DataStore or are cached and don't have cache invalidation logic. +# +# Rather than write duplicate versions of those functions, or lift them to +# a common base class, we going to grab the underlying __func__ object from +# the method descriptor on the DataStore and chuck them into our class. + + +class SlavedReceiptsStore(BaseSlavedStore): + + def __init__(self, db_conn, hs): + super(SlavedReceiptsStore, self).__init__(db_conn, hs) + + self._receipts_id_gen = SlavedIdTracker( + db_conn, "receipts_linearized", "stream_id" + ) + + self._receipts_stream_cache = StreamChangeCache( + "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token() + ) + + get_receipts_for_user = ReceiptsStore.__dict__["get_receipts_for_user"] + get_linearized_receipts_for_room = ( + ReceiptsStore.__dict__["get_linearized_receipts_for_room"] + ) + _get_linearized_receipts_for_rooms = ( + ReceiptsStore.__dict__["_get_linearized_receipts_for_rooms"] + ) + get_last_receipt_event_id_for_user = ( + ReceiptsStore.__dict__["get_last_receipt_event_id_for_user"] + ) + + get_max_receipt_stream_id = DataStore.get_max_receipt_stream_id.__func__ + get_all_updated_receipts = DataStore.get_all_updated_receipts.__func__ + + get_linearized_receipts_for_rooms = ( + DataStore.get_linearized_receipts_for_rooms.__func__ + ) + + def stream_positions(self): + result = super(SlavedReceiptsStore, self).stream_positions() + result["receipts"] = self._receipts_id_gen.get_current_token() + return result + + def process_replication(self, result): + stream = result.get("receipts") + if stream: + self._receipts_id_gen.advance(int(stream["position"])) + for row in stream["rows"]: + position, room_id, receipt_type, user_id = row[:4] + self.invalidate_caches_for_receipt(room_id, receipt_type, user_id) + self._receipts_stream_cache.entity_has_changed(room_id, position) + + return super(SlavedReceiptsStore, self).process_replication(result) + + def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): + self.get_receipts_for_user.invalidate((user_id, receipt_type)) + self.get_linearized_receipts_for_room.invalidate_many((room_id,)) + self.get_last_receipt_event_id_for_user.invalidate( + (user_id, room_id, receipt_type) + ) diff --git a/synapse/replication/slave/storage/registration.py b/synapse/replication/slave/storage/registration.py new file mode 100644 index 0000000000..e27c7332d2 --- /dev/null +++ b/synapse/replication/slave/storage/registration.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import BaseSlavedStore +from synapse.storage import DataStore +from synapse.storage.registration import RegistrationStore + + +class SlavedRegistrationStore(BaseSlavedStore): + def __init__(self, db_conn, hs): + super(SlavedRegistrationStore, self).__init__(db_conn, hs) + + # TODO: use the cached version and invalidate deleted tokens + get_user_by_access_token = RegistrationStore.__dict__[ + "get_user_by_access_token" + ] + + _query_for_auth = DataStore._query_for_auth.__func__ + get_user_by_id = RegistrationStore.__dict__[ + "get_user_by_id" + ] diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py new file mode 100644 index 0000000000..23c613863f --- /dev/null +++ b/synapse/replication/slave/storage/room.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import BaseSlavedStore +from synapse.storage import DataStore +from ._slaved_id_tracker import SlavedIdTracker + + +class RoomStore(BaseSlavedStore): + def __init__(self, db_conn, hs): + super(RoomStore, self).__init__(db_conn, hs) + self._public_room_id_gen = SlavedIdTracker( + db_conn, "public_room_list_stream", "stream_id" + ) + + get_public_room_ids = DataStore.get_public_room_ids.__func__ + get_current_public_room_stream_id = ( + DataStore.get_current_public_room_stream_id.__func__ + ) + get_public_room_ids_at_stream_id = ( + DataStore.get_public_room_ids_at_stream_id.__func__ + ) + get_public_room_ids_at_stream_id_txn = ( + DataStore.get_public_room_ids_at_stream_id_txn.__func__ + ) + get_published_at_stream_id_txn = ( + DataStore.get_published_at_stream_id_txn.__func__ + ) + get_public_room_changes = DataStore.get_public_room_changes.__func__ + + def stream_positions(self): + result = super(RoomStore, self).stream_positions() + result["public_rooms"] = self._public_room_id_gen.get_current_token() + return result + + def process_replication(self, result): + stream = result.get("public_rooms") + if stream: + self._public_room_id_gen.advance(int(stream["position"])) + + return super(RoomStore, self).process_replication(result) diff --git a/synapse/replication/slave/storage/transactions.py b/synapse/replication/slave/storage/transactions.py new file mode 100644 index 0000000000..fbb58f35da --- /dev/null +++ b/synapse/replication/slave/storage/transactions.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import BaseSlavedStore +from synapse.storage import DataStore +from synapse.storage.transactions import TransactionStore + + +class TransactionStore(BaseSlavedStore): + get_destination_retry_timings = TransactionStore.__dict__[ + "get_destination_retry_timings" + ] + _get_destination_retry_timings = DataStore._get_destination_retry_timings.__func__ + set_destination_retry_timings = DataStore.set_destination_retry_timings.__func__ + _set_destination_retry_timings = DataStore._set_destination_retry_timings.__func__ + + prep_send_transaction = DataStore.prep_send_transaction.__func__ + delivered_txn = DataStore.delivered_txn.__func__ diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 6688fa8fa0..f9f5a3e077 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -44,6 +44,12 @@ from synapse.rest.client.v2_alpha import ( tokenrefresh, tags, account_data, + report_event, + openid, + notifications, + devices, + thirdparty, + sendtodevice, ) from synapse.http.server import JsonResource @@ -86,3 +92,9 @@ class ClientRestResource(JsonResource): tokenrefresh.register_servlets(hs, client_resource) tags.register_servlets(hs, client_resource) account_data.register_servlets(hs, client_resource) + report_event.register_servlets(hs, client_resource) + openid.register_servlets(hs, client_resource) + notifications.register_servlets(hs, client_resource) + devices.register_servlets(hs, client_resource) + thirdparty.register_servlets(hs, client_resource) + sendtodevice.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py new file mode 100644 index 0000000000..351170edbc --- /dev/null +++ b/synapse/rest/client/transactions.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module contains logic for storing HTTP PUT transactions. This is used +to ensure idempotency when performing PUTs using the REST API.""" +import logging + +from synapse.api.auth import get_access_token_from_request +from synapse.util.async import ObservableDeferred + +logger = logging.getLogger(__name__) + + +def get_transaction_key(request): + """A helper function which returns a transaction key that can be used + with TransactionCache for idempotent requests. + + Idempotency is based on the returned key being the same for separate + requests to the same endpoint. The key is formed from the HTTP request + path and the access_token for the requesting user. + + Args: + request (twisted.web.http.Request): The incoming request. Must + contain an access_token. + Returns: + str: A transaction key + """ + token = get_access_token_from_request(request) + return request.path + "/" + token + + +CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins + + +class HttpTransactionCache(object): + + def __init__(self, clock): + self.clock = clock + self.transactions = { + # $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp) + } + # Try to clean entries every 30 mins. This means entries will exist + # for at *LEAST* 30 mins, and at *MOST* 60 mins. + self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS) + + def fetch_or_execute_request(self, request, fn, *args, **kwargs): + """A helper function for fetch_or_execute which extracts + a transaction key from the given request. + + See: + fetch_or_execute + """ + return self.fetch_or_execute( + get_transaction_key(request), fn, *args, **kwargs + ) + + def fetch_or_execute(self, txn_key, fn, *args, **kwargs): + """Fetches the response for this transaction, or executes the given function + to produce a response for this transaction. + + Args: + txn_key (str): A key to ensure idempotency should fetch_or_execute be + called again at a later point in time. + fn (function): A function which returns a tuple of + (response_code, response_dict). + *args: Arguments to pass to fn. + **kwargs: Keyword arguments to pass to fn. + Returns: + Deferred which resolves to a tuple of (response_code, response_dict). + """ + try: + return self.transactions[txn_key][0].observe() + except (KeyError, IndexError): + pass # execute the function instead. + + deferred = fn(*args, **kwargs) + observable = ObservableDeferred(deferred) + self.transactions[txn_key] = (observable, self.clock.time_msec()) + return observable.observe() + + def _cleanup(self): + now = self.clock.time_msec() + for key in self.transactions.keys(): + ts = self.transactions[key][1] + if now > (ts + CLEANUP_PERIOD_MS): # after cleanup period + del self.transactions[key] diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index aa05b3f023..af21661d7c 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -28,6 +28,10 @@ logger = logging.getLogger(__name__) class WhoisRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/admin/whois/(?P<user_id>[^/]*)") + def __init__(self, hs): + super(WhoisRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, user_id): target_user = UserID.from_string(user_id) @@ -46,5 +50,86 @@ class WhoisRestServlet(ClientV1RestServlet): defer.returnValue((200, ret)) +class PurgeMediaCacheRestServlet(ClientV1RestServlet): + PATTERNS = client_path_patterns("/admin/purge_media_cache") + + def __init__(self, hs): + self.media_repository = hs.get_media_repository() + super(PurgeMediaCacheRestServlet, self).__init__(hs) + + @defer.inlineCallbacks + def on_POST(self, request): + requester = yield self.auth.get_user_by_req(request) + is_admin = yield self.auth.is_server_admin(requester.user) + + if not is_admin: + raise AuthError(403, "You are not a server admin") + + before_ts = request.args.get("before_ts", None) + if not before_ts: + raise SynapseError(400, "Missing 'before_ts' arg") + + logger.info("before_ts: %r", before_ts[0]) + + try: + before_ts = int(before_ts[0]) + except Exception: + raise SynapseError(400, "Invalid 'before_ts' arg") + + ret = yield self.media_repository.delete_old_remote_media(before_ts) + + defer.returnValue((200, ret)) + + +class PurgeHistoryRestServlet(ClientV1RestServlet): + PATTERNS = client_path_patterns( + "/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" + ) + + def __init__(self, hs): + super(PurgeHistoryRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + + @defer.inlineCallbacks + def on_POST(self, request, room_id, event_id): + requester = yield self.auth.get_user_by_req(request) + is_admin = yield self.auth.is_server_admin(requester.user) + + if not is_admin: + raise AuthError(403, "You are not a server admin") + + yield self.handlers.message_handler.purge_history(room_id, event_id) + + defer.returnValue((200, {})) + + +class DeactivateAccountRestServlet(ClientV1RestServlet): + PATTERNS = client_path_patterns("/admin/deactivate/(?P<target_user_id>[^/]*)") + + def __init__(self, hs): + self.store = hs.get_datastore() + super(DeactivateAccountRestServlet, self).__init__(hs) + + @defer.inlineCallbacks + def on_POST(self, request, target_user_id): + UserID.from_string(target_user_id) + requester = yield self.auth.get_user_by_req(request) + is_admin = yield self.auth.is_server_admin(requester.user) + + if not is_admin: + raise AuthError(403, "You are not a server admin") + + # FIXME: Theoretically there is a race here wherein user resets password + # using threepid. + yield self.store.user_delete_access_tokens(target_user_id) + yield self.store.user_delete_threepids(target_user_id) + yield self.store.user_set_password_hash(target_user_id, None) + + defer.returnValue((200, {})) + + def register_servlets(hs, http_server): WhoisRestServlet(hs).register(http_server) + PurgeMediaCacheRestServlet(hs).register(http_server) + DeactivateAccountRestServlet(hs).register(http_server) + PurgeHistoryRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py index 1c020b7e2c..c7aa0bbf59 100644 --- a/synapse/rest/client/v1/base.py +++ b/synapse/rest/client/v1/base.py @@ -18,7 +18,8 @@ from synapse.http.servlet import RestServlet from synapse.api.urls import CLIENT_PREFIX -from .transactions import HttpTransactionStore +from synapse.rest.client.transactions import HttpTransactionCache + import re import logging @@ -52,8 +53,11 @@ class ClientV1RestServlet(RestServlet): """ def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): + """ self.hs = hs - self.handlers = hs.get_handlers() self.builder_factory = hs.get_event_builder_factory() self.auth = hs.get_v1auth() - self.txns = HttpTransactionStore() + self.txns = HttpTransactionCache(hs.get_clock()) diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 60c5ec77aa..09d0831594 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -30,11 +30,16 @@ logger = logging.getLogger(__name__) def register_servlets(hs, http_server): ClientDirectoryServer(hs).register(http_server) + ClientDirectoryListServer(hs).register(http_server) class ClientDirectoryServer(ClientV1RestServlet): PATTERNS = client_path_patterns("/directory/room/(?P<room_alias>[^/]*)$") + def __init__(self, hs): + super(ClientDirectoryServer, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, room_alias): room_alias = RoomAlias.from_string(room_alias) @@ -127,8 +132,9 @@ class ClientDirectoryServer(ClientV1RestServlet): room_alias = RoomAlias.from_string(room_alias) yield dir_handler.delete_association( - user.to_string(), room_alias + requester, user.to_string(), room_alias ) + logger.info( "User %s deleted alias %s", user.to_string(), @@ -136,3 +142,45 @@ class ClientDirectoryServer(ClientV1RestServlet): ) defer.returnValue((200, {})) + + +class ClientDirectoryListServer(ClientV1RestServlet): + PATTERNS = client_path_patterns("/directory/list/room/(?P<room_id>[^/]*)$") + + def __init__(self, hs): + super(ClientDirectoryListServer, self).__init__(hs) + self.store = hs.get_datastore() + self.handlers = hs.get_handlers() + + @defer.inlineCallbacks + def on_GET(self, request, room_id): + room = yield self.store.get_room(room_id) + if room is None: + raise SynapseError(400, "Unknown room") + + defer.returnValue((200, { + "visibility": "public" if room["is_public"] else "private" + })) + + @defer.inlineCallbacks + def on_PUT(self, request, room_id): + requester = yield self.auth.get_user_by_req(request) + + content = parse_json_object_from_request(request) + visibility = content.get("visibility", "public") + + yield self.handlers.directory_handler.edit_published_room_list( + requester, room_id, visibility, + ) + + defer.returnValue((200, {})) + + @defer.inlineCallbacks + def on_DELETE(self, request, room_id): + requester = yield self.auth.get_user_by_req(request) + + yield self.handlers.directory_handler.edit_published_room_list( + requester, room_id, "private", + ) + + defer.returnValue((200, {})) diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index d1afa0f0d5..701b6f549b 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -32,6 +32,10 @@ class EventStreamRestServlet(ClientV1RestServlet): DEFAULT_LONGPOLL_TIME_MS = 30000 + def __init__(self, hs): + super(EventStreamRestServlet, self).__init__(hs) + self.event_stream_handler = hs.get_event_stream_handler() + @defer.inlineCallbacks def on_GET(self, request): requester = yield self.auth.get_user_by_req( @@ -45,30 +49,26 @@ class EventStreamRestServlet(ClientV1RestServlet): raise SynapseError(400, "Guest users must specify room_id param") if "room_id" in request.args: room_id = request.args["room_id"][0] - try: - handler = self.handlers.event_stream_handler - pagin_config = PaginationConfig.from_request(request) - timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS - if "timeout" in request.args: - try: - timeout = int(request.args["timeout"][0]) - except ValueError: - raise SynapseError(400, "timeout must be in milliseconds.") - - as_client_event = "raw" not in request.args - - chunk = yield handler.get_stream( - requester.user.to_string(), - pagin_config, - timeout=timeout, - as_client_event=as_client_event, - affect_presence=(not is_guest), - room_id=room_id, - is_guest=is_guest, - ) - except: - logger.exception("Event stream failed") - raise + + pagin_config = PaginationConfig.from_request(request) + timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS + if "timeout" in request.args: + try: + timeout = int(request.args["timeout"][0]) + except ValueError: + raise SynapseError(400, "timeout must be in milliseconds.") + + as_client_event = "raw" not in request.args + + chunk = yield self.event_stream_handler.get_stream( + requester.user.to_string(), + pagin_config, + timeout=timeout, + as_client_event=as_client_event, + affect_presence=(not is_guest), + room_id=room_id, + is_guest=is_guest, + ) defer.returnValue((200, chunk)) @@ -83,12 +83,12 @@ class EventRestServlet(ClientV1RestServlet): def __init__(self, hs): super(EventRestServlet, self).__init__(hs) self.clock = hs.get_clock() + self.event_handler = hs.get_event_handler() @defer.inlineCallbacks def on_GET(self, request, event_id): requester = yield self.auth.get_user_by_req(request) - handler = self.handlers.event_handler - event = yield handler.get_event(requester.user, event_id) + event = yield self.event_handler.get_event(requester.user, event_id) time_now = self.clock.time_msec() if event: diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index 36c3520567..478e21eea8 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -23,14 +23,17 @@ from .base import ClientV1RestServlet, client_path_patterns class InitialSyncRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/initialSync$") + def __init__(self, hs): + super(InitialSyncRestServlet, self).__init__(hs) + self.initial_sync_handler = hs.get_initial_sync_handler() + @defer.inlineCallbacks def on_GET(self, request): requester = yield self.auth.get_user_by_req(request) as_client_event = "raw" not in request.args pagination_config = PaginationConfig.from_request(request) - handler = self.handlers.message_handler include_archived = request.args.get("archived", None) == ["true"] - content = yield handler.snapshot_all_rooms( + content = yield self.initial_sync_handler.snapshot_all_rooms( user_id=requester.user.to_string(), pagin_config=pagination_config, as_client_event=as_client_event, diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index fe593d07ce..093bc072f4 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -43,20 +43,25 @@ class LoginRestServlet(ClientV1RestServlet): SAML2_TYPE = "m.login.saml2" CAS_TYPE = "m.login.cas" TOKEN_TYPE = "m.login.token" + JWT_TYPE = "m.login.jwt" def __init__(self, hs): super(LoginRestServlet, self).__init__(hs) self.idp_redirect_url = hs.config.saml2_idp_redirect_url self.password_enabled = hs.config.password_enabled self.saml2_enabled = hs.config.saml2_enabled + self.jwt_enabled = hs.config.jwt_enabled + self.jwt_secret = hs.config.jwt_secret + self.jwt_algorithm = hs.config.jwt_algorithm self.cas_enabled = hs.config.cas_enabled - self.cas_server_url = hs.config.cas_server_url - self.cas_required_attributes = hs.config.cas_required_attributes - self.servername = hs.config.server_name - self.http_client = hs.get_simple_http_client() + self.auth_handler = self.hs.get_auth_handler() + self.device_handler = self.hs.get_device_handler() + self.handlers = hs.get_handlers() def on_GET(self, request): flows = [] + if self.jwt_enabled: + flows.append({"type": LoginRestServlet.JWT_TYPE}) if self.saml2_enabled: flows.append({"type": LoginRestServlet.SAML2_TYPE}) if self.cas_enabled: @@ -98,16 +103,9 @@ class LoginRestServlet(ClientV1RestServlet): "uri": "%s%s" % (self.idp_redirect_url, relay_state) } defer.returnValue((200, result)) - # TODO Delete this after all CAS clients switch to token login instead - elif self.cas_enabled and (login_submission["type"] == - LoginRestServlet.CAS_TYPE): - uri = "%s/proxyValidate" % (self.cas_server_url,) - args = { - "ticket": login_submission["ticket"], - "service": login_submission["service"] - } - body = yield self.http_client.get_raw(uri, args) - result = yield self.do_cas_login(body) + elif self.jwt_enabled and (login_submission["type"] == + LoginRestServlet.JWT_TYPE): + result = yield self.do_jwt_login(login_submission) defer.returnValue(result) elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: result = yield self.do_token_login(login_submission) @@ -133,16 +131,21 @@ class LoginRestServlet(ClientV1RestServlet): user_id, self.hs.hostname ).to_string() - auth_handler = self.handlers.auth_handler - user_id, access_token, refresh_token = yield auth_handler.login_with_password( + auth_handler = self.auth_handler + user_id = yield auth_handler.validate_password_login( user_id=user_id, - password=login_submission["password"]) - + password=login_submission["password"], + ) + device_id = yield self._register_device(user_id, login_submission) + access_token = yield auth_handler.get_access_token_for_user_id( + user_id, device_id, + login_submission.get("initial_device_display_name"), + ) result = { "user_id": user_id, # may have changed "access_token": access_token, - "refresh_token": refresh_token, "home_server": self.hs.hostname, + "device_id": device_id, } defer.returnValue((200, result)) @@ -150,54 +153,68 @@ class LoginRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def do_token_login(self, login_submission): token = login_submission['token'] - auth_handler = self.handlers.auth_handler + auth_handler = self.auth_handler user_id = ( yield auth_handler.validate_short_term_login_token_and_get_user_id(token) ) - user_id, access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(user_id) + device_id = yield self._register_device(user_id, login_submission) + access_token = yield auth_handler.get_access_token_for_user_id( + user_id, device_id, + login_submission.get("initial_device_display_name"), ) result = { "user_id": user_id, # may have changed "access_token": access_token, - "refresh_token": refresh_token, "home_server": self.hs.hostname, + "device_id": device_id, } defer.returnValue((200, result)) - # TODO Delete this after all CAS clients switch to token login instead @defer.inlineCallbacks - def do_cas_login(self, cas_response_body): - user, attributes = self.parse_cas_response(cas_response_body) + def do_jwt_login(self, login_submission): + token = login_submission.get("token", None) + if token is None: + raise LoginError( + 401, "Token field for JWT is missing", + errcode=Codes.UNAUTHORIZED + ) - for required_attribute, required_value in self.cas_required_attributes.items(): - # If required attribute was not in CAS Response - Forbidden - if required_attribute not in attributes: - raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) + import jwt + from jwt.exceptions import InvalidTokenError - # Also need to check value - if required_value is not None: - actual_value = attributes[required_attribute] - # If required attribute value does not match expected - Forbidden - if required_value != actual_value: - raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) + try: + payload = jwt.decode(token, self.jwt_secret, algorithms=[self.jwt_algorithm]) + except jwt.ExpiredSignatureError: + raise LoginError(401, "JWT expired", errcode=Codes.UNAUTHORIZED) + except InvalidTokenError: + raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED) + + user = payload.get("sub", None) + if user is None: + raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED) user_id = UserID.create(user, self.hs.hostname).to_string() - auth_handler = self.handlers.auth_handler - user_exists = yield auth_handler.does_user_exist(user_id) - if user_exists: - user_id, access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(user_id) + auth_handler = self.auth_handler + registered_user_id = yield auth_handler.check_user_exists(user_id) + if registered_user_id: + device_id = yield self._register_device( + registered_user_id, login_submission ) + access_token = yield auth_handler.get_access_token_for_user_id( + registered_user_id, device_id, + login_submission.get("initial_device_display_name"), + ) + result = { - "user_id": user_id, # may have changed + "user_id": registered_user_id, "access_token": access_token, - "refresh_token": refresh_token, "home_server": self.hs.hostname, } - else: + # TODO: we should probably check that the register isn't going + # to fonx/change our user_id before registering the device + device_id = yield self._register_device(user_id, login_submission) user_id, access_token = ( yield self.handlers.registration_handler.register(localpart=user) ) @@ -209,32 +226,25 @@ class LoginRestServlet(ClientV1RestServlet): defer.returnValue((200, result)) - # TODO Delete this after all CAS clients switch to token login instead - def parse_cas_response(self, cas_response_body): - root = ET.fromstring(cas_response_body) - if not root.tag.endswith("serviceResponse"): - raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) - if not root[0].tag.endswith("authenticationSuccess"): - raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED) - for child in root[0]: - if child.tag.endswith("user"): - user = child.text - if child.tag.endswith("attributes"): - attributes = {} - for attribute in child: - # ElementTree library expands the namespace in attribute tags - # to the full URL of the namespace. - # See (https://docs.python.org/2/library/xml.etree.elementtree.html) - # We don't care about namespace here and it will always be encased in - # curly braces, so we remove them. - if "}" in attribute.tag: - attributes[attribute.tag.split("}")[1]] = attribute.text - else: - attributes[attribute.tag] = attribute.text - if user is None or attributes is None: - raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) - - return (user, attributes) + def _register_device(self, user_id, login_submission): + """Register a device for a user. + + This is called after the user's credentials have been validated, but + before the access token has been issued. + + Args: + (str) user_id: full canonical @user:id + (object) login_submission: dictionary supplied to /login call, from + which we pull device_id and initial_device_name + Returns: + defer.Deferred: (str) device_id + """ + device_id = login_submission.get("device_id") + initial_display_name = login_submission.get( + "initial_device_display_name") + return self.device_handler.check_device_registered( + user_id, device_id, initial_display_name + ) class SAML2RestServlet(ClientV1RestServlet): @@ -243,6 +253,7 @@ class SAML2RestServlet(ClientV1RestServlet): def __init__(self, hs): super(SAML2RestServlet, self).__init__(hs) self.sp_config = hs.config.saml2_config_path + self.handlers = hs.get_handlers() @defer.inlineCallbacks def on_POST(self, request): @@ -280,18 +291,6 @@ class SAML2RestServlet(ClientV1RestServlet): defer.returnValue((200, {"status": "not_authenticated"})) -# TODO Delete this after all CAS clients switch to token login instead -class CasRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/login/cas", releases=()) - - def __init__(self, hs): - super(CasRestServlet, self).__init__(hs) - self.cas_server_url = hs.config.cas_server_url - - def on_GET(self, request): - return (200, {"serverUrl": self.cas_server_url}) - - class CasRedirectServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/login/cas/redirect", releases=()) @@ -311,7 +310,7 @@ class CasRedirectServlet(ClientV1RestServlet): service_param = urllib.urlencode({ "service": "%s?%s" % (hs_redirect_url, client_redirect_url_param) }) - request.redirect("%s?%s" % (self.cas_server_url, service_param)) + request.redirect("%s/login?%s" % (self.cas_server_url, service_param)) finish_request(request) @@ -323,6 +322,8 @@ class CasTicketServlet(ClientV1RestServlet): self.cas_server_url = hs.config.cas_server_url self.cas_service_url = hs.config.cas_service_url self.cas_required_attributes = hs.config.cas_required_attributes + self.auth_handler = hs.get_auth_handler() + self.handlers = hs.get_handlers() @defer.inlineCallbacks def on_GET(self, request): @@ -354,14 +355,14 @@ class CasTicketServlet(ClientV1RestServlet): raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) user_id = UserID.create(user, self.hs.hostname).to_string() - auth_handler = self.handlers.auth_handler - user_exists = yield auth_handler.does_user_exist(user_id) - if not user_exists: - user_id, _ = ( + auth_handler = self.auth_handler + registered_user_id = yield auth_handler.check_user_exists(user_id) + if not registered_user_id: + registered_user_id, _ = ( yield self.handlers.registration_handler.register(localpart=user) ) - login_token = auth_handler.generate_short_term_login_token(user_id) + login_token = auth_handler.generate_short_term_login_token(registered_user_id) redirect_url = self.add_login_token_to_redirect_url(client_redirect_url, login_token) request.redirect(redirect_url) @@ -375,30 +376,36 @@ class CasTicketServlet(ClientV1RestServlet): return urlparse.urlunparse(url_parts) def parse_cas_response(self, cas_response_body): - root = ET.fromstring(cas_response_body) - if not root.tag.endswith("serviceResponse"): - raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) - if not root[0].tag.endswith("authenticationSuccess"): - raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED) - for child in root[0]: - if child.tag.endswith("user"): - user = child.text - if child.tag.endswith("attributes"): - attributes = {} - for attribute in child: - # ElementTree library expands the namespace in attribute tags - # to the full URL of the namespace. - # See (https://docs.python.org/2/library/xml.etree.elementtree.html) - # We don't care about namespace here and it will always be encased in - # curly braces, so we remove them. - if "}" in attribute.tag: - attributes[attribute.tag.split("}")[1]] = attribute.text - else: - attributes[attribute.tag] = attribute.text - if user is None or attributes is None: - raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) - - return (user, attributes) + user = None + attributes = {} + try: + root = ET.fromstring(cas_response_body) + if not root.tag.endswith("serviceResponse"): + raise Exception("root of CAS response is not serviceResponse") + success = (root[0].tag.endswith("authenticationSuccess")) + for child in root[0]: + if child.tag.endswith("user"): + user = child.text + if child.tag.endswith("attributes"): + for attribute in child: + # ElementTree library expands the namespace in + # attribute tags to the full URL of the namespace. + # We don't care about namespace here and it will always + # be encased in curly braces, so we remove them. + tag = attribute.tag + if "}" in tag: + tag = tag.split("}")[1] + attributes[tag] = attribute.text + if user is None: + raise Exception("CAS response does not contain user") + except Exception: + logger.error("Error parsing CAS response", exc_info=1) + raise LoginError(401, "Invalid CAS response", + errcode=Codes.UNAUTHORIZED) + if not success: + raise LoginError(401, "Unsuccessful CAS response", + errcode=Codes.UNAUTHORIZED) + return user, attributes def register_servlets(hs, http_server): @@ -408,5 +415,3 @@ def register_servlets(hs, http_server): if hs.config.cas_enabled: CasRedirectServlet(hs).register(http_server) CasTicketServlet(hs).register(http_server) - CasRestServlet(hs).register(http_server) - # TODO PasswordResetRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py index 9bff02ee4e..1358d0acab 100644 --- a/synapse/rest/client/v1/logout.py +++ b/synapse/rest/client/v1/logout.py @@ -15,7 +15,7 @@ from twisted.internet import defer -from synapse.api.errors import AuthError, Codes +from synapse.api.auth import get_access_token_from_request from .base import ClientV1RestServlet, client_path_patterns @@ -37,13 +37,7 @@ class LogoutRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request): - try: - access_token = request.args["access_token"][0] - except KeyError: - raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.", - errcode=Codes.MISSING_TOKEN - ) + access_token = get_access_token_from_request(request) yield self.store.delete_access_token(access_token) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index 27d9ed586b..eafdce865e 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -30,20 +30,24 @@ logger = logging.getLogger(__name__) class PresenceStatusRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/presence/(?P<user_id>[^/]*)/status") + def __init__(self, hs): + super(PresenceStatusRestServlet, self).__init__(hs) + self.presence_handler = hs.get_presence_handler() + @defer.inlineCallbacks def on_GET(self, request, user_id): requester = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) if requester.user != user: - allowed = yield self.handlers.presence_handler.is_visible( + allowed = yield self.presence_handler.is_visible( observed_user=user, observer_user=requester.user, ) if not allowed: raise AuthError(403, "You are not allowed to see their presence.") - state = yield self.handlers.presence_handler.get_state(target_user=user) + state = yield self.presence_handler.get_state(target_user=user) defer.returnValue((200, state)) @@ -74,7 +78,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet): except: raise SynapseError(400, "Unable to parse state") - yield self.handlers.presence_handler.set_state(user, state) + yield self.presence_handler.set_state(user, state) defer.returnValue((200, {})) @@ -85,6 +89,10 @@ class PresenceStatusRestServlet(ClientV1RestServlet): class PresenceListRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/presence/list/(?P<user_id>[^/]*)") + def __init__(self, hs): + super(PresenceListRestServlet, self).__init__(hs) + self.presence_handler = hs.get_presence_handler() + @defer.inlineCallbacks def on_GET(self, request, user_id): requester = yield self.auth.get_user_by_req(request) @@ -96,7 +104,7 @@ class PresenceListRestServlet(ClientV1RestServlet): if requester.user != user: raise SynapseError(400, "Cannot get another user's presence list") - presence = yield self.handlers.presence_handler.get_presence_list( + presence = yield self.presence_handler.get_presence_list( observer_user=user, accepted=True ) @@ -123,7 +131,7 @@ class PresenceListRestServlet(ClientV1RestServlet): if len(u) == 0: continue invited_user = UserID.from_string(u) - yield self.handlers.presence_handler.send_presence_invite( + yield self.presence_handler.send_presence_invite( observer_user=user, observed_user=invited_user ) @@ -134,7 +142,7 @@ class PresenceListRestServlet(ClientV1RestServlet): if len(u) == 0: continue dropped_user = UserID.from_string(u) - yield self.handlers.presence_handler.drop( + yield self.presence_handler.drop( observer_user=user, observed_user=dropped_user ) diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index 65c4e2ebef..355e82474b 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -24,6 +24,10 @@ from synapse.http.servlet import parse_json_object_from_request class ProfileDisplaynameRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/displayname") + def __init__(self, hs): + super(ProfileDisplaynameRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, user_id): user = UserID.from_string(user_id) @@ -62,6 +66,10 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet): class ProfileAvatarURLRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/avatar_url") + def __init__(self, hs): + super(ProfileAvatarURLRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, user_id): user = UserID.from_string(user_id) @@ -99,6 +107,10 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet): class ProfileRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)") + def __init__(self, hs): + super(ProfileRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, user_id): user = UserID.from_string(user_id) diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 02d837ee6a..6bb4821ec6 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -128,11 +128,9 @@ class PushRuleRestServlet(ClientV1RestServlet): # we build up the full structure and then decide which bits of it # to send which means doing unnecessary work sometimes but is # is probably not going to make a whole lot of difference - rawrules = yield self.store.get_push_rules_for_user(user_id) + rules = yield self.store.get_push_rules_for_user(user_id) - enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id) - - rules = format_push_rules_for_user(requester.user, rawrules, enabled_map) + rules = format_push_rules_for_user(requester.user, rules) path = request.postpath[1:] diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 9881f068c3..9a2ed6ed88 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -17,7 +17,11 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, Codes from synapse.push import PusherConfigException -from synapse.http.servlet import parse_json_object_from_request +from synapse.http.servlet import ( + parse_json_object_from_request, parse_string, RestServlet +) +from synapse.http.server import finish_request +from synapse.api.errors import StoreError from .base import ClientV1RestServlet, client_path_patterns @@ -26,11 +30,48 @@ import logging logger = logging.getLogger(__name__) -class PusherRestServlet(ClientV1RestServlet): +class PushersRestServlet(ClientV1RestServlet): + PATTERNS = client_path_patterns("/pushers$") + + def __init__(self, hs): + super(PushersRestServlet, self).__init__(hs) + + @defer.inlineCallbacks + def on_GET(self, request): + requester = yield self.auth.get_user_by_req(request) + user = requester.user + + pushers = yield self.hs.get_datastore().get_pushers_by_user_id( + user.to_string() + ) + + allowed_keys = [ + "app_display_name", + "app_id", + "data", + "device_display_name", + "kind", + "lang", + "profile_tag", + "pushkey", + ] + + for p in pushers: + for k, v in p.items(): + if k not in allowed_keys: + del p[k] + + defer.returnValue((200, {"pushers": pushers})) + + def on_OPTIONS(self, _): + return 200, {} + + +class PushersSetRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/pushers/set$") def __init__(self, hs): - super(PusherRestServlet, self).__init__(hs) + super(PushersSetRestServlet, self).__init__(hs) self.notifier = hs.get_notifier() @defer.inlineCallbacks @@ -99,5 +140,57 @@ class PusherRestServlet(ClientV1RestServlet): return 200, {} +class PushersRemoveRestServlet(RestServlet): + """ + To allow pusher to be delete by clicking a link (ie. GET request) + """ + PATTERNS = client_path_patterns("/pushers/remove$") + SUCCESS_HTML = "<html><body>You have been unsubscribed</body><html>" + + def __init__(self, hs): + super(RestServlet, self).__init__() + self.hs = hs + self.notifier = hs.get_notifier() + self.auth = hs.get_v1auth() + + @defer.inlineCallbacks + def on_GET(self, request): + requester = yield self.auth.get_user_by_req(request, rights="delete_pusher") + user = requester.user + + app_id = parse_string(request, "app_id", required=True) + pushkey = parse_string(request, "pushkey", required=True) + + pusher_pool = self.hs.get_pusherpool() + + try: + yield pusher_pool.remove_pusher( + app_id=app_id, + pushkey=pushkey, + user_id=user.to_string(), + ) + except StoreError as se: + if se.code != 404: + # This is fine: they're already unsubscribed + raise + + self.notifier.on_new_replication_data() + + request.setResponseCode(200) + request.setHeader(b"Content-Type", b"text/html; charset=utf-8") + request.setHeader(b"Server", self.hs.version_string) + request.setHeader(b"Content-Length", b"%d" % ( + len(PushersRemoveRestServlet.SUCCESS_HTML), + )) + request.write(PushersRemoveRestServlet.SUCCESS_HTML) + finish_request(request) + defer.returnValue(None) + + def on_OPTIONS(self, _): + return 200, {} + + def register_servlets(hs, http_server): - PusherRestServlet(hs).register(http_server) + PushersRestServlet(hs).register(http_server) + PushersSetRestServlet(hs).register(http_server) + PushersRemoveRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index c6a2ef2ccc..ecf7e311a9 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -18,9 +18,11 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, Codes from synapse.api.constants import LoginType +from synapse.api.auth import get_access_token_from_request from .base import ClientV1RestServlet, client_path_patterns import synapse.util.stringutils as stringutils from synapse.http.servlet import parse_json_object_from_request +from synapse.types import create_requester from synapse.util.async import run_on_reactor @@ -52,6 +54,10 @@ class RegisterRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/register$", releases=(), include_in_unstable=False) def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ super(RegisterRestServlet, self).__init__(hs) # sessions are stored as: # self.sessions = { @@ -60,6 +66,8 @@ class RegisterRestServlet(ClientV1RestServlet): # TODO: persistent storage self.sessions = {} self.enable_registration = hs.config.enable_registration + self.auth_handler = hs.get_auth_handler() + self.handlers = hs.get_handlers() def on_GET(self, request): if self.hs.config.enable_registration_captcha: @@ -290,18 +298,18 @@ class RegisterRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def _do_app_service(self, request, register_json, session): - if "access_token" not in request.args: - raise SynapseError(400, "Expected application service token.") + as_token = get_access_token_from_request(request) + if "user" not in register_json: raise SynapseError(400, "Expected 'user' key.") - as_token = request.args["access_token"][0] user_localpart = register_json["user"].encode("utf-8") handler = self.handlers.registration_handler - (user_id, token) = yield handler.appservice_register( + user_id = yield handler.appservice_register( user_localpart, as_token ) + token = yield self.auth_handler.issue_access_token(user_id) self._remove_session(session) defer.returnValue({ "user_id": user_id, @@ -324,6 +332,14 @@ class RegisterRestServlet(ClientV1RestServlet): raise SynapseError(400, "Shared secret registration is not enabled") user = register_json["user"].encode("utf-8") + password = register_json["password"].encode("utf-8") + admin = register_json.get("admin", None) + + # Its important to check as we use null bytes as HMAC field separators + if "\x00" in user: + raise SynapseError(400, "Invalid user") + if "\x00" in password: + raise SynapseError(400, "Invalid password") # str() because otherwise hmac complains that 'unicode' does not # have the buffer interface @@ -331,17 +347,21 @@ class RegisterRestServlet(ClientV1RestServlet): want_mac = hmac.new( key=self.hs.config.registration_shared_secret, - msg=user, digestmod=sha1, - ).hexdigest() - - password = register_json["password"].encode("utf-8") + ) + want_mac.update(user) + want_mac.update("\x00") + want_mac.update(password) + want_mac.update("\x00") + want_mac.update("admin" if admin else "notadmin") + want_mac = want_mac.hexdigest() if compare_digest(want_mac, got_mac): handler = self.handlers.registration_handler user_id, token = yield handler.register( localpart=user, password=password, + admin=bool(admin), ) self._remove_session(session) defer.returnValue({ @@ -355,5 +375,68 @@ class RegisterRestServlet(ClientV1RestServlet): ) +class CreateUserRestServlet(ClientV1RestServlet): + """Handles user creation via a server-to-server interface + """ + + PATTERNS = client_path_patterns("/createUser$", releases=()) + + def __init__(self, hs): + super(CreateUserRestServlet, self).__init__(hs) + self.store = hs.get_datastore() + self.handlers = hs.get_handlers() + + @defer.inlineCallbacks + def on_POST(self, request): + user_json = parse_json_object_from_request(request) + + access_token = get_access_token_from_request(request) + app_service = self.store.get_app_service_by_token( + access_token + ) + if not app_service: + raise SynapseError(403, "Invalid application service token.") + + requester = create_requester(app_service.sender) + + logger.debug("creating user: %s", user_json) + response = yield self._do_create(requester, user_json) + + defer.returnValue((200, response)) + + def on_OPTIONS(self, request): + return 403, {} + + @defer.inlineCallbacks + def _do_create(self, requester, user_json): + yield run_on_reactor() + + if "localpart" not in user_json: + raise SynapseError(400, "Expected 'localpart' key.") + + if "displayname" not in user_json: + raise SynapseError(400, "Expected 'displayname' key.") + + localpart = user_json["localpart"].encode("utf-8") + displayname = user_json["displayname"].encode("utf-8") + password_hash = user_json["password_hash"].encode("utf-8") \ + if user_json.get("password_hash") else None + + handler = self.handlers.registration_handler + user_id, token = yield handler.get_or_create_user( + requester=requester, + localpart=localpart, + displayname=displayname, + password_hash=password_hash + ) + + defer.returnValue({ + "user_id": user_id, + "access_token": token, + "home_server": self.hs.hostname, + }) + + def register_servlets(hs, http_server): RegisterRestServlet(hs).register(http_server) + CreateUserRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index a1fa7daf79..3fb1f2deb3 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -20,12 +20,16 @@ from .base import ClientV1RestServlet, client_path_patterns from synapse.api.errors import SynapseError, Codes, AuthError from synapse.streams.config import PaginationConfig from synapse.api.constants import EventTypes, Membership +from synapse.api.filtering import Filter from synapse.types import UserID, RoomID, RoomAlias -from synapse.events.utils import serialize_event -from synapse.http.servlet import parse_json_object_from_request +from synapse.events.utils import serialize_event, format_event_for_client_v2 +from synapse.http.servlet import ( + parse_json_object_from_request, parse_string, parse_integer +) import logging import urllib +import ujson as json logger = logging.getLogger(__name__) @@ -33,6 +37,10 @@ logger = logging.getLogger(__name__) class RoomCreateRestServlet(ClientV1RestServlet): # No PATTERN; we have custom dispatch rules here + def __init__(self, hs): + super(RoomCreateRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + def register(self, http_server): PATTERNS = "/createRoom" register_txn_path(self, PATTERNS, http_server) @@ -45,19 +53,10 @@ class RoomCreateRestServlet(ClientV1RestServlet): client_path_patterns("/createRoom(?:/.*)?$"), self.on_OPTIONS) - @defer.inlineCallbacks def on_PUT(self, request, txn_id): - try: - defer.returnValue( - self.txns.get_client_transaction(request, txn_id) - ) - except KeyError: - pass - - response = yield self.on_POST(request) - - self.txns.store_client_transaction(request, txn_id, response) - defer.returnValue(response) + return self.txns.fetch_or_execute_request( + request, self.on_POST, request + ) @defer.inlineCallbacks def on_POST(self, request): @@ -72,8 +71,6 @@ class RoomCreateRestServlet(ClientV1RestServlet): def get_room_config(self, request): user_supplied_config = parse_json_object_from_request(request) - # default visibility - user_supplied_config.setdefault("visibility", "public") return user_supplied_config def on_OPTIONS(self, request): @@ -82,6 +79,10 @@ class RoomCreateRestServlet(ClientV1RestServlet): # TODO: Needs unit testing for generic events class RoomStateEventRestServlet(ClientV1RestServlet): + def __init__(self, hs): + super(RoomStateEventRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + def register(self, http_server): # /room/$roomid/state/$eventtype no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$" @@ -112,6 +113,8 @@ class RoomStateEventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id, event_type, state_key): requester = yield self.auth.get_user_by_req(request, allow_guest=True) + format = parse_string(request, "format", default="content", + allowed_values=["content", "event"]) msg_handler = self.handlers.message_handler data = yield msg_handler.get_room_data( @@ -126,7 +129,12 @@ class RoomStateEventRestServlet(ClientV1RestServlet): raise SynapseError( 404, "Event not found.", errcode=Codes.NOT_FOUND ) - defer.returnValue((200, data.get_dict()["content"])) + + if format == "event": + event = format_event_for_client_v2(data.get_dict()) + defer.returnValue((200, event)) + elif format == "content": + defer.returnValue((200, data.get_dict()["content"])) @defer.inlineCallbacks def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): @@ -166,6 +174,10 @@ class RoomStateEventRestServlet(ClientV1RestServlet): # TODO: Needs unit testing for generic events + feedback class RoomSendEventRestServlet(ClientV1RestServlet): + def __init__(self, hs): + super(RoomSendEventRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + def register(self, http_server): # /rooms/$roomid/send/$event_type[/$txn_id] PATTERNS = ("/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)") @@ -193,23 +205,17 @@ class RoomSendEventRestServlet(ClientV1RestServlet): def on_GET(self, request, room_id, event_type, txn_id): return (200, "Not implemented") - @defer.inlineCallbacks def on_PUT(self, request, room_id, event_type, txn_id): - try: - defer.returnValue( - self.txns.get_client_transaction(request, txn_id) - ) - except KeyError: - pass - - response = yield self.on_POST(request, room_id, event_type, txn_id) - - self.txns.store_client_transaction(request, txn_id, response) - defer.returnValue(response) + return self.txns.fetch_or_execute_request( + request, self.on_POST, request, room_id, event_type, txn_id + ) # TODO: Needs unit testing for room ID + alias joins class JoinRoomAliasServlet(ClientV1RestServlet): + def __init__(self, hs): + super(JoinRoomAliasServlet, self).__init__(hs) + self.handlers = hs.get_handlers() def register(self, http_server): # /join/$room_identifier[/$txn_id] @@ -232,7 +238,10 @@ class JoinRoomAliasServlet(ClientV1RestServlet): if RoomID.is_valid(room_identifier): room_id = room_identifier - remote_room_hosts = None + try: + remote_room_hosts = request.args["server_name"] + except: + remote_room_hosts = None elif RoomAlias.is_valid(room_identifier): handler = self.handlers.room_member_handler room_alias = RoomAlias.from_string(room_identifier) @@ -250,24 +259,16 @@ class JoinRoomAliasServlet(ClientV1RestServlet): action="join", txn_id=txn_id, remote_room_hosts=remote_room_hosts, + content=content, third_party_signed=content.get("third_party_signed", None), ) defer.returnValue((200, {"room_id": room_id})) - @defer.inlineCallbacks def on_PUT(self, request, room_identifier, txn_id): - try: - defer.returnValue( - self.txns.get_client_transaction(request, txn_id) - ) - except KeyError: - pass - - response = yield self.on_POST(request, room_identifier, txn_id) - - self.txns.store_client_transaction(request, txn_id, response) - defer.returnValue(response) + return self.txns.fetch_or_execute_request( + request, self.on_POST, request, room_identifier, txn_id + ) # TODO: Needs unit testing @@ -276,8 +277,65 @@ class PublicRoomListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - handler = self.handlers.room_list_handler - data = yield handler.get_public_room_list() + server = parse_string(request, "server", default=None) + + try: + yield self.auth.get_user_by_req(request, allow_guest=True) + except AuthError as e: + # We allow people to not be authed if they're just looking at our + # room list, but require auth when we proxy the request. + # In both cases we call the auth function, as that has the side + # effect of logging who issued this request if an access token was + # provided. + if server: + raise e + else: + pass + + limit = parse_integer(request, "limit", 0) + since_token = parse_string(request, "since", None) + + handler = self.hs.get_room_list_handler() + if server: + data = yield handler.get_remote_public_room_list( + server, + limit=limit, + since_token=since_token, + ) + else: + data = yield handler.get_local_public_room_list( + limit=limit, + since_token=since_token, + ) + + defer.returnValue((200, data)) + + @defer.inlineCallbacks + def on_POST(self, request): + yield self.auth.get_user_by_req(request, allow_guest=True) + + server = parse_string(request, "server", default=None) + content = parse_json_object_from_request(request) + + limit = int(content.get("limit", 100)) + since_token = content.get("since", None) + search_filter = content.get("filter", None) + + handler = self.hs.get_room_list_handler() + if server: + data = yield handler.get_remote_public_room_list( + server, + limit=limit, + since_token=since_token, + search_filter=search_filter, + ) + else: + data = yield handler.get_local_public_room_list( + limit=limit, + since_token=since_token, + search_filter=search_filter, + ) + defer.returnValue((200, data)) @@ -285,6 +343,10 @@ class PublicRoomListRestServlet(ClientV1RestServlet): class RoomMemberListRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/members$") + def __init__(self, hs): + super(RoomMemberListRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, room_id): # TODO support Pagination stream API (limit/tokens) @@ -311,6 +373,10 @@ class RoomMemberListRestServlet(ClientV1RestServlet): class RoomMessageListRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/messages$") + def __init__(self, hs): + super(RoomMessageListRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, room_id): requester = yield self.auth.get_user_by_req(request, allow_guest=True) @@ -318,12 +384,19 @@ class RoomMessageListRestServlet(ClientV1RestServlet): request, default_limit=10, ) as_client_event = "raw" not in request.args + filter_bytes = request.args.get("filter", None) + if filter_bytes: + filter_json = urllib.unquote(filter_bytes[-1]).decode("UTF-8") + event_filter = Filter(json.loads(filter_json)) + else: + event_filter = None handler = self.handlers.message_handler msgs = yield handler.get_messages( room_id=room_id, requester=requester, pagin_config=pagination_config, - as_client_event=as_client_event + as_client_event=as_client_event, + event_filter=event_filter, ) defer.returnValue((200, msgs)) @@ -333,6 +406,10 @@ class RoomMessageListRestServlet(ClientV1RestServlet): class RoomStateRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/state$") + def __init__(self, hs): + super(RoomStateRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, room_id): requester = yield self.auth.get_user_by_req(request, allow_guest=True) @@ -350,11 +427,15 @@ class RoomStateRestServlet(ClientV1RestServlet): class RoomInitialSyncRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$") + def __init__(self, hs): + super(RoomInitialSyncRestServlet, self).__init__(hs) + self.initial_sync_handler = hs.get_initial_sync_handler() + @defer.inlineCallbacks def on_GET(self, request, room_id): requester = yield self.auth.get_user_by_req(request, allow_guest=True) pagination_config = PaginationConfig.from_request(request) - content = yield self.handlers.message_handler.room_initial_sync( + content = yield self.initial_sync_handler.room_initial_sync( room_id=room_id, requester=requester, pagin_config=pagination_config, @@ -370,6 +451,7 @@ class RoomEventContext(ClientV1RestServlet): def __init__(self, hs): super(RoomEventContext, self).__init__(hs) self.clock = hs.get_clock() + self.handlers = hs.get_handlers() @defer.inlineCallbacks def on_GET(self, request, room_id, event_id): @@ -405,9 +487,42 @@ class RoomEventContext(ClientV1RestServlet): defer.returnValue((200, results)) +class RoomForgetRestServlet(ClientV1RestServlet): + def __init__(self, hs): + super(RoomForgetRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + + def register(self, http_server): + PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget") + register_txn_path(self, PATTERNS, http_server) + + @defer.inlineCallbacks + def on_POST(self, request, room_id, txn_id=None): + requester = yield self.auth.get_user_by_req( + request, + allow_guest=False, + ) + + yield self.handlers.room_member_handler.forget( + user=requester.user, + room_id=room_id, + ) + + defer.returnValue((200, {})) + + def on_PUT(self, request, room_id, txn_id): + return self.txns.fetch_or_execute_request( + request, self.on_POST, request, room_id, txn_id + ) + + # TODO: Needs unit testing class RoomMembershipRestServlet(ClientV1RestServlet): + def __init__(self, hs): + super(RoomMembershipRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + def register(self, http_server): # /rooms/$roomid/[invite|join|leave] PATTERNS = ("/rooms/(?P<room_id>[^/]*)/" @@ -470,24 +585,17 @@ class RoomMembershipRestServlet(ClientV1RestServlet): return False return True - @defer.inlineCallbacks def on_PUT(self, request, room_id, membership_action, txn_id): - try: - defer.returnValue( - self.txns.get_client_transaction(request, txn_id) - ) - except KeyError: - pass - - response = yield self.on_POST( - request, room_id, membership_action, txn_id + return self.txns.fetch_or_execute_request( + request, self.on_POST, request, room_id, membership_action, txn_id ) - self.txns.store_client_transaction(request, txn_id, response) - defer.returnValue(response) - class RoomRedactEventRestServlet(ClientV1RestServlet): + def __init__(self, hs): + super(RoomRedactEventRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + def register(self, http_server): PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)") register_txn_path(self, PATTERNS, http_server) @@ -512,19 +620,10 @@ class RoomRedactEventRestServlet(ClientV1RestServlet): defer.returnValue((200, {"event_id": event.event_id})) - @defer.inlineCallbacks def on_PUT(self, request, room_id, event_id, txn_id): - try: - defer.returnValue( - self.txns.get_client_transaction(request, txn_id) - ) - except KeyError: - pass - - response = yield self.on_POST(request, room_id, event_id, txn_id) - - self.txns.store_client_transaction(request, txn_id, response) - defer.returnValue(response) + return self.txns.fetch_or_execute_request( + request, self.on_POST, request, room_id, event_id, txn_id + ) class RoomTypingRestServlet(ClientV1RestServlet): @@ -534,7 +633,8 @@ class RoomTypingRestServlet(ClientV1RestServlet): def __init__(self, hs): super(RoomTypingRestServlet, self).__init__(hs) - self.presence_handler = hs.get_handlers().presence_handler + self.presence_handler = hs.get_presence_handler() + self.typing_handler = hs.get_typing_handler() @defer.inlineCallbacks def on_PUT(self, request, room_id, user_id): @@ -545,19 +645,20 @@ class RoomTypingRestServlet(ClientV1RestServlet): content = parse_json_object_from_request(request) - typing_handler = self.handlers.typing_notification_handler - yield self.presence_handler.bump_presence_active_time(requester.user) + # Limit timeout to stop people from setting silly typing timeouts. + timeout = min(content.get("timeout", 30000), 120000) + if content["typing"]: - yield typing_handler.started_typing( + yield self.typing_handler.started_typing( target_user=target_user, auth_user=requester.user, room_id=room_id, - timeout=content.get("timeout", 30000), + timeout=timeout, ) else: - yield typing_handler.stopped_typing( + yield self.typing_handler.stopped_typing( target_user=target_user, auth_user=requester.user, room_id=room_id, @@ -571,6 +672,10 @@ class SearchRestServlet(ClientV1RestServlet): "/search$" ) + def __init__(self, hs): + super(SearchRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_POST(self, request): requester = yield self.auth.get_user_by_req(request) @@ -624,6 +729,7 @@ def register_servlets(hs, http_server): RoomMemberListRestServlet(hs).register(http_server) RoomMessageListRestServlet(hs).register(http_server) JoinRoomAliasServlet(hs).register(http_server) + RoomForgetRestServlet(hs).register(http_server) RoomMembershipRestServlet(hs).register(http_server) RoomSendEventRestServlet(hs).register(http_server) PublicRoomListRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/transactions.py b/synapse/rest/client/v1/transactions.py deleted file mode 100644 index bdccf464a5..0000000000 --- a/synapse/rest/client/v1/transactions.py +++ /dev/null @@ -1,95 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""This module contains logic for storing HTTP PUT transactions. This is used -to ensure idempotency when performing PUTs using the REST API.""" -import logging - -logger = logging.getLogger(__name__) - - -# FIXME: elsewhere we use FooStore to indicate something in the storage layer... -class HttpTransactionStore(object): - - def __init__(self): - # { key : (txn_id, response) } - self.transactions = {} - - def get_response(self, key, txn_id): - """Retrieve a response for this request. - - Args: - key (str): A transaction-independent key for this request. Usually - this is a combination of the path (without the transaction id) - and the user's access token. - txn_id (str): The transaction ID for this request - Returns: - A tuple of (HTTP response code, response content) or None. - """ - try: - logger.debug("get_response TxnId: %s", txn_id) - (last_txn_id, response) = self.transactions[key] - if txn_id == last_txn_id: - logger.info("get_response: Returning a response for %s", txn_id) - return response - except KeyError: - pass - return None - - def store_response(self, key, txn_id, response): - """Stores an HTTP response tuple. - - Args: - key (str): A transaction-independent key for this request. Usually - this is a combination of the path (without the transaction id) - and the user's access token. - txn_id (str): The transaction ID for this request. - response (tuple): A tuple of (HTTP response code, response content) - """ - logger.debug("store_response TxnId: %s", txn_id) - self.transactions[key] = (txn_id, response) - - def store_client_transaction(self, request, txn_id, response): - """Stores the request/response pair of an HTTP transaction. - - Args: - request (twisted.web.http.Request): The twisted HTTP request. This - request must have the transaction ID as the last path segment. - response (tuple): A tuple of (response code, response dict) - txn_id (str): The transaction ID for this request. - """ - self.store_response(self._get_key(request), txn_id, response) - - def get_client_transaction(self, request, txn_id): - """Retrieves a stored response if there was one. - - Args: - request (twisted.web.http.Request): The twisted HTTP request. This - request must have the transaction ID as the last path segment. - txn_id (str): The transaction ID for this request. - Returns: - The response tuple. - Raises: - KeyError if the transaction was not found. - """ - response = self.get_response(self._get_key(request), txn_id) - if response is None: - raise KeyError("Transaction not found.") - return response - - def _get_key(self, request): - token = request.args["access_token"][0] - path_without_txn_id = request.path.rsplit("/", 1)[0] - return path_without_txn_id + "/" + token diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py index b6faa2b0e6..20e765f48f 100644 --- a/synapse/rest/client/v2_alpha/_base.py +++ b/synapse/rest/client/v2_alpha/_base.py @@ -25,7 +25,9 @@ import logging logger = logging.getLogger(__name__) -def client_v2_patterns(path_regex, releases=(0,)): +def client_v2_patterns(path_regex, releases=(0,), + v2_alpha=True, + unstable=True): """Creates a regex compiled client path with the correct client path prefix. @@ -35,9 +37,12 @@ def client_v2_patterns(path_regex, releases=(0,)): Returns: SRE_Pattern """ - patterns = [re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex)] - unstable_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/unstable") - patterns.append(re.compile("^" + unstable_prefix + path_regex)) + patterns = [] + if v2_alpha: + patterns.append(re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex)) + if unstable: + unstable_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/unstable") + patterns.append(re.compile("^" + unstable_prefix + path_regex)) for release in releases: new_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/r%d" % release) patterns.append(re.compile("^" + new_prefix + path_regex)) diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 7f8a6a4cf7..eb49ad62e9 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -28,14 +28,46 @@ import logging logger = logging.getLogger(__name__) +class PasswordRequestTokenRestServlet(RestServlet): + PATTERNS = client_v2_patterns("/account/password/email/requestToken$") + + def __init__(self, hs): + super(PasswordRequestTokenRestServlet, self).__init__() + self.hs = hs + self.identity_handler = hs.get_handlers().identity_handler + + @defer.inlineCallbacks + def on_POST(self, request): + body = parse_json_object_from_request(request) + + required = ['id_server', 'client_secret', 'email', 'send_attempt'] + absent = [] + for k in required: + if k not in body: + absent.append(k) + + if absent: + raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) + + existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( + 'email', body['email'] + ) + + if existingUid is None: + raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) + + ret = yield self.identity_handler.requestEmailToken(**body) + defer.returnValue((200, ret)) + + class PasswordRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/account/password") + PATTERNS = client_v2_patterns("/account/password$") def __init__(self, hs): super(PasswordRestServlet, self).__init__() self.hs = hs self.auth = hs.get_auth() - self.auth_handler = hs.get_handlers().auth_handler + self.auth_handler = hs.get_auth_handler() @defer.inlineCallbacks def on_POST(self, request): @@ -52,6 +84,7 @@ class PasswordRestServlet(RestServlet): defer.returnValue((401, result)) user_id = None + requester = None if LoginType.PASSWORD in result: # if using password, they should also be logged in @@ -88,15 +121,90 @@ class PasswordRestServlet(RestServlet): return 200, {} +class DeactivateAccountRestServlet(RestServlet): + PATTERNS = client_v2_patterns("/account/deactivate$") + + def __init__(self, hs): + self.hs = hs + self.store = hs.get_datastore() + self.auth = hs.get_auth() + self.auth_handler = hs.get_auth_handler() + super(DeactivateAccountRestServlet, self).__init__() + + @defer.inlineCallbacks + def on_POST(self, request): + body = parse_json_object_from_request(request) + + authed, result, params, _ = yield self.auth_handler.check_auth([ + [LoginType.PASSWORD], + ], body, self.hs.get_ip_from_request(request)) + + if not authed: + defer.returnValue((401, result)) + + user_id = None + requester = None + + if LoginType.PASSWORD in result: + # if using password, they should also be logged in + requester = yield self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + if user_id != result[LoginType.PASSWORD]: + raise LoginError(400, "", Codes.UNKNOWN) + else: + logger.error("Auth succeeded but no known type!", result.keys()) + raise SynapseError(500, "", Codes.UNKNOWN) + + # FIXME: Theoretically there is a race here wherein user resets password + # using threepid. + yield self.store.user_delete_access_tokens(user_id) + yield self.store.user_delete_threepids(user_id) + yield self.store.user_set_password_hash(user_id, None) + + defer.returnValue((200, {})) + + +class ThreepidRequestTokenRestServlet(RestServlet): + PATTERNS = client_v2_patterns("/account/3pid/email/requestToken$") + + def __init__(self, hs): + self.hs = hs + super(ThreepidRequestTokenRestServlet, self).__init__() + self.identity_handler = hs.get_handlers().identity_handler + + @defer.inlineCallbacks + def on_POST(self, request): + body = parse_json_object_from_request(request) + + required = ['id_server', 'client_secret', 'email', 'send_attempt'] + absent = [] + for k in required: + if k not in body: + absent.append(k) + + if absent: + raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) + + existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( + 'email', body['email'] + ) + + if existingUid is not None: + raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) + + ret = yield self.identity_handler.requestEmailToken(**body) + defer.returnValue((200, ret)) + + class ThreepidRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/account/3pid") + PATTERNS = client_v2_patterns("/account/3pid$") def __init__(self, hs): super(ThreepidRestServlet, self).__init__() self.hs = hs self.identity_handler = hs.get_handlers().identity_handler self.auth = hs.get_auth() - self.auth_handler = hs.get_handlers().auth_handler + self.auth_handler = hs.get_auth_handler() @defer.inlineCallbacks def on_GET(self, request): @@ -156,5 +264,8 @@ class ThreepidRestServlet(RestServlet): def register_servlets(hs, http_server): + PasswordRequestTokenRestServlet(hs).register(http_server) PasswordRestServlet(hs).register(http_server) + DeactivateAccountRestServlet(hs).register(http_server) + ThreepidRequestTokenRestServlet(hs).register(http_server) ThreepidRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index 78181b7b18..8e5577148f 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -77,8 +77,10 @@ SUCCESS_TEMPLATE = """ user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'> <link rel="stylesheet" href="/_matrix/static/client/register/style.css"> <script> -if (window.onAuthDone != undefined) { +if (window.onAuthDone) { window.onAuthDone(); +} else if (window.opener && window.opener.postMessage) { + window.opener.postMessage("authDone", "*"); } </script> </head> @@ -104,7 +106,7 @@ class AuthRestServlet(RestServlet): super(AuthRestServlet, self).__init__() self.hs = hs self.auth = hs.get_auth() - self.auth_handler = hs.get_handlers().auth_handler + self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_handlers().registration_handler @defer.inlineCallbacks diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py new file mode 100644 index 0000000000..a1feaf3d54 --- /dev/null +++ b/synapse/rest/client/v2_alpha/devices.py @@ -0,0 +1,115 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.internet import defer + +from synapse.api import constants, errors +from synapse.http import servlet +from ._base import client_v2_patterns + +logger = logging.getLogger(__name__) + + +class DevicesRestServlet(servlet.RestServlet): + PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(DevicesRestServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() + + @defer.inlineCallbacks + def on_GET(self, request): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + devices = yield self.device_handler.get_devices_by_user( + requester.user.to_string() + ) + defer.returnValue((200, {"devices": devices})) + + +class DeviceRestServlet(servlet.RestServlet): + PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", + releases=[], v2_alpha=False) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(DeviceRestServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() + self.auth_handler = hs.get_auth_handler() + + @defer.inlineCallbacks + def on_GET(self, request, device_id): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + device = yield self.device_handler.get_device( + requester.user.to_string(), + device_id, + ) + defer.returnValue((200, device)) + + @defer.inlineCallbacks + def on_DELETE(self, request, device_id): + try: + body = servlet.parse_json_object_from_request(request) + + except errors.SynapseError as e: + if e.errcode == errors.Codes.NOT_JSON: + # deal with older clients which didn't pass a JSON dict + # the same as those that pass an empty dict + body = {} + else: + raise + + authed, result, params, _ = yield self.auth_handler.check_auth([ + [constants.LoginType.PASSWORD], + ], body, self.hs.get_ip_from_request(request)) + + if not authed: + defer.returnValue((401, result)) + + requester = yield self.auth.get_user_by_req(request) + yield self.device_handler.delete_device( + requester.user.to_string(), + device_id, + ) + defer.returnValue((200, {})) + + @defer.inlineCallbacks + def on_PUT(self, request, device_id): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + + body = servlet.parse_json_object_from_request(request) + yield self.device_handler.update_device( + requester.user.to_string(), + device_id, + body + ) + defer.returnValue((200, {})) + + +def register_servlets(hs, http_server): + DevicesRestServlet(hs).register(http_server) + DeviceRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py index 510f8b2c74..b4084fec62 100644 --- a/synapse/rest/client/v2_alpha/filter.py +++ b/synapse/rest/client/v2_alpha/filter.py @@ -15,7 +15,7 @@ from twisted.internet import defer -from synapse.api.errors import AuthError, SynapseError +from synapse.api.errors import AuthError, SynapseError, StoreError, Codes from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.types import UserID @@ -45,7 +45,7 @@ class GetFilterRestServlet(RestServlet): raise AuthError(403, "Cannot get filters for other users") if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only get filters for local users") + raise AuthError(403, "Can only get filters for local users") try: filter_id = int(filter_id) @@ -59,8 +59,8 @@ class GetFilterRestServlet(RestServlet): ) defer.returnValue((200, filter.get_filter_json())) - except KeyError: - raise SynapseError(400, "No such filter") + except (KeyError, StoreError): + raise SynapseError(400, "No such filter", errcode=Codes.NOT_FOUND) class CreateFilterRestServlet(RestServlet): @@ -74,6 +74,7 @@ class CreateFilterRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request, user_id): + target_user = UserID.from_string(user_id) requester = yield self.auth.get_user_by_req(request) @@ -81,10 +82,9 @@ class CreateFilterRestServlet(RestServlet): raise AuthError(403, "Cannot create filters for other users") if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only create filters for local users") + raise AuthError(403, "Can only create filters for local users") content = parse_json_object_from_request(request) - filter_id = yield self.filtering.add_user_filter( user_localpart=target_user.localpart, user_filter=content, diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 89ab39491c..08b7c99d57 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -13,24 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - -from synapse.http.servlet import RestServlet, parse_json_object_from_request -from synapse.types import UserID +import logging -from canonicaljson import encode_canonical_json +from twisted.internet import defer +from synapse.api.errors import SynapseError +from synapse.http.servlet import ( + RestServlet, parse_json_object_from_request, parse_integer +) from ._base import client_v2_patterns -import logging -import simplejson as json - logger = logging.getLogger(__name__) class KeyUploadServlet(RestServlet): """ - POST /keys/upload/<device_id> HTTP/1.1 + POST /keys/upload HTTP/1.1 Content-Type: application/json { @@ -53,65 +51,45 @@ class KeyUploadServlet(RestServlet): }, } """ - PATTERNS = client_v2_patterns("/keys/upload/(?P<device_id>[^/]*)", releases=()) + PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$", + releases=()) def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ super(KeyUploadServlet, self).__init__() - self.store = hs.get_datastore() - self.clock = hs.get_clock() self.auth = hs.get_auth() + self.e2e_keys_handler = hs.get_e2e_keys_handler() @defer.inlineCallbacks def on_POST(self, request, device_id): - requester = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) user_id = requester.user.to_string() - # TODO: Check that the device_id matches that in the authentication - # or derive the device_id from the authentication instead. - body = parse_json_object_from_request(request) - time_now = self.clock.time_msec() - - # TODO: Validate the JSON to make sure it has the right keys. - device_keys = body.get("device_keys", None) - if device_keys: - logger.info( - "Updating device_keys for device %r for user %s at %d", - device_id, user_id, time_now + if device_id is not None: + # passing the device_id here is deprecated; however, we allow it + # for now for compatibility with older clients. + if (requester.device_id is not None and + device_id != requester.device_id): + logger.warning("Client uploading keys for a different device " + "(logged in as %s, uploading for %s)", + requester.device_id, device_id) + else: + device_id = requester.device_id + + if device_id is None: + raise SynapseError( + 400, + "To upload keys, you must pass device_id when authenticating" ) - # TODO: Sign the JSON with the server key - yield self.store.set_e2e_device_keys( - user_id, device_id, time_now, - encode_canonical_json(device_keys) - ) - - one_time_keys = body.get("one_time_keys", None) - if one_time_keys: - logger.info( - "Adding %d one_time_keys for device %r for user %r at %d", - len(one_time_keys), device_id, user_id, time_now - ) - key_list = [] - for key_id, key_json in one_time_keys.items(): - algorithm, key_id = key_id.split(":") - key_list.append(( - algorithm, key_id, encode_canonical_json(key_json) - )) - - yield self.store.add_e2e_one_time_keys( - user_id, device_id, time_now, key_list - ) - - result = yield self.store.count_e2e_one_time_keys(user_id, device_id) - defer.returnValue((200, {"one_time_key_counts": result})) - - @defer.inlineCallbacks - def on_GET(self, request, device_id): - requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() - result = yield self.store.count_e2e_one_time_keys(user_id, device_id) - defer.returnValue((200, {"one_time_key_counts": result})) + result = yield self.e2e_keys_handler.upload_keys_for_user( + user_id, device_id, body + ) + defer.returnValue((200, result)) class KeyQueryServlet(RestServlet): @@ -162,63 +140,34 @@ class KeyQueryServlet(RestServlet): ) def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): + """ super(KeyQueryServlet, self).__init__() - self.store = hs.get_datastore() self.auth = hs.get_auth() - self.federation = hs.get_replication_layer() - self.is_mine = hs.is_mine + self.e2e_keys_handler = hs.get_e2e_keys_handler() @defer.inlineCallbacks def on_POST(self, request, user_id, device_id): - yield self.auth.get_user_by_req(request) + yield self.auth.get_user_by_req(request, allow_guest=True) + timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) - result = yield self.handle_request(body) - defer.returnValue(result) + result = yield self.e2e_keys_handler.query_devices(body, timeout) + defer.returnValue((200, result)) @defer.inlineCallbacks def on_GET(self, request, user_id, device_id): - requester = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + timeout = parse_integer(request, "timeout", 10 * 1000) auth_user_id = requester.user.to_string() user_id = user_id if user_id else auth_user_id device_ids = [device_id] if device_id else [] - result = yield self.handle_request( - {"device_keys": {user_id: device_ids}} + result = yield self.e2e_keys_handler.query_devices( + {"device_keys": {user_id: device_ids}}, + timeout, ) - defer.returnValue(result) - - @defer.inlineCallbacks - def handle_request(self, body): - local_query = [] - remote_queries = {} - for user_id, device_ids in body.get("device_keys", {}).items(): - user = UserID.from_string(user_id) - if self.is_mine(user): - if not device_ids: - local_query.append((user_id, None)) - else: - for device_id in device_ids: - local_query.append((user_id, device_id)) - else: - remote_queries.setdefault(user.domain, {})[user_id] = list( - device_ids - ) - results = yield self.store.get_e2e_device_keys(local_query) - - json_result = {} - for user_id, device_keys in results.items(): - for device_id, json_bytes in device_keys.items(): - json_result.setdefault(user_id, {})[device_id] = json.loads( - json_bytes - ) - - for destination, device_keys in remote_queries.items(): - remote_result = yield self.federation.query_client_keys( - destination, {"device_keys": device_keys} - ) - for user_id, keys in remote_result["device_keys"].items(): - if user_id in device_keys: - json_result[user_id] = keys - defer.returnValue((200, {"device_keys": json_result})) + defer.returnValue((200, result)) class OneTimeKeyServlet(RestServlet): @@ -250,59 +199,29 @@ class OneTimeKeyServlet(RestServlet): def __init__(self, hs): super(OneTimeKeyServlet, self).__init__() - self.store = hs.get_datastore() self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.federation = hs.get_replication_layer() - self.is_mine = hs.is_mine + self.e2e_keys_handler = hs.get_e2e_keys_handler() @defer.inlineCallbacks def on_GET(self, request, user_id, device_id, algorithm): - yield self.auth.get_user_by_req(request) - result = yield self.handle_request( - {"one_time_keys": {user_id: {device_id: algorithm}}} + yield self.auth.get_user_by_req(request, allow_guest=True) + timeout = parse_integer(request, "timeout", 10 * 1000) + result = yield self.e2e_keys_handler.claim_one_time_keys( + {"one_time_keys": {user_id: {device_id: algorithm}}}, + timeout, ) - defer.returnValue(result) + defer.returnValue((200, result)) @defer.inlineCallbacks def on_POST(self, request, user_id, device_id, algorithm): - yield self.auth.get_user_by_req(request) + yield self.auth.get_user_by_req(request, allow_guest=True) + timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) - result = yield self.handle_request(body) - defer.returnValue(result) - - @defer.inlineCallbacks - def handle_request(self, body): - local_query = [] - remote_queries = {} - for user_id, device_keys in body.get("one_time_keys", {}).items(): - user = UserID.from_string(user_id) - if self.is_mine(user): - for device_id, algorithm in device_keys.items(): - local_query.append((user_id, device_id, algorithm)) - else: - remote_queries.setdefault(user.domain, {})[user_id] = ( - device_keys - ) - results = yield self.store.claim_e2e_one_time_keys(local_query) - - json_result = {} - for user_id, device_keys in results.items(): - for device_id, keys in device_keys.items(): - for key_id, json_bytes in keys.items(): - json_result.setdefault(user_id, {})[device_id] = { - key_id: json.loads(json_bytes) - } - - for destination, device_keys in remote_queries.items(): - remote_result = yield self.federation.claim_client_keys( - destination, {"one_time_keys": device_keys} - ) - for user_id, keys in remote_result["one_time_keys"].items(): - if user_id in device_keys: - json_result[user_id] = keys - - defer.returnValue((200, {"one_time_keys": json_result})) + result = yield self.e2e_keys_handler.claim_one_time_keys( + body, + timeout, + ) + defer.returnValue((200, result)) def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py new file mode 100644 index 0000000000..fd2a3d69d4 --- /dev/null +++ b/synapse/rest/client/v2_alpha/notifications.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.http.servlet import ( + RestServlet, parse_string, parse_integer +) +from synapse.events.utils import ( + serialize_event, format_event_for_client_v2_without_room_id, +) + +from ._base import client_v2_patterns + +import logging + +logger = logging.getLogger(__name__) + + +class NotificationsServlet(RestServlet): + PATTERNS = client_v2_patterns("/notifications$", releases=()) + + def __init__(self, hs): + super(NotificationsServlet, self).__init__() + self.store = hs.get_datastore() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + + @defer.inlineCallbacks + def on_GET(self, request): + requester = yield self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + + from_token = parse_string(request, "from", required=False) + limit = parse_integer(request, "limit", default=50) + only = parse_string(request, "only", required=False) + + limit = min(limit, 500) + + push_actions = yield self.store.get_push_actions_for_user( + user_id, from_token, limit, only_highlight=(only == "highlight") + ) + + receipts_by_room = yield self.store.get_receipts_for_user_with_orderings( + user_id, 'm.read' + ) + + notif_event_ids = [pa["event_id"] for pa in push_actions] + notif_events = yield self.store.get_events(notif_event_ids) + + returned_push_actions = [] + + next_token = None + + for pa in push_actions: + returned_pa = { + "room_id": pa["room_id"], + "profile_tag": pa["profile_tag"], + "actions": pa["actions"], + "ts": pa["received_ts"], + "event": serialize_event( + notif_events[pa["event_id"]], + self.clock.time_msec(), + event_format=format_event_for_client_v2_without_room_id, + ), + } + + if pa["room_id"] not in receipts_by_room: + returned_pa["read"] = False + else: + receipt = receipts_by_room[pa["room_id"]] + + returned_pa["read"] = ( + receipt["topological_ordering"], receipt["stream_ordering"] + ) >= ( + pa["topological_ordering"], pa["stream_ordering"] + ) + returned_push_actions.append(returned_pa) + next_token = pa["stream_ordering"] + + defer.returnValue((200, { + "notifications": returned_push_actions, + "next_token": next_token, + })) + + +def register_servlets(hs, http_server): + NotificationsServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/openid.py b/synapse/rest/client/v2_alpha/openid.py new file mode 100644 index 0000000000..aa1cae8e1e --- /dev/null +++ b/synapse/rest/client/v2_alpha/openid.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ._base import client_v2_patterns + +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.api.errors import AuthError +from synapse.util.stringutils import random_string + +from twisted.internet import defer + +import logging + +logger = logging.getLogger(__name__) + + +class IdTokenServlet(RestServlet): + """ + Get a bearer token that may be passed to a third party to confirm ownership + of a matrix user id. + + The format of the response could be made compatible with the format given + in http://openid.net/specs/openid-connect-core-1_0.html#TokenResponse + + But instead of returning a signed "id_token" the response contains the + name of the issuing matrix homeserver. This means that for now the third + party will need to check the validity of the "id_token" against the + federation /openid/userinfo endpoint of the homeserver. + + Request: + + POST /user/{user_id}/openid/request_token?access_token=... HTTP/1.1 + + {} + + Response: + + HTTP/1.1 200 OK + { + "access_token": "ABDEFGH", + "token_type": "Bearer", + "matrix_server_name": "example.com", + "expires_in": 3600, + } + """ + PATTERNS = client_v2_patterns( + "/user/(?P<user_id>[^/]*)/openid/request_token" + ) + + EXPIRES_MS = 3600 * 1000 + + def __init__(self, hs): + super(IdTokenServlet, self).__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.clock = hs.get_clock() + self.server_name = hs.config.server_name + + @defer.inlineCallbacks + def on_POST(self, request, user_id): + requester = yield self.auth.get_user_by_req(request) + if user_id != requester.user.to_string(): + raise AuthError(403, "Cannot request tokens for other users.") + + # Parse the request body to make sure it's JSON, but ignore the contents + # for now. + parse_json_object_from_request(request) + + token = random_string(24) + ts_valid_until_ms = self.clock.time_msec() + self.EXPIRES_MS + + yield self.store.insert_open_id_token(token, ts_valid_until_ms, user_id) + + defer.returnValue((200, { + "access_token": token, + "token_type": "Bearer", + "matrix_server_name": self.server_name, + "expires_in": self.EXPIRES_MS / 1000, + })) + + +def register_servlets(hs, http_server): + IdTokenServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index b831d8c95e..1fbff2edd8 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -36,8 +36,8 @@ class ReceiptRestServlet(RestServlet): super(ReceiptRestServlet, self).__init__() self.hs = hs self.auth = hs.get_auth() - self.receipts_handler = hs.get_handlers().receipts_handler - self.presence_handler = hs.get_handlers().presence_handler + self.receipts_handler = hs.get_receipts_handler() + self.presence_handler = hs.get_presence_handler() @defer.inlineCallbacks def on_POST(self, request, room_id, receipt_type, event_id): diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index d32c06c882..3e7a285e10 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -15,6 +15,8 @@ from twisted.internet import defer +import synapse +from synapse.api.auth import get_access_token_from_request, has_access_token from synapse.api.constants import LoginType from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -41,27 +43,72 @@ else: logger = logging.getLogger(__name__) +class RegisterRequestTokenRestServlet(RestServlet): + PATTERNS = client_v2_patterns("/register/email/requestToken$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(RegisterRequestTokenRestServlet, self).__init__() + self.hs = hs + self.identity_handler = hs.get_handlers().identity_handler + + @defer.inlineCallbacks + def on_POST(self, request): + body = parse_json_object_from_request(request) + + required = ['id_server', 'client_secret', 'email', 'send_attempt'] + absent = [] + for k in required: + if k not in body: + absent.append(k) + + if len(absent) > 0: + raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) + + existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( + 'email', body['email'] + ) + + if existingUid is not None: + raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) + + ret = yield self.identity_handler.requestEmailToken(**body) + defer.returnValue((200, ret)) + + class RegisterRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/register") + PATTERNS = client_v2_patterns("/register$") def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ super(RegisterRestServlet, self).__init__() + self.hs = hs self.auth = hs.get_auth() - self.auth_handler = hs.get_handlers().auth_handler + self.store = hs.get_datastore() + self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_handlers().registration_handler self.identity_handler = hs.get_handlers().identity_handler + self.device_handler = hs.get_device_handler() @defer.inlineCallbacks def on_POST(self, request): yield run_on_reactor() + body = parse_json_object_from_request(request) + kind = "user" if "kind" in request.args: kind = request.args["kind"][0] if kind == "guest": - ret = yield self._do_guest_registration() + ret = yield self._do_guest_registration(body) defer.returnValue(ret) return elif kind != "user": @@ -69,12 +116,6 @@ class RegisterRestServlet(RestServlet): "Do not understand membership kind: %s" % (kind,) ) - if '/register/email/requestToken' in request.path: - ret = yield self.onEmailTokenRequest(request) - defer.returnValue(ret) - - body = parse_json_object_from_request(request) - # we do basic sanity checks here because the auth layer will store these # in sessions. Pull out the username/password provided to us. desired_password = None @@ -92,7 +133,7 @@ class RegisterRestServlet(RestServlet): desired_username = body['username'] appservice = None - if 'access_token' in request.args: + if has_access_token(request): appservice = yield self.auth.get_appservice_by_req(request) # fork off as soon as possible for ASes and shared secret auth which @@ -100,9 +141,16 @@ class RegisterRestServlet(RestServlet): # == Application Service Registration == if appservice: - result = yield self._do_appservice_registration( - desired_username, request.args["access_token"][0] - ) + # Set the desired user according to the AS API (which uses the + # 'user' key not 'username'). Since this is a new addition, we'll + # fallback to 'username' if they gave one. + desired_username = body.get("user", desired_username) + access_token = get_access_token_from_request(request) + + if isinstance(desired_username, basestring): + result = yield self._do_appservice_registration( + desired_username, access_token, body + ) defer.returnValue((200, result)) # we throw for non 200 responses return @@ -111,7 +159,7 @@ class RegisterRestServlet(RestServlet): # FIXME: Should we really be determining if this is shared secret # auth based purely on the 'mac' key? result = yield self._do_shared_secret_registration( - desired_username, desired_password, body["mac"] + desired_username, desired_password, body ) defer.returnValue((200, result)) # we throw for non 200 responses return @@ -122,6 +170,17 @@ class RegisterRestServlet(RestServlet): guest_access_token = body.get("guest_access_token", None) + if ( + 'initial_device_display_name' in body and + 'password' not in body + ): + # ignore 'initial_device_display_name' if sent without + # a password to work around a client bug where it sent + # the 'initial_device_display_name' param alone, wiping out + # the original registration params + logger.warn("Ignoring initial_device_display_name without password") + del body['initial_device_display_name'] + session_id = self.auth_handler.get_session_id(body) registered_user_id = None if session_id: @@ -151,12 +210,12 @@ class RegisterRestServlet(RestServlet): [LoginType.EMAIL_IDENTITY] ] - authed, result, params, session_id = yield self.auth_handler.check_auth( + authed, auth_result, params, session_id = yield self.auth_handler.check_auth( flows, body, self.hs.get_ip_from_request(request) ) if not authed: - defer.returnValue((401, result)) + defer.returnValue((401, auth_result)) return if registered_user_id is not None: @@ -164,78 +223,58 @@ class RegisterRestServlet(RestServlet): "Already registered user ID %r for this session", registered_user_id ) - access_token = yield self.auth_handler.issue_access_token(registered_user_id) - refresh_token = yield self.auth_handler.issue_refresh_token( - registered_user_id + # don't re-register the email address + add_email = False + else: + # NB: This may be from the auth handler and NOT from the POST + if 'password' not in params: + raise SynapseError(400, "Missing password.", + Codes.MISSING_PARAM) + + desired_username = params.get("username", None) + new_password = params.get("password", None) + guest_access_token = params.get("guest_access_token", None) + + (registered_user_id, _) = yield self.registration_handler.register( + localpart=desired_username, + password=new_password, + guest_access_token=guest_access_token, + generate_token=False, ) - defer.returnValue((200, { - "user_id": registered_user_id, - "access_token": access_token, - "home_server": self.hs.hostname, - "refresh_token": refresh_token, - })) - - # NB: This may be from the auth handler and NOT from the POST - if 'password' not in params: - raise SynapseError(400, "Missing password.", Codes.MISSING_PARAM) - - desired_username = params.get("username", None) - new_password = params.get("password", None) - guest_access_token = params.get("guest_access_token", None) - - (user_id, token) = yield self.registration_handler.register( - localpart=desired_username, - password=new_password, - guest_access_token=guest_access_token, - ) - # remember that we've now registered that user account, and with what - # user ID (since the user may not have specified) - self.auth_handler.set_session_data( - session_id, "registered_user_id", user_id + # remember that we've now registered that user account, and with + # what user ID (since the user may not have specified) + self.auth_handler.set_session_data( + session_id, "registered_user_id", registered_user_id + ) + + add_email = True + + return_dict = yield self._create_registration_details( + registered_user_id, params ) - if result and LoginType.EMAIL_IDENTITY in result: - threepid = result[LoginType.EMAIL_IDENTITY] - - for reqd in ['medium', 'address', 'validated_at']: - if reqd not in threepid: - logger.info("Can't add incomplete 3pid") - else: - yield self.auth_handler.add_threepid( - user_id, - threepid['medium'], - threepid['address'], - threepid['validated_at'], - ) - - if 'bind_email' in params and params['bind_email']: - logger.info("bind_email specified: binding") - - emailThreepid = result[LoginType.EMAIL_IDENTITY] - threepid_creds = emailThreepid['threepid_creds'] - logger.debug("Binding emails %s to %s" % ( - emailThreepid, user_id - )) - yield self.identity_handler.bind_threepid(threepid_creds, user_id) - else: - logger.info("bind_email not specified: not binding email") - - result = yield self._create_registration_details(user_id, token) - defer.returnValue((200, result)) + if add_email and auth_result and LoginType.EMAIL_IDENTITY in auth_result: + threepid = auth_result[LoginType.EMAIL_IDENTITY] + yield self._register_email_threepid( + registered_user_id, threepid, return_dict["access_token"], + params.get("bind_email") + ) + + defer.returnValue((200, return_dict)) def on_OPTIONS(self, _): return 200, {} @defer.inlineCallbacks - def _do_appservice_registration(self, username, as_token): - (user_id, token) = yield self.registration_handler.appservice_register( + def _do_appservice_registration(self, username, as_token, body): + user_id = yield self.registration_handler.appservice_register( username, as_token ) - defer.returnValue((yield self._create_registration_details(user_id, token))) + defer.returnValue((yield self._create_registration_details(user_id, body))) @defer.inlineCallbacks - def _do_shared_secret_registration(self, username, password, mac): + def _do_shared_secret_registration(self, username, password, body): if not self.hs.config.registration_shared_secret: raise SynapseError(400, "Shared secret registration is not enabled") @@ -243,7 +282,7 @@ class RegisterRestServlet(RestServlet): # str() because otherwise hmac complains that 'unicode' does not # have the buffer interface - got_mac = str(mac) + got_mac = str(body["mac"]) want_mac = hmac.new( key=self.hs.config.registration_shared_secret, @@ -256,59 +295,158 @@ class RegisterRestServlet(RestServlet): 403, "HMAC incorrect", ) - (user_id, token) = yield self.registration_handler.register( - localpart=username, password=password + (user_id, _) = yield self.registration_handler.register( + localpart=username, password=password, generate_token=False, ) - defer.returnValue((yield self._create_registration_details(user_id, token))) - @defer.inlineCallbacks - def _create_registration_details(self, user_id, token): - refresh_token = yield self.auth_handler.issue_refresh_token(user_id) - defer.returnValue({ - "user_id": user_id, - "access_token": token, - "home_server": self.hs.hostname, - "refresh_token": refresh_token, - }) + result = yield self._create_registration_details(user_id, body) + defer.returnValue(result) @defer.inlineCallbacks - def onEmailTokenRequest(self, request): - body = parse_json_object_from_request(request) + def _register_email_threepid(self, user_id, threepid, token, bind_email): + """Add an email address as a 3pid identifier + + Also adds an email pusher for the email address, if configured in the + HS config + + Also optionally binds emails to the given user_id on the identity server + + Args: + user_id (str): id of user + threepid (object): m.login.email.identity auth response + token (str): access_token for the user + bind_email (bool): true if the client requested the email to be + bound at the identity server + Returns: + defer.Deferred: + """ + reqd = ('medium', 'address', 'validated_at') + if any(x not in threepid for x in reqd): + logger.info("Can't add incomplete 3pid") + defer.returnValue() + + yield self.auth_handler.add_threepid( + user_id, + threepid['medium'], + threepid['address'], + threepid['validated_at'], + ) - required = ['id_server', 'client_secret', 'email', 'send_attempt'] - absent = [] - for k in required: - if k not in body: - absent.append(k) + # And we add an email pusher for them by default, but only + # if email notifications are enabled (so people don't start + # getting mail spam where they weren't before if email + # notifs are set up on a home server) + if (self.hs.config.email_enable_notifs and + self.hs.config.email_notif_for_new_users): + # Pull the ID of the access token back out of the db + # It would really make more sense for this to be passed + # up when the access token is saved, but that's quite an + # invasive change I'd rather do separately. + user_tuple = yield self.store.get_user_by_access_token( + token + ) + token_id = user_tuple["token_id"] + + yield self.hs.get_pusherpool().add_pusher( + user_id=user_id, + access_token=token_id, + kind="email", + app_id="m.email", + app_display_name="Email Notifications", + device_display_name=threepid["address"], + pushkey=threepid["address"], + lang=None, # We don't know a user's language here + data={}, + ) - if len(absent) > 0: - raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) + if bind_email: + logger.info("bind_email specified: binding") + logger.debug("Binding emails %s to %s" % ( + threepid, user_id + )) + yield self.identity_handler.bind_threepid( + threepid['threepid_creds'], user_id + ) + else: + logger.info("bind_email not specified: not binding email") - existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( - 'email', body['email'] + @defer.inlineCallbacks + def _create_registration_details(self, user_id, params): + """Complete registration of newly-registered user + + Allocates device_id if one was not given; also creates access_token. + + Args: + (str) user_id: full canonical @user:id + (object) params: registration parameters, from which we pull + device_id and initial_device_name + Returns: + defer.Deferred: (object) dictionary for response from /register + """ + device_id = yield self._register_device(user_id, params) + + access_token = ( + yield self.auth_handler.get_access_token_for_user_id( + user_id, device_id=device_id, + initial_display_name=params.get("initial_device_display_name") + ) ) - if existingUid is not None: - raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) + defer.returnValue({ + "user_id": user_id, + "access_token": access_token, + "home_server": self.hs.hostname, + "device_id": device_id, + }) - ret = yield self.identity_handler.requestEmailToken(**body) - defer.returnValue((200, ret)) + def _register_device(self, user_id, params): + """Register a device for a user. + + This is called after the user's credentials have been validated, but + before the access token has been issued. + + Args: + (str) user_id: full canonical @user:id + (object) params: registration parameters, from which we pull + device_id and initial_device_name + Returns: + defer.Deferred: (str) device_id + """ + # register the user's device + device_id = params.get("device_id") + initial_display_name = params.get("initial_device_display_name") + return self.device_handler.check_device_registered( + user_id, device_id, initial_display_name + ) @defer.inlineCallbacks - def _do_guest_registration(self): + def _do_guest_registration(self, params): if not self.hs.config.allow_guest_access: defer.returnValue((403, "Guest access is disabled")) user_id, _ = yield self.registration_handler.register( generate_token=False, make_guest=True ) - access_token = self.auth_handler.generate_access_token(user_id, ["guest = true"]) + + # we don't allow guests to specify their own device_id, because + # we have nowhere to store it. + device_id = synapse.api.auth.GUEST_DEVICE_ID + initial_display_name = params.get("initial_device_display_name") + self.device_handler.check_device_registered( + user_id, device_id, initial_display_name + ) + + access_token = self.auth_handler.generate_access_token( + user_id, ["guest = true"] + ) defer.returnValue((200, { "user_id": user_id, + "device_id": device_id, "access_token": access_token, "home_server": self.hs.hostname, })) def register_servlets(hs, http_server): + RegisterRequestTokenRestServlet(hs).register(http_server) RegisterRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py new file mode 100644 index 0000000000..8903e12405 --- /dev/null +++ b/synapse/rest/client/v2_alpha/report_event.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from ._base import client_v2_patterns + +import logging + + +logger = logging.getLogger(__name__) + + +class ReportEventRestServlet(RestServlet): + PATTERNS = client_v2_patterns( + "/rooms/(?P<room_id>[^/]*)/report/(?P<event_id>[^/]*)$" + ) + + def __init__(self, hs): + super(ReportEventRestServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def on_POST(self, request, room_id, event_id): + requester = yield self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + + body = parse_json_object_from_request(request) + + yield self.store.add_event_report( + room_id=room_id, + event_id=event_id, + user_id=user_id, + reason=body.get("reason"), + content=body, + received_ts=self.clock.time_msec(), + ) + + defer.returnValue((200, {})) + + +def register_servlets(hs, http_server): + ReportEventRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py new file mode 100644 index 0000000000..d607bd2970 --- /dev/null +++ b/synapse/rest/client/v2_alpha/sendtodevice.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.internet import defer + +from synapse.http import servlet +from synapse.http.servlet import parse_json_object_from_request +from synapse.rest.client.transactions import HttpTransactionCache + +from ._base import client_v2_patterns + +logger = logging.getLogger(__name__) + + +class SendToDeviceRestServlet(servlet.RestServlet): + PATTERNS = client_v2_patterns( + "/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$", + releases=[], v2_alpha=False + ) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(SendToDeviceRestServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.txns = HttpTransactionCache(hs.get_clock()) + self.device_message_handler = hs.get_device_message_handler() + + def on_PUT(self, request, message_type, txn_id): + return self.txns.fetch_or_execute_request( + request, self._put, request, message_type, txn_id + ) + + @defer.inlineCallbacks + def _put(self, request, message_type, txn_id): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + + content = parse_json_object_from_request(request) + + sender_user_id = requester.user.to_string() + + yield self.device_message_handler.send_device_message( + sender_user_id, message_type, content["messages"] + ) + + response = (200, {}) + defer.returnValue(response) + + +def register_servlets(hs, http_server): + SendToDeviceRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index de4a020ad4..7199ec883a 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -79,11 +79,10 @@ class SyncRestServlet(RestServlet): def __init__(self, hs): super(SyncRestServlet, self).__init__() self.auth = hs.get_auth() - self.event_stream_handler = hs.get_handlers().event_stream_handler - self.sync_handler = hs.get_handlers().sync_handler + self.sync_handler = hs.get_sync_handler() self.clock = hs.get_clock() self.filtering = hs.get_filtering() - self.presence_handler = hs.get_handlers().presence_handler + self.presence_handler = hs.get_presence_handler() @defer.inlineCallbacks def on_GET(self, request): @@ -98,6 +97,7 @@ class SyncRestServlet(RestServlet): request, allow_guest=True ) user = requester.user + device_id = requester.device_id timeout = parse_integer(request, "timeout", default=0) since = parse_string(request, "since") @@ -110,11 +110,13 @@ class SyncRestServlet(RestServlet): logger.info( "/sync: user=%r, timeout=%r, since=%r," - " set_presence=%r, filter_id=%r" % ( - user, timeout, since, set_presence, filter_id + " set_presence=%r, filter_id=%r, device_id=%r" % ( + user, timeout, since, set_presence, filter_id, device_id ) ) + request_key = (user, timeout, since, filter_id, full_state, device_id) + if filter_id: if filter_id.startswith('{'): try: @@ -134,6 +136,8 @@ class SyncRestServlet(RestServlet): user=user, filter_collection=filter, is_guest=requester.is_guest, + request_key=request_key, + device_id=device_id, ) if since is not None: @@ -144,7 +148,7 @@ class SyncRestServlet(RestServlet): affect_presence = set_presence != PresenceState.OFFLINE if affect_presence: - yield self.presence_handler.set_state(user, {"presence": set_presence}) + yield self.presence_handler.set_state(user, {"presence": set_presence}, True) context = yield self.presence_handler.user_syncing( user.to_string(), affect_presence=affect_presence, @@ -158,7 +162,7 @@ class SyncRestServlet(RestServlet): time_now = self.clock.time_msec() joined = self.encode_joined( - sync_result.joined, time_now, requester.access_token_id + sync_result.joined, time_now, requester.access_token_id, filter.event_fields ) invited = self.encode_invited( @@ -166,11 +170,12 @@ class SyncRestServlet(RestServlet): ) archived = self.encode_archived( - sync_result.archived, time_now, requester.access_token_id + sync_result.archived, time_now, requester.access_token_id, filter.event_fields ) response_content = { "account_data": {"events": sync_result.account_data}, + "to_device": {"events": sync_result.to_device}, "presence": self.encode_presence( sync_result.presence, time_now ), @@ -192,24 +197,27 @@ class SyncRestServlet(RestServlet): formatted.append(event) return {"events": formatted} - def encode_joined(self, rooms, time_now, token_id): + def encode_joined(self, rooms, time_now, token_id, event_fields): """ Encode the joined rooms in a sync result - :param list[synapse.handlers.sync.JoinedSyncResult] rooms: list of sync - results for rooms this user is joined to - :param int time_now: current time - used as a baseline for age - calculations - :param int token_id: ID of the user's auth token - used for namespacing - of transaction IDs - - :return: the joined rooms list, in our response format - :rtype: dict[str, dict[str, object]] + Args: + rooms(list[synapse.handlers.sync.JoinedSyncResult]): list of sync + results for rooms this user is joined to + time_now(int): current time - used as a baseline for age + calculations + token_id(int): ID of the user's auth token - used for namespacing + of transaction IDs + event_fields(list<str>): List of event fields to include. If empty, + all fields will be returned. + Returns: + dict[str, dict[str, object]]: the joined rooms list, in our + response format """ joined = {} for room in rooms: joined[room.room_id] = self.encode_room( - room, time_now, token_id + room, time_now, token_id, only_fields=event_fields ) return joined @@ -218,15 +226,17 @@ class SyncRestServlet(RestServlet): """ Encode the invited rooms in a sync result - :param list[synapse.handlers.sync.InvitedSyncResult] rooms: list of - sync results for rooms this user is joined to - :param int time_now: current time - used as a baseline for age - calculations - :param int token_id: ID of the user's auth token - used for namespacing + Args: + rooms(list[synapse.handlers.sync.InvitedSyncResult]): list of + sync results for rooms this user is joined to + time_now(int): current time - used as a baseline for age + calculations + token_id(int): ID of the user's auth token - used for namespacing of transaction IDs - :return: the invited rooms list, in our response format - :rtype: dict[str, dict[str, object]] + Returns: + dict[str, dict[str, object]]: the invited rooms list, in our + response format """ invited = {} for room in rooms: @@ -244,48 +254,53 @@ class SyncRestServlet(RestServlet): return invited - def encode_archived(self, rooms, time_now, token_id): + def encode_archived(self, rooms, time_now, token_id, event_fields): """ Encode the archived rooms in a sync result - :param list[synapse.handlers.sync.ArchivedSyncResult] rooms: list of - sync results for rooms this user is joined to - :param int time_now: current time - used as a baseline for age - calculations - :param int token_id: ID of the user's auth token - used for namespacing - of transaction IDs - - :return: the invited rooms list, in our response format - :rtype: dict[str, dict[str, object]] + Args: + rooms (list[synapse.handlers.sync.ArchivedSyncResult]): list of + sync results for rooms this user is joined to + time_now(int): current time - used as a baseline for age + calculations + token_id(int): ID of the user's auth token - used for namespacing + of transaction IDs + event_fields(list<str>): List of event fields to include. If empty, + all fields will be returned. + Returns: + dict[str, dict[str, object]]: The invited rooms list, in our + response format """ joined = {} for room in rooms: joined[room.room_id] = self.encode_room( - room, time_now, token_id, joined=False + room, time_now, token_id, joined=False, only_fields=event_fields ) return joined @staticmethod - def encode_room(room, time_now, token_id, joined=True): + def encode_room(room, time_now, token_id, joined=True, only_fields=None): """ - :param JoinedSyncResult|ArchivedSyncResult room: sync result for a - single room - :param int time_now: current time - used as a baseline for age - calculations - :param int token_id: ID of the user's auth token - used for namespacing - of transaction IDs - :param joined: True if the user is joined to this room - will mean - we handle ephemeral events - - :return: the room, encoded in our response format - :rtype: dict[str, object] + Args: + room (JoinedSyncResult|ArchivedSyncResult): sync result for a + single room + time_now (int): current time - used as a baseline for age + calculations + token_id (int): ID of the user's auth token - used for namespacing + of transaction IDs + joined (bool): True if the user is joined to this room - will mean + we handle ephemeral events + only_fields(list<str>): Optional. The list of event fields to include. + Returns: + dict[str, object]: the room, encoded in our response format """ def serialize(event): # TODO(mjark): Respect formatting requirements in the filter. return serialize_event( event, time_now, token_id=token_id, event_format=format_event_for_client_v2_without_room_id, + only_event_fields=only_fields, ) state_dict = room.state diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py new file mode 100644 index 0000000000..31f94bc6e9 --- /dev/null +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -0,0 +1,120 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging + +from twisted.internet import defer + +from synapse.api.constants import ThirdPartyEntityKind +from synapse.http.servlet import RestServlet +from ._base import client_v2_patterns + +logger = logging.getLogger(__name__) + + +class ThirdPartyProtocolsServlet(RestServlet): + PATTERNS = client_v2_patterns("/thirdparty/protocols", releases=()) + + def __init__(self, hs): + super(ThirdPartyProtocolsServlet, self).__init__() + + self.auth = hs.get_auth() + self.appservice_handler = hs.get_application_service_handler() + + @defer.inlineCallbacks + def on_GET(self, request): + yield self.auth.get_user_by_req(request) + + protocols = yield self.appservice_handler.get_3pe_protocols() + defer.returnValue((200, protocols)) + + +class ThirdPartyProtocolServlet(RestServlet): + PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$", + releases=()) + + def __init__(self, hs): + super(ThirdPartyProtocolServlet, self).__init__() + + self.auth = hs.get_auth() + self.appservice_handler = hs.get_application_service_handler() + + @defer.inlineCallbacks + def on_GET(self, request, protocol): + yield self.auth.get_user_by_req(request) + + protocols = yield self.appservice_handler.get_3pe_protocols( + only_protocol=protocol, + ) + if protocol in protocols: + defer.returnValue((200, protocols[protocol])) + else: + defer.returnValue((404, {"error": "Unknown protocol"})) + + +class ThirdPartyUserServlet(RestServlet): + PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$", + releases=()) + + def __init__(self, hs): + super(ThirdPartyUserServlet, self).__init__() + + self.auth = hs.get_auth() + self.appservice_handler = hs.get_application_service_handler() + + @defer.inlineCallbacks + def on_GET(self, request, protocol): + yield self.auth.get_user_by_req(request) + + fields = request.args + fields.pop("access_token", None) + + results = yield self.appservice_handler.query_3pe( + ThirdPartyEntityKind.USER, protocol, fields + ) + + defer.returnValue((200, results)) + + +class ThirdPartyLocationServlet(RestServlet): + PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$", + releases=()) + + def __init__(self, hs): + super(ThirdPartyLocationServlet, self).__init__() + + self.auth = hs.get_auth() + self.appservice_handler = hs.get_application_service_handler() + + @defer.inlineCallbacks + def on_GET(self, request, protocol): + yield self.auth.get_user_by_req(request) + + fields = request.args + fields.pop("access_token", None) + + results = yield self.appservice_handler.query_3pe( + ThirdPartyEntityKind.LOCATION, protocol, fields + ) + + defer.returnValue((200, results)) + + +def register_servlets(hs, http_server): + ThirdPartyProtocolsServlet(hs).register(http_server) + ThirdPartyProtocolServlet(hs).register(http_server) + ThirdPartyUserServlet(hs).register(http_server) + ThirdPartyLocationServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py index a158c2209a..6e76b9e9c2 100644 --- a/synapse/rest/client/v2_alpha/tokenrefresh.py +++ b/synapse/rest/client/v2_alpha/tokenrefresh.py @@ -15,8 +15,8 @@ from twisted.internet import defer -from synapse.api.errors import AuthError, StoreError, SynapseError -from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.api.errors import AuthError +from synapse.http.servlet import RestServlet from ._base import client_v2_patterns @@ -30,26 +30,10 @@ class TokenRefreshRestServlet(RestServlet): def __init__(self, hs): super(TokenRefreshRestServlet, self).__init__() - self.hs = hs - self.store = hs.get_datastore() @defer.inlineCallbacks def on_POST(self, request): - body = parse_json_object_from_request(request) - try: - old_refresh_token = body["refresh_token"] - auth_handler = self.hs.get_handlers().auth_handler - (user_id, new_refresh_token) = yield self.store.exchange_refresh_token( - old_refresh_token, auth_handler.generate_refresh_token) - new_access_token = yield auth_handler.issue_access_token(user_id) - defer.returnValue((200, { - "access_token": new_access_token, - "refresh_token": new_refresh_token, - })) - except KeyError: - raise SynapseError(400, "Missing required key 'refresh_token'.") - except StoreError: - raise AuthError(403, "Did not recognize refresh token") + raise AuthError(403, "tokenrefresh is no longer supported.") def register_servlets(hs, http_server): diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index ca5468c402..e984ea47db 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -26,7 +26,11 @@ class VersionsRestServlet(RestServlet): def on_GET(self, request): return (200, { - "versions": ["r0.0.1"] + "versions": [ + "r0.0.1", + "r0.1.0", + "r0.2.0", + ] }) diff --git a/synapse/rest/key/v1/server_key_resource.py b/synapse/rest/key/v1/server_key_resource.py index 3db3838b7e..bd4fea5774 100644 --- a/synapse/rest/key/v1/server_key_resource.py +++ b/synapse/rest/key/v1/server_key_resource.py @@ -49,7 +49,6 @@ class LocalKey(Resource): """ def __init__(self, hs): - self.hs = hs self.version_string = hs.version_string self.response_body = encode_canonical_json( self.response_json_object(hs.config) diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py index 93e5b1cbf0..ff95269ba8 100644 --- a/synapse/rest/key/v2/local_key_resource.py +++ b/synapse/rest/key/v2/local_key_resource.py @@ -19,8 +19,6 @@ from synapse.http.server import respond_with_json_bytes from signedjson.sign import sign_json from unpaddedbase64 import encode_base64 from canonicaljson import encode_canonical_json -from hashlib import sha256 -from OpenSSL import crypto import logging @@ -48,8 +46,12 @@ class LocalKey(Resource): "expired_ts": # integer posix timestamp when the key expired. "key": # base64 encoded NACL verification key. } - } - "tls_certificate": # base64 ASN.1 DER encoded X.509 tls cert. + }, + "tls_fingerprints": [ # Fingerprints of the TLS certs this server uses. + { + "sha256": # base64 encoded sha256 fingerprint of the X509 cert + }, + ], "signatures": { "this.server.example.com": { "algorithm:version": # NACL signature for this server @@ -90,21 +92,14 @@ class LocalKey(Resource): u"expired_ts": key.expired, } - x509_certificate_bytes = crypto.dump_certificate( - crypto.FILETYPE_ASN1, - self.config.tls_certificate - ) - - sha256_fingerprint = sha256(x509_certificate_bytes).digest() + tls_fingerprints = self.config.tls_fingerprints json_object = { u"valid_until_ts": self.valid_until_ts, u"server_name": self.config.server_name, u"verify_keys": verify_keys, u"old_verify_keys": old_verify_keys, - u"tls_fingerprints": [{ - u"sha256": encode_base64(sha256_fingerprint), - }] + u"tls_fingerprints": tls_fingerprints, } for key in self.config.signing_key: json_object = sign_json( diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 9552016fec..9fe2013657 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -15,6 +15,7 @@ from synapse.http.server import request_handler, respond_with_json_bytes from synapse.http.servlet import parse_integer, parse_json_object_from_request from synapse.api.errors import SynapseError, Codes +from synapse.crypto.keyring import KeyLookupError from twisted.web.resource import Resource from twisted.web.server import NOT_DONE_YET @@ -97,7 +98,7 @@ class RemoteKey(Resource): self.async_render_GET(request) return NOT_DONE_YET - @request_handler + @request_handler() @defer.inlineCallbacks def async_render_GET(self, request): if len(request.postpath) == 1: @@ -122,7 +123,7 @@ class RemoteKey(Resource): self.async_render_POST(request) return NOT_DONE_YET - @request_handler + @request_handler() @defer.inlineCallbacks def async_render_POST(self, request): content = parse_json_object_from_request(request) @@ -210,9 +211,10 @@ class RemoteKey(Resource): yield self.keyring.get_server_verify_key_v2_direct( server_name, key_ids ) + except KeyLookupError as e: + logger.info("Failed to fetch key: %s", e) except: logger.exception("Failed to get key for %r", server_name) - pass yield self.query_keys( request, query, query_remote_on_cache_miss=False ) diff --git a/synapse/rest/media/v0/content_repository.py b/synapse/rest/media/v0/content_repository.py index d9fc045fc6..956bd5da75 100644 --- a/synapse/rest/media/v0/content_repository.py +++ b/synapse/rest/media/v0/content_repository.py @@ -15,14 +15,12 @@ from synapse.http.server import respond_with_json_bytes, finish_request -from synapse.util.stringutils import random_string from synapse.api.errors import ( - cs_exception, SynapseError, CodeMessageException, Codes, cs_error + Codes, cs_error ) from twisted.protocols.basic import FileSender from twisted.web import server, resource -from twisted.internet import defer import base64 import simplejson as json @@ -50,64 +48,10 @@ class ContentRepoResource(resource.Resource): """ isLeaf = True - def __init__(self, hs, directory, auth, external_addr): + def __init__(self, hs, directory): resource.Resource.__init__(self) self.hs = hs self.directory = directory - self.auth = auth - self.external_addr = external_addr.rstrip('/') - self.max_upload_size = hs.config.max_upload_size - - if not os.path.isdir(self.directory): - os.mkdir(self.directory) - logger.info("ContentRepoResource : Created %s directory.", - self.directory) - - @defer.inlineCallbacks - def map_request_to_name(self, request): - # auth the user - requester = yield self.auth.get_user_by_req(request) - - # namespace all file uploads on the user - prefix = base64.urlsafe_b64encode( - requester.user.to_string() - ).replace('=', '') - - # use a random string for the main portion - main_part = random_string(24) - - # suffix with a file extension if we can make one. This is nice to - # provide a hint to clients on the file information. We will also reuse - # this info to spit back the content type to the client. - suffix = "" - if request.requestHeaders.hasHeader("Content-Type"): - content_type = request.requestHeaders.getRawHeaders( - "Content-Type")[0] - suffix = "." + base64.urlsafe_b64encode(content_type) - if (content_type.split("/")[0].lower() in - ["image", "video", "audio"]): - file_ext = content_type.split("/")[-1] - # be a little paranoid and only allow a-z - file_ext = re.sub("[^a-z]", "", file_ext) - suffix += "." + file_ext - - file_name = prefix + main_part + suffix - file_path = os.path.join(self.directory, file_name) - logger.info("User %s is uploading a file to path %s", - request.user.user_id.to_string(), - file_path) - - # keep trying to make a non-clashing file, with a sensible max attempts - attempts = 0 - while os.path.exists(file_path): - main_part = random_string(24) - file_name = prefix + main_part + suffix - file_path = os.path.join(self.directory, file_name) - attempts += 1 - if attempts > 25: # really? Really? - raise SynapseError(500, "Unable to create file.") - - defer.returnValue(file_path) def render_GET(self, request): # no auth here on purpose, to allow anyone to view, even across home @@ -155,58 +99,6 @@ class ContentRepoResource(resource.Resource): return server.NOT_DONE_YET - def render_POST(self, request): - self._async_render(request) - return server.NOT_DONE_YET - def render_OPTIONS(self, request): respond_with_json_bytes(request, 200, {}, send_cors=True) return server.NOT_DONE_YET - - @defer.inlineCallbacks - def _async_render(self, request): - try: - # TODO: The checks here are a bit late. The content will have - # already been uploaded to a tmp file at this point - content_length = request.getHeader("Content-Length") - if content_length is None: - raise SynapseError( - msg="Request must specify a Content-Length", code=400 - ) - if int(content_length) > self.max_upload_size: - raise SynapseError( - msg="Upload request body is too large", - code=413, - ) - - fname = yield self.map_request_to_name(request) - - # TODO I have a suspicious feeling this is just going to block - with open(fname, "wb") as f: - f.write(request.content.read()) - - # FIXME (erikj): These should use constants. - file_name = os.path.basename(fname) - # FIXME: we can't assume what the repo's public mounted path is - # ...plus self-signed SSL won't work to remote clients anyway - # ...and we can't assume that it's SSL anyway, as we might want to - # serve it via the non-SSL listener... - url = "%s/_matrix/content/%s" % ( - self.external_addr, file_name - ) - - respond_with_json_bytes(request, 200, - json.dumps({"content_token": url}), - send_cors=True) - - except CodeMessageException as e: - logger.exception(e) - respond_with_json_bytes(request, e.code, - json.dumps(cs_exception(e))) - except Exception as e: - logger.error("Failed to store file: %s" % e) - respond_with_json_bytes( - request, - 500, - json.dumps({"error": "Internal server error"}), - send_cors=True) diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py new file mode 100644 index 0000000000..b9600f2167 --- /dev/null +++ b/synapse/rest/media/v1/_base.py @@ -0,0 +1,110 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.http.server import respond_with_json, finish_request +from synapse.api.errors import ( + cs_error, Codes, SynapseError +) + +from twisted.internet import defer +from twisted.protocols.basic import FileSender + +from synapse.util.stringutils import is_ascii + +import os + +import logging +import urllib +import urlparse + +logger = logging.getLogger(__name__) + + +def parse_media_id(request): + try: + # This allows users to append e.g. /test.png to the URL. Useful for + # clients that parse the URL to see content type. + server_name, media_id = request.postpath[:2] + file_name = None + if len(request.postpath) > 2: + try: + file_name = urlparse.unquote(request.postpath[-1]).decode("utf-8") + except UnicodeDecodeError: + pass + return server_name, media_id, file_name + except: + raise SynapseError( + 404, + "Invalid media id token %r" % (request.postpath,), + Codes.UNKNOWN, + ) + + +def respond_404(request): + respond_with_json( + request, 404, + cs_error( + "Not found %r" % (request.postpath,), + code=Codes.NOT_FOUND, + ), + send_cors=True + ) + + +@defer.inlineCallbacks +def respond_with_file(request, media_type, file_path, + file_size=None, upload_name=None): + logger.debug("Responding with %r", file_path) + + if os.path.isfile(file_path): + request.setHeader(b"Content-Type", media_type.encode("UTF-8")) + if upload_name: + if is_ascii(upload_name): + request.setHeader( + b"Content-Disposition", + b"inline; filename=%s" % ( + urllib.quote(upload_name.encode("utf-8")), + ), + ) + else: + request.setHeader( + b"Content-Disposition", + b"inline; filename*=utf-8''%s" % ( + urllib.quote(upload_name.encode("utf-8")), + ), + ) + + # cache for at least a day. + # XXX: we might want to turn this off for data we don't want to + # recommend caching as it's sensitive or private - or at least + # select private. don't bother setting Expires as all our + # clients are smart enough to be happy with Cache-Control + request.setHeader( + b"Cache-Control", b"public,max-age=86400,s-maxage=86400" + ) + if file_size is None: + stat = os.stat(file_path) + file_size = stat.st_size + + request.setHeader( + b"Content-Length", b"%d" % (file_size,) + ) + + with open(file_path, "rb") as f: + yield FileSender().beginFileTransfer(f, request) + + finish_request(request) + else: + respond_404(request) diff --git a/synapse/rest/media/v1/base_resource.py b/synapse/rest/media/v1/base_resource.py deleted file mode 100644 index 58ef91c0b8..0000000000 --- a/synapse/rest/media/v1/base_resource.py +++ /dev/null @@ -1,459 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .thumbnailer import Thumbnailer - -from synapse.http.matrixfederationclient import MatrixFederationHttpClient -from synapse.http.server import respond_with_json, finish_request -from synapse.util.stringutils import random_string -from synapse.api.errors import ( - cs_error, Codes, SynapseError -) - -from twisted.internet import defer, threads -from twisted.web.resource import Resource -from twisted.protocols.basic import FileSender - -from synapse.util.async import ObservableDeferred -from synapse.util.stringutils import is_ascii -from synapse.util.logcontext import preserve_context_over_fn - -import os - -import cgi -import logging -import urllib -import urlparse - -logger = logging.getLogger(__name__) - - -def parse_media_id(request): - try: - # This allows users to append e.g. /test.png to the URL. Useful for - # clients that parse the URL to see content type. - server_name, media_id = request.postpath[:2] - file_name = None - if len(request.postpath) > 2: - try: - file_name = urlparse.unquote(request.postpath[-1]).decode("utf-8") - except UnicodeDecodeError: - pass - return server_name, media_id, file_name - except: - raise SynapseError( - 404, - "Invalid media id token %r" % (request.postpath,), - Codes.UNKNOWN, - ) - - -class BaseMediaResource(Resource): - isLeaf = True - - def __init__(self, hs, filepaths): - Resource.__init__(self) - self.auth = hs.get_auth() - self.client = MatrixFederationHttpClient(hs) - self.clock = hs.get_clock() - self.server_name = hs.hostname - self.store = hs.get_datastore() - self.max_upload_size = hs.config.max_upload_size - self.max_image_pixels = hs.config.max_image_pixels - self.filepaths = filepaths - self.version_string = hs.version_string - self.downloads = {} - self.dynamic_thumbnails = hs.config.dynamic_thumbnails - self.thumbnail_requirements = hs.config.thumbnail_requirements - - def _respond_404(self, request): - respond_with_json( - request, 404, - cs_error( - "Not found %r" % (request.postpath,), - code=Codes.NOT_FOUND, - ), - send_cors=True - ) - - @staticmethod - def _makedirs(filepath): - dirname = os.path.dirname(filepath) - if not os.path.exists(dirname): - os.makedirs(dirname) - - def _get_remote_media(self, server_name, media_id): - key = (server_name, media_id) - download = self.downloads.get(key) - if download is None: - download = self._get_remote_media_impl(server_name, media_id) - download = ObservableDeferred( - download, - consumeErrors=True - ) - self.downloads[key] = download - - @download.addBoth - def callback(media_info): - del self.downloads[key] - return media_info - return download.observe() - - @defer.inlineCallbacks - def _get_remote_media_impl(self, server_name, media_id): - media_info = yield self.store.get_cached_remote_media( - server_name, media_id - ) - if not media_info: - media_info = yield self._download_remote_file( - server_name, media_id - ) - defer.returnValue(media_info) - - @defer.inlineCallbacks - def _download_remote_file(self, server_name, media_id): - file_id = random_string(24) - - fname = self.filepaths.remote_media_filepath( - server_name, file_id - ) - self._makedirs(fname) - - try: - with open(fname, "wb") as f: - request_path = "/".join(( - "/_matrix/media/v1/download", server_name, media_id, - )) - length, headers = yield self.client.get_file( - server_name, request_path, output_stream=f, - max_size=self.max_upload_size, - ) - media_type = headers["Content-Type"][0] - time_now_ms = self.clock.time_msec() - - content_disposition = headers.get("Content-Disposition", None) - if content_disposition: - _, params = cgi.parse_header(content_disposition[0],) - upload_name = None - - # First check if there is a valid UTF-8 filename - upload_name_utf8 = params.get("filename*", None) - if upload_name_utf8: - if upload_name_utf8.lower().startswith("utf-8''"): - upload_name = upload_name_utf8[7:] - - # If there isn't check for an ascii name. - if not upload_name: - upload_name_ascii = params.get("filename", None) - if upload_name_ascii and is_ascii(upload_name_ascii): - upload_name = upload_name_ascii - - if upload_name: - upload_name = urlparse.unquote(upload_name) - try: - upload_name = upload_name.decode("utf-8") - except UnicodeDecodeError: - upload_name = None - else: - upload_name = None - - yield self.store.store_cached_remote_media( - origin=server_name, - media_id=media_id, - media_type=media_type, - time_now_ms=self.clock.time_msec(), - upload_name=upload_name, - media_length=length, - filesystem_id=file_id, - ) - except: - os.remove(fname) - raise - - media_info = { - "media_type": media_type, - "media_length": length, - "upload_name": upload_name, - "created_ts": time_now_ms, - "filesystem_id": file_id, - } - - yield self._generate_remote_thumbnails( - server_name, media_id, media_info - ) - - defer.returnValue(media_info) - - @defer.inlineCallbacks - def _respond_with_file(self, request, media_type, file_path, - file_size=None, upload_name=None): - logger.debug("Responding with %r", file_path) - - if os.path.isfile(file_path): - request.setHeader(b"Content-Type", media_type.encode("UTF-8")) - if upload_name: - if is_ascii(upload_name): - request.setHeader( - b"Content-Disposition", - b"inline; filename=%s" % ( - urllib.quote(upload_name.encode("utf-8")), - ), - ) - else: - request.setHeader( - b"Content-Disposition", - b"inline; filename*=utf-8''%s" % ( - urllib.quote(upload_name.encode("utf-8")), - ), - ) - - # cache for at least a day. - # XXX: we might want to turn this off for data we don't want to - # recommend caching as it's sensitive or private - or at least - # select private. don't bother setting Expires as all our - # clients are smart enough to be happy with Cache-Control - request.setHeader( - b"Cache-Control", b"public,max-age=86400,s-maxage=86400" - ) - if file_size is None: - stat = os.stat(file_path) - file_size = stat.st_size - - request.setHeader( - b"Content-Length", b"%d" % (file_size,) - ) - - with open(file_path, "rb") as f: - yield FileSender().beginFileTransfer(f, request) - - finish_request(request) - else: - self._respond_404(request) - - def _get_thumbnail_requirements(self, media_type): - return self.thumbnail_requirements.get(media_type, ()) - - def _generate_thumbnail(self, input_path, t_path, t_width, t_height, - t_method, t_type): - thumbnailer = Thumbnailer(input_path) - m_width = thumbnailer.width - m_height = thumbnailer.height - - if m_width * m_height >= self.max_image_pixels: - logger.info( - "Image too large to thumbnail %r x %r > %r", - m_width, m_height, self.max_image_pixels - ) - return - - if t_method == "crop": - t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) - elif t_method == "scale": - t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) - else: - t_len = None - - return t_len - - @defer.inlineCallbacks - def _generate_local_exact_thumbnail(self, media_id, t_width, t_height, - t_method, t_type): - input_path = self.filepaths.local_media_filepath(media_id) - - t_path = self.filepaths.local_media_thumbnail( - media_id, t_width, t_height, t_type, t_method - ) - self._makedirs(t_path) - - t_len = yield preserve_context_over_fn( - threads.deferToThread, - self._generate_thumbnail, - input_path, t_path, t_width, t_height, t_method, t_type - ) - - if t_len: - yield self.store.store_local_thumbnail( - media_id, t_width, t_height, t_type, t_method, t_len - ) - - defer.returnValue(t_path) - - @defer.inlineCallbacks - def _generate_remote_exact_thumbnail(self, server_name, file_id, media_id, - t_width, t_height, t_method, t_type): - input_path = self.filepaths.remote_media_filepath(server_name, file_id) - - t_path = self.filepaths.remote_media_thumbnail( - server_name, file_id, t_width, t_height, t_type, t_method - ) - self._makedirs(t_path) - - t_len = yield preserve_context_over_fn( - threads.deferToThread, - self._generate_thumbnail, - input_path, t_path, t_width, t_height, t_method, t_type - ) - - if t_len: - yield self.store.store_remote_media_thumbnail( - server_name, media_id, file_id, - t_width, t_height, t_type, t_method, t_len - ) - - defer.returnValue(t_path) - - @defer.inlineCallbacks - def _generate_local_thumbnails(self, media_id, media_info): - media_type = media_info["media_type"] - requirements = self._get_thumbnail_requirements(media_type) - if not requirements: - return - - input_path = self.filepaths.local_media_filepath(media_id) - thumbnailer = Thumbnailer(input_path) - m_width = thumbnailer.width - m_height = thumbnailer.height - - if m_width * m_height >= self.max_image_pixels: - logger.info( - "Image too large to thumbnail %r x %r > %r", - m_width, m_height, self.max_image_pixels - ) - return - - local_thumbnails = [] - - def generate_thumbnails(): - scales = set() - crops = set() - for r_width, r_height, r_method, r_type in requirements: - if r_method == "scale": - t_width, t_height = thumbnailer.aspect(r_width, r_height) - scales.add(( - min(m_width, t_width), min(m_height, t_height), r_type, - )) - elif r_method == "crop": - crops.add((r_width, r_height, r_type)) - - for t_width, t_height, t_type in scales: - t_method = "scale" - t_path = self.filepaths.local_media_thumbnail( - media_id, t_width, t_height, t_type, t_method - ) - self._makedirs(t_path) - t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) - - local_thumbnails.append(( - media_id, t_width, t_height, t_type, t_method, t_len - )) - - for t_width, t_height, t_type in crops: - if (t_width, t_height, t_type) in scales: - # If the aspect ratio of the cropped thumbnail matches a purely - # scaled one then there is no point in calculating a separate - # thumbnail. - continue - t_method = "crop" - t_path = self.filepaths.local_media_thumbnail( - media_id, t_width, t_height, t_type, t_method - ) - self._makedirs(t_path) - t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) - local_thumbnails.append(( - media_id, t_width, t_height, t_type, t_method, t_len - )) - - yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails) - - for l in local_thumbnails: - yield self.store.store_local_thumbnail(*l) - - defer.returnValue({ - "width": m_width, - "height": m_height, - }) - - @defer.inlineCallbacks - def _generate_remote_thumbnails(self, server_name, media_id, media_info): - media_type = media_info["media_type"] - file_id = media_info["filesystem_id"] - requirements = self._get_thumbnail_requirements(media_type) - if not requirements: - return - - remote_thumbnails = [] - - input_path = self.filepaths.remote_media_filepath(server_name, file_id) - thumbnailer = Thumbnailer(input_path) - m_width = thumbnailer.width - m_height = thumbnailer.height - - def generate_thumbnails(): - if m_width * m_height >= self.max_image_pixels: - logger.info( - "Image too large to thumbnail %r x %r > %r", - m_width, m_height, self.max_image_pixels - ) - return - - scales = set() - crops = set() - for r_width, r_height, r_method, r_type in requirements: - if r_method == "scale": - t_width, t_height = thumbnailer.aspect(r_width, r_height) - scales.add(( - min(m_width, t_width), min(m_height, t_height), r_type, - )) - elif r_method == "crop": - crops.add((r_width, r_height, r_type)) - - for t_width, t_height, t_type in scales: - t_method = "scale" - t_path = self.filepaths.remote_media_thumbnail( - server_name, file_id, t_width, t_height, t_type, t_method - ) - self._makedirs(t_path) - t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) - remote_thumbnails.append([ - server_name, media_id, file_id, - t_width, t_height, t_type, t_method, t_len - ]) - - for t_width, t_height, t_type in crops: - if (t_width, t_height, t_type) in scales: - # If the aspect ratio of the cropped thumbnail matches a purely - # scaled one then there is no point in calculating a separate - # thumbnail. - continue - t_method = "crop" - t_path = self.filepaths.remote_media_thumbnail( - server_name, file_id, t_width, t_height, t_type, t_method - ) - self._makedirs(t_path) - t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) - remote_thumbnails.append([ - server_name, media_id, file_id, - t_width, t_height, t_type, t_method, t_len - ]) - - yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails) - - for r in remote_thumbnails: - yield self.store.store_remote_media_thumbnail(*r) - - defer.returnValue({ - "width": m_width, - "height": m_height, - }) diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py index 1aad6b3551..dfb87ffd15 100644 --- a/synapse/rest/media/v1/download_resource.py +++ b/synapse/rest/media/v1/download_resource.py @@ -13,8 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .base_resource import BaseMediaResource, parse_media_id -from synapse.http.server import request_handler +from ._base import parse_media_id, respond_with_file, respond_404 +from twisted.web.resource import Resource +from synapse.http.server import request_handler, set_cors_headers from twisted.web.server import NOT_DONE_YET from twisted.internet import defer @@ -24,14 +25,35 @@ import logging logger = logging.getLogger(__name__) -class DownloadResource(BaseMediaResource): +class DownloadResource(Resource): + isLeaf = True + + def __init__(self, hs, media_repo): + Resource.__init__(self) + + self.filepaths = media_repo.filepaths + self.media_repo = media_repo + self.server_name = hs.hostname + self.store = hs.get_datastore() + self.version_string = hs.version_string + self.clock = hs.get_clock() + def render_GET(self, request): self._async_render_GET(request) return NOT_DONE_YET - @request_handler + @request_handler() @defer.inlineCallbacks def _async_render_GET(self, request): + set_cors_headers(request) + request.setHeader( + "Content-Security-Policy", + "default-src 'none';" + " script-src 'none';" + " plugin-types application/pdf;" + " style-src 'unsafe-inline';" + " object-src 'self';" + ) server_name, media_id, name = parse_media_id(request) if server_name == self.server_name: yield self._respond_local_file(request, media_id, name) @@ -44,7 +66,7 @@ class DownloadResource(BaseMediaResource): def _respond_local_file(self, request, media_id, name): media_info = yield self.store.get_local_media(media_id) if not media_info: - self._respond_404(request) + respond_404(request) return media_type = media_info["media_type"] @@ -52,14 +74,14 @@ class DownloadResource(BaseMediaResource): upload_name = name if name else media_info["upload_name"] file_path = self.filepaths.local_media_filepath(media_id) - yield self._respond_with_file( + yield respond_with_file( request, media_type, file_path, media_length, upload_name=upload_name, ) @defer.inlineCallbacks def _respond_remote_file(self, request, server_name, media_id, name): - media_info = yield self._get_remote_media(server_name, media_id) + media_info = yield self.media_repo.get_remote_media(server_name, media_id) media_type = media_info["media_type"] media_length = media_info["media_length"] @@ -70,7 +92,7 @@ class DownloadResource(BaseMediaResource): server_name, filesystem_id ) - yield self._respond_with_file( + yield respond_with_file( request, media_type, file_path, media_length, upload_name=upload_name, ) diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py index 422ab86fb3..0137458f71 100644 --- a/synapse/rest/media/v1/filepath.py +++ b/synapse/rest/media/v1/filepath.py @@ -65,3 +65,9 @@ class MediaFilePaths(object): file_id[0:2], file_id[2:4], file_id[4:], file_name ) + + def remote_media_thumbnail_dir(self, server_name, file_id): + return os.path.join( + self.base_path, "remote_thumbnail", server_name, + file_id[0:2], file_id[2:4], file_id[4:], + ) diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 7dfb027dd1..692e078419 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -17,15 +17,458 @@ from .upload_resource import UploadResource from .download_resource import DownloadResource from .thumbnail_resource import ThumbnailResource from .identicon_resource import IdenticonResource +from .preview_url_resource import PreviewUrlResource from .filepath import MediaFilePaths from twisted.web.resource import Resource +from .thumbnailer import Thumbnailer + +from synapse.http.matrixfederationclient import MatrixFederationHttpClient +from synapse.util.stringutils import random_string +from synapse.api.errors import SynapseError + +from twisted.internet import defer, threads + +from synapse.util.async import Linearizer +from synapse.util.stringutils import is_ascii +from synapse.util.logcontext import preserve_context_over_fn + +import os +import errno +import shutil + +import cgi import logging +import urlparse logger = logging.getLogger(__name__) +UPDATE_RECENTLY_ACCESSED_REMOTES_TS = 60 * 1000 + + +class MediaRepository(object): + def __init__(self, hs): + self.auth = hs.get_auth() + self.client = MatrixFederationHttpClient(hs) + self.clock = hs.get_clock() + self.server_name = hs.hostname + self.store = hs.get_datastore() + self.max_upload_size = hs.config.max_upload_size + self.max_image_pixels = hs.config.max_image_pixels + self.filepaths = MediaFilePaths(hs.config.media_store_path) + self.dynamic_thumbnails = hs.config.dynamic_thumbnails + self.thumbnail_requirements = hs.config.thumbnail_requirements + + self.remote_media_linearizer = Linearizer() + + self.recently_accessed_remotes = set() + + self.clock.looping_call( + self._update_recently_accessed_remotes, + UPDATE_RECENTLY_ACCESSED_REMOTES_TS + ) + + @defer.inlineCallbacks + def _update_recently_accessed_remotes(self): + media = self.recently_accessed_remotes + self.recently_accessed_remotes = set() + + yield self.store.update_cached_last_access_time( + media, self.clock.time_msec() + ) + + @staticmethod + def _makedirs(filepath): + dirname = os.path.dirname(filepath) + if not os.path.exists(dirname): + os.makedirs(dirname) + + @defer.inlineCallbacks + def create_content(self, media_type, upload_name, content, content_length, + auth_user): + media_id = random_string(24) + + fname = self.filepaths.local_media_filepath(media_id) + self._makedirs(fname) + + # This shouldn't block for very long because the content will have + # already been uploaded at this point. + with open(fname, "wb") as f: + f.write(content) + + yield self.store.store_local_media( + media_id=media_id, + media_type=media_type, + time_now_ms=self.clock.time_msec(), + upload_name=upload_name, + media_length=content_length, + user_id=auth_user, + ) + media_info = { + "media_type": media_type, + "media_length": content_length, + } + + yield self._generate_local_thumbnails(media_id, media_info) + + defer.returnValue("mxc://%s/%s" % (self.server_name, media_id)) + + @defer.inlineCallbacks + def get_remote_media(self, server_name, media_id): + key = (server_name, media_id) + with (yield self.remote_media_linearizer.queue(key)): + media_info = yield self._get_remote_media_impl(server_name, media_id) + defer.returnValue(media_info) + + @defer.inlineCallbacks + def _get_remote_media_impl(self, server_name, media_id): + media_info = yield self.store.get_cached_remote_media( + server_name, media_id + ) + if not media_info: + media_info = yield self._download_remote_file( + server_name, media_id + ) + else: + self.recently_accessed_remotes.add((server_name, media_id)) + yield self.store.update_cached_last_access_time( + [(server_name, media_id)], self.clock.time_msec() + ) + defer.returnValue(media_info) + + @defer.inlineCallbacks + def _download_remote_file(self, server_name, media_id): + file_id = random_string(24) + + fname = self.filepaths.remote_media_filepath( + server_name, file_id + ) + self._makedirs(fname) + + try: + with open(fname, "wb") as f: + request_path = "/".join(( + "/_matrix/media/v1/download", server_name, media_id, + )) + try: + length, headers = yield self.client.get_file( + server_name, request_path, output_stream=f, + max_size=self.max_upload_size, + ) + except Exception as e: + logger.warn("Failed to fetch remoted media %r", e) + raise SynapseError(502, "Failed to fetch remoted media") + + media_type = headers["Content-Type"][0] + time_now_ms = self.clock.time_msec() + + content_disposition = headers.get("Content-Disposition", None) + if content_disposition: + _, params = cgi.parse_header(content_disposition[0],) + upload_name = None + + # First check if there is a valid UTF-8 filename + upload_name_utf8 = params.get("filename*", None) + if upload_name_utf8: + if upload_name_utf8.lower().startswith("utf-8''"): + upload_name = upload_name_utf8[7:] + + # If there isn't check for an ascii name. + if not upload_name: + upload_name_ascii = params.get("filename", None) + if upload_name_ascii and is_ascii(upload_name_ascii): + upload_name = upload_name_ascii + + if upload_name: + upload_name = urlparse.unquote(upload_name) + try: + upload_name = upload_name.decode("utf-8") + except UnicodeDecodeError: + upload_name = None + else: + upload_name = None + + yield self.store.store_cached_remote_media( + origin=server_name, + media_id=media_id, + media_type=media_type, + time_now_ms=self.clock.time_msec(), + upload_name=upload_name, + media_length=length, + filesystem_id=file_id, + ) + except: + os.remove(fname) + raise + + media_info = { + "media_type": media_type, + "media_length": length, + "upload_name": upload_name, + "created_ts": time_now_ms, + "filesystem_id": file_id, + } + + yield self._generate_remote_thumbnails( + server_name, media_id, media_info + ) + + defer.returnValue(media_info) + + def _get_thumbnail_requirements(self, media_type): + return self.thumbnail_requirements.get(media_type, ()) + + def _generate_thumbnail(self, input_path, t_path, t_width, t_height, + t_method, t_type): + thumbnailer = Thumbnailer(input_path) + m_width = thumbnailer.width + m_height = thumbnailer.height + + if m_width * m_height >= self.max_image_pixels: + logger.info( + "Image too large to thumbnail %r x %r > %r", + m_width, m_height, self.max_image_pixels + ) + return + + if t_method == "crop": + t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) + elif t_method == "scale": + t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) + else: + t_len = None + + return t_len + + @defer.inlineCallbacks + def generate_local_exact_thumbnail(self, media_id, t_width, t_height, + t_method, t_type): + input_path = self.filepaths.local_media_filepath(media_id) + + t_path = self.filepaths.local_media_thumbnail( + media_id, t_width, t_height, t_type, t_method + ) + self._makedirs(t_path) + + t_len = yield preserve_context_over_fn( + threads.deferToThread, + self._generate_thumbnail, + input_path, t_path, t_width, t_height, t_method, t_type + ) + + if t_len: + yield self.store.store_local_thumbnail( + media_id, t_width, t_height, t_type, t_method, t_len + ) + + defer.returnValue(t_path) + + @defer.inlineCallbacks + def generate_remote_exact_thumbnail(self, server_name, file_id, media_id, + t_width, t_height, t_method, t_type): + input_path = self.filepaths.remote_media_filepath(server_name, file_id) + + t_path = self.filepaths.remote_media_thumbnail( + server_name, file_id, t_width, t_height, t_type, t_method + ) + self._makedirs(t_path) + + t_len = yield preserve_context_over_fn( + threads.deferToThread, + self._generate_thumbnail, + input_path, t_path, t_width, t_height, t_method, t_type + ) + + if t_len: + yield self.store.store_remote_media_thumbnail( + server_name, media_id, file_id, + t_width, t_height, t_type, t_method, t_len + ) + + defer.returnValue(t_path) + + @defer.inlineCallbacks + def _generate_local_thumbnails(self, media_id, media_info): + media_type = media_info["media_type"] + requirements = self._get_thumbnail_requirements(media_type) + if not requirements: + return + + input_path = self.filepaths.local_media_filepath(media_id) + thumbnailer = Thumbnailer(input_path) + m_width = thumbnailer.width + m_height = thumbnailer.height + + if m_width * m_height >= self.max_image_pixels: + logger.info( + "Image too large to thumbnail %r x %r > %r", + m_width, m_height, self.max_image_pixels + ) + return + + local_thumbnails = [] + + def generate_thumbnails(): + scales = set() + crops = set() + for r_width, r_height, r_method, r_type in requirements: + if r_method == "scale": + t_width, t_height = thumbnailer.aspect(r_width, r_height) + scales.add(( + min(m_width, t_width), min(m_height, t_height), r_type, + )) + elif r_method == "crop": + crops.add((r_width, r_height, r_type)) + + for t_width, t_height, t_type in scales: + t_method = "scale" + t_path = self.filepaths.local_media_thumbnail( + media_id, t_width, t_height, t_type, t_method + ) + self._makedirs(t_path) + t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) + + local_thumbnails.append(( + media_id, t_width, t_height, t_type, t_method, t_len + )) + + for t_width, t_height, t_type in crops: + if (t_width, t_height, t_type) in scales: + # If the aspect ratio of the cropped thumbnail matches a purely + # scaled one then there is no point in calculating a separate + # thumbnail. + continue + t_method = "crop" + t_path = self.filepaths.local_media_thumbnail( + media_id, t_width, t_height, t_type, t_method + ) + self._makedirs(t_path) + t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) + local_thumbnails.append(( + media_id, t_width, t_height, t_type, t_method, t_len + )) + + yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails) + + for l in local_thumbnails: + yield self.store.store_local_thumbnail(*l) + + defer.returnValue({ + "width": m_width, + "height": m_height, + }) + + @defer.inlineCallbacks + def _generate_remote_thumbnails(self, server_name, media_id, media_info): + media_type = media_info["media_type"] + file_id = media_info["filesystem_id"] + requirements = self._get_thumbnail_requirements(media_type) + if not requirements: + return + + remote_thumbnails = [] + + input_path = self.filepaths.remote_media_filepath(server_name, file_id) + thumbnailer = Thumbnailer(input_path) + m_width = thumbnailer.width + m_height = thumbnailer.height + + def generate_thumbnails(): + if m_width * m_height >= self.max_image_pixels: + logger.info( + "Image too large to thumbnail %r x %r > %r", + m_width, m_height, self.max_image_pixels + ) + return + + scales = set() + crops = set() + for r_width, r_height, r_method, r_type in requirements: + if r_method == "scale": + t_width, t_height = thumbnailer.aspect(r_width, r_height) + scales.add(( + min(m_width, t_width), min(m_height, t_height), r_type, + )) + elif r_method == "crop": + crops.add((r_width, r_height, r_type)) + + for t_width, t_height, t_type in scales: + t_method = "scale" + t_path = self.filepaths.remote_media_thumbnail( + server_name, file_id, t_width, t_height, t_type, t_method + ) + self._makedirs(t_path) + t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) + remote_thumbnails.append([ + server_name, media_id, file_id, + t_width, t_height, t_type, t_method, t_len + ]) + + for t_width, t_height, t_type in crops: + if (t_width, t_height, t_type) in scales: + # If the aspect ratio of the cropped thumbnail matches a purely + # scaled one then there is no point in calculating a separate + # thumbnail. + continue + t_method = "crop" + t_path = self.filepaths.remote_media_thumbnail( + server_name, file_id, t_width, t_height, t_type, t_method + ) + self._makedirs(t_path) + t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) + remote_thumbnails.append([ + server_name, media_id, file_id, + t_width, t_height, t_type, t_method, t_len + ]) + + yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails) + + for r in remote_thumbnails: + yield self.store.store_remote_media_thumbnail(*r) + + defer.returnValue({ + "width": m_width, + "height": m_height, + }) + + @defer.inlineCallbacks + def delete_old_remote_media(self, before_ts): + old_media = yield self.store.get_remote_media_before(before_ts) + + deleted = 0 + + for media in old_media: + origin = media["media_origin"] + media_id = media["media_id"] + file_id = media["filesystem_id"] + key = (origin, media_id) + + logger.info("Deleting: %r", key) + + with (yield self.remote_media_linearizer.queue(key)): + full_path = self.filepaths.remote_media_filepath(origin, file_id) + try: + os.remove(full_path) + except OSError as e: + logger.warn("Failed to remove file: %r", full_path) + if e.errno == errno.ENOENT: + pass + else: + continue + + thumbnail_dir = self.filepaths.remote_media_thumbnail_dir( + origin, file_id + ) + shutil.rmtree(thumbnail_dir, ignore_errors=True) + + yield self.store.delete_remote_media(origin, media_id) + deleted += 1 + + defer.returnValue({"deleted": deleted}) + + class MediaRepositoryResource(Resource): """File uploading and downloading. @@ -73,8 +516,12 @@ class MediaRepositoryResource(Resource): def __init__(self, hs): Resource.__init__(self) - filepaths = MediaFilePaths(hs.config.media_store_path) - self.putChild("upload", UploadResource(hs, filepaths)) - self.putChild("download", DownloadResource(hs, filepaths)) - self.putChild("thumbnail", ThumbnailResource(hs, filepaths)) + + media_repo = hs.get_media_repository() + + self.putChild("upload", UploadResource(hs, media_repo)) + self.putChild("download", DownloadResource(hs, media_repo)) + self.putChild("thumbnail", ThumbnailResource(hs, media_repo)) self.putChild("identicon", IdenticonResource()) + if hs.config.url_preview_enabled: + self.putChild("preview_url", PreviewUrlResource(hs, media_repo)) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py new file mode 100644 index 0000000000..6a5a57102f --- /dev/null +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -0,0 +1,547 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.web.server import NOT_DONE_YET +from twisted.internet import defer +from twisted.web.resource import Resource + +from synapse.api.errors import ( + SynapseError, Codes, +) +from synapse.util.stringutils import random_string +from synapse.util.caches.expiringcache import ExpiringCache +from synapse.http.client import SpiderHttpClient +from synapse.http.server import ( + request_handler, respond_with_json_bytes +) +from synapse.util.async import ObservableDeferred +from synapse.util.stringutils import is_ascii + +import os +import re +import fnmatch +import cgi +import ujson as json +import urlparse +import itertools + +import logging +logger = logging.getLogger(__name__) + + +class PreviewUrlResource(Resource): + isLeaf = True + + def __init__(self, hs, media_repo): + Resource.__init__(self) + + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.version_string = hs.version_string + self.filepaths = media_repo.filepaths + self.max_spider_size = hs.config.max_spider_size + self.server_name = hs.hostname + self.store = hs.get_datastore() + self.client = SpiderHttpClient(hs) + self.media_repo = media_repo + + self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist + + # simple memory cache mapping urls to OG metadata + self.cache = ExpiringCache( + cache_name="url_previews", + clock=self.clock, + # don't spider URLs more often than once an hour + expiry_ms=60 * 60 * 1000, + ) + self.cache.start() + + self.downloads = {} + + def render_GET(self, request): + self._async_render_GET(request) + return NOT_DONE_YET + + @request_handler() + @defer.inlineCallbacks + def _async_render_GET(self, request): + + # XXX: if get_user_by_req fails, what should we do in an async render? + requester = yield self.auth.get_user_by_req(request) + url = request.args.get("url")[0] + if "ts" in request.args: + ts = int(request.args.get("ts")[0]) + else: + ts = self.clock.time_msec() + + url_tuple = urlparse.urlsplit(url) + for entry in self.url_preview_url_blacklist: + match = True + for attrib in entry: + pattern = entry[attrib] + value = getattr(url_tuple, attrib) + logger.debug(( + "Matching attrib '%s' with value '%s' against" + " pattern '%s'" + ) % (attrib, value, pattern)) + + if value is None: + match = False + continue + + if pattern.startswith('^'): + if not re.match(pattern, getattr(url_tuple, attrib)): + match = False + continue + else: + if not fnmatch.fnmatch(getattr(url_tuple, attrib), pattern): + match = False + continue + if match: + logger.warn( + "URL %s blocked by url_blacklist entry %s", url, entry + ) + raise SynapseError( + 403, "URL blocked by url pattern blacklist entry", + Codes.UNKNOWN + ) + + # first check the memory cache - good to handle all the clients on this + # HS thundering away to preview the same URL at the same time. + og = self.cache.get(url) + if og: + respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True) + return + + # then check the URL cache in the DB (which will also provide us with + # historical previews, if we have any) + cache_result = yield self.store.get_url_cache(url, ts) + if ( + cache_result and + cache_result["download_ts"] + cache_result["expires"] > ts and + cache_result["response_code"] / 100 == 2 + ): + respond_with_json_bytes( + request, 200, cache_result["og"].encode('utf-8'), + send_cors=True + ) + return + + # Ensure only one download for a given URL is active at a time + download = self.downloads.get(url) + if download is None: + download = self._download_url(url, requester.user) + download = ObservableDeferred( + download, + consumeErrors=True + ) + self.downloads[url] = download + + @download.addBoth + def callback(media_info): + del self.downloads[url] + return media_info + media_info = yield download.observe() + + # FIXME: we should probably update our cache now anyway, so that + # even if the OG calculation raises, we don't keep hammering on the + # remote server. For now, leave it uncached to aid debugging OG + # calculation problems + + logger.debug("got media_info of '%s'" % media_info) + + if _is_media(media_info['media_type']): + dims = yield self.media_repo._generate_local_thumbnails( + media_info['filesystem_id'], media_info + ) + + og = { + "og:description": media_info['download_name'], + "og:image": "mxc://%s/%s" % ( + self.server_name, media_info['filesystem_id'] + ), + "og:image:type": media_info['media_type'], + "matrix:image:size": media_info['media_length'], + } + + if dims: + og["og:image:width"] = dims['width'] + og["og:image:height"] = dims['height'] + else: + logger.warn("Couldn't get dims for %s" % url) + + # define our OG response for this media + elif _is_html(media_info['media_type']): + # TODO: somehow stop a big HTML tree from exploding synapse's RAM + + file = open(media_info['filename']) + body = file.read() + file.close() + + # clobber the encoding from the content-type, or default to utf-8 + # XXX: this overrides any <meta/> or XML charset headers in the body + # which may pose problems, but so far seems to work okay. + match = re.match(r'.*; *charset=(.*?)(;|$)', media_info['media_type'], re.I) + encoding = match.group(1) if match else "utf-8" + + og = decode_and_calc_og(body, media_info['uri'], encoding) + + # pre-cache the image for posterity + # FIXME: it might be cleaner to use the same flow as the main /preview_url + # request itself and benefit from the same caching etc. But for now we + # just rely on the caching on the master request to speed things up. + if 'og:image' in og and og['og:image']: + image_info = yield self._download_url( + _rebase_url(og['og:image'], media_info['uri']), requester.user + ) + + if _is_media(image_info['media_type']): + # TODO: make sure we don't choke on white-on-transparent images + dims = yield self.media_repo._generate_local_thumbnails( + image_info['filesystem_id'], image_info + ) + if dims: + og["og:image:width"] = dims['width'] + og["og:image:height"] = dims['height'] + else: + logger.warn("Couldn't get dims for %s" % og["og:image"]) + + og["og:image"] = "mxc://%s/%s" % ( + self.server_name, image_info['filesystem_id'] + ) + og["og:image:type"] = image_info['media_type'] + og["matrix:image:size"] = image_info['media_length'] + else: + del og["og:image"] + else: + logger.warn("Failed to find any OG data in %s", url) + og = {} + + logger.debug("Calculated OG for %s as %s" % (url, og)) + + # store OG in ephemeral in-memory cache + self.cache[url] = og + + # store OG in history-aware DB cache + yield self.store.store_url_cache( + url, + media_info["response_code"], + media_info["etag"], + media_info["expires"], + json.dumps(og), + media_info["filesystem_id"], + media_info["created_ts"], + ) + + respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True) + + @defer.inlineCallbacks + def _download_url(self, url, user): + # TODO: we should probably honour robots.txt... except in practice + # we're most likely being explicitly triggered by a human rather than a + # bot, so are we really a robot? + + # XXX: horrible duplication with base_resource's _download_remote_file() + file_id = random_string(24) + + fname = self.filepaths.local_media_filepath(file_id) + self.media_repo._makedirs(fname) + + try: + with open(fname, "wb") as f: + logger.debug("Trying to get url '%s'" % url) + length, headers, uri, code = yield self.client.get_file( + url, output_stream=f, max_size=self.max_spider_size, + ) + # FIXME: pass through 404s and other error messages nicely + + media_type = headers["Content-Type"][0] + time_now_ms = self.clock.time_msec() + + content_disposition = headers.get("Content-Disposition", None) + if content_disposition: + _, params = cgi.parse_header(content_disposition[0],) + download_name = None + + # First check if there is a valid UTF-8 filename + download_name_utf8 = params.get("filename*", None) + if download_name_utf8: + if download_name_utf8.lower().startswith("utf-8''"): + download_name = download_name_utf8[7:] + + # If there isn't check for an ascii name. + if not download_name: + download_name_ascii = params.get("filename", None) + if download_name_ascii and is_ascii(download_name_ascii): + download_name = download_name_ascii + + if download_name: + download_name = urlparse.unquote(download_name) + try: + download_name = download_name.decode("utf-8") + except UnicodeDecodeError: + download_name = None + else: + download_name = None + + yield self.store.store_local_media( + media_id=file_id, + media_type=media_type, + time_now_ms=self.clock.time_msec(), + upload_name=download_name, + media_length=length, + user_id=user, + ) + + except Exception as e: + os.remove(fname) + raise SynapseError( + 500, ("Failed to download content: %s" % e), + Codes.UNKNOWN + ) + + defer.returnValue({ + "media_type": media_type, + "media_length": length, + "download_name": download_name, + "created_ts": time_now_ms, + "filesystem_id": file_id, + "filename": fname, + "uri": uri, + "response_code": code, + # FIXME: we should calculate a proper expiration based on the + # Cache-Control and Expire headers. But for now, assume 1 hour. + "expires": 60 * 60 * 1000, + "etag": headers["ETag"][0] if "ETag" in headers else None, + }) + + +def decode_and_calc_og(body, media_uri, request_encoding=None): + from lxml import etree + + try: + parser = etree.HTMLParser(recover=True, encoding=request_encoding) + tree = etree.fromstring(body, parser) + og = _calc_og(tree, media_uri) + except UnicodeDecodeError: + # blindly try decoding the body as utf-8, which seems to fix + # the charset mismatches on https://google.com + parser = etree.HTMLParser(recover=True, encoding=request_encoding) + tree = etree.fromstring(body.decode('utf-8', 'ignore'), parser) + og = _calc_og(tree, media_uri) + + return og + + +def _calc_og(tree, media_uri): + # suck our tree into lxml and define our OG response. + + # if we see any image URLs in the OG response, then spider them + # (although the client could choose to do this by asking for previews of those + # URLs to avoid DoSing the server) + + # "og:type" : "video", + # "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw", + # "og:site_name" : "YouTube", + # "og:video:type" : "application/x-shockwave-flash", + # "og:description" : "Fun stuff happening here", + # "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon", + # "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg", + # "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1", + # "og:video:width" : "1280" + # "og:video:height" : "720", + # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3", + + og = {} + for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"): + if 'content' in tag.attrib: + og[tag.attrib['property']] = tag.attrib['content'] + + # TODO: grab article: meta tags too, e.g.: + + # "article:publisher" : "https://www.facebook.com/thethudonline" /> + # "article:author" content="https://www.facebook.com/thethudonline" /> + # "article:tag" content="baby" /> + # "article:section" content="Breaking News" /> + # "article:published_time" content="2016-03-31T19:58:24+00:00" /> + # "article:modified_time" content="2016-04-01T18:31:53+00:00" /> + + if 'og:title' not in og: + # do some basic spidering of the HTML + title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]") + og['og:title'] = title[0].text.strip() if title else None + + if 'og:image' not in og: + # TODO: extract a favicon failing all else + meta_image = tree.xpath( + "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content" + ) + if meta_image: + og['og:image'] = _rebase_url(meta_image[0], media_uri) + else: + # TODO: consider inlined CSS styles as well as width & height attribs + images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]") + images = sorted(images, key=lambda i: ( + -1 * float(i.attrib['width']) * float(i.attrib['height']) + )) + if not images: + images = tree.xpath("//img[@src]") + if images: + og['og:image'] = images[0].attrib['src'] + + if 'og:description' not in og: + meta_description = tree.xpath( + "//*/meta" + "[translate(@name, 'DESCRIPTION', 'description')='description']" + "/@content") + if meta_description: + og['og:description'] = meta_description[0] + else: + # grab any text nodes which are inside the <body/> tag... + # unless they are within an HTML5 semantic markup tag... + # <header/>, <nav/>, <aside/>, <footer/> + # ...or if they are within a <script/> or <style/> tag. + # This is a very very very coarse approximation to a plain text + # render of the page. + + # We don't just use XPATH here as that is slow on some machines. + + from lxml import etree + + TAGS_TO_REMOVE = ( + "header", "nav", "aside", "footer", "script", "style", etree.Comment + ) + + # Split all the text nodes into paragraphs (by splitting on new + # lines) + text_nodes = ( + re.sub(r'\s+', '\n', el).strip() + for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE) + ) + og['og:description'] = summarize_paragraphs(text_nodes) + + # TODO: delete the url downloads to stop diskfilling, + # as we only ever cared about its OG + return og + + +def _iterate_over_text(tree, *tags_to_ignore): + """Iterate over the tree returning text nodes in a depth first fashion, + skipping text nodes inside certain tags. + """ + # This is basically a stack that we extend using itertools.chain. + # This will either consist of an element to iterate over *or* a string + # to be returned. + elements = iter([tree]) + while True: + el = elements.next() + if isinstance(el, basestring): + yield el + elif el is not None and el.tag not in tags_to_ignore: + # el.text is the text before the first child, so we can immediately + # return it if the text exists. + if el.text: + yield el.text + + # We add to the stack all the elements children, interspersed with + # each child's tail text (if it exists). The tail text of a node + # is text that comes *after* the node, so we always include it even + # if we ignore the child node. + elements = itertools.chain( + itertools.chain.from_iterable( # Basically a flatmap + [child, child.tail] if child.tail else [child] + for child in el.iterchildren() + ), + elements + ) + + +def _rebase_url(url, base): + base = list(urlparse.urlparse(base)) + url = list(urlparse.urlparse(url)) + if not url[0]: # fix up schema + url[0] = base[0] or "http" + if not url[1]: # fix up hostname + url[1] = base[1] + if not url[2].startswith('/'): + url[2] = re.sub(r'/[^/]+$', '/', base[2]) + url[2] + return urlparse.urlunparse(url) + + +def _is_media(content_type): + if content_type.lower().startswith("image/"): + return True + + +def _is_html(content_type): + content_type = content_type.lower() + if ( + content_type.startswith("text/html") or + content_type.startswith("application/xhtml") + ): + return True + + +def summarize_paragraphs(text_nodes, min_size=200, max_size=500): + # Try to get a summary of between 200 and 500 words, respecting + # first paragraph and then word boundaries. + # TODO: Respect sentences? + + description = '' + + # Keep adding paragraphs until we get to the MIN_SIZE. + for text_node in text_nodes: + if len(description) < min_size: + text_node = re.sub(r'[\t \r\n]+', ' ', text_node) + description += text_node + '\n\n' + else: + break + + description = description.strip() + description = re.sub(r'[\t ]+', ' ', description) + description = re.sub(r'[\t \r\n]*[\r\n]+', '\n\n', description) + + # If the concatenation of paragraphs to get above MIN_SIZE + # took us over MAX_SIZE, then we need to truncate mid paragraph + if len(description) > max_size: + new_desc = "" + + # This splits the paragraph into words, but keeping the + # (preceeding) whitespace intact so we can easily concat + # words back together. + for match in re.finditer("\s*\S+", description): + word = match.group() + + # Keep adding words while the total length is less than + # MAX_SIZE. + if len(word) + len(new_desc) < max_size: + new_desc += word + else: + # At this point the next word *will* take us over + # MAX_SIZE, but we also want to ensure that its not + # a huge word. If it is add it anyway and we'll + # truncate later. + if len(new_desc) < min_size: + new_desc += word + break + + # Double check that we're not over the limit + if len(new_desc) > max_size: + new_desc = new_desc[:max_size] + + # We always add an ellipsis because at the very least + # we chopped mid paragraph. + description = new_desc.strip() + u"…" + return description if description else None diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index ab52499785..d8f54adc99 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -14,9 +14,10 @@ # limitations under the License. -from .base_resource import BaseMediaResource, parse_media_id +from ._base import parse_media_id, respond_404, respond_with_file +from twisted.web.resource import Resource from synapse.http.servlet import parse_string, parse_integer -from synapse.http.server import request_handler +from synapse.http.server import request_handler, set_cors_headers from twisted.web.server import NOT_DONE_YET from twisted.internet import defer @@ -26,16 +27,28 @@ import logging logger = logging.getLogger(__name__) -class ThumbnailResource(BaseMediaResource): +class ThumbnailResource(Resource): isLeaf = True + def __init__(self, hs, media_repo): + Resource.__init__(self) + + self.store = hs.get_datastore() + self.filepaths = media_repo.filepaths + self.media_repo = media_repo + self.dynamic_thumbnails = hs.config.dynamic_thumbnails + self.server_name = hs.hostname + self.version_string = hs.version_string + self.clock = hs.get_clock() + def render_GET(self, request): self._async_render_GET(request) return NOT_DONE_YET - @request_handler + @request_handler() @defer.inlineCallbacks def _async_render_GET(self, request): + set_cors_headers(request) server_name, media_id, _ = parse_media_id(request) width = parse_integer(request, "width") height = parse_integer(request, "height") @@ -69,9 +82,14 @@ class ThumbnailResource(BaseMediaResource): media_info = yield self.store.get_local_media(media_id) if not media_info: - self._respond_404(request) + respond_404(request) return + # if media_info["media_type"] == "image/svg+xml": + # file_path = self.filepaths.local_media_filepath(media_id) + # yield respond_with_file(request, media_info["media_type"], file_path) + # return + thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id) if thumbnail_infos: @@ -86,7 +104,7 @@ class ThumbnailResource(BaseMediaResource): file_path = self.filepaths.local_media_thumbnail( media_id, t_width, t_height, t_type, t_method, ) - yield self._respond_with_file(request, t_type, file_path) + yield respond_with_file(request, t_type, file_path) else: yield self._respond_default_thumbnail( @@ -100,9 +118,14 @@ class ThumbnailResource(BaseMediaResource): media_info = yield self.store.get_local_media(media_id) if not media_info: - self._respond_404(request) + respond_404(request) return + # if media_info["media_type"] == "image/svg+xml": + # file_path = self.filepaths.local_media_filepath(media_id) + # yield respond_with_file(request, media_info["media_type"], file_path) + # return + thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id) for info in thumbnail_infos: t_w = info["thumbnail_width"] == desired_width @@ -114,18 +137,18 @@ class ThumbnailResource(BaseMediaResource): file_path = self.filepaths.local_media_thumbnail( media_id, desired_width, desired_height, desired_type, desired_method, ) - yield self._respond_with_file(request, desired_type, file_path) + yield respond_with_file(request, desired_type, file_path) return logger.debug("We don't have a local thumbnail of that size. Generating") # Okay, so we generate one. - file_path = yield self._generate_local_exact_thumbnail( + file_path = yield self.media_repo.generate_local_exact_thumbnail( media_id, desired_width, desired_height, desired_method, desired_type ) if file_path: - yield self._respond_with_file(request, desired_type, file_path) + yield respond_with_file(request, desired_type, file_path) else: yield self._respond_default_thumbnail( request, media_info, desired_width, desired_height, @@ -136,7 +159,12 @@ class ThumbnailResource(BaseMediaResource): def _select_or_generate_remote_thumbnail(self, request, server_name, media_id, desired_width, desired_height, desired_method, desired_type): - media_info = yield self._get_remote_media(server_name, media_id) + media_info = yield self.media_repo.get_remote_media(server_name, media_id) + + # if media_info["media_type"] == "image/svg+xml": + # file_path = self.filepaths.remote_media_filepath(server_name, media_id) + # yield respond_with_file(request, media_info["media_type"], file_path) + # return thumbnail_infos = yield self.store.get_remote_media_thumbnails( server_name, media_id, @@ -155,19 +183,19 @@ class ThumbnailResource(BaseMediaResource): server_name, file_id, desired_width, desired_height, desired_type, desired_method, ) - yield self._respond_with_file(request, desired_type, file_path) + yield respond_with_file(request, desired_type, file_path) return logger.debug("We don't have a local thumbnail of that size. Generating") # Okay, so we generate one. - file_path = yield self._generate_remote_exact_thumbnail( + file_path = yield self.media_repo.generate_remote_exact_thumbnail( server_name, file_id, media_id, desired_width, desired_height, desired_method, desired_type ) if file_path: - yield self._respond_with_file(request, desired_type, file_path) + yield respond_with_file(request, desired_type, file_path) else: yield self._respond_default_thumbnail( request, media_info, desired_width, desired_height, @@ -179,7 +207,12 @@ class ThumbnailResource(BaseMediaResource): height, method, m_type): # TODO: Don't download the whole remote file # We should proxy the thumbnail from the remote server instead. - media_info = yield self._get_remote_media(server_name, media_id) + media_info = yield self.media_repo.get_remote_media(server_name, media_id) + + # if media_info["media_type"] == "image/svg+xml": + # file_path = self.filepaths.remote_media_filepath(server_name, media_id) + # yield respond_with_file(request, media_info["media_type"], file_path) + # return thumbnail_infos = yield self.store.get_remote_media_thumbnails( server_name, media_id, @@ -199,7 +232,7 @@ class ThumbnailResource(BaseMediaResource): file_path = self.filepaths.remote_media_thumbnail( server_name, file_id, t_width, t_height, t_type, t_method, ) - yield self._respond_with_file(request, t_type, file_path, t_length) + yield respond_with_file(request, t_type, file_path, t_length) else: yield self._respond_default_thumbnail( request, media_info, width, height, method, m_type, @@ -208,6 +241,8 @@ class ThumbnailResource(BaseMediaResource): @defer.inlineCallbacks def _respond_default_thumbnail(self, request, media_info, width, height, method, m_type): + # XXX: how is this meant to work? store.get_default_thumbnails + # appears to always return [] so won't this always 404? media_type = media_info["media_type"] top_level_type = media_type.split("/")[0] sub_type = media_type.split("/")[-1].split(";")[0] @@ -223,7 +258,7 @@ class ThumbnailResource(BaseMediaResource): "_default", "_default", ) if not thumbnail_infos: - self._respond_404(request) + respond_404(request) return thumbnail_info = self._select_thumbnail( @@ -239,7 +274,7 @@ class ThumbnailResource(BaseMediaResource): file_path = self.filepaths.default_thumbnail( top_level_type, sub_type, t_width, t_height, t_type, t_method, ) - yield self.respond_with_file(request, t_type, file_path, t_length) + yield respond_with_file(request, t_type, file_path, t_length) def _select_thumbnail(self, desired_width, desired_height, desired_method, desired_type, thumbnail_infos): diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index 9c7ad4ae85..b716d1d892 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -15,20 +15,34 @@ from synapse.http.server import respond_with_json, request_handler -from synapse.util.stringutils import random_string from synapse.api.errors import SynapseError from twisted.web.server import NOT_DONE_YET from twisted.internet import defer -from .base_resource import BaseMediaResource +from twisted.web.resource import Resource import logging logger = logging.getLogger(__name__) -class UploadResource(BaseMediaResource): +class UploadResource(Resource): + isLeaf = True + + def __init__(self, hs, media_repo): + Resource.__init__(self) + + self.media_repo = media_repo + self.filepaths = media_repo.filepaths + self.store = hs.get_datastore() + self.clock = hs.get_clock() + self.server_name = hs.hostname + self.auth = hs.get_auth() + self.max_upload_size = hs.config.max_upload_size + self.version_string = hs.version_string + self.clock = hs.get_clock() + def render_POST(self, request): self._async_render_POST(request) return NOT_DONE_YET @@ -37,37 +51,7 @@ class UploadResource(BaseMediaResource): respond_with_json(request, 200, {}, send_cors=True) return NOT_DONE_YET - @defer.inlineCallbacks - def create_content(self, media_type, upload_name, content, content_length, - auth_user): - media_id = random_string(24) - - fname = self.filepaths.local_media_filepath(media_id) - self._makedirs(fname) - - # This shouldn't block for very long because the content will have - # already been uploaded at this point. - with open(fname, "wb") as f: - f.write(content) - - yield self.store.store_local_media( - media_id=media_id, - media_type=media_type, - time_now_ms=self.clock.time_msec(), - upload_name=upload_name, - media_length=content_length, - user_id=auth_user, - ) - media_info = { - "media_type": media_type, - "media_length": content_length, - } - - yield self._generate_local_thumbnails(media_id, media_info) - - defer.returnValue("mxc://%s/%s" % (self.server_name, media_id)) - - @request_handler + @request_handler() @defer.inlineCallbacks def _async_render_POST(self, request): requester = yield self.auth.get_user_by_req(request) @@ -108,7 +92,7 @@ class UploadResource(BaseMediaResource): # disposition = headers.getRawHeaders("Content-Disposition")[0] # TODO(markjh): parse content-dispostion - content_uri = yield self.create_content( + content_uri = yield self.media_repo.create_content( media_type, upload_name, request.content.read(), content_length, requester.user ) diff --git a/synapse/server.py b/synapse/server.py index 368d615576..0bfb411269 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -19,29 +19,45 @@ # partial one for unit test mocking. # Imports required for the default HomeServer() implementation -from twisted.web.client import BrowserLikePolicyForHTTPS +import logging + from twisted.enterprise import adbapi +from twisted.web.client import BrowserLikePolicyForHTTPS +from synapse.api.auth import Auth +from synapse.api.filtering import Filtering +from synapse.api.ratelimiting import Ratelimiter +from synapse.appservice.api import ApplicationServiceApi +from synapse.appservice.scheduler import ApplicationServiceScheduler +from synapse.crypto.keyring import Keyring +from synapse.events.builder import EventBuilderFactory from synapse.federation import initialize_http_replication +from synapse.federation.send_queue import FederationRemoteSendQueue +from synapse.federation.transport.client import TransportLayerClient +from synapse.federation.transaction_queue import TransactionQueue +from synapse.handlers import Handlers +from synapse.handlers.appservice import ApplicationServicesHandler +from synapse.handlers.auth import AuthHandler +from synapse.handlers.devicemessage import DeviceMessageHandler +from synapse.handlers.device import DeviceHandler +from synapse.handlers.e2e_keys import E2eKeysHandler +from synapse.handlers.presence import PresenceHandler +from synapse.handlers.room_list import RoomListHandler +from synapse.handlers.sync import SyncHandler +from synapse.handlers.typing import TypingHandler +from synapse.handlers.events import EventHandler, EventStreamHandler +from synapse.handlers.initial_sync import InitialSyncHandler +from synapse.handlers.receipts import ReceiptsHandler from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory +from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.notifier import Notifier -from synapse.api.auth import Auth -from synapse.handlers import Handlers +from synapse.push.pusherpool import PusherPool +from synapse.rest.media.v1.media_repository import MediaRepository from synapse.state import StateHandler from synapse.storage import DataStore +from synapse.streams.events import EventSources from synapse.util import Clock from synapse.util.distributor import Distributor -from synapse.streams.events import EventSources -from synapse.api.ratelimiting import Ratelimiter -from synapse.crypto.keyring import Keyring -from synapse.push.pusherpool import PusherPool -from synapse.events.builder import EventBuilderFactory -from synapse.api.filtering import Filtering - -from synapse.http.matrixfederationclient import MatrixFederationHttpClient - -import logging - logger = logging.getLogger(__name__) @@ -78,6 +94,20 @@ class HomeServer(object): 'auth', 'rest_servlet_factory', 'state_handler', + 'presence_handler', + 'sync_handler', + 'typing_handler', + 'room_list_handler', + 'auth_handler', + 'device_handler', + 'e2e_keys_handler', + 'event_handler', + 'event_stream_handler', + 'initial_sync_handler', + 'application_service_api', + 'application_service_scheduler', + 'application_service_handler', + 'device_message_handler', 'notifier', 'distributor', 'client_resource', @@ -97,6 +127,10 @@ class HomeServer(object): 'filtering', 'http_client_context_factory', 'simple_http_client', + 'media_repository', + 'federation_transport_client', + 'federation_sender', + 'receipts_handler', ] def __init__(self, hostname, **kwargs): @@ -164,6 +198,48 @@ class HomeServer(object): def build_state_handler(self): return StateHandler(self) + def build_presence_handler(self): + return PresenceHandler(self) + + def build_typing_handler(self): + return TypingHandler(self) + + def build_sync_handler(self): + return SyncHandler(self) + + def build_room_list_handler(self): + return RoomListHandler(self) + + def build_auth_handler(self): + return AuthHandler(self) + + def build_device_handler(self): + return DeviceHandler(self) + + def build_device_message_handler(self): + return DeviceMessageHandler(self) + + def build_e2e_keys_handler(self): + return E2eKeysHandler(self) + + def build_application_service_api(self): + return ApplicationServiceApi(self) + + def build_application_service_scheduler(self): + return ApplicationServiceScheduler(self) + + def build_application_service_handler(self): + return ApplicationServicesHandler(self) + + def build_event_handler(self): + return EventHandler(self) + + def build_event_stream_handler(self): + return EventStreamHandler(self) + + def build_initial_sync_handler(self): + return InitialSyncHandler(self) + def build_event_sources(self): return EventSources(self) @@ -193,6 +269,33 @@ class HomeServer(object): **self.db_config.get("args", {}) ) + def build_media_repository(self): + return MediaRepository(self) + + def build_federation_transport_client(self): + return TransportLayerClient(self) + + def build_federation_sender(self): + if self.should_send_federation(): + return TransactionQueue(self) + elif not self.config.worker_app: + return FederationRemoteSendQueue(self) + else: + raise Exception("Workers cannot send federation traffic") + + def build_receipts_handler(self): + return ReceiptsHandler(self) + + def remove_pusher(self, app_id, push_key, user_id): + return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) + + def should_send_federation(self): + "Should this server be sending federation traffic directly?" + return self.config.send_federation and ( + not self.config.worker_app + or self.config.worker_app == "synapse.app.federation_sender" + ) + def _make_dependency_method(depname): def _get(hs): diff --git a/synapse/server.pyi b/synapse/server.pyi new file mode 100644 index 0000000000..9570df5537 --- /dev/null +++ b/synapse/server.pyi @@ -0,0 +1,29 @@ +import synapse.api.auth +import synapse.handlers +import synapse.handlers.auth +import synapse.handlers.device +import synapse.handlers.e2e_keys +import synapse.storage +import synapse.state + +class HomeServer(object): + def get_auth(self) -> synapse.api.auth.Auth: + pass + + def get_auth_handler(self) -> synapse.handlers.auth.AuthHandler: + pass + + def get_datastore(self) -> synapse.storage.DataStore: + pass + + def get_device_handler(self) -> synapse.handlers.device.DeviceHandler: + pass + + def get_e2e_keys_handler(self) -> synapse.handlers.e2e_keys.E2eKeysHandler: + pass + + def get_handlers(self) -> synapse.handlers.Handlers: + pass + + def get_state_handler(self) -> synapse.state.StateHandler: + pass diff --git a/synapse/state.py b/synapse/state.py index b9a1387520..b4eca0e5d5 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -18,15 +18,19 @@ from twisted.internet import defer from synapse.util.logutils import log_function from synapse.util.caches.expiringcache import ExpiringCache +from synapse.util.metrics import Measure from synapse.api.constants import EventTypes from synapse.api.errors import AuthError from synapse.api.auth import AuthEventTypes from synapse.events.snapshot import EventContext +from synapse.util.async import Linearizer from collections import namedtuple +from frozendict import frozendict import logging import hashlib +import os logger = logging.getLogger(__name__) @@ -34,15 +38,45 @@ logger = logging.getLogger(__name__) KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) -SIZE_OF_CACHE = 1000 -EVICTION_TIMEOUT_SECONDS = 20 +CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) + + +SIZE_OF_CACHE = int(1000 * CACHE_SIZE_FACTOR) +EVICTION_TIMEOUT_SECONDS = 60 * 60 + + +_NEXT_STATE_ID = 1 + + +def _gen_state_id(): + global _NEXT_STATE_ID + s = "X%d" % (_NEXT_STATE_ID,) + _NEXT_STATE_ID += 1 + return s class _StateCacheEntry(object): - def __init__(self, state, state_group, ts): - self.state = state + __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"] + + def __init__(self, state, state_group, prev_group=None, delta_ids=None): + self.state = frozendict(state) self.state_group = state_group + self.prev_group = prev_group + self.delta_ids = frozendict(delta_ids) if delta_ids is not None else None + + # The `state_id` is a unique ID we generate that can be used as ID for + # this collection of state. Usually this would be the same as the + # state group, but on worker instances we can't generate a new state + # group each time we resolve state, so we generate a separate one that + # isn't persisted and is used solely for caches. + # `state_id` is either a state_group (and so an int) or a string. This + # ensures we don't accidentally persist a state_id as a stateg_group + if state_group: + self.state_id = state_group + else: + self.state_id = _gen_state_id() + class StateHandler(object): """ Responsible for doing state conflict resolution. @@ -55,6 +89,7 @@ class StateHandler(object): # dict of set of event_ids -> _StateCacheEntry. self._state_cache = None + self.resolve_linearizer = Linearizer() def start_caching(self): logger.debug("start_caching") @@ -70,7 +105,8 @@ class StateHandler(object): self._state_cache.start() @defer.inlineCallbacks - def get_current_state(self, room_id, event_type=None, state_key=""): + def get_current_state(self, room_id, event_type=None, state_key="", + latest_event_ids=None): """ Retrieves the current state for the room. This is done by calling `get_latest_events_in_room` to get the leading edges of the event graph and then resolving any of the state conflicts. @@ -81,20 +117,38 @@ class StateHandler(object): If `event_type` is specified, then the method returns only the one event (or None) with that `event_type` and `state_key`. - :returns map from (type, state_key) to event + Returns: + map from (type, state_key) to event """ - event_ids = yield self.store.get_latest_event_ids_in_room(room_id) + if not latest_event_ids: + latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - cache = None - if self._state_cache is not None: - cache = self._state_cache.get(frozenset(event_ids), None) + ret = yield self.resolve_state_groups(room_id, latest_event_ids) + state = ret.state - if cache: - cache.ts = self.clock.time_msec() - state = cache.state - else: - res = yield self.resolve_state_groups(room_id, event_ids) - state = res[1] + if event_type: + event_id = state.get((event_type, state_key)) + event = None + if event_id: + event = yield self.store.get_event(event_id, allow_none=True) + defer.returnValue(event) + return + + state_map = yield self.store.get_events(state.values(), get_prev_content=False) + state = { + key: state_map[e_id] for key, e_id in state.items() if e_id in state_map + } + + defer.returnValue(state) + + @defer.inlineCallbacks + def get_current_state_ids(self, room_id, event_type=None, state_key="", + latest_event_ids=None): + if not latest_event_ids: + latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) + + ret = yield self.resolve_state_groups(room_id, latest_event_ids) + state = ret.state if event_type: defer.returnValue(state.get((event_type, state_key))) @@ -103,7 +157,17 @@ class StateHandler(object): defer.returnValue(state) @defer.inlineCallbacks - def compute_event_context(self, event, old_state=None, outlier=False): + def get_current_user_in_room(self, room_id, latest_event_ids=None): + if not latest_event_ids: + latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) + entry = yield self.resolve_state_groups(room_id, latest_event_ids) + joined_users = yield self.store.get_joined_users_from_state( + room_id, entry.state_id, entry.state + ) + defer.returnValue(joined_users) + + @defer.inlineCallbacks + def compute_event_context(self, event, old_state=None): """ Fills out the context with the `current state` of the graph. The `current state` here is defined to be the state of the event graph just before the event - i.e. it never includes `event` @@ -118,59 +182,88 @@ class StateHandler(object): """ context = EventContext() - if outlier: + if event.internal_metadata.is_outlier(): # If this is an outlier, then we know it shouldn't have any current # state. Certainly store.get_current_state won't return any, and # persisting the event won't store the state group. if old_state: - context.current_state = { - (s.type, s.state_key): s for s in old_state + context.prev_state_ids = { + (s.type, s.state_key): s.event_id for s in old_state } + if event.is_state(): + context.current_state_events = dict(context.prev_state_ids) + key = (event.type, event.state_key) + context.current_state_events[key] = event.event_id + else: + context.current_state_events = context.prev_state_ids else: - context.current_state = {} + context.current_state_ids = {} + context.prev_state_ids = {} context.prev_state_events = [] - context.state_group = None + context.state_group = self.store.get_next_state_group() defer.returnValue(context) if old_state: - context.current_state = { - (s.type, s.state_key): s for s in old_state + context.prev_state_ids = { + (s.type, s.state_key): s.event_id for s in old_state } - context.state_group = None + context.state_group = self.store.get_next_state_group() if event.is_state(): key = (event.type, event.state_key) - if key in context.current_state: - replaces = context.current_state[key] - if replaces.event_id != event.event_id: # Paranoia check - event.unsigned["replaces_state"] = replaces.event_id + if key in context.prev_state_ids: + replaces = context.prev_state_ids[key] + if replaces != event.event_id: # Paranoia check + event.unsigned["replaces_state"] = replaces + context.current_state_ids = dict(context.prev_state_ids) + context.current_state_ids[key] = event.event_id + else: + context.current_state_ids = context.prev_state_ids context.prev_state_events = [] defer.returnValue(context) if event.is_state(): - ret = yield self.resolve_state_groups( + entry = yield self.resolve_state_groups( event.room_id, [e for e, _ in event.prev_events], event_type=event.type, state_key=event.state_key, ) else: - ret = yield self.resolve_state_groups( + entry = yield self.resolve_state_groups( event.room_id, [e for e, _ in event.prev_events], ) - group, curr_state, prev_state = ret - - context.current_state = curr_state - context.state_group = group if not event.is_state() else None + curr_state = entry.state + context.prev_state_ids = curr_state if event.is_state(): + context.state_group = self.store.get_next_state_group() + key = (event.type, event.state_key) - if key in context.current_state: - replaces = context.current_state[key] - event.unsigned["replaces_state"] = replaces.event_id + if key in context.prev_state_ids: + replaces = context.prev_state_ids[key] + event.unsigned["replaces_state"] = replaces + + context.current_state_ids = dict(context.prev_state_ids) + context.current_state_ids[key] = event.event_id + + context.prev_group = entry.prev_group + context.delta_ids = entry.delta_ids + if context.delta_ids is not None: + context.delta_ids = dict(context.delta_ids) + context.delta_ids[key] = event.event_id + else: + if entry.state_group is None: + entry.state_group = self.store.get_next_state_group() + entry.state_id = entry.state_group + + context.state_group = entry.state_group + context.current_state_ids = context.prev_state_ids + context.prev_group = entry.prev_group + context.delta_ids = entry.delta_ids - context.prev_state_events = prev_state + context.prev_state_events = [] defer.returnValue(context) @defer.inlineCallbacks @@ -179,77 +272,118 @@ class StateHandler(object): """ Given a list of event_ids this method fetches the state at each event, resolves conflicts between them and returns them. - :returns a Deferred tuple of (`state_group`, `state`, `prev_state`). - `state_group` is the name of a state group if one and only one is - involved. `state` is a map from (type, state_key) to event, and - `prev_state` is a list of event ids. + Returns: + a Deferred tuple of (`state_group`, `state`, `prev_state`). + `state_group` is the name of a state group if one and only one is + involved. `state` is a map from (type, state_key) to event, and + `prev_state` is a list of event ids. """ logger.debug("resolve_state_groups event_ids %s", event_ids) - if self._state_cache is not None: - cache = self._state_cache.get(frozenset(event_ids), None) - if cache and cache.state_group: - cache.ts = self.clock.time_msec() - prev_state = cache.state.get((event_type, state_key), None) - if prev_state: - prev_state = prev_state.event_id - prev_states = [prev_state] - else: - prev_states = [] - defer.returnValue( - (cache.state_group, cache.state, prev_states) - ) - - state_groups = yield self.store.get_state_groups( + state_groups_ids = yield self.store.get_state_groups_ids( room_id, event_ids ) logger.debug( "resolve_state_groups state_groups %s", - state_groups.keys() + state_groups_ids.keys() ) - group_names = set(state_groups.keys()) + group_names = frozenset(state_groups_ids.keys()) if len(group_names) == 1: - name, state_list = state_groups.items().pop() - state = { - (e.type, e.state_key): e - for e in state_list - } - prev_state = state.get((event_type, state_key), None) - if prev_state: - prev_state = prev_state.event_id - prev_states = [prev_state] - else: - prev_states = [] + name, state_list = state_groups_ids.items().pop() + + defer.returnValue(_StateCacheEntry( + state=state_list, + state_group=name, + prev_group=name, + delta_ids={}, + )) + with (yield self.resolve_linearizer.queue(group_names)): if self._state_cache is not None: - cache = _StateCacheEntry( - state=state, - state_group=name, - ts=self.clock.time_msec() - ) + cache = self._state_cache.get(group_names, None) + if cache: + defer.returnValue(cache) - self._state_cache[frozenset(event_ids)] = cache + logger.info( + "Resolving state for %s with %d groups", room_id, len(state_groups_ids) + ) - defer.returnValue((name, state, prev_states)) + state = {} + for st in state_groups_ids.values(): + for key, e_id in st.items(): + state.setdefault(key, set()).add(e_id) - new_state, prev_states = self._resolve_events( - state_groups.values(), event_type, state_key - ) + conflicted_state = { + k: list(v) + for k, v in state.items() + if len(v) > 1 + } + + if conflicted_state: + logger.info("Resolving conflicted state for %r", room_id) + state_map = yield self.store.get_events( + [e_id for st in state_groups_ids.values() for e_id in st.values()], + get_prev_content=False + ) + state_sets = [ + [state_map[e_id] for key, e_id in st.items() if e_id in state_map] + for st in state_groups_ids.values() + ] + new_state, _ = self._resolve_events( + state_sets, event_type, state_key + ) + new_state = { + key: e.event_id for key, e in new_state.items() + } + else: + new_state = { + key: e_ids.pop() for key, e_ids in state.items() + } + + state_group = None + new_state_event_ids = frozenset(new_state.values()) + for sg, events in state_groups_ids.items(): + if new_state_event_ids == frozenset(e_id for e_id in events): + state_group = sg + break + if state_group is None: + # Worker instances don't have access to this method, but we want + # to set the state_group on the main instance to increase cache + # hits. + if hasattr(self.store, "get_next_state_group"): + state_group = self.store.get_next_state_group() + + prev_group = None + delta_ids = None + for old_group, old_ids in state_groups_ids.items(): + if not set(new_state.iterkeys()) - set(old_ids.iterkeys()): + n_delta_ids = { + k: v + for k, v in new_state.items() + if old_ids.get(k) != v + } + if not delta_ids or len(n_delta_ids) < len(delta_ids): + prev_group = old_group + delta_ids = n_delta_ids - if self._state_cache is not None: cache = _StateCacheEntry( state=new_state, - state_group=None, - ts=self.clock.time_msec() + state_group=state_group, + prev_group=prev_group, + delta_ids=delta_ids, ) - self._state_cache[frozenset(event_ids)] = cache + if self._state_cache is not None: + self._state_cache[group_names] = cache - defer.returnValue((None, new_state, prev_states)) + defer.returnValue(cache) def resolve_events(self, state_sets, event): + logger.info( + "Resolving state for %s with %d groups", event.room_id, len(state_sets) + ) if event.is_state(): return self._resolve_events( state_sets, event.type, event.state_key @@ -259,52 +393,54 @@ class StateHandler(object): def _resolve_events(self, state_sets, event_type=None, state_key=""): """ - :returns a tuple (new_state, prev_states). new_state is a map - from (type, state_key) to event. prev_states is a list of event_ids. - :rtype: (dict[(str, str), synapse.events.FrozenEvent], list[str]) + Returns + (dict[(str, str), synapse.events.FrozenEvent], list[str]): a tuple + (new_state, prev_states). new_state is a map from (type, state_key) + to event. prev_states is a list of event_ids. """ - state = {} - for st in state_sets: - for e in st: - state.setdefault( - (e.type, e.state_key), - {} - )[e.event_id] = e - - unconflicted_state = { - k: v.values()[0] for k, v in state.items() - if len(v.values()) == 1 - } + with Measure(self.clock, "state._resolve_events"): + state = {} + for st in state_sets: + for e in st: + state.setdefault( + (e.type, e.state_key), + {} + )[e.event_id] = e + + unconflicted_state = { + k: v.values()[0] for k, v in state.items() + if len(v.values()) == 1 + } - conflicted_state = { - k: v.values() - for k, v in state.items() - if len(v.values()) > 1 - } + conflicted_state = { + k: v.values() + for k, v in state.items() + if len(v.values()) > 1 + } - if event_type: - prev_states_events = conflicted_state.get( - (event_type, state_key), [] - ) - prev_states = [s.event_id for s in prev_states_events] - else: - prev_states = [] + if event_type: + prev_states_events = conflicted_state.get( + (event_type, state_key), [] + ) + prev_states = [s.event_id for s in prev_states_events] + else: + prev_states = [] - auth_events = { - k: e for k, e in unconflicted_state.items() - if k[0] in AuthEventTypes - } + auth_events = { + k: e for k, e in unconflicted_state.items() + if k[0] in AuthEventTypes + } - try: - resolved_state = self._resolve_state_events( - conflicted_state, auth_events - ) - except: - logger.exception("Failed to resolve state") - raise + try: + resolved_state = self._resolve_state_events( + conflicted_state, auth_events + ) + except: + logger.exception("Failed to resolve state") + raise - new_state = unconflicted_state - new_state.update(resolved_state) + new_state = unconflicted_state + new_state.update(resolved_state) return new_state, prev_states @@ -369,7 +505,8 @@ class StateHandler(object): try: # FIXME: hs.get_auth() is bad style, but we need to do it to # get around circular deps. - self.hs.get_auth().check(event, auth_events) + # The signatures have already been checked at this point + self.hs.get_auth().check(event, auth_events, do_sig_check=False) prev_event = event except AuthError: return prev_event @@ -381,7 +518,8 @@ class StateHandler(object): try: # FIXME: hs.get_auth() is bad style, but we need to do it to # get around circular deps. - self.hs.get_auth().check(event, auth_events) + # The signatures have already been checked at this point + self.hs.get_auth().check(event, auth_events, do_sig_check=False) return event except AuthError: pass diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 250ba536ea..db146ed348 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -14,10 +14,12 @@ # limitations under the License. from twisted.internet import defer + +from synapse.storage.devices import DeviceStore from .appservice import ( ApplicationServiceStore, ApplicationServiceTransactionStore ) -from ._base import Cache +from ._base import LoggingTransaction from .directory import DirectoryStore from .events import EventsStore from .presence import PresenceStore, UserPresenceState @@ -34,6 +36,7 @@ from .push_rule import PushRuleStore from .media_repository import MediaRepositoryStore from .rejections import RejectionsStore from .event_push_actions import EventPushActionsStore +from .deviceinbox import DeviceInboxStore from .state import StateStore from .signatures import SignatureStore @@ -44,8 +47,11 @@ from .receipts import ReceiptsStore from .search import SearchStore from .tags import TagsStore from .account_data import AccountDataStore +from .openid import OpenIdStore +from .client_ips import ClientIpStore from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator +from .engines import PostgresEngine from synapse.api.constants import PresenceState from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -57,12 +63,6 @@ import logging logger = logging.getLogger(__name__) -# Number of msec of granularity to store the user IP 'last seen' time. Smaller -# times give more inserts into the database even for readonly API hits -# 120 seconds == 2 minutes -LAST_SEEN_GRANULARITY = 120 * 1000 - - class DataStore(RoomMemberStore, RoomStore, RegistrationStore, StreamStore, ProfileStore, PresenceStore, TransactionStore, @@ -81,29 +81,25 @@ class DataStore(RoomMemberStore, RoomStore, SearchStore, TagsStore, AccountDataStore, - EventPushActionsStore + EventPushActionsStore, + OpenIdStore, + ClientIpStore, + DeviceStore, + DeviceInboxStore, ): def __init__(self, db_conn, hs): self.hs = hs + self._clock = hs.get_clock() self.database_engine = hs.database_engine - cur = db_conn.cursor() - try: - cur.execute("SELECT MIN(stream_ordering) FROM events",) - rows = cur.fetchall() - self.min_stream_token = rows[0][0] if rows and rows[0] and rows[0][0] else -1 - self.min_stream_token = min(self.min_stream_token, -1) - finally: - cur.close() - - self.client_ip_last_seen = Cache( - name="client_ip_last_seen", - keylen=4, - ) - self._stream_id_gen = StreamIdGenerator( - db_conn, "events", "stream_ordering" + db_conn, "events", "stream_ordering", + extra_tables=[("local_invites", "stream_id")] + ) + self._backfill_id_gen = StreamIdGenerator( + db_conn, "events", "stream_ordering", step=-1, + extra_tables=[("ex_outlier_stream", "event_stream_ordering")] ) self._receipts_id_gen = StreamIdGenerator( db_conn, "receipts_linearized", "stream_id" @@ -114,11 +110,17 @@ class DataStore(RoomMemberStore, RoomStore, self._presence_id_gen = StreamIdGenerator( db_conn, "presence_stream", "stream_id" ) + self._device_inbox_id_gen = StreamIdGenerator( + db_conn, "device_max_stream_id", "stream_id" + ) + self._public_room_id_gen = StreamIdGenerator( + db_conn, "public_room_list_stream", "stream_id" + ) self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id") self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") - self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id") + self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") self._push_rules_stream_id_gen = ChainedIdGenerator( @@ -129,7 +131,14 @@ class DataStore(RoomMemberStore, RoomStore, extra_tables=[("deleted_pushers", "stream_id")], ) - events_max = self._stream_id_gen.get_max_token() + if isinstance(self.database_engine, PostgresEngine): + self._cache_id_gen = StreamIdGenerator( + db_conn, "cache_invalidation_stream", "stream_id", + ) + else: + self._cache_id_gen = None + + events_max = self._stream_id_gen.get_current_token() event_cache_prefill, min_event_val = self._get_cache_dict( db_conn, "events", entity_column="room_id", @@ -145,18 +154,18 @@ class DataStore(RoomMemberStore, RoomStore, "MembershipStreamChangeCache", events_max, ) - account_max = self._account_data_id_gen.get_max_token() + account_max = self._account_data_id_gen.get_current_token() self._account_data_stream_cache = StreamChangeCache( "AccountDataAndTagsChangeCache", account_max, ) - self.__presence_on_startup = self._get_active_presence(db_conn) + self._presence_on_startup = self._get_active_presence(db_conn) presence_cache_prefill, min_presence_val = self._get_cache_dict( db_conn, "presence_stream", entity_column="user_id", stream_column="stream_id", - max_value=self._presence_id_gen.get_max_token(), + max_value=self._presence_id_gen.get_current_token(), ) self.presence_stream_cache = StreamChangeCache( "PresenceStreamChangeCache", min_presence_val, @@ -167,7 +176,7 @@ class DataStore(RoomMemberStore, RoomStore, db_conn, "push_rules_stream", entity_column="user_id", stream_column="stream_id", - max_value=self._push_rules_stream_id_gen.get_max_token()[0], + max_value=self._push_rules_stream_id_gen.get_current_token()[0], ) self.push_rules_stream_cache = StreamChangeCache( @@ -175,45 +184,51 @@ class DataStore(RoomMemberStore, RoomStore, prefilled_cache=push_rules_prefill, ) - super(DataStore, self).__init__(hs) - - def take_presence_startup_info(self): - active_on_startup = self.__presence_on_startup - self.__presence_on_startup = None - return active_on_startup - - def _get_cache_dict(self, db_conn, table, entity_column, stream_column, max_value): - # Fetch a mapping of room_id -> max stream position for "recent" rooms. - # It doesn't really matter how many we get, the StreamChangeCache will - # do the right thing to ensure it respects the max size of cache. - sql = ( - "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s" - " WHERE %(stream)s > ? - 100000" - " GROUP BY %(entity)s" - ) % { - "table": table, - "entity": entity_column, - "stream": stream_column, - } + max_device_inbox_id = self._device_inbox_id_gen.get_current_token() + device_inbox_prefill, min_device_inbox_id = self._get_cache_dict( + db_conn, "device_inbox", + entity_column="user_id", + stream_column="stream_id", + max_value=max_device_inbox_id + ) + self._device_inbox_stream_cache = StreamChangeCache( + "DeviceInboxStreamChangeCache", min_device_inbox_id, + prefilled_cache=device_inbox_prefill, + ) + # The federation outbox and the local device inbox uses the same + # stream_id generator. + device_outbox_prefill, min_device_outbox_id = self._get_cache_dict( + db_conn, "device_federation_outbox", + entity_column="destination", + stream_column="stream_id", + max_value=max_device_inbox_id, + ) + self._device_federation_outbox_stream_cache = StreamChangeCache( + "DeviceFederationOutboxStreamChangeCache", min_device_outbox_id, + prefilled_cache=device_outbox_prefill, + ) - sql = self.database_engine.convert_param_style(sql) + cur = LoggingTransaction( + db_conn.cursor(), + name="_find_stream_orderings_for_times_txn", + database_engine=self.database_engine, + after_callbacks=[] + ) + self._find_stream_orderings_for_times_txn(cur) + cur.close() - txn = db_conn.cursor() - txn.execute(sql, (int(max_value),)) - rows = txn.fetchall() - txn.close() + self.find_stream_orderings_looping_call = self._clock.looping_call( + self._find_stream_orderings_for_times, 60 * 60 * 1000 + ) - cache = { - row[0]: int(row[1]) - for row in rows - } + self._stream_order_on_start = self.get_room_max_stream_ordering() - if cache: - min_val = min(cache.values()) - else: - min_val = max_value + super(DataStore, self).__init__(hs) - return cache, min_val + def take_presence_startup_info(self): + active_on_startup = self._presence_on_startup + self._presence_on_startup = None + return active_on_startup def _get_active_presence(self, db_conn): """Fetch non-offline presence from the database so that we can register @@ -238,39 +253,6 @@ class DataStore(RoomMemberStore, RoomStore, return [UserPresenceState(**row) for row in rows] @defer.inlineCallbacks - def insert_client_ip(self, user, access_token, ip, user_agent): - now = int(self._clock.time_msec()) - key = (user.to_string(), access_token, ip) - - try: - last_seen = self.client_ip_last_seen.get(key) - except KeyError: - last_seen = None - - # Rate-limited inserts - if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: - defer.returnValue(None) - - self.client_ip_last_seen.prefill(key, now) - - # It's safe not to lock here: a) no unique constraint, - # b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely - yield self._simple_upsert( - "user_ips", - keyvalues={ - "user_id": user.to_string(), - "access_token": access_token, - "ip": ip, - "user_agent": user_agent, - }, - values={ - "last_seen": now, - }, - desc="insert_client_ip", - lock=False, - ) - - @defer.inlineCallbacks def count_daily_users(self): """ Counts the number of users who used this homeserver in the last 24 hours. diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 7dc67ecd57..b62c459d8b 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -18,6 +18,8 @@ from synapse.api.errors import StoreError from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.caches.dictionary_cache import DictionaryCache from synapse.util.caches.descriptors import Cache +from synapse.util.caches import intern_dict +from synapse.storage.engines import PostgresEngine import synapse.metrics @@ -26,6 +28,10 @@ from twisted.internet import defer import sys import time import threading +import os + + +CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) logger = logging.getLogger(__name__) @@ -79,7 +85,6 @@ class LoggingTransaction(object): sql_logger.debug("[SQL] {%s} %s", self.name, sql) sql = self.database_engine.convert_param_style(sql) - if args: try: sql_logger.debug( @@ -147,8 +152,8 @@ class SQLBaseStore(object): def __init__(self, hs): self.hs = hs - self._db_pool = hs.get_db_pool() self._clock = hs.get_clock() + self._db_pool = hs.get_db_pool() self._previous_txn_total_time = 0 self._current_txn_total_time = 0 @@ -160,10 +165,12 @@ class SQLBaseStore(object): self._txn_perf_counters = PerformanceCounters() self._get_event_counters = PerformanceCounters() - self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True, + self._get_event_cache = Cache("*getEvent*", keylen=3, max_entries=hs.config.event_cache_size) - self._state_group_cache = DictionaryCache("*stateGroupCache*", 2000) + self._state_group_cache = DictionaryCache( + "*stateGroupCache*", 2000 * CACHE_SIZE_FACTOR + ) self._event_fetch_lock = threading.Condition() self._event_fetch_list = [] @@ -298,13 +305,14 @@ class SQLBaseStore(object): func, *args, **kwargs ) - with PreserveLoggingContext(): - result = yield self._db_pool.runWithConnection( - inner_func, *args, **kwargs - ) - - for after_callback, after_args in after_callbacks: - after_callback(*after_args) + try: + with PreserveLoggingContext(): + result = yield self._db_pool.runWithConnection( + inner_func, *args, **kwargs + ) + finally: + for after_callback, after_args in after_callbacks: + after_callback(*after_args) defer.returnValue(result) @defer.inlineCallbacks @@ -344,7 +352,7 @@ class SQLBaseStore(object): """ col_headers = list(column[0] for column in cursor.description) results = list( - dict(zip(col_headers, row)) for row in cursor.fetchall() + intern_dict(dict(zip(col_headers, row))) for row in cursor.fetchall() ) return results @@ -446,7 +454,9 @@ class SQLBaseStore(object): keyvalues (dict): The unique key tables and their new values values (dict): The nonunique columns and their new values insertion_values (dict): key/values to use when inserting - Returns: A deferred + Returns: + Deferred(bool): True if a new entry was created, False if an + existing one was updated. """ return self.runInteraction( desc, @@ -491,6 +501,10 @@ class SQLBaseStore(object): ) txn.execute(sql, allvalues.values()) + return True + else: + return False + def _simple_select_one(self, table, keyvalues, retcols, allow_none=False, desc="_simple_select_one"): """Executes a SELECT query on the named table, which is expected to @@ -547,12 +561,17 @@ class SQLBaseStore(object): @staticmethod def _simple_select_onecol_txn(txn, table, keyvalues, retcol): + if keyvalues: + where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys()) + else: + where = "" + sql = ( - "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" + "SELECT %(retcol)s FROM %(table)s %(where)s" ) % { "retcol": retcol, "table": table, - "where": " AND ".join("%s = ?" % k for k in keyvalues.keys()), + "where": where, } txn.execute(sql, keyvalues.values()) @@ -584,10 +603,13 @@ class SQLBaseStore(object): more rows, returning the result as a list of dicts. Args: - table : string giving the table name - keyvalues : dict of column names and values to select the rows with, - or None to not apply a WHERE clause. - retcols : list of strings giving the names of the columns to return + table (str): the table name + keyvalues (dict[str, Any] | None): + column names and values to select the rows with, or None to not + apply a WHERE clause. + retcols (iterable[str]): the names of the columns to return + Returns: + defer.Deferred: resolves to list[dict[str, Any]] """ return self.runInteraction( desc, @@ -602,9 +624,11 @@ class SQLBaseStore(object): Args: txn : Transaction object - table : string giving the table name - keyvalues : dict of column names and values to select the rows with - retcols : list of strings giving the names of the columns to return + table (str): the table name + keyvalues (dict[str, T] | None): + column names and values to select the rows with, or None to not + apply a WHERE clause. + retcols (iterable[str]): the names of the columns to return """ if keyvalues: sql = "SELECT %s FROM %s WHERE %s" % ( @@ -725,10 +749,15 @@ class SQLBaseStore(object): @staticmethod def _simple_update_one_txn(txn, table, keyvalues, updatevalues): - update_sql = "UPDATE %s SET %s WHERE %s" % ( + if keyvalues: + where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys()) + else: + where = "" + + update_sql = "UPDATE %s SET %s %s" % ( table, ", ".join("%s = ?" % (k,) for k in updatevalues), - " AND ".join("%s = ?" % (k,) for k in keyvalues) + where, ) txn.execute( @@ -794,6 +823,11 @@ class SQLBaseStore(object): if txn.rowcount > 1: raise StoreError(500, "more than one row matched") + def _simple_delete(self, table, keyvalues, desc): + return self.runInteraction( + desc, self._simple_delete_txn, table, keyvalues + ) + @staticmethod def _simple_delete_txn(txn, table, keyvalues): sql = "DELETE FROM %s WHERE %s" % ( @@ -803,11 +837,95 @@ class SQLBaseStore(object): return txn.execute(sql, keyvalues.values()) - def get_next_stream_id(self): - with self._next_stream_id_lock: - i = self._next_stream_id - self._next_stream_id += 1 - return i + def _get_cache_dict(self, db_conn, table, entity_column, stream_column, + max_value): + # Fetch a mapping of room_id -> max stream position for "recent" rooms. + # It doesn't really matter how many we get, the StreamChangeCache will + # do the right thing to ensure it respects the max size of cache. + sql = ( + "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s" + " WHERE %(stream)s > ? - 100000" + " GROUP BY %(entity)s" + ) % { + "table": table, + "entity": entity_column, + "stream": stream_column, + } + + sql = self.database_engine.convert_param_style(sql) + + txn = db_conn.cursor() + txn.execute(sql, (int(max_value),)) + rows = txn.fetchall() + txn.close() + + cache = { + row[0]: int(row[1]) + for row in rows + } + + if cache: + min_val = min(cache.values()) + else: + min_val = max_value + + return cache, min_val + + def _invalidate_cache_and_stream(self, txn, cache_func, keys): + """Invalidates the cache and adds it to the cache stream so slaves + will know to invalidate their caches. + + This should only be used to invalidate caches where slaves won't + otherwise know from other replication streams that the cache should + be invalidated. + """ + txn.call_after(cache_func.invalidate, keys) + + if isinstance(self.database_engine, PostgresEngine): + # get_next() returns a context manager which is designed to wrap + # the transaction. However, we want to only get an ID when we want + # to use it, here, so we need to call __enter__ manually, and have + # __exit__ called after the transaction finishes. + ctx = self._cache_id_gen.get_next() + stream_id = ctx.__enter__() + txn.call_after(ctx.__exit__, None, None, None) + txn.call_after(self.hs.get_notifier().on_new_replication_data) + + self._simple_insert_txn( + txn, + table="cache_invalidation_stream", + values={ + "stream_id": stream_id, + "cache_func": cache_func.__name__, + "keys": list(keys), + "invalidation_ts": self.clock.time_msec(), + } + ) + + def get_all_updated_caches(self, last_id, current_id, limit): + if last_id == current_id: + return defer.succeed([]) + + def get_all_updated_caches_txn(txn): + # We purposefully don't bound by the current token, as we want to + # send across cache invalidations as quickly as possible. Cache + # invalidations are idempotent, so duplicates are fine. + sql = ( + "SELECT stream_id, cache_func, keys, invalidation_ts" + " FROM cache_invalidation_stream" + " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, limit,)) + return txn.fetchall() + return self.runInteraction( + "get_all_updated_caches", get_all_updated_caches_txn + ) + + def get_cache_stream_token(self): + if self._cache_id_gen: + return self._cache_id_gen.get_current_token() + else: + return 0 class _RollbackButIsFineException(Exception): diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py index faddefe219..3fa226e92d 100644 --- a/synapse/storage/account_data.py +++ b/synapse/storage/account_data.py @@ -16,6 +16,8 @@ from ._base import SQLBaseStore from twisted.internet import defer +from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks + import ujson as json import logging @@ -24,6 +26,7 @@ logger = logging.getLogger(__name__) class AccountDataStore(SQLBaseStore): + @cached() def get_account_data_for_user(self, user_id): """Get all the client account_data for a user. @@ -60,6 +63,47 @@ class AccountDataStore(SQLBaseStore): "get_account_data_for_user", get_account_data_for_user_txn ) + @cachedInlineCallbacks(num_args=2) + def get_global_account_data_by_type_for_user(self, data_type, user_id): + """ + Returns: + Deferred: A dict + """ + result = yield self._simple_select_one_onecol( + table="account_data", + keyvalues={ + "user_id": user_id, + "account_data_type": data_type, + }, + retcol="content", + desc="get_global_account_data_by_type_for_user", + allow_none=True, + ) + + if result: + defer.returnValue(json.loads(result)) + else: + defer.returnValue(None) + + @cachedList(cached_method_name="get_global_account_data_by_type_for_user", + num_args=2, list_name="user_ids", inlineCallbacks=True) + def get_global_account_data_by_type_for_users(self, data_type, user_ids): + rows = yield self._simple_select_many_batch( + table="account_data", + column="user_id", + iterable=user_ids, + keyvalues={ + "account_data_type": data_type, + }, + retcols=("user_id", "content",), + desc="get_global_account_data_by_type_for_users", + ) + + defer.returnValue({ + row["user_id"]: json.loads(row["content"]) if row["content"] else None + for row in rows + }) + def get_account_data_for_room(self, user_id, room_id): """Get all the client account_data for a user for a room. @@ -94,6 +138,9 @@ class AccountDataStore(SQLBaseStore): A deferred pair of lists of tuples of stream_id int, user_id string, room_id string, type string, and content string. """ + if last_room_id == current_id and last_global_id == current_id: + return defer.succeed(([], [])) + def get_updated_account_data_txn(txn): sql = ( "SELECT stream_id, user_id, account_data_type, content" @@ -193,6 +240,7 @@ class AccountDataStore(SQLBaseStore): self._account_data_stream_cache.entity_has_changed, user_id, next_id, ) + txn.call_after(self.get_account_data_for_user.invalidate, (user_id,)) self._update_max_stream_id(txn, next_id) with self._account_data_id_gen.get_next() as next_id: @@ -200,7 +248,7 @@ class AccountDataStore(SQLBaseStore): "add_room_account_data", add_account_data_txn, next_id ) - result = self._account_data_id_gen.get_max_token() + result = self._account_data_id_gen.get_current_token() defer.returnValue(result) @defer.inlineCallbacks @@ -232,6 +280,11 @@ class AccountDataStore(SQLBaseStore): self._account_data_stream_cache.entity_has_changed, user_id, next_id, ) + txn.call_after(self.get_account_data_for_user.invalidate, (user_id,)) + txn.call_after( + self.get_global_account_data_by_type_for_user.invalidate, + (account_data_type, user_id,) + ) self._update_max_stream_id(txn, next_id) with self._account_data_id_gen.get_next() as next_id: @@ -239,7 +292,7 @@ class AccountDataStore(SQLBaseStore): "add_user_account_data", add_account_data_txn, next_id ) - result = self._account_data_id_gen.get_max_token() + result = self._account_data_id_gen.get_current_token() defer.returnValue(result) def _update_max_stream_id(self, txn, next_id): diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py index 371600eebb..514570561f 100644 --- a/synapse/storage/appservice.py +++ b/synapse/storage/appservice.py @@ -13,16 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import urllib -import yaml import simplejson as json from twisted.internet import defer from synapse.api.constants import Membership -from synapse.appservice import ApplicationService, AppServiceTransaction -from synapse.config._base import ConfigError +from synapse.appservice import AppServiceTransaction +from synapse.config.appservice import load_appservices from synapse.storage.roommember import RoomsForUser -from synapse.types import UserID from ._base import SQLBaseStore @@ -34,13 +31,21 @@ class ApplicationServiceStore(SQLBaseStore): def __init__(self, hs): super(ApplicationServiceStore, self).__init__(hs) self.hostname = hs.hostname - self.services_cache = ApplicationServiceStore.load_appservices( + self.services_cache = load_appservices( hs.hostname, hs.config.app_service_config_files ) def get_app_services(self): - return defer.succeed(self.services_cache) + return self.services_cache + + def get_if_app_services_interested_in_user(self, user_id): + """Check if the user is one associated with an app service + """ + for service in self.services_cache: + if service.is_interested_in_user(user_id): + return True + return False def get_app_service_by_user_id(self, user_id): """Retrieve an application service from their user ID. @@ -57,8 +62,8 @@ class ApplicationServiceStore(SQLBaseStore): """ for service in self.services_cache: if service.sender == user_id: - return defer.succeed(service) - return defer.succeed(None) + return service + return None def get_app_service_by_token(self, token): """Get the application service with the given appservice token. @@ -70,8 +75,8 @@ class ApplicationServiceStore(SQLBaseStore): """ for service in self.services_cache: if service.token == token: - return defer.succeed(service) - return defer.succeed(None) + return service + return None def get_app_service_rooms(self, service): """Get a list of RoomsForUser for this application service. @@ -144,102 +149,6 @@ class ApplicationServiceStore(SQLBaseStore): return rooms_for_user_matching_user_id - @classmethod - def _load_appservice(cls, hostname, as_info, config_filename): - required_string_fields = [ - "id", "url", "as_token", "hs_token", "sender_localpart" - ] - for field in required_string_fields: - if not isinstance(as_info.get(field), basestring): - raise KeyError("Required string field: '%s' (%s)" % ( - field, config_filename, - )) - - localpart = as_info["sender_localpart"] - if urllib.quote(localpart) != localpart: - raise ValueError( - "sender_localpart needs characters which are not URL encoded." - ) - user = UserID(localpart, hostname) - user_id = user.to_string() - - # namespace checks - if not isinstance(as_info.get("namespaces"), dict): - raise KeyError("Requires 'namespaces' object.") - for ns in ApplicationService.NS_LIST: - # specific namespaces are optional - if ns in as_info["namespaces"]: - # expect a list of dicts with exclusive and regex keys - for regex_obj in as_info["namespaces"][ns]: - if not isinstance(regex_obj, dict): - raise ValueError( - "Expected namespace entry in %s to be an object," - " but got %s", ns, regex_obj - ) - if not isinstance(regex_obj.get("regex"), basestring): - raise ValueError( - "Missing/bad type 'regex' key in %s", regex_obj - ) - if not isinstance(regex_obj.get("exclusive"), bool): - raise ValueError( - "Missing/bad type 'exclusive' key in %s", regex_obj - ) - return ApplicationService( - token=as_info["as_token"], - url=as_info["url"], - namespaces=as_info["namespaces"], - hs_token=as_info["hs_token"], - sender=user_id, - id=as_info["id"], - ) - - @classmethod - def load_appservices(cls, hostname, config_files): - """Returns a list of Application Services from the config files.""" - if not isinstance(config_files, list): - logger.warning( - "Expected %s to be a list of AS config files.", config_files - ) - return [] - - # Dicts of value -> filename - seen_as_tokens = {} - seen_ids = {} - - appservices = [] - - for config_file in config_files: - try: - with open(config_file, 'r') as f: - appservice = ApplicationServiceStore._load_appservice( - hostname, yaml.load(f), config_file - ) - if appservice.id in seen_ids: - raise ConfigError( - "Cannot reuse ID across application services: " - "%s (files: %s, %s)" % ( - appservice.id, config_file, seen_ids[appservice.id], - ) - ) - seen_ids[appservice.id] = config_file - if appservice.token in seen_as_tokens: - raise ConfigError( - "Cannot reuse as_token across application services: " - "%s (files: %s, %s)" % ( - appservice.token, - config_file, - seen_as_tokens[appservice.token], - ) - ) - seen_as_tokens[appservice.token] = config_file - logger.info("Loaded application service: %s", appservice) - appservices.append(appservice) - except Exception as e: - logger.error("Failed to load appservice from '%s'", config_file) - logger.exception(e) - raise - return appservices - class ApplicationServiceTransactionStore(SQLBaseStore): @@ -262,7 +171,7 @@ class ApplicationServiceTransactionStore(SQLBaseStore): ["as_id"] ) # NB: This assumes this class is linked with ApplicationServiceStore - as_list = yield self.get_app_services() + as_list = self.get_app_services() services = [] for res in results: @@ -317,38 +226,37 @@ class ApplicationServiceTransactionStore(SQLBaseStore): Returns: AppServiceTransaction: A new transaction. """ + def _create_appservice_txn(txn): + # work out new txn id (highest txn id for this service += 1) + # The highest id may be the last one sent (in which case it is last_txn) + # or it may be the highest in the txns list (which are waiting to be/are + # being sent) + last_txn_id = self._get_last_txn(txn, service.id) + + txn.execute( + "SELECT MAX(txn_id) FROM application_services_txns WHERE as_id=?", + (service.id,) + ) + highest_txn_id = txn.fetchone()[0] + if highest_txn_id is None: + highest_txn_id = 0 + + new_txn_id = max(highest_txn_id, last_txn_id) + 1 + + # Insert new txn into txn table + event_ids = json.dumps([e.event_id for e in events]) + txn.execute( + "INSERT INTO application_services_txns(as_id, txn_id, event_ids) " + "VALUES(?,?,?)", + (service.id, new_txn_id, event_ids) + ) + return AppServiceTransaction( + service=service, id=new_txn_id, events=events + ) + return self.runInteraction( "create_appservice_txn", - self._create_appservice_txn, - service, events - ) - - def _create_appservice_txn(self, txn, service, events): - # work out new txn id (highest txn id for this service += 1) - # The highest id may be the last one sent (in which case it is last_txn) - # or it may be the highest in the txns list (which are waiting to be/are - # being sent) - last_txn_id = self._get_last_txn(txn, service.id) - - txn.execute( - "SELECT MAX(txn_id) FROM application_services_txns WHERE as_id=?", - (service.id,) - ) - highest_txn_id = txn.fetchone()[0] - if highest_txn_id is None: - highest_txn_id = 0 - - new_txn_id = max(highest_txn_id, last_txn_id) + 1 - - # Insert new txn into txn table - event_ids = json.dumps([e.event_id for e in events]) - txn.execute( - "INSERT INTO application_services_txns(as_id, txn_id, event_ids) " - "VALUES(?,?,?)", - (service.id, new_txn_id, event_ids) - ) - return AppServiceTransaction( - service=service, id=new_txn_id, events=events + _create_appservice_txn, ) def complete_appservice_txn(self, txn_id, service): @@ -362,41 +270,41 @@ class ApplicationServiceTransactionStore(SQLBaseStore): A Deferred which resolves if this transaction was stored successfully. """ - return self.runInteraction( - "complete_appservice_txn", - self._complete_appservice_txn, - txn_id, service - ) - - def _complete_appservice_txn(self, txn, txn_id, service): txn_id = int(txn_id) - # Debugging query: Make sure the txn being completed is EXACTLY +1 from - # what was there before. If it isn't, we've got problems (e.g. the AS - # has probably missed some events), so whine loudly but still continue, - # since it shouldn't fail completion of the transaction. - last_txn_id = self._get_last_txn(txn, service.id) - if (last_txn_id + 1) != txn_id: - logger.error( - "appservice: Completing a transaction which has an ID > 1 from " - "the last ID sent to this AS. We've either dropped events or " - "sent it to the AS out of order. FIX ME. last_txn=%s " - "completing_txn=%s service_id=%s", last_txn_id, txn_id, - service.id + def _complete_appservice_txn(txn): + # Debugging query: Make sure the txn being completed is EXACTLY +1 from + # what was there before. If it isn't, we've got problems (e.g. the AS + # has probably missed some events), so whine loudly but still continue, + # since it shouldn't fail completion of the transaction. + last_txn_id = self._get_last_txn(txn, service.id) + if (last_txn_id + 1) != txn_id: + logger.error( + "appservice: Completing a transaction which has an ID > 1 from " + "the last ID sent to this AS. We've either dropped events or " + "sent it to the AS out of order. FIX ME. last_txn=%s " + "completing_txn=%s service_id=%s", last_txn_id, txn_id, + service.id + ) + + # Set current txn_id for AS to 'txn_id' + self._simple_upsert_txn( + txn, "application_services_state", dict(as_id=service.id), + dict(last_txn=txn_id) ) - # Set current txn_id for AS to 'txn_id' - self._simple_upsert_txn( - txn, "application_services_state", dict(as_id=service.id), - dict(last_txn=txn_id) - ) + # Delete txn + self._simple_delete_txn( + txn, "application_services_txns", + dict(txn_id=txn_id, as_id=service.id) + ) - # Delete txn - self._simple_delete_txn( - txn, "application_services_txns", - dict(txn_id=txn_id, as_id=service.id) + return self.runInteraction( + "complete_appservice_txn", + _complete_appservice_txn, ) + @defer.inlineCallbacks def get_oldest_unsent_txn(self, service): """Get the oldest transaction which has not been sent for this service. @@ -407,32 +315,37 @@ class ApplicationServiceTransactionStore(SQLBaseStore): A Deferred which resolves to an AppServiceTransaction or None. """ - return self.runInteraction( - "get_oldest_unsent_appservice_txn", - self._get_oldest_unsent_txn, - service - ) + def _get_oldest_unsent_txn(txn): + # Monotonically increasing txn ids, so just select the smallest + # one in the txns table (we delete them when they are sent) + txn.execute( + "SELECT * FROM application_services_txns WHERE as_id=?" + " ORDER BY txn_id ASC LIMIT 1", + (service.id,) + ) + rows = self.cursor_to_dict(txn) + if not rows: + return None - def _get_oldest_unsent_txn(self, txn, service): - # Monotonically increasing txn ids, so just select the smallest - # one in the txns table (we delete them when they are sent) - txn.execute( - "SELECT * FROM application_services_txns WHERE as_id=?" - " ORDER BY txn_id ASC LIMIT 1", - (service.id,) + entry = rows[0] + + return entry + + entry = yield self.runInteraction( + "get_oldest_unsent_appservice_txn", + _get_oldest_unsent_txn, ) - rows = self.cursor_to_dict(txn) - if not rows: - return None - entry = rows[0] + if not entry: + defer.returnValue(None) event_ids = json.loads(entry["event_ids"]) - events = self._get_events_txn(txn, event_ids) - return AppServiceTransaction( + events = yield self._get_events(event_ids) + + defer.returnValue(AppServiceTransaction( service=service, id=entry["txn_id"], events=events - ) + )) def _get_last_txn(self, txn, service_id): txn.execute( @@ -444,3 +357,45 @@ class ApplicationServiceTransactionStore(SQLBaseStore): return 0 else: return int(last_txn_id[0]) # select 'last_txn' col + + def set_appservice_last_pos(self, pos): + def set_appservice_last_pos_txn(txn): + txn.execute( + "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,) + ) + return self.runInteraction( + "set_appservice_last_pos", set_appservice_last_pos_txn + ) + + @defer.inlineCallbacks + def get_new_events_for_appservice(self, current_id, limit): + """Get all new evnets""" + + def get_new_events_for_appservice_txn(txn): + sql = ( + "SELECT e.stream_ordering, e.event_id" + " FROM events AS e" + " WHERE" + " (SELECT stream_ordering FROM appservice_stream_position)" + " < e.stream_ordering" + " AND e.stream_ordering <= ?" + " ORDER BY e.stream_ordering ASC" + " LIMIT ?" + ) + + txn.execute(sql, (current_id, limit)) + rows = txn.fetchall() + + upper_bound = current_id + if len(rows) == limit: + upper_bound = rows[-1][0] + + return upper_bound, [row[1] for row in rows] + + upper_bound, event_ids = yield self.runInteraction( + "get_new_events_for_appservice", get_new_events_for_appservice_txn, + ) + + events = yield self._get_events(event_ids) + + defer.returnValue((upper_bound, events)) diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 49904046cf..94b2bcc54a 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -14,6 +14,7 @@ # limitations under the License. from ._base import SQLBaseStore +from . import engines from twisted.internet import defer @@ -87,10 +88,12 @@ class BackgroundUpdateStore(SQLBaseStore): @defer.inlineCallbacks def start_doing_background_updates(self): - while True: - if self._background_update_timer is not None: - return + assert self._background_update_timer is None, \ + "background updates already running" + + logger.info("Starting background schema updates") + while True: sleep = defer.Deferred() self._background_update_timer = self._clock.call_later( self.BACKGROUND_UPDATE_INTERVAL_MS / 1000., sleep.callback, None @@ -101,22 +104,23 @@ class BackgroundUpdateStore(SQLBaseStore): self._background_update_timer = None try: - result = yield self.do_background_update( + result = yield self.do_next_background_update( self.BACKGROUND_UPDATE_DURATION_MS ) except: logger.exception("Error doing update") - - if result is None: - logger.info( - "No more background updates to do." - " Unscheduling background update task." - ) - return + else: + if result is None: + logger.info( + "No more background updates to do." + " Unscheduling background update task." + ) + defer.returnValue(None) @defer.inlineCallbacks - def do_background_update(self, desired_duration_ms): - """Does some amount of work on a background update + def do_next_background_update(self, desired_duration_ms): + """Does some amount of work on the next queued background update + Args: desired_duration_ms(float): How long we want to spend updating. @@ -129,17 +133,29 @@ class BackgroundUpdateStore(SQLBaseStore): updates = yield self._simple_select_list( "background_updates", keyvalues=None, - retcols=("update_name",), + retcols=("update_name", "depends_on"), ) + in_flight = set(update["update_name"] for update in updates) for update in updates: - self._background_update_queue.append(update['update_name']) + if update["depends_on"] not in in_flight: + self._background_update_queue.append(update['update_name']) if not self._background_update_queue: + # no work left to do defer.returnValue(None) + # pop from the front, and add back to the back update_name = self._background_update_queue.pop(0) self._background_update_queue.append(update_name) + res = yield self._do_background_update(update_name, desired_duration_ms) + defer.returnValue(res) + + @defer.inlineCallbacks + def _do_background_update(self, update_name, desired_duration_ms): + logger.info("Starting update batch on background update '%s'", + update_name) + update_handler = self._background_update_handlers[update_name] performance = self._background_update_performance.get(update_name) @@ -173,11 +189,12 @@ class BackgroundUpdateStore(SQLBaseStore): logger.info( "Updating %r. Updated %r items in %rms." - " (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r)", + " (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r, batch_size=%r)", update_name, items_updated, duration_ms, performance.total_items_per_ms(), performance.average_items_per_ms(), performance.total_item_count, + batch_size, ) performance.update(items_updated, duration_ms) @@ -201,6 +218,70 @@ class BackgroundUpdateStore(SQLBaseStore): """ self._background_update_handlers[update_name] = update_handler + def register_background_index_update(self, update_name, index_name, + table, columns, where_clause=None): + """Helper for store classes to do a background index addition + + To use: + + 1. use a schema delta file to add a background update. Example: + INSERT INTO background_updates (update_name, progress_json) VALUES + ('my_new_index', '{}'); + + 2. In the Store constructor, call this method + + Args: + update_name (str): update_name to register for + index_name (str): name of index to add + table (str): table to add index to + columns (list[str]): columns/expressions to include in index + """ + + # if this is postgres, we add the indexes concurrently. Otherwise + # we fall back to doing it inline + if isinstance(self.database_engine, engines.PostgresEngine): + conc = True + else: + conc = False + # We don't use partial indices on SQLite as it wasn't introduced + # until 3.8, and wheezy has 3.7 + where_clause = None + + sql = ( + "CREATE INDEX %(conc)s %(name)s ON %(table)s (%(columns)s)" + " %(where_clause)s" + ) % { + "conc": "CONCURRENTLY" if conc else "", + "name": index_name, + "table": table, + "columns": ", ".join(columns), + "where_clause": "WHERE " + where_clause if where_clause else "" + } + + def create_index_concurrently(conn): + conn.rollback() + # postgres insists on autocommit for the index + conn.set_session(autocommit=True) + c = conn.cursor() + c.execute(sql) + conn.set_session(autocommit=False) + + def create_index(conn): + c = conn.cursor() + c.execute(sql) + + @defer.inlineCallbacks + def updater(progress, batch_size): + logger.info("Adding index %s to %s", index_name, table) + if conc: + yield self.runWithConnection(create_index_concurrently) + else: + yield self.runWithConnection(create_index) + yield self._end_background_update(update_name) + defer.returnValue(1) + + self.register_background_update_handler(update_name, updater) + def start_background_update(self, update_name, progress): """Starts a background update running. diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py new file mode 100644 index 0000000000..71e5ea112f --- /dev/null +++ b/synapse/storage/client_ips.py @@ -0,0 +1,145 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.internet import defer + +from ._base import Cache +from . import background_updates + +logger = logging.getLogger(__name__) + +# Number of msec of granularity to store the user IP 'last seen' time. Smaller +# times give more inserts into the database even for readonly API hits +# 120 seconds == 2 minutes +LAST_SEEN_GRANULARITY = 120 * 1000 + + +class ClientIpStore(background_updates.BackgroundUpdateStore): + def __init__(self, hs): + self.client_ip_last_seen = Cache( + name="client_ip_last_seen", + keylen=4, + ) + + super(ClientIpStore, self).__init__(hs) + + self.register_background_index_update( + "user_ips_device_index", + index_name="user_ips_device_id", + table="user_ips", + columns=["user_id", "device_id", "last_seen"], + ) + + @defer.inlineCallbacks + def insert_client_ip(self, user, access_token, ip, user_agent, device_id): + now = int(self._clock.time_msec()) + key = (user.to_string(), access_token, ip) + + try: + last_seen = self.client_ip_last_seen.get(key) + except KeyError: + last_seen = None + + # Rate-limited inserts + if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: + defer.returnValue(None) + + self.client_ip_last_seen.prefill(key, now) + + # It's safe not to lock here: a) no unique constraint, + # b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely + yield self._simple_upsert( + "user_ips", + keyvalues={ + "user_id": user.to_string(), + "access_token": access_token, + "ip": ip, + "user_agent": user_agent, + "device_id": device_id, + }, + values={ + "last_seen": now, + }, + desc="insert_client_ip", + lock=False, + ) + + @defer.inlineCallbacks + def get_last_client_ip_by_device(self, devices): + """For each device_id listed, give the user_ip it was last seen on + + Args: + devices (iterable[(str, str)]): list of (user_id, device_id) pairs + + Returns: + defer.Deferred: resolves to a dict, where the keys + are (user_id, device_id) tuples. The values are also dicts, with + keys giving the column names + """ + + res = yield self.runInteraction( + "get_last_client_ip_by_device", + self._get_last_client_ip_by_device_txn, + retcols=( + "user_id", + "access_token", + "ip", + "user_agent", + "device_id", + "last_seen", + ), + devices=devices + ) + + ret = {(d["user_id"], d["device_id"]): d for d in res} + defer.returnValue(ret) + + @classmethod + def _get_last_client_ip_by_device_txn(cls, txn, devices, retcols): + where_clauses = [] + bindings = [] + for (user_id, device_id) in devices: + if device_id is None: + where_clauses.append("(user_id = ? AND device_id IS NULL)") + bindings.extend((user_id, )) + else: + where_clauses.append("(user_id = ? AND device_id = ?)") + bindings.extend((user_id, device_id)) + + inner_select = ( + "SELECT MAX(last_seen) mls, user_id, device_id FROM user_ips " + "WHERE %(where)s " + "GROUP BY user_id, device_id" + ) % { + "where": " OR ".join(where_clauses), + } + + sql = ( + "SELECT %(retcols)s FROM user_ips " + "JOIN (%(inner_select)s) ips ON" + " user_ips.last_seen = ips.mls AND" + " user_ips.user_id = ips.user_id AND" + " (user_ips.device_id = ips.device_id OR" + " (user_ips.device_id IS NULL AND ips.device_id IS NULL)" + " )" + ) % { + "retcols": ",".join("user_ips." + c for c in retcols), + "inner_select": inner_select, + } + + txn.execute(sql, bindings) + return cls.cursor_to_dict(txn) diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py new file mode 100644 index 0000000000..87398d60bc --- /dev/null +++ b/synapse/storage/deviceinbox.py @@ -0,0 +1,369 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import ujson + +from twisted.internet import defer + +from ._base import SQLBaseStore + + +logger = logging.getLogger(__name__) + + +class DeviceInboxStore(SQLBaseStore): + + @defer.inlineCallbacks + def add_messages_to_device_inbox(self, local_messages_by_user_then_device, + remote_messages_by_destination): + """Used to send messages from this server. + + Args: + sender_user_id(str): The ID of the user sending these messages. + local_messages_by_user_and_device(dict): + Dictionary of user_id to device_id to message. + remote_messages_by_destination(dict): + Dictionary of destination server_name to the EDU JSON to send. + Returns: + A deferred stream_id that resolves when the messages have been + inserted. + """ + + def add_messages_txn(txn, now_ms, stream_id): + # Add the local messages directly to the local inbox. + self._add_messages_to_local_device_inbox_txn( + txn, stream_id, local_messages_by_user_then_device + ) + + # Add the remote messages to the federation outbox. + # We'll send them to a remote server when we next send a + # federation transaction to that destination. + sql = ( + "INSERT INTO device_federation_outbox" + " (destination, stream_id, queued_ts, messages_json)" + " VALUES (?,?,?,?)" + ) + rows = [] + for destination, edu in remote_messages_by_destination.items(): + edu_json = ujson.dumps(edu) + rows.append((destination, stream_id, now_ms, edu_json)) + txn.executemany(sql, rows) + + with self._device_inbox_id_gen.get_next() as stream_id: + now_ms = self.clock.time_msec() + yield self.runInteraction( + "add_messages_to_device_inbox", + add_messages_txn, + now_ms, + stream_id, + ) + for user_id in local_messages_by_user_then_device.keys(): + self._device_inbox_stream_cache.entity_has_changed( + user_id, stream_id + ) + for destination in remote_messages_by_destination.keys(): + self._device_federation_outbox_stream_cache.entity_has_changed( + destination, stream_id + ) + + defer.returnValue(self._device_inbox_id_gen.get_current_token()) + + @defer.inlineCallbacks + def add_messages_from_remote_to_device_inbox( + self, origin, message_id, local_messages_by_user_then_device + ): + def add_messages_txn(txn, now_ms, stream_id): + # Check if we've already inserted a matching message_id for that + # origin. This can happen if the origin doesn't receive our + # acknowledgement from the first time we received the message. + already_inserted = self._simple_select_one_txn( + txn, table="device_federation_inbox", + keyvalues={"origin": origin, "message_id": message_id}, + retcols=("message_id",), + allow_none=True, + ) + if already_inserted is not None: + return + + # Add an entry for this message_id so that we know we've processed + # it. + self._simple_insert_txn( + txn, table="device_federation_inbox", + values={ + "origin": origin, + "message_id": message_id, + "received_ts": now_ms, + }, + ) + + # Add the messages to the approriate local device inboxes so that + # they'll be sent to the devices when they next sync. + self._add_messages_to_local_device_inbox_txn( + txn, stream_id, local_messages_by_user_then_device + ) + + with self._device_inbox_id_gen.get_next() as stream_id: + now_ms = self.clock.time_msec() + yield self.runInteraction( + "add_messages_from_remote_to_device_inbox", + add_messages_txn, + now_ms, + stream_id, + ) + for user_id in local_messages_by_user_then_device.keys(): + self._device_inbox_stream_cache.entity_has_changed( + user_id, stream_id + ) + + defer.returnValue(stream_id) + + def _add_messages_to_local_device_inbox_txn(self, txn, stream_id, + messages_by_user_then_device): + sql = ( + "UPDATE device_max_stream_id" + " SET stream_id = ?" + " WHERE stream_id < ?" + ) + txn.execute(sql, (stream_id, stream_id)) + + local_by_user_then_device = {} + for user_id, messages_by_device in messages_by_user_then_device.items(): + messages_json_for_user = {} + devices = messages_by_device.keys() + if len(devices) == 1 and devices[0] == "*": + # Handle wildcard device_ids. + sql = ( + "SELECT device_id FROM devices" + " WHERE user_id = ?" + ) + txn.execute(sql, (user_id,)) + message_json = ujson.dumps(messages_by_device["*"]) + for row in txn.fetchall(): + # Add the message for all devices for this user on this + # server. + device = row[0] + messages_json_for_user[device] = message_json + else: + if not devices: + continue + sql = ( + "SELECT device_id FROM devices" + " WHERE user_id = ? AND device_id IN (" + + ",".join("?" * len(devices)) + + ")" + ) + # TODO: Maybe this needs to be done in batches if there are + # too many local devices for a given user. + txn.execute(sql, [user_id] + devices) + for row in txn.fetchall(): + # Only insert into the local inbox if the device exists on + # this server + device = row[0] + message_json = ujson.dumps(messages_by_device[device]) + messages_json_for_user[device] = message_json + + if messages_json_for_user: + local_by_user_then_device[user_id] = messages_json_for_user + + if not local_by_user_then_device: + return + + sql = ( + "INSERT INTO device_inbox" + " (user_id, device_id, stream_id, message_json)" + " VALUES (?,?,?,?)" + ) + rows = [] + for user_id, messages_by_device in local_by_user_then_device.items(): + for device_id, message_json in messages_by_device.items(): + rows.append((user_id, device_id, stream_id, message_json)) + + txn.executemany(sql, rows) + + def get_new_messages_for_device( + self, user_id, device_id, last_stream_id, current_stream_id, limit=100 + ): + """ + Args: + user_id(str): The recipient user_id. + device_id(str): The recipient device_id. + current_stream_id(int): The current position of the to device + message stream. + Returns: + Deferred ([dict], int): List of messages for the device and where + in the stream the messages got to. + """ + has_changed = self._device_inbox_stream_cache.has_entity_changed( + user_id, last_stream_id + ) + if not has_changed: + return defer.succeed(([], current_stream_id)) + + def get_new_messages_for_device_txn(txn): + sql = ( + "SELECT stream_id, message_json FROM device_inbox" + " WHERE user_id = ? AND device_id = ?" + " AND ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC" + " LIMIT ?" + ) + txn.execute(sql, ( + user_id, device_id, last_stream_id, current_stream_id, limit + )) + messages = [] + for row in txn.fetchall(): + stream_pos = row[0] + messages.append(ujson.loads(row[1])) + if len(messages) < limit: + stream_pos = current_stream_id + return (messages, stream_pos) + + return self.runInteraction( + "get_new_messages_for_device", get_new_messages_for_device_txn, + ) + + def delete_messages_for_device(self, user_id, device_id, up_to_stream_id): + """ + Args: + user_id(str): The recipient user_id. + device_id(str): The recipient device_id. + up_to_stream_id(int): Where to delete messages up to. + Returns: + A deferred that resolves when the messages have been deleted. + """ + def delete_messages_for_device_txn(txn): + sql = ( + "DELETE FROM device_inbox" + " WHERE user_id = ? AND device_id = ?" + " AND stream_id <= ?" + ) + txn.execute(sql, (user_id, device_id, up_to_stream_id)) + + return self.runInteraction( + "delete_messages_for_device", delete_messages_for_device_txn + ) + + def get_all_new_device_messages(self, last_pos, current_pos, limit): + """ + Args: + last_pos(int): + current_pos(int): + limit(int): + Returns: + A deferred list of rows from the device inbox + """ + if last_pos == current_pos: + return defer.succeed([]) + + def get_all_new_device_messages_txn(txn): + # We limit like this as we might have multiple rows per stream_id, and + # we want to make sure we always get all entries for any stream_id + # we return. + upper_pos = min(current_pos, last_pos + limit) + sql = ( + "SELECT stream_id, user_id" + " FROM device_inbox" + " WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC" + ) + txn.execute(sql, (last_pos, upper_pos)) + rows = txn.fetchall() + + sql = ( + "SELECT stream_id, destination" + " FROM device_federation_outbox" + " WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC" + ) + txn.execute(sql, (last_pos, upper_pos)) + rows.extend(txn.fetchall()) + + return rows + + return self.runInteraction( + "get_all_new_device_messages", get_all_new_device_messages_txn + ) + + def get_to_device_stream_token(self): + return self._device_inbox_id_gen.get_current_token() + + def get_new_device_msgs_for_remote( + self, destination, last_stream_id, current_stream_id, limit=100 + ): + """ + Args: + destination(str): The name of the remote server. + last_stream_id(int): The last position of the device message stream + that the server sent up to. + current_stream_id(int): The current position of the device + message stream. + Returns: + Deferred ([dict], int): List of messages for the device and where + in the stream the messages got to. + """ + + has_changed = self._device_federation_outbox_stream_cache.has_entity_changed( + destination, last_stream_id + ) + if not has_changed or last_stream_id == current_stream_id: + return defer.succeed(([], current_stream_id)) + + def get_new_messages_for_remote_destination_txn(txn): + sql = ( + "SELECT stream_id, messages_json FROM device_federation_outbox" + " WHERE destination = ?" + " AND ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC" + " LIMIT ?" + ) + txn.execute(sql, ( + destination, last_stream_id, current_stream_id, limit + )) + messages = [] + for row in txn.fetchall(): + stream_pos = row[0] + messages.append(ujson.loads(row[1])) + if len(messages) < limit: + stream_pos = current_stream_id + return (messages, stream_pos) + + return self.runInteraction( + "get_new_device_msgs_for_remote", + get_new_messages_for_remote_destination_txn, + ) + + def delete_device_msgs_for_remote(self, destination, up_to_stream_id): + """Used to delete messages when the remote destination acknowledges + their receipt. + + Args: + destination(str): The destination server_name + up_to_stream_id(int): Where to delete messages up to. + Returns: + A deferred that resolves when the messages have been deleted. + """ + def delete_messages_for_remote_destination_txn(txn): + sql = ( + "DELETE FROM device_federation_outbox" + " WHERE destination = ?" + " AND stream_id <= ?" + ) + txn.execute(sql, (destination, up_to_stream_id)) + + return self.runInteraction( + "delete_device_msgs_for_remote", + delete_messages_for_remote_destination_txn + ) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py new file mode 100644 index 0000000000..17920d4480 --- /dev/null +++ b/synapse/storage/devices.py @@ -0,0 +1,141 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from twisted.internet import defer + +from synapse.api.errors import StoreError +from ._base import SQLBaseStore + +logger = logging.getLogger(__name__) + + +class DeviceStore(SQLBaseStore): + @defer.inlineCallbacks + def store_device(self, user_id, device_id, + initial_device_display_name, + ignore_if_known=True): + """Ensure the given device is known; add it to the store if not + + Args: + user_id (str): id of user associated with the device + device_id (str): id of device + initial_device_display_name (str): initial displayname of the + device + ignore_if_known (bool): ignore integrity errors which mean the + device is already known + Returns: + defer.Deferred + Raises: + StoreError: if ignore_if_known is False and the device was already + known + """ + try: + yield self._simple_insert( + "devices", + values={ + "user_id": user_id, + "device_id": device_id, + "display_name": initial_device_display_name + }, + desc="store_device", + or_ignore=ignore_if_known, + ) + except Exception as e: + logger.error("store_device with device_id=%s(%r) user_id=%s(%r)" + " display_name=%s(%r) failed: %s", + type(device_id).__name__, device_id, + type(user_id).__name__, user_id, + type(initial_device_display_name).__name__, + initial_device_display_name, e) + raise StoreError(500, "Problem storing device.") + + def get_device(self, user_id, device_id): + """Retrieve a device. + + Args: + user_id (str): The ID of the user which owns the device + device_id (str): The ID of the device to retrieve + Returns: + defer.Deferred for a dict containing the device information + Raises: + StoreError: if the device is not found + """ + return self._simple_select_one( + table="devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + retcols=("user_id", "device_id", "display_name"), + desc="get_device", + ) + + def delete_device(self, user_id, device_id): + """Delete a device. + + Args: + user_id (str): The ID of the user which owns the device + device_id (str): The ID of the device to delete + Returns: + defer.Deferred + """ + return self._simple_delete_one( + table="devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + desc="delete_device", + ) + + def update_device(self, user_id, device_id, new_display_name=None): + """Update a device. + + Args: + user_id (str): The ID of the user which owns the device + device_id (str): The ID of the device to update + new_display_name (str|None): new displayname for device; None + to leave unchanged + Raises: + StoreError: if the device is not found + Returns: + defer.Deferred + """ + updates = {} + if new_display_name is not None: + updates["display_name"] = new_display_name + if not updates: + return defer.succeed(None) + return self._simple_update_one( + table="devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + updatevalues=updates, + desc="update_device", + ) + + @defer.inlineCallbacks + def get_devices_by_user(self, user_id): + """Retrieve all of a user's registered devices. + + Args: + user_id (str): + Returns: + defer.Deferred: resolves to a dict from device_id to a dict + containing "device_id", "user_id" and "display_name" for each + device. + """ + devices = yield self._simple_select_list( + table="devices", + keyvalues={"user_id": user_id}, + retcols=("user_id", "device_id", "display_name"), + desc="get_devices_by_user" + ) + + defer.returnValue({d["device_id"]: d for d in devices}) diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py index 012a0b414a..9caaf81f2c 100644 --- a/synapse/storage/directory.py +++ b/synapse/storage/directory.py @@ -82,32 +82,39 @@ class DirectoryStore(SQLBaseStore): Returns: Deferred """ - try: - yield self._simple_insert( + def alias_txn(txn): + self._simple_insert_txn( + txn, "room_aliases", { "room_alias": room_alias.to_string(), "room_id": room_id, "creator": creator, }, - desc="create_room_alias_association", - ) - except self.database_engine.module.IntegrityError: - raise SynapseError( - 409, "Room alias %s already exists" % room_alias.to_string() ) - for server in servers: - # TODO(erikj): Fix this to bulk insert - yield self._simple_insert( - "room_alias_servers", - { + self._simple_insert_many_txn( + txn, + table="room_alias_servers", + values=[{ "room_alias": room_alias.to_string(), "server": server, - }, - desc="create_room_alias_association", + } for server in servers], ) - self.get_aliases_for_room.invalidate((room_id,)) + + self._invalidate_cache_and_stream( + txn, self.get_aliases_for_room, (room_id,) + ) + + try: + ret = yield self.runInteraction( + "create_room_alias_association", alias_txn + ) + except self.database_engine.module.IntegrityError: + raise SynapseError( + 409, "Room alias %s already exists" % room_alias.to_string() + ) + defer.returnValue(ret) def get_room_alias_creator(self, room_alias): return self._simple_select_one_onecol( @@ -155,7 +162,7 @@ class DirectoryStore(SQLBaseStore): return room_id - @cached() + @cached(max_entries=5000) def get_aliases_for_room(self, room_id): return self._simple_select_onecol( "room_aliases", diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 2e89066515..385d607056 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -12,6 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import collections + +import twisted.internet.defer from ._base import SQLBaseStore @@ -36,24 +39,49 @@ class EndToEndKeyStore(SQLBaseStore): query_list(list): List of pairs of user_ids and device_ids. Returns: Dict mapping from user-id to dict mapping from device_id to - key json byte strings. + dict containing "key_json", "device_display_name". """ - def _get_e2e_device_keys(txn): - result = {} - for user_id, device_id in query_list: - user_result = result.setdefault(user_id, {}) - keyvalues = {"user_id": user_id} - if device_id: - keyvalues["device_id"] = device_id - rows = self._simple_select_list_txn( - txn, table="e2e_device_keys_json", - keyvalues=keyvalues, - retcols=["device_id", "key_json"] - ) - for row in rows: - user_result[row["device_id"]] = row["key_json"] - return result - return self.runInteraction("get_e2e_device_keys", _get_e2e_device_keys) + if not query_list: + return {} + + return self.runInteraction( + "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list + ) + + def _get_e2e_device_keys_txn(self, txn, query_list): + query_clauses = [] + query_params = [] + + for (user_id, device_id) in query_list: + query_clause = "k.user_id = ?" + query_params.append(user_id) + + if device_id: + query_clause += " AND k.device_id = ?" + query_params.append(device_id) + + query_clauses.append(query_clause) + + sql = ( + "SELECT k.user_id, k.device_id, " + " d.display_name AS device_display_name, " + " k.key_json" + " FROM e2e_device_keys_json k" + " LEFT JOIN devices d ON d.user_id = k.user_id" + " AND d.device_id = k.device_id" + " WHERE %s" + ) % ( + " OR ".join("(" + q + ")" for q in query_clauses) + ) + + txn.execute(sql, query_params) + rows = self.cursor_to_dict(txn) + + result = collections.defaultdict(dict) + for row in rows: + result[row["user_id"]][row["device_id"]] = row + + return result def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list): def _add_e2e_one_time_keys(txn): @@ -123,3 +151,16 @@ class EndToEndKeyStore(SQLBaseStore): return self.runInteraction( "claim_e2e_one_time_keys", _claim_e2e_one_time_keys ) + + @twisted.internet.defer.inlineCallbacks + def delete_e2e_keys_by_device(self, user_id, device_id): + yield self._simple_delete( + table="e2e_device_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + desc="delete_e2e_device_keys_by_device" + ) + yield self._simple_delete( + table="e2e_one_time_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + desc="delete_e2e_one_time_keys_by_device" + ) diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py index a48230b93f..338b495611 100644 --- a/synapse/storage/engines/__init__.py +++ b/synapse/storage/engines/__init__.py @@ -26,13 +26,13 @@ SUPPORTED_MODULE = { } -def create_engine(config): - name = config.database_config["name"] +def create_engine(database_config): + name = database_config["name"] engine_class = SUPPORTED_MODULE.get(name, None) if engine_class: module = importlib.import_module(name) - return engine_class(module, config=config) + return engine_class(module, database_config) raise RuntimeError( "Unsupported database engine '%s'" % (name,) diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index a09685b4df..a6ae79dfad 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -13,18 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.prepare_database import prepare_database - from ._base import IncorrectDatabaseSetup class PostgresEngine(object): single_threaded = False - def __init__(self, database_module, config): + def __init__(self, database_module, database_config): self.module = database_module self.module.extensions.register_type(self.module.extensions.UNICODE) - self.config = config + self.synchronous_commit = database_config.get("synchronous_commit", True) def check_database(self, txn): txn.execute("SHOW SERVER_ENCODING") @@ -43,12 +41,19 @@ class PostgresEngine(object): db_conn.set_isolation_level( self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ ) - - def prepare_database(self, db_conn): - prepare_database(db_conn, self, config=self.config) + # Asynchronous commit, don't wait for the server to call fsync before + # ending the transaction. + # https://www.postgresql.org/docs/current/static/wal-async-commit.html + if not self.synchronous_commit: + cursor = db_conn.cursor() + cursor.execute("SET synchronous_commit TO OFF") + cursor.close() def is_deadlock(self, error): if isinstance(error, self.module.DatabaseError): + # https://www.postgresql.org/docs/current/static/errcodes-appendix.html + # "40001" serialization_failure + # "40P01" deadlock_detected return error.pgcode in ["40001", "40P01"] return False diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py index 522b905949..755c9a1f07 100644 --- a/synapse/storage/engines/sqlite3.py +++ b/synapse/storage/engines/sqlite3.py @@ -13,9 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.prepare_database import ( - prepare_database, prepare_sqlite3_database -) +from synapse.storage.prepare_database import prepare_database import struct @@ -23,9 +21,8 @@ import struct class Sqlite3Engine(object): single_threaded = True - def __init__(self, database_module, config): + def __init__(self, database_module, database_config): self.module = database_module - self.config = config def check_database(self, txn): pass @@ -34,13 +31,9 @@ class Sqlite3Engine(object): return sql def on_new_connection(self, db_conn): - self.prepare_database(db_conn) + prepare_database(db_conn, self, config=None) db_conn.create_function("rank", 1, _rank) - def prepare_database(self, db_conn): - prepare_sqlite3_database(db_conn) - prepare_database(db_conn, self, config=self.config) - def is_deadlock(self, error): return False diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 3489315e0d..53feaa1960 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -16,6 +16,7 @@ from twisted.internet import defer from ._base import SQLBaseStore +from synapse.api.errors import StoreError from synapse.util.caches.descriptors import cached from unpaddedbase64 import encode_base64 @@ -36,6 +37,13 @@ class EventFederationStore(SQLBaseStore): and backfilling from another server respectively. """ + def __init__(self, hs): + super(EventFederationStore, self).__init__(hs) + + hs.get_clock().looping_call( + self._delete_old_forward_extrem_cache, 60 * 60 * 1000 + ) + def get_auth_chain(self, event_ids): return self.get_auth_chain_ids(event_ids).addCallback(self._get_events) @@ -163,6 +171,22 @@ class EventFederationStore(SQLBaseStore): room_id, ) + @defer.inlineCallbacks + def get_max_depth_of_events(self, event_ids): + sql = ( + "SELECT MAX(depth) FROM events WHERE event_id IN (%s)" + ) % (",".join(["?"] * len(event_ids)),) + + rows = yield self._execute( + "get_max_depth_of_events", None, + sql, *event_ids + ) + + if rows: + defer.returnValue(rows[0][0]) + else: + defer.returnValue(1) + def _get_min_depth_interaction(self, txn, room_id): min_depth = self._simple_select_one_onecol_txn( txn, @@ -254,6 +278,37 @@ class EventFederationStore(SQLBaseStore): ] ) + # We now insert into stream_ordering_to_exterm a mapping from room_id, + # new stream_ordering to new forward extremeties in the room. + # This allows us to later efficiently look up the forward extremeties + # for a room before a given stream_ordering + max_stream_ord = max( + ev.internal_metadata.stream_ordering for ev in events + ) + new_extrem = {} + for room_id in events_by_room: + event_ids = self._simple_select_onecol_txn( + txn, + table="event_forward_extremities", + keyvalues={"room_id": room_id}, + retcol="event_id", + ) + new_extrem[room_id] = event_ids + + self._simple_insert_many_txn( + txn, + table="stream_ordering_to_exterm", + values=[ + { + "room_id": room_id, + "event_id": event_id, + "stream_ordering": max_stream_ord, + } + for room_id, extrem_evs in new_extrem.items() + for event_id in extrem_evs + ] + ) + query = ( "INSERT INTO event_backward_extremities (event_id, room_id)" " SELECT ?, ? WHERE NOT EXISTS (" @@ -289,6 +344,75 @@ class EventFederationStore(SQLBaseStore): self.get_latest_event_ids_in_room.invalidate, (room_id,) ) + def get_forward_extremeties_for_room(self, room_id, stream_ordering): + # We want to make the cache more effective, so we clamp to the last + # change before the given ordering. + last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id) + + # We don't always have a full stream_to_exterm_id table, e.g. after + # the upgrade that introduced it, so we make sure we never ask for a + # try and pin to a stream_ordering from before a restart + last_change = max(self._stream_order_on_start, last_change) + + if last_change > self.stream_ordering_month_ago: + stream_ordering = min(last_change, stream_ordering) + + return self._get_forward_extremeties_for_room(room_id, stream_ordering) + + @cached(max_entries=5000, num_args=2) + def _get_forward_extremeties_for_room(self, room_id, stream_ordering): + """For a given room_id and stream_ordering, return the forward + extremeties of the room at that point in "time". + + Throws a StoreError if we have since purged the index for + stream_orderings from that point. + """ + + if stream_ordering <= self.stream_ordering_month_ago: + raise StoreError(400, "stream_ordering too old") + + sql = (""" + SELECT event_id FROM stream_ordering_to_exterm + INNER JOIN ( + SELECT room_id, MAX(stream_ordering) AS stream_ordering + FROM stream_ordering_to_exterm + WHERE stream_ordering <= ? GROUP BY room_id + ) AS rms USING (room_id, stream_ordering) + WHERE room_id = ? + """) + + def get_forward_extremeties_for_room_txn(txn): + txn.execute(sql, (stream_ordering, room_id)) + rows = txn.fetchall() + return [event_id for event_id, in rows] + + return self.runInteraction( + "get_forward_extremeties_for_room", + get_forward_extremeties_for_room_txn + ) + + def _delete_old_forward_extrem_cache(self): + def _delete_old_forward_extrem_cache_txn(txn): + # Delete entries older than a month, while making sure we don't delete + # the only entries for a room. + sql = (""" + DELETE FROM stream_ordering_to_exterm + WHERE + room_id IN ( + SELECT room_id + FROM stream_ordering_to_exterm + WHERE stream_ordering > ? + ) AND stream_ordering < ? + """) + txn.execute( + sql, + (self.stream_ordering_month_ago, self.stream_ordering_month_ago,) + ) + return self.runInteraction( + "_delete_old_forward_extrem_cache", + _delete_old_forward_extrem_cache_txn + ) + def get_backfill_events(self, room_id, event_list, limit): """Get a list of Events for a given topic that occurred before (and including) the events in event_list. Return a list of max size `limit` diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index 5820539a92..7de3e8c58c 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -16,6 +16,8 @@ from ._base import SQLBaseStore from twisted.internet import defer from synapse.util.caches.descriptors import cachedInlineCallbacks +from synapse.types import RoomStreamToken +from .stream import lower_bound import logging import ujson as json @@ -24,10 +26,32 @@ logger = logging.getLogger(__name__) class EventPushActionsStore(SQLBaseStore): + EPA_HIGHLIGHT_INDEX = "epa_highlight_index" + + def __init__(self, hs): + self.stream_ordering_month_ago = None + super(EventPushActionsStore, self).__init__(hs) + + self.register_background_index_update( + self.EPA_HIGHLIGHT_INDEX, + index_name="event_push_actions_u_highlight", + table="event_push_actions", + columns=["user_id", "stream_ordering"], + ) + + self.register_background_index_update( + "event_push_actions_highlights_index", + index_name="event_push_actions_highlights_index", + table="event_push_actions", + columns=["user_id", "room_id", "topological_ordering", "stream_ordering"], + where_clause="highlight=1" + ) + def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples): """ - :param event: the event set actions for - :param tuples: list of tuples of (user_id, actions) + Args: + event: the event set actions for + tuples: list of tuples of (user_id, actions) """ values = [] for uid, actions in tuples: @@ -49,7 +73,7 @@ class EventPushActionsStore(SQLBaseStore): ) self._simple_insert_many_txn(txn, "event_push_actions", values) - @cachedInlineCallbacks(num_args=3, lru=True, tree=True) + @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000) def get_unread_event_push_actions_by_room_for_user( self, room_id, user_id, last_read_event_id ): @@ -68,30 +92,45 @@ class EventPushActionsStore(SQLBaseStore): stream_ordering = results[0][0] topological_ordering = results[0][1] + token = RoomStreamToken( + topological_ordering, stream_ordering + ) + # First get number of notifications. + # We don't need to put a notif=1 clause as all rows always have + # notif=1 sql = ( - "SELECT sum(notif), sum(highlight)" + "SELECT count(*)" " FROM event_push_actions ea" " WHERE" " user_id = ?" " AND room_id = ?" - " AND (" - " topological_ordering > ?" - " OR (topological_ordering = ? AND stream_ordering > ?)" - ")" - ) - txn.execute(sql, ( - user_id, room_id, - topological_ordering, topological_ordering, stream_ordering - )) + " AND %s" + ) % (lower_bound(token, self.database_engine, inclusive=False),) + + txn.execute(sql, (user_id, room_id)) row = txn.fetchone() - if row: - return { - "notify_count": row[0] or 0, - "highlight_count": row[1] or 0, - } - else: - return {"notify_count": 0, "highlight_count": 0} + notify_count = row[0] if row else 0 + + # Now get the number of highlights + sql = ( + "SELECT count(*)" + " FROM event_push_actions ea" + " WHERE" + " highlight = 1" + " AND user_id = ?" + " AND room_id = ?" + " AND %s" + ) % (lower_bound(token, self.database_engine, inclusive=False),) + + txn.execute(sql, (user_id, room_id)) + row = txn.fetchone() + highlight_count = row[0] if row else 0 + + return { + "notify_count": notify_count, + "highlight_count": highlight_count, + } ret = yield self.runInteraction( "get_unread_event_push_actions_by_room", @@ -99,6 +138,304 @@ class EventPushActionsStore(SQLBaseStore): ) defer.returnValue(ret) + @defer.inlineCallbacks + def get_push_action_users_in_range(self, min_stream_ordering, max_stream_ordering): + def f(txn): + sql = ( + "SELECT DISTINCT(user_id) FROM event_push_actions WHERE" + " stream_ordering >= ? AND stream_ordering <= ?" + ) + txn.execute(sql, (min_stream_ordering, max_stream_ordering)) + return [r[0] for r in txn.fetchall()] + ret = yield self.runInteraction("get_push_action_users_in_range", f) + defer.returnValue(ret) + + @defer.inlineCallbacks + def get_unread_push_actions_for_user_in_range_for_http( + self, user_id, min_stream_ordering, max_stream_ordering, limit=20 + ): + """Get a list of the most recent unread push actions for a given user, + within the given stream ordering range. Called by the httppusher. + + Args: + user_id (str): The user to fetch push actions for. + min_stream_ordering(int): The exclusive lower bound on the + stream ordering of event push actions to fetch. + max_stream_ordering(int): The inclusive upper bound on the + stream ordering of event push actions to fetch. + limit (int): The maximum number of rows to return. + Returns: + A promise which resolves to a list of dicts with the keys "event_id", + "room_id", "stream_ordering", "actions". + The list will be ordered by ascending stream_ordering. + The list will have between 0~limit entries. + """ + # find rooms that have a read receipt in them and return the next + # push actions + def get_after_receipt(txn): + # find rooms that have a read receipt in them and return the next + # push actions + sql = ( + "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions" + " FROM (" + " SELECT room_id," + " MAX(topological_ordering) as topological_ordering," + " MAX(stream_ordering) as stream_ordering" + " FROM events" + " INNER JOIN receipts_linearized USING (room_id, event_id)" + " WHERE receipt_type = 'm.read' AND user_id = ?" + " GROUP BY room_id" + ") AS rl," + " event_push_actions AS ep" + " WHERE" + " ep.room_id = rl.room_id" + " AND (" + " ep.topological_ordering > rl.topological_ordering" + " OR (" + " ep.topological_ordering = rl.topological_ordering" + " AND ep.stream_ordering > rl.stream_ordering" + " )" + " )" + " AND ep.user_id = ?" + " AND ep.stream_ordering > ?" + " AND ep.stream_ordering <= ?" + " ORDER BY ep.stream_ordering ASC LIMIT ?" + ) + args = [ + user_id, user_id, + min_stream_ordering, max_stream_ordering, limit, + ] + txn.execute(sql, args) + return txn.fetchall() + after_read_receipt = yield self.runInteraction( + "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt + ) + + # There are rooms with push actions in them but you don't have a read receipt in + # them e.g. rooms you've been invited to, so get push actions for rooms which do + # not have read receipts in them too. + def get_no_receipt(txn): + sql = ( + "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," + " e.received_ts" + " FROM event_push_actions AS ep" + " INNER JOIN events AS e USING (room_id, event_id)" + " WHERE" + " ep.room_id NOT IN (" + " SELECT room_id FROM receipts_linearized" + " WHERE receipt_type = 'm.read' AND user_id = ?" + " GROUP BY room_id" + " )" + " AND ep.user_id = ?" + " AND ep.stream_ordering > ?" + " AND ep.stream_ordering <= ?" + " ORDER BY ep.stream_ordering ASC LIMIT ?" + ) + args = [ + user_id, user_id, + min_stream_ordering, max_stream_ordering, limit, + ] + txn.execute(sql, args) + return txn.fetchall() + no_read_receipt = yield self.runInteraction( + "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt + ) + + notifs = [ + { + "event_id": row[0], + "room_id": row[1], + "stream_ordering": row[2], + "actions": json.loads(row[3]), + } for row in after_read_receipt + no_read_receipt + ] + + # Now sort it so it's ordered correctly, since currently it will + # contain results from the first query, correctly ordered, followed + # by results from the second query, but we want them all ordered + # by stream_ordering, oldest first. + notifs.sort(key=lambda r: r['stream_ordering']) + + # Take only up to the limit. We have to stop at the limit because + # one of the subqueries may have hit the limit. + defer.returnValue(notifs[:limit]) + + @defer.inlineCallbacks + def get_unread_push_actions_for_user_in_range_for_email( + self, user_id, min_stream_ordering, max_stream_ordering, limit=20 + ): + """Get a list of the most recent unread push actions for a given user, + within the given stream ordering range. Called by the emailpusher + + Args: + user_id (str): The user to fetch push actions for. + min_stream_ordering(int): The exclusive lower bound on the + stream ordering of event push actions to fetch. + max_stream_ordering(int): The inclusive upper bound on the + stream ordering of event push actions to fetch. + limit (int): The maximum number of rows to return. + Returns: + A promise which resolves to a list of dicts with the keys "event_id", + "room_id", "stream_ordering", "actions", "received_ts". + The list will be ordered by descending received_ts. + The list will have between 0~limit entries. + """ + # find rooms that have a read receipt in them and return the most recent + # push actions + def get_after_receipt(txn): + sql = ( + "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," + " e.received_ts" + " FROM (" + " SELECT room_id," + " MAX(topological_ordering) as topological_ordering," + " MAX(stream_ordering) as stream_ordering" + " FROM events" + " INNER JOIN receipts_linearized USING (room_id, event_id)" + " WHERE receipt_type = 'm.read' AND user_id = ?" + " GROUP BY room_id" + ") AS rl," + " event_push_actions AS ep" + " INNER JOIN events AS e USING (room_id, event_id)" + " WHERE" + " ep.room_id = rl.room_id" + " AND (" + " ep.topological_ordering > rl.topological_ordering" + " OR (" + " ep.topological_ordering = rl.topological_ordering" + " AND ep.stream_ordering > rl.stream_ordering" + " )" + " )" + " AND ep.user_id = ?" + " AND ep.stream_ordering > ?" + " AND ep.stream_ordering <= ?" + " ORDER BY ep.stream_ordering DESC LIMIT ?" + ) + args = [ + user_id, user_id, + min_stream_ordering, max_stream_ordering, limit, + ] + txn.execute(sql, args) + return txn.fetchall() + after_read_receipt = yield self.runInteraction( + "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt + ) + + # There are rooms with push actions in them but you don't have a read receipt in + # them e.g. rooms you've been invited to, so get push actions for rooms which do + # not have read receipts in them too. + def get_no_receipt(txn): + sql = ( + "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," + " e.received_ts" + " FROM event_push_actions AS ep" + " INNER JOIN events AS e USING (room_id, event_id)" + " WHERE" + " ep.room_id NOT IN (" + " SELECT room_id FROM receipts_linearized" + " WHERE receipt_type = 'm.read' AND user_id = ?" + " GROUP BY room_id" + " )" + " AND ep.user_id = ?" + " AND ep.stream_ordering > ?" + " AND ep.stream_ordering <= ?" + " ORDER BY ep.stream_ordering DESC LIMIT ?" + ) + args = [ + user_id, user_id, + min_stream_ordering, max_stream_ordering, limit, + ] + txn.execute(sql, args) + return txn.fetchall() + no_read_receipt = yield self.runInteraction( + "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt + ) + + # Make a list of dicts from the two sets of results. + notifs = [ + { + "event_id": row[0], + "room_id": row[1], + "stream_ordering": row[2], + "actions": json.loads(row[3]), + "received_ts": row[4], + } for row in after_read_receipt + no_read_receipt + ] + + # Now sort it so it's ordered correctly, since currently it will + # contain results from the first query, correctly ordered, followed + # by results from the second query, but we want them all ordered + # by received_ts (most recent first) + notifs.sort(key=lambda r: -(r['received_ts'] or 0)) + + # Now return the first `limit` + defer.returnValue(notifs[:limit]) + + @defer.inlineCallbacks + def get_push_actions_for_user(self, user_id, before=None, limit=50, + only_highlight=False): + def f(txn): + before_clause = "" + if before: + before_clause = "AND epa.stream_ordering < ?" + args = [user_id, before, limit] + else: + args = [user_id, limit] + + if only_highlight: + if len(before_clause) > 0: + before_clause += " " + before_clause += "AND epa.highlight = 1" + + # NB. This assumes event_ids are globally unique since + # it makes the query easier to index + sql = ( + "SELECT epa.event_id, epa.room_id," + " epa.stream_ordering, epa.topological_ordering," + " epa.actions, epa.profile_tag, e.received_ts" + " FROM event_push_actions epa, events e" + " WHERE epa.event_id = e.event_id" + " AND epa.user_id = ? %s" + " ORDER BY epa.stream_ordering DESC" + " LIMIT ?" + % (before_clause,) + ) + txn.execute(sql, args) + return self.cursor_to_dict(txn) + + push_actions = yield self.runInteraction( + "get_push_actions_for_user", f + ) + for pa in push_actions: + pa["actions"] = json.loads(pa["actions"]) + defer.returnValue(push_actions) + + @defer.inlineCallbacks + def get_time_of_last_push_action_before(self, stream_ordering): + def f(txn): + sql = ( + "SELECT e.received_ts" + " FROM event_push_actions AS ep" + " JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id" + " WHERE ep.stream_ordering > ?" + " ORDER BY ep.stream_ordering ASC" + " LIMIT 1" + ) + txn.execute(sql, (stream_ordering,)) + return txn.fetchone() + result = yield self.runInteraction("get_time_of_last_push_action_before", f) + defer.returnValue(result[0] if result else None) + + @defer.inlineCallbacks + def get_latest_push_action_stream_ordering(self): + def f(txn): + txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions") + return txn.fetchone() + result = yield self.runInteraction( + "get_latest_push_action_stream_ordering", f + ) + defer.returnValue(result[0] or 0) + def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id): # Sad that we have to blow away the cache for the whole room here txn.call_after( @@ -110,6 +447,93 @@ class EventPushActionsStore(SQLBaseStore): (room_id, event_id) ) + def _remove_old_push_actions_before_txn(self, txn, room_id, user_id, + topological_ordering): + """ + Purges old, stale push actions for a user and room before a given + topological_ordering + Args: + txn: The transcation + room_id: Room ID to delete from + user_id: user ID to delete for + topological_ordering: The lowest topological ordering which will + not be deleted. + """ + txn.call_after( + self.get_unread_event_push_actions_by_room_for_user.invalidate_many, + (room_id, user_id, ) + ) + + # We need to join on the events table to get the received_ts for + # event_push_actions and sqlite won't let us use a join in a delete so + # we can't just delete where received_ts < x. Furthermore we can + # only identify event_push_actions by a tuple of room_id, event_id + # we we can't use a subquery. + # Instead, we look up the stream ordering for the last event in that + # room received before the threshold time and delete event_push_actions + # in the room with a stream_odering before that. + txn.execute( + "DELETE FROM event_push_actions " + " WHERE user_id = ? AND room_id = ? AND " + " topological_ordering < ? AND stream_ordering < ?", + (user_id, room_id, topological_ordering, self.stream_ordering_month_ago) + ) + + @defer.inlineCallbacks + def _find_stream_orderings_for_times(self): + yield self.runInteraction( + "_find_stream_orderings_for_times", + self._find_stream_orderings_for_times_txn + ) + + def _find_stream_orderings_for_times_txn(self, txn): + logger.info("Searching for stream ordering 1 month ago") + self.stream_ordering_month_ago = self._find_first_stream_ordering_after_ts_txn( + txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000 + ) + logger.info( + "Found stream ordering 1 month ago: it's %d", + self.stream_ordering_month_ago + ) + + def _find_first_stream_ordering_after_ts_txn(self, txn, ts): + """ + Find the stream_ordering of the first event that was received after + a given timestamp. This is relatively slow as there is no index on + received_ts but we can then use this to delete push actions before + this. + + received_ts must necessarily be in the same order as stream_ordering + and stream_ordering is indexed, so we manually binary search using + stream_ordering + """ + txn.execute("SELECT MAX(stream_ordering) FROM events") + max_stream_ordering = txn.fetchone()[0] + + if max_stream_ordering is None: + return 0 + + range_start = 0 + range_end = max_stream_ordering + + sql = ( + "SELECT received_ts FROM events" + " WHERE stream_ordering > ?" + " ORDER BY stream_ordering" + " LIMIT 1" + ) + + while range_end - range_start > 1: + middle = int((range_end + range_start) / 2) + txn.execute(sql, (middle,)) + middle_ts = txn.fetchone()[0] + if ts > middle_ts: + range_start = middle + else: + range_end = middle + + return range_end + def _action_has_highlight(actions): for action in actions: diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 552e7ca35b..ecb79c07ef 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -19,12 +19,22 @@ from twisted.internet import defer, reactor from synapse.events import FrozenEvent, USE_FROZEN_DICTS from synapse.events.utils import prune_event -from synapse.util.logcontext import preserve_fn, PreserveLoggingContext +from synapse.util.async import ObservableDeferred +from synapse.util.logcontext import ( + preserve_fn, PreserveLoggingContext, preserve_context_over_deferred +) from synapse.util.logutils import log_function +from synapse.util.metrics import Measure from synapse.api.constants import EventTypes +from synapse.api.errors import SynapseError from canonicaljson import encode_canonical_json -from contextlib import contextmanager +from collections import deque, namedtuple, OrderedDict +from functools import wraps + +import synapse +import synapse.metrics + import logging import math @@ -33,6 +43,10 @@ import ujson as json logger = logging.getLogger(__name__) +metrics = synapse.metrics.get_metrics_for(__name__) +persist_event_counter = metrics.register_counter("persisted_events") + + def encode_json(json_object): if USE_FROZEN_DICTS: # ujson doesn't like frozen_dicts @@ -40,6 +54,7 @@ def encode_json(json_object): else: return json.dumps(json_object, ensure_ascii=False) + # These values are used in the `enqueus_event` and `_do_fetch` methods to # control how we batch/bulk fetch events from the database. # The values are plucked out of thing air to make initial sync run faster @@ -50,37 +65,225 @@ EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events +class _EventPeristenceQueue(object): + """Queues up events so that they can be persisted in bulk with only one + concurrent transaction per room. + """ + + _EventPersistQueueItem = namedtuple("_EventPersistQueueItem", ( + "events_and_contexts", "current_state", "backfilled", "deferred", + )) + + def __init__(self): + self._event_persist_queues = {} + self._currently_persisting_rooms = set() + + def add_to_queue(self, room_id, events_and_contexts, backfilled, current_state): + """Add events to the queue, with the given persist_event options. + """ + queue = self._event_persist_queues.setdefault(room_id, deque()) + if queue: + end_item = queue[-1] + if end_item.current_state or current_state: + # We perist events with current_state set to True one at a time + pass + if end_item.backfilled == backfilled: + end_item.events_and_contexts.extend(events_and_contexts) + return end_item.deferred.observe() + + deferred = ObservableDeferred(defer.Deferred()) + + queue.append(self._EventPersistQueueItem( + events_and_contexts=events_and_contexts, + backfilled=backfilled, + current_state=current_state, + deferred=deferred, + )) + + return deferred.observe() + + def handle_queue(self, room_id, per_item_callback): + """Attempts to handle the queue for a room if not already being handled. + + The given callback will be invoked with for each item in the queue,1 + of type _EventPersistQueueItem. The per_item_callback will continuously + be called with new items, unless the queue becomnes empty. The return + value of the function will be given to the deferreds waiting on the item, + exceptions will be passed to the deferres as well. + + This function should therefore be called whenever anything is added + to the queue. + + If another callback is currently handling the queue then it will not be + invoked. + """ + + if room_id in self._currently_persisting_rooms: + return + + self._currently_persisting_rooms.add(room_id) + + @defer.inlineCallbacks + def handle_queue_loop(): + try: + queue = self._get_drainining_queue(room_id) + for item in queue: + try: + ret = yield per_item_callback(item) + item.deferred.callback(ret) + except Exception as e: + item.deferred.errback(e) + finally: + queue = self._event_persist_queues.pop(room_id, None) + if queue: + self._event_persist_queues[room_id] = queue + self._currently_persisting_rooms.discard(room_id) + + preserve_fn(handle_queue_loop)() + + def _get_drainining_queue(self, room_id): + queue = self._event_persist_queues.setdefault(room_id, deque()) + + try: + while True: + yield queue.popleft() + except IndexError: + # Queue has been drained. + pass + + +_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) + + +def _retry_on_integrity_error(func): + """Wraps a database function so that it gets retried on IntegrityError, + with `delete_existing=True` passed in. + + Args: + func: function that returns a Deferred and accepts a `delete_existing` arg + """ + @wraps(func) + @defer.inlineCallbacks + def f(self, *args, **kwargs): + try: + res = yield func(self, *args, **kwargs) + except self.database_engine.module.IntegrityError: + logger.exception("IntegrityError, retrying.") + res = yield func(self, *args, delete_existing=True, **kwargs) + defer.returnValue(res) + + return f + + class EventsStore(SQLBaseStore): EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" + EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" def __init__(self, hs): super(EventsStore, self).__init__(hs) + self._clock = hs.get_clock() self.register_background_update_handler( self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts ) + self.register_background_update_handler( + self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, + self._background_reindex_fields_sender, + ) + + self.register_background_index_update( + "event_contains_url_index", + index_name="event_contains_url_index", + table="events", + columns=["room_id", "topological_ordering", "stream_ordering"], + where_clause="contains_url = true AND outlier = false", + ) + + self._event_persist_queue = _EventPeristenceQueue() + + def persist_events(self, events_and_contexts, backfilled=False): + """ + Write events to the database + Args: + events_and_contexts: list of tuples of (event, context) + backfilled: ? + """ + partitioned = {} + for event, ctx in events_and_contexts: + partitioned.setdefault(event.room_id, []).append((event, ctx)) + + deferreds = [] + for room_id, evs_ctxs in partitioned.items(): + d = preserve_fn(self._event_persist_queue.add_to_queue)( + room_id, evs_ctxs, + backfilled=backfilled, + current_state=None, + ) + deferreds.append(d) + + for room_id in partitioned.keys(): + self._maybe_start_persisting(room_id) + + return preserve_context_over_deferred( + defer.gatherResults(deferreds, consumeErrors=True) + ) + + @defer.inlineCallbacks + @log_function + def persist_event(self, event, context, current_state=None, backfilled=False): + deferred = self._event_persist_queue.add_to_queue( + event.room_id, [(event, context)], + backfilled=backfilled, + current_state=current_state, + ) + + self._maybe_start_persisting(event.room_id) + + yield preserve_context_over_deferred(deferred) + + max_persisted_id = yield self._stream_id_gen.get_current_token() + defer.returnValue((event.internal_metadata.stream_ordering, max_persisted_id)) + + def _maybe_start_persisting(self, room_id): + @defer.inlineCallbacks + def persisting_queue(item): + if item.current_state: + for event, context in item.events_and_contexts: + # There should only ever be one item in + # events_and_contexts when current_state is + # not None + yield self._persist_event( + event, context, + current_state=item.current_state, + backfilled=item.backfilled, + ) + else: + yield self._persist_events( + item.events_and_contexts, + backfilled=item.backfilled, + ) + self._event_persist_queue.handle_queue(room_id, persisting_queue) + + @_retry_on_integrity_error @defer.inlineCallbacks - def persist_events(self, events_and_contexts, backfilled=False, - is_new_state=True): + def _persist_events(self, events_and_contexts, backfilled=False, + delete_existing=False): if not events_and_contexts: return if backfilled: - start = self.min_stream_token - 1 - self.min_stream_token -= len(events_and_contexts) + 1 - stream_orderings = range(start, self.min_stream_token, -1) - - @contextmanager - def stream_ordering_manager(): - yield stream_orderings - stream_ordering_manager = stream_ordering_manager() + stream_ordering_manager = self._backfill_id_gen.get_next_mult( + len(events_and_contexts) + ) else: stream_ordering_manager = self._stream_id_gen.get_next_mult( len(events_and_contexts) ) with stream_ordering_manager as stream_orderings: - for (event, _), stream in zip(events_and_contexts, stream_orderings): + for (event, context), stream, in zip( + events_and_contexts, stream_orderings + ): event.internal_metadata.stream_ordering = stream chunks = [ @@ -96,44 +299,31 @@ class EventsStore(SQLBaseStore): self._persist_events_txn, events_and_contexts=chunk, backfilled=backfilled, - is_new_state=is_new_state, + delete_existing=delete_existing, ) + persist_event_counter.inc_by(len(chunk)) + @_retry_on_integrity_error @defer.inlineCallbacks @log_function - def persist_event(self, event, context, backfilled=False, - is_new_state=True, current_state=None): - stream_ordering = None - if backfilled: - self.min_stream_token -= 1 - stream_ordering = self.min_stream_token - - if stream_ordering is None: - stream_ordering_manager = self._stream_id_gen.get_next() - else: - @contextmanager - def stream_ordering_manager(): - yield stream_ordering - stream_ordering_manager = stream_ordering_manager() - + def _persist_event(self, event, context, current_state=None, backfilled=False, + delete_existing=False): try: - with stream_ordering_manager as stream_ordering: + with self._stream_id_gen.get_next() as stream_ordering: event.internal_metadata.stream_ordering = stream_ordering yield self.runInteraction( "persist_event", self._persist_event_txn, event=event, context=context, - backfilled=backfilled, - is_new_state=is_new_state, current_state=current_state, + backfilled=backfilled, + delete_existing=delete_existing, ) + persist_event_counter.inc() except _RollbackButIsFineException: pass - max_persisted_id = yield self._stream_id_gen.get_max_token() - defer.returnValue((stream_ordering, max_persisted_id)) - @defer.inlineCallbacks def get_event(self, event_id, check_redacted=True, get_prev_content=False, allow_rejected=False, @@ -161,21 +351,53 @@ class EventsStore(SQLBaseStore): ) if not events and not allow_none: - raise RuntimeError("Could not find event %s" % (event_id,)) + raise SynapseError(404, "Could not find event %s" % (event_id,)) defer.returnValue(events[0] if events else None) + @defer.inlineCallbacks + def get_events(self, event_ids, check_redacted=True, + get_prev_content=False, allow_rejected=False): + """Get events from the database + + Args: + event_ids (list): The event_ids of the events to fetch + check_redacted (bool): If True, check if event has been redacted + and redact it. + get_prev_content (bool): If True and event is a state event, + include the previous states content in the unsigned field. + allow_rejected (bool): If True return rejected events. + + Returns: + Deferred : Dict from event_id to event. + """ + events = yield self._get_events( + event_ids, + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) + + defer.returnValue({e.event_id: e for e in events}) + @log_function - def _persist_event_txn(self, txn, event, context, backfilled, - is_new_state=True, current_state=None): + def _persist_event_txn(self, txn, event, context, current_state, backfilled=False, + delete_existing=False): # We purposefully do this first since if we include a `current_state` # key, we *want* to update the `current_state_events` table if current_state: - txn.call_after(self.get_current_state_for_key.invalidate_all) + txn.call_after(self._get_current_state_for_key.invalidate_all) txn.call_after(self.get_rooms_for_user.invalidate_all) txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) - txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,)) - txn.call_after(self.get_room_name_and_aliases, event.room_id) + + # Add an entry to the current_state_resets table to record the point + # where we clobbered the current state + stream_order = event.internal_metadata.stream_ordering + self._simple_insert_txn( + txn, + table="current_state_resets", + values={"event_stream_ordering": stream_order} + ) self._simple_delete_txn( txn, @@ -199,12 +421,38 @@ class EventsStore(SQLBaseStore): txn, [(event, context)], backfilled=backfilled, - is_new_state=is_new_state, + delete_existing=delete_existing, ) @log_function def _persist_events_txn(self, txn, events_and_contexts, backfilled, - is_new_state=True): + delete_existing=False): + """Insert some number of room events into the necessary database tables. + + Rejected events are only inserted into the events table, the events_json table, + and the rejections table. Things reading from those table will need to check + whether the event was rejected. + + If delete_existing is True then existing events will be purged from the + database before insertion. This is useful when retrying due to IntegrityError. + """ + # Ensure that we don't have the same event twice. + # Pick the earliest non-outlier if there is one, else the earliest one. + new_events_and_contexts = OrderedDict() + for event, context in events_and_contexts: + prev_event_context = new_events_and_contexts.get(event.event_id) + if prev_event_context: + if not event.internal_metadata.is_outlier(): + if prev_event_context[0].internal_metadata.is_outlier(): + # To ensure correct ordering we pop, as OrderedDict is + # ordered by first insertion. + new_events_and_contexts.pop(event.event_id, None) + new_events_and_contexts[event.event_id] = (event, context) + else: + new_events_and_contexts[event.event_id] = (event, context) + + events_and_contexts = new_events_and_contexts.values() + depth_updates = {} for event, context in events_and_contexts: # Remove the any existing cache entries for the event_ids @@ -215,21 +463,11 @@ class EventsStore(SQLBaseStore): event.room_id, event.internal_metadata.stream_ordering, ) - if not event.internal_metadata.is_outlier(): + if not event.internal_metadata.is_outlier() and not context.rejected: depth_updates[event.room_id] = max( event.depth, depth_updates.get(event.room_id, event.depth) ) - if context.push_actions: - self._set_push_actions_for_event_and_users_txn( - txn, event, context.push_actions - ) - - if event.type == EventTypes.Redaction and event.redacts is not None: - self._remove_push_actions_for_event_id_txn( - txn, event.room_id, event.redacts - ) - for room_id, depth in depth_updates.items(): self._update_min_depth_for_room_txn(txn, room_id, depth) @@ -239,30 +477,21 @@ class EventsStore(SQLBaseStore): ), [event.event_id for event, _ in events_and_contexts] ) + have_persisted = { event_id: outlier for event_id, outlier in txn.fetchall() } - event_map = {} to_remove = set() for event, context in events_and_contexts: - # Handle the case of the list including the same event multiple - # times. The tricky thing here is when they differ by whether - # they are an outlier. - if event.event_id in event_map: - other = event_map[event.event_id] - - if not other.internal_metadata.is_outlier(): + if context.rejected: + # If the event is rejected then we don't care if the event + # was an outlier or not. + if event.event_id in have_persisted: + # If we have already seen the event then ignore it. to_remove.add(event) - continue - elif not event.internal_metadata.is_outlier(): - to_remove.add(event) - continue - else: - to_remove.add(other) - - event_map[event.event_id] = event + continue if event.event_id not in have_persisted: continue @@ -271,9 +500,17 @@ class EventsStore(SQLBaseStore): outlier_persisted = have_persisted[event.event_id] if not event.internal_metadata.is_outlier() and outlier_persisted: - self._store_state_groups_txn( - txn, event, context, - ) + # We received a copy of an event that we had already stored as + # an outlier in the database. We now have some state at that + # so we need to update the state_groups table with that state. + + # insert into the state_group, state_groups_state and + # event_to_state_groups tables. + try: + self._store_mult_state_groups_txn(txn, ((event, context),)) + except Exception: + logger.exception("") + raise metadata_json = encode_json( event.internal_metadata.get_dict() @@ -288,6 +525,20 @@ class EventsStore(SQLBaseStore): (metadata_json, event.event_id,) ) + # Add an entry to the ex_outlier_stream table to replicate the + # change in outlier status to our workers. + stream_order = event.internal_metadata.stream_ordering + state_group_id = context.state_group + self._simple_insert_txn( + txn, + table="ex_outlier_stream", + values={ + "event_stream_ordering": stream_order, + "event_id": event.event_id, + "state_group": state_group_id, + } + ) + sql = ( "UPDATE events SET outlier = ?" " WHERE event_id = ?" @@ -297,49 +548,21 @@ class EventsStore(SQLBaseStore): (False, event.event_id,) ) + # Update the event_backward_extremities table now that this + # event isn't an outlier any more. self._update_extremeties(txn, [event]) - events_and_contexts = filter( - lambda ec: ec[0] not in to_remove, - events_and_contexts - ) + events_and_contexts = [ + ec for ec in events_and_contexts if ec[0] not in to_remove + ] if not events_and_contexts: + # Make sure we don't pass an empty list to functions that expect to + # be storing at least one element. return - self._store_mult_state_groups_txn(txn, [ - (event, context) - for event, context in events_and_contexts - if not event.internal_metadata.is_outlier() - ]) - - self._handle_mult_prev_events( - txn, - events=[event for event, _ in events_and_contexts], - ) - - for event, _ in events_and_contexts: - if event.type == EventTypes.Name: - self._store_room_name_txn(txn, event) - elif event.type == EventTypes.Topic: - self._store_room_topic_txn(txn, event) - elif event.type == EventTypes.Message: - self._store_room_message_txn(txn, event) - elif event.type == EventTypes.Redaction: - self._store_redaction(txn, event) - elif event.type == EventTypes.RoomHistoryVisibility: - self._store_history_visibility_txn(txn, event) - elif event.type == EventTypes.GuestAccess: - self._store_guest_access_txn(txn, event) - - self._store_room_members_txn( - txn, - [ - event - for event, _ in events_and_contexts - if event.type == EventTypes.Member - ] - ) + # From this point onwards the events are only events that we haven't + # seen before. def event_dict(event): return { @@ -351,6 +574,43 @@ class EventsStore(SQLBaseStore): ] } + if delete_existing: + # For paranoia reasons, we go and delete all the existing entries + # for these events so we can reinsert them. + # This gets around any problems with some tables already having + # entries. + + logger.info("Deleting existing") + + for table in ( + "events", + "event_auth", + "event_json", + "event_content_hashes", + "event_destinations", + "event_edge_hashes", + "event_edges", + "event_forward_extremities", + "event_push_actions", + "event_reference_hashes", + "event_search", + "event_signatures", + "event_to_state_groups", + "guest_access", + "history_visibility", + "local_invites", + "room_names", + "state_events", + "rejections", + "redactions", + "room_memberships", + "topics" + ): + txn.executemany( + "DELETE FROM %s WHERE event_id = ?" % (table,), + [(ev.event_id,) for ev, _ in events_and_contexts] + ) + self._simple_insert_many_txn( txn, table="event_json", @@ -382,15 +642,52 @@ class EventsStore(SQLBaseStore): "outlier": event.internal_metadata.is_outlier(), "content": encode_json(event.content).decode("UTF-8"), "origin_server_ts": int(event.origin_server_ts), + "received_ts": self._clock.time_msec(), + "sender": event.sender, + "contains_url": ( + "url" in event.content + and isinstance(event.content["url"], basestring) + ), } for event, _ in events_and_contexts ], ) - if context.rejected: - self._store_rejections_txn( - txn, event.event_id, context.rejected - ) + # Remove the rejected events from the list now that we've added them + # to the events table and the events_json table. + to_remove = set() + for event, context in events_and_contexts: + if context.rejected: + # Insert the event_id into the rejections table + self._store_rejections_txn( + txn, event.event_id, context.rejected + ) + to_remove.add(event) + + events_and_contexts = [ + ec for ec in events_and_contexts if ec[0] not in to_remove + ] + + if not events_and_contexts: + # Make sure we don't pass an empty list to functions that expect to + # be storing at least one element. + return + + # From this point onwards the events are only ones that weren't rejected. + + for event, context in events_and_contexts: + # Insert all the push actions into the event_push_actions table. + if context.push_actions: + self._set_push_actions_for_event_and_users_txn( + txn, event, context.push_actions + ) + + if event.type == EventTypes.Redaction and event.redacts is not None: + # Remove the entries in the event_push_actions table for the + # redacted event. + self._remove_push_actions_for_event_id_txn( + txn, event.room_id, event.redacts + ) self._simple_insert_many_txn( txn, @@ -406,14 +703,56 @@ class EventsStore(SQLBaseStore): ], ) + # Insert into the state_groups, state_groups_state, and + # event_to_state_groups tables. + self._store_mult_state_groups_txn(txn, events_and_contexts) + + # Update the event_forward_extremities, event_backward_extremities and + # event_edges tables. + self._handle_mult_prev_events( + txn, + events=[event for event, _ in events_and_contexts], + ) + + for event, _ in events_and_contexts: + if event.type == EventTypes.Name: + # Insert into the room_names and event_search tables. + self._store_room_name_txn(txn, event) + elif event.type == EventTypes.Topic: + # Insert into the topics table and event_search table. + self._store_room_topic_txn(txn, event) + elif event.type == EventTypes.Message: + # Insert into the event_search table. + self._store_room_message_txn(txn, event) + elif event.type == EventTypes.Redaction: + # Insert into the redactions table. + self._store_redaction(txn, event) + elif event.type == EventTypes.RoomHistoryVisibility: + # Insert into the event_search table. + self._store_history_visibility_txn(txn, event) + elif event.type == EventTypes.GuestAccess: + # Insert into the event_search table. + self._store_guest_access_txn(txn, event) + + # Insert into the room_memberships table. + self._store_room_members_txn( + txn, + [ + event + for event, _ in events_and_contexts + if event.type == EventTypes.Member + ], + backfilled=backfilled, + ) + + # Insert event_reference_hashes table. self._store_event_reference_hashes_txn( txn, [event for event, _ in events_and_contexts] ) - state_events_and_contexts = filter( - lambda i: i[0].is_state(), - events_and_contexts, - ) + state_events_and_contexts = [ + ec for ec in events_and_contexts if ec[0].is_state() + ] state_values = [] for event, context in state_events_and_contexts: @@ -451,35 +790,78 @@ class EventsStore(SQLBaseStore): ], ) - if is_new_state: - for event, _ in state_events_and_contexts: - if not context.rejected: - txn.call_after( - self.get_current_state_for_key.invalidate, - (event.room_id, event.type, event.state_key,) - ) + # Prefill the event cache + self._add_to_cache(txn, events_and_contexts) - if event.type in [EventTypes.Name, EventTypes.Aliases]: - txn.call_after( - self.get_room_name_and_aliases.invalidate, - (event.room_id,) - ) + if backfilled: + # Backfilled events come before the current state so we don't need + # to update the current state table + return - self._simple_upsert_txn( - txn, - "current_state_events", - keyvalues={ - "room_id": event.room_id, - "type": event.type, - "state_key": event.state_key, - }, - values={ - "event_id": event.event_id, - } - ) + for event, _ in state_events_and_contexts: + if event.internal_metadata.is_outlier(): + # Outlier events shouldn't clobber the current state. + continue + + txn.call_after( + self._get_current_state_for_key.invalidate, + (event.room_id, event.type, event.state_key,) + ) + + self._simple_upsert_txn( + txn, + "current_state_events", + keyvalues={ + "room_id": event.room_id, + "type": event.type, + "state_key": event.state_key, + }, + values={ + "event_id": event.event_id, + } + ) return + def _add_to_cache(self, txn, events_and_contexts): + to_prefill = [] + + rows = [] + N = 200 + for i in range(0, len(events_and_contexts), N): + ev_map = { + e[0].event_id: e[0] + for e in events_and_contexts[i:i + N] + } + if not ev_map: + break + + sql = ( + "SELECT " + " e.event_id as event_id, " + " r.redacts as redacts," + " rej.event_id as rejects " + " FROM events as e" + " LEFT JOIN rejections as rej USING (event_id)" + " LEFT JOIN redactions as r ON e.event_id = r.redacts" + " WHERE e.event_id IN (%s)" + ) % (",".join(["?"] * len(ev_map)),) + + txn.execute(sql, ev_map.keys()) + rows = self.cursor_to_dict(txn) + for row in rows: + event = ev_map[row["event_id"]] + if not row["rejects"] and not row["redacts"]: + to_prefill.append(_EventCacheEntry( + event=event, + redacted_event=None, + )) + + def prefill(): + for cache_entry in to_prefill: + self._get_event_cache.prefill((cache_entry[0].event_id,), cache_entry) + txn.call_after(prefill) + def _store_redaction(self, txn, event): # invalidate the cache for the redacted event txn.call_after(self._invalidate_get_event_cache, event.redacts) @@ -488,6 +870,22 @@ class EventsStore(SQLBaseStore): (event.event_id, event.redacts) ) + @defer.inlineCallbacks + def have_events_in_timeline(self, event_ids): + """Given a list of event ids, check if we have already processed and + stored them as non outliers. + """ + rows = yield self._simple_select_many_batch( + table="events", + retcols=("event_id",), + column="event_id", + iterable=list(event_ids), + keyvalues={"outlier": False}, + desc="have_events_in_timeline", + ) + + defer.returnValue(set(r["event_id"] for r in rows)) + def have_events(self, event_ids): """Given a list of event ids, check if we have already processed them. @@ -526,105 +924,68 @@ class EventsStore(SQLBaseStore): if not event_ids: defer.returnValue([]) - event_map = self._get_events_from_cache( - event_ids, - check_redacted=check_redacted, - get_prev_content=get_prev_content, - allow_rejected=allow_rejected, - ) - - missing_events_ids = [e for e in event_ids if e not in event_map] - - if not missing_events_ids: - defer.returnValue([ - event_map[e_id] for e_id in event_ids - if e_id in event_map and event_map[e_id] - ]) - - missing_events = yield self._enqueue_events( - missing_events_ids, - check_redacted=check_redacted, - get_prev_content=get_prev_content, - allow_rejected=allow_rejected, - ) - - event_map.update(missing_events) - - defer.returnValue([ - event_map[e_id] for e_id in event_ids - if e_id in event_map and event_map[e_id] - ]) + event_id_list = event_ids + event_ids = set(event_ids) - def _get_events_txn(self, txn, event_ids, check_redacted=True, - get_prev_content=False, allow_rejected=False): - if not event_ids: - return [] - - event_map = self._get_events_from_cache( + event_entry_map = self._get_events_from_cache( event_ids, - check_redacted=check_redacted, - get_prev_content=get_prev_content, allow_rejected=allow_rejected, ) - missing_events_ids = [e for e in event_ids if e not in event_map] + missing_events_ids = [e for e in event_ids if e not in event_entry_map] - if not missing_events_ids: - return [ - event_map[e_id] for e_id in event_ids - if e_id in event_map and event_map[e_id] - ] + if missing_events_ids: + missing_events = yield self._enqueue_events( + missing_events_ids, + check_redacted=check_redacted, + allow_rejected=allow_rejected, + ) - missing_events = self._fetch_events_txn( - txn, - missing_events_ids, - check_redacted=check_redacted, - get_prev_content=get_prev_content, - allow_rejected=allow_rejected, - ) + event_entry_map.update(missing_events) - event_map.update(missing_events) + events = [] + for event_id in event_id_list: + entry = event_entry_map.get(event_id, None) + if not entry: + continue - return [ - event_map[e_id] for e_id in event_ids - if e_id in event_map and event_map[e_id] - ] + if allow_rejected or not entry.event.rejected_reason: + if check_redacted and entry.redacted_event: + event = entry.redacted_event + else: + event = entry.event - def _invalidate_get_event_cache(self, event_id): - for check_redacted in (False, True): - for get_prev_content in (False, True): - self._get_event_cache.invalidate( - (event_id, check_redacted, get_prev_content) - ) + events.append(event) - def _get_event_txn(self, txn, event_id, check_redacted=True, - get_prev_content=False, allow_rejected=False): + if get_prev_content: + if "replaces_state" in event.unsigned: + prev = yield self.get_event( + event.unsigned["replaces_state"], + get_prev_content=False, + allow_none=True, + ) + if prev: + event.unsigned = dict(event.unsigned) + event.unsigned["prev_content"] = prev.content + event.unsigned["prev_sender"] = prev.sender - events = self._get_events_txn( - txn, [event_id], - check_redacted=check_redacted, - get_prev_content=get_prev_content, - allow_rejected=allow_rejected, - ) + defer.returnValue(events) - return events[0] if events else None + def _invalidate_get_event_cache(self, event_id): + self._get_event_cache.invalidate((event_id,)) - def _get_events_from_cache(self, events, check_redacted, get_prev_content, - allow_rejected): + def _get_events_from_cache(self, events, allow_rejected): event_map = {} for event_id in events: - try: - ret = self._get_event_cache.get( - (event_id, check_redacted, get_prev_content,) - ) + ret = self._get_event_cache.get((event_id,), None) + if not ret: + continue - if allow_rejected or not ret.rejected_reason: - event_map[event_id] = ret - else: - event_map[event_id] = None - except KeyError: - pass + if allow_rejected or not ret.event.rejected_reason: + event_map[event_id] = ret + else: + event_map[event_id] = None return event_map @@ -695,8 +1056,7 @@ class EventsStore(SQLBaseStore): reactor.callFromThread(fire, event_list) @defer.inlineCallbacks - def _enqueue_events(self, events, check_redacted=True, - get_prev_content=False, allow_rejected=False): + def _enqueue_events(self, events, check_redacted=True, allow_rejected=False): """Fetches events from the database using the _event_fetch_list. This allows batch and bulk fetching of events - it allows us to fetch events without having to create a new transaction for each request for events. @@ -730,21 +1090,19 @@ class EventsStore(SQLBaseStore): if not allow_rejected: rows[:] = [r for r in rows if not r["rejects"]] - res = yield defer.gatherResults( + res = yield preserve_context_over_deferred(defer.gatherResults( [ preserve_fn(self._get_event_from_row)( row["internal_metadata"], row["json"], row["redacts"], - check_redacted=check_redacted, - get_prev_content=get_prev_content, rejected_reason=row["rejects"], ) for row in rows ], consumeErrors=True - ) + )) defer.returnValue({ - e.event_id: e + e.event.event_id: e for e in res if e }) @@ -774,157 +1132,60 @@ class EventsStore(SQLBaseStore): return rows - def _fetch_events_txn(self, txn, events, check_redacted=True, - get_prev_content=False, allow_rejected=False): - if not events: - return {} - - rows = self._fetch_event_rows( - txn, events, - ) - - if not allow_rejected: - rows[:] = [r for r in rows if not r["rejects"]] - - res = [ - self._get_event_from_row_txn( - txn, - row["internal_metadata"], row["json"], row["redacts"], - check_redacted=check_redacted, - get_prev_content=get_prev_content, - rejected_reason=row["rejects"], - ) - for row in rows - ] - - return { - r.event_id: r - for r in res - } - @defer.inlineCallbacks def _get_event_from_row(self, internal_metadata, js, redacted, - check_redacted=True, get_prev_content=False, rejected_reason=None): - d = json.loads(js) - internal_metadata = json.loads(internal_metadata) - - if rejected_reason: - rejected_reason = yield self._simple_select_one_onecol( - table="rejections", - keyvalues={"event_id": rejected_reason}, - retcol="reason", - desc="_get_event_from_row", - ) - - ev = FrozenEvent( - d, - internal_metadata_dict=internal_metadata, - rejected_reason=rejected_reason, - ) - - if check_redacted and redacted: - ev = prune_event(ev) - - redaction_id = yield self._simple_select_one_onecol( - table="redactions", - keyvalues={"redacts": ev.event_id}, - retcol="event_id", - desc="_get_event_from_row", - ) - - ev.unsigned["redacted_by"] = redaction_id - # Get the redaction event. - - because = yield self.get_event( - redaction_id, - check_redacted=False, - allow_none=True, - ) - - if because: - # It's fine to do add the event directly, since get_pdu_json - # will serialise this field correctly - ev.unsigned["redacted_because"] = because - - if get_prev_content and "replaces_state" in ev.unsigned: - prev = yield self.get_event( - ev.unsigned["replaces_state"], - get_prev_content=False, - allow_none=True, - ) - if prev: - ev.unsigned["prev_content"] = prev.content - ev.unsigned["prev_sender"] = prev.sender - - self._get_event_cache.prefill( - (ev.event_id, check_redacted, get_prev_content), ev - ) - - defer.returnValue(ev) - - def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted, - check_redacted=True, get_prev_content=False, - rejected_reason=None): - d = json.loads(js) - internal_metadata = json.loads(internal_metadata) + with Measure(self._clock, "_get_event_from_row"): + d = json.loads(js) + internal_metadata = json.loads(internal_metadata) + + if rejected_reason: + rejected_reason = yield self._simple_select_one_onecol( + table="rejections", + keyvalues={"event_id": rejected_reason}, + retcol="reason", + desc="_get_event_from_row_rejected_reason", + ) - if rejected_reason: - rejected_reason = self._simple_select_one_onecol_txn( - txn, - table="rejections", - keyvalues={"event_id": rejected_reason}, - retcol="reason", + original_ev = FrozenEvent( + d, + internal_metadata_dict=internal_metadata, + rejected_reason=rejected_reason, ) - ev = FrozenEvent( - d, - internal_metadata_dict=internal_metadata, - rejected_reason=rejected_reason, - ) - - if check_redacted and redacted: - ev = prune_event(ev) + redacted_event = None + if redacted: + redacted_event = prune_event(original_ev) - redaction_id = self._simple_select_one_onecol_txn( - txn, - table="redactions", - keyvalues={"redacts": ev.event_id}, - retcol="event_id", - ) + redaction_id = yield self._simple_select_one_onecol( + table="redactions", + keyvalues={"redacts": redacted_event.event_id}, + retcol="event_id", + desc="_get_event_from_row_redactions", + ) - ev.unsigned["redacted_by"] = redaction_id - # Get the redaction event. + redacted_event.unsigned["redacted_by"] = redaction_id + # Get the redaction event. - because = self._get_event_txn( - txn, - redaction_id, - check_redacted=False - ) + because = yield self.get_event( + redaction_id, + check_redacted=False, + allow_none=True, + ) - if because: - ev.unsigned["redacted_because"] = because + if because: + # It's fine to do add the event directly, since get_pdu_json + # will serialise this field correctly + redacted_event.unsigned["redacted_because"] = because - if get_prev_content and "replaces_state" in ev.unsigned: - prev = self._get_event_txn( - txn, - ev.unsigned["replaces_state"], - get_prev_content=False, + cache_entry = _EventCacheEntry( + event=original_ev, + redacted_event=redacted_event, ) - if prev: - ev.unsigned["prev_content"] = prev.content - ev.unsigned["prev_sender"] = prev.sender - self._get_event_cache.prefill( - (ev.event_id, check_redacted, get_prev_content), ev - ) - - return ev + self._get_event_cache.prefill((original_ev.event_id,), cache_entry) - def _parse_events_txn(self, txn, rows): - event_ids = [r["event_id"] for r in rows] - - return self._get_events_txn(txn, event_ids) + defer.returnValue(cache_entry) @defer.inlineCallbacks def count_daily_messages(self): @@ -998,16 +1259,17 @@ class EventsStore(SQLBaseStore): defer.returnValue(ret) @defer.inlineCallbacks - def _background_reindex_origin_server_ts(self, progress, batch_size): + def _background_reindex_fields_sender(self, progress, batch_size): target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] rows_inserted = progress.get("rows_inserted", 0) INSERT_CLUMP_SIZE = 1000 - def reindex_search_txn(txn): + def reindex_txn(txn): sql = ( - "SELECT stream_ordering, event_id FROM events" + "SELECT stream_ordering, event_id, json FROM events" + " INNER JOIN event_json USING (event_id)" " WHERE ? <= stream_ordering AND stream_ordering < ?" " ORDER BY stream_ordering DESC" " LIMIT ?" @@ -1020,28 +1282,31 @@ class EventsStore(SQLBaseStore): return 0 min_stream_id = rows[-1][0] - event_ids = [row[1] for row in rows] - - events = self._get_events_txn(txn, event_ids) - rows = [] - for event in events: + update_rows = [] + for row in rows: try: - event_id = event.event_id - origin_server_ts = event.origin_server_ts + event_id = row[1] + event_json = json.loads(row[2]) + sender = event_json["sender"] + content = event_json["content"] + + contains_url = "url" in content + if contains_url: + contains_url &= isinstance(content["url"], basestring) except (KeyError, AttributeError): # If the event is missing a necessary field then # skip over it. continue - rows.append((origin_server_ts, event_id)) + update_rows.append((sender, contains_url, event_id)) sql = ( - "UPDATE events SET origin_server_ts = ? WHERE event_id = ?" + "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?" ) - for index in range(0, len(rows), INSERT_CLUMP_SIZE): - clump = rows[index:index + INSERT_CLUMP_SIZE] + for index in range(0, len(update_rows), INSERT_CLUMP_SIZE): + clump = update_rows[index:index + INSERT_CLUMP_SIZE] txn.executemany(sql, clump) progress = { @@ -1051,12 +1316,94 @@ class EventsStore(SQLBaseStore): } self._background_update_progress_txn( - txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress + txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress ) return len(rows) result = yield self.runInteraction( + self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn + ) + + if not result: + yield self._end_background_update(self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME) + + defer.returnValue(result) + + @defer.inlineCallbacks + def _background_reindex_origin_server_ts(self, progress, batch_size): + target_min_stream_id = progress["target_min_stream_id_inclusive"] + max_stream_id = progress["max_stream_id_exclusive"] + rows_inserted = progress.get("rows_inserted", 0) + + INSERT_CLUMP_SIZE = 1000 + + def reindex_search_txn(txn): + sql = ( + "SELECT stream_ordering, event_id FROM events" + " WHERE ? <= stream_ordering AND stream_ordering < ?" + " ORDER BY stream_ordering DESC" + " LIMIT ?" + ) + + txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) + + rows = txn.fetchall() + if not rows: + return 0 + + min_stream_id = rows[-1][0] + event_ids = [row[1] for row in rows] + + rows_to_update = [] + + chunks = [ + event_ids[i:i + 100] + for i in xrange(0, len(event_ids), 100) + ] + for chunk in chunks: + ev_rows = self._simple_select_many_txn( + txn, + table="event_json", + column="event_id", + iterable=chunk, + retcols=["event_id", "json"], + keyvalues={}, + ) + + for row in ev_rows: + event_id = row["event_id"] + event_json = json.loads(row["json"]) + try: + origin_server_ts = event_json["origin_server_ts"] + except (KeyError, AttributeError): + # If the event is missing a necessary field then + # skip over it. + continue + + rows_to_update.append((origin_server_ts, event_id)) + + sql = ( + "UPDATE events SET origin_server_ts = ? WHERE event_id = ?" + ) + + for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE): + clump = rows_to_update[index:index + INSERT_CLUMP_SIZE] + txn.executemany(sql, clump) + + progress = { + "target_min_stream_id_inclusive": target_min_stream_id, + "max_stream_id_exclusive": min_stream_id, + "rows_inserted": rows_inserted + len(rows_to_update) + } + + self._background_update_progress_txn( + txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress + ) + + return len(rows_to_update) + + result = yield self.runInteraction( self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn ) @@ -1067,45 +1414,324 @@ class EventsStore(SQLBaseStore): def get_current_backfill_token(self): """The current minimum token that backfilled events have reached""" - - # TODO: Fix race with the persit_event txn by using one of the - # stream id managers - return -self.min_stream_token + return -self._backfill_id_gen.get_current_token() def get_all_new_events(self, last_backfill_id, last_forward_id, current_backfill_id, current_forward_id, limit): """Get all the new events that have arrived at the server either as new events or as backfilled events""" + have_backfill_events = last_backfill_id != current_backfill_id + have_forward_events = last_forward_id != current_forward_id + + if not have_backfill_events and not have_forward_events: + return defer.succeed(AllNewEventsResult([], [], [], [], [])) + def get_all_new_events_txn(txn): sql = ( - "SELECT e.stream_ordering, ej.internal_metadata, ej.json" + "SELECT e.stream_ordering, ej.internal_metadata, ej.json, eg.state_group" " FROM events as e" " JOIN event_json as ej" " ON e.event_id = ej.event_id AND e.room_id = ej.room_id" + " LEFT JOIN event_to_state_groups as eg" + " ON e.event_id = eg.event_id" " WHERE ? < e.stream_ordering AND e.stream_ordering <= ?" " ORDER BY e.stream_ordering ASC" " LIMIT ?" ) - if last_forward_id != current_forward_id: + if have_forward_events: txn.execute(sql, (last_forward_id, current_forward_id, limit)) new_forward_events = txn.fetchall() + + if len(new_forward_events) == limit: + upper_bound = new_forward_events[-1][0] + else: + upper_bound = current_forward_id + + sql = ( + "SELECT event_stream_ordering FROM current_state_resets" + " WHERE ? < event_stream_ordering" + " AND event_stream_ordering <= ?" + " ORDER BY event_stream_ordering ASC" + ) + txn.execute(sql, (last_forward_id, upper_bound)) + state_resets = txn.fetchall() + + sql = ( + "SELECT event_stream_ordering, event_id, state_group" + " FROM ex_outlier_stream" + " WHERE ? > event_stream_ordering" + " AND event_stream_ordering >= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (last_forward_id, upper_bound)) + forward_ex_outliers = txn.fetchall() else: new_forward_events = [] + state_resets = [] + forward_ex_outliers = [] sql = ( - "SELECT -e.stream_ordering, ej.internal_metadata, ej.json" + "SELECT -e.stream_ordering, ej.internal_metadata, ej.json," + " eg.state_group" " FROM events as e" " JOIN event_json as ej" " ON e.event_id = ej.event_id AND e.room_id = ej.room_id" + " LEFT JOIN event_to_state_groups as eg" + " ON e.event_id = eg.event_id" " WHERE ? > e.stream_ordering AND e.stream_ordering >= ?" " ORDER BY e.stream_ordering DESC" " LIMIT ?" ) - if last_backfill_id != current_backfill_id: + if have_backfill_events: txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit)) new_backfill_events = txn.fetchall() + + if len(new_backfill_events) == limit: + upper_bound = new_backfill_events[-1][0] + else: + upper_bound = current_backfill_id + + sql = ( + "SELECT -event_stream_ordering, event_id, state_group" + " FROM ex_outlier_stream" + " WHERE ? > event_stream_ordering" + " AND event_stream_ordering >= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (-last_backfill_id, -upper_bound)) + backward_ex_outliers = txn.fetchall() else: new_backfill_events = [] + backward_ex_outliers = [] - return (new_forward_events, new_backfill_events) + return AllNewEventsResult( + new_forward_events, new_backfill_events, + forward_ex_outliers, backward_ex_outliers, + state_resets, + ) return self.runInteraction("get_all_new_events", get_all_new_events_txn) + + def delete_old_state(self, room_id, topological_ordering): + return self.runInteraction( + "delete_old_state", + self._delete_old_state_txn, room_id, topological_ordering + ) + + def _delete_old_state_txn(self, txn, room_id, topological_ordering): + """Deletes old room state + """ + + # Tables that should be pruned: + # event_auth + # event_backward_extremities + # event_content_hashes + # event_destinations + # event_edge_hashes + # event_edges + # event_forward_extremities + # event_json + # event_push_actions + # event_reference_hashes + # event_search + # event_signatures + # event_to_state_groups + # events + # rejections + # room_depth + # state_groups + # state_groups_state + + # First ensure that we're not about to delete all the forward extremeties + txn.execute( + "SELECT e.event_id, e.depth FROM events as e " + "INNER JOIN event_forward_extremities as f " + "ON e.event_id = f.event_id " + "AND e.room_id = f.room_id " + "WHERE f.room_id = ?", + (room_id,) + ) + rows = txn.fetchall() + max_depth = max(row[0] for row in rows) + + if max_depth <= topological_ordering: + # We need to ensure we don't delete all the events from the datanase + # otherwise we wouldn't be able to send any events (due to not + # having any backwards extremeties) + raise SynapseError( + 400, "topological_ordering is greater than forward extremeties" + ) + + txn.execute( + "SELECT event_id, state_key FROM events" + " LEFT JOIN state_events USING (room_id, event_id)" + " WHERE room_id = ? AND topological_ordering < ?", + (room_id, topological_ordering,) + ) + event_rows = txn.fetchall() + + for event_id, state_key in event_rows: + txn.call_after(self._get_state_group_for_event.invalidate, (event_id,)) + + # We calculate the new entries for the backward extremeties by finding + # all events that point to events that are to be purged + txn.execute( + "SELECT DISTINCT e.event_id FROM events as e" + " INNER JOIN event_edges as ed ON e.event_id = ed.prev_event_id" + " INNER JOIN events as e2 ON e2.event_id = ed.event_id" + " WHERE e.room_id = ? AND e.topological_ordering < ?" + " AND e2.topological_ordering >= ?", + (room_id, topological_ordering, topological_ordering) + ) + new_backwards_extrems = txn.fetchall() + + txn.execute( + "DELETE FROM event_backward_extremities WHERE room_id = ?", + (room_id,) + ) + + # Update backward extremeties + txn.executemany( + "INSERT INTO event_backward_extremities (room_id, event_id)" + " VALUES (?, ?)", + [ + (room_id, event_id) for event_id, in new_backwards_extrems + ] + ) + + # Get all state groups that are only referenced by events that are + # to be deleted. + txn.execute( + "SELECT state_group FROM event_to_state_groups" + " INNER JOIN events USING (event_id)" + " WHERE state_group IN (" + " SELECT DISTINCT state_group FROM events" + " INNER JOIN event_to_state_groups USING (event_id)" + " WHERE room_id = ? AND topological_ordering < ?" + " )" + " GROUP BY state_group HAVING MAX(topological_ordering) < ?", + (room_id, topological_ordering, topological_ordering) + ) + + state_rows = txn.fetchall() + state_groups_to_delete = [sg for sg, in state_rows] + + # Now we get all the state groups that rely on these state groups + new_state_edges = [] + chunks = [ + state_groups_to_delete[i:i + 100] + for i in xrange(0, len(state_groups_to_delete), 100) + ] + for chunk in chunks: + rows = self._simple_select_many_txn( + txn, + table="state_group_edges", + column="prev_state_group", + iterable=chunk, + retcols=["state_group"], + keyvalues={}, + ) + new_state_edges.extend(row["state_group"] for row in rows) + + # Now we turn the state groups that reference to-be-deleted state groups + # to non delta versions. + for new_state_edge in new_state_edges: + curr_state = self._get_state_groups_from_groups_txn( + txn, [new_state_edge], types=None + ) + curr_state = curr_state[new_state_edge] + + self._simple_delete_txn( + txn, + table="state_groups_state", + keyvalues={ + "state_group": new_state_edge, + } + ) + + self._simple_delete_txn( + txn, + table="state_group_edges", + keyvalues={ + "state_group": new_state_edge, + } + ) + + self._simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": new_state_edge, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in curr_state.items() + ], + ) + + txn.executemany( + "DELETE FROM state_groups_state WHERE state_group = ?", + state_rows + ) + txn.executemany( + "DELETE FROM state_groups WHERE id = ?", + state_rows + ) + # Delete all non-state + txn.executemany( + "DELETE FROM event_to_state_groups WHERE event_id = ?", + [(event_id,) for event_id, _ in event_rows] + ) + + txn.execute( + "UPDATE room_depth SET min_depth = ? WHERE room_id = ?", + (topological_ordering, room_id,) + ) + + # Delete all remote non-state events + to_delete = [ + (event_id,) for event_id, state_key in event_rows + if state_key is None and not self.hs.is_mine_id(event_id) + ] + for table in ( + "events", + "event_json", + "event_auth", + "event_content_hashes", + "event_destinations", + "event_edge_hashes", + "event_edges", + "event_forward_extremities", + "event_push_actions", + "event_reference_hashes", + "event_search", + "event_signatures", + "rejections", + ): + txn.executemany( + "DELETE FROM %s WHERE event_id = ?" % (table,), + to_delete + ) + + txn.executemany( + "DELETE FROM events WHERE event_id = ?", + to_delete + ) + # Mark all state and own events as outliers + txn.executemany( + "UPDATE events SET outlier = ?" + " WHERE event_id = ?", + [ + (True, event_id,) for event_id, state_key in event_rows + if state_key is not None or self.hs.is_mine_id(event_id) + ] + ) + + +AllNewEventsResult = namedtuple("AllNewEventsResult", [ + "new_forward_events", "new_backfill_events", + "forward_ex_outliers", "backward_ex_outliers", + "state_resets" +]) diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py index 5248736816..a2ccc66ea7 100644 --- a/synapse/storage/filtering.py +++ b/synapse/storage/filtering.py @@ -16,6 +16,7 @@ from twisted.internet import defer from ._base import SQLBaseStore +from synapse.api.errors import SynapseError, Codes from synapse.util.caches.descriptors import cachedInlineCallbacks import simplejson as json @@ -24,6 +25,13 @@ import simplejson as json class FilteringStore(SQLBaseStore): @cachedInlineCallbacks(num_args=2) def get_user_filter(self, user_localpart, filter_id): + # filter_id is BIGINT UNSIGNED, so if it isn't a number, fail + # with a coherent error message rather than 500 M_UNKNOWN. + try: + int(filter_id) + except ValueError: + raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM) + def_json = yield self._simple_select_one_onecol( table="user_filters", keyvalues={ diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index a495a8a7d9..86b37b9ddd 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -22,6 +22,10 @@ import OpenSSL from signedjson.key import decode_verify_key_bytes import hashlib +import logging + +logger = logging.getLogger(__name__) + class KeyStore(SQLBaseStore): """Persistence for signature verification keys and tls X.509 certificates @@ -74,22 +78,22 @@ class KeyStore(SQLBaseStore): ) @cachedInlineCallbacks() - def get_all_server_verify_keys(self, server_name): - rows = yield self._simple_select_list( + def _get_server_verify_key(self, server_name, key_id): + verify_key_bytes = yield self._simple_select_one_onecol( table="server_signature_keys", keyvalues={ "server_name": server_name, + "key_id": key_id, }, - retcols=["key_id", "verify_key"], - desc="get_all_server_verify_keys", + retcol="verify_key", + desc="_get_server_verify_key", + allow_none=True, ) - defer.returnValue({ - row["key_id"]: decode_verify_key_bytes( - row["key_id"], str(row["verify_key"]) - ) - for row in rows - }) + if verify_key_bytes: + defer.returnValue(decode_verify_key_bytes( + key_id, str(verify_key_bytes) + )) @defer.inlineCallbacks def get_server_verify_keys(self, server_name, key_ids): @@ -101,12 +105,12 @@ class KeyStore(SQLBaseStore): Returns: (list of VerifyKey): The verification keys. """ - keys = yield self.get_all_server_verify_keys(server_name) - defer.returnValue({ - k: keys[k] - for k in key_ids - if k in keys and keys[k] - }) + keys = {} + for key_id in key_ids: + key = yield self._get_server_verify_key(server_name, key_id) + if key: + keys[key_id] = key + defer.returnValue(keys) @defer.inlineCallbacks def store_server_verify_key(self, server_name, from_server, time_now_ms, @@ -133,8 +137,6 @@ class KeyStore(SQLBaseStore): desc="store_server_verify_key", ) - self.get_all_server_verify_keys.invalidate((server_name,)) - def store_server_keys_json(self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes): """Stores the JSON bytes for a set of keys from a server diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py index 9d3ba32478..4c0f82353d 100644 --- a/synapse/storage/media_repository.py +++ b/synapse/storage/media_repository.py @@ -25,7 +25,7 @@ class MediaRepositoryStore(SQLBaseStore): def get_local_media(self, media_id): """Get the metadata for a local piece of media Returns: - None if the meia_id doesn't exist. + None if the media_id doesn't exist. """ return self._simple_select_one( "local_media_repository", @@ -50,6 +50,61 @@ class MediaRepositoryStore(SQLBaseStore): desc="store_local_media", ) + def get_url_cache(self, url, ts): + """Get the media_id and ts for a cached URL as of the given timestamp + Returns: + None if the URL isn't cached. + """ + def get_url_cache_txn(txn): + # get the most recently cached result (relative to the given ts) + sql = ( + "SELECT response_code, etag, expires, og, media_id, download_ts" + " FROM local_media_repository_url_cache" + " WHERE url = ? AND download_ts <= ?" + " ORDER BY download_ts DESC LIMIT 1" + ) + txn.execute(sql, (url, ts)) + row = txn.fetchone() + + if not row: + # ...or if we've requested a timestamp older than the oldest + # copy in the cache, return the oldest copy (if any) + sql = ( + "SELECT response_code, etag, expires, og, media_id, download_ts" + " FROM local_media_repository_url_cache" + " WHERE url = ? AND download_ts > ?" + " ORDER BY download_ts ASC LIMIT 1" + ) + txn.execute(sql, (url, ts)) + row = txn.fetchone() + + if not row: + return None + + return dict(zip(( + 'response_code', 'etag', 'expires', 'og', 'media_id', 'download_ts' + ), row)) + + return self.runInteraction( + "get_url_cache", get_url_cache_txn + ) + + def store_url_cache(self, url, response_code, etag, expires, og, media_id, + download_ts): + return self._simple_insert( + "local_media_repository_url_cache", + { + "url": url, + "response_code": response_code, + "etag": etag, + "expires": expires, + "og": og, + "media_id": media_id, + "download_ts": download_ts, + }, + desc="store_url_cache", + ) + def get_local_media_thumbnails(self, media_id): return self._simple_select_list( "local_media_repository_thumbnails", @@ -102,10 +157,25 @@ class MediaRepositoryStore(SQLBaseStore): "created_ts": time_now_ms, "upload_name": upload_name, "filesystem_id": filesystem_id, + "last_access_ts": time_now_ms, }, desc="store_cached_remote_media", ) + def update_cached_last_access_time(self, origin_id_tuples, time_ts): + def update_cache_txn(txn): + sql = ( + "UPDATE remote_media_cache SET last_access_ts = ?" + " WHERE media_origin = ? AND media_id = ?" + ) + + txn.executemany(sql, ( + (time_ts, media_origin, media_id) + for media_origin, media_id in origin_id_tuples + )) + + return self.runInteraction("update_cached_last_access_time", update_cache_txn) + def get_remote_media_thumbnails(self, origin, media_id): return self._simple_select_list( "remote_media_cache_thumbnails", @@ -135,3 +205,32 @@ class MediaRepositoryStore(SQLBaseStore): }, desc="store_remote_media_thumbnail", ) + + def get_remote_media_before(self, before_ts): + sql = ( + "SELECT media_origin, media_id, filesystem_id" + " FROM remote_media_cache" + " WHERE last_access_ts < ?" + ) + + return self._execute( + "get_remote_media_before", self.cursor_to_dict, sql, before_ts + ) + + def delete_remote_media(self, media_origin, media_id): + def delete_remote_media_txn(txn): + self._simple_delete_txn( + txn, + "remote_media_cache", + keyvalues={ + "media_origin": media_origin, "media_id": media_id + }, + ) + self._simple_delete_txn( + txn, + "remote_media_cache_thumbnails", + keyvalues={ + "media_origin": media_origin, "media_id": media_id + }, + ) + return self.runInteraction("delete_remote_media", delete_remote_media_txn) diff --git a/synapse/storage/openid.py b/synapse/storage/openid.py new file mode 100644 index 0000000000..5dabb607bd --- /dev/null +++ b/synapse/storage/openid.py @@ -0,0 +1,32 @@ +from ._base import SQLBaseStore + + +class OpenIdStore(SQLBaseStore): + def insert_open_id_token(self, token, ts_valid_until_ms, user_id): + return self._simple_insert( + table="open_id_tokens", + values={ + "token": token, + "ts_valid_until_ms": ts_valid_until_ms, + "user_id": user_id, + }, + desc="insert_open_id_token" + ) + + def get_user_id_for_open_id_token(self, token, ts_now_ms): + def get_user_id_for_token_txn(txn): + sql = ( + "SELECT user_id FROM open_id_tokens" + " WHERE token = ? AND ? <= ts_valid_until_ms" + ) + + txn.execute(sql, (token, ts_now_ms)) + + rows = txn.fetchall() + if not rows: + return None + else: + return rows[0][0] + return self.runInteraction( + "get_user_id_for_token", get_user_id_for_token_txn + ) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 3f29aad1e8..e46ae6502e 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -25,23 +25,11 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 30 +SCHEMA_VERSION = 39 dir_path = os.path.abspath(os.path.dirname(__file__)) -def read_schema(path): - """ Read the named database schema. - - Args: - path: Path of the database schema. - Returns: - A string containing the database schema. - """ - with open(path) as schema_file: - return schema_file.read() - - class PrepareDatabaseException(Exception): pass @@ -53,6 +41,9 @@ class UpgradeDatabaseException(PrepareDatabaseException): def prepare_database(db_conn, database_engine, config): """Prepares a database for usage. Will either create all necessary tables or upgrade from an older schema version. + + If `config` is None then prepare_database will assert that no upgrade is + necessary, *or* will create a fresh database if the database is empty. """ try: cur = db_conn.cursor() @@ -60,13 +51,18 @@ def prepare_database(db_conn, database_engine, config): if version_info: user_version, delta_files, upgraded = version_info - _upgrade_existing_database( - cur, user_version, delta_files, upgraded, database_engine, config - ) - else: - _setup_new_database(cur, database_engine, config) - # cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,)) + if config is None: + if user_version != SCHEMA_VERSION: + # If we don't pass in a config file then we are expecting to + # have already upgraded the DB. + raise UpgradeDatabaseException("Database needs to be upgraded") + else: + _upgrade_existing_database( + cur, user_version, delta_files, upgraded, database_engine, config + ) + else: + _setup_new_database(cur, database_engine) cur.close() db_conn.commit() @@ -75,7 +71,7 @@ def prepare_database(db_conn, database_engine, config): raise -def _setup_new_database(cur, database_engine, config): +def _setup_new_database(cur, database_engine): """Sets up the database by finding a base set of "full schemas" and then applying any necessary deltas. @@ -148,12 +144,13 @@ def _setup_new_database(cur, database_engine, config): applied_delta_files=[], upgraded=False, database_engine=database_engine, - config=config, + config=None, + is_empty=True, ) def _upgrade_existing_database(cur, current_version, applied_delta_files, - upgraded, database_engine, config): + upgraded, database_engine, config, is_empty=False): """Upgrades an existing database. Delta files can either be SQL stored in *.sql files, or python modules @@ -245,8 +242,10 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files, module = imp.load_source( module_name, absolute_path, python_file ) - logger.debug("Running script %s", relative_path) - module.run_upgrade(cur, database_engine, config=config) + logger.info("Running script %s", relative_path) + module.run_create(cur, database_engine) + if not is_empty: + module.run_upgrade(cur, database_engine, config=config) elif ext == ".pyc": # Sometimes .pyc files turn up anyway even though we've # disabled their generation; e.g. from distribution package @@ -254,7 +253,7 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files, pass elif ext == ".sql": # A plain old .sql file, just read and execute it - logger.debug("Applying schema %s", relative_path) + logger.info("Applying schema %s", relative_path) executescript(cur, absolute_path) else: # Not a valid delta file. @@ -361,36 +360,3 @@ def _get_or_create_schema_state(txn, database_engine): return current_version, applied_deltas, upgraded return None - - -def prepare_sqlite3_database(db_conn): - """This function should be called before `prepare_database` on sqlite3 - databases. - - Since we changed the way we store the current schema version and handle - updates to schemas, we need a way to upgrade from the old method to the - new. This only affects sqlite databases since they were the only ones - supported at the time. - """ - with db_conn: - schema_path = os.path.join( - dir_path, "schema", "schema_version.sql", - ) - create_schema = read_schema(schema_path) - db_conn.executescript(create_schema) - - c = db_conn.execute("SELECT * FROM schema_version") - rows = c.fetchall() - c.close() - - if not rows: - c = db_conn.execute("PRAGMA user_version") - row = c.fetchone() - c.close() - - if row and row[0]: - db_conn.execute( - "REPLACE INTO schema_version (version, upgraded)" - " VALUES (?,?)", - (row[0], False) - ) diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index 4cec31e316..7460f98a1f 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -37,6 +37,13 @@ class UserPresenceState(namedtuple("UserPresenceState", status_msg (str): User set status message. """ + def as_dict(self): + return dict(self._asdict()) + + @staticmethod + def from_dict(d): + return UserPresenceState(**d) + def copy_and_replace(self, **kwargs): return self._replace(**kwargs) @@ -68,7 +75,9 @@ class PresenceStore(SQLBaseStore): self._update_presence_txn, stream_orderings, presence_states, ) - defer.returnValue((stream_orderings[-1], self._presence_id_gen.get_max_token())) + defer.returnValue(( + stream_orderings[-1], self._presence_id_gen.get_current_token() + )) def _update_presence_txn(self, txn, stream_orderings, presence_states): for stream_id, state in zip(stream_orderings, presence_states): @@ -116,6 +125,9 @@ class PresenceStore(SQLBaseStore): ) def get_all_presence_updates(self, last_id, current_id): + if last_id == current_id: + return defer.succeed([]) + def get_all_presence_updates_txn(txn): sql = ( "SELECT stream_id, user_id, state, last_active_ts," @@ -147,6 +159,7 @@ class PresenceStore(SQLBaseStore): "status_msg", "currently_active", ), + desc="get_presence_for_users", ) for row in rows: @@ -155,7 +168,7 @@ class PresenceStore(SQLBaseStore): defer.returnValue([UserPresenceState(**row) for row in rows]) def get_current_presence_token(self): - return self._presence_id_gen.get_max_token() + return self._presence_id_gen.get_current_token() def allow_presence_visible(self, observed_localpart, observer_userid): return self._simple_insert( @@ -174,16 +187,6 @@ class PresenceStore(SQLBaseStore): desc="disallow_presence_visible", ) - def is_presence_visible(self, observed_localpart, observer_userid): - return self._simple_select_one( - table="presence_allow_inbound", - keyvalues={"observed_user_id": observed_localpart, - "observer_user_id": observer_userid}, - retcols=["observed_user_id"], - allow_none=True, - desc="is_presence_visible", - ) - def add_presence_list_pending(self, observer_localpart, observed_userid): return self._simple_insert( table="presence_list", @@ -193,18 +196,30 @@ class PresenceStore(SQLBaseStore): desc="add_presence_list_pending", ) - @defer.inlineCallbacks def set_presence_list_accepted(self, observer_localpart, observed_userid): - result = yield self._simple_update_one( - table="presence_list", - keyvalues={"user_id": observer_localpart, - "observed_user_id": observed_userid}, - updatevalues={"accepted": True}, - desc="set_presence_list_accepted", + def update_presence_list_txn(txn): + result = self._simple_update_one_txn( + txn, + table="presence_list", + keyvalues={ + "user_id": observer_localpart, + "observed_user_id": observed_userid + }, + updatevalues={"accepted": True}, + ) + + self._invalidate_cache_and_stream( + txn, self.get_presence_list_accepted, (observer_localpart,) + ) + self._invalidate_cache_and_stream( + txn, self.get_presence_list_observers_accepted, (observed_userid,) + ) + + return result + + return self.runInteraction( + "set_presence_list_accepted", update_presence_list_txn, ) - self.get_presence_list_accepted.invalidate((observer_localpart,)) - self.get_presence_list_observers_accepted.invalidate((observed_userid,)) - defer.returnValue(result) def get_presence_list(self, observer_localpart, accepted=None): if accepted: diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 9dbad2fd5f..cbec255966 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -14,7 +14,8 @@ # limitations under the License. from ._base import SQLBaseStore -from synapse.util.caches.descriptors import cachedInlineCallbacks +from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList +from synapse.push.baserules import list_with_base_rules from twisted.internet import defer import logging @@ -23,6 +24,29 @@ import simplejson as json logger = logging.getLogger(__name__) +def _load_rules(rawrules, enabled_map): + ruleslist = [] + for rawrule in rawrules: + rule = dict(rawrule) + rule["conditions"] = json.loads(rawrule["conditions"]) + rule["actions"] = json.loads(rawrule["actions"]) + ruleslist.append(rule) + + # We're going to be mutating this a lot, so do a deep copy + rules = list(list_with_base_rules(ruleslist)) + + for i, rule in enumerate(rules): + rule_id = rule['rule_id'] + if rule_id in enabled_map: + if rule.get('enabled', True) != bool(enabled_map[rule_id]): + # Rules are cached across users. + rule = dict(rule) + rule['enabled'] = bool(enabled_map[rule_id]) + rules[i] = rule + + return rules + + class PushRuleStore(SQLBaseStore): @cachedInlineCallbacks() def get_push_rules_for_user(self, user_id): @@ -42,7 +66,11 @@ class PushRuleStore(SQLBaseStore): key=lambda row: (-int(row["priority_class"]), -int(row["priority"])) ) - defer.returnValue(rows) + enabled_map = yield self.get_push_rules_enabled_for_user(user_id) + + rules = _load_rules(rows, enabled_map) + + defer.returnValue(rules) @cachedInlineCallbacks() def get_push_rules_enabled_for_user(self, user_id): @@ -60,12 +88,16 @@ class PushRuleStore(SQLBaseStore): r['rule_id']: False if r['enabled'] == 0 else True for r in results }) - @defer.inlineCallbacks + @cachedList(cached_method_name="get_push_rules_for_user", + list_name="user_ids", num_args=1, inlineCallbacks=True) def bulk_get_push_rules(self, user_ids): if not user_ids: defer.returnValue({}) - results = {} + results = { + user_id: [] + for user_id in user_ids + } rows = yield self._simple_select_many_batch( table="push_rules", @@ -75,18 +107,101 @@ class PushRuleStore(SQLBaseStore): desc="bulk_get_push_rules", ) - rows.sort(key=lambda e: (-e["priority_class"], -e["priority"])) + rows.sort( + key=lambda row: (-int(row["priority_class"]), -int(row["priority"])) + ) for row in rows: results.setdefault(row['user_name'], []).append(row) + + enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids) + + for user_id, rules in results.items(): + results[user_id] = _load_rules( + rules, enabled_map_by_user.get(user_id, {}) + ) + defer.returnValue(results) - @defer.inlineCallbacks + def bulk_get_push_rules_for_room(self, event, context): + state_group = context.state_group + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + return self._bulk_get_push_rules_for_room( + event.room_id, state_group, context.current_state_ids, event=event + ) + + @cachedInlineCallbacks(num_args=2, cache_context=True) + def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state_ids, + cache_context, event=None): + # We don't use `state_group`, its there so that we can cache based + # on it. However, its important that its never None, since two current_state's + # with a state_group of None are likely to be different. + # See bulk_get_push_rules_for_room for how we work around this. + assert state_group is not None + + # We also will want to generate notifs for other people in the room so + # their unread countss are correct in the event stream, but to avoid + # generating them for bot / AS users etc, we only do so for people who've + # sent a read receipt into the room. + + users_in_room = yield self._get_joined_users_from_context( + room_id, state_group, current_state_ids, + on_invalidate=cache_context.invalidate, + event=event, + ) + + # We ignore app service users for now. This is so that we don't fill + # up the `get_if_users_have_pushers` cache with AS entries that we + # know don't have pushers, nor even read receipts. + local_users_in_room = set( + u for u in users_in_room + if self.hs.is_mine_id(u) + and not self.get_if_app_services_interested_in_user(u) + ) + + # users in the room who have pushers need to get push rules run because + # that's how their pushers work + if_users_with_pushers = yield self.get_if_users_have_pushers( + local_users_in_room, + on_invalidate=cache_context.invalidate, + ) + user_ids = set( + uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher + ) + + users_with_receipts = yield self.get_users_with_read_receipts_in_room( + room_id, on_invalidate=cache_context.invalidate, + ) + + # any users with pushers must be ours: they have pushers + for uid in users_with_receipts: + if uid in local_users_in_room: + user_ids.add(uid) + + rules_by_user = yield self.bulk_get_push_rules( + user_ids, on_invalidate=cache_context.invalidate, + ) + + rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None} + + defer.returnValue(rules_by_user) + + @cachedList(cached_method_name="get_push_rules_enabled_for_user", + list_name="user_ids", num_args=1, inlineCallbacks=True) def bulk_get_push_rules_enabled(self, user_ids): if not user_ids: defer.returnValue({}) - results = {} + results = { + user_id: {} + for user_id in user_ids + } rows = yield self._simple_select_many_batch( table="push_rules_enable", @@ -96,7 +211,8 @@ class PushRuleStore(SQLBaseStore): desc="bulk_get_push_rules_enabled", ) for row in rows: - results.setdefault(row['user_name'], {})[row['rule_id']] = row['enabled'] + enabled = bool(row['enabled']) + results.setdefault(row['user_name'], {})[row['rule_id']] = enabled defer.returnValue(results) @defer.inlineCallbacks @@ -374,6 +490,9 @@ class PushRuleStore(SQLBaseStore): def get_all_push_rule_updates(self, last_id, current_id, limit): """Get all the push rules changes that have happend on the server""" + if last_id == current_id: + return defer.succeed([]) + def get_all_push_rule_updates_txn(txn): sql = ( "SELECT stream_id, event_stream_ordering, user_id, rule_id," @@ -392,7 +511,7 @@ class PushRuleStore(SQLBaseStore): """Get the position of the push rules stream. Returns a pair of a stream id for the push_rules stream and the room stream ordering it corresponds to.""" - return self._push_rules_stream_id_gen.get_max_token() + return self._push_rules_stream_id_gen.get_current_token() def have_push_rules_changed_for_user(self, user_id, last_id): if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 87b2ac5773..8cc9f0353b 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -18,6 +18,8 @@ from twisted.internet import defer from canonicaljson import encode_canonical_json +from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList + import logging import simplejson as json import types @@ -48,23 +50,46 @@ class PusherStore(SQLBaseStore): return rows @defer.inlineCallbacks - def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey): - def r(txn): - sql = ( - "SELECT * FROM pushers" - " WHERE app_id = ? AND pushkey = ?" - ) + def user_has_pusher(self, user_id): + ret = yield self._simple_select_one_onecol( + "pushers", {"user_name": user_id}, "id", allow_none=True + ) + defer.returnValue(ret is not None) - txn.execute(sql, (app_id, pushkey,)) - rows = self.cursor_to_dict(txn) + def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey): + return self.get_pushers_by({ + "app_id": app_id, + "pushkey": pushkey, + }) - return self._decode_pushers_rows(rows) + def get_pushers_by_user_id(self, user_id): + return self.get_pushers_by({ + "user_name": user_id, + }) - rows = yield self.runInteraction( - "get_pushers_by_app_id_and_pushkey", r + @defer.inlineCallbacks + def get_pushers_by(self, keyvalues): + ret = yield self._simple_select_list( + "pushers", keyvalues, + [ + "id", + "user_name", + "access_token", + "profile_tag", + "kind", + "app_id", + "app_display_name", + "device_display_name", + "pushkey", + "ts", + "lang", + "data", + "last_stream_ordering", + "last_success", + "failing_since", + ], desc="get_pushers_by" ) - - defer.returnValue(rows) + defer.returnValue(self._decode_pushers_rows(ret)) @defer.inlineCallbacks def get_all_pushers(self): @@ -78,9 +103,12 @@ class PusherStore(SQLBaseStore): defer.returnValue(rows) def get_pushers_stream_token(self): - return self._pushers_id_gen.get_max_token() + return self._pushers_id_gen.get_current_token() def get_all_updated_pushers(self, last_id, current_id, limit): + if last_id == current_id: + return defer.succeed(([], [])) + def get_all_updated_pushers_txn(txn): sql = ( "SELECT id, user_name, access_token, profile_tag, kind," @@ -107,35 +135,67 @@ class PusherStore(SQLBaseStore): "get_all_updated_pushers", get_all_updated_pushers_txn ) + @cachedInlineCallbacks(num_args=1, max_entries=15000) + def get_if_user_has_pusher(self, user_id): + # This only exists for the cachedList decorator + raise NotImplementedError() + + @cachedList(cached_method_name="get_if_user_has_pusher", + list_name="user_ids", num_args=1, inlineCallbacks=True) + def get_if_users_have_pushers(self, user_ids): + rows = yield self._simple_select_many_batch( + table='pushers', + column='user_name', + iterable=user_ids, + retcols=['user_name'], + desc='get_if_users_have_pushers' + ) + + result = {user_id: False for user_id in user_ids} + result.update({r['user_name']: True for r in rows}) + + defer.returnValue(result) + @defer.inlineCallbacks def add_pusher(self, user_id, access_token, kind, app_id, app_display_name, device_display_name, - pushkey, pushkey_ts, lang, data, profile_tag=""): + pushkey, pushkey_ts, lang, data, last_stream_ordering, + profile_tag=""): with self._pushers_id_gen.get_next() as stream_id: - yield self._simple_upsert( - "pushers", - dict( - app_id=app_id, - pushkey=pushkey, - user_name=user_id, - ), - dict( - access_token=access_token, - kind=kind, - app_display_name=app_display_name, - device_display_name=device_display_name, - ts=pushkey_ts, - lang=lang, - data=encode_canonical_json(data), - profile_tag=profile_tag, - id=stream_id, - ), - desc="add_pusher", - ) + def f(txn): + newly_inserted = self._simple_upsert_txn( + txn, + "pushers", + { + "app_id": app_id, + "pushkey": pushkey, + "user_name": user_id, + }, + { + "access_token": access_token, + "kind": kind, + "app_display_name": app_display_name, + "device_display_name": device_display_name, + "ts": pushkey_ts, + "lang": lang, + "data": encode_canonical_json(data), + "last_stream_ordering": last_stream_ordering, + "profile_tag": profile_tag, + "id": stream_id, + }, + ) + if newly_inserted: + # get_if_user_has_pusher only cares if the user has + # at least *one* pusher. + txn.call_after(self.get_if_user_has_pusher.invalidate, (user_id,)) + + yield self.runInteraction("add_pusher", f) @defer.inlineCallbacks def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id): def delete_pusher_txn(txn, stream_id): + txn.call_after(self.get_if_user_has_pusher.invalidate, (user_id,)) + self._simple_delete_one_txn( txn, "pushers", @@ -147,28 +207,35 @@ class PusherStore(SQLBaseStore): {"app_id": app_id, "pushkey": pushkey, "user_id": user_id}, {"stream_id": stream_id}, ) + with self._pushers_id_gen.get_next() as stream_id: yield self.runInteraction( "delete_pusher", delete_pusher_txn, stream_id ) @defer.inlineCallbacks - def update_pusher_last_token(self, app_id, pushkey, user_id, last_token): + def update_pusher_last_stream_ordering(self, app_id, pushkey, user_id, + last_stream_ordering): yield self._simple_update_one( "pushers", {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id}, - {'last_token': last_token}, - desc="update_pusher_last_token", + {'last_stream_ordering': last_stream_ordering}, + desc="update_pusher_last_stream_ordering", ) @defer.inlineCallbacks - def update_pusher_last_token_and_success(self, app_id, pushkey, user_id, - last_token, last_success): + def update_pusher_last_stream_ordering_and_success(self, app_id, pushkey, + user_id, + last_stream_ordering, + last_success): yield self._simple_update_one( "pushers", {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id}, - {'last_token': last_token, 'last_success': last_success}, - desc="update_pusher_last_token_and_success", + { + 'last_stream_ordering': last_stream_ordering, + 'last_success': last_success + }, + desc="update_pusher_last_stream_ordering_and_success", ) @defer.inlineCallbacks @@ -180,3 +247,30 @@ class PusherStore(SQLBaseStore): {'failing_since': failing_since}, desc="update_pusher_failing_since", ) + + @defer.inlineCallbacks + def get_throttle_params_by_room(self, pusher_id): + res = yield self._simple_select_list( + "pusher_throttle", + {"pusher": pusher_id}, + ["room_id", "last_sent_ts", "throttle_ms"], + desc="get_throttle_params_by_room" + ) + + params_by_room = {} + for row in res: + params_by_room[row["room_id"]] = { + "last_sent_ts": row["last_sent_ts"], + "throttle_ms": row["throttle_ms"] + } + + defer.returnValue(params_by_room) + + @defer.inlineCallbacks + def set_throttle_params(self, pusher_id, room_id, params): + yield self._simple_upsert( + "pusher_throttle", + {"pusher": pusher_id, "room_id": room_id}, + params, + desc="set_throttle_params" + ) diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index dbc074d6b5..9747a04a9a 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -31,9 +31,29 @@ class ReceiptsStore(SQLBaseStore): super(ReceiptsStore, self).__init__(hs) self._receipts_stream_cache = StreamChangeCache( - "ReceiptsRoomChangeCache", self._receipts_id_gen.get_max_token() + "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token() ) + @cachedInlineCallbacks() + def get_users_with_read_receipts_in_room(self, room_id): + receipts = yield self.get_receipts_for_room(room_id, "m.read") + defer.returnValue(set(r['user_id'] for r in receipts)) + + def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type, + user_id): + if receipt_type != "m.read": + return + + # Returns an ObservableDeferred + res = self.get_users_with_read_receipts_in_room.cache.get((room_id,), None) + + if res and res.called and user_id in res.result: + # We'd only be adding to the set, so no point invalidating if the + # user is already there + return + + self.get_users_with_read_receipts_in_room.invalidate((room_id,)) + @cached(num_args=2) def get_receipts_for_room(self, room_id, receipt_type): return self._simple_select_list( @@ -62,18 +82,42 @@ class ReceiptsStore(SQLBaseStore): @cachedInlineCallbacks(num_args=2) def get_receipts_for_user(self, user_id, receipt_type): + rows = yield self._simple_select_list( + table="receipts_linearized", + keyvalues={ + "user_id": user_id, + "receipt_type": receipt_type, + }, + retcols=("room_id", "event_id"), + desc="get_receipts_for_user", + ) + + defer.returnValue({row["room_id"]: row["event_id"] for row in rows}) + + @defer.inlineCallbacks + def get_receipts_for_user_with_orderings(self, user_id, receipt_type): def f(txn): sql = ( - "SELECT room_id,event_id " - "FROM receipts_linearized " - "WHERE user_id = ? AND receipt_type = ? " + "SELECT rl.room_id, rl.event_id," + " e.topological_ordering, e.stream_ordering" + " FROM receipts_linearized AS rl" + " INNER JOIN events AS e USING (room_id, event_id)" + " WHERE rl.room_id = e.room_id" + " AND rl.event_id = e.event_id" + " AND user_id = ?" ) - txn.execute(sql, (user_id, receipt_type)) + txn.execute(sql, (user_id,)) return txn.fetchall() - - defer.returnValue(dict( - (yield self.runInteraction("get_receipts_for_user", f)) - )) + rows = yield self.runInteraction( + "get_receipts_for_user_with_orderings", f + ) + defer.returnValue({ + row[0]: { + "event_id": row[1], + "topological_ordering": row[2], + "stream_ordering": row[3], + } for row in rows + }) @defer.inlineCallbacks def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): @@ -101,7 +145,7 @@ class ReceiptsStore(SQLBaseStore): defer.returnValue([ev for res in results.values() for ev in res]) - @cachedInlineCallbacks(num_args=3, max_entries=5000) + @cachedInlineCallbacks(num_args=3, tree=True) def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): """Get receipts for a single room for sending to clients. @@ -161,8 +205,8 @@ class ReceiptsStore(SQLBaseStore): "content": content, }]) - @cachedList(cache=get_linearized_receipts_for_room.cache, list_name="room_ids", - num_args=3, inlineCallbacks=True) + @cachedList(cached_method_name="get_linearized_receipts_for_room", + list_name="room_ids", num_args=3, inlineCallbacks=True) def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): if not room_ids: defer.returnValue({}) @@ -222,7 +266,7 @@ class ReceiptsStore(SQLBaseStore): defer.returnValue(results) def get_max_receipt_stream_id(self): - return self._receipts_id_gen.get_max_token() + return self._receipts_id_gen.get_current_token() def insert_linearized_receipt_txn(self, txn, room_id, receipt_type, user_id, event_id, data, stream_id): @@ -230,10 +274,14 @@ class ReceiptsStore(SQLBaseStore): self.get_receipts_for_room.invalidate, (room_id, receipt_type) ) txn.call_after( + self._invalidate_get_users_with_receipts_in_room, + room_id, receipt_type, user_id, + ) + txn.call_after( self.get_receipts_for_user.invalidate, (user_id, receipt_type) ) # FIXME: This shouldn't invalidate the whole cache - txn.call_after(self.get_linearized_receipts_for_room.invalidate_all) + txn.call_after(self.get_linearized_receipts_for_room.invalidate_many, (room_id,)) txn.call_after( self._receipts_stream_cache.entity_has_changed, @@ -245,6 +293,17 @@ class ReceiptsStore(SQLBaseStore): (user_id, room_id, receipt_type) ) + res = self._simple_select_one_txn( + txn, + table="events", + retcols=["topological_ordering", "stream_ordering"], + keyvalues={"event_id": event_id}, + allow_none=True + ) + + topological_ordering = int(res["topological_ordering"]) if res else None + stream_ordering = int(res["stream_ordering"]) if res else None + # We don't want to clobber receipts for more recent events, so we # have to compare orderings of existing receipts sql = ( @@ -256,16 +315,7 @@ class ReceiptsStore(SQLBaseStore): txn.execute(sql, (room_id, receipt_type, user_id)) results = txn.fetchall() - if results: - res = self._simple_select_one_txn( - txn, - table="events", - retcols=["topological_ordering", "stream_ordering"], - keyvalues={"event_id": event_id}, - ) - topological_ordering = int(res["topological_ordering"]) - stream_ordering = int(res["stream_ordering"]) - + if results and topological_ordering: for to, so, _ in results: if int(to) > topological_ordering: return False @@ -295,6 +345,14 @@ class ReceiptsStore(SQLBaseStore): } ) + if receipt_type == "m.read" and topological_ordering: + self._remove_old_push_actions_before_txn( + txn, + room_id=room_id, + user_id=user_id, + topological_ordering=topological_ordering, + ) + return True @defer.inlineCallbacks @@ -347,7 +405,7 @@ class ReceiptsStore(SQLBaseStore): room_id, receipt_type, user_id, event_ids, data ) - max_persisted_id = self._stream_id_gen.get_max_token() + max_persisted_id = self._stream_id_gen.get_current_token() defer.returnValue((stream_id, max_persisted_id)) @@ -365,10 +423,14 @@ class ReceiptsStore(SQLBaseStore): self.get_receipts_for_room.invalidate, (room_id, receipt_type) ) txn.call_after( + self._invalidate_get_users_with_receipts_in_room, + room_id, receipt_type, user_id, + ) + txn.call_after( self.get_receipts_for_user.invalidate, (user_id, receipt_type) ) # FIXME: This shouldn't invalidate the whole cache - txn.call_after(self.get_linearized_receipts_for_room.invalidate_all) + txn.call_after(self.get_linearized_receipts_for_room.invalidate_many, (room_id,)) self._simple_delete_txn( txn, @@ -391,16 +453,22 @@ class ReceiptsStore(SQLBaseStore): } ) - def get_all_updated_receipts(self, last_id, current_id, limit): + def get_all_updated_receipts(self, last_id, current_id, limit=None): + if last_id == current_id: + return defer.succeed([]) + def get_all_updated_receipts_txn(txn): sql = ( "SELECT stream_id, room_id, receipt_type, user_id, event_id, data" " FROM receipts_linearized" " WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC" - " LIMIT ?" ) - txn.execute(sql, (last_id, current_id, limit)) + args = [last_id, current_id] + if limit is not None: + sql += " LIMIT ?" + args.append(limit) + txn.execute(sql, args) return txn.fetchall() return self.runInteraction( diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index bd4eb88a92..983a8ec52b 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -18,25 +18,40 @@ import re from twisted.internet import defer from synapse.api.errors import StoreError, Codes +from synapse.storage import background_updates +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks -from ._base import SQLBaseStore -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList - -class RegistrationStore(SQLBaseStore): +class RegistrationStore(background_updates.BackgroundUpdateStore): def __init__(self, hs): super(RegistrationStore, self).__init__(hs) self.clock = hs.get_clock() + self.register_background_index_update( + "access_tokens_device_index", + index_name="access_tokens_device_id", + table="access_tokens", + columns=["user_id", "device_id"], + ) + + self.register_background_index_update( + "refresh_tokens_device_index", + index_name="refresh_tokens_device_id", + table="refresh_tokens", + columns=["user_id", "device_id"], + ) + @defer.inlineCallbacks - def add_access_token_to_user(self, user_id, token): + def add_access_token_to_user(self, user_id, token, device_id=None): """Adds an access token for the given user. Args: user_id (str): The user ID. token (str): The new access token to add. + device_id (str): ID of the device to associate with the access + token Raises: StoreError if there was a problem adding this. """ @@ -47,51 +62,34 @@ class RegistrationStore(SQLBaseStore): { "id": next_id, "user_id": user_id, - "token": token + "token": token, + "device_id": device_id, }, desc="add_access_token_to_user", ) - @defer.inlineCallbacks - def add_refresh_token_to_user(self, user_id, token): - """Adds a refresh token for the given user. - - Args: - user_id (str): The user ID. - token (str): The new refresh token to add. - Raises: - StoreError if there was a problem adding this. - """ - next_id = self._refresh_tokens_id_gen.get_next() - - yield self._simple_insert( - "refresh_tokens", - { - "id": next_id, - "user_id": user_id, - "token": token - }, - desc="add_refresh_token_to_user", - ) - - @defer.inlineCallbacks - def register(self, user_id, token, password_hash, - was_guest=False, make_guest=False, appservice_id=None): + def register(self, user_id, token=None, password_hash=None, + was_guest=False, make_guest=False, appservice_id=None, + create_profile_with_localpart=None, admin=False): """Attempts to register an account. Args: user_id (str): The desired user ID to register. - token (str): The desired access token to use for this user. + token (str): The desired access token to use for this user. If this + is not None, the given access token is associated with the user + id. password_hash (str): Optional. The password hash for this user. was_guest (bool): Optional. Whether this is a guest account being upgraded to a non-guest account. make_guest (boolean): True if the the new user should be guest, false to add a regular user account. appservice_id (str): The ID of the appservice registering the user. + create_profile_with_localpart (str): Optionally create a profile for + the given localpart. Raises: StoreError if the user_id could not be registered. """ - yield self.runInteraction( + return self.runInteraction( "register", self._register, user_id, @@ -99,9 +97,10 @@ class RegistrationStore(SQLBaseStore): password_hash, was_guest, make_guest, - appservice_id + appservice_id, + create_profile_with_localpart, + admin ) - self.is_guest.invalidate((user_id,)) def _register( self, @@ -111,7 +110,9 @@ class RegistrationStore(SQLBaseStore): password_hash, was_guest, make_guest, - appservice_id + appservice_id, + create_profile_with_localpart, + admin, ): now = int(self.clock.time()) @@ -119,29 +120,48 @@ class RegistrationStore(SQLBaseStore): try: if was_guest: - txn.execute("UPDATE users SET" - " password_hash = ?," - " upgrade_ts = ?," - " is_guest = ?" - " WHERE name = ?", - [password_hash, now, 1 if make_guest else 0, user_id]) + # Ensure that the guest user actually exists + # ``allow_none=False`` makes this raise an exception + # if the row isn't in the database. + self._simple_select_one_txn( + txn, + "users", + keyvalues={ + "name": user_id, + "is_guest": 1, + }, + retcols=("name",), + allow_none=False, + ) + + self._simple_update_one_txn( + txn, + "users", + keyvalues={ + "name": user_id, + "is_guest": 1, + }, + updatevalues={ + "password_hash": password_hash, + "upgrade_ts": now, + "is_guest": 1 if make_guest else 0, + "appservice_id": appservice_id, + "admin": 1 if admin else 0, + } + ) else: - txn.execute("INSERT INTO users " - "(" - " name," - " password_hash," - " creation_ts," - " is_guest," - " appservice_id" - ") " - "VALUES (?,?,?,?,?)", - [ - user_id, - password_hash, - now, - 1 if make_guest else 0, - appservice_id, - ]) + self._simple_insert_txn( + txn, + "users", + values={ + "name": user_id, + "password_hash": password_hash, + "creation_ts": now, + "is_guest": 1 if make_guest else 0, + "appservice_id": appservice_id, + "admin": 1 if admin else 0, + } + ) except self.database_engine.module.IntegrityError: raise StoreError( 400, "User ID already taken.", errcode=Codes.USER_IN_USE @@ -156,6 +176,18 @@ class RegistrationStore(SQLBaseStore): (next_id, user_id, token,) ) + if create_profile_with_localpart: + txn.execute( + "INSERT INTO profiles(user_id) VALUES (?)", + (create_profile_with_localpart,) + ) + + self._invalidate_cache_and_stream( + txn, self.get_user_by_id, (user_id,) + ) + txn.call_after(self.is_guest.invalidate, (user_id,)) + + @cached() def get_user_by_id(self, user_id): return self._simple_select_one( table="users", @@ -181,48 +213,88 @@ class RegistrationStore(SQLBaseStore): return self.runInteraction("get_users_by_id_case_insensitive", f) - @defer.inlineCallbacks def user_set_password_hash(self, user_id, password_hash): """ NB. This does *not* evict any cache because the one use for this removes most of the entries subsequently anyway so it would be pointless. Use flush_user separately. """ - yield self._simple_update_one('users', { - 'name': user_id - }, { - 'password_hash': password_hash - }) + def user_set_password_hash_txn(txn): + self._simple_update_one_txn( + txn, + 'users', { + 'name': user_id + }, + { + 'password_hash': password_hash + } + ) + self._invalidate_cache_and_stream( + txn, self.get_user_by_id, (user_id,) + ) + return self.runInteraction( + "user_set_password_hash", user_set_password_hash_txn + ) @defer.inlineCallbacks - def user_delete_access_tokens(self, user_id, except_token_ids=[]): - def f(txn): - sql = "SELECT token FROM access_tokens WHERE user_id = ?" - clauses = [user_id] + def user_delete_access_tokens(self, user_id, except_token_id=None, + device_id=None, + delete_refresh_tokens=False): + """ + Invalidate access/refresh tokens belonging to a user - if except_token_ids: - sql += " AND id NOT IN (%s)" % ( - ",".join(["?" for _ in except_token_ids]), + Args: + user_id (str): ID of user the tokens belong to + except_token_id (str): list of access_tokens IDs which should + *not* be deleted + device_id (str|None): ID of device the tokens are associated with. + If None, tokens associated with any device (or no device) will + be deleted + delete_refresh_tokens (bool): True to delete refresh tokens as + well as access tokens. + Returns: + defer.Deferred: + """ + def f(txn): + keyvalues = { + "user_id": user_id, + } + if device_id is not None: + keyvalues["device_id"] = device_id + + if delete_refresh_tokens: + self._simple_delete_txn( + txn, + table="refresh_tokens", + keyvalues=keyvalues, ) - clauses += except_token_ids - txn.execute(sql, clauses) + items = keyvalues.items() + where_clause = " AND ".join(k + " = ?" for k, _ in items) + values = [v for _, v in items] + if except_token_id: + where_clause += " AND id != ?" + values.append(except_token_id) - rows = txn.fetchall() - - n = 100 - chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)] - for chunk in chunks: - for row in chunk: - txn.call_after(self.get_user_by_access_token.invalidate, (row[0],)) + txn.execute( + "SELECT token FROM access_tokens WHERE %s" % where_clause, + values + ) + rows = self.cursor_to_dict(txn) - txn.execute( - "DELETE FROM access_tokens WHERE token in (%s)" % ( - ",".join(["?" for _ in chunk]), - ), [r[0] for r in chunk] + for row in rows: + self._invalidate_cache_and_stream( + txn, self.get_user_by_access_token, (row["token"],) ) - yield self.runInteraction("user_delete_access_tokens", f) + txn.execute( + "DELETE FROM access_tokens WHERE %s" % where_clause, + values + ) + + yield self.runInteraction( + "user_delete_access_tokens", f, + ) def delete_access_token(self, access_token): def f(txn): @@ -234,7 +306,9 @@ class RegistrationStore(SQLBaseStore): }, ) - txn.call_after(self.get_user_by_access_token.invalidate, (access_token,)) + self._invalidate_cache_and_stream( + txn, self.get_user_by_access_token, (access_token,) + ) return self.runInteraction("delete_access_token", f) @@ -245,9 +319,8 @@ class RegistrationStore(SQLBaseStore): Args: token (str): The access token of a user. Returns: - dict: Including the name (user_id) and the ID of their access token. - Raises: - StoreError if no user was found. + defer.Deferred: None, if the token did not match, otherwise dict + including the keys `name`, `is_guest`, `device_id`, `token_id`. """ return self.runInteraction( "get_user_by_access_token", @@ -255,46 +328,6 @@ class RegistrationStore(SQLBaseStore): token ) - def exchange_refresh_token(self, refresh_token, token_generator): - """Exchange a refresh token for a new access token and refresh token. - - Doing so invalidates the old refresh token - refresh tokens are single - use. - - Args: - token (str): The refresh token of a user. - token_generator (fn: str -> str): Function which, when given a - user ID, returns a unique refresh token for that user. This - function must never return the same value twice. - Returns: - tuple of (user_id, refresh_token) - Raises: - StoreError if no user was found with that refresh token. - """ - return self.runInteraction( - "exchange_refresh_token", - self._exchange_refresh_token, - refresh_token, - token_generator - ) - - def _exchange_refresh_token(self, txn, old_token, token_generator): - sql = "SELECT user_id FROM refresh_tokens WHERE token = ?" - txn.execute(sql, (old_token,)) - rows = self.cursor_to_dict(txn) - if not rows: - raise StoreError(403, "Did not recognize refresh token") - user_id = rows[0]["user_id"] - - # TODO(danielwh): Maybe perform a validation on the macaroon that - # macaroon.user_id == user_id. - - new_token = token_generator(user_id) - sql = "UPDATE refresh_tokens SET token = ? WHERE token = ?" - txn.execute(sql, (new_token, old_token,)) - - return user_id, new_token - @defer.inlineCallbacks def is_server_admin(self, user): res = yield self._simple_select_one_onecol( @@ -319,29 +352,10 @@ class RegistrationStore(SQLBaseStore): defer.returnValue(res if res else False) - @cachedList(cache=is_guest.cache, list_name="user_ids", num_args=1, - inlineCallbacks=True) - def are_guests(self, user_ids): - sql = "SELECT name, is_guest FROM users WHERE name IN (%s)" % ( - ",".join("?" for _ in user_ids), - ) - - rows = yield self._execute( - "are_guests", self.cursor_to_dict, sql, *user_ids - ) - - result = {user_id: False for user_id in user_ids} - - result.update({ - row["name"]: bool(row["is_guest"]) - for row in rows - }) - - defer.returnValue(result) - def _query_for_auth(self, txn, token): sql = ( - "SELECT users.name, users.is_guest, access_tokens.id as token_id" + "SELECT users.name, users.is_guest, access_tokens.id as token_id," + " access_tokens.device_id" " FROM users" " INNER JOIN access_tokens on users.name = access_tokens.user_id" " WHERE token = ?" @@ -390,6 +404,15 @@ class RegistrationStore(SQLBaseStore): defer.returnValue(ret['user_id']) defer.returnValue(None) + def user_delete_threepids(self, user_id): + return self._simple_delete( + "user_threepids", + keyvalues={ + "user_id": user_id, + }, + desc="user_delete_threepids", + ) + @defer.inlineCallbacks def count_all_users(self): """Counts all users registered on the homeserver.""" @@ -458,12 +481,15 @@ class RegistrationStore(SQLBaseStore): """ Gets the 3pid's guest access token if exists, else saves access_token. - :param medium (str): Medium of the 3pid. Must be "email". - :param address (str): 3pid address. - :param access_token (str): The access token to persist if none is - already persisted. - :param inviter_user_id (str): User ID of the inviter. - :return (deferred str): Whichever access token is persisted at the end + Args: + medium (str): Medium of the 3pid. Must be "email". + address (str): 3pid address. + access_token (str): The access token to persist if none is + already persisted. + inviter_user_id (str): User ID of the inviter. + + Returns: + deferred str: Whichever access token is persisted at the end of this function call. """ def insert(txn): diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 46ab38a313..11813b44f6 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -18,11 +18,11 @@ from twisted.internet import defer from synapse.api.errors import StoreError from ._base import SQLBaseStore -from synapse.util.caches.descriptors import cachedInlineCallbacks from .engines import PostgresEngine, Sqlite3Engine import collections import logging +import ujson as json logger = logging.getLogger(__name__) @@ -48,15 +48,31 @@ class RoomStore(SQLBaseStore): StoreError if the room could not be stored. """ try: - yield self._simple_insert( - "rooms", - { - "room_id": room_id, - "creator": room_creator_user_id, - "is_public": is_public, - }, - desc="store_room", - ) + def store_room_txn(txn, next_id): + self._simple_insert_txn( + txn, + "rooms", + { + "room_id": room_id, + "creator": room_creator_user_id, + "is_public": is_public, + }, + ) + if is_public: + self._simple_insert_txn( + txn, + table="public_room_list_stream", + values={ + "stream_id": next_id, + "room_id": room_id, + "visibility": is_public, + } + ) + with self._public_room_id_gen.get_next() as next_id: + yield self.runInteraction( + "store_room_txn", + store_room_txn, next_id, + ) except Exception as e: logger.error("store_room with room_id=%s failed: %s", room_id, e) raise StoreError(500, "Problem creating room.") @@ -77,6 +93,46 @@ class RoomStore(SQLBaseStore): allow_none=True, ) + @defer.inlineCallbacks + def set_room_is_public(self, room_id, is_public): + def set_room_is_public_txn(txn, next_id): + self._simple_update_one_txn( + txn, + table="rooms", + keyvalues={"room_id": room_id}, + updatevalues={"is_public": is_public}, + ) + + entries = self._simple_select_list_txn( + txn, + table="public_room_list_stream", + keyvalues={"room_id": room_id}, + retcols=("stream_id", "visibility"), + ) + + entries.sort(key=lambda r: r["stream_id"]) + + add_to_stream = True + if entries: + add_to_stream = bool(entries[-1]["visibility"]) != is_public + + if add_to_stream: + self._simple_insert_txn( + txn, + table="public_room_list_stream", + values={ + "stream_id": next_id, + "room_id": room_id, + "visibility": is_public, + } + ) + + with self._public_room_id_gen.get_next() as next_id: + yield self.runInteraction( + "set_room_is_public", + set_room_is_public_txn, next_id, + ) + def get_public_room_ids(self): return self._simple_select_onecol( table="rooms", @@ -161,47 +217,112 @@ class RoomStore(SQLBaseStore): def _store_event_search_txn(self, txn, event, key, value): if isinstance(self.database_engine, PostgresEngine): sql = ( - "INSERT INTO event_search (event_id, room_id, key, vector)" - " VALUES (?,?,?,to_tsvector('english', ?))" + "INSERT INTO event_search" + " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)" + " VALUES (?,?,?,to_tsvector('english', ?),?,?)" + ) + txn.execute( + sql, + ( + event.event_id, event.room_id, key, value, + event.internal_metadata.stream_ordering, + event.origin_server_ts, + ) ) elif isinstance(self.database_engine, Sqlite3Engine): sql = ( "INSERT INTO event_search (event_id, room_id, key, value)" " VALUES (?,?,?,?)" ) + txn.execute(sql, (event.event_id, event.room_id, key, value,)) else: # This should be unreachable. raise Exception("Unrecognized database engine") - txn.execute(sql, (event.event_id, event.room_id, key, value,)) + def add_event_report(self, room_id, event_id, user_id, reason, content, + received_ts): + next_id = self._event_reports_id_gen.get_next() + return self._simple_insert( + table="event_reports", + values={ + "id": next_id, + "received_ts": received_ts, + "room_id": room_id, + "event_id": event_id, + "user_id": user_id, + "reason": reason, + "content": json.dumps(content), + }, + desc="add_event_report" + ) - @cachedInlineCallbacks() - def get_room_name_and_aliases(self, room_id): - def f(txn): - sql = ( - "SELECT event_id FROM current_state_events " - "WHERE room_id = ? " + def get_current_public_room_stream_id(self): + return self._public_room_id_gen.get_current_token() + + def get_public_room_ids_at_stream_id(self, stream_id): + return self.runInteraction( + "get_public_room_ids_at_stream_id", + self.get_public_room_ids_at_stream_id_txn, stream_id + ) + + def get_public_room_ids_at_stream_id_txn(self, txn, stream_id): + return { + rm + for rm, vis in self.get_published_at_stream_id_txn(txn, stream_id).items() + if vis + } + + def get_published_at_stream_id_txn(self, txn, stream_id): + sql = (""" + SELECT room_id, visibility FROM public_room_list_stream + INNER JOIN ( + SELECT room_id, max(stream_id) AS stream_id + FROM public_room_list_stream + WHERE stream_id <= ? + GROUP BY room_id + ) grouped USING (room_id, stream_id) + """) + + txn.execute(sql, (stream_id,)) + return dict(txn.fetchall()) + + def get_public_room_changes(self, prev_stream_id, new_stream_id): + def get_public_room_changes_txn(txn): + then_rooms = self.get_public_room_ids_at_stream_id_txn(txn, prev_stream_id) + + now_rooms_dict = self.get_published_at_stream_id_txn(txn, new_stream_id) + + now_rooms_visible = set( + rm for rm, vis in now_rooms_dict.items() if vis + ) + now_rooms_not_visible = set( + rm for rm, vis in now_rooms_dict.items() if not vis ) - sql += " AND ((type = 'm.room.name' AND state_key = '')" - sql += " OR type = 'm.room.aliases')" + newly_visible = now_rooms_visible - then_rooms + newly_unpublished = now_rooms_not_visible & then_rooms - txn.execute(sql, (room_id,)) - results = self.cursor_to_dict(txn) + return newly_visible, newly_unpublished - return self._parse_events_txn(txn, results) + return self.runInteraction( + "get_public_room_changes", get_public_room_changes_txn + ) - events = yield self.runInteraction("get_room_name_and_aliases", f) + def get_all_new_public_rooms(self, prev_id, current_id, limit): + def get_all_new_public_rooms(txn): + sql = (""" + SELECT stream_id, room_id, visibility FROM public_room_list_stream + WHERE stream_id > ? AND stream_id <= ? + ORDER BY stream_id ASC + LIMIT ? + """) - name = None - aliases = [] + txn.execute(sql, (prev_id, current_id, limit,)) + return txn.fetchall() - for e in events: - if e.type == 'm.room.name': - if 'name' in e.content: - name = e.content['name'] - elif e.type == 'm.room.aliases': - if 'aliases' in e.content: - aliases.extend(e.content['aliases']) + if prev_id == current_id: + return defer.succeed([]) - defer.returnValue((name, aliases)) + return self.runInteraction( + "get_all_new_public_rooms", get_all_new_public_rooms + ) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 0cd89260f2..866d64e679 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -20,8 +20,8 @@ from collections import namedtuple from ._base import SQLBaseStore from synapse.util.caches.descriptors import cached, cachedInlineCallbacks -from synapse.api.constants import Membership -from synapse.types import UserID +from synapse.api.constants import Membership, EventTypes +from synapse.types import get_domain_from_id import logging @@ -36,7 +36,7 @@ RoomsForUser = namedtuple( class RoomMemberStore(SQLBaseStore): - def _store_room_members_txn(self, txn, events): + def _store_room_members_txn(self, txn, events, backfilled): """Store a room member in the database. """ self._simple_insert_many_txn( @@ -56,33 +56,70 @@ class RoomMemberStore(SQLBaseStore): for event in events: txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,)) - txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,)) txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) txn.call_after( self._membership_stream_cache.entity_has_changed, event.state_key, event.internal_metadata.stream_ordering ) + txn.call_after( + self.get_invited_rooms_for_user.invalidate, (event.state_key,) + ) - def get_room_member(self, user_id, room_id): - """Retrieve the current state of a room member. + # We update the local_invites table only if the event is "current", + # i.e., its something that has just happened. + # The only current event that can also be an outlier is if its an + # invite that has come in across federation. + is_new_state = not backfilled and ( + not event.internal_metadata.is_outlier() + or event.internal_metadata.is_invite_from_remote() + ) + is_mine = self.hs.is_mine_id(event.state_key) + if is_new_state and is_mine: + if event.membership == Membership.INVITE: + self._simple_insert_txn( + txn, + table="local_invites", + values={ + "event_id": event.event_id, + "invitee": event.state_key, + "inviter": event.sender, + "room_id": event.room_id, + "stream_id": event.internal_metadata.stream_ordering, + } + ) + else: + sql = ( + "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE" + " room_id = ? AND invitee = ? AND locally_rejected is NULL" + " AND replaced_by is NULL" + ) + + txn.execute(sql, ( + event.internal_metadata.stream_ordering, + event.event_id, + event.room_id, + event.state_key, + )) - Args: - user_id (str): The member's user ID. - room_id (str): The room the member is in. - Returns: - Deferred: Results in a MembershipEvent or None. - """ - return self.runInteraction( - "get_room_member", - self._get_members_events_txn, - room_id, - user_id=user_id, - ).addCallback( - self._get_events - ).addCallback( - lambda events: events[0] if events else None + @defer.inlineCallbacks + def locally_reject_invite(self, user_id, room_id): + sql = ( + "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE" + " room_id = ? AND invitee = ? AND locally_rejected is NULL" + " AND replaced_by is NULL" ) + def f(txn, stream_ordering): + txn.execute(sql, ( + stream_ordering, + True, + room_id, + user_id, + )) + + with self._stream_id_gen.get_next() as stream_ordering: + yield self.runInteraction("locally_reject_invite", f, stream_ordering) + @cached(max_entries=5000) def get_users_in_room(self, room_id): def f(txn): @@ -96,51 +133,36 @@ class RoomMemberStore(SQLBaseStore): return [r["user_id"] for r in rows] return self.runInteraction("get_users_in_room", f) - def get_room_members(self, room_id, membership=None): - """Retrieve the current room member list for a room. - - Args: - room_id (str): The room to get the list of members. - membership (synapse.api.constants.Membership): The filter to apply - to this list, or None to return all members with some state - associated with this room. - Returns: - list of namedtuples representing the members in this room. - """ - return self.runInteraction( - "get_room_members", - self._get_members_events_txn, - room_id, - membership=membership, - ).addCallback(self._get_events) - @cached() - def get_invites_for_user(self, user_id): - """ Get all the invite events for a user + def get_invited_rooms_for_user(self, user_id): + """ Get all the rooms the user is invited to Args: user_id (str): The user ID. Returns: - A deferred list of event objects. + A deferred list of RoomsForUser. """ return self.get_rooms_for_user_where_membership_is( user_id, [Membership.INVITE] - ).addCallback(lambda invites: self._get_events([ - invite.event_id for invite in invites - ])) + ) + + @defer.inlineCallbacks + def get_invite_for_user_in_room(self, user_id, room_id): + """Gets the invite for the given user and room - def get_leave_and_ban_events_for_user(self, user_id): - """ Get all the leave events for a user Args: - user_id (str): The user ID. + user_id (str) + room_id (str) + Returns: - A deferred list of event objects. + Deferred: Resolves to either a RoomsForUser or None if no invite was + found. """ - return self.get_rooms_for_user_where_membership_is( - user_id, (Membership.LEAVE, Membership.BAN) - ).addCallback(lambda leaves: self._get_events([ - leave.event_id for leave in leaves - ])) + invites = yield self.get_invited_rooms_for_user(user_id) + for invite in invites: + if invite.room_id == room_id: + defer.returnValue(invite) + defer.returnValue(None) def get_rooms_for_user_where_membership_is(self, user_id, membership_list): """ Get all the rooms for this user where the membership for this user @@ -165,57 +187,55 @@ class RoomMemberStore(SQLBaseStore): def _get_rooms_for_user_where_membership_is_txn(self, txn, user_id, membership_list): - where_clause = "user_id = ? AND (%s) AND forgotten = 0" % ( - " OR ".join(["membership = ?" for _ in membership_list]), - ) - args = [user_id] - args.extend(membership_list) + do_invite = Membership.INVITE in membership_list + membership_list = [m for m in membership_list if m != Membership.INVITE] - sql = ( - "SELECT m.room_id, m.sender, m.membership, m.event_id, e.stream_ordering" - " FROM current_state_events as c" - " INNER JOIN room_memberships as m" - " ON m.event_id = c.event_id" - " INNER JOIN events as e" - " ON e.event_id = c.event_id" - " AND m.room_id = c.room_id" - " AND m.user_id = c.state_key" - " WHERE %s" - ) % (where_clause,) - - txn.execute(sql, args) - return [ - RoomsForUser(**r) for r in self.cursor_to_dict(txn) - ] + results = [] + if membership_list: + where_clause = "user_id = ? AND (%s) AND forgotten = 0" % ( + " OR ".join(["membership = ?" for _ in membership_list]), + ) - @cached(max_entries=5000) - def get_joined_hosts_for_room(self, room_id): - return self.runInteraction( - "get_joined_hosts_for_room", - self._get_joined_hosts_for_room_txn, - room_id, - ) + args = [user_id] + args.extend(membership_list) - def _get_joined_hosts_for_room_txn(self, txn, room_id): - rows = self._get_members_rows_txn( - txn, - room_id, membership=Membership.JOIN - ) + sql = ( + "SELECT m.room_id, m.sender, m.membership, m.event_id, e.stream_ordering" + " FROM current_state_events as c" + " INNER JOIN room_memberships as m" + " ON m.event_id = c.event_id" + " INNER JOIN events as e" + " ON e.event_id = c.event_id" + " AND m.room_id = c.room_id" + " AND m.user_id = c.state_key" + " WHERE %s" + ) % (where_clause,) + + txn.execute(sql, args) + results = [ + RoomsForUser(**r) for r in self.cursor_to_dict(txn) + ] - joined_domains = set( - UserID.from_string(r["user_id"]).domain - for r in rows - ) + if do_invite: + sql = ( + "SELECT i.room_id, inviter, i.event_id, e.stream_ordering" + " FROM local_invites as i" + " INNER JOIN events as e USING (event_id)" + " WHERE invitee = ? AND locally_rejected is NULL" + " AND replaced_by is NULL" + ) - return joined_domains + txn.execute(sql, (user_id,)) + results.extend(RoomsForUser( + room_id=r["room_id"], + sender=r["inviter"], + event_id=r["event_id"], + stream_ordering=r["stream_ordering"], + membership=Membership.INVITE, + ) for r in self.cursor_to_dict(txn)) - def _get_members_events_txn(self, txn, room_id, membership=None, user_id=None): - rows = self._get_members_rows_txn( - txn, - room_id, membership, user_id, - ) - return [r["event_id"] for r in rows] + return results def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None): where_clause = "c.room_id = ?" @@ -251,7 +271,6 @@ class RoomMemberStore(SQLBaseStore): user_id, membership_list=[Membership.JOIN], ) - @defer.inlineCallbacks def forget(self, user_id, room_id): """Indicate that user_id wishes to discard history for room_id.""" def f(txn): @@ -266,10 +285,13 @@ class RoomMemberStore(SQLBaseStore): " room_id = ?" ) txn.execute(sql, (user_id, room_id)) - yield self.runInteraction("forget_membership", f) - self.was_forgotten_at.invalidate_all() - self.who_forgot_in_room.invalidate_all() - self.did_forget.invalidate((user_id, room_id)) + + txn.call_after(self.was_forgotten_at.invalidate_all) + txn.call_after(self.did_forget.invalidate, (user_id, room_id)) + self._invalidate_cache_and_stream( + txn, self.who_forgot_in_room, (room_id,) + ) + return self.runInteraction("forget_membership", f) @cachedInlineCallbacks(num_args=2) def did_forget(self, user_id, room_id): @@ -297,7 +319,8 @@ class RoomMemberStore(SQLBaseStore): @cachedInlineCallbacks(num_args=3) def was_forgotten_at(self, user_id, room_id, event_id): - """Returns whether user_id has elected to discard history for room_id at event_id. + """Returns whether user_id has elected to discard history for room_id at + event_id. event_id must be a membership event.""" def f(txn): @@ -330,3 +353,98 @@ class RoomMemberStore(SQLBaseStore): }, desc="who_forgot" ) + + def get_joined_users_from_context(self, event, context): + state_group = context.state_group + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + return self._get_joined_users_from_context( + event.room_id, state_group, context.current_state_ids, event=event, + ) + + def get_joined_users_from_state(self, room_id, state_group, state_ids): + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + return self._get_joined_users_from_context( + room_id, state_group, state_ids, + ) + + @cachedInlineCallbacks(num_args=2, cache_context=True) + def _get_joined_users_from_context(self, room_id, state_group, current_state_ids, + cache_context, event=None): + # We don't use `state_group`, its there so that we can cache based + # on it. However, its important that its never None, since two current_state's + # with a state_group of None are likely to be different. + # See bulk_get_push_rules_for_room for how we work around this. + assert state_group is not None + + member_event_ids = [ + e_id + for key, e_id in current_state_ids.iteritems() + if key[0] == EventTypes.Member + ] + + rows = yield self._simple_select_many_batch( + table="room_memberships", + column="event_id", + iterable=member_event_ids, + retcols=['user_id'], + keyvalues={ + "membership": Membership.JOIN, + }, + batch_size=500, + desc="_get_joined_users_from_context", + ) + + users_in_room = set(row["user_id"] for row in rows) + if event is not None and event.type == EventTypes.Member: + if event.membership == Membership.JOIN: + if event.event_id in member_event_ids: + users_in_room.add(event.state_key) + + defer.returnValue(users_in_room) + + def is_host_joined(self, room_id, host, state_group, state_ids): + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + return self._is_host_joined( + room_id, host, state_group, state_ids + ) + + @cachedInlineCallbacks(num_args=3) + def _is_host_joined(self, room_id, host, state_group, current_state_ids): + # We don't use `state_group`, its there so that we can cache based + # on it. However, its important that its never None, since two current_state's + # with a state_group of None are likely to be different. + # See bulk_get_push_rules_for_room for how we work around this. + assert state_group is not None + + for (etype, state_key), event_id in current_state_ids.items(): + if etype == EventTypes.Member: + try: + if get_domain_from_id(state_key) != host: + continue + except: + logger.warn("state_key not user_id: %s", state_key) + continue + + event = yield self.get_event(event_id, allow_none=True) + if event and event.content["membership"] == Membership.JOIN: + defer.returnValue(True) + + defer.returnValue(False) diff --git a/synapse/storage/schema/delta/14/upgrade_appservice_db.py b/synapse/storage/schema/delta/14/upgrade_appservice_db.py index 5c40a77757..8755bb2e49 100644 --- a/synapse/storage/schema/delta/14/upgrade_appservice_db.py +++ b/synapse/storage/schema/delta/14/upgrade_appservice_db.py @@ -18,7 +18,7 @@ import logging logger = logging.getLogger(__name__) -def run_upgrade(cur, *args, **kwargs): +def run_create(cur, *args, **kwargs): cur.execute("SELECT id, regex FROM application_services_regex") for row in cur.fetchall(): try: @@ -35,3 +35,7 @@ def run_upgrade(cur, *args, **kwargs): "UPDATE application_services_regex SET regex=? WHERE id=?", (new_regex, row[0]) ) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/schema/delta/20/pushers.py b/synapse/storage/schema/delta/20/pushers.py index 29164732af..147496a38b 100644 --- a/synapse/storage/schema/delta/20/pushers.py +++ b/synapse/storage/schema/delta/20/pushers.py @@ -27,7 +27,7 @@ import logging logger = logging.getLogger(__name__) -def run_upgrade(cur, database_engine, *args, **kwargs): +def run_create(cur, database_engine, *args, **kwargs): logger.info("Porting pushers table...") cur.execute(""" CREATE TABLE IF NOT EXISTS pushers2 ( @@ -74,3 +74,7 @@ def run_upgrade(cur, database_engine, *args, **kwargs): cur.execute("DROP TABLE pushers") cur.execute("ALTER TABLE pushers2 RENAME TO pushers") logger.info("Moved %d pushers to new table", count) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/schema/delta/22/receipts_index.sql b/synapse/storage/schema/delta/22/receipts_index.sql index 7bc061dff6..bfc0b3bcaa 100644 --- a/synapse/storage/schema/delta/22/receipts_index.sql +++ b/synapse/storage/schema/delta/22/receipts_index.sql @@ -13,6 +13,10 @@ * limitations under the License. */ +/** Using CREATE INDEX directly is deprecated in favour of using background + * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql + * and synapse/storage/registration.py for an example using + * "access_tokens_device_index" **/ CREATE INDEX receipts_linearized_room_stream ON receipts_linearized( room_id, stream_id ); diff --git a/synapse/storage/schema/delta/25/fts.py b/synapse/storage/schema/delta/25/fts.py index d3ff2b1779..4269ac69ad 100644 --- a/synapse/storage/schema/delta/25/fts.py +++ b/synapse/storage/schema/delta/25/fts.py @@ -43,7 +43,7 @@ SQLITE_TABLE = ( ) -def run_upgrade(cur, database_engine, *args, **kwargs): +def run_create(cur, database_engine, *args, **kwargs): if isinstance(database_engine, PostgresEngine): for statement in get_statements(POSTGRES_TABLE.splitlines()): cur.execute(statement) @@ -76,3 +76,7 @@ def run_upgrade(cur, database_engine, *args, **kwargs): sql = database_engine.convert_param_style(sql) cur.execute(sql, ("event_search", progress_json)) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/schema/delta/27/ts.py b/synapse/storage/schema/delta/27/ts.py index f8c16391a2..71b12a2731 100644 --- a/synapse/storage/schema/delta/27/ts.py +++ b/synapse/storage/schema/delta/27/ts.py @@ -27,7 +27,7 @@ ALTER_TABLE = ( ) -def run_upgrade(cur, database_engine, *args, **kwargs): +def run_create(cur, database_engine, *args, **kwargs): for statement in get_statements(ALTER_TABLE.splitlines()): cur.execute(statement) @@ -55,3 +55,7 @@ def run_upgrade(cur, database_engine, *args, **kwargs): sql = database_engine.convert_param_style(sql) cur.execute(sql, ("event_origin_server_ts", progress_json)) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/schema/delta/28/events_room_stream.sql b/synapse/storage/schema/delta/28/events_room_stream.sql index 200c35e6e2..36609475f1 100644 --- a/synapse/storage/schema/delta/28/events_room_stream.sql +++ b/synapse/storage/schema/delta/28/events_room_stream.sql @@ -13,4 +13,8 @@ * limitations under the License. */ +/** Using CREATE INDEX directly is deprecated in favour of using background + * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql + * and synapse/storage/registration.py for an example using + * "access_tokens_device_index" **/ CREATE INDEX events_room_stream on events(room_id, stream_ordering); diff --git a/synapse/storage/schema/delta/28/public_roms_index.sql b/synapse/storage/schema/delta/28/public_roms_index.sql index ba62a974a4..6c1fd68c5b 100644 --- a/synapse/storage/schema/delta/28/public_roms_index.sql +++ b/synapse/storage/schema/delta/28/public_roms_index.sql @@ -13,4 +13,8 @@ * limitations under the License. */ +/** Using CREATE INDEX directly is deprecated in favour of using background + * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql + * and synapse/storage/registration.py for an example using + * "access_tokens_device_index" **/ CREATE INDEX public_room_index on rooms(is_public); diff --git a/synapse/storage/schema/delta/28/receipts_user_id_index.sql b/synapse/storage/schema/delta/28/receipts_user_id_index.sql index 452a1b3c6c..cb84c69baa 100644 --- a/synapse/storage/schema/delta/28/receipts_user_id_index.sql +++ b/synapse/storage/schema/delta/28/receipts_user_id_index.sql @@ -13,6 +13,10 @@ * limitations under the License. */ +/** Using CREATE INDEX directly is deprecated in favour of using background + * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql + * and synapse/storage/registration.py for an example using + * "access_tokens_device_index" **/ CREATE INDEX receipts_linearized_user ON receipts_linearized( user_id ); diff --git a/synapse/storage/schema/delta/29/push_actions.sql b/synapse/storage/schema/delta/29/push_actions.sql index 7e7b09820a..84b21cf813 100644 --- a/synapse/storage/schema/delta/29/push_actions.sql +++ b/synapse/storage/schema/delta/29/push_actions.sql @@ -26,6 +26,10 @@ UPDATE event_push_actions SET stream_ordering = ( UPDATE event_push_actions SET notif = 1, highlight = 0; +/** Using CREATE INDEX directly is deprecated in favour of using background + * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql + * and synapse/storage/registration.py for an example using + * "access_tokens_device_index" **/ CREATE INDEX event_push_actions_rm_tokens on event_push_actions( user_id, room_id, topological_ordering, stream_ordering ); diff --git a/synapse/storage/schema/delta/30/as_users.py b/synapse/storage/schema/delta/30/as_users.py index 4f6e9dd540..5b7d8d1ab5 100644 --- a/synapse/storage/schema/delta/30/as_users.py +++ b/synapse/storage/schema/delta/30/as_users.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from synapse.storage.appservice import ApplicationServiceStore +from synapse.config.appservice import load_appservices logger = logging.getLogger(__name__) -def run_upgrade(cur, database_engine, config, *args, **kwargs): +def run_create(cur, database_engine, *args, **kwargs): # NULL indicates user was not registered by an appservice. try: cur.execute("ALTER TABLE users ADD COLUMN appservice_id TEXT") @@ -26,6 +26,8 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs): # Maybe we already added the column? Hope so... pass + +def run_upgrade(cur, database_engine, config, *args, **kwargs): cur.execute("SELECT name FROM users") rows = cur.fetchall() @@ -36,7 +38,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs): logger.warning("Could not get app_service_config_files from config") pass - appservices = ApplicationServiceStore.load_appservices( + appservices = load_appservices( config.server_name, config_files ) diff --git a/synapse/storage/schema/delta/30/public_rooms.sql b/synapse/storage/schema/delta/30/public_rooms.sql new file mode 100644 index 0000000000..f09db4faa6 --- /dev/null +++ b/synapse/storage/schema/delta/30/public_rooms.sql @@ -0,0 +1,23 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +/* This release removes the restriction that published rooms must have an alias, + * so we go back and ensure the only 'public' rooms are ones with an alias. + * We use (1 = 0) and (1 = 1) so that it works in both postgres and sqlite + */ +UPDATE rooms SET is_public = (1 = 0) WHERE is_public = (1 = 1) AND room_id not in ( + SELECT room_id FROM room_aliases +); diff --git a/synapse/storage/schema/delta/30/state_stream.sql b/synapse/storage/schema/delta/30/state_stream.sql new file mode 100644 index 0000000000..706fe1dcf4 --- /dev/null +++ b/synapse/storage/schema/delta/30/state_stream.sql @@ -0,0 +1,38 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +/** + * The positions in the event stream_ordering when the current_state was + * replaced by the state at the event. + */ + +CREATE TABLE IF NOT EXISTS current_state_resets( + event_stream_ordering BIGINT PRIMARY KEY NOT NULL +); + +/* The outlier events that have aquired a state group typically through + * backfill. This is tracked separately to the events table, as assigning a + * state group change the position of the existing event in the stream + * ordering. + * However since a stream_ordering is assigned in persist_event for the + * (event, state) pair, we can use that stream_ordering to identify when + * the new state was assigned for the event. + */ +CREATE TABLE IF NOT EXISTS ex_outlier_stream( + event_stream_ordering BIGINT PRIMARY KEY NOT NULL, + event_id TEXT NOT NULL, + state_group BIGINT NOT NULL +); diff --git a/synapse/storage/schema/delta/31/invites.sql b/synapse/storage/schema/delta/31/invites.sql new file mode 100644 index 0000000000..2c57846d5a --- /dev/null +++ b/synapse/storage/schema/delta/31/invites.sql @@ -0,0 +1,42 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE local_invites( + stream_id BIGINT NOT NULL, + inviter TEXT NOT NULL, + invitee TEXT NOT NULL, + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + locally_rejected TEXT, + replaced_by TEXT +); + +-- Insert all invites for local users into new `invites` table +INSERT INTO local_invites SELECT + stream_ordering as stream_id, + sender as inviter, + state_key as invitee, + event_id, + room_id, + NULL as locally_rejected, + NULL as replaced_by + FROM events + NATURAL JOIN current_state_events + NATURAL JOIN room_memberships + WHERE membership = 'invite' AND state_key IN (SELECT name FROM users); + +CREATE INDEX local_invites_id ON local_invites(stream_id); +CREATE INDEX local_invites_for_user_idx ON local_invites(invitee, locally_rejected, replaced_by, room_id); diff --git a/synapse/storage/schema/delta/31/local_media_repository_url_cache.sql b/synapse/storage/schema/delta/31/local_media_repository_url_cache.sql new file mode 100644 index 0000000000..9efb4280eb --- /dev/null +++ b/synapse/storage/schema/delta/31/local_media_repository_url_cache.sql @@ -0,0 +1,27 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE local_media_repository_url_cache( + url TEXT, -- the URL being cached + response_code INTEGER, -- the HTTP response code of this download attempt + etag TEXT, -- the etag header of this response + expires INTEGER, -- the number of ms this response was valid for + og TEXT, -- cache of the OG metadata of this URL as JSON + media_id TEXT, -- the media_id, if any, of the URL's content in the repo + download_ts BIGINT -- the timestamp of this download attempt +); + +CREATE INDEX local_media_repository_url_cache_by_url_download_ts + ON local_media_repository_url_cache(url, download_ts); diff --git a/synapse/storage/schema/delta/31/pushers.py b/synapse/storage/schema/delta/31/pushers.py new file mode 100644 index 0000000000..93367fa09e --- /dev/null +++ b/synapse/storage/schema/delta/31/pushers.py @@ -0,0 +1,79 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Change the last_token to last_stream_ordering now that pushers no longer +# listen on an event stream but instead select out of the event_push_actions +# table. + + +import logging + +logger = logging.getLogger(__name__) + + +def token_to_stream_ordering(token): + return int(token[1:].split('_')[0]) + + +def run_create(cur, database_engine, *args, **kwargs): + logger.info("Porting pushers table, delta 31...") + cur.execute(""" + CREATE TABLE IF NOT EXISTS pushers2 ( + id BIGINT PRIMARY KEY, + user_name TEXT NOT NULL, + access_token BIGINT DEFAULT NULL, + profile_tag VARCHAR(32) NOT NULL, + kind VARCHAR(8) NOT NULL, + app_id VARCHAR(64) NOT NULL, + app_display_name VARCHAR(64) NOT NULL, + device_display_name VARCHAR(128) NOT NULL, + pushkey TEXT NOT NULL, + ts BIGINT NOT NULL, + lang VARCHAR(8), + data TEXT, + last_stream_ordering INTEGER, + last_success BIGINT, + failing_since BIGINT, + UNIQUE (app_id, pushkey, user_name) + ) + """) + cur.execute("""SELECT + id, user_name, access_token, profile_tag, kind, + app_id, app_display_name, device_display_name, + pushkey, ts, lang, data, last_token, last_success, + failing_since + FROM pushers + """) + count = 0 + for row in cur.fetchall(): + row = list(row) + row[12] = token_to_stream_ordering(row[12]) + cur.execute(database_engine.convert_param_style(""" + INSERT into pushers2 ( + id, user_name, access_token, profile_tag, kind, + app_id, app_display_name, device_display_name, + pushkey, ts, lang, data, last_stream_ordering, last_success, + failing_since + ) values (%s)""" % (','.join(['?' for _ in range(len(row))]))), + row + ) + count += 1 + cur.execute("DROP TABLE pushers") + cur.execute("ALTER TABLE pushers2 RENAME TO pushers") + logger.info("Moved %d pushers to new table", count) + + +def run_upgrade(cur, database_engine, *args, **kwargs): + pass diff --git a/synapse/storage/schema/delta/31/pushers_index.sql b/synapse/storage/schema/delta/31/pushers_index.sql new file mode 100644 index 0000000000..a82add88fd --- /dev/null +++ b/synapse/storage/schema/delta/31/pushers_index.sql @@ -0,0 +1,22 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Using CREATE INDEX directly is deprecated in favour of using background + * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql + * and synapse/storage/registration.py for an example using + * "access_tokens_device_index" **/ + CREATE INDEX event_push_actions_stream_ordering on event_push_actions( + stream_ordering, user_id + ); diff --git a/synapse/storage/schema/delta/31/search_update.py b/synapse/storage/schema/delta/31/search_update.py new file mode 100644 index 0000000000..470ae0c005 --- /dev/null +++ b/synapse/storage/schema/delta/31/search_update.py @@ -0,0 +1,65 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.engines import PostgresEngine +from synapse.storage.prepare_database import get_statements + +import logging +import ujson + +logger = logging.getLogger(__name__) + + +ALTER_TABLE = """ +ALTER TABLE event_search ADD COLUMN origin_server_ts BIGINT; +ALTER TABLE event_search ADD COLUMN stream_ordering BIGINT; +""" + + +def run_create(cur, database_engine, *args, **kwargs): + if not isinstance(database_engine, PostgresEngine): + return + + for statement in get_statements(ALTER_TABLE.splitlines()): + cur.execute(statement) + + cur.execute("SELECT MIN(stream_ordering) FROM events") + rows = cur.fetchall() + min_stream_id = rows[0][0] + + cur.execute("SELECT MAX(stream_ordering) FROM events") + rows = cur.fetchall() + max_stream_id = rows[0][0] + + if min_stream_id is not None and max_stream_id is not None: + progress = { + "target_min_stream_id_inclusive": min_stream_id, + "max_stream_id_exclusive": max_stream_id + 1, + "rows_inserted": 0, + "have_added_indexes": False, + } + progress_json = ujson.dumps(progress) + + sql = ( + "INSERT into background_updates (update_name, progress_json)" + " VALUES (?, ?)" + ) + + sql = database_engine.convert_param_style(sql) + + cur.execute(sql, ("event_search_order", progress_json)) + + +def run_upgrade(cur, database_engine, *args, **kwargs): + pass diff --git a/synapse/storage/schema/delta/32/events.sql b/synapse/storage/schema/delta/32/events.sql new file mode 100644 index 0000000000..1dd0f9e170 --- /dev/null +++ b/synapse/storage/schema/delta/32/events.sql @@ -0,0 +1,16 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE events ADD COLUMN received_ts BIGINT; diff --git a/synapse/storage/schema/delta/32/openid.sql b/synapse/storage/schema/delta/32/openid.sql new file mode 100644 index 0000000000..36f37b11c8 --- /dev/null +++ b/synapse/storage/schema/delta/32/openid.sql @@ -0,0 +1,9 @@ + +CREATE TABLE open_id_tokens ( + token TEXT NOT NULL PRIMARY KEY, + ts_valid_until_ms bigint NOT NULL, + user_id TEXT NOT NULL, + UNIQUE (token) +); + +CREATE index open_id_tokens_ts_valid_until_ms ON open_id_tokens(ts_valid_until_ms); diff --git a/synapse/storage/schema/delta/32/pusher_throttle.sql b/synapse/storage/schema/delta/32/pusher_throttle.sql new file mode 100644 index 0000000000..d86d30c13c --- /dev/null +++ b/synapse/storage/schema/delta/32/pusher_throttle.sql @@ -0,0 +1,23 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE pusher_throttle( + pusher BIGINT NOT NULL, + room_id TEXT NOT NULL, + last_sent_ts BIGINT, + throttle_ms BIGINT, + PRIMARY KEY (pusher, room_id) +); diff --git a/synapse/storage/schema/delta/32/remove_indices.sql b/synapse/storage/schema/delta/32/remove_indices.sql new file mode 100644 index 0000000000..f859be46a6 --- /dev/null +++ b/synapse/storage/schema/delta/32/remove_indices.sql @@ -0,0 +1,38 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- The following indices are redundant, other indices are equivalent or +-- supersets +DROP INDEX IF EXISTS events_room_id; -- Prefix of events_room_stream +DROP INDEX IF EXISTS events_order; -- Prefix of events_order_topo_stream_room +DROP INDEX IF EXISTS events_topological_ordering; -- Prefix of events_order_topo_stream_room +DROP INDEX IF EXISTS events_stream_ordering; -- Duplicate of PRIMARY KEY +DROP INDEX IF EXISTS state_groups_id; -- Duplicate of PRIMARY KEY +DROP INDEX IF EXISTS event_to_state_groups_id; -- Duplicate of PRIMARY KEY +DROP INDEX IF EXISTS event_push_actions_room_id_event_id_user_id_profile_tag; -- Duplicate of UNIQUE CONSTRAINT + +DROP INDEX IF EXISTS event_destinations_id; -- Prefix of UNIQUE CONSTRAINT +DROP INDEX IF EXISTS st_extrem_id; -- Prefix of UNIQUE CONSTRAINT +DROP INDEX IF EXISTS event_content_hashes_id; -- Prefix of UNIQUE CONSTRAINT +DROP INDEX IF EXISTS event_signatures_id; -- Prefix of UNIQUE CONSTRAINT +DROP INDEX IF EXISTS event_edge_hashes_id; -- Prefix of UNIQUE CONSTRAINT +DROP INDEX IF EXISTS redactions_event_id; -- Duplicate of UNIQUE CONSTRAINT +DROP INDEX IF EXISTS room_hosts_room_id; -- Prefix of UNIQUE CONSTRAINT + +-- The following indices were unused +DROP INDEX IF EXISTS remote_media_cache_thumbnails_media_id; +DROP INDEX IF EXISTS evauth_edges_auth_id; +DROP INDEX IF EXISTS presence_stream_state; diff --git a/synapse/storage/schema/delta/32/reports.sql b/synapse/storage/schema/delta/32/reports.sql new file mode 100644 index 0000000000..d13609776f --- /dev/null +++ b/synapse/storage/schema/delta/32/reports.sql @@ -0,0 +1,25 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE event_reports( + id BIGINT NOT NULL PRIMARY KEY, + received_ts BIGINT NOT NULL, + room_id TEXT NOT NULL, + event_id TEXT NOT NULL, + user_id TEXT NOT NULL, + reason TEXT, + content TEXT +); diff --git a/synapse/storage/schema/delta/33/access_tokens_device_index.sql b/synapse/storage/schema/delta/33/access_tokens_device_index.sql new file mode 100644 index 0000000000..61ad3fe3e8 --- /dev/null +++ b/synapse/storage/schema/delta/33/access_tokens_device_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('access_tokens_device_index', '{}'); diff --git a/synapse/storage/schema/delta/33/devices.sql b/synapse/storage/schema/delta/33/devices.sql new file mode 100644 index 0000000000..eca7268d82 --- /dev/null +++ b/synapse/storage/schema/delta/33/devices.sql @@ -0,0 +1,21 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE devices ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + display_name TEXT, + CONSTRAINT device_uniqueness UNIQUE (user_id, device_id) +); diff --git a/synapse/storage/schema/delta/33/devices_for_e2e_keys.sql b/synapse/storage/schema/delta/33/devices_for_e2e_keys.sql new file mode 100644 index 0000000000..aa4a3b9f2f --- /dev/null +++ b/synapse/storage/schema/delta/33/devices_for_e2e_keys.sql @@ -0,0 +1,19 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- make sure that we have a device record for each set of E2E keys, so that the +-- user can delete them if they like. +INSERT INTO devices + SELECT user_id, device_id, NULL FROM e2e_device_keys_json; diff --git a/synapse/storage/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql b/synapse/storage/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql new file mode 100644 index 0000000000..6671573398 --- /dev/null +++ b/synapse/storage/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql @@ -0,0 +1,20 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- a previous version of the "devices_for_e2e_keys" delta set all the device +-- names to "unknown device". This wasn't terribly helpful +UPDATE devices + SET display_name = NULL + WHERE display_name = 'unknown device'; diff --git a/synapse/storage/schema/delta/33/event_fields.py b/synapse/storage/schema/delta/33/event_fields.py new file mode 100644 index 0000000000..83066cccc9 --- /dev/null +++ b/synapse/storage/schema/delta/33/event_fields.py @@ -0,0 +1,60 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.prepare_database import get_statements + +import logging +import ujson + +logger = logging.getLogger(__name__) + + +ALTER_TABLE = """ +ALTER TABLE events ADD COLUMN sender TEXT; +ALTER TABLE events ADD COLUMN contains_url BOOLEAN; +""" + + +def run_create(cur, database_engine, *args, **kwargs): + for statement in get_statements(ALTER_TABLE.splitlines()): + cur.execute(statement) + + cur.execute("SELECT MIN(stream_ordering) FROM events") + rows = cur.fetchall() + min_stream_id = rows[0][0] + + cur.execute("SELECT MAX(stream_ordering) FROM events") + rows = cur.fetchall() + max_stream_id = rows[0][0] + + if min_stream_id is not None and max_stream_id is not None: + progress = { + "target_min_stream_id_inclusive": min_stream_id, + "max_stream_id_exclusive": max_stream_id + 1, + "rows_inserted": 0, + } + progress_json = ujson.dumps(progress) + + sql = ( + "INSERT into background_updates (update_name, progress_json)" + " VALUES (?, ?)" + ) + + sql = database_engine.convert_param_style(sql) + + cur.execute(sql, ("event_fields_sender_url", progress_json)) + + +def run_upgrade(cur, database_engine, *args, **kwargs): + pass diff --git a/synapse/storage/schema/delta/33/refreshtoken_device.sql b/synapse/storage/schema/delta/33/refreshtoken_device.sql new file mode 100644 index 0000000000..290bd6da86 --- /dev/null +++ b/synapse/storage/schema/delta/33/refreshtoken_device.sql @@ -0,0 +1,16 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE refresh_tokens ADD COLUMN device_id TEXT; diff --git a/synapse/storage/schema/delta/33/refreshtoken_device_index.sql b/synapse/storage/schema/delta/33/refreshtoken_device_index.sql new file mode 100644 index 0000000000..bb225dafbf --- /dev/null +++ b/synapse/storage/schema/delta/33/refreshtoken_device_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('refresh_tokens_device_index', '{}'); diff --git a/synapse/storage/schema/delta/33/remote_media_ts.py b/synapse/storage/schema/delta/33/remote_media_ts.py new file mode 100644 index 0000000000..55ae43f395 --- /dev/null +++ b/synapse/storage/schema/delta/33/remote_media_ts.py @@ -0,0 +1,31 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + + +ALTER_TABLE = "ALTER TABLE remote_media_cache ADD COLUMN last_access_ts BIGINT" + + +def run_create(cur, database_engine, *args, **kwargs): + cur.execute(ALTER_TABLE) + + +def run_upgrade(cur, database_engine, *args, **kwargs): + cur.execute( + database_engine.convert_param_style( + "UPDATE remote_media_cache SET last_access_ts = ?" + ), + (int(time.time() * 1000),) + ) diff --git a/synapse/storage/schema/delta/33/user_ips_index.sql b/synapse/storage/schema/delta/33/user_ips_index.sql new file mode 100644 index 0000000000..473f75a78e --- /dev/null +++ b/synapse/storage/schema/delta/33/user_ips_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('user_ips_device_index', '{}'); diff --git a/synapse/storage/schema/delta/34/appservice_stream.sql b/synapse/storage/schema/delta/34/appservice_stream.sql new file mode 100644 index 0000000000..69e16eda0f --- /dev/null +++ b/synapse/storage/schema/delta/34/appservice_stream.sql @@ -0,0 +1,23 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS appservice_stream_position( + Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. + stream_ordering BIGINT, + CHECK (Lock='X') +); + +INSERT INTO appservice_stream_position (stream_ordering) + SELECT COALESCE(MAX(stream_ordering), 0) FROM events; diff --git a/synapse/storage/schema/delta/34/cache_stream.py b/synapse/storage/schema/delta/34/cache_stream.py new file mode 100644 index 0000000000..3b63a1562d --- /dev/null +++ b/synapse/storage/schema/delta/34/cache_stream.py @@ -0,0 +1,46 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.prepare_database import get_statements +from synapse.storage.engines import PostgresEngine + +import logging + +logger = logging.getLogger(__name__) + + +# This stream is used to notify replication slaves that some caches have +# been invalidated that they cannot infer from the other streams. +CREATE_TABLE = """ +CREATE TABLE cache_invalidation_stream ( + stream_id BIGINT, + cache_func TEXT, + keys TEXT[], + invalidation_ts BIGINT +); + +CREATE INDEX cache_invalidation_stream_id ON cache_invalidation_stream(stream_id); +""" + + +def run_create(cur, database_engine, *args, **kwargs): + if not isinstance(database_engine, PostgresEngine): + return + + for statement in get_statements(CREATE_TABLE.splitlines()): + cur.execute(statement) + + +def run_upgrade(cur, database_engine, *args, **kwargs): + pass diff --git a/synapse/storage/schema/delta/34/device_inbox.sql b/synapse/storage/schema/delta/34/device_inbox.sql new file mode 100644 index 0000000000..e68844c74a --- /dev/null +++ b/synapse/storage/schema/delta/34/device_inbox.sql @@ -0,0 +1,24 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE device_inbox ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + stream_id BIGINT NOT NULL, + message_json TEXT NOT NULL -- {"type":, "sender":, "content",} +); + +CREATE INDEX device_inbox_user_stream_id ON device_inbox(user_id, device_id, stream_id); +CREATE INDEX device_inbox_stream_id ON device_inbox(stream_id); diff --git a/synapse/storage/schema/delta/34/push_display_name_rename.sql b/synapse/storage/schema/delta/34/push_display_name_rename.sql new file mode 100644 index 0000000000..0d9fe1a99a --- /dev/null +++ b/synapse/storage/schema/delta/34/push_display_name_rename.sql @@ -0,0 +1,20 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +DELETE FROM push_rules WHERE rule_id = 'global/override/.m.rule.contains_display_name'; +UPDATE push_rules SET rule_id = 'global/override/.m.rule.contains_display_name' WHERE rule_id = 'global/underride/.m.rule.contains_display_name'; + +DELETE FROM push_rules_enable WHERE rule_id = 'global/override/.m.rule.contains_display_name'; +UPDATE push_rules_enable SET rule_id = 'global/override/.m.rule.contains_display_name' WHERE rule_id = 'global/underride/.m.rule.contains_display_name'; diff --git a/synapse/storage/schema/delta/34/received_txn_purge.py b/synapse/storage/schema/delta/34/received_txn_purge.py new file mode 100644 index 0000000000..033144341c --- /dev/null +++ b/synapse/storage/schema/delta/34/received_txn_purge.py @@ -0,0 +1,32 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.engines import PostgresEngine + +import logging + +logger = logging.getLogger(__name__) + + +def run_create(cur, database_engine, *args, **kwargs): + if isinstance(database_engine, PostgresEngine): + cur.execute("TRUNCATE received_transactions") + else: + cur.execute("DELETE FROM received_transactions") + + cur.execute("CREATE INDEX received_transactions_ts ON received_transactions(ts)") + + +def run_upgrade(cur, database_engine, *args, **kwargs): + pass diff --git a/synapse/storage/schema/delta/34/sent_txn_purge.py b/synapse/storage/schema/delta/34/sent_txn_purge.py new file mode 100644 index 0000000000..81948e3431 --- /dev/null +++ b/synapse/storage/schema/delta/34/sent_txn_purge.py @@ -0,0 +1,32 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.engines import PostgresEngine + +import logging + +logger = logging.getLogger(__name__) + + +def run_create(cur, database_engine, *args, **kwargs): + if isinstance(database_engine, PostgresEngine): + cur.execute("TRUNCATE sent_transactions") + else: + cur.execute("DELETE FROM sent_transactions") + + cur.execute("CREATE INDEX sent_transactions_ts ON sent_transactions(ts)") + + +def run_upgrade(cur, database_engine, *args, **kwargs): + pass diff --git a/synapse/storage/schema/delta/35/add_state_index.sql b/synapse/storage/schema/delta/35/add_state_index.sql new file mode 100644 index 0000000000..0fce26345b --- /dev/null +++ b/synapse/storage/schema/delta/35/add_state_index.sql @@ -0,0 +1,20 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +ALTER TABLE background_updates ADD COLUMN depends_on TEXT; + +INSERT into background_updates (update_name, progress_json, depends_on) + VALUES ('state_group_state_type_index', '{}', 'state_group_state_deduplication'); diff --git a/synapse/storage/schema/delta/35/contains_url.sql b/synapse/storage/schema/delta/35/contains_url.sql new file mode 100644 index 0000000000..6cd123027b --- /dev/null +++ b/synapse/storage/schema/delta/35/contains_url.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + INSERT into background_updates (update_name, progress_json) + VALUES ('event_contains_url_index', '{}'); diff --git a/synapse/storage/schema/delta/35/device_outbox.sql b/synapse/storage/schema/delta/35/device_outbox.sql new file mode 100644 index 0000000000..17e6c43105 --- /dev/null +++ b/synapse/storage/schema/delta/35/device_outbox.sql @@ -0,0 +1,39 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +DROP TABLE IF EXISTS device_federation_outbox; +CREATE TABLE device_federation_outbox ( + destination TEXT NOT NULL, + stream_id BIGINT NOT NULL, + queued_ts BIGINT NOT NULL, + messages_json TEXT NOT NULL +); + + +DROP INDEX IF EXISTS device_federation_outbox_destination_id; +CREATE INDEX device_federation_outbox_destination_id + ON device_federation_outbox(destination, stream_id); + + +DROP TABLE IF EXISTS device_federation_inbox; +CREATE TABLE device_federation_inbox ( + origin TEXT NOT NULL, + message_id TEXT NOT NULL, + received_ts BIGINT NOT NULL +); + +DROP INDEX IF EXISTS device_federation_inbox_sender_id; +CREATE INDEX device_federation_inbox_sender_id + ON device_federation_inbox(origin, message_id); diff --git a/synapse/storage/schema/delta/35/device_stream_id.sql b/synapse/storage/schema/delta/35/device_stream_id.sql new file mode 100644 index 0000000000..7ab7d942e2 --- /dev/null +++ b/synapse/storage/schema/delta/35/device_stream_id.sql @@ -0,0 +1,21 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE device_max_stream_id ( + stream_id BIGINT NOT NULL +); + +INSERT INTO device_max_stream_id (stream_id) + SELECT COALESCE(MAX(stream_id), 0) FROM device_inbox; diff --git a/synapse/storage/schema/delta/35/event_push_actions_index.sql b/synapse/storage/schema/delta/35/event_push_actions_index.sql new file mode 100644 index 0000000000..2e836d8e9c --- /dev/null +++ b/synapse/storage/schema/delta/35/event_push_actions_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + INSERT into background_updates (update_name, progress_json) + VALUES ('epa_highlight_index', '{}'); diff --git a/synapse/storage/schema/delta/35/public_room_list_change_stream.sql b/synapse/storage/schema/delta/35/public_room_list_change_stream.sql new file mode 100644 index 0000000000..dd2bf2e28a --- /dev/null +++ b/synapse/storage/schema/delta/35/public_room_list_change_stream.sql @@ -0,0 +1,33 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE public_room_list_stream ( + stream_id BIGINT NOT NULL, + room_id TEXT NOT NULL, + visibility BOOLEAN NOT NULL +); + +INSERT INTO public_room_list_stream (stream_id, room_id, visibility) + SELECT 1, room_id, is_public FROM rooms + WHERE is_public = CAST(1 AS BOOLEAN); + +CREATE INDEX public_room_list_stream_idx on public_room_list_stream( + stream_id +); + +CREATE INDEX public_room_list_stream_rm_idx on public_room_list_stream( + room_id, stream_id +); diff --git a/synapse/storage/schema/delta/35/state.sql b/synapse/storage/schema/delta/35/state.sql new file mode 100644 index 0000000000..0f1fa68a89 --- /dev/null +++ b/synapse/storage/schema/delta/35/state.sql @@ -0,0 +1,22 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE state_group_edges( + state_group BIGINT NOT NULL, + prev_state_group BIGINT NOT NULL +); + +CREATE INDEX state_group_edges_idx ON state_group_edges(state_group); +CREATE INDEX state_group_edges_prev_idx ON state_group_edges(prev_state_group); diff --git a/synapse/storage/schema/delta/35/state_dedupe.sql b/synapse/storage/schema/delta/35/state_dedupe.sql new file mode 100644 index 0000000000..97e5067ef4 --- /dev/null +++ b/synapse/storage/schema/delta/35/state_dedupe.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT into background_updates (update_name, progress_json) + VALUES ('state_group_state_deduplication', '{}'); diff --git a/synapse/storage/schema/delta/35/stream_order_to_extrem.sql b/synapse/storage/schema/delta/35/stream_order_to_extrem.sql new file mode 100644 index 0000000000..2b945d8a57 --- /dev/null +++ b/synapse/storage/schema/delta/35/stream_order_to_extrem.sql @@ -0,0 +1,37 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE stream_ordering_to_exterm ( + stream_ordering BIGINT NOT NULL, + room_id TEXT NOT NULL, + event_id TEXT NOT NULL +); + +INSERT INTO stream_ordering_to_exterm (stream_ordering, room_id, event_id) + SELECT stream_ordering, room_id, event_id FROM event_forward_extremities + INNER JOIN ( + SELECT room_id, max(stream_ordering) as stream_ordering FROM events + INNER JOIN event_forward_extremities USING (room_id, event_id) + GROUP BY room_id + ) AS rms USING (room_id); + +CREATE INDEX stream_ordering_to_exterm_idx on stream_ordering_to_exterm( + stream_ordering +); + +CREATE INDEX stream_ordering_to_exterm_rm_idx on stream_ordering_to_exterm( + room_id, stream_ordering +); diff --git a/synapse/storage/schema/delta/36/readd_public_rooms.sql b/synapse/storage/schema/delta/36/readd_public_rooms.sql new file mode 100644 index 0000000000..90d8fd18f9 --- /dev/null +++ b/synapse/storage/schema/delta/36/readd_public_rooms.sql @@ -0,0 +1,26 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Re-add some entries to stream_ordering_to_exterm that were incorrectly deleted +INSERT INTO stream_ordering_to_exterm (stream_ordering, room_id, event_id) + SELECT + (SELECT stream_ordering FROM events where event_id = e.event_id) AS stream_ordering, + room_id, + event_id + FROM event_forward_extremities AS e + WHERE NOT EXISTS ( + SELECT room_id FROM stream_ordering_to_exterm AS s + WHERE s.room_id = e.room_id + ); diff --git a/synapse/storage/schema/delta/37/remove_auth_idx.py b/synapse/storage/schema/delta/37/remove_auth_idx.py new file mode 100644 index 0000000000..784f3b348f --- /dev/null +++ b/synapse/storage/schema/delta/37/remove_auth_idx.py @@ -0,0 +1,81 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.prepare_database import get_statements +from synapse.storage.engines import PostgresEngine + +import logging + +logger = logging.getLogger(__name__) + +DROP_INDICES = """ +-- We only ever query based on event_id +DROP INDEX IF EXISTS state_events_room_id; +DROP INDEX IF EXISTS state_events_type; +DROP INDEX IF EXISTS state_events_state_key; + +-- room_id is indexed elsewhere +DROP INDEX IF EXISTS current_state_events_room_id; +DROP INDEX IF EXISTS current_state_events_state_key; +DROP INDEX IF EXISTS current_state_events_type; + +DROP INDEX IF EXISTS transactions_have_ref; + +-- (topological_ordering, stream_ordering, room_id) seems like a strange index, +-- and is used incredibly rarely. +DROP INDEX IF EXISTS events_order_topo_stream_room; + +DROP INDEX IF EXISTS event_search_ev_idx; +""" + +POSTGRES_DROP_CONSTRAINT = """ +ALTER TABLE event_auth DROP CONSTRAINT IF EXISTS event_auth_event_id_auth_id_room_id_key; +""" + +SQLITE_DROP_CONSTRAINT = """ +DROP INDEX IF EXISTS evauth_edges_id; + +CREATE TABLE IF NOT EXISTS event_auth_new( + event_id TEXT NOT NULL, + auth_id TEXT NOT NULL, + room_id TEXT NOT NULL +); + +INSERT INTO event_auth_new + SELECT event_id, auth_id, room_id + FROM event_auth; + +DROP TABLE event_auth; + +ALTER TABLE event_auth_new RENAME TO event_auth; + +CREATE INDEX evauth_edges_id ON event_auth(event_id); +""" + + +def run_create(cur, database_engine, *args, **kwargs): + for statement in get_statements(DROP_INDICES.splitlines()): + cur.execute(statement) + + if isinstance(database_engine, PostgresEngine): + drop_constraint = POSTGRES_DROP_CONSTRAINT + else: + drop_constraint = SQLITE_DROP_CONSTRAINT + + for statement in get_statements(drop_constraint.splitlines()): + cur.execute(statement) + + +def run_upgrade(cur, database_engine, *args, **kwargs): + pass diff --git a/synapse/storage/schema/delta/37/user_threepids.sql b/synapse/storage/schema/delta/37/user_threepids.sql new file mode 100644 index 0000000000..cf7a90dd10 --- /dev/null +++ b/synapse/storage/schema/delta/37/user_threepids.sql @@ -0,0 +1,52 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Update any email addresses that were stored with mixed case into all + * lowercase + */ + + -- There may be "duplicate" emails (with different case) already in the table, + -- so we find them and move all but the most recently used account. + UPDATE user_threepids + SET medium = 'email_old' + WHERE medium = 'email' + AND address IN ( + -- We select all the addresses that are linked to the user_id that is NOT + -- the most recently created. + SELECT u.address + FROM + user_threepids AS u, + -- `duplicate_addresses` is a table of all the email addresses that + -- appear multiple times and when the binding was created + ( + SELECT lower(u1.address) AS address, max(u1.added_at) AS max_ts + FROM user_threepids AS u1 + INNER JOIN user_threepids AS u2 ON u1.medium = u2.medium AND lower(u1.address) = lower(u2.address) AND u1.address != u2.address + WHERE u1.medium = 'email' AND u2.medium = 'email' + GROUP BY lower(u1.address) + ) AS duplicate_addresses + WHERE + lower(u.address) = duplicate_addresses.address + AND u.added_at != max_ts -- NOT the most recently created + ); + + +-- This update is now safe since we've removed the duplicate addresses. +UPDATE user_threepids SET address = LOWER(address) WHERE medium = 'email'; + + +/* Add an index for the select we do on passwored reset */ +CREATE INDEX user_threepids_medium_address on user_threepids (medium, address); diff --git a/synapse/storage/schema/delta/38/postgres_fts_gist.sql b/synapse/storage/schema/delta/38/postgres_fts_gist.sql new file mode 100644 index 0000000000..f090a7b75a --- /dev/null +++ b/synapse/storage/schema/delta/38/postgres_fts_gist.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + INSERT into background_updates (update_name, progress_json) + VALUES ('event_search_postgres_gist', '{}'); diff --git a/synapse/storage/schema/delta/39/device_federation_stream_idx.sql b/synapse/storage/schema/delta/39/device_federation_stream_idx.sql new file mode 100644 index 0000000000..00be801e90 --- /dev/null +++ b/synapse/storage/schema/delta/39/device_federation_stream_idx.sql @@ -0,0 +1,16 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE INDEX device_federation_outbox_id ON device_federation_outbox(stream_id); diff --git a/synapse/storage/schema/delta/39/event_push_index.sql b/synapse/storage/schema/delta/39/event_push_index.sql new file mode 100644 index 0000000000..de2ad93e5c --- /dev/null +++ b/synapse/storage/schema/delta/39/event_push_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('event_push_actions_highlights_index', '{}'); diff --git a/synapse/storage/schema/delta/39/federation_out_position.sql b/synapse/storage/schema/delta/39/federation_out_position.sql new file mode 100644 index 0000000000..5af814290b --- /dev/null +++ b/synapse/storage/schema/delta/39/federation_out_position.sql @@ -0,0 +1,22 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + CREATE TABLE federation_stream_position( + type TEXT NOT NULL, + stream_id INTEGER NOT NULL + ); + + INSERT INTO federation_stream_position (type, stream_id) VALUES ('federation', -1); + INSERT INTO federation_stream_position (type, stream_id) SELECT 'events', coalesce(max(stream_ordering), -1) FROM events; diff --git a/synapse/storage/search.py b/synapse/storage/search.py index 59ac7f424c..8f2b3c4435 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -21,6 +21,7 @@ from synapse.storage.engines import PostgresEngine, Sqlite3Engine import logging import re +import ujson as json logger = logging.getLogger(__name__) @@ -29,12 +30,22 @@ logger = logging.getLogger(__name__) class SearchStore(BackgroundUpdateStore): EVENT_SEARCH_UPDATE_NAME = "event_search" + EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order" + EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist" def __init__(self, hs): super(SearchStore, self).__init__(hs) self.register_background_update_handler( self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search ) + self.register_background_update_handler( + self.EVENT_SEARCH_ORDER_UPDATE_NAME, + self._background_reindex_search_order + ) + self.register_background_update_handler( + self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME, + self._background_reindex_gist_search + ) @defer.inlineCallbacks def _background_reindex_search(self, progress, batch_size): @@ -47,7 +58,7 @@ class SearchStore(BackgroundUpdateStore): def reindex_search_txn(txn): sql = ( - "SELECT stream_ordering, event_id FROM events" + "SELECT stream_ordering, event_id, room_id, type, content FROM events" " WHERE ? <= stream_ordering AND stream_ordering < ?" " AND (%s)" " ORDER BY stream_ordering DESC" @@ -56,28 +67,30 @@ class SearchStore(BackgroundUpdateStore): txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) - rows = txn.fetchall() + rows = self.cursor_to_dict(txn) if not rows: return 0 - min_stream_id = rows[-1][0] - event_ids = [row[1] for row in rows] - - events = self._get_events_txn(txn, event_ids) + min_stream_id = rows[-1]["stream_ordering"] event_search_rows = [] - for event in events: + for row in rows: try: - event_id = event.event_id - room_id = event.room_id - content = event.content - if event.type == "m.room.message": + event_id = row["event_id"] + room_id = row["room_id"] + etype = row["type"] + try: + content = json.loads(row["content"]) + except: + continue + + if etype == "m.room.message": key = "content.body" value = content["body"] - elif event.type == "m.room.topic": + elif etype == "m.room.topic": key = "content.topic" value = content["topic"] - elif event.type == "m.room.name": + elif etype == "m.room.name": key = "content.name" value = content["name"] except (KeyError, AttributeError): @@ -132,6 +145,104 @@ class SearchStore(BackgroundUpdateStore): defer.returnValue(result) @defer.inlineCallbacks + def _background_reindex_gist_search(self, progress, batch_size): + def create_index(conn): + conn.rollback() + conn.set_session(autocommit=True) + c = conn.cursor() + + c.execute( + "CREATE INDEX CONCURRENTLY event_search_fts_idx_gist" + " ON event_search USING GIST (vector)" + ) + + c.execute("DROP INDEX event_search_fts_idx") + + conn.set_session(autocommit=False) + + if isinstance(self.database_engine, PostgresEngine): + yield self.runWithConnection(create_index) + + yield self._end_background_update(self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME) + defer.returnValue(1) + + @defer.inlineCallbacks + def _background_reindex_search_order(self, progress, batch_size): + target_min_stream_id = progress["target_min_stream_id_inclusive"] + max_stream_id = progress["max_stream_id_exclusive"] + rows_inserted = progress.get("rows_inserted", 0) + have_added_index = progress['have_added_indexes'] + + if not have_added_index: + def create_index(conn): + conn.rollback() + conn.set_session(autocommit=True) + c = conn.cursor() + + # We create with NULLS FIRST so that when we search *backwards* + # we get the ones with non null origin_server_ts *first* + c.execute( + "CREATE INDEX CONCURRENTLY event_search_room_order ON event_search(" + "room_id, origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)" + ) + c.execute( + "CREATE INDEX CONCURRENTLY event_search_order ON event_search(" + "origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)" + ) + conn.set_session(autocommit=False) + + yield self.runWithConnection(create_index) + + pg = dict(progress) + pg["have_added_indexes"] = True + + yield self.runInteraction( + self.EVENT_SEARCH_ORDER_UPDATE_NAME, + self._background_update_progress_txn, + self.EVENT_SEARCH_ORDER_UPDATE_NAME, pg, + ) + + def reindex_search_txn(txn): + sql = ( + "UPDATE event_search AS es SET stream_ordering = e.stream_ordering," + " origin_server_ts = e.origin_server_ts" + " FROM events AS e" + " WHERE e.event_id = es.event_id" + " AND ? <= e.stream_ordering AND e.stream_ordering < ?" + " RETURNING es.stream_ordering" + ) + + min_stream_id = max_stream_id - batch_size + txn.execute(sql, (min_stream_id, max_stream_id)) + rows = txn.fetchall() + + if min_stream_id < target_min_stream_id: + # We've recached the end. + return len(rows), False + + progress = { + "target_min_stream_id_inclusive": target_min_stream_id, + "max_stream_id_exclusive": min_stream_id, + "rows_inserted": rows_inserted + len(rows), + "have_added_indexes": True, + } + + self._background_update_progress_txn( + txn, self.EVENT_SEARCH_ORDER_UPDATE_NAME, progress + ) + + return len(rows), True + + num_rows, finished = yield self.runInteraction( + self.EVENT_SEARCH_ORDER_UPDATE_NAME, reindex_search_txn + ) + + if not finished: + yield self._end_background_update(self.EVENT_SEARCH_ORDER_UPDATE_NAME) + + defer.returnValue(num_rows) + + @defer.inlineCallbacks def search_msgs(self, room_ids, search_term, keys): """Performs a full text search over events with given keys. @@ -310,7 +421,6 @@ class SearchStore(BackgroundUpdateStore): "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) as rank," " origin_server_ts, stream_ordering, room_id, event_id" " FROM event_search" - " NATURAL JOIN events" " WHERE vector @@ to_tsquery('english', ?) AND " ) args = [search_query, search_query] + args @@ -355,7 +465,15 @@ class SearchStore(BackgroundUpdateStore): # We add an arbitrary limit here to ensure we don't try to pull the # entire table from the database. - sql += " ORDER BY origin_server_ts DESC, stream_ordering DESC LIMIT ?" + if isinstance(self.database_engine, PostgresEngine): + sql += ( + " ORDER BY origin_server_ts DESC NULLS LAST," + " stream_ordering DESC NULLS LAST LIMIT ?" + ) + elif isinstance(self.database_engine, Sqlite3Engine): + sql += " ORDER BY origin_server_ts DESC, stream_ordering DESC LIMIT ?" + else: + raise Exception("Unrecognized database engine") args.append(limit) diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py index b10f2a5787..e1dca927d7 100644 --- a/synapse/storage/signatures.py +++ b/synapse/storage/signatures.py @@ -19,17 +19,24 @@ from ._base import SQLBaseStore from unpaddedbase64 import encode_base64 from synapse.crypto.event_signing import compute_event_reference_hash +from synapse.util.caches.descriptors import cached, cachedList class SignatureStore(SQLBaseStore): """Persistence for event signatures and hashes""" + @cached() + def get_event_reference_hash(self, event_id): + return self._get_event_reference_hashes_txn(event_id) + + @cachedList(cached_method_name="get_event_reference_hash", + list_name="event_ids", num_args=1) def get_event_reference_hashes(self, event_ids): def f(txn): - return [ - self._get_event_reference_hashes_txn(txn, ev) - for ev in event_ids - ] + return { + event_id: self._get_event_reference_hashes_txn(txn, event_id) + for event_id in event_ids + } return self.runInteraction( "get_event_reference_hashes", @@ -41,15 +48,15 @@ class SignatureStore(SQLBaseStore): hashes = yield self.get_event_reference_hashes( event_ids ) - hashes = [ - { + hashes = { + e_id: { k: encode_base64(v) for k, v in h.items() if k == "sha256" } - for h in hashes - ] + for e_id, h in hashes.items() + } - defer.returnValue(zip(event_ids, hashes)) + defer.returnValue(hashes.items()) def _get_event_reference_hashes_txn(self, txn, event_id): """Get all the hashes for a given PDU. diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 8ed8a21b0a..23e7ad9922 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -14,9 +14,9 @@ # limitations under the License. from ._base import SQLBaseStore -from synapse.util.caches.descriptors import ( - cached, cachedInlineCallbacks, cachedList -) +from synapse.util.caches.descriptors import cached, cachedList +from synapse.util.caches import intern_string +from synapse.storage.engines import PostgresEngine from twisted.internet import defer @@ -25,6 +25,9 @@ import logging logger = logging.getLogger(__name__) +MAX_STATE_DELTA_HOPS = 100 + + class StateStore(SQLBaseStore): """ Keeps track of the state at a given event. @@ -44,12 +47,22 @@ class StateStore(SQLBaseStore): * `state_groups_state`: Maps state group to state events. """ - @defer.inlineCallbacks - def get_state_groups(self, room_id, event_ids): - """ Get the state groups for the given list of event_ids + STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" + STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" - The return value is a dict mapping group names to lists of events. - """ + def __init__(self, hs): + super(StateStore, self).__init__(hs) + self.register_background_update_handler( + self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, + self._background_deduplicate_state, + ) + self.register_background_update_handler( + self.STATE_GROUP_INDEX_UPDATE_NAME, + self._background_index_state, + ) + + @defer.inlineCallbacks + def get_state_groups_ids(self, room_id, event_ids): if not event_ids: defer.returnValue({}) @@ -60,69 +73,165 @@ class StateStore(SQLBaseStore): groups = set(event_to_groups.values()) group_to_state = yield self._get_state_for_groups(groups) + defer.returnValue(group_to_state) + + @defer.inlineCallbacks + def get_state_groups(self, room_id, event_ids): + """ Get the state groups for the given list of event_ids + + The return value is a dict mapping group names to lists of events. + """ + if not event_ids: + defer.returnValue({}) + + group_to_ids = yield self.get_state_groups_ids(room_id, event_ids) + + state_event_map = yield self.get_events( + [ + ev_id for group_ids in group_to_ids.values() + for ev_id in group_ids.values() + ], + get_prev_content=False + ) + defer.returnValue({ - group: state_map.values() - for group, state_map in group_to_state.items() + group: [ + state_event_map[v] for v in event_id_map.values() if v in state_event_map + ] + for group, event_id_map in group_to_ids.items() }) - def _store_state_groups_txn(self, txn, event, context): - return self._store_mult_state_groups_txn(txn, [(event, context)]) + def _have_persisted_state_group_txn(self, txn, state_group): + txn.execute( + "SELECT count(*) FROM state_groups WHERE id = ?", + (state_group,) + ) + row = txn.fetchone() + return row and row[0] def _store_mult_state_groups_txn(self, txn, events_and_contexts): state_groups = {} for event, context in events_and_contexts: - if context.current_state is None: + if event.internal_metadata.is_outlier(): continue - if context.state_group is not None: - state_groups[event.event_id] = context.state_group + if context.current_state_ids is None: continue - state_events = dict(context.current_state) + state_groups[event.event_id] = context.state_group - if event.is_state(): - state_events[(event.type, event.state_key)] = event + if self._have_persisted_state_group_txn(txn, context.state_group): + continue - state_group = self._state_groups_id_gen.get_next() self._simple_insert_txn( txn, table="state_groups", values={ - "id": state_group, + "id": context.state_group, "room_id": event.room_id, "event_id": event.event_id, }, ) - self._simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": state_group, - "room_id": state.room_id, - "type": state.type, - "state_key": state.state_key, - "event_id": state.event_id, - } - for state in state_events.values() - ], - ) - state_groups[event.event_id] = state_group + # We persist as a delta if we can, while also ensuring the chain + # of deltas isn't tooo long, as otherwise read performance degrades. + if context.prev_group: + potential_hops = self._count_state_group_hops_txn( + txn, context.prev_group + ) + if context.prev_group and potential_hops < MAX_STATE_DELTA_HOPS: + self._simple_insert_txn( + txn, + table="state_group_edges", + values={ + "state_group": context.state_group, + "prev_state_group": context.prev_group, + }, + ) + + self._simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": context.state_group, + "room_id": event.room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in context.delta_ids.items() + ], + ) + else: + self._simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": context.state_group, + "room_id": event.room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in context.current_state_ids.items() + ], + ) self._simple_insert_many_txn( txn, table="event_to_state_groups", values=[ { - "state_group": state_groups[event.event_id], - "event_id": event.event_id, + "state_group": state_group_id, + "event_id": event_id, } - for event, context in events_and_contexts - if context.current_state is not None + for event_id, state_group_id in state_groups.items() ], ) + def _count_state_group_hops_txn(self, txn, state_group): + """Given a state group, count how many hops there are in the tree. + + This is used to ensure the delta chains don't get too long. + """ + if isinstance(self.database_engine, PostgresEngine): + sql = (""" + WITH RECURSIVE state(state_group) AS ( + VALUES(?::bigint) + UNION ALL + SELECT prev_state_group FROM state_group_edges e, state s + WHERE s.state_group = e.state_group + ) + SELECT count(*) FROM state; + """) + + txn.execute(sql, (state_group,)) + row = txn.fetchone() + if row and row[0]: + return row[0] + else: + return 0 + else: + # We don't use WITH RECURSIVE on sqlite3 as there are distributions + # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) + next_group = state_group + count = 0 + + while next_group: + next_group = self._simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": next_group}, + retcol="prev_state_group", + allow_none=True, + ) + if next_group: + count += 1 + + return count + @defer.inlineCallbacks def get_current_state(self, room_id, event_type=None, state_key=""): if event_type and state_key is not None: @@ -155,8 +264,14 @@ class StateStore(SQLBaseStore): events = yield self._get_events(event_ids, get_prev_content=False) defer.returnValue(events) - @cachedInlineCallbacks(num_args=3) + @defer.inlineCallbacks def get_current_state_for_key(self, room_id, event_type, state_key): + event_ids = yield self._get_current_state_for_key(room_id, event_type, state_key) + events = yield self._get_events(event_ids, get_prev_content=False) + defer.returnValue(events) + + @cached(num_args=3) + def _get_current_state_for_key(self, room_id, event_type, state_key): def f(txn): sql = ( "SELECT event_id FROM current_state_events" @@ -167,14 +282,92 @@ class StateStore(SQLBaseStore): txn.execute(sql, args) results = txn.fetchall() return [r[0] for r in results] - event_ids = yield self.runInteraction("get_current_state_for_key", f) - events = yield self._get_events(event_ids, get_prev_content=False) - defer.returnValue(events) + return self.runInteraction("get_current_state_for_key", f) + + @cached(num_args=2, max_entries=1000) + def _get_state_group_from_group(self, group, types): + raise NotImplementedError() + @cachedList(cached_method_name="_get_state_group_from_group", + list_name="groups", num_args=2, inlineCallbacks=True) def _get_state_groups_from_groups(self, groups, types): - """Returns dictionary state_group -> state event ids + """Returns dictionary state_group -> (dict of (type, state_key) -> event id) """ - def f(txn, groups): + results = {} + + chunks = [groups[i:i + 100] for i in xrange(0, len(groups), 100)] + for chunk in chunks: + res = yield self.runInteraction( + "_get_state_groups_from_groups", + self._get_state_groups_from_groups_txn, chunk, types, + ) + results.update(res) + + defer.returnValue(results) + + def _get_state_groups_from_groups_txn(self, txn, groups, types=None): + results = {group: {} for group in groups} + if types is not None: + types = list(set(types)) # deduplicate types list + + if isinstance(self.database_engine, PostgresEngine): + # Temporarily disable sequential scans in this transaction. This is + # a temporary hack until we can add the right indices in + txn.execute("SET LOCAL enable_seqscan=off") + + # The below query walks the state_group tree so that the "state" + # table includes all state_groups in the tree. It then joins + # against `state_groups_state` to fetch the latest state. + # It assumes that previous state groups are always numerically + # lesser. + # The PARTITION is used to get the event_id in the greatest state + # group for the given type, state_key. + # This may return multiple rows per (type, state_key), but last_value + # should be the same. + sql = (""" + WITH RECURSIVE state(state_group) AS ( + VALUES(?::bigint) + UNION ALL + SELECT prev_state_group FROM state_group_edges e, state s + WHERE s.state_group = e.state_group + ) + SELECT type, state_key, last_value(event_id) OVER ( + PARTITION BY type, state_key ORDER BY state_group ASC + ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + ) AS event_id FROM state_groups_state + WHERE state_group IN ( + SELECT state_group FROM state + ) + %s + """) + + # Turns out that postgres doesn't like doing a list of OR's and + # is about 1000x slower, so we just issue a query for each specific + # type seperately. + if types: + clause_to_args = [ + ( + "AND type = ? AND state_key = ?", + (etype, state_key) + ) + for etype, state_key in types + ] + else: + # If types is None we fetch all the state, and so just use an + # empty where clause with no extra args. + clause_to_args = [("", [])] + + for where_clause, where_args in clause_to_args: + for group in groups: + args = [group] + args.extend(where_args) + + txn.execute(sql % (where_clause,), args) + rows = self.cursor_to_dict(txn) + for row in rows: + key = (row["type"], row["state_key"]) + results[group][key] = row["event_id"] + else: if types is not None: where_clause = "AND (%s)" % ( " OR ".join(["(type = ? AND state_key = ?)"] * len(types)), @@ -182,32 +375,47 @@ class StateStore(SQLBaseStore): else: where_clause = "" - sql = ( - "SELECT state_group, event_id FROM state_groups_state WHERE" - " state_group IN (%s) %s" % ( - ",".join("?" for _ in groups), - where_clause, - ) - ) - - args = list(groups) - if types is not None: - args.extend([i for typ in types for i in typ]) - - txn.execute(sql, args) - rows = self.cursor_to_dict(txn) - - results = {} - for row in rows: - results.setdefault(row["state_group"], []).append(row["event_id"]) - return results + # We don't use WITH RECURSIVE on sqlite3 as there are distributions + # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) + for group in groups: + next_group = group + + while next_group: + # We did this before by getting the list of group ids, and + # then passing that list to sqlite to get latest event for + # each (type, state_key). However, that was terribly slow + # without the right indicies (which we can't add until + # after we finish deduping state, which requires this func) + args = [next_group] + if types: + args.extend(i for typ in types for i in typ) + + txn.execute( + "SELECT type, state_key, event_id FROM state_groups_state" + " WHERE state_group = ? %s" % (where_clause,), + args + ) + rows = txn.fetchall() + results[group].update({ + (typ, state_key): event_id + for typ, state_key, event_id in rows + if (typ, state_key) not in results[group] + }) + + # If the lengths match then we must have all the types, + # so no need to go walk further down the tree. + if types is not None and len(results[group]) == len(types): + break + + next_group = self._simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": next_group}, + retcol="prev_state_group", + allow_none=True, + ) - chunks = [groups[i:i + 100] for i in xrange(0, len(groups), 100)] - for chunk in chunks: - return self.runInteraction( - "_get_state_groups_from_groups", - f, chunk - ) + return results @defer.inlineCallbacks def get_state_for_events(self, event_ids, types): @@ -232,6 +440,31 @@ class StateStore(SQLBaseStore): groups = set(event_to_groups.values()) group_to_state = yield self._get_state_for_groups(groups, types) + state_event_map = yield self.get_events( + [ev_id for sd in group_to_state.values() for ev_id in sd.values()], + get_prev_content=False + ) + + event_to_state = { + event_id: { + k: state_event_map[v] + for k, v in group_to_state[group].items() + if v in state_event_map + } + for event_id, group in event_to_groups.items() + } + + defer.returnValue({event: event_to_state[event] for event in event_ids}) + + @defer.inlineCallbacks + def get_state_ids_for_events(self, event_ids, types): + event_to_groups = yield self._get_state_group_for_events( + event_ids, + ) + + groups = set(event_to_groups.values()) + group_to_state = yield self._get_state_for_groups(groups, types) + event_to_state = { event_id: group_to_state[group] for event_id, group in event_to_groups.items() @@ -244,16 +477,36 @@ class StateStore(SQLBaseStore): """ Get the state dict corresponding to a particular event - :param str event_id: event whose state should be returned - :param list[(str, str)]|None types: List of (type, state_key) tuples - which are used to filter the state fetched. May be None, which - matches any key - :return: a deferred dict from (type, state_key) -> state_event + Args: + event_id(str): event whose state should be returned + types(list[(str, str)]|None): List of (type, state_key) tuples + which are used to filter the state fetched. May be None, which + matches any key + + Returns: + A deferred dict from (type, state_key) -> state_event """ state_map = yield self.get_state_for_events([event_id], types) defer.returnValue(state_map[event_id]) - @cached(num_args=2, lru=True, max_entries=10000) + @defer.inlineCallbacks + def get_state_ids_for_event(self, event_id, types=None): + """ + Get the state dict corresponding to a particular event + + Args: + event_id(str): event whose state should be returned + types(list[(str, str)]|None): List of (type, state_key) tuples + which are used to filter the state fetched. May be None, which + matches any key + + Returns: + A deferred dict from (type, state_key) -> state_event + """ + state_map = yield self.get_state_ids_for_events([event_id], types) + defer.returnValue(state_map[event_id]) + + @cached(num_args=2, max_entries=10000) def _get_state_group_for_event(self, room_id, event_id): return self._simple_select_one_onecol( table="event_to_state_groups", @@ -265,8 +518,8 @@ class StateStore(SQLBaseStore): desc="_get_state_group_for_event", ) - @cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids", - num_args=1, inlineCallbacks=True) + @cachedList(cached_method_name="_get_state_group_for_event", + list_name="event_ids", num_args=1, inlineCallbacks=True) def _get_state_group_for_events(self, event_ids): """Returns mapping event_id -> state_group """ @@ -296,7 +549,7 @@ class StateStore(SQLBaseStore): where a `state_key` of `None` matches all state_keys for the `type`. """ - is_all, state_dict = self._state_group_cache.get(group) + is_all, state_dict_ids = self._state_group_cache.get(group) type_to_key = {} missing_types = set() @@ -308,7 +561,7 @@ class StateStore(SQLBaseStore): if type_to_key.get(typ, object()) is not None: type_to_key.setdefault(typ, set()).add(state_key) - if (typ, state_key) not in state_dict: + if (typ, state_key) not in state_dict_ids: missing_types.add((typ, state_key)) sentinel = object() @@ -326,7 +579,7 @@ class StateStore(SQLBaseStore): got_all = not (missing_types or types is None) return { - k: v for k, v in state_dict.items() + k: v for k, v in state_dict_ids.items() if include(k[0], k[1]) }, missing_types, got_all @@ -340,8 +593,9 @@ class StateStore(SQLBaseStore): Args: group: The state group to lookup """ - is_all, state_dict = self._state_group_cache.get(group) - return state_dict, is_all + is_all, state_dict_ids = self._state_group_cache.get(group) + + return state_dict_ids, is_all @defer.inlineCallbacks def _get_state_for_groups(self, groups, types=None): @@ -350,88 +604,256 @@ class StateStore(SQLBaseStore): a `state_key` of None matches all state_keys. If `types` is None then all events are returned. """ + if types: + types = frozenset(types) results = {} missing_groups = [] if types is not None: for group in set(groups): - state_dict, missing_types, got_all = self._get_some_state_from_cache( + state_dict_ids, missing_types, got_all = self._get_some_state_from_cache( group, types ) - results[group] = state_dict + results[group] = state_dict_ids if not got_all: missing_groups.append(group) else: for group in set(groups): - state_dict, got_all = self._get_all_state_from_cache( + state_dict_ids, got_all = self._get_all_state_from_cache( group ) - results[group] = state_dict + + results[group] = state_dict_ids if not got_all: missing_groups.append(group) - if not missing_groups: - defer.returnValue({ - group: { - type_tuple: event - for type_tuple, event in state.items() - if event - } - for group, state in results.items() - }) + if missing_groups: + # Okay, so we have some missing_types, lets fetch them. + cache_seq_num = self._state_group_cache.sequence - # Okay, so we have some missing_types, lets fetch them. - cache_seq_num = self._state_group_cache.sequence + group_to_state_dict = yield self._get_state_groups_from_groups( + missing_groups, types + ) - group_state_dict = yield self._get_state_groups_from_groups( - missing_groups, types - ) + # Now we want to update the cache with all the things we fetched + # from the database. + for group, group_state_dict in group_to_state_dict.items(): + if types: + # We delibrately put key -> None mappings into the cache to + # cache absence of the key, on the assumption that if we've + # explicitly asked for some types then we will probably ask + # for them again. + state_dict = { + (intern_string(etype), intern_string(state_key)): None + for (etype, state_key) in types + } + state_dict.update(results[group]) + results[group] = state_dict + else: + state_dict = results[group] + + state_dict.update({ + (intern_string(k[0]), intern_string(k[1])): v + for k, v in group_state_dict.items() + }) + + self._state_group_cache.update( + cache_seq_num, + key=group, + value=state_dict, + full=(types is None), + ) - state_events = yield self._get_events( - [e_id for l in group_state_dict.values() for e_id in l], - get_prev_content=False + # Remove all the entries with None values. The None values were just + # used for bookkeeping in the cache. + for group, state_dict in results.items(): + results[group] = { + key: event_id + for key, event_id in state_dict.items() + if event_id + } + + defer.returnValue(results) + + def get_next_state_group(self): + return self._state_groups_id_gen.get_next() + + @defer.inlineCallbacks + def _background_deduplicate_state(self, progress, batch_size): + """This background update will slowly deduplicate state by reencoding + them as deltas. + """ + last_state_group = progress.get("last_state_group", 0) + rows_inserted = progress.get("rows_inserted", 0) + max_group = progress.get("max_group", None) + + BATCH_SIZE_SCALE_FACTOR = 100 + + batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR)) + + if max_group is None: + rows = yield self._execute( + "_background_deduplicate_state", None, + "SELECT coalesce(max(id), 0) FROM state_groups", + ) + max_group = rows[0][0] + + def reindex_txn(txn): + new_last_state_group = last_state_group + for count in xrange(batch_size): + txn.execute( + "SELECT id, room_id FROM state_groups" + " WHERE ? < id AND id <= ?" + " ORDER BY id ASC" + " LIMIT 1", + (new_last_state_group, max_group,) + ) + row = txn.fetchone() + if row: + state_group, room_id = row + + if not row or not state_group: + return True, count + + txn.execute( + "SELECT state_group FROM state_group_edges" + " WHERE state_group = ?", + (state_group,) + ) + + # If we reach a point where we've already started inserting + # edges we should stop. + if txn.fetchall(): + return True, count + + txn.execute( + "SELECT coalesce(max(id), 0) FROM state_groups" + " WHERE id < ? AND room_id = ?", + (state_group, room_id,) + ) + prev_group, = txn.fetchone() + new_last_state_group = state_group + + if prev_group: + potential_hops = self._count_state_group_hops_txn( + txn, prev_group + ) + if potential_hops >= MAX_STATE_DELTA_HOPS: + # We want to ensure chains are at most this long,# + # otherwise read performance degrades. + continue + + prev_state = self._get_state_groups_from_groups_txn( + txn, [prev_group], types=None + ) + prev_state = prev_state[prev_group] + + curr_state = self._get_state_groups_from_groups_txn( + txn, [state_group], types=None + ) + curr_state = curr_state[state_group] + + if not set(prev_state.keys()) - set(curr_state.keys()): + # We can only do a delta if the current has a strict super set + # of keys + + delta_state = { + key: value for key, value in curr_state.items() + if prev_state.get(key, None) != value + } + + self._simple_delete_txn( + txn, + table="state_group_edges", + keyvalues={ + "state_group": state_group, + } + ) + + self._simple_insert_txn( + txn, + table="state_group_edges", + values={ + "state_group": state_group, + "prev_state_group": prev_group, + } + ) + + self._simple_delete_txn( + txn, + table="state_groups_state", + keyvalues={ + "state_group": state_group, + } + ) + + self._simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": state_group, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in delta_state.items() + ], + ) + + progress = { + "last_state_group": state_group, + "rows_inserted": rows_inserted + batch_size, + "max_group": max_group, + } + + self._background_update_progress_txn( + txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress + ) + + return False, batch_size + + finished, result = yield self.runInteraction( + self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn ) - state_events = {e.event_id: e for e in state_events} + if finished: + yield self._end_background_update(self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME) - # Now we want to update the cache with all the things we fetched - # from the database. - for group, state_ids in group_state_dict.items(): - if types: - # We delibrately put key -> None mappings into the cache to - # cache absence of the key, on the assumption that if we've - # explicitly asked for some types then we will probably ask - # for them again. - state_dict = {key: None for key in types} - state_dict.update(results[group]) - results[group] = state_dict - else: - state_dict = results[group] + defer.returnValue(result * BATCH_SIZE_SCALE_FACTOR) - for event_id in state_ids: + @defer.inlineCallbacks + def _background_index_state(self, progress, batch_size): + def reindex_txn(conn): + conn.rollback() + if isinstance(self.database_engine, PostgresEngine): + # postgres insists on autocommit for the index + conn.set_session(autocommit=True) try: - state_event = state_events[event_id] - state_dict[(state_event.type, state_event.state_key)] = state_event - except KeyError: - # Hmm. So we do don't have that state event? Interesting. - logger.warn( - "Can't find state event %r for state group %r", - event_id, group, + txn = conn.cursor() + txn.execute( + "CREATE INDEX CONCURRENTLY state_groups_state_type_idx" + " ON state_groups_state(state_group, type, state_key)" + ) + txn.execute( + "DROP INDEX IF EXISTS state_groups_state_id" ) + finally: + conn.set_session(autocommit=False) + else: + txn = conn.cursor() + txn.execute( + "CREATE INDEX state_groups_state_type_idx" + " ON state_groups_state(state_group, type, state_key)" + ) + txn.execute( + "DROP INDEX IF EXISTS state_groups_state_id" + ) - self._state_group_cache.update( - cache_seq_num, - key=group, - value=state_dict, - full=(types is None), - ) + yield self.runWithConnection(reindex_txn) - # Remove all the entries with None values. The None values were just - # used for bookkeeping in the cache. - for group, state_dict in results.items(): - results[group] = { - key: event for key, event in state_dict.items() if event - } + yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME) - defer.returnValue(results) + defer.returnValue(1) diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 7f4a827528..7fa63b58a7 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -36,10 +36,11 @@ what sort order was used: from twisted.internet import defer from ._base import SQLBaseStore -from synapse.util.caches.descriptors import cachedInlineCallbacks +from synapse.util.caches.descriptors import cached from synapse.api.constants import EventTypes from synapse.types import RoomStreamToken -from synapse.util.logcontext import preserve_fn +from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred +from synapse.storage.engines import PostgresEngine, Sqlite3Engine import logging @@ -54,26 +55,92 @@ _STREAM_TOKEN = "stream" _TOPOLOGICAL_TOKEN = "topological" -def lower_bound(token): +def lower_bound(token, engine, inclusive=False): + inclusive = "=" if inclusive else "" if token.topological is None: - return "(%d < %s)" % (token.stream, "stream_ordering") + return "(%d <%s %s)" % (token.stream, inclusive, "stream_ordering") else: - return "(%d < %s OR (%d = %s AND %d < %s))" % ( + if isinstance(engine, PostgresEngine): + # Postgres doesn't optimise ``(x < a) OR (x=a AND y<b)`` as well + # as it optimises ``(x,y) < (a,b)`` on multicolumn indexes. So we + # use the later form when running against postgres. + return "((%d,%d) <%s (%s,%s))" % ( + token.topological, token.stream, inclusive, + "topological_ordering", "stream_ordering", + ) + return "(%d < %s OR (%d = %s AND %d <%s %s))" % ( token.topological, "topological_ordering", token.topological, "topological_ordering", - token.stream, "stream_ordering", + token.stream, inclusive, "stream_ordering", ) -def upper_bound(token): +def upper_bound(token, engine, inclusive=True): + inclusive = "=" if inclusive else "" if token.topological is None: - return "(%d >= %s)" % (token.stream, "stream_ordering") + return "(%d >%s %s)" % (token.stream, inclusive, "stream_ordering") else: - return "(%d > %s OR (%d = %s AND %d >= %s))" % ( + if isinstance(engine, PostgresEngine): + # Postgres doesn't optimise ``(x > a) OR (x=a AND y>b)`` as well + # as it optimises ``(x,y) > (a,b)`` on multicolumn indexes. So we + # use the later form when running against postgres. + return "((%d,%d) >%s (%s,%s))" % ( + token.topological, token.stream, inclusive, + "topological_ordering", "stream_ordering", + ) + return "(%d > %s OR (%d = %s AND %d >%s %s))" % ( token.topological, "topological_ordering", token.topological, "topological_ordering", - token.stream, "stream_ordering", + token.stream, inclusive, "stream_ordering", + ) + + +def filter_to_clause(event_filter): + # NB: This may create SQL clauses that don't optimise well (and we don't + # have indices on all possible clauses). E.g. it may create + # "room_id == X AND room_id != X", which postgres doesn't optimise. + + if not event_filter: + return "", [] + + clauses = [] + args = [] + + if event_filter.types: + clauses.append( + "(%s)" % " OR ".join("type = ?" for _ in event_filter.types) + ) + args.extend(event_filter.types) + + for typ in event_filter.not_types: + clauses.append("type != ?") + args.append(typ) + + if event_filter.senders: + clauses.append( + "(%s)" % " OR ".join("sender = ?" for _ in event_filter.senders) + ) + args.extend(event_filter.senders) + + for sender in event_filter.not_senders: + clauses.append("sender != ?") + args.append(sender) + + if event_filter.rooms: + clauses.append( + "(%s)" % " OR ".join("room_id = ?" for _ in event_filter.rooms) ) + args.extend(event_filter.rooms) + + for room_id in event_filter.not_rooms: + clauses.append("room_id != ?") + args.append(room_id) + + if event_filter.contains_url: + clauses.append("contains_url = ?") + args.append(event_filter.contains_url) + + return " AND ".join(clauses), args class StreamStore(SQLBaseStore): @@ -132,29 +199,25 @@ class StreamStore(SQLBaseStore): return True return False - ret = self._get_events_txn( - txn, - # apply the filter on the room id list - [ - r["event_id"] for r in rows - if app_service_interested(r) - ], - get_prev_content=True - ) + return [r for r in rows if app_service_interested(r)] - self._set_before_and_after(ret, rows) + rows = yield self.runInteraction("get_appservice_room_stream", f) - if rows: - key = "s%d" % max(r["stream_ordering"] for r in rows) - else: - # Assume we didn't get anything because there was nothing to - # get. - key = to_key + ret = yield self._get_events( + [r["event_id"] for r in rows], + get_prev_content=True + ) - return ret, key + self._set_before_and_after(ret, rows, topo_order=from_id is None) - results = yield self.runInteraction("get_appservice_room_stream", f) - defer.returnValue(results) + if rows: + key = "s%d" % max(r["stream_ordering"] for r in rows) + else: + # Assume we didn't get anything because there was nothing to + # get. + key = to_key + + defer.returnValue((ret, key)) @defer.inlineCallbacks def get_room_events_stream_for_rooms(self, room_ids, from_key, to_key, limit=0, @@ -171,12 +234,12 @@ class StreamStore(SQLBaseStore): results = {} room_ids = list(room_ids) for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)): - res = yield defer.gatherResults([ + res = yield preserve_context_over_deferred(defer.gatherResults([ preserve_fn(self.get_room_events_stream_for_room)( room_id, from_key, to_key, limit, order=order, ) for room_id in rm_ids - ]) + ])) results.update(dict(zip(rm_ids, res))) defer.returnValue(results) @@ -303,117 +366,37 @@ class StreamStore(SQLBaseStore): defer.returnValue(ret) - def get_room_events_stream( - self, - user_id, - from_key, - to_key, - limit=0, - is_guest=False, - room_ids=None - ): - room_ids = room_ids or [] - room_ids = [r for r in room_ids] - if is_guest: - current_room_membership_sql = ( - "SELECT c.room_id FROM history_visibility AS h" - " INNER JOIN current_state_events AS c" - " ON h.event_id = c.event_id" - " WHERE c.room_id IN (%s)" - " AND h.history_visibility = 'world_readable'" % ( - ",".join(map(lambda _: "?", room_ids)) - ) - ) - current_room_membership_args = room_ids - else: - current_room_membership_sql = ( - "SELECT m.room_id FROM room_memberships as m " - " INNER JOIN current_state_events as c" - " ON m.event_id = c.event_id AND c.state_key = m.user_id" - " WHERE m.user_id = ? AND m.membership = 'join'" - ) - current_room_membership_args = [user_id] - - # We also want to get any membership events about that user, e.g. - # invites or leave notifications. - membership_sql = ( - "SELECT m.event_id FROM room_memberships as m " - "INNER JOIN current_state_events as c ON m.event_id = c.event_id " - "WHERE m.user_id = ? " - ) - membership_args = [user_id] - - if limit: - limit = max(limit, MAX_STREAM_SIZE) - else: - limit = MAX_STREAM_SIZE - - # From and to keys should be integers from ordering. - from_id = RoomStreamToken.parse_stream_token(from_key) - to_id = RoomStreamToken.parse_stream_token(to_key) - - if from_key == to_key: - return defer.succeed(([], to_key)) - - sql = ( - "SELECT e.event_id, e.stream_ordering FROM events AS e WHERE " - "(e.outlier = ? AND (room_id IN (%(current)s)) OR " - "(event_id IN (%(invites)s))) " - "AND e.stream_ordering > ? AND e.stream_ordering <= ? " - "ORDER BY stream_ordering ASC LIMIT %(limit)d " - ) % { - "current": current_room_membership_sql, - "invites": membership_sql, - "limit": limit - } - - def f(txn): - args = ([False] + current_room_membership_args + membership_args + - [from_id.stream, to_id.stream]) - txn.execute(sql, args) - - rows = self.cursor_to_dict(txn) - - ret = self._get_events_txn( - txn, - [r["event_id"] for r in rows], - get_prev_content=True - ) - - self._set_before_and_after(ret, rows) - - if rows: - key = "s%d" % max(r["stream_ordering"] for r in rows) - else: - # Assume we didn't get anything because there was nothing to - # get. - key = to_key - - return ret, key - - return self.runInteraction("get_room_events_stream", f) - @defer.inlineCallbacks def paginate_room_events(self, room_id, from_key, to_key=None, - direction='b', limit=-1): + direction='b', limit=-1, event_filter=None): # Tokens really represent positions between elements, but we use # the convention of pointing to the event before the gap. Hence # we have a bit of asymmetry when it comes to equalities. args = [False, room_id] if direction == 'b': order = "DESC" - bounds = upper_bound(RoomStreamToken.parse(from_key)) + bounds = upper_bound( + RoomStreamToken.parse(from_key), self.database_engine + ) if to_key: - bounds = "%s AND %s" % ( - bounds, lower_bound(RoomStreamToken.parse(to_key)) - ) + bounds = "%s AND %s" % (bounds, lower_bound( + RoomStreamToken.parse(to_key), self.database_engine + )) else: order = "ASC" - bounds = lower_bound(RoomStreamToken.parse(from_key)) + bounds = lower_bound( + RoomStreamToken.parse(from_key), self.database_engine + ) if to_key: - bounds = "%s AND %s" % ( - bounds, upper_bound(RoomStreamToken.parse(to_key)) - ) + bounds = "%s AND %s" % (bounds, upper_bound( + RoomStreamToken.parse(to_key), self.database_engine + )) + + filter_clause, filter_args = filter_to_clause(event_filter) + + if filter_clause: + bounds += " AND " + filter_clause + args.extend(filter_args) if int(limit) > 0: args.append(int(limit)) @@ -465,9 +448,25 @@ class StreamStore(SQLBaseStore): defer.returnValue((events, token)) - @cachedInlineCallbacks(num_args=4) + @defer.inlineCallbacks def get_recent_events_for_room(self, room_id, limit, end_token, from_token=None): + rows, token = yield self.get_recent_event_ids_for_room( + room_id, limit, end_token, from_token + ) + + logger.debug("stream before") + events = yield self._get_events( + [r["event_id"] for r in rows], + get_prev_content=True + ) + logger.debug("stream after") + + self._set_before_and_after(events, rows) + defer.returnValue((events, token)) + + @cached(num_args=4) + def get_recent_event_ids_for_room(self, room_id, limit, end_token, from_token=None): end_token = RoomStreamToken.parse_stream_token(end_token) if from_token is None: @@ -517,32 +516,31 @@ class StreamStore(SQLBaseStore): return rows, token - rows, token = yield self.runInteraction( + return self.runInteraction( "get_recent_events_for_room", get_recent_events_for_room_txn ) - logger.debug("stream before") - events = yield self._get_events( - [r["event_id"] for r in rows], - get_prev_content=True - ) - logger.debug("stream after") - - self._set_before_and_after(events, rows) - - defer.returnValue((events, token)) - @defer.inlineCallbacks - def get_room_events_max_id(self, direction='f'): - token = yield self._stream_id_gen.get_max_token() - if direction != 'b': + def get_room_events_max_id(self, room_id=None): + """Returns the current token for rooms stream. + + By default, it returns the current global stream token. Specifying a + `room_id` causes it to return the current room specific topological + token. + """ + token = yield self._stream_id_gen.get_current_token() + if room_id is None: defer.returnValue("s%d" % (token,)) else: topo = yield self.runInteraction( - "_get_max_topological_txn", self._get_max_topological_txn + "_get_max_topological_txn", self._get_max_topological_txn, + room_id, ) defer.returnValue("t%d-%d" % (topo, token)) + def get_room_max_stream_ordering(self): + return self._stream_id_gen.get_current_token() + def get_stream_token_for_event(self, event_id): """The stream token for an event Args: @@ -576,23 +574,23 @@ class StreamStore(SQLBaseStore): row["topological_ordering"], row["stream_ordering"],) ) - def get_max_topological_token_for_stream_and_room(self, room_id, stream_key): + def get_max_topological_token(self, room_id, stream_key): sql = ( "SELECT max(topological_ordering) FROM events" " WHERE room_id = ? AND stream_ordering < ?" ) return self._execute( - "get_max_topological_token_for_stream_and_room", None, + "get_max_topological_token", None, sql, room_id, stream_key, ).addCallback( lambda r: r[0][0] if r else 0 ) - def _get_max_topological_txn(self, txn): + def _get_max_topological_txn(self, txn, room_id): txn.execute( "SELECT MAX(topological_ordering) FROM events" - " WHERE outlier = ?", - (False,) + " WHERE room_id = ?", + (room_id,) ) rows = txn.fetchall() @@ -675,32 +673,60 @@ class StreamStore(SQLBaseStore): retcols=["stream_ordering", "topological_ordering"], ) - stream_ordering = results["stream_ordering"] - topological_ordering = results["topological_ordering"] - - query_before = ( - "SELECT topological_ordering, stream_ordering, event_id FROM events" - " WHERE room_id = ? AND (topological_ordering < ?" - " OR (topological_ordering = ? AND stream_ordering < ?))" - " ORDER BY topological_ordering DESC, stream_ordering DESC" - " LIMIT ?" + token = RoomStreamToken( + results["topological_ordering"], + results["stream_ordering"], ) - query_after = ( - "SELECT topological_ordering, stream_ordering, event_id FROM events" - " WHERE room_id = ? AND (topological_ordering > ?" - " OR (topological_ordering = ? AND stream_ordering > ?))" - " ORDER BY topological_ordering ASC, stream_ordering ASC" - " LIMIT ?" - ) + if isinstance(self.database_engine, Sqlite3Engine): + # SQLite3 doesn't optimise ``(x < a) OR (x = a AND y < b)`` + # So we give pass it to SQLite3 as the UNION ALL of the two queries. + + query_before = ( + "SELECT topological_ordering, stream_ordering, event_id FROM events" + " WHERE room_id = ? AND topological_ordering < ?" + " UNION ALL" + " SELECT topological_ordering, stream_ordering, event_id FROM events" + " WHERE room_id = ? AND topological_ordering = ? AND stream_ordering < ?" + " ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?" + ) + before_args = ( + room_id, token.topological, + room_id, token.topological, token.stream, + before_limit, + ) - txn.execute( - query_before, - ( - room_id, topological_ordering, topological_ordering, - stream_ordering, before_limit, + query_after = ( + "SELECT topological_ordering, stream_ordering, event_id FROM events" + " WHERE room_id = ? AND topological_ordering > ?" + " UNION ALL" + " SELECT topological_ordering, stream_ordering, event_id FROM events" + " WHERE room_id = ? AND topological_ordering = ? AND stream_ordering > ?" + " ORDER BY topological_ordering ASC, stream_ordering ASC LIMIT ?" ) - ) + after_args = ( + room_id, token.topological, + room_id, token.topological, token.stream, + after_limit, + ) + else: + query_before = ( + "SELECT topological_ordering, stream_ordering, event_id FROM events" + " WHERE room_id = ? AND %s" + " ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?" + ) % (upper_bound(token, self.database_engine, inclusive=False),) + + before_args = (room_id, before_limit) + + query_after = ( + "SELECT topological_ordering, stream_ordering, event_id FROM events" + " WHERE room_id = ? AND %s" + " ORDER BY topological_ordering ASC, stream_ordering ASC LIMIT ?" + ) % (lower_bound(token, self.database_engine, inclusive=False),) + + after_args = (room_id, after_limit) + + txn.execute(query_before, before_args) rows = self.cursor_to_dict(txn) events_before = [r["event_id"] for r in rows] @@ -712,17 +738,11 @@ class StreamStore(SQLBaseStore): )) else: start_token = str(RoomStreamToken( - topological_ordering, - stream_ordering - 1, + token.topological, + token.stream - 1, )) - txn.execute( - query_after, - ( - room_id, topological_ordering, topological_ordering, - stream_ordering, after_limit, - ) - ) + txn.execute(query_after, after_args) rows = self.cursor_to_dict(txn) events_after = [r["event_id"] for r in rows] @@ -733,10 +753,7 @@ class StreamStore(SQLBaseStore): rows[-1]["stream_ordering"], )) else: - end_token = str(RoomStreamToken( - topological_ordering, - stream_ordering, - )) + end_token = str(token) return { "before": { @@ -748,3 +765,50 @@ class StreamStore(SQLBaseStore): "token": end_token, }, } + + @defer.inlineCallbacks + def get_all_new_events_stream(self, from_id, current_id, limit): + """Get all new events""" + + def get_all_new_events_stream_txn(txn): + sql = ( + "SELECT e.stream_ordering, e.event_id" + " FROM events AS e" + " WHERE" + " ? < e.stream_ordering AND e.stream_ordering <= ?" + " ORDER BY e.stream_ordering ASC" + " LIMIT ?" + ) + + txn.execute(sql, (from_id, current_id, limit)) + rows = txn.fetchall() + + upper_bound = current_id + if len(rows) == limit: + upper_bound = rows[-1][0] + + return upper_bound, [row[1] for row in rows] + + upper_bound, event_ids = yield self.runInteraction( + "get_all_new_events_stream", get_all_new_events_stream_txn, + ) + + events = yield self._get_events(event_ids) + + defer.returnValue((upper_bound, events)) + + def get_federation_out_pos(self, typ): + return self._simple_select_one_onecol( + table="federation_stream_position", + retcol="stream_id", + keyvalues={"type": typ}, + desc="get_federation_out_pos" + ) + + def update_federation_out_pos(self, typ, stream_id): + return self._simple_update_one( + table="federation_stream_position", + keyvalues={"type": typ}, + updatevalues={"stream_id": stream_id}, + desc="update_federation_out_pos", + ) diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py index a0e6b42b30..5a2c1aa59b 100644 --- a/synapse/storage/tags.py +++ b/synapse/storage/tags.py @@ -30,7 +30,7 @@ class TagsStore(SQLBaseStore): Returns: A deferred int. """ - return self._account_data_id_gen.get_max_token() + return self._account_data_id_gen.get_current_token() @cached() def get_tags_for_user(self, user_id): @@ -68,6 +68,9 @@ class TagsStore(SQLBaseStore): A deferred list of tuples of stream_id int, user_id string, room_id string, tag string and content string. """ + if last_id == current_id: + defer.returnValue([]) + def get_all_updated_tags_txn(txn): sql = ( "SELECT stream_id, user_id, room_id" @@ -200,7 +203,7 @@ class TagsStore(SQLBaseStore): self.get_tags_for_user.invalidate((user_id,)) - result = self._account_data_id_gen.get_max_token() + result = self._account_data_id_gen.get_current_token() defer.returnValue(result) @defer.inlineCallbacks @@ -222,7 +225,7 @@ class TagsStore(SQLBaseStore): self.get_tags_for_user.invalidate((user_id,)) - result = self._account_data_id_gen.get_max_token() + result = self._account_data_id_gen.get_current_token() defer.returnValue(result) def _update_revision_txn(self, txn, user_id, room_id, next_id): diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index d338dfcf0a..809fdd311f 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -16,16 +16,41 @@ from ._base import SQLBaseStore from synapse.util.caches.descriptors import cached +from twisted.internet import defer + from canonicaljson import encode_canonical_json + +from collections import namedtuple + import logging +import ujson as json logger = logging.getLogger(__name__) +_TransactionRow = namedtuple( + "_TransactionRow", ( + "id", "transaction_id", "destination", "ts", "response_code", + "response_json", + ) +) + +_UpdateTransactionRow = namedtuple( + "_TransactionRow", ( + "response_code", "response_json", + ) +) + + class TransactionStore(SQLBaseStore): """A collection of queries for handling PDUs. """ + def __init__(self, hs): + super(TransactionStore, self).__init__(hs) + + self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000) + def get_received_txn_response(self, transaction_id, origin): """For an incoming transaction from a given origin, check if we have already responded to it. If so, return the response code and response @@ -61,7 +86,7 @@ class TransactionStore(SQLBaseStore): ) if result and result["response_code"]: - return result["response_code"], result["response_json"] + return result["response_code"], json.loads(str(result["response_json"])) else: return None @@ -86,6 +111,7 @@ class TransactionStore(SQLBaseStore): "origin": origin, "response_code": code, "response_json": buffer(encode_canonical_json(response_dict)), + "ts": self._clock.time_msec(), }, or_ignore=True, desc="set_received_txn_response", @@ -107,50 +133,7 @@ class TransactionStore(SQLBaseStore): Returns: list: A list of previous transaction ids. """ - - return self.runInteraction( - "prep_send_transaction", - self._prep_send_transaction, - transaction_id, destination, origin_server_ts - ) - - def _prep_send_transaction(self, txn, transaction_id, destination, - origin_server_ts): - - next_id = self._transaction_id_gen.get_next() - - # First we find out what the prev_txns should be. - # Since we know that we are only sending one transaction at a time, - # we can simply take the last one. - query = ( - "SELECT * FROM sent_transactions" - " WHERE destination = ?" - " ORDER BY id DESC LIMIT 1" - ) - - txn.execute(query, (destination,)) - results = self.cursor_to_dict(txn) - - prev_txns = [r["transaction_id"] for r in results] - - # Actually add the new transaction to the sent_transactions table. - - self._simple_insert_txn( - txn, - table="sent_transactions", - values={ - "id": next_id, - "transaction_id": transaction_id, - "destination": destination, - "ts": origin_server_ts, - "response_code": 0, - "response_json": None, - } - ) - - # TODO Update the tx id -> pdu id mapping - - return prev_txns + return defer.succeed([]) def delivered_txn(self, transaction_id, destination, code, response_dict): """Persists the response for an outgoing transaction. @@ -161,58 +144,9 @@ class TransactionStore(SQLBaseStore): code (int) response_json (str) """ - return self.runInteraction( - "delivered_txn", - self._delivered_txn, - transaction_id, destination, code, - buffer(encode_canonical_json(response_dict)), - ) - - def _delivered_txn(self, txn, transaction_id, destination, - code, response_json): - self._simple_update_one_txn( - txn, - table="sent_transactions", - keyvalues={ - "transaction_id": transaction_id, - "destination": destination, - }, - updatevalues={ - "response_code": code, - "response_json": None, # For now, don't persist response_json - } - ) - - def get_transactions_after(self, transaction_id, destination): - """Get all transactions after a given local transaction_id. - - Args: - transaction_id (str) - destination (str) - - Returns: - list: A list of dicts - """ - return self.runInteraction( - "get_transactions_after", - self._get_transactions_after, transaction_id, destination - ) + pass - def _get_transactions_after(self, txn, transaction_id, destination): - query = ( - "SELECT * FROM sent_transactions" - " WHERE destination = ? AND id >" - " (" - " SELECT id FROM sent_transactions" - " WHERE transaction_id = ? AND destination = ?" - " )" - ) - - txn.execute(query, (destination, transaction_id, destination)) - - return self.cursor_to_dict(txn) - - @cached() + @cached(max_entries=10000) def get_destination_retry_timings(self, destination): """Gets the current retry timings (if any) for a given destination. @@ -266,25 +200,48 @@ class TransactionStore(SQLBaseStore): def _set_destination_retry_timings(self, txn, destination, retry_last_ts, retry_interval): - txn.call_after(self.get_destination_retry_timings.invalidate, (destination,)) + self.database_engine.lock_table(txn, "destinations") - self._simple_upsert_txn( + self._invalidate_cache_and_stream( + txn, self.get_destination_retry_timings, (destination,) + ) + + # We need to be careful here as the data may have changed from under us + # due to a worker setting the timings. + + prev_row = self._simple_select_one_txn( txn, - "destinations", + table="destinations", keyvalues={ "destination": destination, }, - values={ - "retry_last_ts": retry_last_ts, - "retry_interval": retry_interval, - }, - insertion_values={ - "destination": destination, - "retry_last_ts": retry_last_ts, - "retry_interval": retry_interval, - } + retcols=("retry_last_ts", "retry_interval"), + allow_none=True, ) + if not prev_row: + self._simple_insert_txn( + txn, + table="destinations", + values={ + "destination": destination, + "retry_last_ts": retry_last_ts, + "retry_interval": retry_interval, + } + ) + elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval: + self._simple_update_one_txn( + txn, + "destinations", + keyvalues={ + "destination": destination, + }, + updatevalues={ + "retry_last_ts": retry_last_ts, + "retry_interval": retry_interval, + }, + ) + def get_destinations_needing_retry(self): """Get all destinations which are due a retry for sending a transaction. @@ -305,3 +262,12 @@ class TransactionStore(SQLBaseStore): txn.execute(query, (self._clock.time_msec(),)) return self.cursor_to_dict(txn) + + def _cleanup_transactions(self): + now = self._clock.time_msec() + month_ago = now - 30 * 24 * 60 * 60 * 1000 + + def _cleanup_transactions_txn(txn): + txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,)) + + return self.runInteraction("_cleanup_transactions", _cleanup_transactions_txn) diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index a02dfc7d58..46cf93ff87 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -21,7 +21,7 @@ import threading class IdGenerator(object): def __init__(self, db_conn, table, column): self._lock = threading.Lock() - self._next_id = _load_max_id(db_conn, table, column) + self._next_id = _load_current_id(db_conn, table, column) def get_next(self): with self._lock: @@ -29,12 +29,16 @@ class IdGenerator(object): return self._next_id -def _load_max_id(db_conn, table, column): +def _load_current_id(db_conn, table, column, step=1): cur = db_conn.cursor() - cur.execute("SELECT MAX(%s) FROM %s" % (column, table,)) + if step == 1: + cur.execute("SELECT MAX(%s) FROM %s" % (column, table,)) + else: + cur.execute("SELECT MIN(%s) FROM %s" % (column, table,)) val, = cur.fetchone() cur.close() - return int(val) if val else 1 + current_id = int(val) if val else step + return (max if step > 0 else min)(current_id, step) class StreamIdGenerator(object): @@ -45,17 +49,32 @@ class StreamIdGenerator(object): all ids less than or equal to it have completed. This handles the fact that persistence of events can complete out of order. + Args: + db_conn(connection): A database connection to use to fetch the + initial value of the generator from. + table(str): A database table to read the initial value of the id + generator from. + column(str): The column of the database table to read the initial + value from the id generator from. + extra_tables(list): List of pairs of database tables and columns to + use to source the initial value of the generator from. The value + with the largest magnitude is used. + step(int): which direction the stream ids grow in. +1 to grow + upwards, -1 to grow downwards. + Usage: with stream_id_gen.get_next() as stream_id: # ... persist event ... """ - def __init__(self, db_conn, table, column, extra_tables=[]): + def __init__(self, db_conn, table, column, extra_tables=[], step=1): + assert step != 0 self._lock = threading.Lock() - self._current_max = _load_max_id(db_conn, table, column) + self._step = step + self._current = _load_current_id(db_conn, table, column, step) for table, column in extra_tables: - self._current_max = max( - self._current_max, - _load_max_id(db_conn, table, column) + self._current = (max if step > 0 else min)( + self._current, + _load_current_id(db_conn, table, column, step) ) self._unfinished_ids = deque() @@ -66,8 +85,8 @@ class StreamIdGenerator(object): # ... persist event ... """ with self._lock: - self._current_max += 1 - next_id = self._current_max + self._current += self._step + next_id = self._current self._unfinished_ids.append(next_id) @@ -88,8 +107,12 @@ class StreamIdGenerator(object): # ... persist events ... """ with self._lock: - next_ids = range(self._current_max + 1, self._current_max + n + 1) - self._current_max += n + next_ids = range( + self._current + self._step, + self._current + self._step * (n + 1), + self._step + ) + self._current += n * self._step for next_id in next_ids: self._unfinished_ids.append(next_id) @@ -105,15 +128,15 @@ class StreamIdGenerator(object): return manager() - def get_max_token(self): + def get_current_token(self): """Returns the maximum stream id such that all stream ids less than or equal to it have been successfully persisted. """ with self._lock: if self._unfinished_ids: - return self._unfinished_ids[0] - 1 + return self._unfinished_ids[0] - self._step - return self._current_max + return self._current class ChainedIdGenerator(object): @@ -125,7 +148,7 @@ class ChainedIdGenerator(object): def __init__(self, chained_generator, db_conn, table, column): self.chained_generator = chained_generator self._lock = threading.Lock() - self._current_max = _load_max_id(db_conn, table, column) + self._current_max = _load_current_id(db_conn, table, column) self._unfinished_ids = deque() def get_next(self): @@ -137,7 +160,7 @@ class ChainedIdGenerator(object): with self._lock: self._current_max += 1 next_id = self._current_max - chained_id = self.chained_generator.get_max_token() + chained_id = self.chained_generator.get_current_token() self._unfinished_ids.append((next_id, chained_id)) @@ -151,7 +174,7 @@ class ChainedIdGenerator(object): return manager() - def get_max_token(self): + def get_current_token(self): """Returns the maximum stream id such that all stream ids less than or equal to it have been successfully persisted. """ @@ -160,4 +183,4 @@ class ChainedIdGenerator(object): stream_id, chained_id = self._unfinished_ids[0] return (stream_id - 1, chained_id) - return (self._current_max, self.chained_generator.get_max_token()) + return (self._current_max, self.chained_generator.get_current_token()) diff --git a/synapse/streams/events.py b/synapse/streams/events.py index d4c0bb6732..4d44c3d4ca 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -41,12 +41,13 @@ class EventSources(object): self.store = hs.get_datastore() @defer.inlineCallbacks - def get_current_token(self, direction='f'): + def get_current_token(self): push_rules_key, _ = self.store.get_push_rules_stream_token() + to_device_key = self.store.get_to_device_stream_token() token = StreamToken( room_key=( - yield self.sources["room"].get_current_key(direction) + yield self.sources["room"].get_current_key() ), presence_key=( yield self.sources["presence"].get_current_key() @@ -61,5 +62,32 @@ class EventSources(object): yield self.sources["account_data"].get_current_key() ), push_rules_key=push_rules_key, + to_device_key=to_device_key, + ) + defer.returnValue(token) + + @defer.inlineCallbacks + def get_current_token_for_room(self, room_id): + push_rules_key, _ = self.store.get_push_rules_stream_token() + to_device_key = self.store.get_to_device_stream_token() + + token = StreamToken( + room_key=( + yield self.sources["room"].get_current_key_for_room(room_id) + ), + presence_key=( + yield self.sources["presence"].get_current_key() + ), + typing_key=( + yield self.sources["typing"].get_current_key() + ), + receipt_key=( + yield self.sources["receipt"].get_current_key() + ), + account_data_key=( + yield self.sources["account_data"].get_current_key() + ), + push_rules_key=push_rules_key, + to_device_key=to_device_key, ) defer.returnValue(token) diff --git a/synapse/types.py b/synapse/types.py index 5b166835bd..ffab12df09 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -18,7 +18,48 @@ from synapse.api.errors import SynapseError from collections import namedtuple -Requester = namedtuple("Requester", ["user", "access_token_id", "is_guest"]) +Requester = namedtuple("Requester", [ + "user", "access_token_id", "is_guest", "device_id", "app_service", +]) +""" +Represents the user making a request + +Attributes: + user (UserID): id of the user making the request + access_token_id (int|None): *ID* of the access token used for this + request, or None if it came via the appservice API or similar + is_guest (bool): True if the user making this request is a guest user + device_id (str|None): device_id which was set at authentication time + app_service (ApplicationService|None): the AS requesting on behalf of the user +""" + + +def create_requester(user_id, access_token_id=None, is_guest=False, + device_id=None, app_service=None): + """ + Create a new ``Requester`` object + + Args: + user_id (str|UserID): id of the user making the request + access_token_id (int|None): *ID* of the access token used for this + request, or None if it came via the appservice API or similar + is_guest (bool): True if the user making this request is a guest user + device_id (str|None): device_id which was set at authentication time + app_service (ApplicationService|None): the AS requesting on behalf of the user + + Returns: + Requester + """ + if not isinstance(user_id, UserID): + user_id = UserID.from_string(user_id) + return Requester(user_id, access_token_id, is_guest, device_id, app_service) + + +def get_domain_from_id(string): + try: + return string.split(":", 1)[1] + except IndexError: + raise SynapseError(400, "Invalid ID: %r" % (string,)) class DomainSpecificString( @@ -116,6 +157,7 @@ class StreamToken( "receipt_key", "account_data_key", "push_rules_key", + "to_device_key", )) ): _SEPARATOR = "_" @@ -152,6 +194,7 @@ class StreamToken( or (int(other.receipt_key) < int(self.receipt_key)) or (int(other.account_data_key) < int(self.account_data_key)) or (int(other.push_rules_key) < int(self.push_rules_key)) + or (int(other.to_device_key) < int(self.to_device_key)) ) def copy_and_advance(self, key, new_value): diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 3b9da5b34a..c05b9450be 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.api.errors import SynapseError from synapse.util.logcontext import PreserveLoggingContext from twisted.internet import defer, reactor, task @@ -33,7 +34,7 @@ class Clock(object): """A small utility that obtains current time-of-day so that time may be mocked during unit-tests. - TODO(paul): Also move the sleep() functionallity into it + TODO(paul): Also move the sleep() functionality into it """ def time(self): @@ -45,13 +46,18 @@ class Clock(object): return int(self.time() * 1000) def looping_call(self, f, msec): + """Call a function repeatedly. + + Waits `msec` initially before calling `f` for the first time. + + Args: + f(function): The function to call repeatedly. + msec(float): How long to wait between calls in milliseconds. + """ l = task.LoopingCall(f) l.start(msec / 1000.0, now=False) return l - def stop_looping_call(self, loop): - loop.stop() - def call_later(self, delay, callback, *args, **kwargs): """Call something later @@ -83,7 +89,7 @@ class Clock(object): def timed_out_fn(): try: - ret_deferred.errback(RuntimeError("Timed out")) + ret_deferred.errback(SynapseError(504, "Timed out")) except: pass diff --git a/synapse/util/async.py b/synapse/util/async.py index 640fae3890..347fb1e380 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -16,7 +16,12 @@ from twisted.internet import defer, reactor -from .logcontext import PreserveLoggingContext +from .logcontext import ( + PreserveLoggingContext, preserve_fn, preserve_context_over_deferred, +) +from synapse.util import unwrapFirstError + +from contextlib import contextmanager @defer.inlineCallbacks @@ -97,6 +102,15 @@ class ObservableDeferred(object): def observers(self): return self._observers + def has_called(self): + return self._result is not None + + def has_succeeded(self): + return self._result is not None and self._result[0] is True + + def get_result(self): + return self._result[1] + def __getattr__(self, name): return getattr(self._deferred, name) @@ -107,3 +121,159 @@ class ObservableDeferred(object): return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % ( id(self), self._result, self._deferred, ) + + +def concurrently_execute(func, args, limit): + """Executes the function with each argument conncurrently while limiting + the number of concurrent executions. + + Args: + func (func): Function to execute, should return a deferred. + args (list): List of arguments to pass to func, each invocation of func + gets a signle argument. + limit (int): Maximum number of conccurent executions. + + Returns: + deferred: Resolved when all function invocations have finished. + """ + it = iter(args) + + @defer.inlineCallbacks + def _concurrently_execute_inner(): + try: + while True: + yield func(it.next()) + except StopIteration: + pass + + return preserve_context_over_deferred(defer.gatherResults([ + preserve_fn(_concurrently_execute_inner)() + for _ in xrange(limit) + ], consumeErrors=True)).addErrback(unwrapFirstError) + + +class Linearizer(object): + """Linearizes access to resources based on a key. Useful to ensure only one + thing is happening at a time on a given resource. + + Example: + + with (yield linearizer.queue("test_key")): + # do some work. + + """ + def __init__(self): + self.key_to_defer = {} + + @defer.inlineCallbacks + def queue(self, key): + # If there is already a deferred in the queue, we pull it out so that + # we can wait on it later. + # Then we replace it with a deferred that we resolve *after* the + # context manager has exited. + # We only return the context manager after the previous deferred has + # resolved. + # This all has the net effect of creating a chain of deferreds that + # wait for the previous deferred before starting their work. + current_defer = self.key_to_defer.get(key) + + new_defer = defer.Deferred() + self.key_to_defer[key] = new_defer + + if current_defer: + with PreserveLoggingContext(): + yield current_defer + + @contextmanager + def _ctx_manager(): + try: + yield + finally: + new_defer.callback(None) + current_d = self.key_to_defer.get(key) + if current_d is new_defer: + self.key_to_defer.pop(key, None) + + defer.returnValue(_ctx_manager()) + + +class ReadWriteLock(object): + """A deferred style read write lock. + + Example: + + with (yield read_write_lock.read("test_key")): + # do some work + """ + + # IMPLEMENTATION NOTES + # + # We track the most recent queued reader and writer deferreds (which get + # resolved when they release the lock). + # + # Read: We know its safe to acquire a read lock when the latest writer has + # been resolved. The new reader is appeneded to the list of latest readers. + # + # Write: We know its safe to acquire the write lock when both the latest + # writers and readers have been resolved. The new writer replaces the latest + # writer. + + def __init__(self): + # Latest readers queued + self.key_to_current_readers = {} + + # Latest writer queued + self.key_to_current_writer = {} + + @defer.inlineCallbacks + def read(self, key): + new_defer = defer.Deferred() + + curr_readers = self.key_to_current_readers.setdefault(key, set()) + curr_writer = self.key_to_current_writer.get(key, None) + + curr_readers.add(new_defer) + + # We wait for the latest writer to finish writing. We can safely ignore + # any existing readers... as they're readers. + yield curr_writer + + @contextmanager + def _ctx_manager(): + try: + yield + finally: + new_defer.callback(None) + self.key_to_current_readers.get(key, set()).discard(new_defer) + + defer.returnValue(_ctx_manager()) + + @defer.inlineCallbacks + def write(self, key): + new_defer = defer.Deferred() + + curr_readers = self.key_to_current_readers.get(key, set()) + curr_writer = self.key_to_current_writer.get(key, None) + + # We wait on all latest readers and writer. + to_wait_on = list(curr_readers) + if curr_writer: + to_wait_on.append(curr_writer) + + # We can clear the list of current readers since the new writer waits + # for them to finish. + curr_readers.clear() + self.key_to_current_writer[key] = new_defer + + yield preserve_context_over_deferred(defer.gatherResults(to_wait_on)) + + @contextmanager + def _ctx_manager(): + try: + yield + finally: + new_defer.callback(None) + if self.key_to_current_writer[key] == new_defer: + self.key_to_current_writer.pop(key) + + defer.returnValue(_ctx_manager()) diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index 1a14904194..ebd715c5dc 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -14,14 +14,81 @@ # limitations under the License. import synapse.metrics +from lrucache import LruCache +import os + +CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) DEBUG_CACHES = False metrics = synapse.metrics.get_metrics_for("synapse.util.caches") caches_by_name = {} -cache_counter = metrics.register_cache( - "cache", - lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()}, - labels=["name"], -) +# cache_counter = metrics.register_cache( +# "cache", +# lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()}, +# labels=["name"], +# ) + + +def register_cache(name, cache): + caches_by_name[name] = cache + return metrics.register_cache( + "cache", + lambda: len(cache), + name, + ) + + +_string_cache = LruCache(int(5000 * CACHE_SIZE_FACTOR)) +caches_by_name["string_cache"] = _string_cache + + +KNOWN_KEYS = { + key: key for key in + ( + "auth_events", + "content", + "depth", + "event_id", + "hashes", + "origin", + "origin_server_ts", + "prev_events", + "room_id", + "sender", + "signatures", + "state_key", + "type", + "unsigned", + "user_id", + ) +} + + +def intern_string(string): + """Takes a (potentially) unicode string and interns using custom cache + """ + return _string_cache.setdefault(string, string) + + +def intern_dict(dictionary): + """Takes a dictionary and interns well known keys and their values + """ + return { + KNOWN_KEYS.get(key, key): _intern_known_values(key, value) + for key, value in dictionary.items() + } + + +def _intern_known_values(key, value): + intern_str_keys = ("event_id", "room_id") + intern_unicode_keys = ("sender", "user_id", "type", "state_key") + + if key in intern_str_keys: + return intern(value.encode('ascii')) + + if key in intern_unicode_keys: + return intern_string(value) + + return value diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 35544b19fd..8dba61d49f 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -22,17 +22,17 @@ from synapse.util.logcontext import ( PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn ) -from . import caches_by_name, DEBUG_CACHES, cache_counter +from . import DEBUG_CACHES, register_cache from twisted.internet import defer - -from collections import OrderedDict +from collections import namedtuple import os import functools import inspect import threading + logger = logging.getLogger(__name__) @@ -43,23 +43,27 @@ CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) class Cache(object): + __slots__ = ( + "cache", + "max_entries", + "name", + "keylen", + "sequence", + "thread", + "metrics", + ) - def __init__(self, name, max_entries=1000, keylen=1, lru=True, tree=False): - if lru: - cache_type = TreeCache if tree else dict - self.cache = LruCache( - max_size=max_entries, keylen=keylen, cache_type=cache_type - ) - self.max_entries = None - else: - self.cache = OrderedDict() - self.max_entries = max_entries + def __init__(self, name, max_entries=1000, keylen=1, tree=False): + cache_type = TreeCache if tree else dict + self.cache = LruCache( + max_size=max_entries, keylen=keylen, cache_type=cache_type + ) self.name = name self.keylen = keylen self.sequence = 0 self.thread = None - caches_by_name[name] = self.cache + self.metrics = register_cache(name, self.cache) def check_thread(self): expected_thread = self.thread @@ -71,32 +75,28 @@ class Cache(object): "Cache objects can only be accessed from the main thread" ) - def get(self, key, default=_CacheSentinel): - val = self.cache.get(key, _CacheSentinel) + def get(self, key, default=_CacheSentinel, callback=None): + val = self.cache.get(key, _CacheSentinel, callback=callback) if val is not _CacheSentinel: - cache_counter.inc_hits(self.name) + self.metrics.inc_hits() return val - cache_counter.inc_misses(self.name) + self.metrics.inc_misses() if default is _CacheSentinel: raise KeyError() else: return default - def update(self, sequence, key, value): + def update(self, sequence, key, value, callback=None): self.check_thread() if self.sequence == sequence: # Only update the cache if the caches sequence number matches the # number that the cache had before the SELECT was started (SYN-369) - self.prefill(key, value) - - def prefill(self, key, value): - if self.max_entries is not None: - while len(self.cache) >= self.max_entries: - self.cache.popitem(last=False) + self.prefill(key, value, callback=callback) - self.cache[key] = value + def prefill(self, key, value, callback=None): + self.cache.set(key, value, callback=callback) def invalidate(self, key): self.check_thread() @@ -141,9 +141,21 @@ class CacheDescriptor(object): The wrapped function has another additional callable, called "prefill", which can be used to insert values into the cache specifically, without calling the calculation function. + + Cached functions can be "chained" (i.e. a cached function can call other cached + functions and get appropriately invalidated when they called caches are + invalidated) by adding a special "cache_context" argument to the function + and passing that as a kwarg to all caches called. For example:: + + @cachedInlineCallbacks(cache_context=True) + def foo(self, key, cache_context): + r1 = yield self.bar1(key, on_invalidate=cache_context.invalidate) + r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate) + defer.returnValue(r1 + r2) + """ - def __init__(self, orig, max_entries=1000, num_args=1, lru=True, tree=False, - inlineCallbacks=False): + def __init__(self, orig, max_entries=1000, num_args=1, tree=False, + inlineCallbacks=False, cache_context=False): max_entries = int(max_entries * CACHE_SIZE_FACTOR) self.orig = orig @@ -155,34 +167,64 @@ class CacheDescriptor(object): self.max_entries = max_entries self.num_args = num_args - self.lru = lru self.tree = tree - self.arg_names = inspect.getargspec(orig).args[1:num_args + 1] + all_args = inspect.getargspec(orig) + self.arg_names = all_args.args[1:num_args + 1] + + if "cache_context" in all_args.args: + if not cache_context: + raise ValueError( + "Cannot have a 'cache_context' arg without setting" + " cache_context=True" + ) + try: + self.arg_names.remove("cache_context") + except ValueError: + pass + elif cache_context: + raise ValueError( + "Cannot have cache_context=True without having an arg" + " named `cache_context`" + ) + + self.add_cache_context = cache_context if len(self.arg_names) < self.num_args: raise Exception( "Not enough explicit positional arguments to key off of for %r." - " (@cached cannot key off of *args or **kwars)" + " (@cached cannot key off of *args or **kwargs)" % (orig.__name__,) ) - self.cache = Cache( + def __get__(self, obj, objtype=None): + cache = Cache( name=self.orig.__name__, max_entries=self.max_entries, keylen=self.num_args, - lru=self.lru, tree=self.tree, ) - def __get__(self, obj, objtype=None): - @functools.wraps(self.orig) def wrapped(*args, **kwargs): + # If we're passed a cache_context then we'll want to call its invalidate() + # whenever we are invalidated + invalidate_callback = kwargs.pop("on_invalidate", None) + + # Add temp cache_context so inspect.getcallargs doesn't explode + if self.add_cache_context: + kwargs["cache_context"] = None + arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) + + # Add our own `cache_context` to argument list if the wrapped function + # has asked for one + if self.add_cache_context: + kwargs["cache_context"] = _CacheContext(cache, cache_key) + try: - cached_result_d = self.cache.get(cache_key) + cached_result_d = cache.get(cache_key, callback=invalidate_callback) observer = cached_result_d.observe() if DEBUG_CACHES: @@ -204,7 +246,7 @@ class CacheDescriptor(object): # Get the sequence number of the cache before reading from the # database so that we can tell if the cache is invalidated # while the SELECT is executing (SYN-369) - sequence = self.cache.sequence + sequence = cache.sequence ret = defer.maybeDeferred( preserve_context_over_fn, @@ -213,20 +255,21 @@ class CacheDescriptor(object): ) def onErr(f): - self.cache.invalidate(cache_key) + cache.invalidate(cache_key) return f ret.addErrback(onErr) ret = ObservableDeferred(ret, consumeErrors=True) - self.cache.update(sequence, cache_key, ret) + cache.update(sequence, cache_key, ret, callback=invalidate_callback) return preserve_context_over_deferred(ret.observe()) - wrapped.invalidate = self.cache.invalidate - wrapped.invalidate_all = self.cache.invalidate_all - wrapped.invalidate_many = self.cache.invalidate_many - wrapped.prefill = self.cache.prefill + wrapped.invalidate = cache.invalidate + wrapped.invalidate_all = cache.invalidate_all + wrapped.invalidate_many = cache.invalidate_many + wrapped.prefill = cache.prefill + wrapped.cache = cache obj.__dict__[self.orig.__name__] = wrapped @@ -240,11 +283,12 @@ class CacheListDescriptor(object): the list of missing keys to the wrapped fucntion. """ - def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False): + def __init__(self, orig, cached_method_name, list_name, num_args=1, + inlineCallbacks=False): """ Args: orig (function) - cache (Cache) + method_name (str); The name of the chached method. list_name (str): Name of the argument which is the bulk lookup list num_args (int) inlineCallbacks (bool): Whether orig is a generator that should @@ -263,7 +307,7 @@ class CacheListDescriptor(object): self.arg_names = inspect.getargspec(orig).args[1:num_args + 1] self.list_pos = self.arg_names.index(self.list_name) - self.cache = cache + self.cached_method_name = cached_method_name self.sentinel = object() @@ -277,34 +321,45 @@ class CacheListDescriptor(object): if self.list_name not in self.arg_names: raise Exception( "Couldn't see arguments %r for %r." - % (self.list_name, cache.name,) + % (self.list_name, cached_method_name,) ) def __get__(self, obj, objtype=None): + cache = getattr(obj, self.cached_method_name).cache + @functools.wraps(self.orig) def wrapped(*args, **kwargs): + # If we're passed a cache_context then we'll want to call its invalidate() + # whenever we are invalidated + invalidate_callback = kwargs.pop("on_invalidate", None) + arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] list_args = arg_dict[self.list_name] # cached is a dict arg -> deferred, where deferred results in a # 2-tuple (`arg`, `result`) - cached = {} + results = {} + cached_defers = {} missing = [] for arg in list_args: key = list(keyargs) key[self.list_pos] = arg try: - res = self.cache.get(tuple(key)).observe() - res.addCallback(lambda r, arg: (arg, r), arg) - cached[arg] = res + res = cache.get(tuple(key), callback=invalidate_callback) + if not res.has_succeeded(): + res = res.observe() + res.addCallback(lambda r, arg: (arg, r), arg) + cached_defers[arg] = res + else: + results[arg] = res.get_result() except KeyError: missing.append(arg) if missing: - sequence = self.cache.sequence + sequence = cache.sequence args_to_call = dict(arg_dict) args_to_call[self.list_name] = missing @@ -327,50 +382,67 @@ class CacheListDescriptor(object): key = list(keyargs) key[self.list_pos] = arg - self.cache.update(sequence, tuple(key), observer) + cache.update( + sequence, tuple(key), observer, + callback=invalidate_callback + ) def invalidate(f, key): - self.cache.invalidate(key) + cache.invalidate(key) return f observer.addErrback(invalidate, tuple(key)) res = observer.observe() res.addCallback(lambda r, arg: (arg, r), arg) - cached[arg] = res + cached_defers[arg] = res + + if cached_defers: + def update_results_dict(res): + results.update(res) + return results - return preserve_context_over_deferred(defer.gatherResults( - cached.values(), - consumeErrors=True, - ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res))) + return preserve_context_over_deferred(defer.gatherResults( + cached_defers.values(), + consumeErrors=True, + ).addCallback(update_results_dict).addErrback( + unwrapFirstError + )) + else: + return results obj.__dict__[self.orig.__name__] = wrapped return wrapped -def cached(max_entries=1000, num_args=1, lru=True, tree=False): +class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))): + def invalidate(self): + self.cache.invalidate(self.key) + + +def cached(max_entries=1000, num_args=1, tree=False, cache_context=False): return lambda orig: CacheDescriptor( orig, max_entries=max_entries, num_args=num_args, - lru=lru, tree=tree, + cache_context=cache_context, ) -def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False, tree=False): +def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False): return lambda orig: CacheDescriptor( orig, max_entries=max_entries, num_args=num_args, - lru=lru, tree=tree, inlineCallbacks=True, + cache_context=cache_context, ) -def cachedList(cache, list_name, num_args=1, inlineCallbacks=False): +def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False): """Creates a descriptor that wraps a function in a `CacheListDescriptor`. Used to do batch lookups for an already created cache. A single argument @@ -400,7 +472,7 @@ def cachedList(cache, list_name, num_args=1, inlineCallbacks=False): """ return lambda orig: CacheListDescriptor( orig, - cache=cache, + cached_method_name=cached_method_name, list_name=list_name, num_args=num_args, inlineCallbacks=inlineCallbacks, diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index f92d80542b..b0ca1bb79d 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -15,7 +15,7 @@ from synapse.util.caches.lrucache import LruCache from collections import namedtuple -from . import caches_by_name, cache_counter +from . import register_cache import threading import logging @@ -43,7 +43,7 @@ class DictionaryCache(object): __slots__ = [] self.sentinel = Sentinel() - caches_by_name[name] = self.cache + self.metrics = register_cache(name, self.cache) def check_thread(self): expected_thread = self.thread @@ -58,7 +58,7 @@ class DictionaryCache(object): def get(self, key, dict_keys=None): entry = self.cache.get(key, self.sentinel) if entry is not self.sentinel: - cache_counter.inc_hits(self.name) + self.metrics.inc_hits() if dict_keys is None: return DictionaryEntry(entry.full, dict(entry.value)) @@ -69,7 +69,7 @@ class DictionaryCache(object): if k in entry.value }) - cache_counter.inc_misses(self.name) + self.metrics.inc_misses() return DictionaryEntry(False, {}) def invalidate(self, key): diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index 2b68c1ac93..080388958f 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util.caches import cache_counter, caches_by_name +from synapse.util.caches import register_cache import logging @@ -49,7 +49,7 @@ class ExpiringCache(object): self._cache = {} - caches_by_name[cache_name] = self._cache + self.metrics = register_cache(cache_name, self._cache) def start(self): if not self._expiry_ms: @@ -78,9 +78,9 @@ class ExpiringCache(object): def __getitem__(self, key): try: entry = self._cache[key] - cache_counter.inc_hits(self._cache_name) + self.metrics.inc_hits() except KeyError: - cache_counter.inc_misses(self._cache_name) + self.metrics.inc_misses() raise if self._reset_expiry_on_get: diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index f7423f2fab..9c4c679175 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -29,19 +29,32 @@ def enumerate_leaves(node, depth): yield m +class _Node(object): + __slots__ = ["prev_node", "next_node", "key", "value", "callbacks"] + + def __init__(self, prev_node, next_node, key, value, callbacks=set()): + self.prev_node = prev_node + self.next_node = next_node + self.key = key + self.value = value + self.callbacks = callbacks + + class LruCache(object): """ Least-recently-used cache. Supports del_multi only if cache_type=TreeCache If cache_type=TreeCache, all keys must be tuples. + + Can also set callbacks on objects when getting/setting which are fired + when that key gets invalidated/evicted. """ def __init__(self, max_size, keylen=1, cache_type=dict): cache = cache_type() self.cache = cache # Used for introspection. - list_root = [] - list_root[:] = [list_root, list_root, None, None] - - PREV, NEXT, KEY, VALUE = 0, 1, 2, 3 + list_root = _Node(None, None, None, None) + list_root.next_node = list_root + list_root.prev_node = list_root lock = threading.Lock() @@ -53,65 +66,83 @@ class LruCache(object): return inner - def add_node(key, value): + def add_node(key, value, callbacks=set()): prev_node = list_root - next_node = prev_node[NEXT] - node = [prev_node, next_node, key, value] - prev_node[NEXT] = node - next_node[PREV] = node + next_node = prev_node.next_node + node = _Node(prev_node, next_node, key, value, callbacks) + prev_node.next_node = node + next_node.prev_node = node cache[key] = node def move_node_to_front(node): - prev_node = node[PREV] - next_node = node[NEXT] - prev_node[NEXT] = next_node - next_node[PREV] = prev_node + prev_node = node.prev_node + next_node = node.next_node + prev_node.next_node = next_node + next_node.prev_node = prev_node prev_node = list_root - next_node = prev_node[NEXT] - node[PREV] = prev_node - node[NEXT] = next_node - prev_node[NEXT] = node - next_node[PREV] = node + next_node = prev_node.next_node + node.prev_node = prev_node + node.next_node = next_node + prev_node.next_node = node + next_node.prev_node = node def delete_node(node): - prev_node = node[PREV] - next_node = node[NEXT] - prev_node[NEXT] = next_node - next_node[PREV] = prev_node + prev_node = node.prev_node + next_node = node.next_node + prev_node.next_node = next_node + next_node.prev_node = prev_node + + for cb in node.callbacks: + cb() + node.callbacks.clear() @synchronized - def cache_get(key, default=None): + def cache_get(key, default=None, callback=None): node = cache.get(key, None) if node is not None: move_node_to_front(node) - return node[VALUE] + if callback: + node.callbacks.add(callback) + return node.value else: return default @synchronized - def cache_set(key, value): + def cache_set(key, value, callback=None): node = cache.get(key, None) if node is not None: + if value != node.value: + for cb in node.callbacks: + cb() + node.callbacks.clear() + + if callback: + node.callbacks.add(callback) + move_node_to_front(node) - node[VALUE] = value + node.value = value else: - add_node(key, value) + if callback: + callbacks = set([callback]) + else: + callbacks = set() + add_node(key, value, callbacks) if len(cache) > max_size: - todelete = list_root[PREV] + todelete = list_root.prev_node delete_node(todelete) - cache.pop(todelete[KEY], None) + cache.pop(todelete.key, None) @synchronized def cache_set_default(key, value): node = cache.get(key, None) if node is not None: - return node[VALUE] + return node.value else: add_node(key, value) if len(cache) > max_size: - todelete = list_root[PREV] + todelete = list_root.prev_node delete_node(todelete) - cache.pop(todelete[KEY], None) + cache.pop(todelete.key, None) return value @synchronized @@ -119,8 +150,8 @@ class LruCache(object): node = cache.get(key, None) if node: delete_node(node) - cache.pop(node[KEY], None) - return node[VALUE] + cache.pop(node.key, None) + return node.value else: return default @@ -137,8 +168,11 @@ class LruCache(object): @synchronized def cache_clear(): - list_root[NEXT] = list_root - list_root[PREV] = list_root + list_root.next_node = list_root + list_root.prev_node = list_root + for node in cache.values(): + for cb in node.callbacks: + cb() cache.clear() @synchronized diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py new file mode 100644 index 0000000000..00af539880 --- /dev/null +++ b/synapse/util/caches/response_cache.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.util.async import ObservableDeferred + + +class ResponseCache(object): + """ + This caches a deferred response. Until the deferred completes it will be + returned from the cache. This means that if the client retries the request + while the response is still being computed, that original response will be + used rather than trying to compute a new response. + """ + + def __init__(self, hs, timeout_ms=0): + self.pending_result_cache = {} # Requests that haven't finished yet. + + self.clock = hs.get_clock() + self.timeout_sec = timeout_ms / 1000. + + def get(self, key): + result = self.pending_result_cache.get(key) + if result is not None: + return result.observe() + else: + return None + + def set(self, key, deferred): + result = ObservableDeferred(deferred, consumeErrors=True) + self.pending_result_cache[key] = result + + def remove(r): + if self.timeout_sec: + self.clock.call_later( + self.timeout_sec, + self.pending_result_cache.pop, key, None, + ) + else: + self.pending_result_cache.pop(key, None) + return r + + result.addBoth(remove) + return result.observe() diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index ea8a74ca69..b72bb0ff02 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util.caches import cache_counter, caches_by_name +from synapse.util.caches import register_cache from blist import sorteddict @@ -42,7 +42,7 @@ class StreamChangeCache(object): self._cache = sorteddict() self._earliest_known_stream_pos = current_stream_pos self.name = name - caches_by_name[self.name] = self._cache + self.metrics = register_cache(self.name, self._cache) for entity, stream_pos in prefilled_cache.items(): self.entity_has_changed(entity, stream_pos) @@ -53,19 +53,19 @@ class StreamChangeCache(object): assert type(stream_pos) is int if stream_pos < self._earliest_known_stream_pos: - cache_counter.inc_misses(self.name) + self.metrics.inc_misses() return True latest_entity_change_pos = self._entity_to_key.get(entity, None) if latest_entity_change_pos is None: - cache_counter.inc_hits(self.name) + self.metrics.inc_hits() return False if stream_pos < latest_entity_change_pos: - cache_counter.inc_misses(self.name) + self.metrics.inc_misses() return True - cache_counter.inc_hits(self.name) + self.metrics.inc_hits() return False def get_entities_changed(self, entities, stream_pos): @@ -82,10 +82,10 @@ class StreamChangeCache(object): self._cache[k] for k in keys[i:] ).intersection(entities) - cache_counter.inc_hits(self.name) + self.metrics.inc_hits() else: result = entities - cache_counter.inc_misses(self.name) + self.metrics.inc_misses() return result @@ -121,3 +121,9 @@ class StreamChangeCache(object): k, r = self._cache.popitem() self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos) self._entity_to_key.pop(r, None) + + def get_max_pos_of_last_change(self, entity): + """Returns an upper bound of the stream id of the last change to an + entity. + """ + return self._entity_to_key.get(entity, self._earliest_known_stream_pos) diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py index 03bc1401b7..c31585aea3 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py @@ -64,6 +64,9 @@ class TreeCache(object): self.size -= cnt return popped + def values(self): + return [e.value for e in self.root.values()] + def __len__(self): return self.size diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index 8875813de4..e68f94ce77 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -15,7 +15,9 @@ from twisted.internet import defer -from synapse.util.logcontext import PreserveLoggingContext +from synapse.util.logcontext import ( + PreserveLoggingContext, preserve_context_over_fn +) from synapse.util import unwrapFirstError @@ -25,6 +27,20 @@ import logging logger = logging.getLogger(__name__) +def user_left_room(distributor, user, room_id): + return preserve_context_over_fn( + distributor.fire, + "user_left_room", user=user, room_id=room_id + ) + + +def user_joined_room(distributor, user, room_id): + return preserve_context_over_fn( + distributor.fire, + "user_joined_room", user=user, room_id=room_id + ) + + class Distributor(object): """A central dispatch point for loosely-connected pieces of code to register, observe, and fire signals. diff --git a/synapse/util/httpresourcetree.py b/synapse/util/httpresourcetree.py new file mode 100644 index 0000000000..45be47159a --- /dev/null +++ b/synapse/util/httpresourcetree.py @@ -0,0 +1,98 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.web.resource import Resource + +import logging + +logger = logging.getLogger(__name__) + + +def create_resource_tree(desired_tree, root_resource): + """Create the resource tree for this Home Server. + + This in unduly complicated because Twisted does not support putting + child resources more than 1 level deep at a time. + + Args: + web_client (bool): True to enable the web client. + root_resource (twisted.web.resource.Resource): The root + resource to add the tree to. + Returns: + twisted.web.resource.Resource: the ``root_resource`` with a tree of + child resources added to it. + """ + + # ideally we'd just use getChild and putChild but getChild doesn't work + # unless you give it a Request object IN ADDITION to the name :/ So + # instead, we'll store a copy of this mapping so we can actually add + # extra resources to existing nodes. See self._resource_id for the key. + resource_mappings = {} + for full_path, res in desired_tree.items(): + logger.info("Attaching %s to path %s", res, full_path) + last_resource = root_resource + for path_seg in full_path.split('/')[1:-1]: + if path_seg not in last_resource.listNames(): + # resource doesn't exist, so make a "dummy resource" + child_resource = Resource() + last_resource.putChild(path_seg, child_resource) + res_id = _resource_id(last_resource, path_seg) + resource_mappings[res_id] = child_resource + last_resource = child_resource + else: + # we have an existing Resource, use that instead. + res_id = _resource_id(last_resource, path_seg) + last_resource = resource_mappings[res_id] + + # =========================== + # now attach the actual desired resource + last_path_seg = full_path.split('/')[-1] + + # if there is already a resource here, thieve its children and + # replace it + res_id = _resource_id(last_resource, last_path_seg) + if res_id in resource_mappings: + # there is a dummy resource at this path already, which needs + # to be replaced with the desired resource. + existing_dummy_resource = resource_mappings[res_id] + for child_name in existing_dummy_resource.listNames(): + child_res_id = _resource_id( + existing_dummy_resource, child_name + ) + child_resource = resource_mappings[child_res_id] + # steal the children + res.putChild(child_name, child_resource) + + # finally, insert the desired resource in the right place + last_resource.putChild(last_path_seg, res) + res_id = _resource_id(last_resource, last_path_seg) + resource_mappings[res_id] = res + + return root_resource + + +def _resource_id(resource, path_seg): + """Construct an arbitrary resource ID so you can retrieve the mapping + later. + + If you want to represent resource A putChild resource B with path C, + the mapping should looks like _resource_id(A,C) = B. + + Args: + resource (Resource): The *parent* Resourceb + path_seg (str): The name of the child Resource to be attached. + Returns: + str: A unique string which can be a key to the child Resource. + """ + return "%s-%s" % (resource, path_seg) diff --git a/synapse/util/jsonobject.py b/synapse/util/jsonobject.py index 3fd5c3d9fd..d668e5a6b8 100644 --- a/synapse/util/jsonobject.py +++ b/synapse/util/jsonobject.py @@ -76,15 +76,26 @@ class JsonEncodedObject(object): d.update(self.unrecognized_keys) return d + def get_internal_dict(self): + d = { + k: _encode(v, internal=True) for (k, v) in self.__dict__.items() + if k in self.valid_keys + } + d.update(self.unrecognized_keys) + return d + def __str__(self): return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__)) -def _encode(obj): +def _encode(obj, internal=False): if type(obj) is list: - return [_encode(o) for o in obj] + return [_encode(o, internal=internal) for o in obj] if isinstance(obj, JsonEncodedObject): - return obj.get_dict() + if internal: + return obj.get_internal_dict() + else: + return obj.get_dict() return obj diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index 5316259d15..6c83eb213d 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -297,12 +297,13 @@ def preserve_context_over_fn(fn, *args, **kwargs): return res -def preserve_context_over_deferred(deferred): +def preserve_context_over_deferred(deferred, context=None): """Given a deferred wrap it such that any callbacks added later to it will be invoked with the current context. """ - current_context = LoggingContext.current_context() - d = _PreservingContextDeferred(current_context) + if context is None: + context = LoggingContext.current_context() + d = _PreservingContextDeferred(context) deferred.chainDeferred(d) return d @@ -316,8 +317,13 @@ def preserve_fn(f): def g(*args, **kwargs): with PreserveLoggingContext(current): - return f(*args, **kwargs) - + res = f(*args, **kwargs) + if isinstance(res, defer.Deferred): + return preserve_context_over_deferred( + res, context=LoggingContext.sentinel + ) + else: + return res return g diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py new file mode 100644 index 0000000000..97e0f00b67 --- /dev/null +++ b/synapse/util/manhole.py @@ -0,0 +1,70 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.conch.manhole import ColoredManhole +from twisted.conch.insults import insults +from twisted.conch import manhole_ssh +from twisted.cred import checkers, portal +from twisted.conch.ssh.keys import Key + +PUBLIC_KEY = ( + "ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAGEArzJx8OYOnJmzf4tfBEvLi8DVPrJ3/c9k2I/Az" + "64fxjHf9imyRJbixtQhlH9lfNjUIx+4LmrJH5QNRsFporcHDKOTwTTYLh5KmRpslkYHRivcJS" + "kbh/C+BR3utDS555mV" +) + +PRIVATE_KEY = """-----BEGIN RSA PRIVATE KEY----- +MIIByAIBAAJhAK8ycfDmDpyZs3+LXwRLy4vA1T6yd/3PZNiPwM+uH8Yx3/YpskSW +4sbUIZR/ZXzY1CMfuC5qyR+UDUbBaaK3Bwyjk8E02C4eSpkabJZGB0Yr3CUpG4fw +vgUd7rQ0ueeZlQIBIwJgbh+1VZfr7WftK5lu7MHtqE1S1vPWZQYE3+VUn8yJADyb +Z4fsZaCrzW9lkIqXkE3GIY+ojdhZhkO1gbG0118sIgphwSWKRxK0mvh6ERxKqIt1 +xJEJO74EykXZV4oNJ8sjAjEA3J9r2ZghVhGN6V8DnQrTk24Td0E8hU8AcP0FVP+8 +PQm/g/aXf2QQkQT+omdHVEJrAjEAy0pL0EBH6EVS98evDCBtQw22OZT52qXlAwZ2 +gyTriKFVoqjeEjt3SZKKqXHSApP/AjBLpF99zcJJZRq2abgYlf9lv1chkrWqDHUu +DZttmYJeEfiFBBavVYIF1dOlZT0G8jMCMBc7sOSZodFnAiryP+Qg9otSBjJ3bQML +pSTqy7c3a2AScC/YyOwkDaICHnnD3XyjMwIxALRzl0tQEKMXs6hH8ToUdlLROCrP +EhQ0wahUTCk1gKA4uPD6TMTChavbh4K63OvbKg== +-----END RSA PRIVATE KEY-----""" + + +def manhole(username, password, globals): + """Starts a ssh listener with password authentication using + the given username and password. Clients connecting to the ssh + listener will find themselves in a colored python shell with + the supplied globals. + + Args: + username(str): The username ssh clients should auth with. + password(str): The password ssh clients should auth with. + globals(dict): The variables to expose in the shell. + + Returns: + twisted.internet.protocol.Factory: A factory to pass to ``listenTCP`` + """ + + checker = checkers.InMemoryUsernamePasswordDatabaseDontUse( + **{username: password} + ) + + rlm = manhole_ssh.TerminalRealm() + rlm.chainedProtocolFactory = lambda: insults.ServerProtocol( + ColoredManhole, + dict(globals, __name__="__console__") + ) + + factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker])) + factory.publicKeys['ssh-rsa'] = Key.fromString(PUBLIC_KEY) + factory.privateKeys['ssh-rsa'] = Key.fromString(PRIVATE_KEY) + + return factory diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index c51b641125..4ea930d3e8 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -13,10 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet import defer from synapse.util.logcontext import LoggingContext import synapse.metrics +from functools import wraps import logging @@ -47,10 +49,22 @@ block_db_txn_duration = metrics.register_distribution( ) +def measure_func(name): + def wrapper(func): + @wraps(func) + @defer.inlineCallbacks + def measured_func(self, *args, **kwargs): + with Measure(self.clock, name): + r = yield func(self, *args, **kwargs) + defer.returnValue(r) + return measured_func + return wrapper + + class Measure(object): __slots__ = [ "clock", "name", "start_context", "start", "new_context", "ru_utime", - "ru_stime", "db_txn_count", "db_txn_duration" + "ru_stime", "db_txn_count", "db_txn_duration", "created_context" ] def __init__(self, clock, name): @@ -58,17 +72,22 @@ class Measure(object): self.name = name self.start_context = None self.start = None + self.created_context = False def __enter__(self): self.start = self.clock.time_msec() self.start_context = LoggingContext.current_context() - if self.start_context: - self.ru_utime, self.ru_stime = self.start_context.get_resource_usage() - self.db_txn_count = self.start_context.db_txn_count - self.db_txn_duration = self.start_context.db_txn_duration + if not self.start_context: + self.start_context = LoggingContext("Measure") + self.start_context.__enter__() + self.created_context = True + + self.ru_utime, self.ru_stime = self.start_context.get_resource_usage() + self.db_txn_count = self.start_context.db_txn_count + self.db_txn_duration = self.start_context.db_txn_duration def __exit__(self, exc_type, exc_val, exc_tb): - if exc_type is not None or not self.start_context: + if isinstance(exc_type, Exception) or not self.start_context: return duration = self.clock.time_msec() - self.start @@ -78,8 +97,8 @@ class Measure(object): if context != self.start_context: logger.warn( - "Context have unexpectedly changed from '%s' to '%s'. (%r)", - context, self.start_context, self.name + "Context has unexpectedly changed from '%s' to '%s'. (%r)", + self.start_context, context, self.name ) return @@ -91,7 +110,12 @@ class Measure(object): block_ru_utime.inc_by(ru_utime - self.ru_utime, self.name) block_ru_stime.inc_by(ru_stime - self.ru_stime, self.name) - block_db_txn_count.inc_by(context.db_txn_count - self.db_txn_count, self.name) + block_db_txn_count.inc_by( + context.db_txn_count - self.db_txn_count, self.name + ) block_db_txn_duration.inc_by( context.db_txn_duration - self.db_txn_duration, self.name ) + + if self.created_context: + self.start_context.__exit__(exc_type, exc_val, exc_tb) diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index 4076eed269..1101881a2d 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -100,20 +100,6 @@ class _PerHostRatelimiter(object): self.current_processing = set() self.request_times = [] - def is_empty(self): - time_now = self.clock.time_msec() - self.request_times[:] = [ - r for r in self.request_times - if time_now - r < self.window_size - ] - - return not ( - self.ready_request_queue - or self.sleeping_requests - or self.current_processing - or self.request_times - ) - @contextlib.contextmanager def ratelimit(self): # `contextlib.contextmanager` takes a generator and turns it into a diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index 43cf11f3f6..e2de7fce91 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -121,15 +121,9 @@ class RetryDestinationLimiter(object): pass def __exit__(self, exc_type, exc_val, exc_tb): - def err(failure): - logger.exception( - "Failed to store set_destination_retry_timings", - failure.value - ) - valid_err_code = False - if exc_type is CodeMessageException: - valid_err_code = 0 <= exc_val.code < 500 + if exc_type is not None and issubclass(exc_type, CodeMessageException): + valid_err_code = exc_val.code != 429 and 0 <= exc_val.code < 500 if exc_type is None or valid_err_code: # We connected successfully. @@ -151,6 +145,15 @@ class RetryDestinationLimiter(object): retry_last_ts = int(self.clock.time_msec()) - self.store.set_destination_retry_timings( - self.destination, retry_last_ts, self.retry_interval - ).addErrback(err) + @defer.inlineCallbacks + def store_retry_timings(): + try: + yield self.store.set_destination_retry_timings( + self.destination, retry_last_ts, self.retry_interval + ) + except: + logger.exception( + "Failed to store set_destination_retry_timings", + ) + + store_retry_timings() diff --git a/synapse/util/rlimit.py b/synapse/util/rlimit.py new file mode 100644 index 0000000000..f4a9abf83f --- /dev/null +++ b/synapse/util/rlimit.py @@ -0,0 +1,37 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import resource +import logging + + +logger = logging.getLogger("synapse.app.homeserver") + + +def change_resource_limit(soft_file_no): + try: + soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) + + if not soft_file_no: + soft_file_no = hard + + resource.setrlimit(resource.RLIMIT_NOFILE, (soft_file_no, hard)) + logger.info("Set file limit to: %d", soft_file_no) + + resource.setrlimit( + resource.RLIMIT_CORE, (resource.RLIM_INFINITY, resource.RLIM_INFINITY) + ) + except (ValueError, resource.error) as e: + logger.warn("Failed to set file or core limit: %s", e) diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index b490bb8725..a100f151d4 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -21,10 +21,6 @@ _string_with_symbols = ( ) -def origin_from_ucid(ucid): - return ucid.split("@", 1)[1] - - def random_string(length): return ''.join(random.choice(string.ascii_letters) for _ in xrange(length)) diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py new file mode 100644 index 0000000000..52086df465 --- /dev/null +++ b/synapse/util/versionstring.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import subprocess +import os +import logging + +logger = logging.getLogger(__name__) + + +def get_version_string(module): + try: + null = open(os.devnull, 'w') + cwd = os.path.dirname(os.path.abspath(module.__file__)) + try: + git_branch = subprocess.check_output( + ['git', 'rev-parse', '--abbrev-ref', 'HEAD'], + stderr=null, + cwd=cwd, + ).strip() + git_branch = "b=" + git_branch + except subprocess.CalledProcessError: + git_branch = "" + + try: + git_tag = subprocess.check_output( + ['git', 'describe', '--exact-match'], + stderr=null, + cwd=cwd, + ).strip() + git_tag = "t=" + git_tag + except subprocess.CalledProcessError: + git_tag = "" + + try: + git_commit = subprocess.check_output( + ['git', 'rev-parse', '--short', 'HEAD'], + stderr=null, + cwd=cwd, + ).strip() + except subprocess.CalledProcessError: + git_commit = "" + + try: + dirty_string = "-this_is_a_dirty_checkout" + is_dirty = subprocess.check_output( + ['git', 'describe', '--dirty=' + dirty_string], + stderr=null, + cwd=cwd, + ).strip().endswith(dirty_string) + + git_dirty = "dirty" if is_dirty else "" + except subprocess.CalledProcessError: + git_dirty = "" + + if git_branch or git_tag or git_commit or git_dirty: + git_version = ",".join( + s for s in + (git_branch, git_tag, git_commit, git_dirty,) + if s + ) + + return ( + "%s (%s)" % ( + module.__version__, git_version, + ) + ).encode("ascii") + except Exception as e: + logger.info("Failed to check for git repository: %s", e) + + return module.__version__.encode("ascii") diff --git a/synapse/visibility.py b/synapse/visibility.py new file mode 100644 index 0000000000..199b16d827 --- /dev/null +++ b/synapse/visibility.py @@ -0,0 +1,229 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 - 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.api.constants import Membership, EventTypes + +from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred + +import logging + + +logger = logging.getLogger(__name__) + + +VISIBILITY_PRIORITY = ( + "world_readable", + "shared", + "invited", + "joined", +) + + +MEMBERSHIP_PRIORITY = ( + Membership.JOIN, + Membership.INVITE, + Membership.KNOCK, + Membership.LEAVE, + Membership.BAN, +) + + +@defer.inlineCallbacks +def filter_events_for_clients(store, user_tuples, events, event_id_to_state): + """ Returns dict of user_id -> list of events that user is allowed to + see. + + Args: + user_tuples (str, bool): (user id, is_peeking) for each user to be + checked. is_peeking should be true if: + * the user is not currently a member of the room, and: + * the user has not been a member of the room since the + given events + events ([synapse.events.EventBase]): list of events to filter + """ + forgotten = yield preserve_context_over_deferred(defer.gatherResults([ + preserve_fn(store.who_forgot_in_room)( + room_id, + ) + for room_id in frozenset(e.room_id for e in events) + ], consumeErrors=True)) + + # Set of membership event_ids that have been forgotten + event_id_forgotten = frozenset( + row["event_id"] for rows in forgotten for row in rows + ) + + ignore_dict_content = yield store.get_global_account_data_by_type_for_users( + "m.ignored_user_list", user_ids=[user_id for user_id, _ in user_tuples] + ) + + # FIXME: This will explode if people upload something incorrect. + ignore_dict = { + user_id: frozenset( + content.get("ignored_users", {}).keys() if content else [] + ) + for user_id, content in ignore_dict_content.items() + } + + def allowed(event, user_id, is_peeking, ignore_list): + """ + Args: + event (synapse.events.EventBase): event to check + user_id (str) + is_peeking (bool) + ignore_list (list): list of users to ignore + """ + if not event.is_state() and event.sender in ignore_list: + return False + + state = event_id_to_state[event.event_id] + + # get the room_visibility at the time of the event. + visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None) + if visibility_event: + visibility = visibility_event.content.get("history_visibility", "shared") + else: + visibility = "shared" + + if visibility not in VISIBILITY_PRIORITY: + visibility = "shared" + + # if it was world_readable, it's easy: everyone can read it + if visibility == "world_readable": + return True + + # Always allow history visibility events on boundaries. This is done + # by setting the effective visibility to the least restrictive + # of the old vs new. + if event.type == EventTypes.RoomHistoryVisibility: + prev_content = event.unsigned.get("prev_content", {}) + prev_visibility = prev_content.get("history_visibility", None) + + if prev_visibility not in VISIBILITY_PRIORITY: + prev_visibility = "shared" + + new_priority = VISIBILITY_PRIORITY.index(visibility) + old_priority = VISIBILITY_PRIORITY.index(prev_visibility) + if old_priority < new_priority: + visibility = prev_visibility + + # likewise, if the event is the user's own membership event, use + # the 'most joined' membership + membership = None + if event.type == EventTypes.Member and event.state_key == user_id: + membership = event.content.get("membership", None) + if membership not in MEMBERSHIP_PRIORITY: + membership = "leave" + + prev_content = event.unsigned.get("prev_content", {}) + prev_membership = prev_content.get("membership", None) + if prev_membership not in MEMBERSHIP_PRIORITY: + prev_membership = "leave" + + new_priority = MEMBERSHIP_PRIORITY.index(membership) + old_priority = MEMBERSHIP_PRIORITY.index(prev_membership) + if old_priority < new_priority: + membership = prev_membership + + # otherwise, get the user's membership at the time of the event. + if membership is None: + membership_event = state.get((EventTypes.Member, user_id), None) + if membership_event: + if membership_event.event_id not in event_id_forgotten: + membership = membership_event.membership + + # if the user was a member of the room at the time of the event, + # they can see it. + if membership == Membership.JOIN: + return True + + if visibility == "joined": + # we weren't a member at the time of the event, so we can't + # see this event. + return False + + elif visibility == "invited": + # user can also see the event if they were *invited* at the time + # of the event. + return membership == Membership.INVITE + + else: + # visibility is shared: user can also see the event if they have + # become a member since the event + # + # XXX: if the user has subsequently joined and then left again, + # ideally we would share history up to the point they left. But + # we don't know when they left. + return not is_peeking + + defer.returnValue({ + user_id: [ + event + for event in events + if allowed(event, user_id, is_peeking, ignore_dict.get(user_id, [])) + ] + for user_id, is_peeking in user_tuples + }) + + +@defer.inlineCallbacks +def filter_events_for_clients_context(store, user_tuples, events, event_id_to_context): + user_ids = set(u[0] for u in user_tuples) + event_id_to_state = {} + for event_id, context in event_id_to_context.items(): + state = yield store.get_events([ + e_id + for key, e_id in context.current_state_ids.iteritems() + if key == (EventTypes.RoomHistoryVisibility, "") + or (key[0] == EventTypes.Member and key[1] in user_ids) + ]) + event_id_to_state[event_id] = state + + res = yield filter_events_for_clients( + store, user_tuples, events, event_id_to_state + ) + defer.returnValue(res) + + +@defer.inlineCallbacks +def filter_events_for_client(store, user_id, events, is_peeking=False): + """ + Check which events a user is allowed to see + + Args: + user_id(str): user id to be checked + events([synapse.events.EventBase]): list of events to be checked + is_peeking(bool): should be True if: + * the user is not currently a member of the room, and: + * the user has not been a member of the room since the given + events + + Returns: + [synapse.events.EventBase] + """ + types = ( + (EventTypes.RoomHistoryVisibility, ""), + (EventTypes.Member, user_id), + ) + event_id_to_state = yield store.get_state_for_events( + frozenset(e.event_id for e in events), + types=types + ) + res = yield filter_events_for_clients( + store, [(user_id, is_peeking)], events, event_id_to_state + ) + defer.returnValue(res.get(user_id, [])) |