From 58c9f206929560044fccae84c36fdd89724ccfc0 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 12 Feb 2016 13:46:59 +0000 Subject: Catch the exceptions thrown by twisted when you write to a closed connection --- synapse/rest/client/v1/login.py | 10 ++++++---- synapse/rest/client/v2_alpha/auth.py | 5 +++-- 2 files changed, 9 insertions(+), 6 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 7199113dac..79101106ac 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -17,6 +17,8 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, LoginError, Codes from synapse.types import UserID +from synapse.http.server import finish_request + from base import ClientV1RestServlet, client_path_patterns import simplejson as json @@ -263,7 +265,7 @@ class SAML2RestServlet(ClientV1RestServlet): '?status=authenticated&access_token=' + token + '&user_id=' + user_id + '&ava=' + urllib.quote(json.dumps(saml2_auth.ava))) - request.finish() + finish_request(request) defer.returnValue(None) defer.returnValue((200, {"status": "authenticated", "user_id": user_id, "token": token, @@ -272,7 +274,7 @@ class SAML2RestServlet(ClientV1RestServlet): request.redirect(urllib.unquote( request.args['RelayState'][0]) + '?status=not_authenticated') - request.finish() + finish_request(request) defer.returnValue(None) defer.returnValue((200, {"status": "not_authenticated"})) @@ -309,7 +311,7 @@ class CasRedirectServlet(ClientV1RestServlet): "service": "%s?%s" % (hs_redirect_url, client_redirect_url_param) }) request.redirect("%s?%s" % (self.cas_server_url, service_param)) - request.finish() + finish_request(request) class CasTicketServlet(ClientV1RestServlet): @@ -362,7 +364,7 @@ class CasTicketServlet(ClientV1RestServlet): redirect_url = self.add_login_token_to_redirect_url(client_redirect_url, login_token) request.redirect(redirect_url) - request.finish() + finish_request(request) def add_login_token_to_redirect_url(self, url, token): url_parts = list(urlparse.urlparse(url)) diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index ff71c40b43..78181b7b18 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -18,6 +18,7 @@ from twisted.internet import defer from synapse.api.constants import LoginType from synapse.api.errors import SynapseError from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX +from synapse.http.server import finish_request from synapse.http.servlet import RestServlet from ._base import client_v2_patterns @@ -130,7 +131,7 @@ class AuthRestServlet(RestServlet): request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) request.write(html_bytes) - request.finish() + finish_request(request) defer.returnValue(None) else: raise SynapseError(404, "Unknown auth stage type") @@ -176,7 +177,7 @@ class AuthRestServlet(RestServlet): request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) request.write(html_bytes) - request.finish() + finish_request(request) defer.returnValue(None) else: -- cgit 1.4.1 From cf81375b94c4763766440471e632fc4b103450ab Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Fri, 12 Feb 2016 15:11:49 +0000 Subject: Merge two of the room join codepaths There's at least one more to merge in. Side-effects: * Stop reporting None as displayname and avatar_url in some cases * Joining a room by alias populates guest-ness in join event * Remove unspec'd PUT version of /join/ which has not been called on matrix.org according to logs * Stop recording access_token_id on /join/room_id - currently we don't record it on /join/room_alias; I can try to thread it through at some point. --- synapse/api/errors.py | 5 ++++ synapse/handlers/profile.py | 11 +++++-- synapse/handlers/room.py | 44 +++++++++++++++++++++------ synapse/rest/client/v1/room.py | 68 ++++++++---------------------------------- synapse/types.py | 14 +++++++-- 5 files changed, 73 insertions(+), 69 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/api/errors.py b/synapse/api/errors.py index b106fbed6d..0c7858f78d 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -84,6 +84,11 @@ class RegistrationError(SynapseError): pass +class BadIdentifierError(SynapseError): + """An error indicating an identifier couldn't be parsed.""" + pass + + class UnrecognizedRequestError(SynapseError): """An error indicating we don't understand the request you're trying to make""" def __init__(self, *args, **kwargs): diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 629e6e3594..32af622733 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -169,8 +169,15 @@ class ProfileHandler(BaseHandler): consumeErrors=True ).addErrback(unwrapFirstError) - state["displayname"] = displayname - state["avatar_url"] = avatar_url + if displayname is None: + del state["displayname"] + else: + state["displayname"] = displayname + + if avatar_url is None: + del state["avatar_url"] + else: + state["avatar_url"] = avatar_url defer.returnValue(None) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index b2de2cd0c0..2950ed14e4 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -527,7 +527,17 @@ class RoomMemberHandler(BaseHandler): defer.returnValue({"room_id": room_id}) @defer.inlineCallbacks - def join_room_alias(self, joinee, room_alias, content={}): + def lookup_room_alias(self, room_alias): + """ + Gets the room ID for an alias. + + Args: + room_alias (str): The room alias to look up. + Returns: + A tuple of the room ID (str) and the hosts hosting the room ([str]) + Raises: + SynapseError if the room couldn't be looked up. + """ directory_handler = self.hs.get_handlers().directory_handler mapping = yield directory_handler.get_association(room_alias) @@ -539,24 +549,40 @@ class RoomMemberHandler(BaseHandler): if not hosts: raise SynapseError(404, "No known servers") - # If event doesn't include a display name, add one. - yield collect_presencelike_data(self.distributor, joinee, content) + defer.returnValue((room_id, hosts)) + + @defer.inlineCallbacks + def do_join(self, requester, room_id, hosts=None): + """ + Joins requester to room_id. + + Args: + requester (Requester): The user joining the room. + room_id (str): The room ID (not alias) being joined. + hosts ([str]): A list of hosts which are hopefully in the room. + Raises: + SynapseError if the room couldn't be joined. + """ + hosts = hosts or [] + + content = {"membership": Membership.JOIN} + if requester.is_guest: + content["kind"] = "guest" + + yield collect_presencelike_data(self.distributor, requester.user, content) - content.update({"membership": Membership.JOIN}) builder = self.event_builder_factory.new({ "type": EventTypes.Member, - "state_key": joinee.to_string(), + "state_key": requester.user.to_string(), "room_id": room_id, - "sender": joinee.to_string(), - "membership": Membership.JOIN, + "sender": requester.user.to_string(), + "membership": Membership.JOIN, # For backwards compatibility "content": content, }) event, context = yield self._create_new_client_event(builder) yield self._do_join(event, context, room_hosts=hosts) - defer.returnValue({"room_id": room_id}) - @defer.inlineCallbacks def _do_join(self, event, context, room_hosts=None): room_id = event.room_id diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 81bfe377bd..1dd33b0a56 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -216,11 +216,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet): # TODO: Needs unit testing for room ID + alias joins class JoinRoomAliasServlet(ClientV1RestServlet): - - def register(self, http_server): - # /join/$room_identifier[/$txn_id] - PATTERNS = ("/join/(?P[^/]*)") - register_txn_path(self, PATTERNS, http_server) + PATTERNS = client_path_patterns("/join/(?P[^/]*)$") @defer.inlineCallbacks def on_POST(self, request, room_identifier, txn_id=None): @@ -229,60 +225,22 @@ class JoinRoomAliasServlet(ClientV1RestServlet): allow_guest=True, ) - # the identifier could be a room alias or a room id. Try one then the - # other if it fails to parse, without swallowing other valid - # SynapseErrors. + handler = self.handlers.room_member_handler - identifier = None - is_room_alias = False - try: - identifier = RoomAlias.from_string(room_identifier) - is_room_alias = True - except SynapseError: - identifier = RoomID.from_string(room_identifier) + room_id = None + hosts = [] + if RoomAlias.is_valid(room_identifier): + room_alias = RoomAlias.from_string(room_identifier) + room_id, hosts = yield handler.lookup_room_alias(room_alias) + else: + room_id = RoomID.from_string(room_identifier).to_string() # TODO: Support for specifying the home server to join with? - if is_room_alias: - handler = self.handlers.room_member_handler - ret_dict = yield handler.join_room_alias( - requester.user, - identifier, - ) - defer.returnValue((200, ret_dict)) - else: # room id - msg_handler = self.handlers.message_handler - content = {"membership": Membership.JOIN} - if requester.is_guest: - content["kind"] = "guest" - yield msg_handler.create_and_send_event( - { - "type": EventTypes.Member, - "content": content, - "room_id": identifier.to_string(), - "sender": requester.user.to_string(), - "state_key": requester.user.to_string(), - }, - token_id=requester.access_token_id, - txn_id=txn_id, - is_guest=requester.is_guest, - ) - - defer.returnValue((200, {"room_id": identifier.to_string()})) - - @defer.inlineCallbacks - def on_PUT(self, request, room_identifier, txn_id): - try: - defer.returnValue( - self.txns.get_client_transaction(request, txn_id) - ) - except KeyError: - pass - - response = yield self.on_POST(request, room_identifier, txn_id) - - self.txns.store_client_transaction(request, txn_id, response) - defer.returnValue(response) + yield handler.do_join( + requester, room_id, hosts=hosts + ) + defer.returnValue((200, {"room_id": room_id})) # TODO: Needs unit testing diff --git a/synapse/types.py b/synapse/types.py index 2095837ba6..0be8384e18 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.api.errors import SynapseError +from synapse.api.errors import SynapseError, BadIdentifierError from collections import namedtuple @@ -51,13 +51,13 @@ class DomainSpecificString( def from_string(cls, s): """Parse the string given by 's' into a structure object.""" if len(s) < 1 or s[0] != cls.SIGIL: - raise SynapseError(400, "Expected %s string to start with '%s'" % ( + raise BadIdentifierError(400, "Expected %s string to start with '%s'" % ( cls.__name__, cls.SIGIL, )) parts = s[1:].split(':', 1) if len(parts) != 2: - raise SynapseError( + raise BadIdentifierError( 400, "Expected %s of the form '%slocalname:domain'" % ( cls.__name__, cls.SIGIL, ) @@ -69,6 +69,14 @@ class DomainSpecificString( # names on one HS return cls(localpart=parts[0], domain=domain) + @classmethod + def is_valid(cls, s): + try: + cls.from_string(s) + return True + except: + return False + def to_string(self): """Return a string encoding the fields of the structure object.""" return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain) -- cgit 1.4.1 From 4de08a4672c62eebda2ad3ee89643c2c32242cbf Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Fri, 12 Feb 2016 16:17:24 +0000 Subject: Revert "Merge two of the room join codepaths" This reverts commit cf81375b94c4763766440471e632fc4b103450ab. It subtly violates a guest joining auth check --- synapse/api/errors.py | 5 ---- synapse/handlers/profile.py | 11 ++----- synapse/handlers/room.py | 44 ++++++--------------------- synapse/rest/client/v1/room.py | 68 ++++++++++++++++++++++++++++++++++-------- synapse/types.py | 14 ++------- 5 files changed, 69 insertions(+), 73 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 0c7858f78d..b106fbed6d 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -84,11 +84,6 @@ class RegistrationError(SynapseError): pass -class BadIdentifierError(SynapseError): - """An error indicating an identifier couldn't be parsed.""" - pass - - class UnrecognizedRequestError(SynapseError): """An error indicating we don't understand the request you're trying to make""" def __init__(self, *args, **kwargs): diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 32af622733..629e6e3594 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -169,15 +169,8 @@ class ProfileHandler(BaseHandler): consumeErrors=True ).addErrback(unwrapFirstError) - if displayname is None: - del state["displayname"] - else: - state["displayname"] = displayname - - if avatar_url is None: - del state["avatar_url"] - else: - state["avatar_url"] = avatar_url + state["displayname"] = displayname + state["avatar_url"] = avatar_url defer.returnValue(None) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 2950ed14e4..b2de2cd0c0 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -527,17 +527,7 @@ class RoomMemberHandler(BaseHandler): defer.returnValue({"room_id": room_id}) @defer.inlineCallbacks - def lookup_room_alias(self, room_alias): - """ - Gets the room ID for an alias. - - Args: - room_alias (str): The room alias to look up. - Returns: - A tuple of the room ID (str) and the hosts hosting the room ([str]) - Raises: - SynapseError if the room couldn't be looked up. - """ + def join_room_alias(self, joinee, room_alias, content={}): directory_handler = self.hs.get_handlers().directory_handler mapping = yield directory_handler.get_association(room_alias) @@ -549,40 +539,24 @@ class RoomMemberHandler(BaseHandler): if not hosts: raise SynapseError(404, "No known servers") - defer.returnValue((room_id, hosts)) - - @defer.inlineCallbacks - def do_join(self, requester, room_id, hosts=None): - """ - Joins requester to room_id. - - Args: - requester (Requester): The user joining the room. - room_id (str): The room ID (not alias) being joined. - hosts ([str]): A list of hosts which are hopefully in the room. - Raises: - SynapseError if the room couldn't be joined. - """ - hosts = hosts or [] - - content = {"membership": Membership.JOIN} - if requester.is_guest: - content["kind"] = "guest" - - yield collect_presencelike_data(self.distributor, requester.user, content) + # If event doesn't include a display name, add one. + yield collect_presencelike_data(self.distributor, joinee, content) + content.update({"membership": Membership.JOIN}) builder = self.event_builder_factory.new({ "type": EventTypes.Member, - "state_key": requester.user.to_string(), + "state_key": joinee.to_string(), "room_id": room_id, - "sender": requester.user.to_string(), - "membership": Membership.JOIN, # For backwards compatibility + "sender": joinee.to_string(), + "membership": Membership.JOIN, "content": content, }) event, context = yield self._create_new_client_event(builder) yield self._do_join(event, context, room_hosts=hosts) + defer.returnValue({"room_id": room_id}) + @defer.inlineCallbacks def _do_join(self, event, context, room_hosts=None): room_id = event.room_id diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 1dd33b0a56..81bfe377bd 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -216,7 +216,11 @@ class RoomSendEventRestServlet(ClientV1RestServlet): # TODO: Needs unit testing for room ID + alias joins class JoinRoomAliasServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/join/(?P[^/]*)$") + + def register(self, http_server): + # /join/$room_identifier[/$txn_id] + PATTERNS = ("/join/(?P[^/]*)") + register_txn_path(self, PATTERNS, http_server) @defer.inlineCallbacks def on_POST(self, request, room_identifier, txn_id=None): @@ -225,22 +229,60 @@ class JoinRoomAliasServlet(ClientV1RestServlet): allow_guest=True, ) - handler = self.handlers.room_member_handler + # the identifier could be a room alias or a room id. Try one then the + # other if it fails to parse, without swallowing other valid + # SynapseErrors. - room_id = None - hosts = [] - if RoomAlias.is_valid(room_identifier): - room_alias = RoomAlias.from_string(room_identifier) - room_id, hosts = yield handler.lookup_room_alias(room_alias) - else: - room_id = RoomID.from_string(room_identifier).to_string() + identifier = None + is_room_alias = False + try: + identifier = RoomAlias.from_string(room_identifier) + is_room_alias = True + except SynapseError: + identifier = RoomID.from_string(room_identifier) # TODO: Support for specifying the home server to join with? - yield handler.do_join( - requester, room_id, hosts=hosts - ) - defer.returnValue((200, {"room_id": room_id})) + if is_room_alias: + handler = self.handlers.room_member_handler + ret_dict = yield handler.join_room_alias( + requester.user, + identifier, + ) + defer.returnValue((200, ret_dict)) + else: # room id + msg_handler = self.handlers.message_handler + content = {"membership": Membership.JOIN} + if requester.is_guest: + content["kind"] = "guest" + yield msg_handler.create_and_send_event( + { + "type": EventTypes.Member, + "content": content, + "room_id": identifier.to_string(), + "sender": requester.user.to_string(), + "state_key": requester.user.to_string(), + }, + token_id=requester.access_token_id, + txn_id=txn_id, + is_guest=requester.is_guest, + ) + + defer.returnValue((200, {"room_id": identifier.to_string()})) + + @defer.inlineCallbacks + def on_PUT(self, request, room_identifier, txn_id): + try: + defer.returnValue( + self.txns.get_client_transaction(request, txn_id) + ) + except KeyError: + pass + + response = yield self.on_POST(request, room_identifier, txn_id) + + self.txns.store_client_transaction(request, txn_id, response) + defer.returnValue(response) # TODO: Needs unit testing diff --git a/synapse/types.py b/synapse/types.py index 0be8384e18..2095837ba6 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.api.errors import SynapseError, BadIdentifierError +from synapse.api.errors import SynapseError from collections import namedtuple @@ -51,13 +51,13 @@ class DomainSpecificString( def from_string(cls, s): """Parse the string given by 's' into a structure object.""" if len(s) < 1 or s[0] != cls.SIGIL: - raise BadIdentifierError(400, "Expected %s string to start with '%s'" % ( + raise SynapseError(400, "Expected %s string to start with '%s'" % ( cls.__name__, cls.SIGIL, )) parts = s[1:].split(':', 1) if len(parts) != 2: - raise BadIdentifierError( + raise SynapseError( 400, "Expected %s of the form '%slocalname:domain'" % ( cls.__name__, cls.SIGIL, ) @@ -69,14 +69,6 @@ class DomainSpecificString( # names on one HS return cls(localpart=parts[0], domain=domain) - @classmethod - def is_valid(cls, s): - try: - cls.from_string(s) - return True - except: - return False - def to_string(self): """Return a string encoding the fields of the structure object.""" return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain) -- cgit 1.4.1 From dbeed36dec021df3036e088910c72d5727910dd3 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Mon, 15 Feb 2016 14:38:27 +0000 Subject: Merge some room joining codepaths Force joining by alias to go through the send_membership_event checks, rather than bypassing them straight into _do_join. This is the first of many stages of cleanup. --- synapse/handlers/room.py | 14 ++++++++++---- synapse/rest/client/v1/room.py | 2 +- 2 files changed, 11 insertions(+), 5 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index b2de2cd0c0..89695cc0cf 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -455,7 +455,7 @@ class RoomMemberHandler(BaseHandler): yield self.forget(requester.user, room_id) @defer.inlineCallbacks - def send_membership_event(self, event, context, is_guest=False): + def send_membership_event(self, event, context, is_guest=False, room_hosts=None): """ Change the membership status of a user in a room. Args: @@ -490,7 +490,7 @@ class RoomMemberHandler(BaseHandler): if not is_guest_access_allowed: raise AuthError(403, "Guest access not allowed") - yield self._do_join(event, context) + yield self._do_join(event, context, room_hosts=room_hosts) else: if event.membership == Membership.LEAVE: is_host_in_room = yield self.is_host_in_room(room_id, context) @@ -527,7 +527,8 @@ class RoomMemberHandler(BaseHandler): defer.returnValue({"room_id": room_id}) @defer.inlineCallbacks - def join_room_alias(self, joinee, room_alias, content={}): + def join_room_alias(self, requester, room_alias, content={}): + joinee = requester.user directory_handler = self.hs.get_handlers().directory_handler mapping = yield directory_handler.get_association(room_alias) @@ -553,7 +554,12 @@ class RoomMemberHandler(BaseHandler): }) event, context = yield self._create_new_client_event(builder) - yield self._do_join(event, context, room_hosts=hosts) + yield self.send_membership_event( + event, + context, + is_guest=requester.is_guest, + room_hosts=hosts + ) defer.returnValue({"room_id": room_id}) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 81bfe377bd..76025213dc 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -246,7 +246,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet): if is_room_alias: handler = self.handlers.room_member_handler ret_dict = yield handler.join_room_alias( - requester.user, + requester, identifier, ) defer.returnValue((200, ret_dict)) -- cgit 1.4.1 From e71095801fc376aac30ff9408ae7f0203684024d Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Mon, 15 Feb 2016 15:39:16 +0000 Subject: Merge implementation of /join by alias or ID This code is kind of rough (passing the remote servers down a long chain), but is a step towards improvement. --- synapse/handlers/_base.py | 5 +++- synapse/handlers/message.py | 20 ++++++++----- synapse/handlers/room.py | 40 ++++++++++--------------- synapse/rest/client/v1/room.py | 68 +++++++++++++++++++----------------------- synapse/types.py | 8 +++++ 5 files changed, 71 insertions(+), 70 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 064e8723c8..8508ecdd49 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -188,9 +188,12 @@ class BaseHandler(object): ) @defer.inlineCallbacks - def handle_new_client_event(self, event, context, extra_users=[]): + def handle_new_client_event(self, event, context, ratelimit=True, extra_users=[]): # We now need to go and hit out to wherever we need to hit out to. + if ratelimit: + self.ratelimit(event.sender) + self.auth.check(event, auth_events=context.current_state) yield self.maybe_kick_guest_users(event, context.current_state.values()) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 82c8cb5f0c..a94fad1735 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -216,7 +216,7 @@ class MessageHandler(BaseHandler): defer.returnValue((event, context)) @defer.inlineCallbacks - def send_event(self, event, context, ratelimit=True, is_guest=False): + def send_event(self, event, context, ratelimit=True, is_guest=False, room_hosts=None): """ Persists and notifies local clients and federation of an event. @@ -230,9 +230,6 @@ class MessageHandler(BaseHandler): assert self.hs.is_mine(user), "User must be our own: %s" % (user,) - if ratelimit: - self.ratelimit(event.sender) - if event.is_state(): prev_state = context.current_state.get((event.type, event.state_key)) if prev_state and event.user_id == prev_state.user_id: @@ -245,11 +242,18 @@ class MessageHandler(BaseHandler): if event.type == EventTypes.Member: member_handler = self.hs.get_handlers().room_member_handler - yield member_handler.send_membership_event(event, context, is_guest=is_guest) + yield member_handler.send_membership_event( + event, + context, + is_guest=is_guest, + ratelimit=ratelimit, + room_hosts=room_hosts + ) else: yield self.handle_new_client_event( event=event, context=context, + ratelimit=ratelimit, ) if event.type == EventTypes.Message: @@ -259,7 +263,8 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def create_and_send_event(self, event_dict, ratelimit=True, - token_id=None, txn_id=None, is_guest=False): + token_id=None, txn_id=None, is_guest=False, + room_hosts=None): """ Creates an event, then sends it. @@ -274,7 +279,8 @@ class MessageHandler(BaseHandler): event, context, ratelimit=ratelimit, - is_guest=is_guest + is_guest=is_guest, + room_hosts=room_hosts, ) defer.returnValue(event) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 89695cc0cf..b748e81d20 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -455,7 +455,9 @@ class RoomMemberHandler(BaseHandler): yield self.forget(requester.user, room_id) @defer.inlineCallbacks - def send_membership_event(self, event, context, is_guest=False, room_hosts=None): + def send_membership_event( + self, event, context, is_guest=False, room_hosts=None, ratelimit=True + ): """ Change the membership status of a user in a room. Args: @@ -527,8 +529,17 @@ class RoomMemberHandler(BaseHandler): defer.returnValue({"room_id": room_id}) @defer.inlineCallbacks - def join_room_alias(self, requester, room_alias, content={}): - joinee = requester.user + def lookup_room_alias(self, room_alias): + """ + Get the room ID associated with a room alias. + + Args: + room_alias (RoomAlias): The alias to look up. + Returns: + The room ID as a RoomID object. + Raises: + SynapseError if room alias could not be found. + """ directory_handler = self.hs.get_handlers().directory_handler mapping = yield directory_handler.get_association(room_alias) @@ -540,28 +551,7 @@ class RoomMemberHandler(BaseHandler): if not hosts: raise SynapseError(404, "No known servers") - # If event doesn't include a display name, add one. - yield collect_presencelike_data(self.distributor, joinee, content) - - content.update({"membership": Membership.JOIN}) - builder = self.event_builder_factory.new({ - "type": EventTypes.Member, - "state_key": joinee.to_string(), - "room_id": room_id, - "sender": joinee.to_string(), - "membership": Membership.JOIN, - "content": content, - }) - event, context = yield self._create_new_client_event(builder) - - yield self.send_membership_event( - event, - context, - is_guest=requester.is_guest, - room_hosts=hosts - ) - - defer.returnValue({"room_id": room_id}) + defer.returnValue((RoomID.from_string(room_id), hosts)) @defer.inlineCallbacks def _do_join(self, event, context, room_hosts=None): diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 76025213dc..340c24635d 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -229,46 +229,40 @@ class JoinRoomAliasServlet(ClientV1RestServlet): allow_guest=True, ) - # the identifier could be a room alias or a room id. Try one then the - # other if it fails to parse, without swallowing other valid - # SynapseErrors. - - identifier = None - is_room_alias = False - try: - identifier = RoomAlias.from_string(room_identifier) - is_room_alias = True - except SynapseError: - identifier = RoomID.from_string(room_identifier) + if RoomID.is_valid(room_identifier): + room_id = room_identifier + room_hosts = None + elif RoomAlias.is_valid(room_identifier): + handler = self.handlers.room_member_handler + room_alias = RoomAlias.from_string(room_identifier) + room_id, room_hosts = yield handler.lookup_room_alias(room_alias) + room_id = room_id.to_string() + else: + raise SynapseError(400, "%s was not legal room ID or room alias" % ( + room_identifier, + )) - # TODO: Support for specifying the home server to join with? + msg_handler = self.handlers.message_handler + content = {"membership": Membership.JOIN} + if requester.is_guest: + content["kind"] = "guest" + yield msg_handler.create_and_send_event( + { + "type": EventTypes.Member, + "content": content, + "room_id": room_id, + "sender": requester.user.to_string(), + "state_key": requester.user.to_string(), - if is_room_alias: - handler = self.handlers.room_member_handler - ret_dict = yield handler.join_room_alias( - requester, - identifier, - ) - defer.returnValue((200, ret_dict)) - else: # room id - msg_handler = self.handlers.message_handler - content = {"membership": Membership.JOIN} - if requester.is_guest: - content["kind"] = "guest" - yield msg_handler.create_and_send_event( - { - "type": EventTypes.Member, - "content": content, - "room_id": identifier.to_string(), - "sender": requester.user.to_string(), - "state_key": requester.user.to_string(), - }, - token_id=requester.access_token_id, - txn_id=txn_id, - is_guest=requester.is_guest, - ) + "membership": Membership.JOIN, # For backwards compatibility + }, + token_id=requester.access_token_id, + txn_id=txn_id, + is_guest=requester.is_guest, + room_hosts=room_hosts, + ) - defer.returnValue((200, {"room_id": identifier.to_string()})) + defer.returnValue((200, {"room_id": room_id})) @defer.inlineCallbacks def on_PUT(self, request, room_identifier, txn_id): diff --git a/synapse/types.py b/synapse/types.py index 2095837ba6..d5bd95cbd3 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -73,6 +73,14 @@ class DomainSpecificString( """Return a string encoding the fields of the structure object.""" return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain) + @classmethod + def is_valid(cls, s): + try: + cls.from_string(s) + return True + except: + return False + __str__ = to_string @classmethod -- cgit 1.4.1 From 150fcde0dce02670c2180f9d4657783eb204daa8 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Mon, 15 Feb 2016 16:16:03 +0000 Subject: Reuse update_membership from /join --- synapse/handlers/room.py | 12 +++++++++--- synapse/rest/client/v1/room.py | 21 +++++---------------- 2 files changed, 14 insertions(+), 19 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index d17e5c1b7b..04916d4e24 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -403,7 +403,9 @@ class RoomMemberHandler(BaseHandler): remotedomains.add(member.domain) @defer.inlineCallbacks - def update_membership(self, requester, target, room_id, action, txn_id=None): + def update_membership( + self, requester, target, room_id, action, txn_id=None, room_hosts=None + ): effective_membership_state = action if action in ["kick", "unban"]: effective_membership_state = "leave" @@ -412,7 +414,7 @@ class RoomMemberHandler(BaseHandler): msg_handler = self.hs.get_handlers().message_handler - content = {"membership": unicode(effective_membership_state)} + content = {"membership": effective_membership_state} if requester.is_guest: content["kind"] = "guest" @@ -423,6 +425,9 @@ class RoomMemberHandler(BaseHandler): "room_id": room_id, "sender": requester.user.to_string(), "state_key": target.to_string(), + + # For backwards compatibility: + "membership": effective_membership_state, }, token_id=requester.access_token_id, txn_id=txn_id, @@ -447,7 +452,8 @@ class RoomMemberHandler(BaseHandler): event, context, ratelimit=True, - is_guest=requester.is_guest + is_guest=requester.is_guest, + room_hosts=room_hosts, ) if action == "forget": diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 340c24635d..f8cd746a88 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -242,23 +242,12 @@ class JoinRoomAliasServlet(ClientV1RestServlet): room_identifier, )) - msg_handler = self.handlers.message_handler - content = {"membership": Membership.JOIN} - if requester.is_guest: - content["kind"] = "guest" - yield msg_handler.create_and_send_event( - { - "type": EventTypes.Member, - "content": content, - "room_id": room_id, - "sender": requester.user.to_string(), - "state_key": requester.user.to_string(), - - "membership": Membership.JOIN, # For backwards compatibility - }, - token_id=requester.access_token_id, + yield self.handlers.room_member_handler.update_membership( + requester=requester, + target=requester.user, + room_id=room_id, + action="join", txn_id=txn_id, - is_guest=requester.is_guest, room_hosts=room_hosts, ) -- cgit 1.4.1 From e560045cfd73e2dbfa6a272fc298dab820e6e943 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Mon, 15 Feb 2016 18:13:10 +0000 Subject: Simplify room creation code --- synapse/handlers/room.py | 68 +++++++++++++++++------------------------- synapse/rest/client/v1/room.py | 18 ++--------- 2 files changed, 31 insertions(+), 55 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index b2de2cd0c0..8e3c86d3a7 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -81,20 +81,20 @@ class RoomCreationHandler(BaseHandler): } @defer.inlineCallbacks - def create_room(self, user_id, room_id, config): + def create_room(self, requester, config): """ Creates a new room. Args: - user_id (str): The ID of the user creating the new room. - room_id (str): The proposed ID for the new room. Can be None, in - which case one will be created for you. + requester (Requester): The user who requested the room creation. config (dict) : A dict of configuration options. Returns: The new room ID. Raises: - SynapseError if the room ID was taken, couldn't be stored, or - something went horribly wrong. + SynapseError if the room ID couldn't be stored, or something went + horribly wrong. """ + user_id = requester.user.to_string() + self.ratelimit(user_id) if "room_alias_name" in config: @@ -126,40 +126,28 @@ class RoomCreationHandler(BaseHandler): is_public = config.get("visibility", None) == "public" - if room_id: - # Ensure room_id is the correct type - room_id_obj = RoomID.from_string(room_id) - if not self.hs.is_mine(room_id_obj): - raise SynapseError(400, "Room id must be local") - - yield self.store.store_room( - room_id=room_id, - room_creator_user_id=user_id, - is_public=is_public - ) - else: - # autogen room IDs and try to create it. We may clash, so just - # try a few times till one goes through, giving up eventually. - attempts = 0 - room_id = None - while attempts < 5: - try: - random_string = stringutils.random_string(18) - gen_room_id = RoomID.create( - random_string, - self.hs.hostname, - ) - yield self.store.store_room( - room_id=gen_room_id.to_string(), - room_creator_user_id=user_id, - is_public=is_public - ) - room_id = gen_room_id.to_string() - break - except StoreError: - attempts += 1 - if not room_id: - raise StoreError(500, "Couldn't generate a room ID.") + # autogen room IDs and try to create it. We may clash, so just + # try a few times till one goes through, giving up eventually. + attempts = 0 + room_id = None + while attempts < 5: + try: + random_string = stringutils.random_string(18) + gen_room_id = RoomID.create( + random_string, + self.hs.hostname, + ) + yield self.store.store_room( + room_id=gen_room_id.to_string(), + room_creator_user_id=user_id, + is_public=is_public + ) + room_id = gen_room_id.to_string() + break + except StoreError: + attempts += 1 + if not room_id: + raise StoreError(500, "Couldn't generate a room ID.") if room_alias: directory_handler = self.hs.get_handlers().directory_handler diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 81bfe377bd..d3c1b359a2 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -63,24 +63,12 @@ class RoomCreateRestServlet(ClientV1RestServlet): def on_POST(self, request): requester = yield self.auth.get_user_by_req(request) - room_config = self.get_room_config(request) - info = yield self.make_room( - room_config, - requester.user, - None, - ) - room_config.update(info) - defer.returnValue((200, info)) - - @defer.inlineCallbacks - def make_room(self, room_config, auth_user, room_id): handler = self.handlers.room_creation_handler info = yield handler.create_room( - user_id=auth_user.to_string(), - room_id=room_id, - config=room_config + requester, self.get_room_config(request) ) - defer.returnValue(info) + + defer.returnValue((200, info)) def get_room_config(self, request): try: -- cgit 1.4.1 From 1a2197d7bf62437208643f750ee757b8b85e2db6 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Mon, 15 Feb 2016 18:13:10 +0000 Subject: Simplify room creation code --- synapse/handlers/room.py | 62 +++++++++++++++++------------------------- synapse/rest/client/v1/room.py | 18 ++---------- 2 files changed, 28 insertions(+), 52 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 505fb383ec..bdaa05e0b6 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -76,13 +76,11 @@ class RoomCreationHandler(BaseHandler): } @defer.inlineCallbacks - def create_room(self, user_id, room_id, config): + def create_room(self, requester, config): """ Creates a new room. Args: user_id (str): The ID of the user creating the new room. - room_id (str): The proposed ID for the new room. Can be None, in - which case one will be created for you. config (dict) : A dict of configuration options. Returns: The new room ID. @@ -90,6 +88,8 @@ class RoomCreationHandler(BaseHandler): SynapseError if the room ID was taken, couldn't be stored, or something went horribly wrong. """ + user_id = requester.user.to_string() + self.ratelimit(user_id) if "room_alias_name" in config: @@ -121,40 +121,28 @@ class RoomCreationHandler(BaseHandler): is_public = config.get("visibility", None) == "public" - if room_id: - # Ensure room_id is the correct type - room_id_obj = RoomID.from_string(room_id) - if not self.hs.is_mine(room_id_obj): - raise SynapseError(400, "Room id must be local") - - yield self.store.store_room( - room_id=room_id, - room_creator_user_id=user_id, - is_public=is_public - ) - else: - # autogen room IDs and try to create it. We may clash, so just - # try a few times till one goes through, giving up eventually. - attempts = 0 - room_id = None - while attempts < 5: - try: - random_string = stringutils.random_string(18) - gen_room_id = RoomID.create( - random_string, - self.hs.hostname, - ) - yield self.store.store_room( - room_id=gen_room_id.to_string(), - room_creator_user_id=user_id, - is_public=is_public - ) - room_id = gen_room_id.to_string() - break - except StoreError: - attempts += 1 - if not room_id: - raise StoreError(500, "Couldn't generate a room ID.") + # autogen room IDs and try to create it. We may clash, so just + # try a few times till one goes through, giving up eventually. + attempts = 0 + room_id = None + while attempts < 5: + try: + random_string = stringutils.random_string(18) + gen_room_id = RoomID.create( + random_string, + self.hs.hostname, + ) + yield self.store.store_room( + room_id=gen_room_id.to_string(), + room_creator_user_id=user_id, + is_public=is_public + ) + room_id = gen_room_id.to_string() + break + except StoreError: + attempts += 1 + if not room_id: + raise StoreError(500, "Couldn't generate a room ID.") if room_alias: directory_handler = self.hs.get_handlers().directory_handler diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index f8cd746a88..5f5c26a91c 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -63,24 +63,12 @@ class RoomCreateRestServlet(ClientV1RestServlet): def on_POST(self, request): requester = yield self.auth.get_user_by_req(request) - room_config = self.get_room_config(request) - info = yield self.make_room( - room_config, - requester.user, - None, - ) - room_config.update(info) - defer.returnValue((200, info)) - - @defer.inlineCallbacks - def make_room(self, room_config, auth_user, room_id): handler = self.handlers.room_creation_handler info = yield handler.create_room( - user_id=auth_user.to_string(), - room_id=room_id, - config=room_config + requester, self.get_room_config(request) ) - defer.returnValue(info) + + defer.returnValue((200, info)) def get_room_config(self, request): try: -- cgit 1.4.1 From 4bfb32f685cff919141a3fc0cd9179447febc765 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Mon, 15 Feb 2016 18:21:30 +0000 Subject: Branch off member and non member sends Unclean, needs tidy-up, but works --- synapse/handlers/directory.py | 2 +- synapse/handlers/federation.py | 4 +-- synapse/handlers/message.py | 66 +++++++++++++++++----------------- synapse/handlers/room.py | 80 ++++++++++++++++++++++++------------------ synapse/rest/client/v1/room.py | 21 ++++++++--- 5 files changed, 99 insertions(+), 74 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 4efecb1ffd..e0a778e7ff 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -216,7 +216,7 @@ class DirectoryHandler(BaseHandler): aliases = yield self.store.get_aliases_for_room(room_id) msg_handler = self.hs.get_handlers().message_handler - yield msg_handler.create_and_send_event({ + yield msg_handler.create_and_send_nonmember_event({ "type": EventTypes.Aliases, "state_key": self.hs.hostname, "room_id": room_id, diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index da55d43541..ac15f9e5dd 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1658,7 +1658,7 @@ class FederationHandler(BaseHandler): self.auth.check(event, context.current_state) yield self._validate_keyserver(event, auth_events=context.current_state) member_handler = self.hs.get_handlers().room_member_handler - yield member_handler.send_membership_event(event, context) + yield member_handler.send_membership_event(event, context, from_client=False) else: destinations = set([x.split(":", 1)[-1] for x in (sender, room_id)]) yield self.replication_layer.forward_third_party_invite( @@ -1687,7 +1687,7 @@ class FederationHandler(BaseHandler): # TODO: Make sure the signatures actually are correct. event.signatures.update(returned_invite.signatures) member_handler = self.hs.get_handlers().room_member_handler - yield member_handler.send_membership_event(event, context) + yield member_handler.send_membership_event(event, context, from_client=False) @defer.inlineCallbacks def add_display_name_to_third_party_invite(self, event_dict, event, context): diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index a94fad1735..05dab172b8 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -16,7 +16,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import AuthError, Codes +from synapse.api.errors import AuthError, Codes, SynapseError from synapse.streams.config import PaginationConfig from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator @@ -216,7 +216,7 @@ class MessageHandler(BaseHandler): defer.returnValue((event, context)) @defer.inlineCallbacks - def send_event(self, event, context, ratelimit=True, is_guest=False, room_hosts=None): + def send_nonmember_event(self, event, context, ratelimit=True): """ Persists and notifies local clients and federation of an event. @@ -226,61 +226,63 @@ class MessageHandler(BaseHandler): ratelimit (bool): Whether to rate limit this send. is_guest (bool): Whether the sender is a guest. """ + if event.type == EventTypes.Member: + raise SynapseError( + 500, + "Tried to send member even through non-member codepath" + ) + user = UserID.from_string(event.sender) assert self.hs.is_mine(user), "User must be our own: %s" % (user,) if event.is_state(): - prev_state = context.current_state.get((event.type, event.state_key)) - if prev_state and event.user_id == prev_state.user_id: - prev_content = encode_canonical_json(prev_state.content) - next_content = encode_canonical_json(event.content) - if prev_content == next_content: - # Duplicate suppression for state updates with same sender - # and content. - defer.returnValue(prev_state) + prev_state = self.deduplicate_state_event(event, context) + if prev_state is not None: + defer.returnValue(prev_state) - if event.type == EventTypes.Member: - member_handler = self.hs.get_handlers().room_member_handler - yield member_handler.send_membership_event( - event, - context, - is_guest=is_guest, - ratelimit=ratelimit, - room_hosts=room_hosts - ) - else: - yield self.handle_new_client_event( - event=event, - context=context, - ratelimit=ratelimit, - ) + yield self.handle_new_client_event( + event=event, + context=context, + ratelimit=ratelimit, + ) if event.type == EventTypes.Message: presence = self.hs.get_handlers().presence_handler with PreserveLoggingContext(): presence.bump_presence_active_time(user) + def deduplicate_state_event(self, event, context): + prev_state = context.current_state.get((event.type, event.state_key)) + if prev_state and event.user_id == prev_state.user_id: + prev_content = encode_canonical_json(prev_state.content) + next_content = encode_canonical_json(event.content) + if prev_content == next_content: + return prev_state + return None + @defer.inlineCallbacks - def create_and_send_event(self, event_dict, ratelimit=True, - token_id=None, txn_id=None, is_guest=False, - room_hosts=None): + def create_and_send_nonmember_event( + self, + event_dict, + ratelimit=True, + token_id=None, + txn_id=None + ): """ Creates an event, then sends it. - See self.create_event and self.send_event. + See self.create_event and self.send_nonmember_event. """ event, context = yield self.create_event( event_dict, token_id=token_id, txn_id=txn_id ) - yield self.send_event( + yield self.send_nonmember_event( event, context, ratelimit=ratelimit, - is_guest=is_guest, - room_hosts=room_hosts, ) defer.returnValue(event) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index bdaa05e0b6..5d4e87b3b0 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -179,13 +179,24 @@ class RoomCreationHandler(BaseHandler): ) msg_handler = self.hs.get_handlers().message_handler + room_member_handler = self.hs.get_handlers().room_member_handler for event in creation_events: - yield msg_handler.create_and_send_event(event, ratelimit=False) + if event["type"] == EventTypes.Member: + # TODO(danielwh): This is hideous + yield room_member_handler.update_membership( + requester, + user, + room_id, + "join", + ratelimit=False, + ) + else: + yield msg_handler.create_and_send_nonmember_event(event, ratelimit=False) if "name" in config: name = config["name"] - yield msg_handler.create_and_send_event({ + yield msg_handler.create_and_send_nonmember_event({ "type": EventTypes.Name, "room_id": room_id, "sender": user_id, @@ -195,7 +206,7 @@ class RoomCreationHandler(BaseHandler): if "topic" in config: topic = config["topic"] - yield msg_handler.create_and_send_event({ + yield msg_handler.create_and_send_nonmember_event({ "type": EventTypes.Topic, "room_id": room_id, "sender": user_id, @@ -204,13 +215,13 @@ class RoomCreationHandler(BaseHandler): }, ratelimit=False) for invitee in invite_list: - yield msg_handler.create_and_send_event({ - "type": EventTypes.Member, - "state_key": invitee, - "room_id": room_id, - "sender": user_id, - "content": {"membership": Membership.INVITE}, - }, ratelimit=False) + room_member_handler.update_membership( + requester, + UserID.from_string(invitee), + room_id, + "invite", + ratelimit=False, + ) for invite_3pid in invite_3pid_list: id_server = invite_3pid["id_server"] @@ -222,7 +233,7 @@ class RoomCreationHandler(BaseHandler): medium, address, id_server, - token_id=None, + requester, txn_id=None, ) @@ -439,12 +450,14 @@ class RoomMemberHandler(BaseHandler): errcode=Codes.BAD_STATE ) - yield msg_handler.send_event( + member_handler = self.hs.get_handlers().room_member_handler + yield member_handler.send_membership_event( event, context, - ratelimit=ratelimit, is_guest=requester.is_guest, + ratelimit=ratelimit, room_hosts=room_hosts, + from_client=True, ) if action == "forget": @@ -452,7 +465,7 @@ class RoomMemberHandler(BaseHandler): @defer.inlineCallbacks def send_membership_event( - self, event, context, is_guest=False, room_hosts=None, ratelimit=True + self, event, context, is_guest=False, room_hosts=None, ratelimit=True, from_client=True, ): """ Change the membership status of a user in a room. @@ -461,6 +474,16 @@ class RoomMemberHandler(BaseHandler): Raises: SynapseError if there was a problem changing the membership. """ + if from_client: + user = UserID.from_string(event.sender) + + assert self.hs.is_mine(user), "User must be our own: %s" % (user,) + + if event.is_state(): + prev_state = self.hs.get_handlers().message_handler.deduplicate_state_event(event, context) + if prev_state is not None: + return + target_user_id = event.state_key target_user = UserID.from_string(event.state_key) @@ -549,13 +572,11 @@ class RoomMemberHandler(BaseHandler): room_id, event.user_id ) - defer.returnValue({"room_id": room_id}) return # FIXME: This isn't idempotency. if prev_state and prev_state.membership == event.membership: # double same action, treat this event as a NOOP. - defer.returnValue({}) return yield self.handle_new_client_event( @@ -569,8 +590,6 @@ class RoomMemberHandler(BaseHandler): user = UserID.from_string(event.user_id) user_left_room(self.distributor, user, event.room_id) - defer.returnValue({"room_id": room_id}) - @defer.inlineCallbacks def lookup_room_alias(self, room_alias): """ @@ -657,7 +676,7 @@ class RoomMemberHandler(BaseHandler): medium, address, id_server, - token_id, + requester, txn_id ): invitee = yield self._lookup_3pid( @@ -665,19 +684,12 @@ class RoomMemberHandler(BaseHandler): ) if invitee: - # make sure it looks like a user ID; it'll throw if it's invalid. - UserID.from_string(invitee) - yield self.hs.get_handlers().message_handler.create_and_send_event( - { - "type": EventTypes.Member, - "content": { - "membership": unicode("invite") - }, - "room_id": room_id, - "sender": inviter.to_string(), - "state_key": invitee, - }, - token_id=token_id, + handler = self.hs.get_handlers().room_member_handler + yield handler.update_membership( + requester, + UserID.from_string(invitee), + room_id, + "invite", txn_id=txn_id, ) else: @@ -687,7 +699,7 @@ class RoomMemberHandler(BaseHandler): address, room_id, inviter, - token_id, + requester.access_token_id, txn_id=txn_id ) @@ -798,7 +810,7 @@ class RoomMemberHandler(BaseHandler): ) ) msg_handler = self.hs.get_handlers().message_handler - yield msg_handler.create_and_send_event( + yield msg_handler.create_and_send_nonmember_event( { "type": EventTypes.ThirdPartyInvite, "content": { diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 5f5c26a91c..179fe9a010 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -150,10 +150,21 @@ class RoomStateEventRestServlet(ClientV1RestServlet): event_dict["state_key"] = state_key msg_handler = self.handlers.message_handler - yield msg_handler.create_and_send_event( - event_dict, token_id=requester.access_token_id, txn_id=txn_id, + event, context = yield msg_handler.create_event( + event_dict, + token_id=requester.access_token_id, + txn_id=txn_id, ) + if event_type == EventTypes.Member: + yield self.handlers.room_member_handler.send_membership_event( + event, + context, + is_guest=requester.is_guest, + ) + else: + yield msg_handler.send_nonmember_event(event, context) + defer.returnValue((200, {})) @@ -171,7 +182,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet): content = _parse_json(request) msg_handler = self.handlers.message_handler - event = yield msg_handler.create_and_send_event( + event = yield msg_handler.create_and_send_nonmember_event( { "type": event_type, "content": content, @@ -434,7 +445,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet): content["medium"], content["address"], content["id_server"], - requester.access_token_id, + requester, txn_id ) defer.returnValue((200, {})) @@ -490,7 +501,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet): content = _parse_json(request) msg_handler = self.handlers.message_handler - event = yield msg_handler.create_and_send_event( + event = yield msg_handler.create_and_send_nonmember_event( { "type": EventTypes.Redaction, "content": content, -- cgit 1.4.1 From 458782bf67ef7c188af752b0f455d4a0f9f4cdd5 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Tue, 16 Feb 2016 18:00:30 +0000 Subject: Fix typo in request validation for adding push rules. --- synapse/rest/client/v1/push_rule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/rest/client') diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 96633a176c..7766b8be1d 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -400,7 +400,7 @@ def _filter_ruleset_with_path(ruleset, path): def _priority_class_from_spec(spec): if spec['template'] not in PRIORITY_CLASS_MAP.keys(): - raise InvalidRuleException("Unknown template: %s" % (spec['kind'])) + raise InvalidRuleException("Unknown template: %s" % (spec['template'])) pc = PRIORITY_CLASS_MAP[spec['template']] if spec['scope'] == 'device': -- cgit 1.4.1 From 71d5d2c669139305b829bdfdbd403a0b8a52b66f Mon Sep 17 00:00:00 2001 From: Patrik Oldsberg Date: Wed, 17 Feb 2016 11:52:30 +0100 Subject: client/v1/room: include event_id in response to state event PUT, in accordance with the spec Signed-off-by: Patrik Oldsberg --- synapse/rest/client/v1/room.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index d3c1b359a2..24706f9383 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -150,11 +150,11 @@ class RoomStateEventRestServlet(ClientV1RestServlet): event_dict["state_key"] = state_key msg_handler = self.handlers.message_handler - yield msg_handler.create_and_send_event( + event = yield msg_handler.create_and_send_event( event_dict, token_id=requester.access_token_id, txn_id=txn_id, ) - defer.returnValue((200, {})) + defer.returnValue((200, {"event_id": event.event_id})) # TODO: Needs unit testing for generic events + feedback -- cgit 1.4.1 From e5999bfb1a4aab56acecb59ed6d068442f5b11a0 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 15 Feb 2016 17:10:40 +0000 Subject: Initial cut --- synapse/handlers/events.py | 43 +- synapse/handlers/message.py | 14 +- synapse/handlers/presence.py | 1662 ++++++++------------ synapse/handlers/profile.py | 3 + synapse/handlers/sync.py | 22 + synapse/rest/client/v1/presence.py | 26 +- synapse/rest/client/v1/room.py | 18 +- synapse/rest/client/v2_alpha/receipts.py | 3 + synapse/rest/client/v2_alpha/sync.py | 16 +- synapse/storage/__init__.py | 50 +- synapse/storage/prepare_database.py | 2 +- synapse/storage/presence.py | 170 +- .../storage/schema/delta/30/presence_stream.sql | 30 + synapse/storage/util/id_generators.py | 4 +- synapse/util/__init__.py | 2 +- tests/utils.py | 4 +- 16 files changed, 933 insertions(+), 1136 deletions(-) create mode 100644 synapse/storage/schema/delta/30/presence_stream.sql (limited to 'synapse/rest/client') diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 4933c31c19..72a31a9755 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -19,6 +19,8 @@ from synapse.util.logutils import log_function from synapse.types import UserID from synapse.events.utils import serialize_event from synapse.util.logcontext import preserve_context_over_fn +from synapse.api.constants import Membership, EventTypes +from synapse.events import EventBase from ._base import BaseHandler @@ -126,11 +128,12 @@ class EventStreamHandler(BaseHandler): If `only_keys` is not None, events from keys will be sent down. """ auth_user = UserID.from_string(auth_user_id) + presence_handler = self.hs.get_handlers().presence_handler - try: - if affect_presence: - yield self.started_stream(auth_user) - + context = yield presence_handler.user_syncing( + auth_user_id, affect_presence=affect_presence, + ) + with context: if timeout: # If they've set a timeout set a minimum limit. timeout = max(timeout, 500) @@ -145,6 +148,34 @@ class EventStreamHandler(BaseHandler): is_guest=is_guest, explicit_room_id=room_id ) + # When the user joins a new room, or another user joins a currently + # joined room, we need to send down presence for those users. + to_add = [] + for event in events: + if not isinstance(event, EventBase): + continue + if event.type == EventTypes.Member: + if event.membership != Membership.JOIN: + continue + # Send down presence. + if event.state_key == auth_user_id: + # Send down presence for everyone in the room. + users = yield self.store.get_users_in_room(event.room_id) + states = yield presence_handler.get_states( + users, + as_event=True, + ) + to_add.extend(states) + else: + + ev = yield presence_handler.get_state( + UserID.from_string(event.state_key), + as_event=True, + ) + to_add.append(ev) + + events.extend(to_add) + time_now = self.clock.time_msec() chunks = [ @@ -159,10 +190,6 @@ class EventStreamHandler(BaseHandler): defer.returnValue(chunk) - finally: - if affect_presence: - self.stopped_stream(auth_user) - class EventHandler(BaseHandler): diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 82c8cb5f0c..77894d9132 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -21,7 +21,6 @@ from synapse.streams.config import PaginationConfig from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator from synapse.util import unwrapFirstError -from synapse.util.logcontext import PreserveLoggingContext from synapse.util.caches.snapshot_cache import SnapshotCache from synapse.types import UserID, RoomStreamToken, StreamToken @@ -254,8 +253,7 @@ class MessageHandler(BaseHandler): if event.type == EventTypes.Message: presence = self.hs.get_handlers().presence_handler - with PreserveLoggingContext(): - presence.bump_presence_active_time(user) + yield presence.bump_presence_active_time(user) @defer.inlineCallbacks def create_and_send_event(self, event_dict, ratelimit=True, @@ -660,10 +658,6 @@ class MessageHandler(BaseHandler): room_id=room_id, ) - # TODO(paul): I wish I was called with user objects not user_id - # strings... - auth_user = UserID.from_string(user_id) - # TODO: These concurrently time_now = self.clock.time_msec() state = [ @@ -688,13 +682,11 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def get_presence(): states = yield presence_handler.get_states( - target_users=[UserID.from_string(m.user_id) for m in room_members], - auth_user=auth_user, + [m.user_id for m in room_members], as_event=True, - check_auth=False, ) - defer.returnValue(states.values()) + defer.returnValue(states) @defer.inlineCallbacks def get_receipts(): diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index b61394f2b5..26f2e669ce 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -13,13 +13,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer +"""This module is responsible for keeping track of presence status of local +and remote users. -from synapse.api.errors import SynapseError, AuthError +The methods that define policy are: + - PresenceHandler._update_states + - PresenceHandler._handle_timeouts + - should_notify +""" + +from twisted.internet import defer, reactor +from contextlib import contextmanager + +from synapse.api.errors import SynapseError from synapse.api.constants import PresenceState +from synapse.storage.presence import UserPresenceState -from synapse.util.logcontext import PreserveLoggingContext +from synapse.util.logcontext import preserve_fn from synapse.util.logutils import log_function +from synapse.util.wheel_timer import WheelTimer from synapse.types import UserID import synapse.metrics @@ -33,33 +45,24 @@ logger = logging.getLogger(__name__) metrics = synapse.metrics.get_metrics_for(__name__) -# Don't bother bumping "last active" time if it differs by less than 60 seconds +# If a user was last active in the last LAST_ACTIVE_GRANULARITY, consider them +# "currently_active" LAST_ACTIVE_GRANULARITY = 60 * 1000 -# Keep no more than this number of offline serial revisions -MAX_OFFLINE_SERIALS = 1000 - - -# TODO(paul): Maybe there's one of these I can steal from somewhere -def partition(l, func): - """Partition the list by the result of func applied to each element.""" - ret = {} +# How long to wait until a new /events or /sync request before assuming +# the client has gone. +SYNC_ONLINE_TIMEOUT = 30 * 1000 - for x in l: - key = func(x) - if key not in ret: - ret[key] = [] - ret[key].append(x) +# How long to wait before marking the user as idle. Compared against last active +IDLE_TIMER = 5 * 60 * 1000 - return ret +# How often we expect remote servers to resend us presence. +FEDERATION_TIMEOUT = 30 * 60 * 1000 +# How often to resend presence to remote servers +FEDERATION_PING_INTERVAL = 25 * 60 * 1000 -def partitionbool(l, func): - def boolfunc(x): - return bool(func(x)) - - ret = partition(l, boolfunc) - return ret.get(True, []), ret.get(False, []) +assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER def user_presence_changed(distributor, user, statuscache): @@ -72,45 +75,13 @@ def collect_presencelike_data(distributor, user, content): class PresenceHandler(BaseHandler): - STATE_LEVELS = { - PresenceState.OFFLINE: 0, - PresenceState.UNAVAILABLE: 1, - PresenceState.ONLINE: 2, - PresenceState.FREE_FOR_CHAT: 3, - } - def __init__(self, hs): super(PresenceHandler, self).__init__(hs) - - self.homeserver = hs - + self.hs = hs self.clock = hs.get_clock() - - distributor = hs.get_distributor() - distributor.observe("registered_user", self.registered_user) - - distributor.observe( - "started_user_eventstream", self.started_user_eventstream - ) - distributor.observe( - "stopped_user_eventstream", self.stopped_user_eventstream - ) - - distributor.observe("user_joined_room", self.user_joined_room) - - distributor.declare("collect_presencelike_data") - - distributor.declare("changed_presencelike_data") - distributor.observe( - "changed_presencelike_data", self.changed_presencelike_data - ) - - # outbound signal from the presence module to advertise when a user's - # presence has changed - distributor.declare("user_presence_changed") - - self.distributor = distributor - + self.store = hs.get_datastore() + self.wheel_timer = WheelTimer() + self.notifier = hs.get_notifier() self.federation = hs.get_replication_layer() self.federation.register_edu_handler( @@ -138,348 +109,574 @@ class PresenceHandler(BaseHandler): ) ) - # IN-MEMORY store, mapping local userparts to sets of local users to - # be informed of state changes. - self._local_pushmap = {} - # map local users to sets of remote /domain names/ who are interested - # in them - self._remote_sendmap = {} - # map remote users to sets of local users who're interested in them - self._remote_recvmap = {} - # list of (serial, set of(userids)) tuples, ordered by serial, latest - # first - self._remote_offline_serials = [] - - # map any user to a UserPresenceCache - self._user_cachemap = {} - self._user_cachemap_latest_serial = 0 - - # map room_ids to the latest presence serial for a member of that - # room - self._room_serials = {} - - metrics.register_callback( - "userCachemap:size", - lambda: len(self._user_cachemap), + distributor = hs.get_distributor() + distributor.observe("user_joined_room", self.user_joined_room) + + active_presence = self.store.take_presence_startup_info() + + # A dictionary of the current state of users. This is prefilled with + # non-offline presence from the DB. We should fetch from the DB if + # we can't find a users presence in here. + self.user_to_current_state = { + state.user_id: state + for state in active_presence + } + + now = self.clock.time_msec() + for state in active_presence: + self.wheel_timer.insert( + now=now, + obj=state.user_id, + then=state.last_active + IDLE_TIMER, + ) + self.wheel_timer.insert( + now=now, + obj=state.user_id, + then=state.last_user_sync + SYNC_ONLINE_TIMEOUT, + ) + if self.hs.is_mine_id(state.user_id): + self.wheel_timer.insert( + now=now, + obj=state.user_id, + then=state.last_federation_update + FEDERATION_PING_INTERVAL, + ) + else: + self.wheel_timer.insert( + now=now, + obj=state.user_id, + then=state.last_federation_update + FEDERATION_TIMEOUT, + ) + + # Set of users who have presence in the `user_to_current_state` that + # have not yet been persisted + self.unpersisted_users_changes = set() + + reactor.addSystemEventTrigger("before", "shutdown", self._on_shutdown) + + self.serial_to_user = {} + self._next_serial = 1 + + # Keeps track of the number of *ongoing* syncs. While this is non zero + # a user will never go offline. + self.user_to_num_current_syncs = {} + + # Start a LoopingCall in 30s that fires every 5s. + # The initial delay is to allow disconnected clients a chance to + # reconnect before we treat them as offline. + self.clock.call_later( + 0 * 1000, + self.clock.looping_call, + self._handle_timeouts, + 5000, ) - def _get_or_make_usercache(self, user): - """If the cache entry doesn't exist, initialise a new one.""" - if user not in self._user_cachemap: - self._user_cachemap[user] = UserPresenceCache() - return self._user_cachemap[user] - - def _get_or_offline_usercache(self, user): - """If the cache entry doesn't exist, return an OFFLINE one but do not - store it into the cache.""" - if user in self._user_cachemap: - return self._user_cachemap[user] - else: - return UserPresenceCache() + @defer.inlineCallbacks + def _on_shutdown(self): + """Gets called when shutting down. This lets us persist any updates that + we haven't yet persisted, e.g. updates that only changes some internal + timers. This allows changes to persist across startup without having to + persist every single change. + + If this does not run it simply means that some of the timers will fire + earlier than they should when synapse is restarted. This affect of this + is some spurious presence changes that will self-correct. + """ + logger.info( + "Performing _on_shutdown. Persiting %d unpersisted changes", + len(self.user_to_current_state) + ) - def registered_user(self, user): - return self.store.create_presence(user.localpart) + if self.unpersisted_users_changes: + yield self.store.update_presence([ + self.user_to_current_state[user_id] + for user_id in self.unpersisted_users_changes + ]) + logger.info("Finished _on_shutdown") @defer.inlineCallbacks - def is_presence_visible(self, observer_user, observed_user): - assert(self.hs.is_mine(observed_user)) + def _update_states(self, new_states): + """Updates presence of users. Sets the appropriate timeouts. Pokes + the notifier and federation if and only if the changed presence state + should be sent to clients/servers. + """ + now = self.clock.time_msec() - if observer_user == observed_user: - defer.returnValue(True) + # NOTE: We purposefully don't yield between now and when we've + # calculated what we want to do with the new states, to avoid races. - if (yield self.store.user_rooms_intersect( - [u.to_string() for u in observer_user, observed_user])): - defer.returnValue(True) + to_notify = {} # Changes we want to notify everyone about + to_federation_ping = {} # These need sending keep-alives + for new_state in new_states: + user_id = new_state.user_id + prev_state = self.user_to_current_state.get( + user_id, UserPresenceState.default(user_id) + ) - if (yield self.store.is_presence_visible( - observed_localpart=observed_user.localpart, - observer_userid=observer_user.to_string())): - defer.returnValue(True) + # If the users are ours then we want to set up a bunch of timers + # to time things out. + if self.hs.is_mine_id(user_id): + if new_state.state == PresenceState.ONLINE: + # Idle timer + self.wheel_timer.insert( + now=now, + obj=user_id, + then=new_state.last_active + IDLE_TIMER + ) - defer.returnValue(False) + if new_state.state != PresenceState.OFFLINE: + # User has stopped syncing + self.wheel_timer.insert( + now=now, + obj=user_id, + then=new_state.last_user_sync + SYNC_ONLINE_TIMEOUT + ) - @defer.inlineCallbacks - def get_state(self, target_user, auth_user, as_event=False, check_auth=True): - """Get the current presence state of the given user. + last_federate = new_state.last_federation_update + if now - last_federate > FEDERATION_PING_INTERVAL: + # Been a while since we've poked remote servers + new_state = new_state.copy_and_replace( + last_federation_update=now, + ) + to_federation_ping[user_id] = new_state - Args: - target_user (UserID): The user whose presence we want - auth_user (UserID): The user requesting the presence, used for - checking if said user is allowed to see the persence of the - `target_user` - as_event (bool): Format the return as an event or not? - check_auth (bool): Perform the auth checks or not? + else: + self.wheel_timer.insert( + now=now, + obj=user_id, + then=new_state.last_federation_update + FEDERATION_TIMEOUT + ) - Returns: - dict: The presence state of the `target_user`, whose format depends - on the `as_event` argument. - """ - if self.hs.is_mine(target_user): - if check_auth: - visible = yield self.is_presence_visible( - observer_user=auth_user, - observed_user=target_user + if new_state.state == PresenceState.ONLINE: + currently_active = now - new_state.last_active < LAST_ACTIVE_GRANULARITY + new_state = new_state.copy_and_replace( + currently_active=currently_active, ) - if not visible: - raise SynapseError(404, "Presence information not visible") + # Check whether the change was something worth notifying about + if should_notify(prev_state, new_state): + new_state.copy_and_replace( + last_federation_update=now, + ) + to_notify[user_id] = new_state - if target_user in self._user_cachemap: - state = self._user_cachemap[target_user].get_state() - else: - state = yield self.store.get_presence_state(target_user.localpart) - if "mtime" in state: - del state["mtime"] - state["presence"] = state.pop("state") - else: - # TODO(paul): Have remote server send us permissions set - state = self._get_or_offline_usercache(target_user).get_state() + self.user_to_current_state[user_id] = new_state + + # TODO: We should probably ensure there are no races hereafter - if "last_active" in state: - state["last_active_ago"] = int( - self.clock.time_msec() - state.pop("last_active") + if to_notify: + yield self._persist_and_notify(to_notify.values()) + + self.unpersisted_users_changes |= set(s.user_id for s in new_states) + self.unpersisted_users_changes -= set(to_notify.keys()) + + to_federation_ping = { + user_id: state for user_id, state in to_federation_ping.items() + if user_id not in to_notify + } + if to_federation_ping: + _, _, hosts_to_states = yield self._get_interested_parties( + to_federation_ping.values() ) - if as_event: - content = state + self._push_to_remotes(hosts_to_states) - content["user_id"] = target_user.to_string() + def _handle_timeouts(self): + """Checks the presence of users that have timed out and updates as + appropriate. + """ + now = self.clock.time_msec() - if "last_active" in content: - content["last_active_ago"] = int( - self._clock.time_msec() - content.pop("last_active") - ) + # Fetch the list of users that *may* have timed out. Things may have + # changed since the timeout was set, so we won't necessarily have to + # take any action. + users_to_check = self.wheel_timer.fetch(now) - defer.returnValue({"type": "m.presence", "content": content}) - else: - defer.returnValue(state) + changes = {} # Actual changes we need to notify people about - @defer.inlineCallbacks - def get_states(self, target_users, auth_user, as_event=False, check_auth=True): - """A batched version of the `get_state` method that accepts a list of - `target_users` + for user_id in set(users_to_check): + state = self.user_to_current_state.get(user_id, None) + if not state: + continue - Args: - target_users (list): The list of UserID's whose presence we want - auth_user (UserID): The user requesting the presence, used for - checking if said user is allowed to see the persence of the - `target_users` - as_event (bool): Format the return as an event or not? - check_auth (bool): Perform the auth checks or not? + if self.hs.is_mine_id(user_id): + if state.state == PresenceState.OFFLINE: + continue - Returns: - dict: A mapping from user -> presence_state + if state.state == PresenceState.ONLINE: + if now - state.last_active > IDLE_TIMER: + # Currently online, but last activity ages ago so auto + # idle + changes[user_id] = state.copy_and_replace( + state=PresenceState.UNAVAILABLE, + ) + elif now - state.last_active > LAST_ACTIVE_GRANULARITY: + # So that we send down a notification that we've + # stopped updating. + changes[user_id] = state + + if now - state.last_federation_update > FEDERATION_PING_INTERVAL: + # Need to send ping to other servers to ensure they don't + # timeout and set us to offline + changes[user_id] = state + + # If there are have been no sync for a while (and none ongoing), + # set presence to offline + if not self.user_to_num_current_syncs.get(user_id, 0): + if now - state.last_user_sync > SYNC_ONLINE_TIMEOUT: + changes[user_id] = state.copy_and_replace( + state=PresenceState.OFFLINE, + ) + else: + # We expect to be poked occaisonally by the other side. + # This is to protect against forgetful/buggy servers, so that + # no one gets stuck online forever. + if now - state.last_federation_update > FEDERATION_TIMEOUT: + if state.state != PresenceState.OFFLINE: + # The other side seems to have disappeared. + changes[user_id] = state.copy_and_replace( + state=PresenceState.OFFLINE, + ) + + preserve_fn(self._update_states)(changes.values()) + + @defer.inlineCallbacks + def bump_presence_active_time(self, user): + """We've seen the user do something that indicates they're interacting + with the app. """ - local_users, remote_users = partitionbool( - target_users, - lambda u: self.hs.is_mine(u) - ) + user_id = user.to_string() - if check_auth: - for user in local_users: - visible = yield self.is_presence_visible( - observer_user=auth_user, - observed_user=user - ) + prev_state = yield self.current_state_for_user(user_id) - if not visible: - raise SynapseError(404, "Presence information not visible") + yield self._update_states([prev_state.copy_and_replace( + state=PresenceState.ONLINE, + last_active=self.clock.time_msec(), + )]) - results = {} - if local_users: - for user in local_users: - if user in self._user_cachemap: - results[user] = self._user_cachemap[user].get_state() + @defer.inlineCallbacks + def user_syncing(self, user_id, affect_presence=True): + """Returns a context manager that should surround any stream requests + from the user. - local_to_user = {u.localpart: u for u in local_users} + This allows us to keep track of who is currently streaming and who isn't + without having to have timers outside of this module to avoid flickering + when users disconnect/reconnect. - states = yield self.store.get_presence_states( - [u.localpart for u in local_users if u not in results] - ) + Args: + user_id (str) + affect_presence (bool): If false this function will be a no-op. + Useful for streams that are not associated with an actual + client that is being used by a user. + """ + if affect_presence: + curr_sync = self.user_to_num_current_syncs.get(user_id, 0) + self.user_to_num_current_syncs[user_id] = curr_sync + 1 + + prev_state = yield self.current_state_for_user(user_id) + if prev_state.state == PresenceState.OFFLINE: + # If they're currently offline then bring them online, otherwise + # just update the last sync times. + yield self._update_states([prev_state.copy_and_replace( + state=PresenceState.ONLINE, + last_active=self.clock.time_msec(), + last_user_sync=self.clock.time_msec(), + )]) + else: + yield self._update_states([prev_state.copy_and_replace( + last_user_sync=self.clock.time_msec(), + )]) - for local_part, state in states.items(): - if state is None: - continue - res = {"presence": state["state"]} - if "status_msg" in state and state["status_msg"]: - res["status_msg"] = state["status_msg"] - results[local_to_user[local_part]] = res - - for user in remote_users: - # TODO(paul): Have remote server send us permissions set - results[user] = self._get_or_offline_usercache(user).get_state() - - for state in results.values(): - if "last_active" in state: - state["last_active_ago"] = int( - self.clock.time_msec() - state.pop("last_active") - ) + @defer.inlineCallbacks + def _end(): + if affect_presence: + self.user_to_num_current_syncs[user_id] -= 1 - if as_event: - for user, state in results.items(): - content = state - content["user_id"] = user.to_string() + prev_state = yield self.current_state_for_user(user_id) + yield self._update_states([prev_state.copy_and_replace( + last_user_sync=self.clock.time_msec(), + )]) - if "last_active" in content: - content["last_active_ago"] = int( - self._clock.time_msec() - content.pop("last_active") - ) + @contextmanager + def _user_syncing(): + try: + yield + finally: + preserve_fn(_end)() - results[user] = {"type": "m.presence", "content": content} + defer.returnValue(_user_syncing()) - defer.returnValue(results) + @defer.inlineCallbacks + def current_state_for_user(self, user_id): + """Get the current presence state for a user. + """ + res = yield self.current_state_for_users([user_id]) + defer.returnValue(res[user_id]) @defer.inlineCallbacks - @log_function - def set_state(self, target_user, auth_user, state): - # return - # TODO (erikj): Turn this back on. Why did we end up sending EDUs - # everywhere? + def current_state_for_users(self, user_ids): + """Get the current presence state for multiple users. - if not self.hs.is_mine(target_user): - raise SynapseError(400, "User is not hosted on this Home Server") + Returns: + dict: `user_id` -> `UserPresenceState` + """ + states = { + user_id: self.user_to_current_state.get(user_id, None) + for user_id in user_ids + } + + missing = [user_id for user_id, state in states.items() if not state] + if missing: + # There are things not in our in memory cache. Lets pull them out of + # the database. + res = yield self.store.get_presence_for_users(missing) + states.update({state.user_id: state for state in res}) + + missing = [user_id for user_id, state in states.items() if not state] + if missing: + states.update({ + user_id: UserPresenceState.default(user_id) + for user_id in missing + }) - if target_user != auth_user: - raise AuthError(400, "Cannot set another user's presence") + defer.returnValue(states) - if "status_msg" not in state: - state["status_msg"] = None + @defer.inlineCallbacks + def _get_interested_parties(self, states): + """Given a list of states return which entities (rooms, users, servers) + are interested in the given states. - for k in state.keys(): - if k not in ("presence", "status_msg"): - raise SynapseError( - 400, "Unexpected presence state key '%s'" % (k,) - ) + Returns: + 3-tuple: `(room_ids_to_states, users_to_states, hosts_to_states)`, + with each item being a dict of `entity_name` -> `[UserPresenceState]` + """ + room_ids_to_states = {} + users_to_states = {} + for state in states: + events = yield self.store.get_rooms_for_user(state.user_id) + for e in events: + room_ids_to_states.setdefault(e.room_id, []).append(state) + + plist = yield self.store.get_presence_list_observers_accepted(state.user_id) + for u in plist: + users_to_states.setdefault(u, []).append(state) + + # Always notify self + users_to_states.setdefault(state.user_id, []).append(state) + + hosts_to_states = {} + for room_id, states in room_ids_to_states.items(): + hosts = yield self.store.get_joined_hosts_for_room(room_id) + for host in hosts: + hosts_to_states.setdefault(host, []).extend(states) - if state["presence"] not in self.STATE_LEVELS: - raise SynapseError(400, "'%s' is not a valid presence state" % ( - state["presence"], - )) + for user_id, states in users_to_states.items(): + host = UserID.from_string(user_id).domain + hosts_to_states.setdefault(host, []).extend(states) - logger.debug("Updating presence state of %s to %s", - target_user.localpart, state["presence"]) + # TODO: de-dup hosts_to_states, as a single host might have multiple + # of same presence - state_to_store = dict(state) - state_to_store["state"] = state_to_store.pop("presence") + defer.returnValue((room_ids_to_states, users_to_states, hosts_to_states)) + + @defer.inlineCallbacks + def _persist_and_notify(self, states): + """Persist states in the database, poke the notifier and send to + interested remote servers + """ + stream_id, max_token = yield self.store.update_presence(states) - statuscache = self._get_or_offline_usercache(target_user) - was_level = self.STATE_LEVELS[statuscache.get_state()["presence"]] - now_level = self.STATE_LEVELS[state["presence"]] + parties = yield self._get_interested_parties(states) + room_ids_to_states, users_to_states, hosts_to_states = parties - yield self.store.set_presence_state( - target_user.localpart, state_to_store + self.notifier.on_new_event( + "presence_key", stream_id, rooms=room_ids_to_states.keys(), + users=[UserID.from_string(u) for u in users_to_states.keys()] ) - yield collect_presencelike_data(self.distributor, target_user, state) - if now_level > was_level: - state["last_active"] = self.clock.time_msec() + self._push_to_remotes(hosts_to_states) + + def _push_to_remotes(self, hosts_to_states): + """Sends state updates to remote servers. + + Args: + hosts_to_states (dict): Mapping `server_name` -> `[UserPresenceState]` + """ + now = self.clock.time_msec() + for host, states in hosts_to_states.items(): + self.federation.send_edu( + destination=host, + edu_type="m.presence", + content={ + "push": [ + _format_user_presence_state(state, now) + for state in states + ] + } + ) + + @defer.inlineCallbacks + def incoming_presence(self, origin, content): + """Called when we receive a `m.presence` EDU from a remote server. + """ + now = self.clock.time_msec() + updates = [] + for push in content.get("push", []): + # A "push" contains a list of presence that we are probably interested + # in. + # TODO: Actually check if we're interested, rather than blindly + # accepting presence updates. + user_id = push.get("user_id", None) + if not user_id: + logger.info( + "Got presence update from %r with no 'user_id': %r", + origin, push, + ) + continue - now_online = state["presence"] != PresenceState.OFFLINE - was_polling = target_user in self._user_cachemap + presence_state = push.get("presence", None) + if not presence_state: + logger.info( + "Got presence update from %r with no 'presence_state': %r", + origin, push, + ) + continue - if now_online and not was_polling: - yield self.start_polling_presence(target_user, state=state) - elif not now_online and was_polling: - yield self.stop_polling_presence(target_user) + new_fields = { + "state": presence_state, + "last_federation_update": now, + } - # TODO(paul): perform a presence push as part of start/stop poll so - # we don't have to do this all the time - yield self.changed_presencelike_data(target_user, state) + last_active_ago = push.get("last_active_ago", None) + if last_active_ago is not None: + new_fields["last_active"] = now - last_active_ago - def bump_presence_active_time(self, user, now=None): - if now is None: - now = self.clock.time_msec() + new_fields["status_msg"] = push.get("status_msg", None) - prev_state = self._get_or_make_usercache(user) - if now - prev_state.state.get("last_active", 0) < LAST_ACTIVE_GRANULARITY: - return + prev_state = yield self.current_state_for_user(user_id) + updates.append(prev_state.copy_and_replace(**new_fields)) - with PreserveLoggingContext(): - self.changed_presencelike_data(user, {"last_active": now}) + if updates: + yield self._update_states(updates) - def get_joined_rooms_for_user(self, user): - """Get the list of rooms a user is joined to. + @defer.inlineCallbacks + def get_state(self, target_user, as_event=False): + results = yield self.get_states( + [target_user.to_string()], + as_event=as_event, + ) + + defer.returnValue(results[0]) + + @defer.inlineCallbacks + def get_states(self, target_user_ids, as_event=False): + """Get the presence state for users. Args: - user(UserID): The user. + target_user_ids (list) + as_event (bool): Whether to format it as a client event or not. + Returns: - A Deferred of a list of room id strings. + list """ - rm_handler = self.homeserver.get_handlers().room_member_handler - return rm_handler.get_joined_rooms_for_user(user) - def get_joined_users_for_room_id(self, room_id): - rm_handler = self.homeserver.get_handlers().room_member_handler - return rm_handler.get_room_members(room_id) + updates = yield self.current_state_for_users(target_user_ids) + updates = updates.values() - @defer.inlineCallbacks - def changed_presencelike_data(self, user, state): - """Updates the presence state of a local user. + for user_id in set(target_user_ids) - set(u.user_id for u in updates): + updates.append(UserPresenceState.default(user_id)) - Args: - user(UserID): The user being updated. - state(dict): The new presence state for the user. - Returns: - A Deferred + now = self.clock.time_msec() + if as_event: + defer.returnValue([ + { + "type": "m.presence", + "content": _format_user_presence_state(state, now), + } + for state in updates + ]) + else: + defer.returnValue([ + _format_user_presence_state(state, now) for state in updates + ]) + + @defer.inlineCallbacks + def set_state(self, target_user, state): + """Set the presence state of the user. """ - self._user_cachemap_latest_serial += 1 - statuscache = yield self.update_presence_cache(user, state) - yield self.push_presence(user, statuscache=statuscache) + status_msg = state.get("status_msg", None) + presence = state["presence"] - @log_function - def started_user_eventstream(self, user): - # TODO(paul): Use "last online" state - return self.set_state(user, user, {"presence": PresenceState.ONLINE}) + user_id = target_user.to_string() - @log_function - def stopped_user_eventstream(self, user): - # TODO(paul): Save current state as "last online" state - return self.set_state(user, user, {"presence": PresenceState.OFFLINE}) + prev_state = yield self.current_state_for_user(user_id) + + new_fields = { + "state": presence, + "status_msg": status_msg + } + + if presence == PresenceState.ONLINE: + new_fields["last_active"] = self.clock.time_msec() + + yield self._update_states([prev_state.copy_and_replace(**new_fields)]) @defer.inlineCallbacks def user_joined_room(self, user, room_id): - """Called via the distributor whenever a user joins a room. - Notifies the new member of the presence of the current members. - Notifies the current members of the room of the new member's presence. - - Args: - user(UserID): The user who joined the room. - room_id(str): The room id the user joined. + """Called (via the distributor) when a user joins a room. This funciton + sends presence updates to servers, either: + 1. the joining user is a local user and we send their presence to + all servers in the room. + 2. the joining user is a remote user and so we send presence for all + local users in the room. """ + # We only need to send presence to servers that don't have it yet. We + # don't need to send to local clients here, as that is done as part + # of the event stream/sync. + # TODO: Only send to servers not already in the room. if self.hs.is_mine(user): - # No actual update but we need to bump the serial anyway for the - # event source - self._user_cachemap_latest_serial += 1 - statuscache = yield self.update_presence_cache( - user, room_ids=[room_id] - ) - self.push_update_to_local_and_remote( - observed_user=user, - room_ids=[room_id], - statuscache=statuscache, - ) + state = yield self.current_state_for_user(user.to_string()) - # We also want to tell them about current presence of people. - curr_users = yield self.get_joined_users_for_room_id(room_id) + hosts = yield self.store.get_joined_hosts_for_room(room_id) + self._push_to_remotes({host: (state,) for host in hosts}) + else: + user_ids = yield self.store.get_users_in_room(room_id) + user_ids = filter(self.hs.is_mine_id, user_ids) - for local_user in [c for c in curr_users if self.hs.is_mine(c)]: - statuscache = yield self.update_presence_cache( - local_user, room_ids=[room_id], add_to_cache=False - ) + states = yield self.current_state_for_users(user_ids) - with PreserveLoggingContext(): - self.push_update_to_local_and_remote( - observed_user=local_user, - users_to_push=[user], - statuscache=statuscache, - ) + self._push_to_remotes({user.domain: states.values()}) @defer.inlineCallbacks - def send_presence_invite(self, observer_user, observed_user): - """Request the presence of a local or remote user for a local user""" + def get_presence_list(self, observer_user, accepted=None): + """Returns the presence for all users in their presence list. + """ if not self.hs.is_mine(observer_user): raise SynapseError(400, "User is not hosted on this Home Server") + presence_list = yield self.store.get_presence_list( + observer_user.localpart, accepted=accepted + ) + + results = yield self.get_states( + target_user_ids=[row["observed_user_id"] for row in presence_list], + as_event=False, + ) + + is_accepted = { + row["observed_user_id"]: row["accepted"] for row in presence_list + } + + for result in results: + result.update({ + "accepted": is_accepted, + }) + + defer.returnValue(results) + + @defer.inlineCallbacks + def send_presence_invite(self, observer_user, observed_user): + """Sends a presence invite. + """ yield self.store.add_presence_list_pending( observer_user.localpart, observed_user.to_string() ) @@ -496,60 +693,41 @@ class PresenceHandler(BaseHandler): } ) - @defer.inlineCallbacks - def _should_accept_invite(self, observed_user, observer_user): - if not self.hs.is_mine(observed_user): - defer.returnValue(False) - - row = yield self.store.has_presence_state(observed_user.localpart) - if not row: - defer.returnValue(False) - - # TODO(paul): Eventually we'll ask the user's permission for this - # before accepting. For now just accept any invite request - defer.returnValue(True) - @defer.inlineCallbacks def invite_presence(self, observed_user, observer_user): - """Handles a m.presence_invite EDU. A remote or local user has - requested presence updates for a local user. If the invite is accepted - then allow the local or remote user to see the presence of the local - user. - - Args: - observed_user(UserID): The local user whose presence is requested. - observer_user(UserID): The remote or local user requesting presence. + """Handles new presence invites. """ - accept = yield self._should_accept_invite(observed_user, observer_user) - - if accept: - yield self.store.allow_presence_visible( - observed_user.localpart, observer_user.to_string() - ) + if not self.hs.is_mine(observed_user): + raise SynapseError(400, "User is not hosted on this Home Server") + # TODO: Don't auto accept if self.hs.is_mine(observer_user): - if accept: - yield self.accept_presence(observed_user, observer_user) - else: - yield self.deny_presence(observed_user, observer_user) + yield self.accept_presence(observed_user, observer_user) else: - edu_type = "m.presence_accept" if accept else "m.presence_deny" - - yield self.federation.send_edu( + self.federation.send_edu( destination=observer_user.domain, - edu_type=edu_type, + edu_type="m.presence_accept", content={ "observed_user": observed_user.to_string(), "observer_user": observer_user.to_string(), } ) + state_dict = yield self.get_state(observed_user, as_event=False) + + self.federation.send_edu( + destination=observer_user.domain, + edu_type="m.presence", + content={ + "push": [state_dict] + } + ) + @defer.inlineCallbacks def accept_presence(self, observed_user, observer_user): """Handles a m.presence_accept EDU. Mark a presence invite from a local or remote user as accepted in a local user's presence list. Starts polling for presence updates from the local or remote user. - Args: observed_user(UserID): The user to update in the presence list. observer_user(UserID): The owner of the presence list to update. @@ -558,15 +736,10 @@ class PresenceHandler(BaseHandler): observer_user.localpart, observed_user.to_string() ) - yield self.start_polling_presence( - observer_user, target_user=observed_user - ) - @defer.inlineCallbacks def deny_presence(self, observed_user, observer_user): """Handle a m.presence_deny EDU. Removes a local or remote user from a local user's presence list. - Args: observed_user(UserID): The local or remote user to remove from the list. @@ -584,7 +757,6 @@ class PresenceHandler(BaseHandler): def drop(self, observed_user, observer_user): """Remove a local or remote user from a local user's presence list and unsubscribe the local user from updates that user. - Args: observed_user(UserId): The local or remote user to remove from the list. @@ -599,710 +771,138 @@ class PresenceHandler(BaseHandler): observer_user.localpart, observed_user.to_string() ) - self.stop_polling_presence( - observer_user, target_user=observed_user - ) - - @defer.inlineCallbacks - def get_presence_list(self, observer_user, accepted=None): - """Get the presence list for a local user. The retured list includes - the current presence state for each user listed. - - Args: - observer_user(UserID): The local user whose presence list to fetch. - accepted(bool or None): If not none then only include users who - have or have not accepted the presence invite request. - Returns: - A Deferred list of presence state events. - """ - if not self.hs.is_mine(observer_user): - raise SynapseError(400, "User is not hosted on this Home Server") - - presence_list = yield self.store.get_presence_list( - observer_user.localpart, accepted=accepted - ) - - results = [] - for row in presence_list: - observed_user = UserID.from_string(row["observed_user_id"]) - result = { - "observed_user": observed_user, "accepted": row["accepted"] - } - result.update( - self._get_or_offline_usercache(observed_user).get_state() - ) - if "last_active" in result: - result["last_active_ago"] = int( - self.clock.time_msec() - result.pop("last_active") - ) - results.append(result) - - defer.returnValue(results) - - @defer.inlineCallbacks - @log_function - def start_polling_presence(self, user, target_user=None, state=None): - """Subscribe a local user to presence updates from a local or remote - user. If no target_user is supplied then subscribe to all users stored - in the presence list for the local user. - - Additonally this pushes the current presence state of this user to all - target_users. That state can be provided directly or will be read from - the stored state for the local user. - - Also this attempts to notify the local user of the current state of - any local target users. - - Args: - user(UserID): The local user that whishes for presence updates. - target_user(UserID): The local or remote user whose updates are - wanted. - state(dict): Optional presence state for the local user. - """ - logger.debug("Start polling for presence from %s", user) - - if target_user: - target_users = set([target_user]) - room_ids = [] - else: - presence = yield self.store.get_presence_list( - user.localpart, accepted=True - ) - target_users = set([ - UserID.from_string(x["observed_user_id"]) for x in presence - ]) - - # Also include people in all my rooms - - room_ids = yield self.get_joined_rooms_for_user(user) - - if state is None: - state = yield self.store.get_presence_state(user.localpart) - else: - # statuscache = self._get_or_make_usercache(user) - # self._user_cachemap_latest_serial += 1 - # statuscache.update(state, self._user_cachemap_latest_serial) - pass - - yield self.push_update_to_local_and_remote( - observed_user=user, - users_to_push=target_users, - room_ids=room_ids, - statuscache=self._get_or_make_usercache(user), - ) - - for target_user in target_users: - if self.hs.is_mine(target_user): - self._start_polling_local(user, target_user) - - # We want to tell the person that just came online - # presence state of people they are interested in? - self.push_update_to_clients( - users_to_push=[user], - ) - - deferreds = [] - remote_users = [u for u in target_users if not self.hs.is_mine(u)] - remoteusers_by_domain = partition(remote_users, lambda u: u.domain) - # Only poll for people in our get_presence_list - for domain in remoteusers_by_domain: - remoteusers = remoteusers_by_domain[domain] - - deferreds.append(self._start_polling_remote( - user, domain, remoteusers - )) - - yield defer.DeferredList(deferreds, consumeErrors=True) - - def _start_polling_local(self, user, target_user): - """Subscribe a local user to presence updates for a local user - - Args: - user(UserId): The local user that wishes for updates. - target_user(UserId): The local users whose updates are wanted. - """ - target_localpart = target_user.localpart - - if target_localpart not in self._local_pushmap: - self._local_pushmap[target_localpart] = set() - - self._local_pushmap[target_localpart].add(user) - - def _start_polling_remote(self, user, domain, remoteusers): - """Subscribe a local user to presence updates for remote users on a - given remote domain. - - Args: - user(UserID): The local user that wishes for updates. - domain(str): The remote server the local user wants updates from. - remoteusers(UserID): The remote users that local user wants to be - told about. - Returns: - A Deferred. - """ - to_poll = set() - - for u in remoteusers: - if u not in self._remote_recvmap: - self._remote_recvmap[u] = set() - to_poll.add(u) - - self._remote_recvmap[u].add(user) - - if not to_poll: - return defer.succeed(None) - - return self.federation.send_edu( - destination=domain, - edu_type="m.presence", - content={"poll": [u.to_string() for u in to_poll]} - ) - - @log_function - def stop_polling_presence(self, user, target_user=None): - """Unsubscribe a local user from presence updates from a local or - remote user. If no target user is supplied then unsubscribe the user - from all presence updates that the user had subscribed to. - - Args: - user(UserID): The local user that no longer wishes for updates. - target_user(UserID or None): The user whose updates are no longer - wanted. - Returns: - A Deferred. - """ - logger.debug("Stop polling for presence from %s", user) - - if not target_user or self.hs.is_mine(target_user): - self._stop_polling_local(user, target_user=target_user) - - deferreds = [] - - if target_user: - if target_user not in self._remote_recvmap: - return - target_users = set([target_user]) - else: - target_users = self._remote_recvmap.keys() - - remoteusers = [u for u in target_users - if user in self._remote_recvmap[u]] - remoteusers_by_domain = partition(remoteusers, lambda u: u.domain) - - for domain in remoteusers_by_domain: - remoteusers = remoteusers_by_domain[domain] - - deferreds.append( - self._stop_polling_remote(user, domain, remoteusers) - ) - - return defer.DeferredList(deferreds, consumeErrors=True) - - def _stop_polling_local(self, user, target_user): - """Unsubscribe a local user from presence updates from a local user on - this server. - - Args: - user(UserID): The local user that no longer wishes for updates. - target_user(UserID): The user whose updates are no longer wanted. - """ - for localpart in self._local_pushmap.keys(): - if target_user and localpart != target_user.localpart: - continue - - if user in self._local_pushmap[localpart]: - self._local_pushmap[localpart].remove(user) - - if not self._local_pushmap[localpart]: - del self._local_pushmap[localpart] - - @log_function - def _stop_polling_remote(self, user, domain, remoteusers): - """Unsubscribe a local user from presence updates from remote users on - a given domain. - - Args: - user(UserID): The local user that no longer wishes for updates. - domain(str): The remote server to unsubscribe from. - remoteusers([UserID]): The users on that remote server that the - local user no longer wishes to be updated about. - Returns: - A Deferred. - """ - to_unpoll = set() - - for u in remoteusers: - self._remote_recvmap[u].remove(user) - - if not self._remote_recvmap[u]: - del self._remote_recvmap[u] - to_unpoll.add(u) - - if not to_unpoll: - return defer.succeed(None) - - return self.federation.send_edu( - destination=domain, - edu_type="m.presence", - content={"unpoll": [u.to_string() for u in to_unpoll]} - ) - - @defer.inlineCallbacks - @log_function - def push_presence(self, user, statuscache): - """ - Notify local and remote users of a change in presence of a local user. - Pushes the update to local clients and remote domains that are directly - subscribed to the presence of the local user. - Also pushes that update to any local user or remote domain that shares - a room with the local user. - - Args: - user(UserID): The local user whose presence was updated. - statuscache(UserPresenceCache): Cache of the user's presence state - Returns: - A Deferred. - """ - assert(self.hs.is_mine(user)) - - logger.debug("Pushing presence update from %s", user) - - localusers = set(self._local_pushmap.get(user.localpart, set())) - remotedomains = set(self._remote_sendmap.get(user.localpart, set())) - - # Reflect users' status changes back to themselves, so UIs look nice - # and also user is informed of server-forced pushes - localusers.add(user) - - room_ids = yield self.get_joined_rooms_for_user(user) - - if not localusers and not room_ids: - defer.returnValue(None) - - yield self.push_update_to_local_and_remote( - observed_user=user, - users_to_push=localusers, - remote_domains=remotedomains, - room_ids=room_ids, - statuscache=statuscache, - ) - yield user_presence_changed(self.distributor, user, statuscache) - - @defer.inlineCallbacks - def incoming_presence(self, origin, content): - """Handle an incoming m.presence EDU. - For each presence update in the "push" list update our local cache and - notify the appropriate local clients. Only clients that share a room - or are directly subscribed to the presence for a user should be - notified of the update. - For each subscription request in the "poll" list start pushing presence - updates to the remote server. - For unsubscribe request in the "unpoll" list stop pushing presence - updates to the remote server. - - Args: - orgin(str): The source of this m.presence EDU. - content(dict): The content of this m.presence EDU. - Returns: - A Deferred. - """ - deferreds = [] - - for push in content.get("push", []): - user = UserID.from_string(push["user_id"]) - - logger.debug("Incoming presence update from %s", user) - - observers = set(self._remote_recvmap.get(user, set())) - if observers: - logger.debug( - " | %d interested local observers %r", len(observers), observers - ) - - room_ids = yield self.get_joined_rooms_for_user(user) - if room_ids: - logger.debug(" | %d interested room IDs %r", len(room_ids), room_ids) - - state = dict(push) - del state["user_id"] - - if "presence" not in state: - logger.warning( - "Received a presence 'push' EDU from %s without a" - " 'presence' key", origin - ) - continue - - if "last_active_ago" in state: - state["last_active"] = int( - self.clock.time_msec() - state.pop("last_active_ago") - ) - - self._user_cachemap_latest_serial += 1 - yield self.update_presence_cache(user, state, room_ids=room_ids) - - if not observers and not room_ids: - logger.debug(" | no interested observers or room IDs") - continue - - self.push_update_to_clients( - users_to_push=observers, room_ids=room_ids - ) - - user_id = user.to_string() - - if state["presence"] == PresenceState.OFFLINE: - self._remote_offline_serials.insert( - 0, - (self._user_cachemap_latest_serial, set([user_id])) - ) - while len(self._remote_offline_serials) > MAX_OFFLINE_SERIALS: - self._remote_offline_serials.pop() # remove the oldest - if user in self._user_cachemap: - del self._user_cachemap[user] - else: - # Remove the user from remote_offline_serials now that they're - # no longer offline - for idx, elem in enumerate(self._remote_offline_serials): - (_, user_ids) = elem - user_ids.discard(user_id) - if not user_ids: - self._remote_offline_serials.pop(idx) - - for poll in content.get("poll", []): - user = UserID.from_string(poll) - - if not self.hs.is_mine(user): - continue - - # TODO(paul) permissions checks - - if user not in self._remote_sendmap: - self._remote_sendmap[user] = set() - - self._remote_sendmap[user].add(origin) - - deferreds.append(self._push_presence_remote(user, origin)) - - for unpoll in content.get("unpoll", []): - user = UserID.from_string(unpoll) - - if not self.hs.is_mine(user): - continue - - if user in self._remote_sendmap: - self._remote_sendmap[user].remove(origin) - - if not self._remote_sendmap[user]: - del self._remote_sendmap[user] - - yield defer.DeferredList(deferreds, consumeErrors=True) - - @defer.inlineCallbacks - def update_presence_cache(self, user, state={}, room_ids=None, - add_to_cache=True): - """Update the presence cache for a user with a new state and bump the - serial to the latest value. - - Args: - user(UserID): The user being updated - state(dict): The presence state being updated - room_ids(None or list of str): A list of room_ids to update. If - room_ids is None then fetch the list of room_ids the user is - joined to. - add_to_cache: Whether to add an entry to the presence cache if the - user isn't already in the cache. - Returns: - A Deferred UserPresenceCache for the user being updated. - """ - if room_ids is None: - room_ids = yield self.get_joined_rooms_for_user(user) - - for room_id in room_ids: - self._room_serials[room_id] = self._user_cachemap_latest_serial - if add_to_cache: - statuscache = self._get_or_make_usercache(user) - else: - statuscache = self._get_or_offline_usercache(user) - statuscache.update(state, serial=self._user_cachemap_latest_serial) - defer.returnValue(statuscache) + # TODO: Inform the remote that we've dropped the presence list. @defer.inlineCallbacks - def push_update_to_local_and_remote(self, observed_user, statuscache, - users_to_push=[], room_ids=[], - remote_domains=[]): - """Notify local clients and remote servers of a change in the presence - of a user. - - Args: - observed_user(UserID): The user to push the presence state for. - statuscache(UserPresenceCache): The cache for the presence state to - push. - users_to_push([UserID]): A list of local and remote users to - notify. - room_ids([str]): Notify the local and remote occupants of these - rooms. - remote_domains([str]): A list of remote servers to notify in - addition to those implied by the users_to_push and the - room_ids. - Returns: - A Deferred. - """ + def is_visible(self, observed_user, observer_user): + observer_rooms = yield self.store.get_rooms_for_user(observer_user.to_string()) + observed_rooms = yield self.store.get_rooms_for_user(observed_user.to_string()) - localusers, remoteusers = partitionbool( - users_to_push, - lambda u: self.hs.is_mine(u) - ) + observer_room_ids = set(r.room_id for r in observer_rooms) + observed_room_ids = set(r.room_id for r in observed_rooms) - localusers = set(localusers) + if observer_room_ids & observed_room_ids: + defer.returnValue(True) - self.push_update_to_clients( - users_to_push=localusers, room_ids=room_ids + accepted_observers = yield self.store.get_presence_list_observers_accepted( + observed_user.to_string() ) - remote_domains = set(remote_domains) - remote_domains |= set([r.domain for r in remoteusers]) - for room_id in room_ids: - remote_domains.update( - (yield self.store.get_joined_hosts_for_room(room_id)) - ) + defer.returnValue(observer_user.to_string() in accepted_observers) - remote_domains.discard(self.hs.hostname) - - deferreds = [] - for domain in remote_domains: - logger.debug(" | push to remote domain %s", domain) - deferreds.append( - self._push_presence_remote( - observed_user, domain, state=statuscache.get_state() - ) - ) - yield defer.DeferredList(deferreds, consumeErrors=True) +def should_notify(old_state, new_state): + """Decides if a presence state change should be sent to interested parties. + """ + if old_state.status_msg != new_state.status_msg: + return True - defer.returnValue((localusers, remote_domains)) + if old_state.state == PresenceState.ONLINE: + if new_state.state != PresenceState.ONLINE: + # Always notify for online -> anything + return True - def push_update_to_clients(self, users_to_push=[], room_ids=[]): - """Notify clients of a new presence event. + if new_state.currently_active != old_state.currently_active: + return True - Args: - users_to_push([UserID]): List of users to notify. - room_ids([str]): List of room_ids to notify. - """ - with PreserveLoggingContext(): - self.notifier.on_new_event( - "presence_key", - self._user_cachemap_latest_serial, - users_to_push, - room_ids, - ) + if new_state.last_active - old_state.last_active > LAST_ACTIVE_GRANULARITY: + # Always notify for a transition where last active gets bumped. + return True - @defer.inlineCallbacks - def _push_presence_remote(self, user, destination, state=None): - """Push a user's presence to a remote server. If a presence state event - that event is sent. Otherwise a new state event is constructed from the - stored presence state. - The last_active is replaced with last_active_ago in case the wallclock - time on the remote server is different to the time on this server. - Sends an EDU to the remote server with the current presence state. + if old_state.state != new_state.state: + # Nothing to report. + return True - Args: - user(UserID): The user to push the presence state for. - destination(str): The remote server to send state to. - state(dict): The state to push, or None to use the current stored - state. - Returns: - A Deferred. - """ - if state is None: - state = yield self.store.get_presence_state(user.localpart) - del state["mtime"] - state["presence"] = state.pop("state") - - if user in self._user_cachemap: - state["last_active"] = ( - self._user_cachemap[user].get_state()["last_active"] - ) + return False - yield collect_presencelike_data(self.distributor, user, state) - if "last_active" in state: - state = dict(state) - state["last_active_ago"] = int( - self.clock.time_msec() - state.pop("last_active") - ) - - user_state = {"user_id": user.to_string(), } - user_state.update(state) +def _format_user_presence_state(state, now): + """Convert UserPresenceState to a format that can be sent down to clients + and to other servers. + """ + content = { + "presence": state.state, + "user_id": state.user_id, + } + if state.last_active: + content["last_active_ago"] = now - state.last_active + if state.status_msg and state.state != PresenceState.OFFLINE: + content["status_msg"] = state.status_msg + if state.state == PresenceState.ONLINE: + content["currently_active"] = state.currently_active - yield self.federation.send_edu( - destination=destination, - edu_type="m.presence", - content={"push": [user_state, ], } - ) + return content class PresenceEventSource(object): def __init__(self, hs): self.hs = hs self.clock = hs.get_clock() + self.store = hs.get_datastore() @defer.inlineCallbacks @log_function - def get_new_events(self, user, from_key, room_ids=None, **kwargs): - from_key = int(from_key) + def get_new_events(self, user, from_key, room_ids=None, include_offline=True, + **kwargs): + # The process for getting presence events are: + # 1. Get the rooms the user is in. + # 2. Get the list of user in the rooms. + # 3. Get the list of users that are in the user's presence list. + # 4. If there is a from_key set, cross reference the list of users + # with the `presence_stream_cache` to see which ones we actually + # need to check. + # 5. Load current state for the users. + # + # We don't try and limit the presence updates by the current token, as + # sending down the rare duplicate is not a concern. + + user_id = user.to_string() + if from_key is not None: + from_key = int(from_key) room_ids = room_ids or [] presence = self.hs.get_handlers().presence_handler - cachemap = presence._user_cachemap - - max_serial = presence._user_cachemap_latest_serial - - clock = self.clock - latest_serial = 0 - - user_ids_to_check = {user} - presence_list = yield presence.store.get_presence_list( - user.localpart, accepted=True - ) - if presence_list is not None: - user_ids_to_check |= set( - UserID.from_string(p["observed_user_id"]) for p in presence_list - ) - for room_id in set(room_ids) & set(presence._room_serials): - if presence._room_serials[room_id] > from_key: - joined = yield presence.get_joined_users_for_room_id(room_id) - user_ids_to_check |= set(joined) - updates = [] - for observed_user in user_ids_to_check & set(cachemap): - cached = cachemap[observed_user] - - if cached.serial <= from_key or cached.serial > max_serial: - continue - - latest_serial = max(cached.serial, latest_serial) - updates.append(cached.make_event(user=observed_user, clock=clock)) + if not room_ids: + rooms = yield self.store.get_rooms_for_user(user_id) + room_ids = set(e.room_id for e in rooms) - # TODO(paul): limit - - for serial, user_ids in presence._remote_offline_serials: - if serial <= from_key: - break - - if serial > max_serial: - continue - - latest_serial = max(latest_serial, serial) - for u in user_ids: - updates.append({ - "type": "m.presence", - "content": {"user_id": u, "presence": PresenceState.OFFLINE}, - }) - # TODO(paul): For the v2 API we want to tell the client their from_key - # is too old if we fell off the end of the _remote_offline_serials - # list, and get them to invalidate+resync. In v1 we have no such - # concept so this is a best-effort result. - - if updates: - defer.returnValue((updates, latest_serial)) - else: - defer.returnValue(([], presence._user_cachemap_latest_serial)) - - def get_current_key(self): - presence = self.hs.get_handlers().presence_handler - return presence._user_cachemap_latest_serial + user_ids_to_check = set() + for room_id in room_ids: + users = yield self.store.get_users_in_room(room_id) + user_ids_to_check.update(users) - @defer.inlineCallbacks - def get_pagination_rows(self, user, pagination_config, key): - # TODO (erikj): Does this make sense? Ordering? + plist = yield self.store.get_presence_list_accepted(user.localpart) + user_ids_to_check.update([row["observed_user_id"] for row in plist]) - from_key = int(pagination_config.from_key) + # Always include yourself. Only really matters for when the user is + # not in any rooms, but still. + user_ids_to_check.add(user_id) - if pagination_config.to_key: - to_key = int(pagination_config.to_key) - else: - to_key = -1 + max_token = self.store.get_current_presence_token() - presence = self.hs.get_handlers().presence_handler - cachemap = presence._user_cachemap - - user_ids_to_check = {user} - presence_list = yield presence.store.get_presence_list( - user.localpart, accepted=True - ) - if presence_list is not None: - user_ids_to_check |= set( - UserID.from_string(p["observed_user_id"]) for p in presence_list + if from_key: + user_ids_changed = self.store.presence_stream_cache.get_entities_changed( + user_ids_to_check, from_key, ) - room_ids = yield presence.get_joined_rooms_for_user(user) - for room_id in set(room_ids) & set(presence._room_serials): - if presence._room_serials[room_id] >= from_key: - joined = yield presence.get_joined_users_for_room_id(room_id) - user_ids_to_check |= set(joined) - - updates = [] - for observed_user in user_ids_to_check & set(cachemap): - if not (to_key < cachemap[observed_user].serial <= from_key): - continue - - updates.append((observed_user, cachemap[observed_user])) - - # TODO(paul): limit - - if updates: - clock = self.clock - - earliest_serial = max([x[1].serial for x in updates]) - data = [x[1].make_event(user=x[0], clock=clock) for x in updates] - - defer.returnValue((data, earliest_serial)) else: - defer.returnValue(([], 0)) - + user_ids_changed = user_ids_to_check -class UserPresenceCache(object): - """Store an observed user's state and status message. + updates = yield presence.current_state_for_users(user_ids_changed) - Includes the update timestamp. - """ - def __init__(self): - self.state = {"presence": PresenceState.OFFLINE} - self.serial = None - - def __repr__(self): - return "UserPresenceCache(state=%r, serial=%r)" % ( - self.state, self.serial - ) - - def update(self, state, serial): - assert("mtime_age" not in state) + now = self.clock.time_msec() - self.state.update(state) - # Delete keys that are now 'None' - for k in self.state.keys(): - if self.state[k] is None: - del self.state[k] - - self.serial = serial - - if "status_msg" in state: - self.status_msg = state["status_msg"] - else: - self.status_msg = None - - def get_state(self): - # clone it so caller can't break our cache - state = dict(self.state) - return state - - def make_event(self, user, clock): - content = self.get_state() - content["user_id"] = user.to_string() + defer.returnValue(([ + { + "type": "m.presence", + "content": _format_user_presence_state(s, now), + } + for s in updates.values() + if include_offline or s.state != PresenceState.OFFLINE + ], max_token)) - if "last_active" in content: - content["last_active_ago"] = int( - clock.time_msec() - content.pop("last_active") - ) + def get_current_key(self): + return self.store.get_current_presence_token() - return {"type": "m.presence", "content": content} + def get_pagination_rows(self, user, pagination_config, key): + return self.get_new_events(user, from_key=None, include_offline=False) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 629e6e3594..7084a7396f 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -49,6 +49,9 @@ class ProfileHandler(BaseHandler): distributor = hs.get_distributor() self.distributor = distributor + distributor.declare("collect_presencelike_data") + distributor.declare("changed_presencelike_data") + distributor.observe("registered_user", self.registered_user) distributor.observe( diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 1d0f0058a2..c5c13e085b 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -582,6 +582,28 @@ class SyncHandler(BaseHandler): if room_sync: joined.append(room_sync) + # For each newly joined room, we want to send down presence of + # existing users. + presence_handler = self.hs.get_handlers().presence_handler + extra_presence_users = set() + for room_id in newly_joined_rooms: + users = yield self.store.get_users_in_room(event.room_id) + extra_presence_users.update(users) + + # For each new member, send down presence. + for joined_sync in joined: + it = itertools.chain(joined_sync.timeline.events, joined_sync.state.values()) + for event in it: + if event.type == EventTypes.Member: + if event.membership == Membership.JOIN: + extra_presence_users.add(event.state_key) + + states = yield presence_handler.get_states( + [u for u in extra_presence_users if u != user_id], + as_event=True, + ) + presence.extend(states) + account_data_for_user = sync_config.filter_collection.filter_account_data( self.account_data_for_user(account_data) ) diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index a6f8754e32..27ea5f2a43 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -17,7 +17,7 @@ """ from twisted.internet import defer -from synapse.api.errors import SynapseError +from synapse.api.errors import SynapseError, AuthError from synapse.types import UserID from .base import ClientV1RestServlet, client_path_patterns @@ -35,8 +35,15 @@ class PresenceStatusRestServlet(ClientV1RestServlet): requester = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) - state = yield self.handlers.presence_handler.get_state( - target_user=user, auth_user=requester.user) + if requester.user != user: + allowed = yield self.handlers.presence_handler.is_visible( + observed_user=user, observer_user=requester.user, + ) + + if not allowed: + raise AuthError(403, "You are allowed to see their presence.") + + state = yield self.handlers.presence_handler.get_state(target_user=user) defer.returnValue((200, state)) @@ -45,6 +52,9 @@ class PresenceStatusRestServlet(ClientV1RestServlet): requester = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) + if requester.user != user: + raise AuthError(403, "Can only set your own presence state") + state = {} try: content = json.loads(request.content.read()) @@ -63,8 +73,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet): except: raise SynapseError(400, "Unable to parse state") - yield self.handlers.presence_handler.set_state( - target_user=user, auth_user=requester.user, state=state) + yield self.handlers.presence_handler.set_state(user, state) defer.returnValue((200, {})) @@ -87,11 +96,8 @@ class PresenceListRestServlet(ClientV1RestServlet): raise SynapseError(400, "Cannot get another user's presence list") presence = yield self.handlers.presence_handler.get_presence_list( - observer_user=user, accepted=True) - - for p in presence: - observed_user = p.pop("observed_user") - p["user_id"] = observed_user.to_string() + observer_user=user, accepted=True + ) defer.returnValue((200, presence)) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 24706f9383..a8e89c7fe9 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -304,18 +304,6 @@ class RoomMemberListRestServlet(ClientV1RestServlet): if event["type"] != EventTypes.Member: continue chunk.append(event) - # FIXME: should probably be state_key here, not user_id - target_user = UserID.from_string(event["user_id"]) - # Presence is an optional cache; don't fail if we can't fetch it - try: - presence_handler = self.handlers.presence_handler - presence_state = yield presence_handler.get_state( - target_user=target_user, - auth_user=requester.user, - ) - event["content"].update(presence_state) - except: - pass defer.returnValue((200, { "chunk": chunk @@ -541,6 +529,10 @@ class RoomTypingRestServlet(ClientV1RestServlet): "/rooms/(?P[^/]*)/typing/(?P[^/]*)$" ) + def __init__(self, hs): + super(RoomTypingRestServlet, self).__init__(hs) + self.presence_handler = hs.get_handlers().presence_handler + @defer.inlineCallbacks def on_PUT(self, request, room_id, user_id): requester = yield self.auth.get_user_by_req(request) @@ -552,6 +544,8 @@ class RoomTypingRestServlet(ClientV1RestServlet): typing_handler = self.handlers.typing_notification_handler + yield self.presence_handler.bump_presence_active_time(requester.user) + if content["typing"]: yield typing_handler.started_typing( target_user=target_user, diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index eb4b369a3d..b831d8c95e 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -37,6 +37,7 @@ class ReceiptRestServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() self.receipts_handler = hs.get_handlers().receipts_handler + self.presence_handler = hs.get_handlers().presence_handler @defer.inlineCallbacks def on_POST(self, request, room_id, receipt_type, event_id): @@ -45,6 +46,8 @@ class ReceiptRestServlet(RestServlet): if receipt_type != "m.read": raise SynapseError(400, "Receipt type must be 'm.read'") + yield self.presence_handler.bump_presence_active_time(requester.user) + yield self.receipts_handler.received_client_receipt( room_id, receipt_type, diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index accbc6cfac..de4a020ad4 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -25,6 +25,7 @@ from synapse.events.utils import ( ) from synapse.api.filtering import FilterCollection, DEFAULT_FILTER_COLLECTION from synapse.api.errors import SynapseError +from synapse.api.constants import PresenceState from ._base import client_v2_patterns import copy @@ -82,6 +83,7 @@ class SyncRestServlet(RestServlet): self.sync_handler = hs.get_handlers().sync_handler self.clock = hs.get_clock() self.filtering = hs.get_filtering() + self.presence_handler = hs.get_handlers().presence_handler @defer.inlineCallbacks def on_GET(self, request): @@ -139,17 +141,19 @@ class SyncRestServlet(RestServlet): else: since_token = None - if set_presence == "online": - yield self.event_stream_handler.started_stream(user) + affect_presence = set_presence != PresenceState.OFFLINE - try: + if affect_presence: + yield self.presence_handler.set_state(user, {"presence": set_presence}) + + context = yield self.presence_handler.user_syncing( + user.to_string(), affect_presence=affect_presence, + ) + with context: sync_result = yield self.sync_handler.wait_for_sync_for_user( sync_config, since_token=since_token, timeout=timeout, full_state=full_state ) - finally: - if set_presence == "online": - self.event_stream_handler.stopped_stream(user) time_now = self.clock.time_msec() diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 5a9e7720d9..8c3cf9e801 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -20,7 +20,7 @@ from .appservice import ( from ._base import Cache from .directory import DirectoryStore from .events import EventsStore -from .presence import PresenceStore +from .presence import PresenceStore, UserPresenceState from .profile import ProfileStore from .registration import RegistrationStore from .room import RoomStore @@ -47,6 +47,7 @@ from .account_data import AccountDataStore from util.id_generators import IdGenerator, StreamIdGenerator +from synapse.api.constants import PresenceState from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -110,6 +111,9 @@ class DataStore(RoomMemberStore, RoomStore, self._account_data_id_gen = StreamIdGenerator( db_conn, "account_data_max_stream_id", "stream_id" ) + self._presence_id_gen = StreamIdGenerator( + db_conn, "presence_stream", "stream_id" + ) self._transaction_id_gen = IdGenerator("sent_transactions", "id", self) self._state_groups_id_gen = IdGenerator("state_groups", "id", self) @@ -119,7 +123,7 @@ class DataStore(RoomMemberStore, RoomStore, self._push_rule_id_gen = IdGenerator("push_rules", "id", self) self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self) - events_max = self._stream_id_gen.get_max_token(None) + events_max = self._stream_id_gen.get_max_token() event_cache_prefill, min_event_val = self._get_cache_dict( db_conn, "events", entity_column="room_id", @@ -135,13 +139,31 @@ class DataStore(RoomMemberStore, RoomStore, "MembershipStreamChangeCache", events_max, ) - account_max = self._account_data_id_gen.get_max_token(None) + account_max = self._account_data_id_gen.get_max_token() self._account_data_stream_cache = StreamChangeCache( "AccountDataAndTagsChangeCache", account_max, ) + self.__presence_on_startup = self._get_active_presence(db_conn) + + presence_cache_prefill, min_presence_val = self._get_cache_dict( + db_conn, "presence_stream", + entity_column="user_id", + stream_column="stream_id", + max_value=self._presence_id_gen.get_max_token(), + ) + self.presence_stream_cache = StreamChangeCache( + "PresenceStreamChangeCache", min_presence_val, + prefilled_cache=presence_cache_prefill + ) + super(DataStore, self).__init__(hs) + def take_presence_startup_info(self): + active_on_startup = self.__presence_on_startup + self.__presence_on_startup = None + return active_on_startup + def _get_cache_dict(self, db_conn, table, entity_column, stream_column, max_value): # Fetch a mapping of room_id -> max stream position for "recent" rooms. # It doesn't really matter how many we get, the StreamChangeCache will @@ -161,6 +183,7 @@ class DataStore(RoomMemberStore, RoomStore, txn = db_conn.cursor() txn.execute(sql, (int(max_value),)) rows = txn.fetchall() + txn.close() cache = { row[0]: int(row[1]) @@ -174,6 +197,27 @@ class DataStore(RoomMemberStore, RoomStore, return cache, min_val + def _get_active_presence(self, db_conn): + """Fetch non-offline presence from the database so that we can register + the appropriate time outs. + """ + + sql = ( + "SELECT user_id, state, last_active, last_federation_update," + " last_user_sync, status_msg, currently_active FROM presence_stream" + " WHERE state != ?" + ) + sql = self.database_engine.convert_param_style(sql) + + txn = db_conn.cursor() + txn.execute(sql, (PresenceState.OFFLINE,)) + rows = self.cursor_to_dict(txn) + + for row in rows: + row["currently_active"] = bool(row["currently_active"]) + + return [UserPresenceState(**row) for row in rows] + @defer.inlineCallbacks def insert_client_ip(self, user, access_token, ip, user_agent): now = int(self._clock.time_msec()) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 850736c85e..0fd5d497ab 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 29 +SCHEMA_VERSION = 30 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index ef525f34c5..b133979102 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -14,73 +14,128 @@ # limitations under the License. from ._base import SQLBaseStore -from synapse.util.caches.descriptors import cached, cachedList +from synapse.api.constants import PresenceState +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from collections import namedtuple from twisted.internet import defer -class PresenceStore(SQLBaseStore): - def create_presence(self, user_localpart): - res = self._simple_insert( - table="presence", - values={"user_id": user_localpart}, - desc="create_presence", +class UserPresenceState(namedtuple("UserPresenceState", + ("user_id", "state", "last_active", "last_federation_update", + "last_user_sync", "status_msg", "currently_active"))): + """Represents the current presence state of the user. + + user_id (str) + last_active (int): Time in msec that the user last interacted with server. + last_federation_update (int): Time in msec since either a) we sent a presence + update to other servers or b) we received a presence update, depending + on if is a local user or not. + last_user_sync (int): Time in msec that the user last *completed* a sync + (or event stream). + status_msg (str): User set status message. + """ + + def copy_and_replace(self, **kwargs): + return self._replace(**kwargs) + + @classmethod + def default(cls, user_id): + """Returns a default presence state. + """ + return cls( + user_id=user_id, + state=PresenceState.OFFLINE, + last_active=0, + last_federation_update=0, + last_user_sync=0, + status_msg=None, + currently_active=False, ) - self.get_presence_state.invalidate((user_localpart,)) - return res - def has_presence_state(self, user_localpart): - return self._simple_select_one( - table="presence", - keyvalues={"user_id": user_localpart}, - retcols=["user_id"], - allow_none=True, - desc="has_presence_state", +class PresenceStore(SQLBaseStore): + @defer.inlineCallbacks + def update_presence(self, presence_states): + stream_id_manager = yield self._presence_id_gen.get_next(self) + with stream_id_manager as stream_id: + yield self.runInteraction( + "update_presence", + self._update_presence_txn, stream_id, presence_states, + ) + + defer.returnValue((stream_id, self._presence_id_gen.get_max_token())) + + def _update_presence_txn(self, txn, stream_id, presence_states): + for state in presence_states: + txn.call_after( + self.presence_stream_cache.entity_has_changed, + state.user_id, stream_id, + ) + + # Actually insert new rows + self._simple_insert_many_txn( + txn, + table="presence_stream", + values=[ + { + "stream_id": stream_id, + "user_id": state.user_id, + "state": state.state, + "last_active": state.last_active, + "last_federation_update": state.last_federation_update, + "last_user_sync": state.last_user_sync, + "status_msg": state.status_msg, + "currently_active": state.currently_active, + } + for state in presence_states + ], ) - @cached(max_entries=2000) - def get_presence_state(self, user_localpart): - return self._simple_select_one( - table="presence", - keyvalues={"user_id": user_localpart}, - retcols=["state", "status_msg", "mtime"], - desc="get_presence_state", + # Delete old rows to stop database from getting really big + sql = ( + "DELETE FROM presence_stream WHERE" + " stream_id < ?" + " AND user_id IN (%s)" ) - @cachedList(get_presence_state.cache, list_name="user_localparts", - inlineCallbacks=True) - def get_presence_states(self, user_localparts): + batches = ( + presence_states[i:i + 50] + for i in xrange(0, len(presence_states), 50) + ) + for states in batches: + args = [stream_id] + args.extend(s.user_id for s in states) + txn.execute( + sql % (",".join("?" for _ in states),), + args + ) + + @defer.inlineCallbacks + def get_presence_for_users(self, user_ids): rows = yield self._simple_select_many_batch( - table="presence", + table="presence_stream", column="user_id", - iterable=user_localparts, - retcols=("user_id", "state", "status_msg", "mtime",), - desc="get_presence_states", + iterable=user_ids, + keyvalues={}, + retcols=( + "user_id", + "state", + "last_active", + "last_federation_update", + "last_user_sync", + "status_msg", + "currently_active", + ), ) - defer.returnValue({ - row["user_id"]: { - "state": row["state"], - "status_msg": row["status_msg"], - "mtime": row["mtime"], - } - for row in rows - }) + for row in rows: + row["currently_active"] = bool(row["currently_active"]) - @defer.inlineCallbacks - def set_presence_state(self, user_localpart, new_state): - res = yield self._simple_update_one( - table="presence", - keyvalues={"user_id": user_localpart}, - updatevalues={"state": new_state["state"], - "status_msg": new_state["status_msg"], - "mtime": self._clock.time_msec()}, - desc="set_presence_state", - ) + defer.returnValue([UserPresenceState(**row) for row in rows]) - self.get_presence_state.invalidate((user_localpart,)) - defer.returnValue(res) + def get_current_presence_token(self): + return self._presence_id_gen.get_max_token() def allow_presence_visible(self, observed_localpart, observer_userid): return self._simple_insert( @@ -128,6 +183,7 @@ class PresenceStore(SQLBaseStore): desc="set_presence_list_accepted", ) self.get_presence_list_accepted.invalidate((observer_localpart,)) + self.get_presence_list_observers_accepted.invalidate((observed_userid,)) defer.returnValue(result) def get_presence_list(self, observer_localpart, accepted=None): @@ -154,6 +210,19 @@ class PresenceStore(SQLBaseStore): desc="get_presence_list_accepted", ) + @cachedInlineCallbacks() + def get_presence_list_observers_accepted(self, observed_userid): + user_localparts = yield self._simple_select_onecol( + table="presence_list", + keyvalues={"observed_user_id": observed_userid, "accepted": True}, + retcol="user_id", + desc="get_presence_list_accepted", + ) + + defer.returnValue([ + "@%s:%s" % (u, self.hs.hostname,) for u in user_localparts + ]) + @defer.inlineCallbacks def del_presence_list(self, observer_localpart, observed_userid): yield self._simple_delete_one( @@ -163,3 +232,4 @@ class PresenceStore(SQLBaseStore): desc="del_presence_list", ) self.get_presence_list_accepted.invalidate((observer_localpart,)) + self.get_presence_list_observers_accepted.invalidate((observed_userid,)) diff --git a/synapse/storage/schema/delta/30/presence_stream.sql b/synapse/storage/schema/delta/30/presence_stream.sql new file mode 100644 index 0000000000..14f5e3d30a --- /dev/null +++ b/synapse/storage/schema/delta/30/presence_stream.sql @@ -0,0 +1,30 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + + CREATE TABLE presence_stream( + stream_id BIGINT, + user_id TEXT, + state TEXT, + last_active BIGINT, + last_federation_update BIGINT, + last_user_sync BIGINT, + status_msg TEXT, + currently_active BOOLEAN + ); + + CREATE INDEX presence_stream_id ON presence_stream(stream_id, user_id); + CREATE INDEX presence_stream_user_id ON presence_stream(user_id); + CREATE INDEX presence_stream_state ON presence_stream(state); diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 5c522f4ab9..5ce54f76de 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -130,9 +130,11 @@ class StreamIdGenerator(object): return manager() - def get_max_token(self, store): + def get_max_token(self, *args): """Returns the maximum stream id such that all stream ids less than or equal to it have been successfully persisted. + + Used to take a DataStore param, which is no longer needed. """ with self._lock: if self._unfinished_ids: diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 133671e238..3b9da5b34a 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -42,7 +42,7 @@ class Clock(object): def time_msec(self): """Returns the current system time in miliseconds since epoch.""" - return self.time() * 1000 + return int(self.time() * 1000) def looping_call(self, f, msec): l = task.LoopingCall(f) diff --git a/tests/utils.py b/tests/utils.py index 3b1eb50d8d..f71125042b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -224,12 +224,12 @@ class MockClock(object): def time_msec(self): return self.time() * 1000 - def call_later(self, delay, callback): + def call_later(self, delay, callback, *args, **kwargs): current_context = LoggingContext.current_context() def wrapped_callback(): LoggingContext.thread_local.current_context = current_context - callback() + callback(*args, **kwargs) t = [self.now + delay, wrapped_callback, False] self.timers.append(t) -- cgit 1.4.1 From 591af2d074044a70a48b033c4dfc322f58189d3e Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Wed, 17 Feb 2016 15:50:13 +0000 Subject: Some cleanup I'm not particularly happy with the "action" switching, but there's no convenient way to defer the work that needs to happen after it, so... :( --- synapse/handlers/room.py | 124 +++++++++++++++++++---------------------- synapse/rest/client/v1/room.py | 6 +- 2 files changed, 61 insertions(+), 69 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index f85a5f2677..cd04ac09fa 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -397,7 +397,7 @@ class RoomMemberHandler(BaseHandler): room_id, action, txn_id=None, - room_hosts=None, + remote_room_hosts=None, ratelimit=True, ): effective_membership_state = action @@ -448,7 +448,7 @@ class RoomMemberHandler(BaseHandler): context, is_guest=requester.is_guest, ratelimit=ratelimit, - room_hosts=room_hosts, + remote_room_hosts=remote_room_hosts, from_client=True, ) @@ -461,11 +461,12 @@ class RoomMemberHandler(BaseHandler): event, context, is_guest=False, - room_hosts=None, + remote_room_hosts=None, ratelimit=True, from_client=True, ): - """ Change the membership status of a user in a room. + """ + Change the membership status of a user in a room. Args: event (SynapseEvent): The membership event. @@ -482,78 +483,64 @@ class RoomMemberHandler(BaseHandler): Raises: SynapseError if there was a problem changing the membership. """ - user = UserID.from_string(event.sender) + target_user = UserID.from_string(event.state_key) + room_id = event.room_id if from_client: - assert self.hs.is_mine(user), "User must be our own: %s" % (user,) + sender = UserID.from_string(event.sender) + assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,) if event.is_state(): message_handler = self.hs.get_handlers().message_handler - prev_state = message_handler.deduplicate_state_event(event, context) - if prev_state is not None: + prev_event = message_handler.deduplicate_state_event(event, context) + if prev_event is not None: return - target_user = UserID.from_string(event.state_key) - - prev_state = context.current_state.get( - (EventTypes.Member, target_user.to_string()), - None - ) - - room_id = event.room_id - - handled = False + action = "send" if event.membership == Membership.JOIN: if is_guest and not self._can_guest_join(context.current_state): # This should be an auth check, but guests are a local concept, # so don't really fit into the general auth process. raise AuthError(403, "Guest access not allowed") - - should_do_dance, room_hosts = self._should_do_dance( + do_remote_join_dance, remote_room_hosts = self._should_do_dance( context, - (self.get_inviter(target_user.to_string(), context.current_state)), - room_hosts, + (self.get_inviter(event.state_key, context.current_state)), + remote_room_hosts, ) - - if should_do_dance: - if len(room_hosts) == 0: - # return the same error as join_room_alias does - raise SynapseError(404, "No known servers") - - # We don't do an auth check if we are doing an invite - # join dance for now, since we're kinda implicitly checking - # that we are allowed to join when we decide whether or not we - # need to do the invite/join dance. - yield self.hs.get_handlers().federation_handler.do_invite_join( - room_hosts, - room_id, - event.user_id, - event.content, - ) - handled = True - if event.membership == Membership.LEAVE: + if do_remote_join_dance: + action = "remote_join" + elif event.membership == Membership.LEAVE: is_host_in_room = self.is_host_in_room(context.current_state) if not is_host_in_room: - # Rejecting an invite, rather than leaving a joined room - handler = self.hs.get_handlers().federation_handler - inviter = self.get_inviter(target_user.to_string(), context.current_state) - if not inviter: - # return the same error as join_room_alias does - raise SynapseError(404, "No known servers") - yield handler.do_remotely_reject_invite( - [inviter.domain], - room_id, - event.user_id - ) - handled = True - - # FIXME: This isn't idempotency. - if prev_state and prev_state.membership == event.membership: - # double same action, treat this event as a NOOP. - return - - if not handled: + action = "remote_reject" + + federation_handler = self.hs.get_handlers().federation_handler + + if action == "remote_join": + if len(remote_room_hosts) == 0: + raise SynapseError(404, "No known servers") + + # We don't do an auth check if we are doing an invite + # join dance for now, since we're kinda implicitly checking + # that we are allowed to join when we decide whether or not we + # need to do the invite/join dance. + yield federation_handler.do_invite_join( + remote_room_hosts, + event.room_id, + event.user_id, + event.content, + ) + elif action == "remote_reject": + inviter = self.get_inviter(target_user.to_string(), context.current_state) + if not inviter: + raise SynapseError(404, "No known servers") + yield federation_handler.do_remotely_reject_invite( + [inviter.domain], + room_id, + event.user_id + ) + else: yield self.handle_new_client_event( event, context, @@ -561,14 +548,19 @@ class RoomMemberHandler(BaseHandler): ratelimit=ratelimit, ) + prev_member_event = context.current_state.get( + (EventTypes.Member, target_user.to_string()), + None + ) + if event.membership == Membership.JOIN: - if not prev_state or prev_state.membership != Membership.JOIN: + if not prev_member_event or prev_member_event.membership != Membership.JOIN: # Only fire user_joined_room if the user has acutally joined the # room. Don't bother if the user is just changing their profile # info. yield user_joined_room(self.distributor, target_user, room_id) elif event.membership == Membership.LEAVE: - if prev_state and prev_state.membership == Membership.JOIN: + if prev_member_event and prev_member_event.membership == Membership.JOIN: user_left_room(self.distributor, target_user, room_id) def _can_guest_join(self, current_state): @@ -604,7 +596,9 @@ class RoomMemberHandler(BaseHandler): Args: room_alias (RoomAlias): The alias to look up. Returns: - The room ID as a RoomID object. + A tuple of: + The room ID as a RoomID object. + Hosts likely to be participating in the room ([str]). Raises: SynapseError if room alias could not be found. """ @@ -615,11 +609,9 @@ class RoomMemberHandler(BaseHandler): raise SynapseError(404, "No such room alias") room_id = mapping["room_id"] - hosts = mapping["servers"] - if not hosts: - raise SynapseError(404, "No known servers") + servers = mapping["servers"] - defer.returnValue((RoomID.from_string(room_id), hosts)) + defer.returnValue((RoomID.from_string(room_id), servers)) def get_inviter(self, user_id, current_state): prev_state = current_state.get((EventTypes.Member, user_id)) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 179fe9a010..1f5ee09dcc 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -230,11 +230,11 @@ class JoinRoomAliasServlet(ClientV1RestServlet): if RoomID.is_valid(room_identifier): room_id = room_identifier - room_hosts = None + remote_room_hosts = None elif RoomAlias.is_valid(room_identifier): handler = self.handlers.room_member_handler room_alias = RoomAlias.from_string(room_identifier) - room_id, room_hosts = yield handler.lookup_room_alias(room_alias) + room_id, remote_room_hosts = yield handler.lookup_room_alias(room_alias) room_id = room_id.to_string() else: raise SynapseError(400, "%s was not legal room ID or room alias" % ( @@ -247,7 +247,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet): room_id=room_id, action="join", txn_id=txn_id, - room_hosts=room_hosts, + remote_room_hosts=remote_room_hosts, ) defer.returnValue((200, {"room_id": room_id})) -- cgit 1.4.1 From b9977ea667889f6cf89464c92fc57cbcae7cca28 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 18 Feb 2016 16:05:13 +0000 Subject: Remove dead code for setting device specific rules. It wasn't possible to hit the code from the API because of a typo in parsing the request path. Since no-one was using the feature we might as well remove the dead code. --- synapse/push/__init__.py | 7 ++- synapse/push/action_generator.py | 2 +- synapse/push/bulk_push_rule_evaluator.py | 2 +- synapse/push/httppusher.py | 3 +- synapse/push/push_rule_evaluator.py | 15 ++---- synapse/push/pusherpool.py | 48 +++++++---------- synapse/rest/client/v1/push_rule.py | 90 ++------------------------------ synapse/rest/client/v1/pusher.py | 6 +-- synapse/storage/event_push_actions.py | 7 ++- synapse/storage/pusher.py | 6 +-- 10 files changed, 45 insertions(+), 141 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py index 8da2d8716c..4c6c3b83a2 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py @@ -47,14 +47,13 @@ class Pusher(object): MAX_BACKOFF = 60 * 60 * 1000 GIVE_UP_AFTER = 24 * 60 * 60 * 1000 - def __init__(self, _hs, profile_tag, user_id, app_id, + def __init__(self, _hs, user_id, app_id, app_display_name, device_display_name, pushkey, pushkey_ts, data, last_token, last_success, failing_since): self.hs = _hs self.evStreamHandler = self.hs.get_handlers().event_stream_handler self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() - self.profile_tag = profile_tag self.user_id = user_id self.app_id = app_id self.app_display_name = app_display_name @@ -186,8 +185,8 @@ class Pusher(object): processed = False rule_evaluator = yield \ - push_rule_evaluator.evaluator_for_user_id_and_profile_tag( - self.user_id, self.profile_tag, single_event['room_id'], self.store + push_rule_evaluator.evaluator_for_user_id( + self.user_id, single_event['room_id'], self.store ) actions = yield rule_evaluator.actions_for_event(single_event) diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py index e0da0868ec..c6c1dc769e 100644 --- a/synapse/push/action_generator.py +++ b/synapse/push/action_generator.py @@ -44,5 +44,5 @@ class ActionGenerator: ) context.push_actions = [ - (uid, None, actions) for uid, actions in actions_by_user.items() + (uid, actions) for uid, actions in actions_by_user.items() ] diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 8ac5ceb9ef..0a23b3f102 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -152,7 +152,7 @@ def _condition_checker(evaluator, conditions, uid, display_name, cache): elif res is True: continue - res = evaluator.matches(cond, uid, display_name, None) + res = evaluator.matches(cond, uid, display_name) if _id: cache[_id] = bool(res) diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index cdc4494928..9be4869360 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -23,12 +23,11 @@ logger = logging.getLogger(__name__) class HttpPusher(Pusher): - def __init__(self, _hs, profile_tag, user_id, app_id, + def __init__(self, _hs, user_id, app_id, app_display_name, device_display_name, pushkey, pushkey_ts, data, last_token, last_success, failing_since): super(HttpPusher, self).__init__( _hs, - profile_tag, user_id, app_id, app_display_name, diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 2a2b4437dc..98e2a2015e 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -33,7 +33,7 @@ INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$") @defer.inlineCallbacks -def evaluator_for_user_id_and_profile_tag(user_id, profile_tag, room_id, store): +def evaluator_for_user_id(user_id, room_id, store): rawrules = yield store.get_push_rules_for_user(user_id) enabled_map = yield store.get_push_rules_enabled_for_user(user_id) our_member_event = yield store.get_current_state( @@ -43,7 +43,7 @@ def evaluator_for_user_id_and_profile_tag(user_id, profile_tag, room_id, store): ) defer.returnValue(PushRuleEvaluator( - user_id, profile_tag, rawrules, enabled_map, + user_id, rawrules, enabled_map, room_id, our_member_event, store )) @@ -77,10 +77,9 @@ def _room_member_count(ev, condition, room_member_count): class PushRuleEvaluator: DEFAULT_ACTIONS = [] - def __init__(self, user_id, profile_tag, raw_rules, enabled_map, room_id, + def __init__(self, user_id, raw_rules, enabled_map, room_id, our_member_event, store): self.user_id = user_id - self.profile_tag = profile_tag self.room_id = room_id self.our_member_event = our_member_event self.store = store @@ -152,7 +151,7 @@ class PushRuleEvaluator: matches = True for c in conditions: matches = evaluator.matches( - c, self.user_id, my_display_name, self.profile_tag + c, self.user_id, my_display_name ) if not matches: break @@ -189,13 +188,9 @@ class PushRuleEvaluatorForEvent(object): # Maps strings of e.g. 'content.body' -> event["content"]["body"] self._value_cache = _flatten_dict(event) - def matches(self, condition, user_id, display_name, profile_tag): + def matches(self, condition, user_id, display_name): if condition['kind'] == 'event_match': return self._event_match(condition, user_id) - elif condition['kind'] == 'device': - if 'profile_tag' not in condition: - return True - return condition['profile_tag'] == profile_tag elif condition['kind'] == 'contains_display_name': return self._contains_display_name(display_name) elif condition['kind'] == 'room_member_count': diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index d7dcb2de4b..a05aa5f661 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -29,6 +29,7 @@ class PusherPool: def __init__(self, _hs): self.hs = _hs self.store = self.hs.get_datastore() + self.clock = self.hs.get_clock() self.pushers = {} self.last_pusher_started = -1 @@ -38,8 +39,11 @@ class PusherPool: self._start_pushers(pushers) @defer.inlineCallbacks - def add_pusher(self, user_id, access_token, profile_tag, kind, app_id, - app_display_name, device_display_name, pushkey, lang, data): + def add_pusher(self, user_id, access_token, kind, app_id, + app_display_name, device_display_name, pushkey, lang, data, + profile_tag=""): + time_now_msec = self.clock.time_msec() + # we try to create the pusher just to validate the config: it # will then get pulled out of the database, # recreated, added and started: this means we have only one @@ -47,23 +51,31 @@ class PusherPool: self._create_pusher({ "user_name": user_id, "kind": kind, - "profile_tag": profile_tag, "app_id": app_id, "app_display_name": app_display_name, "device_display_name": device_display_name, "pushkey": pushkey, - "ts": self.hs.get_clock().time_msec(), + "ts": time_now_msec, "lang": lang, "data": data, "last_token": None, "last_success": None, "failing_since": None }) - yield self._add_pusher_to_store( - user_id, access_token, profile_tag, kind, app_id, - app_display_name, device_display_name, - pushkey, lang, data + yield self.store.add_pusher( + user_id=user_id, + access_token=access_token, + kind=kind, + app_id=app_id, + app_display_name=app_display_name, + device_display_name=device_display_name, + pushkey=pushkey, + pushkey_ts=time_now_msec, + lang=lang, + data=data, + profile_tag=profile_tag, ) + yield self._refresh_pusher(app_id, pushkey, user_id) @defer.inlineCallbacks def remove_pushers_by_app_id_and_pushkey_not_user(self, app_id, pushkey, @@ -94,30 +106,10 @@ class PusherPool: ) yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) - @defer.inlineCallbacks - def _add_pusher_to_store(self, user_id, access_token, profile_tag, kind, - app_id, app_display_name, device_display_name, - pushkey, lang, data): - yield self.store.add_pusher( - user_id=user_id, - access_token=access_token, - profile_tag=profile_tag, - kind=kind, - app_id=app_id, - app_display_name=app_display_name, - device_display_name=device_display_name, - pushkey=pushkey, - pushkey_ts=self.hs.get_clock().time_msec(), - lang=lang, - data=data, - ) - yield self._refresh_pusher(app_id, pushkey, user_id) - def _create_pusher(self, pusherdict): if pusherdict['kind'] == 'http': return HttpPusher( self.hs, - profile_tag=pusherdict['profile_tag'], user_id=pusherdict['user_name'], app_id=pusherdict['app_id'], app_display_name=pusherdict['app_display_name'], diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 7766b8be1d..5db2805d68 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -60,7 +60,6 @@ class PushRuleRestServlet(ClientV1RestServlet): spec['template'], spec['rule_id'], content, - device=spec['device'] if 'device' in spec else None ) except InvalidRuleException as e: raise SynapseError(400, e.message) @@ -153,23 +152,7 @@ class PushRuleRestServlet(ClientV1RestServlet): elif pattern_type == "user_localpart": c["pattern"] = user.localpart - if r['priority_class'] > PRIORITY_CLASS_MAP['override']: - # per-device rule - profile_tag = _profile_tag_from_conditions(r["conditions"]) - r = _strip_device_condition(r) - if not profile_tag: - continue - if profile_tag not in rules['device']: - rules['device'][profile_tag] = {} - rules['device'][profile_tag] = ( - _add_empty_priority_class_arrays( - rules['device'][profile_tag] - ) - ) - - rulearray = rules['device'][profile_tag][template_name] - else: - rulearray = rules['global'][template_name] + rulearray = rules['global'][template_name] template_rule = _rule_to_template(r) if template_rule: @@ -195,24 +178,6 @@ class PushRuleRestServlet(ClientV1RestServlet): path = path[1:] result = _filter_ruleset_with_path(rules['global'], path) defer.returnValue((200, result)) - elif path[0] == 'device': - path = path[1:] - if path == []: - raise UnrecognizedRequestError( - PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR - ) - if path[0] == '': - defer.returnValue((200, rules['device'])) - - profile_tag = path[0] - path = path[1:] - if profile_tag not in rules['device']: - ret = {} - ret = _add_empty_priority_class_arrays(ret) - defer.returnValue((200, ret)) - ruleset = rules['device'][profile_tag] - result = _filter_ruleset_with_path(ruleset, path) - defer.returnValue((200, result)) else: raise UnrecognizedRequestError() @@ -252,16 +217,9 @@ def _rule_spec_from_path(path): scope = path[1] path = path[2:] - if scope not in ['global', 'device']: + if scope != 'global': raise UnrecognizedRequestError() - device = None - if scope == 'device': - if len(path) == 0: - raise UnrecognizedRequestError() - device = path[0] - path = path[1:] - if len(path) == 0: raise UnrecognizedRequestError() @@ -278,8 +236,6 @@ def _rule_spec_from_path(path): 'template': template, 'rule_id': rule_id } - if device: - spec['profile_tag'] = device path = path[1:] @@ -289,7 +245,7 @@ def _rule_spec_from_path(path): return spec -def _rule_tuple_from_request_object(rule_template, rule_id, req_obj, device=None): +def _rule_tuple_from_request_object(rule_template, rule_id, req_obj): if rule_template in ['override', 'underride']: if 'conditions' not in req_obj: raise InvalidRuleException("Missing 'conditions'") @@ -322,12 +278,6 @@ def _rule_tuple_from_request_object(rule_template, rule_id, req_obj, device=None else: raise InvalidRuleException("Unknown rule template: %s" % (rule_template,)) - if device: - conditions.append({ - 'kind': 'device', - 'profile_tag': device - }) - if 'actions' not in req_obj: raise InvalidRuleException("No actions found") actions = req_obj['actions'] @@ -349,17 +299,6 @@ def _add_empty_priority_class_arrays(d): return d -def _profile_tag_from_conditions(conditions): - """ - Given a list of conditions, return the profile tag of the - device rule if there is one - """ - for c in conditions: - if c['kind'] == 'device': - return c['profile_tag'] - return None - - def _filter_ruleset_with_path(ruleset, path): if path == []: raise UnrecognizedRequestError( @@ -403,19 +342,11 @@ def _priority_class_from_spec(spec): raise InvalidRuleException("Unknown template: %s" % (spec['template'])) pc = PRIORITY_CLASS_MAP[spec['template']] - if spec['scope'] == 'device': - pc += len(PRIORITY_CLASS_MAP) - return pc def _priority_class_to_template_name(pc): - if pc > PRIORITY_CLASS_MAP['override']: - # per-device - prio_class_index = pc - len(PRIORITY_CLASS_MAP) - return PRIORITY_CLASS_INVERSE_MAP[prio_class_index] - else: - return PRIORITY_CLASS_INVERSE_MAP[pc] + return PRIORITY_CLASS_INVERSE_MAP[pc] def _rule_to_template(rule): @@ -445,23 +376,12 @@ def _rule_to_template(rule): return templaterule -def _strip_device_condition(rule): - for i, c in enumerate(rule['conditions']): - if c['kind'] == 'device': - del rule['conditions'][i] - return rule - - def _namespaced_rule_id_from_spec(spec): return _namespaced_rule_id(spec, spec['rule_id']) def _namespaced_rule_id(spec, rule_id): - if spec['scope'] == 'global': - scope = 'global' - else: - scope = 'device/%s' % (spec['profile_tag']) - return "%s/%s/%s" % (scope, spec['template'], rule_id) + return "global/%s/%s" % (spec['template'], rule_id) def _rule_id_from_namespaced(in_rule_id): diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 5547f1b112..4c662e6e3c 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -45,7 +45,7 @@ class PusherRestServlet(ClientV1RestServlet): ) defer.returnValue((200, {})) - reqd = ['profile_tag', 'kind', 'app_id', 'app_display_name', + reqd = ['kind', 'app_id', 'app_display_name', 'device_display_name', 'pushkey', 'lang', 'data'] missing = [] for i in reqd: @@ -73,14 +73,14 @@ class PusherRestServlet(ClientV1RestServlet): yield pusher_pool.add_pusher( user_id=user.to_string(), access_token=requester.access_token_id, - profile_tag=content['profile_tag'], kind=content['kind'], app_id=content['app_id'], app_display_name=content['app_display_name'], device_display_name=content['device_display_name'], pushkey=content['pushkey'], lang=content['lang'], - data=content['data'] + data=content['data'], + profile_tag=content.get('profile_tag', ""), ) except PusherConfigException as pce: raise SynapseError(400, "Config Error: " + pce.message, diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index d77a817682..5820539a92 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -27,15 +27,14 @@ class EventPushActionsStore(SQLBaseStore): def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples): """ :param event: the event set actions for - :param tuples: list of tuples of (user_id, profile_tag, actions) + :param tuples: list of tuples of (user_id, actions) """ values = [] - for uid, profile_tag, actions in tuples: + for uid, actions in tuples: values.append({ 'room_id': event.room_id, 'event_id': event.event_id, 'user_id': uid, - 'profile_tag': profile_tag, 'actions': json.dumps(actions), 'stream_ordering': event.internal_metadata.stream_ordering, 'topological_ordering': event.depth, @@ -43,7 +42,7 @@ class EventPushActionsStore(SQLBaseStore): 'highlight': 1 if _action_has_highlight(actions) else 0, }) - for uid, _, __ in tuples: + for uid, __ in tuples: txn.call_after( self.get_unread_event_push_actions_by_room_for_user.invalidate_many, (event.room_id, uid) diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 8ec706178a..c23648cdbc 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -80,9 +80,9 @@ class PusherStore(SQLBaseStore): defer.returnValue(rows) @defer.inlineCallbacks - def add_pusher(self, user_id, access_token, profile_tag, kind, app_id, + def add_pusher(self, user_id, access_token, kind, app_id, app_display_name, device_display_name, - pushkey, pushkey_ts, lang, data): + pushkey, pushkey_ts, lang, data, profile_tag=""): try: next_id = yield self._pushers_id_gen.get_next() yield self._simple_upsert( @@ -95,12 +95,12 @@ class PusherStore(SQLBaseStore): dict( access_token=access_token, kind=kind, - profile_tag=profile_tag, app_display_name=app_display_name, device_display_name=device_display_name, ts=pushkey_ts, lang=lang, data=encode_canonical_json(data), + profile_tag=profile_tag, ), insertion_values=dict( id=next_id, -- cgit 1.4.1 From e12ec335a58bb7957cb7abfc1c96500bb4fb2627 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 18 Feb 2016 17:01:53 +0000 Subject: "You are not..." --- synapse/rest/client/v1/presence.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/rest/client') diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index 27ea5f2a43..bbfa1d6ac4 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -41,7 +41,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet): ) if not allowed: - raise AuthError(403, "You are allowed to see their presence.") + raise AuthError(403, "You are not allowed to see their presence.") state = yield self.handlers.presence_handler.get_state(target_user=user) -- cgit 1.4.1 From 577951b0324f67308f50c14fad703d2103621bc5 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Tue, 23 Feb 2016 15:11:25 +0000 Subject: Allow third_party_signed to be specified on /join --- synapse/api/auth.py | 57 ++++++++++++-------- synapse/federation/federation_server.py | 15 +++++- synapse/federation/transport/server.py | 12 ++++- synapse/handlers/federation.py | 93 +++++++++++++++++++++++++-------- synapse/handlers/room.py | 67 +++++++++++++++++++++--- synapse/python_dependencies.py | 2 +- synapse/rest/client/v1/room.py | 4 ++ 7 files changed, 196 insertions(+), 54 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index e2f84c4d57..183245443c 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -434,31 +434,46 @@ class Auth(object): if event.user_id != invite_event.user_id: return False - try: - public_key = invite_event.content["public_key"] - if signed["mxid"] != event.state_key: - return False - if signed["token"] != token: - return False - for server, signature_block in signed["signatures"].items(): - for key_name, encoded_signature in signature_block.items(): - if not key_name.startswith("ed25519:"): - return False - verify_key = decode_verify_key_bytes( - key_name, - decode_base64(public_key) - ) - verify_signed_json(signed, server, verify_key) - # We got the public key from the invite, so we know that the - # correct server signed the signed bundle. - # The caller is responsible for checking that the signing - # server has not revoked that public key. - return True + if signed["mxid"] != event.state_key: return False - except (KeyError, SignatureVerifyException,): + if signed["token"] != token: return False + for public_key_object in self.get_public_keys(invite_event): + public_key = public_key_object["public_key"] + try: + for server, signature_block in signed["signatures"].items(): + for key_name, encoded_signature in signature_block.items(): + if not key_name.startswith("ed25519:"): + continue + verify_key = decode_verify_key_bytes( + key_name, + decode_base64(public_key) + ) + verify_signed_json(signed, server, verify_key) + + # We got the public key from the invite, so we know that the + # correct server signed the signed bundle. + # The caller is responsible for checking that the signing + # server has not revoked that public key. + return True + except (KeyError, SignatureVerifyException,): + continue + return False + + def get_public_keys(self, invite_event): + public_keys = [] + if "public_key" in invite_event.content: + o = { + "public_key": invite_event.content["public_key"], + } + if "key_validity_url" in invite_event.content: + o["key_validity_url"] = invite_event.content["key_validity_url"] + public_keys.append(o) + public_keys.extend(invite_event.content.get("public_keys", [])) + return public_keys + def _get_power_level_event(self, auth_events): key = (EventTypes.PowerLevels, "", ) return auth_events.get(key) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 90718192dd..e8bfbe7cb5 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -543,8 +543,19 @@ class FederationServer(FederationBase): return event @defer.inlineCallbacks - def exchange_third_party_invite(self, invite): - ret = yield self.handler.exchange_third_party_invite(invite) + def exchange_third_party_invite( + self, + sender_user_id, + target_user_id, + room_id, + signed, + ): + ret = yield self.handler.exchange_third_party_invite( + sender_user_id, + target_user_id, + room_id, + signed, + ) defer.returnValue(ret) @defer.inlineCallbacks diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 65e054f7dd..6e92e2f8f4 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -425,7 +425,17 @@ class On3pidBindServlet(BaseFederationServlet): last_exception = None for invite in content["invites"]: try: - yield self.handler.exchange_third_party_invite(invite) + if "signed" not in invite or "token" not in invite["signed"]: + message = ("Rejecting received notification of third-" + "party invite without signed: %s" % (invite,)) + logger.info(message) + raise SynapseError(400, message) + yield self.handler.exchange_third_party_invite( + invite["sender"], + invite["mxid"], + invite["room_id"], + invite["signed"], + ) except Exception as e: last_exception = e if last_exception: diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index ac15f9e5dd..3655b9e5e2 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -14,6 +14,9 @@ # limitations under the License. """Contains handlers for federation events.""" +from signedjson.key import decode_verify_key_bytes +from signedjson.sign import verify_signed_json +from unpaddedbase64 import decode_base64 from ._base import BaseHandler @@ -1620,19 +1623,15 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks @log_function - def exchange_third_party_invite(self, invite): - sender = invite["sender"] - room_id = invite["room_id"] - - if "signed" not in invite or "token" not in invite["signed"]: - logger.info( - "Discarding received notification of third party invite " - "without signed: %s" % (invite,) - ) - return - + def exchange_third_party_invite( + self, + sender_user_id, + target_user_id, + room_id, + signed, + ): third_party_invite = { - "signed": invite["signed"], + "signed": signed, } event_dict = { @@ -1642,8 +1641,8 @@ class FederationHandler(BaseHandler): "third_party_invite": third_party_invite, }, "room_id": room_id, - "sender": sender, - "state_key": invite["mxid"], + "sender": sender_user_id, + "state_key": target_user_id, } if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)): @@ -1656,11 +1655,11 @@ class FederationHandler(BaseHandler): ) self.auth.check(event, context.current_state) - yield self._validate_keyserver(event, auth_events=context.current_state) + yield self._check_signature(event, auth_events=context.current_state) member_handler = self.hs.get_handlers().room_member_handler yield member_handler.send_membership_event(event, context, from_client=False) else: - destinations = set([x.split(":", 1)[-1] for x in (sender, room_id)]) + destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id)) yield self.replication_layer.forward_third_party_invite( destinations, room_id, @@ -1681,7 +1680,7 @@ class FederationHandler(BaseHandler): ) self.auth.check(event, auth_events=context.current_state) - yield self._validate_keyserver(event, auth_events=context.current_state) + yield self._check_signature(event, auth_events=context.current_state) returned_invite = yield self.send_invite(origin, event) # TODO: Make sure the signatures actually are correct. @@ -1711,17 +1710,69 @@ class FederationHandler(BaseHandler): defer.returnValue((event, context)) @defer.inlineCallbacks - def _validate_keyserver(self, event, auth_events): - token = event.content["third_party_invite"]["signed"]["token"] + def _check_signature(self, event, auth_events): + """ + Checks that the signature in the event is consistent with its invite. + :param event (Event): The m.room.member event to check + :param auth_events (dict<(event type, state_key), event>) + + :raises + AuthError if signature didn't match any keys, or key has been + revoked, + SynapseError if a transient error meant a key couldn't be checked + for revocation. + """ + signed = event.content["third_party_invite"]["signed"] + token = signed["token"] invite_event = auth_events.get( (EventTypes.ThirdPartyInvite, token,) ) + if not invite_event: + raise AuthError(403, "Could not find invite") + + last_exception = None + for public_key_object in self.hs.get_auth().get_public_keys(invite_event): + try: + for server, signature_block in signed["signatures"].items(): + for key_name, encoded_signature in signature_block.items(): + if not key_name.startswith("ed25519:"): + continue + + public_key = public_key_object["public_key"] + verify_key = decode_verify_key_bytes( + key_name, + decode_base64(public_key) + ) + verify_signed_json(signed, server, verify_key) + if "key_validity_url" in public_key_object: + yield self._check_key_revocation( + public_key, + public_key_object["key_validity_url"] + ) + return + except Exception as e: + last_exception = e + raise last_exception + + @defer.inlineCallbacks + def _check_key_revocation(self, public_key, url): + """ + Checks whether public_key has been revoked. + + :param public_key (str): base-64 encoded public key. + :param url (str): Key revocation URL. + + :raises + AuthError if they key has been revoked. + SynapseError if a transient error meant a key couldn't be checked + for revocation. + """ try: response = yield self.hs.get_simple_http_client().get_json( - invite_event.content["key_validity_url"], - {"public_key": invite_event.content["public_key"]} + url, + {"public_key": public_key} ) except Exception: raise SynapseError( diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index b00cac4bd4..eb9700a35b 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -398,6 +398,7 @@ class RoomMemberHandler(BaseHandler): action, txn_id=None, remote_room_hosts=None, + third_party_signed=None, ratelimit=True, ): effective_membership_state = action @@ -406,6 +407,15 @@ class RoomMemberHandler(BaseHandler): elif action == "forget": effective_membership_state = "leave" + if third_party_signed is not None: + replication = self.hs.get_replication_layer() + yield replication.exchange_third_party_invite( + third_party_signed["sender"], + target.to_string(), + room_id, + third_party_signed, + ) + msg_handler = self.hs.get_handlers().message_handler content = {"membership": effective_membership_state} @@ -759,7 +769,7 @@ class RoomMemberHandler(BaseHandler): if room_avatar_event: room_avatar_url = room_avatar_event.content.get("url", "") - token, public_key, key_validity_url, display_name = ( + token, public_keys, fallback_public_key, display_name = ( yield self._ask_id_server_for_third_party_invite( id_server=id_server, medium=medium, @@ -774,14 +784,18 @@ class RoomMemberHandler(BaseHandler): inviter_avatar_url=inviter_avatar_url ) ) + msg_handler = self.hs.get_handlers().message_handler yield msg_handler.create_and_send_nonmember_event( { "type": EventTypes.ThirdPartyInvite, "content": { "display_name": display_name, - "key_validity_url": key_validity_url, - "public_key": public_key, + "public_keys": public_keys, + + # For backwards compatibility: + "key_validity_url": fallback_public_key["key_validity_url"], + "public_key": fallback_public_key["public_key"], }, "room_id": room_id, "sender": user.to_string(), @@ -806,6 +820,34 @@ class RoomMemberHandler(BaseHandler): inviter_display_name, inviter_avatar_url ): + """ + Asks an identity server for a third party invite. + + :param id_server (str): hostname + optional port for the identity server. + :param medium (str): The literal string "email". + :param address (str): The third party address being invited. + :param room_id (str): The ID of the room to which the user is invited. + :param inviter_user_id (str): The user ID of the inviter. + :param room_alias (str): An alias for the room, for cosmetic + notifications. + :param room_avatar_url (str): The URL of the room's avatar, for cosmetic + notifications. + :param room_join_rules (str): The join rules of the email + (e.g. "public"). + :param room_name (str): The m.room.name of the room. + :param inviter_display_name (str): The current display name of the + inviter. + :param inviter_avatar_url (str): The URL of the inviter's avatar. + + :return: A deferred tuple containing: + token (str): The token which must be signed to prove authenticity. + public_keys ([{"public_key": str, "key_validity_url": str}]): + public_key is a base64-encoded ed25519 public key. + fallback_public_key: One element from public_keys. + display_name (str): A user-friendly name to represent the invited + user. + """ + is_url = "%s%s/_matrix/identity/api/v1/store-invite" % ( id_server_scheme, id_server, ) @@ -826,12 +868,21 @@ class RoomMemberHandler(BaseHandler): ) # TODO: Check for success token = data["token"] - public_key = data["public_key"] + public_keys = data.get("public_keys", []) + if "public_key" in data: + fallback_public_key = { + "public_key": data["public_key"], + "key_validity_url": "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % ( + id_server_scheme, id_server, + ), + } + else: + fallback_public_key = public_keys[0] + + if not public_keys: + public_keys.append(fallback_public_key) display_name = data["display_name"] - key_validity_url = "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % ( - id_server_scheme, id_server, - ) - defer.returnValue((token, public_key, key_validity_url, display_name)) + defer.returnValue((token, public_keys, fallback_public_key, display_name)) def forget(self, user, room_id): return self.store.forget(user.to_string(), room_id) diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 75bf3d13aa..35933324a4 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -19,7 +19,7 @@ logger = logging.getLogger(__name__) REQUIREMENTS = { "frozendict>=0.4": ["frozendict"], - "unpaddedbase64>=1.0.1": ["unpaddedbase64>=1.0.1"], + "unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"], "canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"], "signedjson>=1.0.0": ["signedjson>=1.0.0"], "pynacl==0.3.0": ["nacl==0.3.0", "nacl.bindings"], diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index e6f5c5614a..07a2a5dd82 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -228,6 +228,8 @@ class JoinRoomAliasServlet(ClientV1RestServlet): allow_guest=True, ) + content = _parse_json(request) + if RoomID.is_valid(room_identifier): room_id = room_identifier remote_room_hosts = None @@ -248,6 +250,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet): action="join", txn_id=txn_id, remote_room_hosts=remote_room_hosts, + third_party_signed=content.get("third_party_signed", None), ) defer.returnValue((200, {"room_id": room_id})) @@ -451,6 +454,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet): room_id=room_id, action=membership_action, txn_id=txn_id, + third_party_signed=content.get("third_party_signed", None), ) defer.returnValue((200, {})) -- cgit 1.4.1 From 869580206daa5aa940d079fc907f75dea7770505 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Wed, 24 Feb 2016 08:50:28 +0000 Subject: Ignore invalid POST bodies when joining rooms --- synapse/rest/client/v1/room.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 07a2a5dd82..f5ed4f7302 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -228,7 +228,12 @@ class JoinRoomAliasServlet(ClientV1RestServlet): allow_guest=True, ) - content = _parse_json(request) + try: + content = _parse_json(request) + except: + # Turns out we used to ignore the body entirely, and some clients + # cheekily send invalid bodies. + content = {} if RoomID.is_valid(room_identifier): room_id = room_identifier @@ -427,7 +432,12 @@ class RoomMembershipRestServlet(ClientV1RestServlet): }: raise AuthError(403, "Guest access not allowed") - content = _parse_json(request) + try: + content = _parse_json(request) + except: + # Turns out we used to ignore the body entirely, and some clients + # cheekily send invalid bodies. + content = {} if membership_action == "invite" and self._has_3pid_invite_keys(content): yield self.handlers.room_member_handler.do_3pid_invite( -- cgit 1.4.1 From 9892d017b25380184fb8db47ef859b07053f00f9 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 24 Feb 2016 16:31:07 +0000 Subject: Remove unused get_rule_attr method --- synapse/rest/client/v1/push_rule.py | 8 -------- 1 file changed, 8 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 5db2805d68..6c8f09e898 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -200,14 +200,6 @@ class PushRuleRestServlet(ClientV1RestServlet): else: raise UnrecognizedRequestError() - def get_rule_attr(self, user_id, namespaced_rule_id, attr): - if attr == 'enabled': - return self.hs.get_datastore().get_push_rule_enabled_by_user_rule_id( - user_id, namespaced_rule_id - ) - else: - raise UnrecognizedRequestError() - def _rule_spec_from_path(path): if len(path) < 2: -- cgit 1.4.1 From 15c2ac2cac7377e48eb0531f911c4b7ea1891457 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 25 Feb 2016 15:13:07 +0000 Subject: Make sure we return a JSON object when returning the values of specif… MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …ic keys from a push rule --- synapse/rest/client/v1/push_rule.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'synapse/rest/client') diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 6c8f09e898..d26e4cde3e 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -324,7 +324,9 @@ def _filter_ruleset_with_path(ruleset, path): attr = path[0] if attr in the_rule: - return the_rule[attr] + # Make sure we return a JSON object as the attribute may be a + # JSON value. + return {attr: the_rule[attr]} else: raise UnrecognizedRequestError() -- cgit 1.4.1 From a53774721a90955cfb6180d332ca9f54f9b5e58a Mon Sep 17 00:00:00 2001 From: Gergely Polonkai Date: Fri, 26 Feb 2016 10:22:35 +0100 Subject: Add error codes for malformed/bad JSON in /login Signed-off-by: Gergely Polonkai --- synapse/rest/client/v1/login.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 79101106ac..a4f89aea7b 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -404,10 +404,10 @@ def _parse_json(request): try: content = json.loads(request.content.read()) if type(content) != dict: - raise SynapseError(400, "Content must be a JSON object.") + raise SynapseError(400, "Content must be a JSON object.", errcode=Codes.BAD_JSON) return content except ValueError: - raise SynapseError(400, "Content not JSON.") + raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) def register_servlets(hs, http_server): -- cgit 1.4.1 From 87acd8fb075114849ca06188ef333119ea73ad12 Mon Sep 17 00:00:00 2001 From: Gergely Polonkai Date: Fri, 26 Feb 2016 12:05:38 +0100 Subject: Fix to appease the PEP8 dragon --- synapse/rest/client/v1/login.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'synapse/rest/client') diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index a4f89aea7b..f13272da8e 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -404,7 +404,9 @@ def _parse_json(request): try: content = json.loads(request.content.read()) if type(content) != dict: - raise SynapseError(400, "Content must be a JSON object.", errcode=Codes.BAD_JSON) + raise SynapseError( + 400, "Content must be a JSON object.", errcode=Codes.BAD_JSON + ) return content except ValueError: raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) -- cgit 1.4.1 From de27f7fc79b785961181d13749468ae3e2019772 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 26 Feb 2016 14:28:19 +0000 Subject: Add support for changing the actions for default rules See matrix-org/matrix-doc#283 Works by adding dummy rules to the push rules table with a negative priority class and then using those rules to clobber the default rule actions when adding the default rules in ``list_with_base_rules`` --- synapse/push/baserules.py | 57 ++++++++++++++++++++++++++++++++----- synapse/rest/client/v1/push_rule.py | 31 +++++++++++++++++--- synapse/storage/push_rule.py | 25 ++++++++++++++++ 3 files changed, 102 insertions(+), 11 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py index 0832c77cb4..86a2998bcc 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py @@ -13,46 +13,67 @@ # limitations under the License. from synapse.push.rulekinds import PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP +import copy def list_with_base_rules(rawrules): + """Combine the list of rules set by the user with the default push rules + + :param list rawrules: The rules the user has modified or set. + :returns: A new list with the rules set by the user combined with the + defaults. + """ ruleslist = [] + # Grab the base rules that the user has modified. + # The modified base rules have a priority_class of -1. + modified_base_rules = { + r['rule_id']: r for r in rawrules if r['priority_class'] < 0 + } + + # Remove the modified base rules from the list, They'll be added back + # in the default postions in the list. + rawrules = [r for r in rawrules if r['priority_class'] >= 0] + # shove the server default rules for each kind onto the end of each current_prio_class = PRIORITY_CLASS_INVERSE_MAP.keys()[-1] ruleslist.extend(make_base_prepend_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class] + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules )) for r in rawrules: if r['priority_class'] < current_prio_class: while r['priority_class'] < current_prio_class: ruleslist.extend(make_base_append_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class] + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], + modified_base_rules, )) current_prio_class -= 1 if current_prio_class > 0: ruleslist.extend(make_base_prepend_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class] + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], + modified_base_rules, )) ruleslist.append(r) while current_prio_class > 0: ruleslist.extend(make_base_append_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class] + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], + modified_base_rules, )) current_prio_class -= 1 if current_prio_class > 0: ruleslist.extend(make_base_prepend_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class] + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], + modified_base_rules, )) return ruleslist -def make_base_append_rules(kind): +def make_base_append_rules(kind, modified_base_rules): rules = [] if kind == 'override': @@ -62,15 +83,31 @@ def make_base_append_rules(kind): elif kind == 'content': rules = BASE_APPEND_CONTENT_RULES + # Copy the rules before modifying them + rules = copy.deepcopy(rules) + for r in rules: + # Only modify the actions, keep the conditions the same. + modified = modified_base_rules.get(r['rule_id']) + if modified: + r['actions'] = modified['actions'] + return rules -def make_base_prepend_rules(kind): +def make_base_prepend_rules(kind, modified_base_rules): rules = [] if kind == 'override': rules = BASE_PREPEND_OVERRIDE_RULES + # Copy the rules before modifying them + rules = copy.deepcopy(rules) + for r in rules: + # Only modify the actions, keep the conditions the same. + modified = modified_base_rules.get(r['rule_id']) + if modified: + r['actions'] = modified['actions'] + return rules @@ -263,18 +300,24 @@ BASE_APPEND_UNDERRIDE_RULES = [ ] +BASE_RULE_IDS = set() + for r in BASE_APPEND_CONTENT_RULES: r['priority_class'] = PRIORITY_CLASS_MAP['content'] r['default'] = True + BASE_RULE_IDS.add(r['rule_id']) for r in BASE_PREPEND_OVERRIDE_RULES: r['priority_class'] = PRIORITY_CLASS_MAP['override'] r['default'] = True + BASE_RULE_IDS.add(r['rule_id']) for r in BASE_APPEND_OVRRIDE_RULES: r['priority_class'] = PRIORITY_CLASS_MAP['override'] r['default'] = True + BASE_RULE_IDS.add(r['rule_id']) for r in BASE_APPEND_UNDERRIDE_RULES: r['priority_class'] = PRIORITY_CLASS_MAP['underride'] r['default'] = True + BASE_RULE_IDS.add(r['rule_id']) diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index d26e4cde3e..970a019223 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -22,7 +22,7 @@ from .base import ClientV1RestServlet, client_path_patterns from synapse.storage.push_rule import ( InconsistentRuleException, RuleNotFoundException ) -import synapse.push.baserules as baserules +from synapse.push.baserules import list_with_base_rules, BASE_RULE_IDS from synapse.push.rulekinds import ( PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP ) @@ -55,6 +55,10 @@ class PushRuleRestServlet(ClientV1RestServlet): yield self.set_rule_attr(requester.user.to_string(), spec, content) defer.returnValue((200, {})) + if spec['rule_id'].startswith('.'): + # Rule ids starting with '.' are reserved for server default rules. + raise SynapseError(400, "cannot add new rule_ids that start with '.'") + try: (conditions, actions) = _rule_tuple_from_request_object( spec['template'], @@ -128,7 +132,7 @@ class PushRuleRestServlet(ClientV1RestServlet): ruleslist.append(rule) # We're going to be mutating this a lot, so do a deep copy - ruleslist = copy.deepcopy(baserules.list_with_base_rules(ruleslist)) + ruleslist = copy.deepcopy(list_with_base_rules(ruleslist)) rules = {'global': {}, 'device': {}} @@ -197,6 +201,18 @@ class PushRuleRestServlet(ClientV1RestServlet): return self.hs.get_datastore().set_push_rule_enabled( user_id, namespaced_rule_id, val ) + elif spec['attr'] == 'actions': + actions = val.get('actions') + _check_actions(actions) + namespaced_rule_id = _namespaced_rule_id_from_spec(spec) + rule_id = spec['rule_id'] + is_default_rule = rule_id.startswith(".") + if is_default_rule: + if namespaced_rule_id not in BASE_RULE_IDS: + raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,)) + return self.hs.get_datastore().set_push_rule_actions( + user_id, namespaced_rule_id, actions, is_default_rule + ) else: raise UnrecognizedRequestError() @@ -274,6 +290,15 @@ def _rule_tuple_from_request_object(rule_template, rule_id, req_obj): raise InvalidRuleException("No actions found") actions = req_obj['actions'] + _check_actions(actions) + + return conditions, actions + + +def _check_actions(actions): + if not isinstance(actions, list): + raise InvalidRuleException("No actions found") + for a in actions: if a in ['notify', 'dont_notify', 'coalesce']: pass @@ -282,8 +307,6 @@ def _rule_tuple_from_request_object(rule_template, rule_id, req_obj): else: raise InvalidRuleException("Unrecognised action") - return conditions, actions - def _add_empty_priority_class_arrays(d): for pc in PRIORITY_CLASS_MAP.keys(): diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index e19a81e41f..bb5c14d912 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -294,6 +294,31 @@ class PushRuleStore(SQLBaseStore): self.get_push_rules_enabled_for_user.invalidate, (user_id,) ) + def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule): + actions_json = json.dumps(actions) + + def set_push_rule_actions_txn(txn): + if is_default_rule: + # Add a dummy rule to the rules table with the user specified + # actions. + priority_class = -1 + priority = 1 + self._upsert_push_rule_txn( + txn, user_id, rule_id, priority_class, priority, + "[]", actions_json + ) + else: + self._simple_update_one_txn( + txn, + "push_rules", + {'user_name': user_id, 'rule_id': rule_id}, + {'actions': actions_json}, + ) + + return self.runInteraction( + "set_push_rule_actions", set_push_rule_actions_txn, + ) + class RuleNotFoundException(Exception): pass -- cgit 1.4.1 From f9af8962f8ea6201ed3910eb248b8668f1262fef Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 1 Mar 2016 14:46:31 +0000 Subject: Allow alias creators to delete aliases --- synapse/handlers/directory.py | 27 ++++++++++++++++++----- synapse/rest/client/v1/directory.py | 3 --- synapse/storage/directory.py | 15 ++++++++++++- synapse/storage/schema/delta/30/alias_creator.sql | 16 ++++++++++++++ 4 files changed, 51 insertions(+), 10 deletions(-) create mode 100644 synapse/storage/schema/delta/30/alias_creator.sql (limited to 'synapse/rest/client') diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index e0a778e7ff..cce6f76f0e 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -17,9 +17,9 @@ from twisted.internet import defer from ._base import BaseHandler -from synapse.api.errors import SynapseError, Codes, CodeMessageException +from synapse.api.errors import SynapseError, Codes, CodeMessageException, AuthError from synapse.api.constants import EventTypes -from synapse.types import RoomAlias +from synapse.types import RoomAlias, UserID import logging import string @@ -38,7 +38,7 @@ class DirectoryHandler(BaseHandler): ) @defer.inlineCallbacks - def _create_association(self, room_alias, room_id, servers=None): + def _create_association(self, room_alias, room_id, servers=None, creator=None): # general association creation for both human users and app services for wchar in string.whitespace: @@ -60,7 +60,8 @@ class DirectoryHandler(BaseHandler): yield self.store.create_room_alias_association( room_alias, room_id, - servers + servers, + creator=creator, ) @defer.inlineCallbacks @@ -77,7 +78,7 @@ class DirectoryHandler(BaseHandler): 400, "This alias is reserved by an application service.", errcode=Codes.EXCLUSIVE ) - yield self._create_association(room_alias, room_id, servers) + yield self._create_association(room_alias, room_id, servers, creator=user_id) @defer.inlineCallbacks def create_appservice_association(self, service, room_alias, room_id, @@ -95,7 +96,11 @@ class DirectoryHandler(BaseHandler): def delete_association(self, user_id, room_alias): # association deletion for human users - # TODO Check if server admin + can_delete = yield self._user_can_delete_alias(room_alias, user_id) + if not can_delete: + raise AuthError( + 403, "You don't have permission to delete the alias.", + ) can_delete = yield self.can_modify_alias( room_alias, @@ -257,3 +262,13 @@ class DirectoryHandler(BaseHandler): return # either no interested services, or no service with an exclusive lock defer.returnValue(True) + + @defer.inlineCallbacks + def _user_can_delete_alias(self, alias, user_id): + creator = yield self.store.get_room_alias_creator(alias.to_string()) + + if creator and creator == user_id: + defer.returnValue(True) + + is_admin = yield self.auth.is_server_admin(UserID.from_string(user_id)) + defer.returnValue(is_admin) diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 74ec1e50e0..55c22000fd 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -118,9 +118,6 @@ class ClientDirectoryServer(ClientV1RestServlet): requester = yield self.auth.get_user_by_req(request) user = requester.user - is_admin = yield self.auth.is_server_admin(user) - if not is_admin: - raise AuthError(403, "You need to be a server admin") room_alias = RoomAlias.from_string(room_alias) diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py index 1556619d5e..012a0b414a 100644 --- a/synapse/storage/directory.py +++ b/synapse/storage/directory.py @@ -70,13 +70,14 @@ class DirectoryStore(SQLBaseStore): ) @defer.inlineCallbacks - def create_room_alias_association(self, room_alias, room_id, servers): + def create_room_alias_association(self, room_alias, room_id, servers, creator=None): """ Creates an associatin between a room alias and room_id/servers Args: room_alias (RoomAlias) room_id (str) servers (list) + creator (str): Optional user_id of creator. Returns: Deferred @@ -87,6 +88,7 @@ class DirectoryStore(SQLBaseStore): { "room_alias": room_alias.to_string(), "room_id": room_id, + "creator": creator, }, desc="create_room_alias_association", ) @@ -107,6 +109,17 @@ class DirectoryStore(SQLBaseStore): ) self.get_aliases_for_room.invalidate((room_id,)) + def get_room_alias_creator(self, room_alias): + return self._simple_select_one_onecol( + table="room_aliases", + keyvalues={ + "room_alias": room_alias, + }, + retcol="creator", + desc="get_room_alias_creator", + allow_none=True + ) + @defer.inlineCallbacks def delete_room_alias(self, room_alias): room_id = yield self.runInteraction( diff --git a/synapse/storage/schema/delta/30/alias_creator.sql b/synapse/storage/schema/delta/30/alias_creator.sql new file mode 100644 index 0000000000..c9d0dde638 --- /dev/null +++ b/synapse/storage/schema/delta/30/alias_creator.sql @@ -0,0 +1,16 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE room_aliases ADD COLUMN creator TEXT; -- cgit 1.4.1 From ddf9e7b3027eee61086ebfb447c5fa33e9b640fe Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 3 Mar 2016 14:57:45 +0000 Subject: Hook up the push rules to the notifier --- synapse/handlers/message.py | 4 ++-- synapse/notifier.py | 2 +- synapse/rest/client/v1/push_rule.py | 44 ++++++++++++++++++++++++------------- synapse/streams/events.py | 4 ++++ synapse/types.py | 7 ++++++ 5 files changed, 43 insertions(+), 18 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index afa7c9c36c..2fa12c8f2b 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -647,8 +647,8 @@ class MessageHandler(BaseHandler): user_id, messages, is_peeking=is_peeking ) - start_token = StreamToken(token[0], 0, 0, 0, 0) - end_token = StreamToken(token[1], 0, 0, 0, 0) + start_token = StreamToken.START.copy_and_replace("room_key", token[0]) + end_token = StreamToken.START.copy_and_replace("room_key", token[1]) time_now = self.clock.time_msec() diff --git a/synapse/notifier.py b/synapse/notifier.py index 3c36a20868..9b69b0333a 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -284,7 +284,7 @@ class Notifier(object): @defer.inlineCallbacks def wait_for_events(self, user_id, timeout, callback, room_ids=None, - from_token=StreamToken("s0", "0", "0", "0", "0")): + from_token=StreamToken.START): """Wait until the callback returns a non empty response or the timeout fires. """ diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 970a019223..cf68725ca1 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -36,6 +36,11 @@ class PushRuleRestServlet(ClientV1RestServlet): SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = ( "Unrecognised request: You probably wanted a trailing slash") + def __init__(self, hs): + super(PushRuleRestServlet, self).__init__(hs) + self.store = hs.get_datastore() + self.notifier = hs.get_notifier() + @defer.inlineCallbacks def on_PUT(self, request): spec = _rule_spec_from_path(request.postpath) @@ -51,8 +56,11 @@ class PushRuleRestServlet(ClientV1RestServlet): content = _parse_json(request) + user_id = requester.user.to_string() + if 'attr' in spec: - yield self.set_rule_attr(requester.user.to_string(), spec, content) + yield self.set_rule_attr(user_id, spec, content) + self.notify_user(user_id) defer.returnValue((200, {})) if spec['rule_id'].startswith('.'): @@ -77,8 +85,8 @@ class PushRuleRestServlet(ClientV1RestServlet): after = _namespaced_rule_id(spec, after[0]) try: - yield self.hs.get_datastore().add_push_rule( - user_id=requester.user.to_string(), + yield self.store.add_push_rule( + user_id=user_id, rule_id=_namespaced_rule_id_from_spec(spec), priority_class=priority_class, conditions=conditions, @@ -86,6 +94,7 @@ class PushRuleRestServlet(ClientV1RestServlet): before=before, after=after ) + self.notify_user(user_id) except InconsistentRuleException as e: raise SynapseError(400, e.message) except RuleNotFoundException as e: @@ -98,13 +107,15 @@ class PushRuleRestServlet(ClientV1RestServlet): spec = _rule_spec_from_path(request.postpath) requester = yield self.auth.get_user_by_req(request) + user_id = requester.user.to_string() namespaced_rule_id = _namespaced_rule_id_from_spec(spec) try: - yield self.hs.get_datastore().delete_push_rule( - requester.user.to_string(), namespaced_rule_id + yield self.store.delete_push_rule( + user_id, namespaced_rule_id ) + self.notify_user(user_id) defer.returnValue((200, {})) except StoreError as e: if e.code == 404: @@ -115,14 +126,12 @@ class PushRuleRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): requester = yield self.auth.get_user_by_req(request) - user = requester.user + user_id = requester.user.to_string() # we build up the full structure and then decide which bits of it # to send which means doing unnecessary work sometimes but is # is probably not going to make a whole lot of difference - rawrules = yield self.hs.get_datastore().get_push_rules_for_user( - user.to_string() - ) + rawrules = yield self.store.get_push_rules_for_user(user_id) ruleslist = [] for rawrule in rawrules: @@ -138,8 +147,7 @@ class PushRuleRestServlet(ClientV1RestServlet): rules['global'] = _add_empty_priority_class_arrays(rules['global']) - enabled_map = yield self.hs.get_datastore().\ - get_push_rules_enabled_for_user(user.to_string()) + enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id) for r in ruleslist: rulearray = None @@ -152,9 +160,9 @@ class PushRuleRestServlet(ClientV1RestServlet): pattern_type = c.pop("pattern_type", None) if pattern_type == "user_id": - c["pattern"] = user.to_string() + c["pattern"] = user_id elif pattern_type == "user_localpart": - c["pattern"] = user.localpart + c["pattern"] = requester.user.localpart rulearray = rules['global'][template_name] @@ -188,6 +196,12 @@ class PushRuleRestServlet(ClientV1RestServlet): def on_OPTIONS(self, _): return 200, {} + def notify_user(self, user_id): + stream_id = self.store.get_push_rules_stream_token() + self.notifier.on_new_event( + "push_rules_key", stream_id, users=[user_id] + ) + def set_rule_attr(self, user_id, spec, val): if spec['attr'] == 'enabled': if isinstance(val, dict) and "enabled" in val: @@ -198,7 +212,7 @@ class PushRuleRestServlet(ClientV1RestServlet): # bools directly, so let's not break them. raise SynapseError(400, "Value for 'enabled' must be boolean") namespaced_rule_id = _namespaced_rule_id_from_spec(spec) - return self.hs.get_datastore().set_push_rule_enabled( + return self.store.set_push_rule_enabled( user_id, namespaced_rule_id, val ) elif spec['attr'] == 'actions': @@ -210,7 +224,7 @@ class PushRuleRestServlet(ClientV1RestServlet): if is_default_rule: if namespaced_rule_id not in BASE_RULE_IDS: raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,)) - return self.hs.get_datastore().set_push_rule_actions( + return self.store.set_push_rule_actions( user_id, namespaced_rule_id, actions, is_default_rule ) else: diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 5ddf4e988b..d4c0bb6732 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -38,9 +38,12 @@ class EventSources(object): name: cls(hs) for name, cls in EventSources.SOURCE_TYPES.items() } + self.store = hs.get_datastore() @defer.inlineCallbacks def get_current_token(self, direction='f'): + push_rules_key, _ = self.store.get_push_rules_stream_token() + token = StreamToken( room_key=( yield self.sources["room"].get_current_key(direction) @@ -57,5 +60,6 @@ class EventSources(object): account_data_key=( yield self.sources["account_data"].get_current_key() ), + push_rules_key=push_rules_key, ) defer.returnValue(token) diff --git a/synapse/types.py b/synapse/types.py index d5bd95cbd3..5b166835bd 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -115,6 +115,7 @@ class StreamToken( "typing_key", "receipt_key", "account_data_key", + "push_rules_key", )) ): _SEPARATOR = "_" @@ -150,6 +151,7 @@ class StreamToken( or (int(other.typing_key) < int(self.typing_key)) or (int(other.receipt_key) < int(self.receipt_key)) or (int(other.account_data_key) < int(self.account_data_key)) + or (int(other.push_rules_key) < int(self.push_rules_key)) ) def copy_and_advance(self, key, new_value): @@ -174,6 +176,11 @@ class StreamToken( return StreamToken(**d) +StreamToken.START = StreamToken( + *(["s0"] + ["0"] * (len(StreamToken._fields) - 1)) +) + + class RoomStreamToken(namedtuple("_StreamToken", "topological stream")): """Tokens are positions between events. The token "s1" comes after event 1. -- cgit 1.4.1 From 3406eba4ef40de888ebb5b22c0ea4925b2dddeb1 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 3 Mar 2016 16:11:59 +0000 Subject: Move the code for formatting push rules into a separate function --- synapse/push/clientformat.py | 112 ++++++++++++++++++++++++++++++++++++ synapse/rest/client/v1/push_rule.py | 90 ++--------------------------- 2 files changed, 116 insertions(+), 86 deletions(-) create mode 100644 synapse/push/clientformat.py (limited to 'synapse/rest/client') diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py new file mode 100644 index 0000000000..ae9db9ec2f --- /dev/null +++ b/synapse/push/clientformat.py @@ -0,0 +1,112 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.push.baserules import list_with_base_rules + +from synapse.push.rulekinds import ( + PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP +) + +import copy +import simplejson as json + + +def format_push_rules_for_user(user, rawrules, enabled_map): + """Converts a list of rawrules and a enabled map into nested dictionaries + to match the Matrix client-server format for push rules""" + + ruleslist = [] + for rawrule in rawrules: + rule = dict(rawrule) + rule["conditions"] = json.loads(rawrule["conditions"]) + rule["actions"] = json.loads(rawrule["actions"]) + ruleslist.append(rule) + + # We're going to be mutating this a lot, so do a deep copy + ruleslist = copy.deepcopy(list_with_base_rules(ruleslist)) + + rules = {'global': {}, 'device': {}} + + rules['global'] = _add_empty_priority_class_arrays(rules['global']) + + for r in ruleslist: + rulearray = None + + template_name = _priority_class_to_template_name(r['priority_class']) + + # Remove internal stuff. + for c in r["conditions"]: + c.pop("_id", None) + + pattern_type = c.pop("pattern_type", None) + if pattern_type == "user_id": + c["pattern"] = user.to_string() + elif pattern_type == "user_localpart": + c["pattern"] = user.localpart + + rulearray = rules['global'][template_name] + + template_rule = _rule_to_template(r) + if template_rule: + if r['rule_id'] in enabled_map: + template_rule['enabled'] = enabled_map[r['rule_id']] + elif 'enabled' in r: + template_rule['enabled'] = r['enabled'] + else: + template_rule['enabled'] = True + rulearray.append(template_rule) + + return rules + + +def _add_empty_priority_class_arrays(d): + for pc in PRIORITY_CLASS_MAP.keys(): + d[pc] = [] + return d + + +def _rule_to_template(rule): + unscoped_rule_id = None + if 'rule_id' in rule: + unscoped_rule_id = _rule_id_from_namespaced(rule['rule_id']) + + template_name = _priority_class_to_template_name(rule['priority_class']) + if template_name in ['override', 'underride']: + templaterule = {k: rule[k] for k in ["conditions", "actions"]} + elif template_name in ["sender", "room"]: + templaterule = {'actions': rule['actions']} + unscoped_rule_id = rule['conditions'][0]['pattern'] + elif template_name == 'content': + if len(rule["conditions"]) != 1: + return None + thecond = rule["conditions"][0] + if "pattern" not in thecond: + return None + templaterule = {'actions': rule['actions']} + templaterule["pattern"] = thecond["pattern"] + + if unscoped_rule_id: + templaterule['rule_id'] = unscoped_rule_id + if 'default' in rule: + templaterule['default'] = rule['default'] + return templaterule + + +def _rule_id_from_namespaced(in_rule_id): + return in_rule_id.split('/')[-1] + + +def _priority_class_to_template_name(pc): + return PRIORITY_CLASS_INVERSE_MAP[pc] diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index cf68725ca1..edfe28c79b 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -22,12 +22,10 @@ from .base import ClientV1RestServlet, client_path_patterns from synapse.storage.push_rule import ( InconsistentRuleException, RuleNotFoundException ) -from synapse.push.baserules import list_with_base_rules, BASE_RULE_IDS -from synapse.push.rulekinds import ( - PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP -) +from synapse.push.clientformat import format_push_rules_for_user +from synapse.push.baserules import BASE_RULE_IDS +from synapse.push.rulekinds import PRIORITY_CLASS_MAP -import copy import simplejson as json @@ -133,48 +131,9 @@ class PushRuleRestServlet(ClientV1RestServlet): # is probably not going to make a whole lot of difference rawrules = yield self.store.get_push_rules_for_user(user_id) - ruleslist = [] - for rawrule in rawrules: - rule = dict(rawrule) - rule["conditions"] = json.loads(rawrule["conditions"]) - rule["actions"] = json.loads(rawrule["actions"]) - ruleslist.append(rule) - - # We're going to be mutating this a lot, so do a deep copy - ruleslist = copy.deepcopy(list_with_base_rules(ruleslist)) - - rules = {'global': {}, 'device': {}} - - rules['global'] = _add_empty_priority_class_arrays(rules['global']) - enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id) - for r in ruleslist: - rulearray = None - - template_name = _priority_class_to_template_name(r['priority_class']) - - # Remove internal stuff. - for c in r["conditions"]: - c.pop("_id", None) - - pattern_type = c.pop("pattern_type", None) - if pattern_type == "user_id": - c["pattern"] = user_id - elif pattern_type == "user_localpart": - c["pattern"] = requester.user.localpart - - rulearray = rules['global'][template_name] - - template_rule = _rule_to_template(r) - if template_rule: - if r['rule_id'] in enabled_map: - template_rule['enabled'] = enabled_map[r['rule_id']] - elif 'enabled' in r: - template_rule['enabled'] = r['enabled'] - else: - template_rule['enabled'] = True - rulearray.append(template_rule) + rules = format_push_rules_for_user(requester.user, rawrules, enabled_map) path = request.postpath[1:] @@ -322,12 +281,6 @@ def _check_actions(actions): raise InvalidRuleException("Unrecognised action") -def _add_empty_priority_class_arrays(d): - for pc in PRIORITY_CLASS_MAP.keys(): - d[pc] = [] - return d - - def _filter_ruleset_with_path(ruleset, path): if path == []: raise UnrecognizedRequestError( @@ -376,37 +329,6 @@ def _priority_class_from_spec(spec): return pc -def _priority_class_to_template_name(pc): - return PRIORITY_CLASS_INVERSE_MAP[pc] - - -def _rule_to_template(rule): - unscoped_rule_id = None - if 'rule_id' in rule: - unscoped_rule_id = _rule_id_from_namespaced(rule['rule_id']) - - template_name = _priority_class_to_template_name(rule['priority_class']) - if template_name in ['override', 'underride']: - templaterule = {k: rule[k] for k in ["conditions", "actions"]} - elif template_name in ["sender", "room"]: - templaterule = {'actions': rule['actions']} - unscoped_rule_id = rule['conditions'][0]['pattern'] - elif template_name == 'content': - if len(rule["conditions"]) != 1: - return None - thecond = rule["conditions"][0] - if "pattern" not in thecond: - return None - templaterule = {'actions': rule['actions']} - templaterule["pattern"] = thecond["pattern"] - - if unscoped_rule_id: - templaterule['rule_id'] = unscoped_rule_id - if 'default' in rule: - templaterule['default'] = rule['default'] - return templaterule - - def _namespaced_rule_id_from_spec(spec): return _namespaced_rule_id(spec, spec['rule_id']) @@ -415,10 +337,6 @@ def _namespaced_rule_id(spec, rule_id): return "global/%s/%s" % (spec['template'], rule_id) -def _rule_id_from_namespaced(in_rule_id): - return in_rule_id.split('/')[-1] - - class InvalidRuleException(Exception): pass -- cgit 1.4.1 From b4022cc487921ec46942a6a72fb174bb7aa1e459 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Thu, 3 Mar 2016 16:43:42 +0000 Subject: Pass whole requester to ratelimiting This will enable more detailed decisions --- synapse/handlers/_base.py | 15 +++++-- synapse/handlers/directory.py | 20 ++++++---- synapse/handlers/federation.py | 4 +- synapse/handlers/message.py | 8 ++-- synapse/handlers/profile.py | 17 ++++---- synapse/handlers/room.py | 76 +++++++++++++++++++++--------------- synapse/rest/client/v1/directory.py | 6 ++- synapse/rest/client/v1/profile.py | 4 +- synapse/rest/client/v1/room.py | 8 ++-- tests/handlers/test_profile.py | 16 ++++++-- tests/replication/test_resource.py | 17 ++++---- tests/rest/client/v1/test_profile.py | 4 +- tests/utils.py | 5 +++ 13 files changed, 124 insertions(+), 76 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index bdade98bf7..2333fc0c09 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -160,10 +160,10 @@ class BaseHandler(object): ) defer.returnValue(res.get(user_id, [])) - def ratelimit(self, user_id): + def ratelimit(self, requester): time_now = self.clock.time() allowed, time_allowed = self.ratelimiter.send_message( - user_id, time_now, + requester.user.to_string(), time_now, msg_rate_hz=self.hs.config.rc_messages_per_second, burst_count=self.hs.config.rc_message_burst_count, ) @@ -263,11 +263,18 @@ class BaseHandler(object): return False @defer.inlineCallbacks - def handle_new_client_event(self, event, context, ratelimit=True, extra_users=[]): + def handle_new_client_event( + self, + requester, + event, + context, + ratelimit=True, + extra_users=[] + ): # We now need to go and hit out to wherever we need to hit out to. if ratelimit: - self.ratelimit(event.sender) + self.ratelimit(requester) self.auth.check(event, auth_events=context.current_state) diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index e0a778e7ff..88166f0187 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -212,17 +212,21 @@ class DirectoryHandler(BaseHandler): ) @defer.inlineCallbacks - def send_room_alias_update_event(self, user_id, room_id): + def send_room_alias_update_event(self, requester, user_id, room_id): aliases = yield self.store.get_aliases_for_room(room_id) msg_handler = self.hs.get_handlers().message_handler - yield msg_handler.create_and_send_nonmember_event({ - "type": EventTypes.Aliases, - "state_key": self.hs.hostname, - "room_id": room_id, - "sender": user_id, - "content": {"aliases": aliases}, - }, ratelimit=False) + yield msg_handler.create_and_send_nonmember_event( + requester, + { + "type": EventTypes.Aliases, + "state_key": self.hs.hostname, + "room_id": room_id, + "sender": user_id, + "content": {"aliases": aliases}, + }, + ratelimit=False + ) @defer.inlineCallbacks def get_association_from_room_alias(self, room_alias): diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 3655b9e5e2..6e50b0963e 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1657,7 +1657,7 @@ class FederationHandler(BaseHandler): self.auth.check(event, context.current_state) yield self._check_signature(event, auth_events=context.current_state) member_handler = self.hs.get_handlers().room_member_handler - yield member_handler.send_membership_event(event, context, from_client=False) + yield member_handler.send_membership_event(None, event, context) else: destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id)) yield self.replication_layer.forward_third_party_invite( @@ -1686,7 +1686,7 @@ class FederationHandler(BaseHandler): # TODO: Make sure the signatures actually are correct. event.signatures.update(returned_invite.signatures) member_handler = self.hs.get_handlers().room_member_handler - yield member_handler.send_membership_event(event, context, from_client=False) + yield member_handler.send_membership_event(None, event, context) @defer.inlineCallbacks def add_display_name_to_third_party_invite(self, event_dict, event, context): diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index afa7c9c36c..cace1cb82a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -215,7 +215,7 @@ class MessageHandler(BaseHandler): defer.returnValue((event, context)) @defer.inlineCallbacks - def send_nonmember_event(self, event, context, ratelimit=True): + def send_nonmember_event(self, requester, event, context, ratelimit=True): """ Persists and notifies local clients and federation of an event. @@ -241,6 +241,7 @@ class MessageHandler(BaseHandler): defer.returnValue(prev_state) yield self.handle_new_client_event( + requester=requester, event=event, context=context, ratelimit=ratelimit, @@ -268,9 +269,9 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def create_and_send_nonmember_event( self, + requester, event_dict, ratelimit=True, - token_id=None, txn_id=None ): """ @@ -280,10 +281,11 @@ class MessageHandler(BaseHandler): """ event, context = yield self.create_event( event_dict, - token_id=token_id, + token_id=requester.access_token_id, txn_id=txn_id ) yield self.send_nonmember_event( + requester, event, context, ratelimit=ratelimit, diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index c9ad5944e6..b45eafbb49 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -89,13 +89,13 @@ class ProfileHandler(BaseHandler): defer.returnValue(result["displayname"]) @defer.inlineCallbacks - def set_displayname(self, target_user, auth_user, new_displayname): + def set_displayname(self, target_user, requester, new_displayname): """target_user is the user whose displayname is to be changed; auth_user is the user attempting to make this change.""" if not self.hs.is_mine(target_user): raise SynapseError(400, "User is not hosted on this Home Server") - if target_user != auth_user: + if target_user != requester.user: raise AuthError(400, "Cannot set another user's displayname") if new_displayname == '': @@ -109,7 +109,7 @@ class ProfileHandler(BaseHandler): "displayname": new_displayname, }) - yield self._update_join_states(target_user) + yield self._update_join_states(requester) @defer.inlineCallbacks def get_avatar_url(self, target_user): @@ -139,13 +139,13 @@ class ProfileHandler(BaseHandler): defer.returnValue(result["avatar_url"]) @defer.inlineCallbacks - def set_avatar_url(self, target_user, auth_user, new_avatar_url): + def set_avatar_url(self, target_user, requester, new_avatar_url): """target_user is the user whose avatar_url is to be changed; auth_user is the user attempting to make this change.""" if not self.hs.is_mine(target_user): raise SynapseError(400, "User is not hosted on this Home Server") - if target_user != auth_user: + if target_user != requester.user: raise AuthError(400, "Cannot set another user's avatar_url") yield self.store.set_profile_avatar_url( @@ -156,7 +156,7 @@ class ProfileHandler(BaseHandler): "avatar_url": new_avatar_url, }) - yield self._update_join_states(target_user) + yield self._update_join_states(requester) @defer.inlineCallbacks def collect_presencelike_data(self, user, state): @@ -199,11 +199,12 @@ class ProfileHandler(BaseHandler): defer.returnValue(response) @defer.inlineCallbacks - def _update_join_states(self, user): + def _update_join_states(self, requester): + user = requester.user if not self.hs.is_mine(user): return - self.ratelimit(user.to_string()) + self.ratelimit(requester) joins = yield self.store.get_rooms_for_user( user.to_string(), diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index d2de23a6cc..91fe306cf4 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -18,7 +18,7 @@ from twisted.internet import defer from ._base import BaseHandler -from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken +from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken, Requester from synapse.api.constants import ( EventTypes, Membership, JoinRules, RoomCreationPreset, ) @@ -90,7 +90,7 @@ class RoomCreationHandler(BaseHandler): """ user_id = requester.user.to_string() - self.ratelimit(user_id) + self.ratelimit(requester) if "room_alias_name" in config: for wchar in string.whitespace: @@ -185,23 +185,29 @@ class RoomCreationHandler(BaseHandler): if "name" in config: name = config["name"] - yield msg_handler.create_and_send_nonmember_event({ - "type": EventTypes.Name, - "room_id": room_id, - "sender": user_id, - "state_key": "", - "content": {"name": name}, - }, ratelimit=False) + yield msg_handler.create_and_send_nonmember_event( + requester, + { + "type": EventTypes.Name, + "room_id": room_id, + "sender": user_id, + "state_key": "", + "content": {"name": name}, + }, + ratelimit=False) if "topic" in config: topic = config["topic"] - yield msg_handler.create_and_send_nonmember_event({ - "type": EventTypes.Topic, - "room_id": room_id, - "sender": user_id, - "state_key": "", - "content": {"topic": topic}, - }, ratelimit=False) + yield msg_handler.create_and_send_nonmember_event( + requester, + { + "type": EventTypes.Topic, + "room_id": room_id, + "sender": user_id, + "state_key": "", + "content": {"topic": topic}, + }, + ratelimit=False) for invitee in invite_list: room_member_handler.update_membership( @@ -231,7 +237,7 @@ class RoomCreationHandler(BaseHandler): if room_alias: result["room_alias"] = room_alias.to_string() yield directory_handler.send_room_alias_update_event( - user_id, room_id + requester, user_id, room_id ) defer.returnValue(result) @@ -263,7 +269,11 @@ class RoomCreationHandler(BaseHandler): @defer.inlineCallbacks def send(etype, content, **kwargs): event = create(etype, content, **kwargs) - yield msg_handler.create_and_send_nonmember_event(event, ratelimit=False) + yield msg_handler.create_and_send_nonmember_event( + creator, + event, + ratelimit=False + ) config = RoomCreationHandler.PRESETS_DICT[preset_config] @@ -454,12 +464,11 @@ class RoomMemberHandler(BaseHandler): member_handler = self.hs.get_handlers().room_member_handler yield member_handler.send_membership_event( + requester, event, context, - is_guest=requester.is_guest, ratelimit=ratelimit, remote_room_hosts=remote_room_hosts, - from_client=True, ) if action == "forget": @@ -468,17 +477,19 @@ class RoomMemberHandler(BaseHandler): @defer.inlineCallbacks def send_membership_event( self, + requester, event, context, - is_guest=False, remote_room_hosts=None, ratelimit=True, - from_client=True, ): """ Change the membership status of a user in a room. Args: + requester (Requester): The local user who requested the membership + event. If None, certain checks, like whether this homeserver can + act as the sender, will be skipped. event (SynapseEvent): The membership event. context: The context of the event. is_guest (bool): Whether the sender is a guest. @@ -486,19 +497,21 @@ class RoomMemberHandler(BaseHandler): the room, and could be danced with in order to join this homeserver for the first time. ratelimit (bool): Whether to rate limit this request. - from_client (bool): Whether this request is the result of a local - client request (rather than over federation). If so, we will - perform extra checks, like that this homeserver can act as this - client. Raises: SynapseError if there was a problem changing the membership. """ target_user = UserID.from_string(event.state_key) room_id = event.room_id - if from_client: + if requester is not None: sender = UserID.from_string(event.sender) + assert sender == requester.user, ( + "Sender (%s) must be same as requester (%s)" % + (sender, requester.user) + ) assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,) + else: + requester = Requester(target_user, None, False) message_handler = self.hs.get_handlers().message_handler prev_event = message_handler.deduplicate_state_event(event, context) @@ -508,7 +521,7 @@ class RoomMemberHandler(BaseHandler): action = "send" if event.membership == Membership.JOIN: - if is_guest and not self._can_guest_join(context.current_state): + if requester.is_guest and not self._can_guest_join(context.current_state): # This should be an auth check, but guests are a local concept, # so don't really fit into the general auth process. raise AuthError(403, "Guest access not allowed") @@ -551,6 +564,7 @@ class RoomMemberHandler(BaseHandler): ) else: yield self.handle_new_client_event( + requester, event, context, extra_users=[target_user], @@ -669,12 +683,12 @@ class RoomMemberHandler(BaseHandler): ) else: yield self._make_and_store_3pid_invite( + requester, id_server, medium, address, room_id, inviter, - requester.access_token_id, txn_id=txn_id ) @@ -732,12 +746,12 @@ class RoomMemberHandler(BaseHandler): @defer.inlineCallbacks def _make_and_store_3pid_invite( self, + requester, id_server, medium, address, room_id, user, - token_id, txn_id ): room_state = yield self.hs.get_state_handler().get_current_state(room_id) @@ -787,6 +801,7 @@ class RoomMemberHandler(BaseHandler): msg_handler = self.hs.get_handlers().message_handler yield msg_handler.create_and_send_nonmember_event( + requester, { "type": EventTypes.ThirdPartyInvite, "content": { @@ -801,7 +816,6 @@ class RoomMemberHandler(BaseHandler): "sender": user.to_string(), "state_key": token, }, - token_id=token_id, txn_id=txn_id, ) diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 74ec1e50e0..8c1a2614a0 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -75,7 +75,11 @@ class ClientDirectoryServer(ClientV1RestServlet): yield dir_handler.create_association( user_id, room_alias, room_id, servers ) - yield dir_handler.send_room_alias_update_event(user_id, room_id) + yield dir_handler.send_room_alias_update_event( + requester, + user_id, + room_id + ) except SynapseError as e: raise e except: diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index 3c5a212920..953764bd8e 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -51,7 +51,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet): defer.returnValue((400, "Unable to parse name")) yield self.handlers.profile_handler.set_displayname( - user, requester.user, new_name) + user, requester, new_name) defer.returnValue((200, {})) @@ -88,7 +88,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet): defer.returnValue((400, "Unable to parse name")) yield self.handlers.profile_handler.set_avatar_url( - user, requester.user, new_name) + user, requester, new_name) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index f5ed4f7302..cbf3673eff 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -158,12 +158,12 @@ class RoomStateEventRestServlet(ClientV1RestServlet): if event_type == EventTypes.Member: yield self.handlers.room_member_handler.send_membership_event( + requester, event, context, - is_guest=requester.is_guest, ) else: - yield msg_handler.send_nonmember_event(event, context) + yield msg_handler.send_nonmember_event(requester, event, context) defer.returnValue((200, {"event_id": event.event_id})) @@ -183,13 +183,13 @@ class RoomSendEventRestServlet(ClientV1RestServlet): msg_handler = self.handlers.message_handler event = yield msg_handler.create_and_send_nonmember_event( + requester, { "type": event_type, "content": content, "room_id": room_id, "sender": requester.user.to_string(), }, - token_id=requester.access_token_id, txn_id=txn_id, ) @@ -504,6 +504,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet): msg_handler = self.handlers.message_handler event = yield msg_handler.create_and_send_nonmember_event( + requester, { "type": EventTypes.Redaction, "content": content, @@ -511,7 +512,6 @@ class RoomRedactEventRestServlet(ClientV1RestServlet): "sender": requester.user.to_string(), "redacts": event_id, }, - token_id=requester.access_token_id, txn_id=txn_id, ) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index a87703bbfd..4f2c14e4ff 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -23,7 +23,7 @@ from synapse.api.errors import AuthError from synapse.handlers.profile import ProfileHandler from synapse.types import UserID -from tests.utils import setup_test_homeserver +from tests.utils import setup_test_homeserver, requester_for_user class ProfileHandlers(object): @@ -84,7 +84,11 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_my_name(self): - yield self.handler.set_displayname(self.frank, self.frank, "Frank Jr.") + yield self.handler.set_displayname( + self.frank, + requester_for_user(self.frank), + "Frank Jr." + ) self.assertEquals( (yield self.store.get_profile_displayname(self.frank.localpart)), @@ -93,7 +97,11 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_my_name_noauth(self): - d = self.handler.set_displayname(self.frank, self.bob, "Frank Jr.") + d = self.handler.set_displayname( + self.frank, + requester_for_user(self.bob), + "Frank Jr." + ) yield self.assertFailure(d, AuthError) @@ -136,7 +144,7 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_my_avatar(self): yield self.handler.set_avatar_url( - self.frank, self.frank, "http://my.server/pic.gif" + self.frank, requester_for_user(self.frank), "http://my.server/pic.gif" ) self.assertEquals( diff --git a/tests/replication/test_resource.py b/tests/replication/test_resource.py index 38daaf87e2..daabc563b4 100644 --- a/tests/replication/test_resource.py +++ b/tests/replication/test_resource.py @@ -18,7 +18,7 @@ from synapse.types import Requester, UserID from twisted.internet import defer from tests import unittest -from tests.utils import setup_test_homeserver +from tests.utils import setup_test_homeserver, requester_for_user from mock import Mock, NonCallableMock import json import contextlib @@ -133,12 +133,15 @@ class ReplicationResourceCase(unittest.TestCase): @defer.inlineCallbacks def send_text_message(self, room_id, message): handler = self.hs.get_handlers().message_handler - event = yield handler.create_and_send_nonmember_event({ - "type": "m.room.message", - "content": {"body": "message", "msgtype": "m.text"}, - "room_id": room_id, - "sender": self.user.to_string(), - }) + event = yield handler.create_and_send_nonmember_event( + requester_for_user(self.user), + { + "type": "m.room.message", + "content": {"body": "message", "msgtype": "m.text"}, + "room_id": room_id, + "sender": self.user.to_string(), + } + ) defer.returnValue(event.event_id) @defer.inlineCallbacks diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 0785965de2..1d210f9bf8 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -86,7 +86,7 @@ class ProfileTestCase(unittest.TestCase): self.assertEquals(200, code) self.assertEquals(mocked_set.call_args[0][0].localpart, "1234ABCD") - self.assertEquals(mocked_set.call_args[0][1].localpart, "1234ABCD") + self.assertEquals(mocked_set.call_args[0][1].user.localpart, "1234ABCD") self.assertEquals(mocked_set.call_args[0][2], "Frank Jr.") @defer.inlineCallbacks @@ -155,5 +155,5 @@ class ProfileTestCase(unittest.TestCase): self.assertEquals(200, code) self.assertEquals(mocked_set.call_args[0][0].localpart, "1234ABCD") - self.assertEquals(mocked_set.call_args[0][1].localpart, "1234ABCD") + self.assertEquals(mocked_set.call_args[0][1].user.localpart, "1234ABCD") self.assertEquals(mocked_set.call_args[0][2], "http://my.server/pic.gif") diff --git a/tests/utils.py b/tests/utils.py index c67fa1ca35..291b549053 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -20,6 +20,7 @@ from synapse.storage.prepare_database import prepare_database from synapse.storage.engines import create_engine from synapse.server import HomeServer from synapse.federation.transport import server +from synapse.types import Requester from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.logcontext import LoggingContext @@ -510,3 +511,7 @@ class DeferredMockCallable(object): "call(%s)" % _format_call(c[0], c[1]) for c in calls ]) ) + + +def requester_for_user(user): + return Requester(user, None, False) -- cgit 1.4.1 From 1b4f4a936fb416d81203fcd66be690f9a04b2b62 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 4 Mar 2016 14:44:01 +0000 Subject: Hook up the push rules stream to account_data in /sync --- synapse/handlers/sync.py | 22 +++++++ synapse/rest/client/v1/push_rule.py | 2 +- synapse/storage/__init__.py | 5 ++ synapse/storage/push_rule.py | 125 ++++++++++++++++-------------------- 4 files changed, 85 insertions(+), 69 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index fded6e4009..92eab20c7c 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -20,6 +20,7 @@ from synapse.api.constants import Membership, EventTypes from synapse.util import unwrapFirstError from synapse.util.logcontext import LoggingContext, preserve_fn from synapse.util.metrics import Measure +from synapse.push.clientformat import format_push_rules_for_user from twisted.internet import defer @@ -224,6 +225,10 @@ class SyncHandler(BaseHandler): ) ) + account_data['m.push_rules'] = yield self.push_rules_for_user( + sync_config.user + ) + tags_by_room = yield self.store.get_tags_for_user( sync_config.user.to_string() ) @@ -322,6 +327,14 @@ class SyncHandler(BaseHandler): defer.returnValue(room_sync) + @defer.inlineCallbacks + def push_rules_for_user(self, user): + user_id = user.to_string() + rawrules = yield self.store.get_push_rules_for_user(user_id) + enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id) + rules = format_push_rules_for_user(user, rawrules, enabled_map) + defer.returnValue(rules) + def account_data_for_user(self, account_data): account_data_events = [] @@ -481,6 +494,15 @@ class SyncHandler(BaseHandler): ) ) + push_rules_changed = yield self.store.have_push_rules_changed_for_user( + user_id, int(since_token.push_rules_key) + ) + + if push_rules_changed: + account_data["m.push_rules"] = yield self.push_rules_for_user( + sync_config.user + ) + # Get a list of membership change events that have happened. rooms_changed = yield self.store.get_membership_changes_for_user( user_id, since_token.room_key, now_token.room_key diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index edfe28c79b..981d7708db 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -156,7 +156,7 @@ class PushRuleRestServlet(ClientV1RestServlet): return 200, {} def notify_user(self, user_id): - stream_id = self.store.get_push_rules_stream_token() + stream_id, _ = self.store.get_push_rules_stream_token() self.notifier.on_new_event( "push_rules_key", stream_id, users=[user_id] ) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index e2d7b52569..7b7b03d052 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -160,6 +160,11 @@ class DataStore(RoomMemberStore, RoomStore, prefilled_cache=presence_cache_prefill ) + self.push_rules_stream_cache = StreamChangeCache( + "PushRulesStreamChangeCache", + self._push_rules_stream_id_gen.get_max_token()[0], + ) + super(DataStore, self).__init__(hs) def take_presence_startup_info(self): diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index e034024108..792fcbdf5b 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -244,15 +244,10 @@ class PushRuleStore(SQLBaseStore): ) if update_stream: - self._simple_insert_txn( - txn, - table="push_rules_stream", - values={ - "stream_id": stream_id, - "stream_ordering": stream_ordering, - "user_id": user_id, - "rule_id": rule_id, - "op": "ADD", + self._insert_push_rules_update_txn( + txn, stream_id, stream_ordering, user_id, rule_id, + op="ADD", + data={ "priority_class": priority_class, "priority": priority, "conditions": conditions_json, @@ -260,13 +255,6 @@ class PushRuleStore(SQLBaseStore): } ) - txn.call_after( - self.get_push_rules_for_user.invalidate, (user_id,) - ) - txn.call_after( - self.get_push_rules_enabled_for_user.invalidate, (user_id,) - ) - @defer.inlineCallbacks def delete_push_rule(self, user_id, rule_id): """ @@ -284,22 +272,10 @@ class PushRuleStore(SQLBaseStore): "push_rules", {'user_name': user_id, 'rule_id': rule_id}, ) - self._simple_insert_txn( - txn, - table="push_rules_stream", - values={ - "stream_id": stream_id, - "stream_ordering": stream_ordering, - "user_id": user_id, - "rule_id": rule_id, - "op": "DELETE", - } - ) - txn.call_after( - self.get_push_rules_for_user.invalidate, (user_id,) - ) - txn.call_after( - self.get_push_rules_enabled_for_user.invalidate, (user_id,) + + self._insert_push_rules_update_txn( + txn, stream_id, stream_ordering, user_id, rule_id, + op="DELETE" ) with self._push_rules_stream_id_gen.get_next() as (stream_id, stream_ordering): @@ -328,23 +304,9 @@ class PushRuleStore(SQLBaseStore): {'id': new_id}, ) - self._simple_insert_txn( - txn, - "push_rules_stream", - values={ - "stream_id": stream_id, - "stream_ordering": stream_ordering, - "user_id": user_id, - "rule_id": rule_id, - "op": "ENABLE" if enabled else "DISABLE", - } - ) - - txn.call_after( - self.get_push_rules_for_user.invalidate, (user_id,) - ) - txn.call_after( - self.get_push_rules_enabled_for_user.invalidate, (user_id,) + self._insert_push_rules_update_txn( + txn, stream_id, stream_ordering, user_id, rule_id, + op="ENABLE" if enabled else "DISABLE" ) @defer.inlineCallbacks @@ -370,24 +332,9 @@ class PushRuleStore(SQLBaseStore): {'actions': actions_json}, ) - self._simple_insert_txn( - txn, - "push_rules_stream", - values={ - "stream_id": stream_id, - "stream_ordering": stream_ordering, - "user_id": user_id, - "rule_id": rule_id, - "op": "ACTIONS", - "actions": actions_json, - } - ) - - txn.call_after( - self.get_push_rules_for_user.invalidate, (user_id,) - ) - txn.call_after( - self.get_push_rules_enabled_for_user.invalidate, (user_id,) + self._insert_push_rules_update_txn( + txn, stream_id, stream_ordering, user_id, rule_id, + op="ACTIONS", data={"actions": actions_json} ) with self._push_rules_stream_id_gen.get_next() as (stream_id, stream_ordering): @@ -396,6 +343,31 @@ class PushRuleStore(SQLBaseStore): stream_id, stream_ordering ) + def _insert_push_rules_update_txn( + self, txn, stream_id, stream_ordering, user_id, rule_id, op, data=None + ): + values = { + "stream_id": stream_id, + "stream_ordering": stream_ordering, + "user_id": user_id, + "rule_id": rule_id, + "op": op, + } + if data is not None: + values.update(data) + + self._simple_insert_txn(txn, "push_rules_stream", values=values) + + txn.call_after( + self.get_push_rules_for_user.invalidate, (user_id,) + ) + txn.call_after( + self.get_push_rules_enabled_for_user.invalidate, (user_id,) + ) + txn.call_after( + self.push_rules_stream_cache.entity_has_changed, user_id, stream_id + ) + def get_all_push_rule_updates(self, last_id, current_id, limit): """Get all the push rules changes that have happend on the server""" def get_all_push_rule_updates_txn(txn): @@ -403,7 +375,7 @@ class PushRuleStore(SQLBaseStore): "SELECT stream_id, stream_ordering, user_id, rule_id," " op, priority_class, priority, conditions, actions" " FROM push_rules_stream" - " WHERE ? < stream_id and stream_id <= ?" + " WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?" ) txn.execute(sql, (last_id, current_id, limit)) @@ -418,6 +390,23 @@ class PushRuleStore(SQLBaseStore): room stream ordering it corresponds to.""" return self._push_rules_stream_id_gen.get_max_token() + def have_push_rules_changed_for_user(self, user_id, last_id): + if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): + logger.error("FNARG") + return defer.succeed(False) + else: + def have_push_rules_changed_txn(txn): + sql = ( + "SELECT COUNT(stream_id) FROM push_rules_stream" + " WHERE user_id = ? AND ? < stream_id" + ) + txn.execute(sql, (user_id, last_id)) + count, = txn.fetchone() + return bool(count) + return self.runInteraction( + "have_push_rules_changed", have_push_rules_changed_txn + ) + class RuleNotFoundException(Exception): pass -- cgit 1.4.1 From 239badea9be1dd7857833408209ef22dd99773de Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Mon, 7 Mar 2016 20:13:10 +0000 Subject: Use syntax that works on both py2.7 and py3 --- synapse/app/homeserver.py | 2 +- synapse/app/synctl.py | 6 +++--- synapse/config/__main__.py | 2 +- synapse/config/_base.py | 2 +- synapse/handlers/federation.py | 2 +- synapse/handlers/register.py | 2 +- synapse/rest/client/v1/login.py | 2 +- synapse/util/caches/expiringcache.py | 2 +- 8 files changed, 10 insertions(+), 10 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 021dc1d610..fcdc8e6e10 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -722,7 +722,7 @@ def run(hs): if hs.config.daemonize: if hs.config.print_pidfile: - print hs.config.pid_file + print (hs.config.pid_file) daemon = Daemonize( app="synapse-homeserver", diff --git a/synapse/app/synctl.py b/synapse/app/synctl.py index 9249e36d82..ab3a31d7b7 100755 --- a/synapse/app/synctl.py +++ b/synapse/app/synctl.py @@ -29,13 +29,13 @@ NORMAL = "\x1b[m" def start(configfile): - print "Starting ...", + print ("Starting ...") args = SYNAPSE args.extend(["--daemonize", "-c", configfile]) try: subprocess.check_call(args) - print GREEN + "started" + NORMAL + print (GREEN + "started" + NORMAL) except subprocess.CalledProcessError as e: print ( RED + @@ -48,7 +48,7 @@ def stop(pidfile): if os.path.exists(pidfile): pid = int(open(pidfile).read()) os.kill(pid, signal.SIGTERM) - print GREEN + "stopped" + NORMAL + print (GREEN + "stopped" + NORMAL) def main(): diff --git a/synapse/config/__main__.py b/synapse/config/__main__.py index 0a3b70e11f..58c97a70af 100644 --- a/synapse/config/__main__.py +++ b/synapse/config/__main__.py @@ -28,7 +28,7 @@ if __name__ == "__main__": sys.stderr.write("\n" + e.message + "\n") sys.exit(1) - print getattr(config, key) + print (getattr(config, key)) sys.exit(0) else: sys.stderr.write("Unknown command %r\n" % (action,)) diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 15d78ff33a..7449f36491 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -104,7 +104,7 @@ class Config(object): dir_path = cls.abspath(dir_path) try: os.makedirs(dir_path) - except OSError, e: + except OSError as e: if e.errno != errno.EEXIST: raise if not os.path.isdir(dir_path): diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 6e50b0963e..27f2b40bfe 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -472,7 +472,7 @@ class FederationHandler(BaseHandler): limit=100, extremities=[e for e in extremities.keys()] ) - except SynapseError: + except SynapseError as e: logger.info( "Failed to backfill from %s because %s", dom, e, diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index c5e5b28811..e2ace6a4e5 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -241,7 +241,7 @@ class RegistrationHandler(BaseHandler): password_hash=None ) yield registered_user(self.distributor, user) - except Exception, e: + except Exception as e: yield self.store.add_access_token_to_user(user_id, token) # Ignore Registration errors logger.exception(e) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index f13272da8e..c14e8af00e 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -252,7 +252,7 @@ class SAML2RestServlet(ClientV1RestServlet): SP = Saml2Client(conf) saml2_auth = SP.parse_authn_request_response( request.args['SAMLResponse'][0], BINDING_HTTP_POST) - except Exception, e: # Not authenticated + except Exception as e: # Not authenticated logger.exception(e) if saml2_auth and saml2_auth.status_ok() and not saml2_auth.not_signed: username = saml2_auth.name_id.text diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index e863a8f8a9..2b68c1ac93 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -69,7 +69,7 @@ class ExpiringCache(object): if self._max_len and len(self._cache.keys()) > self._max_len: sorted_entries = sorted( self._cache.items(), - key=lambda (k, v): v.time, + key=lambda item: item[1].time, ) for k, _ in sorted_entries[self._max_len:]: -- cgit 1.4.1 From 7076082ae677b280c5b68df37b1fee2fc72752ff Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Tue, 8 Mar 2016 11:45:50 +0000 Subject: Fix relative imports so they work in both py3 and py27 --- synapse/push/__init__.py | 4 ++-- synapse/push/action_generator.py | 4 ++-- synapse/push/bulk_push_rule_evaluator.py | 6 +++--- synapse/push/push_rule_evaluator.py | 4 ++-- synapse/push/pusherpool.py | 2 +- synapse/rest/client/v1/admin.py | 2 +- synapse/rest/client/v1/initial_sync.py | 2 +- synapse/rest/client/v1/login.py | 2 +- synapse/rest/client/v1/register.py | 2 +- synapse/rest/client/v1/room.py | 2 +- synapse/rest/client/v1/voip.py | 2 +- synapse/storage/__init__.py | 2 +- synapse/storage/end_to_end_keys.py | 2 +- synapse/storage/events.py | 2 +- synapse/storage/keys.py | 2 +- synapse/storage/media_repository.py | 2 +- synapse/storage/signatures.py | 2 +- 17 files changed, 22 insertions(+), 22 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py index 4c6c3b83a2..65ef1b68a3 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py @@ -21,7 +21,7 @@ from synapse.util.logcontext import LoggingContext from synapse.util.metrics import Measure import synapse.util.async -import push_rule_evaluator as push_rule_evaluator +from .push_rule_evaluator import evaluator_for_user_id import logging import random @@ -185,7 +185,7 @@ class Pusher(object): processed = False rule_evaluator = yield \ - push_rule_evaluator.evaluator_for_user_id( + evaluator_for_user_id( self.user_id, single_event['room_id'], self.store ) diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py index c6c1dc769e..84efcdd184 100644 --- a/synapse/push/action_generator.py +++ b/synapse/push/action_generator.py @@ -15,7 +15,7 @@ from twisted.internet import defer -import bulk_push_rule_evaluator +from .bulk_push_rule_evaluator import evaluator_for_room_id import logging @@ -35,7 +35,7 @@ class ActionGenerator: @defer.inlineCallbacks def handle_push_actions_for_event(self, event, context, handler): - bulk_evaluator = yield bulk_push_rule_evaluator.evaluator_for_room_id( + bulk_evaluator = yield evaluator_for_room_id( event.room_id, self.hs, self.store ) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 5d8be483e5..87d5061fb0 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -18,8 +18,8 @@ import ujson as json from twisted.internet import defer -import baserules -from push_rule_evaluator import PushRuleEvaluatorForEvent +from .baserules import list_with_base_rules +from .push_rule_evaluator import PushRuleEvaluatorForEvent from synapse.api.constants import EventTypes @@ -39,7 +39,7 @@ def _get_rules(room_id, user_ids, store): rules_enabled_by_user = yield store.bulk_get_push_rules_enabled(user_ids) rules_by_user = { - uid: baserules.list_with_base_rules([ + uid: list_with_base_rules([ decode_rule_json(rule_list) for rule_list in rules_by_user.get(uid, []) ]) diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 98e2a2015e..51f73a5b78 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -15,7 +15,7 @@ from twisted.internet import defer -import baserules +from .baserules import list_with_base_rules import logging import simplejson as json @@ -91,7 +91,7 @@ class PushRuleEvaluator: rule['actions'] = json.loads(raw_rule['actions']) rules.append(rule) - self.rules = baserules.list_with_base_rules(rules) + self.rules = list_with_base_rules(rules) self.enabled_map = enabled_map diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index a05aa5f661..772a095f8b 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -16,7 +16,7 @@ from twisted.internet import defer -from httppusher import HttpPusher +from .httppusher import HttpPusher from synapse.push import PusherConfigException from synapse.util.logcontext import preserve_fn diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index e2f5eb7b29..aa05b3f023 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -18,7 +18,7 @@ from twisted.internet import defer from synapse.api.errors import AuthError, SynapseError from synapse.types import UserID -from base import ClientV1RestServlet, client_path_patterns +from .base import ClientV1RestServlet, client_path_patterns import logging diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index ad161bdbab..36c3520567 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -16,7 +16,7 @@ from twisted.internet import defer from synapse.streams.config import PaginationConfig -from base import ClientV1RestServlet, client_path_patterns +from .base import ClientV1RestServlet, client_path_patterns # TODO: Needs unit testing diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index c14e8af00e..f6902a60a8 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -19,7 +19,7 @@ from synapse.api.errors import SynapseError, LoginError, Codes from synapse.types import UserID from synapse.http.server import finish_request -from base import ClientV1RestServlet, client_path_patterns +from .base import ClientV1RestServlet, client_path_patterns import simplejson as json import urllib diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index 6d6d03c34c..040a7a7ffa 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -18,7 +18,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, Codes from synapse.api.constants import LoginType -from base import ClientV1RestServlet, client_path_patterns +from .base import ClientV1RestServlet, client_path_patterns import synapse.util.stringutils as stringutils from synapse.util.async import run_on_reactor diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index cbf3673eff..4b7d198c52 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -16,7 +16,7 @@ """ This module contains REST servlets to do with rooms: /rooms/ """ from twisted.internet import defer -from base import ClientV1RestServlet, client_path_patterns +from .base import ClientV1RestServlet, client_path_patterns from synapse.api.errors import SynapseError, Codes, AuthError from synapse.streams.config import PaginationConfig from synapse.api.constants import EventTypes, Membership diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index ec4cf8db79..c40442f958 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -15,7 +15,7 @@ from twisted.internet import defer -from base import ClientV1RestServlet, client_path_patterns +from .base import ClientV1RestServlet, client_path_patterns import hmac diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 6f37a85d09..168eb27b03 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -45,7 +45,7 @@ from .search import SearchStore from .tags import TagsStore from .account_data import AccountDataStore -from util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator +from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator from synapse.api.constants import PresenceState from synapse.util.caches.stream_change_cache import StreamChangeCache diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 5dd32b1413..2e89066515 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from _base import SQLBaseStore +from ._base import SQLBaseStore class EndToEndKeyStore(SQLBaseStore): diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 60936500d8..552e7ca35b 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from _base import SQLBaseStore, _RollbackButIsFineException +from ._base import SQLBaseStore, _RollbackButIsFineException from twisted.internet import defer, reactor diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index fd05bfe54e..a495a8a7d9 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from _base import SQLBaseStore +from ._base import SQLBaseStore from synapse.util.caches.descriptors import cachedInlineCallbacks from twisted.internet import defer diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py index 0894384780..9d3ba32478 100644 --- a/synapse/storage/media_repository.py +++ b/synapse/storage/media_repository.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from _base import SQLBaseStore +from ._base import SQLBaseStore class MediaRepositoryStore(SQLBaseStore): diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py index 70c6a06cd1..b10f2a5787 100644 --- a/synapse/storage/signatures.py +++ b/synapse/storage/signatures.py @@ -15,7 +15,7 @@ from twisted.internet import defer -from _base import SQLBaseStore +from ._base import SQLBaseStore from unpaddedbase64 import encode_base64 from synapse.crypto.event_signing import compute_event_reference_hash -- cgit 1.4.1 From b7dbe5147a85bea8f14b78a27ff499fe5a0d444a Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 9 Mar 2016 11:26:26 +0000 Subject: Add a parse_json_object function to deduplicate all the copy+pasted _parse_json functions. Also document the parse_.* functions. --- synapse/http/servlet.py | 70 ++++++++++++++++++++++++++-- synapse/rest/client/v1/directory.py | 16 ++----- synapse/rest/client/v1/login.py | 15 +----- synapse/rest/client/v1/push_rule.py | 16 ++----- synapse/rest/client/v1/pusher.py | 17 ++----- synapse/rest/client/v1/register.py | 14 +----- synapse/rest/client/v1/room.py | 26 ++++------- synapse/rest/client/v2_alpha/_base.py | 22 --------- synapse/rest/client/v2_alpha/account.py | 8 ++-- synapse/rest/client/v2_alpha/register.py | 8 ++-- synapse/rest/client/v2_alpha/tokenrefresh.py | 6 +-- 11 files changed, 97 insertions(+), 121 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 7bd87940b4..41c519c000 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -15,14 +15,27 @@ """ This module contains base REST classes for constructing REST servlets. """ -from synapse.api.errors import SynapseError +from synapse.api.errors import SynapseError, Codes import logging +import simplejson logger = logging.getLogger(__name__) def parse_integer(request, name, default=None, required=False): + """Parse an integer parameter from the request string + + :param request: the twisted HTTP request. + :param name (str): the name of the query parameter. + :param default: value to use if the parameter is absent, defaults to None. + :param required (bool): whether to raise a 400 SynapseError if the + parameter is absent, defaults to False. + :return: An int value or the default. + :raises + SynapseError if the parameter is absent and required, or if the + parameter is present and not an integer. + """ if name in request.args: try: return int(request.args[name][0]) @@ -32,12 +45,25 @@ def parse_integer(request, name, default=None, required=False): else: if required: message = "Missing integer query parameter %r" % (name,) - raise SynapseError(400, message) + raise SynapseError(400, message, errcode=Codes.MISSING_PARAM) else: return default def parse_boolean(request, name, default=None, required=False): + """Parse a boolean parameter from the request query string + + :param request: the twisted HTTP request. + :param name (str): the name of the query parameter. + :param default: value to use if the parameter is absent, defaults to None. + :param required (bool): whether to raise a 400 SynapseError if the + parameter is absent, defaults to False. + :return: A bool value or the default. + :raises + SynapseError if the parameter is absent and required, or if the + parameter is present and not one of "true" or "false". + """ + if name in request.args: try: return { @@ -53,30 +79,64 @@ def parse_boolean(request, name, default=None, required=False): else: if required: message = "Missing boolean query parameter %r" % (name,) - raise SynapseError(400, message) + raise SynapseError(400, message, errcode=Codes.MISSING_PARAM) else: return default def parse_string(request, name, default=None, required=False, allowed_values=None, param_type="string"): + """Parse a string parameter from the request query string. + + :param request: the twisted HTTP request. + :param name (str): the name of the query parameter. + :param default: value to use if the parameter is absent, defaults to None. + :param required (bool): whether to raise a 400 SynapseError if the + parameter is absent, defaults to False. + :param allowed_values (list): List of allowed values for the string, + or None if any value is allowed, defaults to None + :return: A string value or the default. + :raises + SynapseError if the parameter is absent and required, or if the + parameter is present, must be one of a list of allowed values and + is not one of those allowed values. + """ + if name in request.args: value = request.args[name][0] if allowed_values is not None and value not in allowed_values: message = "Query parameter %r must be one of [%s]" % ( name, ", ".join(repr(v) for v in allowed_values) ) - raise SynapseError(message) + raise SynapseError(400, message) else: return value else: if required: message = "Missing %s query parameter %r" % (param_type, name) - raise SynapseError(400, message) + raise SynapseError(400, message, errcode=Codes.MISSING_PARAM) else: return default +def parse_json_object_from_request(request): + """Parse a JSON object from the body of a twisted HTTP request. + + :param request: the twisted HTTP request. + :raises + SynapseError if the request body couldn't be decoded as JSON or + if it wasn't a JSON object. + """ + try: + content = simplejson.loads(request.content.read()) + if type(content) != dict: + message = "Content must be a JSON object." + raise SynapseError(400, message, errcode=Codes.BAD_JSON) + return content + except simplejson.JSONDecodeError: + raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) + + class RestServlet(object): """ A Synapse REST Servlet. diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 8bfe9fdea8..60c5ec77aa 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -18,9 +18,10 @@ from twisted.internet import defer from synapse.api.errors import AuthError, SynapseError, Codes from synapse.types import RoomAlias +from synapse.http.servlet import parse_json_object_from_request + from .base import ClientV1RestServlet, client_path_patterns -import simplejson as json import logging @@ -45,7 +46,7 @@ class ClientDirectoryServer(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, room_alias): - content = _parse_json(request) + content = parse_json_object_from_request(request) if "room_id" not in content: raise SynapseError(400, "Missing room_id key", errcode=Codes.BAD_JSON) @@ -135,14 +136,3 @@ class ClientDirectoryServer(ClientV1RestServlet): ) defer.returnValue((200, {})) - - -def _parse_json(request): - try: - content = json.loads(request.content.read()) - if type(content) != dict: - raise SynapseError(400, "Content must be a JSON object.", - errcode=Codes.NOT_JSON) - return content - except ValueError: - raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index f6902a60a8..fe593d07ce 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -18,6 +18,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, LoginError, Codes from synapse.types import UserID from synapse.http.server import finish_request +from synapse.http.servlet import parse_json_object_from_request from .base import ClientV1RestServlet, client_path_patterns @@ -79,7 +80,7 @@ class LoginRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request): - login_submission = _parse_json(request) + login_submission = parse_json_object_from_request(request) try: if login_submission["type"] == LoginRestServlet.PASS_TYPE: if not self.password_enabled: @@ -400,18 +401,6 @@ class CasTicketServlet(ClientV1RestServlet): return (user, attributes) -def _parse_json(request): - try: - content = json.loads(request.content.read()) - if type(content) != dict: - raise SynapseError( - 400, "Content must be a JSON object.", errcode=Codes.BAD_JSON - ) - return content - except ValueError: - raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) - - def register_servlets(hs, http_server): LoginRestServlet(hs).register(http_server) if hs.config.saml2_enabled: diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 981d7708db..b5695d427a 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -16,7 +16,7 @@ from twisted.internet import defer from synapse.api.errors import ( - SynapseError, Codes, UnrecognizedRequestError, NotFoundError, StoreError + SynapseError, UnrecognizedRequestError, NotFoundError, StoreError ) from .base import ClientV1RestServlet, client_path_patterns from synapse.storage.push_rule import ( @@ -25,8 +25,7 @@ from synapse.storage.push_rule import ( from synapse.push.clientformat import format_push_rules_for_user from synapse.push.baserules import BASE_RULE_IDS from synapse.push.rulekinds import PRIORITY_CLASS_MAP - -import simplejson as json +from synapse.http.servlet import parse_json_object_from_request class PushRuleRestServlet(ClientV1RestServlet): @@ -52,7 +51,7 @@ class PushRuleRestServlet(ClientV1RestServlet): if '/' in spec['rule_id'] or '\\' in spec['rule_id']: raise SynapseError(400, "rule_id may not contain slashes") - content = _parse_json(request) + content = parse_json_object_from_request(request) user_id = requester.user.to_string() @@ -341,14 +340,5 @@ class InvalidRuleException(Exception): pass -# XXX: C+ped from rest/room.py - surely this should be common? -def _parse_json(request): - try: - content = json.loads(request.content.read()) - return content - except ValueError: - raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) - - def register_servlets(hs, http_server): PushRuleRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 4c662e6e3c..ee029b4f77 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -17,9 +17,10 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, Codes from synapse.push import PusherConfigException +from synapse.http.servlet import parse_json_object_from_request + from .base import ClientV1RestServlet, client_path_patterns -import simplejson as json import logging logger = logging.getLogger(__name__) @@ -33,7 +34,7 @@ class PusherRestServlet(ClientV1RestServlet): requester = yield self.auth.get_user_by_req(request) user = requester.user - content = _parse_json(request) + content = parse_json_object_from_request(request) pusher_pool = self.hs.get_pusherpool() @@ -92,17 +93,5 @@ class PusherRestServlet(ClientV1RestServlet): return 200, {} -# XXX: C+ped from rest/room.py - surely this should be common? -def _parse_json(request): - try: - content = json.loads(request.content.read()) - if type(content) != dict: - raise SynapseError(400, "Content must be a JSON object.", - errcode=Codes.NOT_JSON) - return content - except ValueError: - raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) - - def register_servlets(hs, http_server): PusherRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index 040a7a7ffa..c6a2ef2ccc 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -20,12 +20,12 @@ from synapse.api.errors import SynapseError, Codes from synapse.api.constants import LoginType from .base import ClientV1RestServlet, client_path_patterns import synapse.util.stringutils as stringutils +from synapse.http.servlet import parse_json_object_from_request from synapse.util.async import run_on_reactor from hashlib import sha1 import hmac -import simplejson as json import logging logger = logging.getLogger(__name__) @@ -98,7 +98,7 @@ class RegisterRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request): - register_json = _parse_json(request) + register_json = parse_json_object_from_request(request) session = (register_json["session"] if "session" in register_json else None) @@ -355,15 +355,5 @@ class RegisterRestServlet(ClientV1RestServlet): ) -def _parse_json(request): - try: - content = json.loads(request.content.read()) - if type(content) != dict: - raise SynapseError(400, "Content must be a JSON object.") - return content - except ValueError: - raise SynapseError(400, "Content not JSON.") - - def register_servlets(hs, http_server): RegisterRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 4b7d198c52..7a9f3d11b9 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -22,6 +22,7 @@ from synapse.streams.config import PaginationConfig from synapse.api.constants import EventTypes, Membership from synapse.types import UserID, RoomID, RoomAlias from synapse.events.utils import serialize_event +from synapse.http.servlet import parse_json_object_from_request import simplejson as json import logging @@ -137,7 +138,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet): def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): requester = yield self.auth.get_user_by_req(request) - content = _parse_json(request) + content = parse_json_object_from_request(request) event_dict = { "type": event_type, @@ -179,7 +180,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, room_id, event_type, txn_id=None): requester = yield self.auth.get_user_by_req(request, allow_guest=True) - content = _parse_json(request) + content = parse_json_object_from_request(request) msg_handler = self.handlers.message_handler event = yield msg_handler.create_and_send_nonmember_event( @@ -229,7 +230,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet): ) try: - content = _parse_json(request) + content = parse_json_object_from_request(request) except: # Turns out we used to ignore the body entirely, and some clients # cheekily send invalid bodies. @@ -433,7 +434,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet): raise AuthError(403, "Guest access not allowed") try: - content = _parse_json(request) + content = parse_json_object_from_request(request) except: # Turns out we used to ignore the body entirely, and some clients # cheekily send invalid bodies. @@ -500,7 +501,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, room_id, event_id, txn_id=None): requester = yield self.auth.get_user_by_req(request) - content = _parse_json(request) + content = parse_json_object_from_request(request) msg_handler = self.handlers.message_handler event = yield msg_handler.create_and_send_nonmember_event( @@ -548,7 +549,7 @@ class RoomTypingRestServlet(ClientV1RestServlet): room_id = urllib.unquote(room_id) target_user = UserID.from_string(urllib.unquote(user_id)) - content = _parse_json(request) + content = parse_json_object_from_request(request) typing_handler = self.handlers.typing_notification_handler @@ -580,7 +581,7 @@ class SearchRestServlet(ClientV1RestServlet): def on_POST(self, request): requester = yield self.auth.get_user_by_req(request) - content = _parse_json(request) + content = parse_json_object_from_request(request) batch = request.args.get("next_batch", [None])[0] results = yield self.handlers.search_handler.search( @@ -592,17 +593,6 @@ class SearchRestServlet(ClientV1RestServlet): defer.returnValue((200, results)) -def _parse_json(request): - try: - content = json.loads(request.content.read()) - if type(content) != dict: - raise SynapseError(400, "Content must be a JSON object.", - errcode=Codes.NOT_JSON) - return content - except ValueError: - raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) - - def register_txn_path(servlet, regex_string, http_server, with_get=False): """Registers a transaction-based path. diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py index 24af322126..b6faa2b0e6 100644 --- a/synapse/rest/client/v2_alpha/_base.py +++ b/synapse/rest/client/v2_alpha/_base.py @@ -17,11 +17,9 @@ """ from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX -from synapse.api.errors import SynapseError import re import logging -import simplejson logger = logging.getLogger(__name__) @@ -44,23 +42,3 @@ def client_v2_patterns(path_regex, releases=(0,)): new_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/r%d" % release) patterns.append(re.compile("^" + new_prefix + path_regex)) return patterns - - -def parse_request_allow_empty(request): - content = request.content.read() - if content is None or content == '': - return None - try: - return simplejson.loads(content) - except simplejson.JSONDecodeError: - raise SynapseError(400, "Content not JSON.") - - -def parse_json_dict_from_request(request): - try: - content = simplejson.loads(request.content.read()) - if type(content) != dict: - raise SynapseError(400, "Content must be a JSON object.") - return content - except simplejson.JSONDecodeError: - raise SynapseError(400, "Content not JSON.") diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index a614b79d45..688b051580 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -17,10 +17,10 @@ from twisted.internet import defer from synapse.api.constants import LoginType from synapse.api.errors import LoginError, SynapseError, Codes -from synapse.http.servlet import RestServlet +from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.util.async import run_on_reactor -from ._base import client_v2_patterns, parse_json_dict_from_request +from ._base import client_v2_patterns import logging @@ -41,7 +41,7 @@ class PasswordRestServlet(RestServlet): def on_POST(self, request): yield run_on_reactor() - body = parse_json_dict_from_request(request) + body = parse_json_object_from_request(request) authed, result, params = yield self.auth_handler.check_auth([ [LoginType.PASSWORD], @@ -114,7 +114,7 @@ class ThreepidRestServlet(RestServlet): def on_POST(self, request): yield run_on_reactor() - body = parse_json_dict_from_request(request) + body = parse_json_object_from_request(request) threePidCreds = body.get('threePidCreds') threePidCreds = body.get('three_pid_creds', threePidCreds) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index ec5c21fa1f..b090e66bcf 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -17,9 +17,9 @@ from twisted.internet import defer from synapse.api.constants import LoginType from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError -from synapse.http.servlet import RestServlet +from synapse.http.servlet import RestServlet, parse_json_object_from_request -from ._base import client_v2_patterns, parse_json_dict_from_request +from ._base import client_v2_patterns import logging import hmac @@ -73,7 +73,7 @@ class RegisterRestServlet(RestServlet): ret = yield self.onEmailTokenRequest(request) defer.returnValue(ret) - body = parse_json_dict_from_request(request) + body = parse_json_object_from_request(request) # we do basic sanity checks here because the auth layer will store these # in sessions. Pull out the username/password provided to us. @@ -236,7 +236,7 @@ class RegisterRestServlet(RestServlet): @defer.inlineCallbacks def onEmailTokenRequest(self, request): - body = parse_json_dict_from_request(request) + body = parse_json_object_from_request(request) required = ['id_server', 'client_secret', 'email', 'send_attempt'] absent = [] diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py index 3553f6b040..a158c2209a 100644 --- a/synapse/rest/client/v2_alpha/tokenrefresh.py +++ b/synapse/rest/client/v2_alpha/tokenrefresh.py @@ -16,9 +16,9 @@ from twisted.internet import defer from synapse.api.errors import AuthError, StoreError, SynapseError -from synapse.http.servlet import RestServlet +from synapse.http.servlet import RestServlet, parse_json_object_from_request -from ._base import client_v2_patterns, parse_json_dict_from_request +from ._base import client_v2_patterns class TokenRefreshRestServlet(RestServlet): @@ -35,7 +35,7 @@ class TokenRefreshRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request): - body = parse_json_dict_from_request(request) + body = parse_json_object_from_request(request) try: old_refresh_token = body["refresh_token"] auth_handler = self.hs.get_handlers().auth_handler -- cgit 1.4.1 From 40160e24ab93ca4261df82335ab5521f134e2eda Mon Sep 17 00:00:00 2001 From: blide Date: Thu, 10 Mar 2016 02:08:37 +0300 Subject: Register endpoint returns refresh_token Guest registration still doesn't return refresh_token --- synapse/rest/client/v2_alpha/register.py | 13 ++++++++----- tests/rest/client/v2_alpha/test_register.py | 30 +++++++++++++++++------------ 2 files changed, 26 insertions(+), 17 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index b090e66bcf..533ff136eb 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -187,7 +187,7 @@ class RegisterRestServlet(RestServlet): else: logger.info("bind_email not specified: not binding email") - result = self._create_registration_details(user_id, token) + result = yield self._create_registration_details(user_id, token) defer.returnValue((200, result)) def on_OPTIONS(self, _): @@ -198,7 +198,7 @@ class RegisterRestServlet(RestServlet): (user_id, token) = yield self.registration_handler.appservice_register( username, as_token ) - defer.returnValue(self._create_registration_details(user_id, token)) + defer.returnValue((yield self._create_registration_details(user_id, token))) @defer.inlineCallbacks def _do_shared_secret_registration(self, username, password, mac): @@ -225,14 +225,17 @@ class RegisterRestServlet(RestServlet): (user_id, token) = yield self.registration_handler.register( localpart=username, password=password ) - defer.returnValue(self._create_registration_details(user_id, token)) + defer.returnValue((yield self._create_registration_details(user_id, token))) + @defer.inlineCallbacks def _create_registration_details(self, user_id, token): - return { + refresh_token = yield self.auth_handler.issue_refresh_token(user_id) + defer.returnValue({ "user_id": user_id, "access_token": token, "home_server": self.hs.hostname, - } + "refresh_token": refresh_token, + }) @defer.inlineCallbacks def onEmailTokenRequest(self, request): diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index b867599079..e19952b8b6 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -62,12 +62,15 @@ class RegisterRestServletTestCase(unittest.TestCase): self.registration_handler.appservice_register = Mock( return_value=(user_id, token) ) - result = yield self.servlet.on_POST(self.request) - self.assertEquals(result, (200, { - "user_id": user_id, - "access_token": token, - "home_server": self.hs.hostname - })) + (code, result) = yield self.servlet.on_POST(self.request) + self.assertEquals(code, 200) + det_data = { + "user_id": user_id, + "access_token": token, + "home_server": self.hs.hostname + } + self.assertDictContainsSubset(det_data, result) + self.assertIn("refresh_token", result) @defer.inlineCallbacks def test_POST_appservice_registration_invalid(self): @@ -112,12 +115,15 @@ class RegisterRestServletTestCase(unittest.TestCase): }) self.registration_handler.register = Mock(return_value=(user_id, token)) - result = yield self.servlet.on_POST(self.request) - self.assertEquals(result, (200, { - "user_id": user_id, - "access_token": token, - "home_server": self.hs.hostname - })) + (code, result) = yield self.servlet.on_POST(self.request) + self.assertEquals(code, 200) + det_data = { + "user_id": user_id, + "access_token": token, + "home_server": self.hs.hostname + } + self.assertDictContainsSubset(det_data, result) + self.assertIn("refresh_token", result) def test_POST_disabled_registration(self): self.hs.config.enable_registration = False -- cgit 1.4.1 From aa11db5f119b9fa88242b0df95cfddd00e196ca1 Mon Sep 17 00:00:00 2001 From: David Baker Date: Fri, 11 Mar 2016 13:14:18 +0000 Subject: Fix cache invalidation so deleting access tokens (which we did when changing password) actually takes effect without HS restart. Reinstate the code to avoid logging out the session that changed the password, removed in 415c2f05491ce65a4fc34326519754cd1edd9c54 --- synapse/handlers/auth.py | 13 +++++++++---- synapse/push/pusherpool.py | 8 ++++---- synapse/rest/client/v2_alpha/account.py | 2 +- synapse/storage/registration.py | 28 ++++++++++++++++++++-------- 4 files changed, 34 insertions(+), 17 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 7a4afe446d..a740cc3da3 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -432,13 +432,18 @@ class AuthHandler(BaseHandler): ) @defer.inlineCallbacks - def set_password(self, user_id, newpassword): + def set_password(self, user_id, newpassword, requester=None): password_hash = self.hash(newpassword) + except_access_token_ids = [requester.access_token_id] if requester else [] + yield self.store.user_set_password_hash(user_id, password_hash) - yield self.store.user_delete_access_tokens(user_id) - yield self.hs.get_pusherpool().remove_pushers_by_user(user_id) - yield self.store.flush_user(user_id) + yield self.store.user_delete_access_tokens_except( + user_id, except_access_token_ids + ) + yield self.hs.get_pusherpool().remove_pushers_by_user_except_access_tokens( + user_id, except_access_token_ids + ) @defer.inlineCallbacks def add_threepid(self, user_id, medium, address, validated_at): diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 772a095f8b..28ec94d866 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -92,14 +92,14 @@ class PusherPool: yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) @defer.inlineCallbacks - def remove_pushers_by_user(self, user_id): + def remove_pushers_by_user_except_access_tokens(self, user_id, except_token_ids): all = yield self.store.get_all_pushers() logger.info( - "Removing all pushers for user %s", - user_id, + "Removing all pushers for user %s except access tokens ids %r", + user_id, except_token_ids ) for p in all: - if p['user_name'] == user_id: + if p['user_name'] == user_id and p['access_token'] not in except_token_ids: logger.info( "Removing pusher for app id %s, pushkey %s, user %s", p['app_id'], p['pushkey'], p['user_name'] diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 688b051580..dd4ea45588 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -79,7 +79,7 @@ class PasswordRestServlet(RestServlet): new_password = params['new_password'] yield self.auth_handler.set_password( - user_id, new_password + user_id, new_password, requester ) defer.returnValue((200, {})) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index aa49f53458..5eef7ebcc7 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -208,14 +208,26 @@ class RegistrationStore(SQLBaseStore): ) @defer.inlineCallbacks - def flush_user(self, user_id): - rows = yield self._execute( - 'flush_user', None, - "SELECT token FROM access_tokens WHERE user_id = ?", - user_id - ) - for r in rows: - self.get_user_by_access_token.invalidate((r,)) + def user_delete_access_tokens_except(self, user_id, except_token_ids): + def f(txn): + txn.execute( + "SELECT id, token FROM access_tokens WHERE user_id = ? LIMIT 50", + (user_id,) + ) + rows = txn.fetchall() + for r in rows: + if r[0] in except_token_ids: + continue + + txn.call_after(self.get_user_by_access_token.invalidate, (r[1],)) + txn.execute( + "DELETE FROM access_tokens WHERE id in (%s)" % ",".join( + ["?" for _ in rows] + ), [r[0] for r in rows] + ) + return len(rows) == 50 + while (yield self.runInteraction("user_delete_access_tokens_except", f)): + pass @cached() def get_user_by_access_token(self, token): -- cgit 1.4.1 From b13035cc91410634421820e5175d0596f5a67549 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 11 Mar 2016 16:27:50 +0000 Subject: Implement logout --- synapse/rest/__init__.py | 2 ++ synapse/rest/client/v1/logout.py | 72 ++++++++++++++++++++++++++++++++++++++++ synapse/storage/registration.py | 49 +++++++++++++++++++-------- 3 files changed, 109 insertions(+), 14 deletions(-) create mode 100644 synapse/rest/client/v1/logout.py (limited to 'synapse/rest/client') diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 433237c204..6688fa8fa0 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -30,6 +30,7 @@ from synapse.rest.client.v1 import ( push_rule, register as v1_register, login as v1_login, + logout, ) from synapse.rest.client.v2_alpha import ( @@ -72,6 +73,7 @@ class ClientRestResource(JsonResource): admin.register_servlets(hs, client_resource) pusher.register_servlets(hs, client_resource) push_rule.register_servlets(hs, client_resource) + logout.register_servlets(hs, client_resource) # "v2" sync.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py new file mode 100644 index 0000000000..9bff02ee4e --- /dev/null +++ b/synapse/rest/client/v1/logout.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.api.errors import AuthError, Codes + +from .base import ClientV1RestServlet, client_path_patterns + +import logging + + +logger = logging.getLogger(__name__) + + +class LogoutRestServlet(ClientV1RestServlet): + PATTERNS = client_path_patterns("/logout$") + + def __init__(self, hs): + super(LogoutRestServlet, self).__init__(hs) + self.store = hs.get_datastore() + + def on_OPTIONS(self, request): + return (200, {}) + + @defer.inlineCallbacks + def on_POST(self, request): + try: + access_token = request.args["access_token"][0] + except KeyError: + raise AuthError( + self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.", + errcode=Codes.MISSING_TOKEN + ) + yield self.store.delete_access_token(access_token) + defer.returnValue((200, {})) + + +class LogoutAllRestServlet(ClientV1RestServlet): + PATTERNS = client_path_patterns("/logout/all$") + + def __init__(self, hs): + super(LogoutAllRestServlet, self).__init__(hs) + self.store = hs.get_datastore() + self.auth = hs.get_auth() + + def on_OPTIONS(self, request): + return (200, {}) + + @defer.inlineCallbacks + def on_POST(self, request): + requester = yield self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + yield self.store.user_delete_access_tokens(user_id) + defer.returnValue((200, {})) + + +def register_servlets(hs, http_server): + LogoutRestServlet(hs).register(http_server) + LogoutAllRestServlet(hs).register(http_server) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 5e7a4e371d..18898c44eb 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -195,24 +195,45 @@ class RegistrationStore(SQLBaseStore): }) @defer.inlineCallbacks - def user_delete_access_tokens(self, user_id, except_token_ids): + def user_delete_access_tokens(self, user_id, except_token_ids=[]): def f(txn): txn.execute( - "SELECT id, token FROM access_tokens " - "WHERE user_id = ? AND id NOT IN ? LIMIT 50", - (user_id, except_token_ids) + "SELECT token FROM access_tokens" + " WHERE user_id = ? AND id NOT IN (%s)" % ( + ",".join(["?" for _ in except_token_ids]), + ), + [user_id] + except_token_ids ) - rows = txn.fetchall() - for r in rows: - txn.call_after(self.get_user_by_access_token.invalidate, (r[1],)) - txn.execute( - "DELETE FROM access_tokens WHERE id in (%s)" % ",".join( - ["?" for _ in rows] - ), [r[0] for r in rows] + + while True: + rows = txn.fetchmany(100) + if not rows: + break + + for row in rows: + txn.call_after(self.get_user_by_access_token.invalidate, (row[0],)) + + txn.execute( + "DELETE FROM access_tokens WHERE token in (%s)" % ( + ",".join(["?" for _ in rows]), + ), [r[0] for r in rows] + ) + + yield self.runInteraction("user_delete_access_tokens", f) + + def delete_access_token(self, access_token): + def f(txn): + self._simple_delete_one_txn( + txn, + table="access_tokens", + keyvalues={ + "token": access_token + }, ) - return len(rows) == 50 - while (yield self.runInteraction("user_delete_access_tokens", f)): - pass + + txn.call_after(self.get_user_by_access_token.invalidate, (access_token,)) + + return self.runInteraction("delete_access_token", f) @cached() def get_user_by_access_token(self, token): -- cgit 1.4.1 From e9c1cabac26cf8e28152ebdb3caf29d4457eea0e Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 11 Mar 2016 16:41:03 +0000 Subject: Use parse_json_object_from_request to parse JSON out of request bodies --- synapse/federation/transport/server.py | 4 ++-- synapse/http/servlet.py | 17 ++++++++++++----- synapse/rest/client/v1/presence.py | 13 +++++-------- synapse/rest/client/v1/profile.py | 8 ++++---- synapse/rest/client/v1/room.py | 14 ++++---------- synapse/rest/client/v2_alpha/account_data.py | 21 ++++----------------- synapse/rest/client/v2_alpha/filter.py | 10 ++-------- synapse/rest/client/v2_alpha/keys.py | 22 +++++++--------------- synapse/rest/client/v2_alpha/tags.py | 12 +++--------- synapse/rest/key/v2/remote_key_resource.py | 12 ++---------- tests/rest/client/v1/test_profile.py | 6 ++++-- 11 files changed, 49 insertions(+), 90 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 6e92e2f8f4..208bff8d4f 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -18,6 +18,7 @@ from twisted.internet import defer from synapse.api.urls import FEDERATION_PREFIX as PREFIX from synapse.api.errors import Codes, SynapseError from synapse.http.server import JsonResource +from synapse.http.servlet import parse_json_object_from_request from synapse.util.ratelimitutils import FederationRateLimiter import functools @@ -419,8 +420,7 @@ class On3pidBindServlet(BaseFederationServlet): @defer.inlineCallbacks def on_POST(self, request): - content_bytes = request.content.read() - content = json.loads(content_bytes) + content = parse_json_object_from_request(request) if "invites" in content: last_exception = None for invite in content["invites"]: diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 41c519c000..1996f8b136 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -128,14 +128,21 @@ def parse_json_object_from_request(request): if it wasn't a JSON object. """ try: - content = simplejson.loads(request.content.read()) - if type(content) != dict: - message = "Content must be a JSON object." - raise SynapseError(400, message, errcode=Codes.BAD_JSON) - return content + content_bytes = request.content.read() + except: + raise SynapseError(400, "Error reading JSON content.") + + try: + content = simplejson.loads(content_bytes) except simplejson.JSONDecodeError: raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) + if type(content) != dict: + message = "Content must be a JSON object." + raise SynapseError(400, message, errcode=Codes.BAD_JSON) + + return content + class RestServlet(object): diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index bbfa1d6ac4..27d9ed586b 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -19,9 +19,9 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, AuthError from synapse.types import UserID +from synapse.http.servlet import parse_json_object_from_request from .base import ClientV1RestServlet, client_path_patterns -import simplejson as json import logging logger = logging.getLogger(__name__) @@ -56,9 +56,10 @@ class PresenceStatusRestServlet(ClientV1RestServlet): raise AuthError(403, "Can only set your own presence state") state = {} - try: - content = json.loads(request.content.read()) + content = parse_json_object_from_request(request) + + try: state["presence"] = content.pop("presence") if "status_msg" in content: @@ -113,11 +114,7 @@ class PresenceListRestServlet(ClientV1RestServlet): raise SynapseError( 400, "Cannot modify another user's presence list") - try: - content = json.loads(request.content.read()) - except: - logger.exception("JSON parse error") - raise SynapseError(400, "Unable to parse content") + content = parse_json_object_from_request(request) if "invite" in content: for u in content["invite"]: diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index 953764bd8e..65c4e2ebef 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -18,8 +18,7 @@ from twisted.internet import defer from .base import ClientV1RestServlet, client_path_patterns from synapse.types import UserID - -import simplejson as json +from synapse.http.servlet import parse_json_object_from_request class ProfileDisplaynameRestServlet(ClientV1RestServlet): @@ -44,8 +43,9 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet): requester = yield self.auth.get_user_by_req(request, allow_guest=True) user = UserID.from_string(user_id) + content = parse_json_object_from_request(request) + try: - content = json.loads(request.content.read()) new_name = content["displayname"] except: defer.returnValue((400, "Unable to parse name")) @@ -81,8 +81,8 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet): requester = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) + content = parse_json_object_from_request(request) try: - content = json.loads(request.content.read()) new_name = content["avatar_url"] except: defer.returnValue((400, "Unable to parse name")) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 7a9f3d11b9..a1fa7daf79 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -24,7 +24,6 @@ from synapse.types import UserID, RoomID, RoomAlias from synapse.events.utils import serialize_event from synapse.http.servlet import parse_json_object_from_request -import simplejson as json import logging import urllib @@ -72,15 +71,10 @@ class RoomCreateRestServlet(ClientV1RestServlet): defer.returnValue((200, info)) def get_room_config(self, request): - try: - user_supplied_config = json.loads(request.content.read()) - if "visibility" not in user_supplied_config: - # default visibility - user_supplied_config["visibility"] = "public" - return user_supplied_config - except (ValueError, TypeError): - raise SynapseError(400, "Body must be JSON.", - errcode=Codes.BAD_JSON) + user_supplied_config = parse_json_object_from_request(request) + # default visibility + user_supplied_config.setdefault("visibility", "public") + return user_supplied_config def on_OPTIONS(self, request): return (200, {}) diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py index 1456881c1a..b16079cece 100644 --- a/synapse/rest/client/v2_alpha/account_data.py +++ b/synapse/rest/client/v2_alpha/account_data.py @@ -15,15 +15,13 @@ from ._base import client_v2_patterns -from synapse.http.servlet import RestServlet -from synapse.api.errors import AuthError, SynapseError +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.api.errors import AuthError from twisted.internet import defer import logging -import simplejson as json - logger = logging.getLogger(__name__) @@ -47,11 +45,7 @@ class AccountDataServlet(RestServlet): if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add account data for other users.") - try: - content_bytes = request.content.read() - body = json.loads(content_bytes) - except: - raise SynapseError(400, "Invalid JSON") + body = parse_json_object_from_request(request) max_id = yield self.store.add_account_data_for_user( user_id, account_data_type, body @@ -86,14 +80,7 @@ class RoomAccountDataServlet(RestServlet): if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add account data for other users.") - try: - content_bytes = request.content.read() - body = json.loads(content_bytes) - except: - raise SynapseError(400, "Invalid JSON") - - if not isinstance(body, dict): - raise ValueError("Expected a JSON object") + body = parse_json_object_from_request(request) max_id = yield self.store.add_account_data_to_room( user_id, room_id, account_data_type, body diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py index 7c94f6ec41..510f8b2c74 100644 --- a/synapse/rest/client/v2_alpha/filter.py +++ b/synapse/rest/client/v2_alpha/filter.py @@ -16,12 +16,11 @@ from twisted.internet import defer from synapse.api.errors import AuthError, SynapseError -from synapse.http.servlet import RestServlet +from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.types import UserID from ._base import client_v2_patterns -import simplejson as json import logging @@ -84,12 +83,7 @@ class CreateFilterRestServlet(RestServlet): if not self.hs.is_mine(target_user): raise SynapseError(400, "Can only create filters for local users") - try: - content = json.loads(request.content.read()) - - # TODO(paul): check for required keys and invalid keys - except: - raise SynapseError(400, "Invalid filter definition") + content = parse_json_object_from_request(request) filter_id = yield self.filtering.add_user_filter( user_localpart=target_user.localpart, diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index f989b08614..89ab39491c 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -15,16 +15,15 @@ from twisted.internet import defer -from synapse.api.errors import SynapseError -from synapse.http.servlet import RestServlet +from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.types import UserID from canonicaljson import encode_canonical_json from ._base import client_v2_patterns -import simplejson as json import logging +import simplejson as json logger = logging.getLogger(__name__) @@ -68,10 +67,9 @@ class KeyUploadServlet(RestServlet): user_id = requester.user.to_string() # TODO: Check that the device_id matches that in the authentication # or derive the device_id from the authentication instead. - try: - body = json.loads(request.content.read()) - except: - raise SynapseError(400, "Invalid key JSON") + + body = parse_json_object_from_request(request) + time_now = self.clock.time_msec() # TODO: Validate the JSON to make sure it has the right keys. @@ -173,10 +171,7 @@ class KeyQueryServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request, user_id, device_id): yield self.auth.get_user_by_req(request) - try: - body = json.loads(request.content.read()) - except: - raise SynapseError(400, "Invalid key JSON") + body = parse_json_object_from_request(request) result = yield self.handle_request(body) defer.returnValue(result) @@ -272,10 +267,7 @@ class OneTimeKeyServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request, user_id, device_id, algorithm): yield self.auth.get_user_by_req(request) - try: - body = json.loads(request.content.read()) - except: - raise SynapseError(400, "Invalid key JSON") + body = parse_json_object_from_request(request) result = yield self.handle_request(body) defer.returnValue(result) diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py index 79c436a8cf..dac8603b07 100644 --- a/synapse/rest/client/v2_alpha/tags.py +++ b/synapse/rest/client/v2_alpha/tags.py @@ -15,15 +15,13 @@ from ._base import client_v2_patterns -from synapse.http.servlet import RestServlet -from synapse.api.errors import AuthError, SynapseError +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.api.errors import AuthError from twisted.internet import defer import logging -import simplejson as json - logger = logging.getLogger(__name__) @@ -72,11 +70,7 @@ class TagServlet(RestServlet): if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add tags for other users.") - try: - content_bytes = request.content.read() - body = json.loads(content_bytes) - except: - raise SynapseError(400, "Invalid tag JSON") + body = parse_json_object_from_request(request) max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 81ef1f4702..9552016fec 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -13,7 +13,7 @@ # limitations under the License. from synapse.http.server import request_handler, respond_with_json_bytes -from synapse.http.servlet import parse_integer +from synapse.http.servlet import parse_integer, parse_json_object_from_request from synapse.api.errors import SynapseError, Codes from twisted.web.resource import Resource @@ -22,7 +22,6 @@ from twisted.internet import defer from io import BytesIO -import json import logging logger = logging.getLogger(__name__) @@ -126,14 +125,7 @@ class RemoteKey(Resource): @request_handler @defer.inlineCallbacks def async_render_POST(self, request): - try: - content = json.loads(request.content.read()) - if type(content) != dict: - raise ValueError() - except ValueError: - raise SynapseError( - 400, "Content must be JSON object.", errcode=Codes.NOT_JSON - ) + content = parse_json_object_from_request(request) query = content["server_keys"] diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 1d210f9bf8..af02fce8fb 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -95,7 +95,8 @@ class ProfileTestCase(unittest.TestCase): mocked_set.side_effect = AuthError(400, "message") (code, response) = yield self.mock_resource.trigger( - "PUT", "/profile/%s/displayname" % ("@4567:test"), '"Frank Jr."' + "PUT", "/profile/%s/displayname" % ("@4567:test"), + '{"displayname": "Frank Jr."}' ) self.assertTrue( @@ -121,7 +122,8 @@ class ProfileTestCase(unittest.TestCase): mocked_set.side_effect = SynapseError(400, "message") (code, response) = yield self.mock_resource.trigger( - "PUT", "/profile/%s/displayname" % ("@opaque:elsewhere"), None + "PUT", "/profile/%s/displayname" % ("@opaque:elsewhere"), + '{"displayname":"bob"}' ) self.assertTrue( -- cgit 1.4.1 From 398cd1edfbefa207e44047bd63adcdcc6e859f2e Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Mon, 14 Mar 2016 14:16:41 +0000 Subject: Fix regression where synapse checked whether push rules were valid JSON before the compatibility hack that handled clients sending invalid JSON --- synapse/http/servlet.py | 21 +++++++++++++++++---- synapse/rest/client/v1/push_rule.py | 4 ++-- 2 files changed, 19 insertions(+), 6 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 1996f8b136..1c8bd8666f 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -119,13 +119,13 @@ def parse_string(request, name, default=None, required=False, return default -def parse_json_object_from_request(request): - """Parse a JSON object from the body of a twisted HTTP request. +def parse_json_value_from_request(request): + """Parse a JSON value from the body of a twisted HTTP request. :param request: the twisted HTTP request. + :returns: The JSON value. :raises - SynapseError if the request body couldn't be decoded as JSON or - if it wasn't a JSON object. + SynapseError if the request body couldn't be decoded as JSON. """ try: content_bytes = request.content.read() @@ -137,6 +137,19 @@ def parse_json_object_from_request(request): except simplejson.JSONDecodeError: raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) + return content + + +def parse_json_object_from_request(request): + """Parse a JSON object from the body of a twisted HTTP request. + + :param request: the twisted HTTP request. + :raises + SynapseError if the request body couldn't be decoded as JSON or + if it wasn't a JSON object. + """ + content = parse_json_value_from_request(request) + if type(content) != dict: message = "Content must be a JSON object." raise SynapseError(400, message, errcode=Codes.BAD_JSON) diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index b5695d427a..02d837ee6a 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -25,7 +25,7 @@ from synapse.storage.push_rule import ( from synapse.push.clientformat import format_push_rules_for_user from synapse.push.baserules import BASE_RULE_IDS from synapse.push.rulekinds import PRIORITY_CLASS_MAP -from synapse.http.servlet import parse_json_object_from_request +from synapse.http.servlet import parse_json_value_from_request class PushRuleRestServlet(ClientV1RestServlet): @@ -51,7 +51,7 @@ class PushRuleRestServlet(ClientV1RestServlet): if '/' in spec['rule_id'] or '\\' in spec['rule_id']: raise SynapseError(400, "rule_id may not contain slashes") - content = parse_json_object_from_request(request) + content = parse_json_value_from_request(request) user_id = requester.user.to_string() -- cgit 1.4.1 From 12904932c41c73714543b817157f09073fcc2625 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Tue, 15 Mar 2016 17:41:06 +0000 Subject: Hook up adding a pusher to the notifier for replication. --- synapse/notifier.py | 6 ++++++ synapse/rest/client/v1/pusher.py | 6 ++++++ 2 files changed, 12 insertions(+) (limited to 'synapse/rest/client') diff --git a/synapse/notifier.py b/synapse/notifier.py index 9b69b0333a..f00cd8c588 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -282,6 +282,12 @@ class Notifier(object): self.notify_replication() + def on_new_replication_data(self): + """Used to inform replication listeners that something has happend + without waking up any of the normal user event streams""" + with PreserveLoggingContext(): + self.notify_replication() + @defer.inlineCallbacks def wait_for_events(self, user_id, timeout, callback, room_ids=None, from_token=StreamToken.START): diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index ee029b4f77..9881f068c3 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -29,6 +29,10 @@ logger = logging.getLogger(__name__) class PusherRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/pushers/set$") + def __init__(self, hs): + super(PusherRestServlet, self).__init__(hs) + self.notifier = hs.get_notifier() + @defer.inlineCallbacks def on_POST(self, request): requester = yield self.auth.get_user_by_req(request) @@ -87,6 +91,8 @@ class PusherRestServlet(ClientV1RestServlet): raise SynapseError(400, "Config Error: " + pce.message, errcode=Codes.MISSING_PARAM) + self.notifier.on_new_replication_data() + defer.returnValue((200, {})) def on_OPTIONS(self, _): -- cgit 1.4.1 From c12b9d719a3cf1eeb9c4c8d354dbaecab5e76233 Mon Sep 17 00:00:00 2001 From: David Baker Date: Wed, 16 Mar 2016 11:56:24 +0000 Subject: Make registration idempotent: if you specify the same session, make it give you an access token for the user that was registered on previous uses of that session. Tweak the UI auth layer to not delete sessions when their auth has completed and hence expire themn so they don't hang around until server restart. Allow server-side data to be associated with UI auth sessions. --- synapse/handlers/auth.py | 60 +++++++++++++++++++++++++------- synapse/rest/client/v2_alpha/register.py | 27 +++++++++++++- 2 files changed, 74 insertions(+), 13 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 5c0ea636bc..5dc9d91757 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -27,6 +27,7 @@ import logging import bcrypt import pymacaroons import simplejson +import time import synapse.util.stringutils as stringutils @@ -35,6 +36,7 @@ logger = logging.getLogger(__name__) class AuthHandler(BaseHandler): + SESSION_EXPIRE_SECS = 48 * 60 * 60 def __init__(self, hs): super(AuthHandler, self).__init__(hs) @@ -66,15 +68,18 @@ class AuthHandler(BaseHandler): 'auth' key: this method prompts for auth if none is sent. clientip (str): The IP address of the client. Returns: - A tuple of (authed, dict, dict) where authed is true if the client - has successfully completed an auth flow. If it is true, the first - dict contains the authenticated credentials of each stage. + A tuple of (authed, dict, dict, session_id) where authed is true if + the client has successfully completed an auth flow. If it is true + the first dict contains the authenticated credentials of each stage. If authed is false, the first dictionary is the server response to the login request and should be passed back to the client. In either case, the second dict contains the parameters for this request (which may have been given only in a previous call). + + session_id is the ID of this session, either passed in by the client + or assigned by the call to check_auth """ authdict = None @@ -103,7 +108,10 @@ class AuthHandler(BaseHandler): if not authdict: defer.returnValue( - (False, self._auth_dict_for_flows(flows, session), clientdict) + ( + False, self._auth_dict_for_flows(flows, session), + clientdict, session['id'] + ) ) if 'creds' not in session: @@ -122,12 +130,11 @@ class AuthHandler(BaseHandler): for f in flows: if len(set(f) - set(creds.keys())) == 0: logger.info("Auth completed with creds: %r", creds) - self._remove_session(session) - defer.returnValue((True, creds, clientdict)) + defer.returnValue((True, creds, clientdict, session['id'])) ret = self._auth_dict_for_flows(flows, session) ret['completed'] = creds.keys() - defer.returnValue((False, ret, clientdict)) + defer.returnValue((False, ret, clientdict, session['id'])) @defer.inlineCallbacks def add_oob_auth(self, stagetype, authdict, clientip): @@ -154,6 +161,29 @@ class AuthHandler(BaseHandler): defer.returnValue(True) defer.returnValue(False) + def set_session_data(self, session_id, key, value): + """ + Store a key-value pair into the sessions data associated with this + request. This data is stored server-side and cannot be modified by + the client. + :param session_id: (string) The ID of this session as returned from check_auth + :param key: (string) The key to store the data under + :param value: (any) The data to store + """ + sess = self._get_session_info(session_id) + sess.setdefault('serverdict', {})[key] = value + self._save_session(sess) + + def get_session_data(self, session_id, key, default=None): + """ + Retrieve data stored with set_session_data + :param session_id: (string) The ID of this session as returned from check_auth + :param key: (string) The key to store the data under + :param default: (any) Value to return if the key has not been set + """ + sess = self._get_session_info(session_id) + return sess.setdefault('serverdict', {}).get(key, default) + @defer.inlineCallbacks def _check_password_auth(self, authdict, _): if "user" not in authdict or "password" not in authdict: @@ -263,7 +293,7 @@ class AuthHandler(BaseHandler): if not session_id: # create a new session while session_id is None or session_id in self.sessions: - session_id = stringutils.random_string(24) + session_id = stringutils.random_string_with_symbols(24) self.sessions[session_id] = { "id": session_id, } @@ -455,11 +485,17 @@ class AuthHandler(BaseHandler): def _save_session(self, session): # TODO: Persistent storage logger.debug("Saving session %s", session) + session["last_used"] = time.time() self.sessions[session["id"]] = session - - def _remove_session(self, session): - logger.debug("Removing session %s", session) - del self.sessions[session["id"]] + self._prune_sessions() + + def _prune_sessions(self): + for sid,sess in self.sessions.items(): + last_used = 0 + if 'last_used' in sess: + last_used = sess['last_used'] + if last_used < time.time() - AuthHandler.SESSION_EXPIRE_SECS: + del self.sessions[sid] def hash(self, password): """Computes a secure hash of password. diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 533ff136eb..649491bdf6 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -139,7 +139,7 @@ class RegisterRestServlet(RestServlet): [LoginType.EMAIL_IDENTITY] ] - authed, result, params = yield self.auth_handler.check_auth( + authed, result, params, session_id = yield self.auth_handler.check_auth( flows, body, self.hs.get_ip_from_request(request) ) @@ -147,6 +147,24 @@ class RegisterRestServlet(RestServlet): defer.returnValue((401, result)) return + # have we already registered a user for this session + registered_user_id = self.auth_handler.get_session_data( + session_id, "registered_user_id", None + ) + if registered_user_id is not None: + logger.info( + "Already registered user ID %r for this session", + registered_user_id + ) + access_token = yield self.auth_handler.issue_access_token(registered_user_id) + refresh_token = yield self.auth_handler.issue_refresh_token(registered_user_id) + defer.returnValue((200, { + "user_id": registered_user_id, + "access_token": access_token, + "home_server": self.hs.hostname, + "refresh_token": refresh_token, + })) + # NB: This may be from the auth handler and NOT from the POST if 'password' not in params: raise SynapseError(400, "Missing password.", Codes.MISSING_PARAM) @@ -161,6 +179,13 @@ class RegisterRestServlet(RestServlet): guest_access_token=guest_access_token, ) + # remember that we've now registered that user account, and with what + # user ID (since the user may not have specified) + logger.info("%r", body) + self.auth_handler.set_session_data( + session_id, "registered_user_id", user_id + ) + if result and LoginType.EMAIL_IDENTITY in result: threepid = result[LoginType.EMAIL_IDENTITY] -- cgit 1.4.1 From 99797947aa5a7cdf8fe12043b4f25a155bcf4555 Mon Sep 17 00:00:00 2001 From: David Baker Date: Wed, 16 Mar 2016 12:51:34 +0000 Subject: pep8 & remove debug logging --- synapse/handlers/auth.py | 2 +- synapse/rest/client/v2_alpha/register.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 5dc9d91757..a9f5e3710b 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -490,7 +490,7 @@ class AuthHandler(BaseHandler): self._prune_sessions() def _prune_sessions(self): - for sid,sess in self.sessions.items(): + for sid, sess in self.sessions.items(): last_used = 0 if 'last_used' in sess: last_used = sess['last_used'] diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 649491bdf6..c440430e25 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -149,7 +149,7 @@ class RegisterRestServlet(RestServlet): # have we already registered a user for this session registered_user_id = self.auth_handler.get_session_data( - session_id, "registered_user_id", None + session_id, "registered_user_id", None ) if registered_user_id is not None: logger.info( @@ -157,7 +157,9 @@ class RegisterRestServlet(RestServlet): registered_user_id ) access_token = yield self.auth_handler.issue_access_token(registered_user_id) - refresh_token = yield self.auth_handler.issue_refresh_token(registered_user_id) + refresh_token = yield self.auth_handler.issue_refresh_token( + registered_user_id + ) defer.returnValue((200, { "user_id": registered_user_id, "access_token": access_token, @@ -181,9 +183,8 @@ class RegisterRestServlet(RestServlet): # remember that we've now registered that user account, and with what # user ID (since the user may not have specified) - logger.info("%r", body) self.auth_handler.set_session_data( - session_id, "registered_user_id", user_id + session_id, "registered_user_id", user_id ) if result and LoginType.EMAIL_IDENTITY in result: -- cgit 1.4.1 From f5e90422f5d70afaf9bdf97cc620b563cf31a8eb Mon Sep 17 00:00:00 2001 From: David Baker Date: Wed, 16 Mar 2016 14:33:19 +0000 Subject: take extra return val from check_auth in account too --- synapse/rest/client/v2_alpha/account.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/rest/client') diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index dd4ea45588..7f8a6a4cf7 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -43,7 +43,7 @@ class PasswordRestServlet(RestServlet): body = parse_json_object_from_request(request) - authed, result, params = yield self.auth_handler.check_auth([ + authed, result, params, _ = yield self.auth_handler.check_auth([ [LoginType.PASSWORD], [LoginType.EMAIL_IDENTITY] ], body, self.hs.get_ip_from_request(request)) -- cgit 1.4.1 From a7daa5ae131cc860769d859cf03b48cefdc0500a Mon Sep 17 00:00:00 2001 From: David Baker Date: Wed, 16 Mar 2016 19:36:57 +0000 Subject: Make registration idempotent, part 2: be idempotent if the client specifies a username. --- synapse/handlers/auth.py | 14 ++++++++++++++ synapse/handlers/register.py | 12 +++++++++++- synapse/rest/client/v2_alpha/register.py | 22 +++++++++++++++++----- 3 files changed, 42 insertions(+), 6 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index d7233cd0d6..82d458b424 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -160,6 +160,20 @@ class AuthHandler(BaseHandler): defer.returnValue(True) defer.returnValue(False) + def get_session_id(self, clientdict): + """ + Gets the session ID for a client given the client dictionary + :param clientdict: The dictionary sent by the client in the request + :return: The string session ID the client sent. If the client did not + send a session ID, returns None. + """ + sid = None + if clientdict and 'auth' in clientdict: + authdict = clientdict['auth'] + if 'session' in authdict: + sid = authdict['session'] + return sid + def set_session_data(self, session_id, key, value): """ Store a key-value pair into the sessions data associated with this diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 6ffb8c0da6..f287ee247b 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -47,7 +47,8 @@ class RegistrationHandler(BaseHandler): self._next_generated_user_id = None @defer.inlineCallbacks - def check_username(self, localpart, guest_access_token=None): + def check_username(self, localpart, guest_access_token=None, + assigned_user_id=None): yield run_on_reactor() if urllib.quote(localpart.encode('utf-8')) != localpart: @@ -60,6 +61,15 @@ class RegistrationHandler(BaseHandler): user = UserID(localpart, self.hs.hostname) user_id = user.to_string() + if assigned_user_id: + if user_id == assigned_user_id: + return + else: + raise SynapseError( + 400, + "A different user ID has already been registered for this session", + ) + yield self.check_user_id_not_appservice_exclusive(user_id) users = yield self.store.get_users_by_id_case_insensitive(user_id) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index c440430e25..b8590560d3 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -16,6 +16,7 @@ from twisted.internet import defer from synapse.api.constants import LoginType +from synapse.types import UserID from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -122,10 +123,25 @@ class RegisterRestServlet(RestServlet): guest_access_token = body.get("guest_access_token", None) + session_id = self.auth_handler.get_session_id(body) + logger.error("session id: %r", session_id) + registered_user_id = None + if session_id: + # if we get a registered user id out of here, it means we previously + # registered a user for this session, so we could just return the + # user here. We carry on and go through the auth checks though, + # for paranoia. + registered_user_id = self.auth_handler.get_session_data( + session_id, "registered_user_id", None + ) + logger.error("already regged: %r", registered_user_id) + logger.error("check: %r", desired_username) + if desired_username is not None: yield self.registration_handler.check_username( desired_username, - guest_access_token=guest_access_token + guest_access_token=guest_access_token, + assigned_user_id=registered_user_id, ) if self.hs.config.enable_registration_captcha: @@ -147,10 +163,6 @@ class RegisterRestServlet(RestServlet): defer.returnValue((401, result)) return - # have we already registered a user for this session - registered_user_id = self.auth_handler.get_session_data( - session_id, "registered_user_id", None - ) if registered_user_id is not None: logger.info( "Already registered user ID %r for this session", -- cgit 1.4.1 From f984decd6636baa4974a136e2ce8d4fecab3146f Mon Sep 17 00:00:00 2001 From: David Baker Date: Wed, 16 Mar 2016 19:40:48 +0000 Subject: Unused import --- synapse/rest/client/v2_alpha/register.py | 1 - 1 file changed, 1 deletion(-) (limited to 'synapse/rest/client') diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index b8590560d3..d3e66740ad 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -16,7 +16,6 @@ from twisted.internet import defer from synapse.api.constants import LoginType -from synapse.types import UserID from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError from synapse.http.servlet import RestServlet, parse_json_object_from_request -- cgit 1.4.1 From 5670205e2a0e4b87005be743eb6cdfd817fe89ae Mon Sep 17 00:00:00 2001 From: David Baker Date: Wed, 16 Mar 2016 19:49:42 +0000 Subject: remove debug logging --- synapse/rest/client/v2_alpha/register.py | 3 --- 1 file changed, 3 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index d3e66740ad..d32c06c882 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -123,7 +123,6 @@ class RegisterRestServlet(RestServlet): guest_access_token = body.get("guest_access_token", None) session_id = self.auth_handler.get_session_id(body) - logger.error("session id: %r", session_id) registered_user_id = None if session_id: # if we get a registered user id out of here, it means we previously @@ -133,8 +132,6 @@ class RegisterRestServlet(RestServlet): registered_user_id = self.auth_handler.get_session_data( session_id, "registered_user_id", None ) - logger.error("already regged: %r", registered_user_id) - logger.error("check: %r", desired_username) if desired_username is not None: yield self.registration_handler.check_username( -- cgit 1.4.1 From 2cd9260500efa82713edd365f54d491ac0328fb0 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 17 Mar 2016 11:09:03 +0000 Subject: Update aliases event after deletion Attempt to update the appropriate `m.room.aliases` event after deleting an alias. This may fail due to the deleter not being in the room. Will also check if the canonical alias of the event is set to the deleted alias, and if so will attempt to delete it. --- synapse/handlers/directory.py | 52 ++++++++++++++++++++++++++++++++----- synapse/rest/client/v1/directory.py | 3 ++- 2 files changed, 48 insertions(+), 7 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index c4aaa11918..be9f2a21b2 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -32,6 +32,8 @@ class DirectoryHandler(BaseHandler): def __init__(self, hs): super(DirectoryHandler, self).__init__(hs) + self.state = hs.get_state_handler() + self.federation = hs.get_replication_layer() self.federation.register_query_handler( "directory", self.on_directory_query @@ -93,7 +95,7 @@ class DirectoryHandler(BaseHandler): yield self._create_association(room_alias, room_id, servers) @defer.inlineCallbacks - def delete_association(self, user_id, room_alias): + def delete_association(self, requester, user_id, room_alias): # association deletion for human users can_delete = yield self._user_can_delete_alias(room_alias, user_id) @@ -112,7 +114,25 @@ class DirectoryHandler(BaseHandler): errcode=Codes.EXCLUSIVE ) - yield self._delete_association(room_alias) + room_id = yield self._delete_association(room_alias) + + try: + yield self.send_room_alias_update_event( + requester, + requester.user.to_string(), + room_id + ) + + yield self._update_canonical_alias( + requester, + requester.user.to_string(), + room_id, + room_alias, + ) + except AuthError as e: + logger.info("Failed to update alias events: %s", e) + + defer.returnValue(room_id) @defer.inlineCallbacks def delete_appservice_association(self, service, room_alias): @@ -129,11 +149,9 @@ class DirectoryHandler(BaseHandler): if not self.hs.is_mine(room_alias): raise SynapseError(400, "Room alias must be local") - yield self.store.delete_room_alias(room_alias) + room_id = yield self.store.delete_room_alias(room_alias) - # TODO - Looks like _update_room_alias_event has never been implemented - # if room_id: - # yield self._update_room_alias_events(user_id, room_id) + defer.returnValue(room_id) @defer.inlineCallbacks def get_association(self, room_alias): @@ -233,6 +251,28 @@ class DirectoryHandler(BaseHandler): ratelimit=False ) + @defer.inlineCallbacks + def _update_canonical_alias(self, requester, user_id, room_id, room_alias): + alias_event = yield self.state.get_current_state( + room_id, EventTypes.CanonicalAlias, "" + ) + + if alias_event.content.get("alias", "") != room_alias.to_string(): + return + + msg_handler = self.hs.get_handlers().message_handler + yield msg_handler.create_and_send_nonmember_event( + requester, + { + "type": EventTypes.CanonicalAlias, + "state_key": "", + "room_id": room_id, + "sender": user_id, + "content": {}, + }, + ratelimit=False + ) + @defer.inlineCallbacks def get_association_from_room_alias(self, room_alias): result = yield self.store.get_association_from_room_alias( diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 60c5ec77aa..59a23d6cb6 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -127,8 +127,9 @@ class ClientDirectoryServer(ClientV1RestServlet): room_alias = RoomAlias.from_string(room_alias) yield dir_handler.delete_association( - user.to_string(), room_alias + requester, user.to_string(), room_alias ) + logger.info( "User %s deleted alias %s", user.to_string(), -- cgit 1.4.1 From 3e7fac0d56dca5b389ef7a671c1cd6b0795724c8 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 21 Mar 2016 14:03:20 +0000 Subject: Add published room list edit API --- synapse/api/auth.py | 54 ++++++++++++++++++++++++++++++++++--- synapse/handlers/directory.py | 16 +++++++++++ synapse/rest/client/v1/directory.py | 42 +++++++++++++++++++++++++++++ synapse/storage/room.py | 8 ++++++ 4 files changed, 116 insertions(+), 4 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 3038df4ab8..4f9c3c9db8 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -814,17 +814,16 @@ class Auth(object): return auth_ids - @log_function - def _can_send_event(self, event, auth_events): + def _get_send_level(self, etype, state_key, auth_events): key = (EventTypes.PowerLevels, "", ) send_level_event = auth_events.get(key) send_level = None if send_level_event: send_level = send_level_event.content.get("events", {}).get( - event.type + etype ) if send_level is None: - if hasattr(event, "state_key"): + if state_key is not None: send_level = send_level_event.content.get( "state_default", 50 ) @@ -838,6 +837,13 @@ class Auth(object): else: send_level = 0 + return send_level + + @log_function + def _can_send_event(self, event, auth_events): + send_level = self._get_send_level( + event.type, event.get("state_key", None), auth_events + ) user_level = self._get_user_power_level(event.user_id, auth_events) if user_level < send_level: @@ -982,3 +988,43 @@ class Auth(object): "You don't have permission to add ops level greater " "than your own" ) + + @defer.inlineCallbacks + def check_can_change_room_list(self, room_id, user): + """Check if the user is allowed to edit the room's entry in the + published room list. + + Args: + room_id (str) + user (UserID) + """ + + is_admin = yield self.is_server_admin(user) + if is_admin: + defer.returnValue(True) + + user_id = user.to_string() + yield self.check_joined_room(room_id, user_id) + + # We currently require the user is a "moderator" in the room. We do this + # by checking if they would (theoretically) be able to change the + # m.room.aliases events + power_level_event = yield self.state.get_current_state( + room_id, EventTypes.PowerLevels, "" + ) + + auth_events = {} + if power_level_event: + auth_events[(EventTypes.PowerLevels, "")] = power_level_event + + send_level = self._get_send_level( + EventTypes.Aliases, "", auth_events + ) + user_level = self._get_user_power_level(user_id, auth_events) + + if user_level < send_level: + raise AuthError( + 403, + "This server requires you to be a moderator in the room to" + " edit its room list entry" + ) diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 6bcc5a5e2b..b2617c8898 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -317,3 +317,19 @@ class DirectoryHandler(BaseHandler): is_admin = yield self.auth.is_server_admin(UserID.from_string(user_id)) defer.returnValue(is_admin) + + @defer.inlineCallbacks + def edit_published_room_list(self, requester, room_id, visibility): + if requester.is_guest: + raise AuthError(403, "Guests cannot edit the published room list") + + if visibility not in ["public", "private"]: + raise SynapseError(400, "Invalide visibility setting") + + room = yield self.store.get_room(room_id) + if room is None: + raise SynapseError(400, "Unknown room") + + yield self.auth.check_can_change_room_list(room_id, requester.user) + + yield self.store.set_room_is_public(room_id, visibility == "public") diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 59a23d6cb6..8ac09419dc 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -30,6 +30,7 @@ logger = logging.getLogger(__name__) def register_servlets(hs, http_server): ClientDirectoryServer(hs).register(http_server) + ClientDirectoryListServer(hs).register(http_server) class ClientDirectoryServer(ClientV1RestServlet): @@ -137,3 +138,44 @@ class ClientDirectoryServer(ClientV1RestServlet): ) defer.returnValue((200, {})) + + +class ClientDirectoryListServer(ClientV1RestServlet): + PATTERNS = client_path_patterns("/directory/list/room/(?P[^/]*)$") + + def __init__(self, hs): + super(ClientDirectoryListServer, self).__init__(hs) + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def on_GET(self, request, room_id): + room = yield self.store.get_room(room_id) + if room is None: + raise SynapseError(400, "Unknown room") + + defer.returnValue((200, { + "visibility": "public" if room["is_public"] else "private" + })) + + @defer.inlineCallbacks + def on_PUT(self, request, room_id): + requester = yield self.auth.get_user_by_req(request) + + content = parse_json_object_from_request(request) + visibility = content.get("visibility", "public") + + yield self.handlers.directory_handler.edit_published_room_list( + requester, room_id, visibility, + ) + + defer.returnValue((200, {})) + + @defer.inlineCallbacks + def on_DELETE(self, request, room_id): + requester = yield self.auth.get_user_by_req(request) + + yield self.handlers.directory_handler.edit_published_room_list( + requester, room_id, "private", + ) + + defer.returnValue((200, {})) diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 46ab38a313..9be977f387 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -77,6 +77,14 @@ class RoomStore(SQLBaseStore): allow_none=True, ) + def set_room_is_public(self, room_id, is_public): + return self._simple_update_one( + table="rooms", + keyvalues={"room_id": room_id}, + updatevalues={"is_public": is_public}, + desc="set_room_is_public", + ) + def get_public_room_ids(self): return self._simple_select_onecol( table="rooms", -- cgit 1.4.1