From a24bc5b2dc3a5d81cdfbe7be367dbb461d85b999 Mon Sep 17 00:00:00 2001 From: David Baker Date: Mon, 23 May 2016 18:33:51 +0100 Subject: Add GET /notifications API --- synapse/rest/client/v2_alpha/notifications.py | 100 ++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) create mode 100644 synapse/rest/client/v2_alpha/notifications.py (limited to 'synapse/rest/client') diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py new file mode 100644 index 0000000000..505e998393 --- /dev/null +++ b/synapse/rest/client/v2_alpha/notifications.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.http.servlet import ( + RestServlet, parse_string, parse_integer +) +from synapse.events.utils import ( + serialize_event, format_event_for_client_v2_without_room_id, +) + +from ._base import client_v2_patterns + +import logging + +logger = logging.getLogger(__name__) + + +class NotificationsServlet(RestServlet): + PATTERNS = client_v2_patterns("/notifications$", releases=()) + + def __init__(self, hs): + super(NotificationsServlet, self).__init__() + self.store = hs.get_datastore() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + + @defer.inlineCallbacks + def on_GET(self, request): + requester = yield self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + + from_token = parse_string(request, "from", required=False) + limit = parse_integer(request, "limit", default=50) + + limit = min(limit, 500) + + push_actions = yield self.store.get_push_actions_for_user( + user_id, from_token, limit + ) + + receipts_by_room = yield self.store.get_receipts_for_user_with_orderings( + user_id, 'm.read' + ) + + notif_event_ids = [pa["event_id"] for pa in push_actions] + notif_events = yield self.store.get_events(notif_event_ids) + + returned_push_actions = [] + + next_token = None + + for pa in push_actions: + returned_pa = { + "room_id": pa["room_id"], + "profile_tag": pa["profile_tag"], + "actions": pa["actions"], + "event": serialize_event( + notif_events[pa["event_id"]], + self.clock.time_msec(), + event_format=format_event_for_client_v2_without_room_id, + ), + } + + if pa["room_id"] not in receipts_by_room: + returned_pa["read"] = False + else: + receipt = receipts_by_room[pa["room_id"]] + + returned_pa["read"] = ( + pa["topological_ordering"] > receipt["topological_ordering"] + or ( + pa["topological_ordering"] == receipt["topological_ordering"] + and pa["stream_ordering"] > receipt["stream_ordering"] + ) + ) + returned_push_actions.append(returned_pa) + next_token = pa["stream_ordering"] + + defer.returnValue((200, { + "notifications": returned_push_actions, + "next_token": next_token, + })) + + +def register_servlets(hs, http_server): + NotificationsServlet(hs).register(http_server) -- cgit 1.5.1 From b791a530da1d89e36297cb626950cb42a7ea9226 Mon Sep 17 00:00:00 2001 From: David Baker Date: Mon, 23 May 2016 18:48:02 +0100 Subject: Actually make the 'read' flag correct --- synapse/rest/client/v2_alpha/notifications.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py index 505e998393..9600256962 100644 --- a/synapse/rest/client/v2_alpha/notifications.py +++ b/synapse/rest/client/v2_alpha/notifications.py @@ -81,10 +81,9 @@ class NotificationsServlet(RestServlet): receipt = receipts_by_room[pa["room_id"]] returned_pa["read"] = ( - pa["topological_ordering"] > receipt["topological_ordering"] - or ( - pa["topological_ordering"] == receipt["topological_ordering"] - and pa["stream_ordering"] > receipt["stream_ordering"] + receipt["topological_ordering"] >= pa["topological_ordering"] or ( + receipt["topological_ordering"] == pa["topological_ordering"] and + receipt["stream_ordering"] >= pa["stream_ordering"] ) ) returned_push_actions.append(returned_pa) -- cgit 1.5.1 From 37b7e846200f00a36c6084d426ab73ee5d0e0218 Mon Sep 17 00:00:00 2001 From: David Baker Date: Tue, 24 May 2016 11:33:32 +0100 Subject: Include the ts the notif was received at --- synapse/rest/client/v2_alpha/notifications.py | 1 + synapse/storage/event_push_actions.py | 12 +++++++----- 2 files changed, 8 insertions(+), 5 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py index 9600256962..4d84230e68 100644 --- a/synapse/rest/client/v2_alpha/notifications.py +++ b/synapse/rest/client/v2_alpha/notifications.py @@ -68,6 +68,7 @@ class NotificationsServlet(RestServlet): "room_id": pa["room_id"], "profile_tag": pa["profile_tag"], "actions": pa["actions"], + "ts": pa["received_ts"], "event": serialize_event( notif_events[pa["event_id"]], self.clock.time_msec(), diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index a9cb042b5a..5123072c44 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -201,11 +201,13 @@ class EventPushActionsStore(SQLBaseStore): else: args = [user_id, limit] sql = ( - "SELECT event_id, room_id, stream_ordering, topological_ordering," - " actions, profile_tag" - " FROM event_push_actions" - " WHERE user_id = ? %s" - " ORDER BY stream_ordering DESC" + "SELECT epa.event_id, epa.room_id," + " epa.stream_ordering, epa.topological_ordering," + " epa.actions, epa.profile_tag, e.received_ts" + " FROM event_push_actions epa, events e" + " WHERE epa.room_id = e.room_id AND epa.event_id = e.event_id" + " AND epa.user_id = ? %s" + " ORDER BY epa.stream_ordering DESC" " LIMIT ?" % (before_clause,) ) -- cgit 1.5.1 From 0682ca04b3ac0a3e148633d020b3248dbe98f13d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 8 Aug 2016 17:01:30 +0100 Subject: Fix CAS login Attempting to log in with CAS was giving a 500 error. --- synapse/rest/client/v1/login.py | 1 + 1 file changed, 1 insertion(+) (limited to 'synapse/rest/client') diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 92fcae674a..d8c76a3465 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -427,6 +427,7 @@ class CasTicketServlet(ClientV1RestServlet): self.cas_server_url = hs.config.cas_server_url self.cas_service_url = hs.config.cas_service_url self.cas_required_attributes = hs.config.cas_required_attributes + self.auth_handler = hs.get_auth_handler() @defer.inlineCallbacks def on_GET(self, request): -- cgit 1.5.1 From 65666fedd5f60ec65fd86d9bbdff40fa67469025 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 8 Aug 2016 17:17:25 +0100 Subject: Clean up CAS login code Remove some apparently unused code. Clean up parse_cas_response, mostly to catch the exception if the CAS response isn't valid XML. --- synapse/rest/client/v1/login.py | 158 +++++++++------------------------------- 1 file changed, 33 insertions(+), 125 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 92fcae674a..fef7910c4f 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -54,10 +54,6 @@ class LoginRestServlet(ClientV1RestServlet): self.jwt_secret = hs.config.jwt_secret self.jwt_algorithm = hs.config.jwt_algorithm self.cas_enabled = hs.config.cas_enabled - self.cas_server_url = hs.config.cas_server_url - self.cas_required_attributes = hs.config.cas_required_attributes - self.servername = hs.config.server_name - self.http_client = hs.get_simple_http_client() self.auth_handler = self.hs.get_auth_handler() self.device_handler = self.hs.get_device_handler() @@ -110,17 +106,6 @@ class LoginRestServlet(ClientV1RestServlet): LoginRestServlet.JWT_TYPE): result = yield self.do_jwt_login(login_submission) defer.returnValue(result) - # TODO Delete this after all CAS clients switch to token login instead - elif self.cas_enabled and (login_submission["type"] == - LoginRestServlet.CAS_TYPE): - uri = "%s/proxyValidate" % (self.cas_server_url,) - args = { - "ticket": login_submission["ticket"], - "service": login_submission["service"] - } - body = yield self.http_client.get_raw(uri, args) - result = yield self.do_cas_login(body) - defer.returnValue(result) elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: result = yield self.do_token_login(login_submission) defer.returnValue(result) @@ -191,51 +176,6 @@ class LoginRestServlet(ClientV1RestServlet): defer.returnValue((200, result)) - # TODO Delete this after all CAS clients switch to token login instead - @defer.inlineCallbacks - def do_cas_login(self, cas_response_body): - user, attributes = self.parse_cas_response(cas_response_body) - - for required_attribute, required_value in self.cas_required_attributes.items(): - # If required attribute was not in CAS Response - Forbidden - if required_attribute not in attributes: - raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) - - # Also need to check value - if required_value is not None: - actual_value = attributes[required_attribute] - # If required attribute value does not match expected - Forbidden - if required_value != actual_value: - raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) - - user_id = UserID.create(user, self.hs.hostname).to_string() - auth_handler = self.auth_handler - registered_user_id = yield auth_handler.check_user_exists(user_id) - if registered_user_id: - access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id( - registered_user_id - ) - ) - result = { - "user_id": registered_user_id, # may have changed - "access_token": access_token, - "refresh_token": refresh_token, - "home_server": self.hs.hostname, - } - - else: - user_id, access_token = ( - yield self.handlers.registration_handler.register(localpart=user) - ) - result = { - "user_id": user_id, # may have changed - "access_token": access_token, - "home_server": self.hs.hostname, - } - - defer.returnValue((200, result)) - @defer.inlineCallbacks def do_jwt_login(self, login_submission): token = login_submission.get("token", None) @@ -293,33 +233,6 @@ class LoginRestServlet(ClientV1RestServlet): defer.returnValue((200, result)) - # TODO Delete this after all CAS clients switch to token login instead - def parse_cas_response(self, cas_response_body): - root = ET.fromstring(cas_response_body) - if not root.tag.endswith("serviceResponse"): - raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) - if not root[0].tag.endswith("authenticationSuccess"): - raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED) - for child in root[0]: - if child.tag.endswith("user"): - user = child.text - if child.tag.endswith("attributes"): - attributes = {} - for attribute in child: - # ElementTree library expands the namespace in attribute tags - # to the full URL of the namespace. - # See (https://docs.python.org/2/library/xml.etree.elementtree.html) - # We don't care about namespace here and it will always be encased in - # curly braces, so we remove them. - if "}" in attribute.tag: - attributes[attribute.tag.split("}")[1]] = attribute.text - else: - attributes[attribute.tag] = attribute.text - if user is None or attributes is None: - raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) - - return (user, attributes) - def _register_device(self, user_id, login_submission): """Register a device for a user. @@ -384,18 +297,6 @@ class SAML2RestServlet(ClientV1RestServlet): defer.returnValue((200, {"status": "not_authenticated"})) -# TODO Delete this after all CAS clients switch to token login instead -class CasRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/login/cas", releases=()) - - def __init__(self, hs): - super(CasRestServlet, self).__init__(hs) - self.cas_server_url = hs.config.cas_server_url - - def on_GET(self, request): - return (200, {"serverUrl": self.cas_server_url}) - - class CasRedirectServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/login/cas/redirect", releases=()) @@ -479,30 +380,39 @@ class CasTicketServlet(ClientV1RestServlet): return urlparse.urlunparse(url_parts) def parse_cas_response(self, cas_response_body): - root = ET.fromstring(cas_response_body) - if not root.tag.endswith("serviceResponse"): - raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) - if not root[0].tag.endswith("authenticationSuccess"): - raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED) - for child in root[0]: - if child.tag.endswith("user"): - user = child.text - if child.tag.endswith("attributes"): - attributes = {} - for attribute in child: - # ElementTree library expands the namespace in attribute tags - # to the full URL of the namespace. - # See (https://docs.python.org/2/library/xml.etree.elementtree.html) - # We don't care about namespace here and it will always be encased in - # curly braces, so we remove them. - if "}" in attribute.tag: - attributes[attribute.tag.split("}")[1]] = attribute.text - else: - attributes[attribute.tag] = attribute.text - if user is None or attributes is None: - raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) - - return (user, attributes) + user = None + attributes = None + try: + root = ET.fromstring(cas_response_body) + if not root.tag.endswith("serviceResponse"): + raise Exception("root of CAS response is not serviceResponse") + success = (root[0].tag.endswith("authenticationSuccess")) + for child in root[0]: + if child.tag.endswith("user"): + user = child.text + if child.tag.endswith("attributes"): + attributes = {} + for attribute in child: + # ElementTree library expands the namespace in + # attribute tags to the full URL of the namespace. + # We don't care about namespace here and it will always + # be encased in curly braces, so we remove them. + tag = attribute.tag + if "}" in tag: + tag = tag.split("}")[1] + attributes[tag] = attribute.text + if user is None: + raise Exception("CAS response does not contain user") + if attributes is None: + raise Exception("CAS response does not contain attributes") + except Exception: + logger.error("Error parsing CAS response", exc_info=1) + raise LoginError(401, "Invalid CAS response", + errcode=Codes.UNAUTHORIZED) + if not success: + raise LoginError(401, "Unsuccessful CAS response", + errcode=Codes.UNAUTHORIZED) + return user, attributes def register_servlets(hs, http_server): @@ -512,5 +422,3 @@ def register_servlets(hs, http_server): if hs.config.cas_enabled: CasRedirectServlet(hs).register(http_server) CasTicketServlet(hs).register(http_server) - CasRestServlet(hs).register(http_server) - # TODO PasswordResetRestServlet(hs).register(http_server) -- cgit 1.5.1 From 2510db3e760320b8eeffd9b9a0a0a193ce49f5ba Mon Sep 17 00:00:00 2001 From: Will Hunt Date: Wed, 10 Aug 2016 12:57:30 +0100 Subject: Don't change status_msg on /sync --- synapse/handlers/presence.py | 9 ++++++--- synapse/rest/client/v2_alpha/sync.py | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 6b70fa3817..2293b5fdf7 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -672,7 +672,7 @@ class PresenceHandler(object): ]) @defer.inlineCallbacks - def set_state(self, target_user, state): + def set_state(self, target_user, state, ignore_status_msg=False): """Set the presence state of the user. """ status_msg = state.get("status_msg", None) @@ -689,10 +689,13 @@ class PresenceHandler(object): prev_state = yield self.current_state_for_user(user_id) new_fields = { - "state": presence, - "status_msg": status_msg if presence != PresenceState.OFFLINE else None + "state": presence } + if not ignore_status_msg: + msg = status_msg if presence != PresenceState.OFFLINE else None + new_fields["status_msg"] = msg + if presence == PresenceState.ONLINE: new_fields["last_active_ts"] = self.clock.time_msec() diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 43d8e0bf39..b11acdbea7 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -146,7 +146,7 @@ class SyncRestServlet(RestServlet): affect_presence = set_presence != PresenceState.OFFLINE if affect_presence: - yield self.presence_handler.set_state(user, {"presence": set_presence}) + yield self.presence_handler.set_state(user, {"presence": set_presence}, True) context = yield self.presence_handler.user_syncing( user.to_string(), affect_presence=affect_presence, -- cgit 1.5.1 From 866a5320de439ab2019251aa8f8697c74aeeef8c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 12 Aug 2016 10:03:19 +0100 Subject: Dont invoke get_handlers fromClientV1RestServlet hs.get_handlers() can not be invoked from split out processes. Moving the invocations down a level means that we can slowly split out individual servlets. --- synapse/rest/client/v1/admin.py | 8 ++++++ synapse/rest/client/v1/base.py | 1 - synapse/rest/client/v1/directory.py | 5 ++++ synapse/rest/client/v1/events.py | 4 +++ synapse/rest/client/v1/initial_sync.py | 4 +++ synapse/rest/client/v1/login.py | 3 +++ synapse/rest/client/v1/profile.py | 12 +++++++++ synapse/rest/client/v1/register.py | 2 ++ synapse/rest/client/v1/room.py | 48 ++++++++++++++++++++++++++++++++++ 9 files changed, 86 insertions(+), 1 deletion(-) (limited to 'synapse/rest/client') diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index b0cb31a448..af21661d7c 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -28,6 +28,10 @@ logger = logging.getLogger(__name__) class WhoisRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/admin/whois/(?P[^/]*)") + def __init__(self, hs): + super(WhoisRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, user_id): target_user = UserID.from_string(user_id) @@ -82,6 +86,10 @@ class PurgeHistoryRestServlet(ClientV1RestServlet): "/admin/purge_history/(?P[^/]*)/(?P[^/]*)" ) + def __init__(self, hs): + super(PurgeHistoryRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_POST(self, request, room_id, event_id): requester = yield self.auth.get_user_by_req(request) diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py index 96b49b01f2..c2a8447860 100644 --- a/synapse/rest/client/v1/base.py +++ b/synapse/rest/client/v1/base.py @@ -57,7 +57,6 @@ class ClientV1RestServlet(RestServlet): hs (synapse.server.HomeServer): """ self.hs = hs - self.handlers = hs.get_handlers() self.builder_factory = hs.get_event_builder_factory() self.auth = hs.get_v1auth() self.txns = HttpTransactionStore() diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 8ac09419dc..09d0831594 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -36,6 +36,10 @@ def register_servlets(hs, http_server): class ClientDirectoryServer(ClientV1RestServlet): PATTERNS = client_path_patterns("/directory/room/(?P[^/]*)$") + def __init__(self, hs): + super(ClientDirectoryServer, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, room_alias): room_alias = RoomAlias.from_string(room_alias) @@ -146,6 +150,7 @@ class ClientDirectoryListServer(ClientV1RestServlet): def __init__(self, hs): super(ClientDirectoryListServer, self).__init__(hs) self.store = hs.get_datastore() + self.handlers = hs.get_handlers() @defer.inlineCallbacks def on_GET(self, request, room_id): diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index 498bb9e18a..998b115bb9 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -32,6 +32,10 @@ class EventStreamRestServlet(ClientV1RestServlet): DEFAULT_LONGPOLL_TIME_MS = 30000 + def __init__(self, hs): + super(EventStreamRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request): requester = yield self.auth.get_user_by_req( diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index 36c3520567..113a49e539 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -23,6 +23,10 @@ from .base import ClientV1RestServlet, client_path_patterns class InitialSyncRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/initialSync$") + def __init__(self, hs): + super(InitialSyncRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request): requester = yield self.auth.get_user_by_req(request) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index b31e27f7b3..6c0eec8fb3 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -56,6 +56,7 @@ class LoginRestServlet(ClientV1RestServlet): self.cas_enabled = hs.config.cas_enabled self.auth_handler = self.hs.get_auth_handler() self.device_handler = self.hs.get_device_handler() + self.handlers = hs.get_handlers() def on_GET(self, request): flows = [] @@ -260,6 +261,7 @@ class SAML2RestServlet(ClientV1RestServlet): def __init__(self, hs): super(SAML2RestServlet, self).__init__(hs) self.sp_config = hs.config.saml2_config_path + self.handlers = hs.get_handlers() @defer.inlineCallbacks def on_POST(self, request): @@ -329,6 +331,7 @@ class CasTicketServlet(ClientV1RestServlet): self.cas_service_url = hs.config.cas_service_url self.cas_required_attributes = hs.config.cas_required_attributes self.auth_handler = hs.get_auth_handler() + self.handlers = hs.get_handlers() @defer.inlineCallbacks def on_GET(self, request): diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index 65c4e2ebef..355e82474b 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -24,6 +24,10 @@ from synapse.http.servlet import parse_json_object_from_request class ProfileDisplaynameRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/profile/(?P[^/]*)/displayname") + def __init__(self, hs): + super(ProfileDisplaynameRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, user_id): user = UserID.from_string(user_id) @@ -62,6 +66,10 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet): class ProfileAvatarURLRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/profile/(?P[^/]*)/avatar_url") + def __init__(self, hs): + super(ProfileAvatarURLRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, user_id): user = UserID.from_string(user_id) @@ -99,6 +107,10 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet): class ProfileRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/profile/(?P[^/]*)") + def __init__(self, hs): + super(ProfileRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, user_id): user = UserID.from_string(user_id) diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index 2383b9df86..71d58c8e8d 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -65,6 +65,7 @@ class RegisterRestServlet(ClientV1RestServlet): self.sessions = {} self.enable_registration = hs.config.enable_registration self.auth_handler = hs.get_auth_handler() + self.handlers = hs.get_handlers() def on_GET(self, request): if self.hs.config.enable_registration_captcha: @@ -383,6 +384,7 @@ class CreateUserRestServlet(ClientV1RestServlet): super(CreateUserRestServlet, self).__init__(hs) self.store = hs.get_datastore() self.direct_user_creation_max_duration = hs.config.user_creation_max_duration + self.handlers = hs.get_handlers() @defer.inlineCallbacks def on_POST(self, request): diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 866a1e9120..89c3895118 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -35,6 +35,10 @@ logger = logging.getLogger(__name__) class RoomCreateRestServlet(ClientV1RestServlet): # No PATTERN; we have custom dispatch rules here + def __init__(self, hs): + super(RoomCreateRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + def register(self, http_server): PATTERNS = "/createRoom" register_txn_path(self, PATTERNS, http_server) @@ -82,6 +86,10 @@ class RoomCreateRestServlet(ClientV1RestServlet): # TODO: Needs unit testing for generic events class RoomStateEventRestServlet(ClientV1RestServlet): + def __init__(self, hs): + super(RoomStateEventRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + def register(self, http_server): # /room/$roomid/state/$eventtype no_state_key = "/rooms/(?P[^/]*)/state/(?P[^/]*)$" @@ -166,6 +174,10 @@ class RoomStateEventRestServlet(ClientV1RestServlet): # TODO: Needs unit testing for generic events + feedback class RoomSendEventRestServlet(ClientV1RestServlet): + def __init__(self, hs): + super(RoomSendEventRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + def register(self, http_server): # /rooms/$roomid/send/$event_type[/$txn_id] PATTERNS = ("/rooms/(?P[^/]*)/send/(?P[^/]*)") @@ -210,6 +222,9 @@ class RoomSendEventRestServlet(ClientV1RestServlet): # TODO: Needs unit testing for room ID + alias joins class JoinRoomAliasServlet(ClientV1RestServlet): + def __init__(self, hs): + super(JoinRoomAliasServlet, self).__init__(hs) + self.handlers = hs.get_handlers() def register(self, http_server): # /join/$room_identifier[/$txn_id] @@ -296,6 +311,10 @@ class PublicRoomListRestServlet(ClientV1RestServlet): class RoomMemberListRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/rooms/(?P[^/]*)/members$") + def __init__(self, hs): + super(RoomMemberListRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, room_id): # TODO support Pagination stream API (limit/tokens) @@ -322,6 +341,10 @@ class RoomMemberListRestServlet(ClientV1RestServlet): class RoomMessageListRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/rooms/(?P[^/]*)/messages$") + def __init__(self, hs): + super(RoomMessageListRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, room_id): requester = yield self.auth.get_user_by_req(request, allow_guest=True) @@ -351,6 +374,10 @@ class RoomMessageListRestServlet(ClientV1RestServlet): class RoomStateRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/rooms/(?P[^/]*)/state$") + def __init__(self, hs): + super(RoomStateRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, room_id): requester = yield self.auth.get_user_by_req(request, allow_guest=True) @@ -368,6 +395,10 @@ class RoomStateRestServlet(ClientV1RestServlet): class RoomInitialSyncRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/rooms/(?P[^/]*)/initialSync$") + def __init__(self, hs): + super(RoomInitialSyncRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, room_id): requester = yield self.auth.get_user_by_req(request, allow_guest=True) @@ -388,6 +419,7 @@ class RoomEventContext(ClientV1RestServlet): def __init__(self, hs): super(RoomEventContext, self).__init__(hs) self.clock = hs.get_clock() + self.handlers = hs.get_handlers() @defer.inlineCallbacks def on_GET(self, request, room_id, event_id): @@ -424,6 +456,10 @@ class RoomEventContext(ClientV1RestServlet): class RoomForgetRestServlet(ClientV1RestServlet): + def __init__(self, hs): + super(RoomForgetRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + def register(self, http_server): PATTERNS = ("/rooms/(?P[^/]*)/forget") register_txn_path(self, PATTERNS, http_server) @@ -462,6 +498,10 @@ class RoomForgetRestServlet(ClientV1RestServlet): # TODO: Needs unit testing class RoomMembershipRestServlet(ClientV1RestServlet): + def __init__(self, hs): + super(RoomMembershipRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + def register(self, http_server): # /rooms/$roomid/[invite|join|leave] PATTERNS = ("/rooms/(?P[^/]*)/" @@ -542,6 +582,10 @@ class RoomMembershipRestServlet(ClientV1RestServlet): class RoomRedactEventRestServlet(ClientV1RestServlet): + def __init__(self, hs): + super(RoomRedactEventRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + def register(self, http_server): PATTERNS = ("/rooms/(?P[^/]*)/redact/(?P[^/]*)") register_txn_path(self, PATTERNS, http_server) @@ -624,6 +668,10 @@ class SearchRestServlet(ClientV1RestServlet): "/search$" ) + def __init__(self, hs): + super(SearchRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_POST(self, request): requester = yield self.auth.get_user_by_req(request) -- cgit 1.5.1 From 4e1cebd56f9688d49a51929264c095356005f9a3 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 12 Aug 2016 15:31:44 +0100 Subject: Make synchrotron accept /events --- synapse/app/synchrotron.py | 36 ++++++++++++++++++++++++++++++++++-- synapse/handlers/__init__.py | 3 --- synapse/handlers/presence.py | 27 +++++++++++++++++++-------- synapse/rest/client/v1/events.py | 9 ++++----- synapse/server.py | 9 +++++++++ 5 files changed, 66 insertions(+), 18 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 48bc97636c..3dca1c37a0 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -26,6 +26,7 @@ from synapse.http.site import SynapseSite from synapse.http.server import JsonResource from synapse.metrics.resource import MetricsResource, METRICS_PREFIX from synapse.rest.client.v2_alpha import sync +from synapse.rest.client.v1 import events from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore @@ -89,17 +90,23 @@ class SynchrotronSlavedStore( get_presence_list_accepted = PresenceStore.__dict__[ "get_presence_list_accepted" ] + get_presence_list_observers_accepted = PresenceStore.__dict__[ + "get_presence_list_observers_accepted" + ] + UPDATE_SYNCING_USERS_MS = 10 * 1000 class SynchrotronPresence(object): def __init__(self, hs): + self.is_mine_id = hs.is_mine_id self.http_client = hs.get_simple_http_client() self.store = hs.get_datastore() self.user_to_num_current_syncs = {} self.syncing_users_url = hs.config.worker_replication_url + "/syncing_users" self.clock = hs.get_clock() + self.notifier = hs.get_notifier() active_presence = self.store.take_presence_startup_info() self.user_to_current_state = { @@ -124,6 +131,8 @@ class SynchrotronPresence(object): pass get_states = PresenceHandler.get_states.__func__ + get_state = PresenceHandler.get_state.__func__ + _get_interested_parties = PresenceHandler._get_interested_parties.__func__ current_state_for_users = PresenceHandler.current_state_for_users.__func__ @defer.inlineCallbacks @@ -194,19 +203,39 @@ class SynchrotronPresence(object): self._need_to_send_sync = False yield self._send_syncing_users_now() + @defer.inlineCallbacks + def notify_from_replication(self, states, stream_id): + parties = yield self._get_interested_parties( + states, calculate_remote_hosts=False + ) + room_ids_to_states, users_to_states, _ = parties + + self.notifier.on_new_event( + "presence_key", stream_id, rooms=room_ids_to_states.keys(), + users=users_to_states.keys() + ) + + @defer.inlineCallbacks def process_replication(self, result): stream = result.get("presence", {"rows": []}) + states = [] for row in stream["rows"]: ( position, user_id, state, last_active_ts, last_federation_update_ts, last_user_sync_ts, status_msg, currently_active ) = row - self.user_to_current_state[user_id] = UserPresenceState( + state = UserPresenceState( user_id, state, last_active_ts, last_federation_update_ts, last_user_sync_ts, status_msg, currently_active ) + self.user_to_current_state[user_id] = state + states.append(state) + + if states and "position" in stream: + stream_id = int(stream["position"]) + yield self.notify_from_replication(states, stream_id) class SynchrotronTyping(object): @@ -266,10 +295,12 @@ class SynchrotronServer(HomeServer): elif name == "client": resource = JsonResource(self, canonical_json=False) sync.register_servlets(self, resource) + events.register_servlets(self, resource) resources.update({ "/_matrix/client/r0": resource, "/_matrix/client/unstable": resource, "/_matrix/client/v2_alpha": resource, + "/_matrix/client/api/v1": resource, }) root_resource = create_resource_tree(resources, Resource()) @@ -315,6 +346,7 @@ class SynchrotronServer(HomeServer): def expire_broken_caches(): store.who_forgot_in_room.invalidate_all() store.get_presence_list_accepted.invalidate_all() + store.get_presence_list_observers_accepted.invalidate_all() def notify_from_stream( result, stream_name, stream_key, room=None, user=None @@ -392,7 +424,7 @@ class SynchrotronServer(HomeServer): ) yield store.process_replication(result) typing_handler.process_replication(result) - presence_handler.process_replication(result) + yield presence_handler.process_replication(result) notify(result) except: logger.exception("Error replicating from %r", replication_url) diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py index 1a50a2ec98..63d05f2531 100644 --- a/synapse/handlers/__init__.py +++ b/synapse/handlers/__init__.py @@ -19,7 +19,6 @@ from .room import ( ) from .room_member import RoomMemberHandler from .message import MessageHandler -from .events import EventStreamHandler, EventHandler from .federation import FederationHandler from .profile import ProfileHandler from .directory import DirectoryHandler @@ -53,8 +52,6 @@ class Handlers(object): self.message_handler = MessageHandler(hs) self.room_creation_handler = RoomCreationHandler(hs) self.room_member_handler = RoomMemberHandler(hs) - self.event_stream_handler = EventStreamHandler(hs) - self.event_handler = EventHandler(hs) self.federation_handler = FederationHandler(hs) self.profile_handler = ProfileHandler(hs) self.directory_handler = DirectoryHandler(hs) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 2293b5fdf7..6a1fe76c88 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -503,7 +503,7 @@ class PresenceHandler(object): defer.returnValue(states) @defer.inlineCallbacks - def _get_interested_parties(self, states): + def _get_interested_parties(self, states, calculate_remote_hosts=True): """Given a list of states return which entities (rooms, users, servers) are interested in the given states. @@ -526,14 +526,15 @@ class PresenceHandler(object): users_to_states.setdefault(state.user_id, []).append(state) hosts_to_states = {} - for room_id, states in room_ids_to_states.items(): - local_states = filter(lambda s: self.is_mine_id(s.user_id), states) - if not local_states: - continue + if calculate_remote_hosts: + for room_id, states in room_ids_to_states.items(): + local_states = filter(lambda s: self.is_mine_id(s.user_id), states) + if not local_states: + continue - hosts = yield self.store.get_joined_hosts_for_room(room_id) - for host in hosts: - hosts_to_states.setdefault(host, []).extend(local_states) + hosts = yield self.store.get_joined_hosts_for_room(room_id) + for host in hosts: + hosts_to_states.setdefault(host, []).extend(local_states) for user_id, states in users_to_states.items(): local_states = filter(lambda s: self.is_mine_id(s.user_id), states) @@ -565,6 +566,16 @@ class PresenceHandler(object): self._push_to_remotes(hosts_to_states) + @defer.inlineCallbacks + def notify_for_states(self, state, stream_id): + parties = yield self._get_interested_parties([state]) + room_ids_to_states, users_to_states, hosts_to_states = parties + + self.notifier.on_new_event( + "presence_key", stream_id, rooms=room_ids_to_states.keys(), + users=[UserID.from_string(u) for u in users_to_states.keys()] + ) + def _push_to_remotes(self, hosts_to_states): """Sends state updates to remote servers. diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index 998b115bb9..701b6f549b 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -34,7 +34,7 @@ class EventStreamRestServlet(ClientV1RestServlet): def __init__(self, hs): super(EventStreamRestServlet, self).__init__(hs) - self.handlers = hs.get_handlers() + self.event_stream_handler = hs.get_event_stream_handler() @defer.inlineCallbacks def on_GET(self, request): @@ -50,7 +50,6 @@ class EventStreamRestServlet(ClientV1RestServlet): if "room_id" in request.args: room_id = request.args["room_id"][0] - handler = self.handlers.event_stream_handler pagin_config = PaginationConfig.from_request(request) timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS if "timeout" in request.args: @@ -61,7 +60,7 @@ class EventStreamRestServlet(ClientV1RestServlet): as_client_event = "raw" not in request.args - chunk = yield handler.get_stream( + chunk = yield self.event_stream_handler.get_stream( requester.user.to_string(), pagin_config, timeout=timeout, @@ -84,12 +83,12 @@ class EventRestServlet(ClientV1RestServlet): def __init__(self, hs): super(EventRestServlet, self).__init__(hs) self.clock = hs.get_clock() + self.event_handler = hs.get_event_handler() @defer.inlineCallbacks def on_GET(self, request, event_id): requester = yield self.auth.get_user_by_req(request) - handler = self.handlers.event_handler - event = yield handler.get_event(requester.user, event_id) + event = yield self.event_handler.get_event(requester.user, event_id) time_now = self.clock.time_msec() if event: diff --git a/synapse/server.py b/synapse/server.py index 6bb4988309..af3246504b 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -41,6 +41,7 @@ from synapse.handlers.presence import PresenceHandler from synapse.handlers.room import RoomListHandler from synapse.handlers.sync import SyncHandler from synapse.handlers.typing import TypingHandler +from synapse.handlers.events import EventHandler, EventStreamHandler from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.notifier import Notifier @@ -94,6 +95,8 @@ class HomeServer(object): 'auth_handler', 'device_handler', 'e2e_keys_handler', + 'event_handler', + 'event_stream_handler', 'application_service_api', 'application_service_scheduler', 'application_service_handler', @@ -214,6 +217,12 @@ class HomeServer(object): def build_application_service_handler(self): return ApplicationServicesHandler(self) + def build_event_handler(self): + return EventHandler(self) + + def build_event_stream_handler(self): + return EventStreamHandler(self) + def build_event_sources(self): return EventSources(self) -- cgit 1.5.1 From e3e3fbc23aab45c50ca3c605568de10c8e04a518 Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Wed, 17 Aug 2016 12:46:49 +0100 Subject: Initial empty implementation that just registers an API endpoint handler --- synapse/rest/__init__.py | 2 ++ synapse/rest/client/v2_alpha/thirdparty.py | 38 ++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+) create mode 100644 synapse/rest/client/v2_alpha/thirdparty.py (limited to 'synapse/rest/client') diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 14227f1cdb..2e0e6babef 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -47,6 +47,7 @@ from synapse.rest.client.v2_alpha import ( report_event, openid, devices, + thirdparty, ) from synapse.http.server import JsonResource @@ -92,3 +93,4 @@ class ClientRestResource(JsonResource): report_event.register_servlets(hs, client_resource) openid.register_servlets(hs, client_resource) devices.register_servlets(hs, client_resource) + thirdparty.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py new file mode 100644 index 0000000000..9be88b2ba1 --- /dev/null +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging + +from synapse.http.servlet import RestServlet +from ._base import client_v2_patterns + +logger = logging.getLogger(__name__) + + +class ThirdPartyUserServlet(RestServlet): + PATTERNS = client_v2_patterns("/3pu(/(?P[^/]+))?$", + releases=()) + + def __init__(self, hs): + super(ThirdPartyUserServlet, self).__init__() + pass + + def on_GET(self, request, protocol): + return (200, {"TODO":"TODO"}) + + +def register_servlets(hs, http_server): + ThirdPartyUserServlet(hs).register(http_server) -- cgit 1.5.1 From fa87c981e1efabe85a88144479b0dd7131b7da12 Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Wed, 17 Aug 2016 13:15:06 +0100 Subject: Thread 3PU lookup through as far as the AS API object; which currently noƶps it MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- synapse/appservice/api.py | 3 +++ synapse/handlers/appservice.py | 21 +++++++++++++++++++++ synapse/rest/client/v2_alpha/thirdparty.py | 11 +++++++++-- 3 files changed, 33 insertions(+), 2 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 6da6a1b62e..6e5f7dc404 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -71,6 +71,9 @@ class ApplicationServiceApi(SimpleHttpClient): logger.warning("query_alias to %s threw exception %s", uri, ex) defer.returnValue(False) + def query_3pu(self, service, protocol, fields): + return False + @defer.inlineCallbacks def push_bulk(self, service, events, txn_id=None): events = self._serialize(events) diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 051ccdb380..69fd766613 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -120,6 +120,21 @@ class ApplicationServicesHandler(object): ) defer.returnValue(result) + @defer.inlineCallbacks + def query_3pu(self, protocol, fields): + services = yield self._get_services_for_3pn(protocol) + + # TODO(paul): scattergather + results = [] + for service in services: + result = yield self.appservice_api.query_3pu( + service, protocol, fields + ) + if result: + results.append(result) + + defer.returnValue(results) + @defer.inlineCallbacks def _get_services_for_event(self, event, restrict_to="", alias_list=None): """Retrieve a list of application services interested in this event. @@ -163,6 +178,12 @@ class ApplicationServicesHandler(object): ] defer.returnValue(interested_list) + @defer.inlineCallbacks + def _get_services_for_3pn(self, protocol): + # TODO(paul): Filter by protocol + services = yield self.store.get_app_services() + defer.returnValue(services) + @defer.inlineCallbacks def _is_unknown_user(self, user_id): if not self.is_mine_id(user_id): diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index 9be88b2ba1..0180b73e9f 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -16,6 +16,8 @@ import logging +from twisted.internet import defer + from synapse.http.servlet import RestServlet from ._base import client_v2_patterns @@ -28,10 +30,15 @@ class ThirdPartyUserServlet(RestServlet): def __init__(self, hs): super(ThirdPartyUserServlet, self).__init__() - pass + self.appservice_handler = hs.get_application_service_handler() + + @defer.inlineCallbacks def on_GET(self, request, protocol): - return (200, {"TODO":"TODO"}) + fields = {} # TODO + results = yield self.appservice_handler.query_3pu(protocol, fields) + + defer.returnValue((200, results)) def register_servlets(hs, http_server): -- cgit 1.5.1 From 38565827418b71e12b3ff37e1482ff71ffe170d9 Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Thu, 18 Aug 2016 14:06:02 +0100 Subject: Ensure that 3PU lookup request fields actually get passed in --- synapse/rest/client/v2_alpha/thirdparty.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'synapse/rest/client') diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index 0180b73e9f..4b2a93f1bb 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -35,7 +35,11 @@ class ThirdPartyUserServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, protocol): - fields = {} # TODO + fields = request.args + del fields["access_token"] + + # TODO(paul): Some type checking on the request args might be nice + # They should probably all be strings results = yield self.appservice_handler.query_3pu(protocol, fields) defer.returnValue((200, results)) -- cgit 1.5.1 From f3afd6ef1a44ef8b87a3f7257a5e42e69c75523e Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Thu, 18 Aug 2016 15:53:01 +0100 Subject: Remove TODO note about request fields being strings - they're always strings --- synapse/rest/client/v2_alpha/thirdparty.py | 2 -- 1 file changed, 2 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index 4b2a93f1bb..bce104c545 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -38,8 +38,6 @@ class ThirdPartyUserServlet(RestServlet): fields = request.args del fields["access_token"] - # TODO(paul): Some type checking on the request args might be nice - # They should probably all be strings results = yield self.appservice_handler.query_3pu(protocol, fields) defer.returnValue((200, results)) -- cgit 1.5.1 From 06964c4a0adabf7d983cbd0d2c6d83eba6fcaf79 Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Thu, 18 Aug 2016 16:09:50 +0100 Subject: Copypasta the 3PU support code to also do 3PL --- synapse/appservice/api.py | 11 ++++++++++ synapse/handlers/appservice.py | 33 +++++++++++++++++++++++++++--- synapse/rest/client/v2_alpha/thirdparty.py | 20 ++++++++++++++++++ 3 files changed, 61 insertions(+), 3 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index e05570cc8b..4ccb5c43c1 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -82,6 +82,17 @@ class ApplicationServiceApi(SimpleHttpClient): logger.warning("query_3pu to %s threw exception %s", uri, ex) defer.returnValue([]) + @defer.inlineCallbacks + def query_3pl(self, service, protocol, fields): + uri = service.url + ("/3pl/%s" % urllib.quote(protocol)) + response = None + try: + response = yield self.get_json(uri, fields) + defer.returnValue(response) + except Exception as ex: + logger.warning("query_3pl to %s threw exception %s", uri, ex) + defer.returnValue([]) + @defer.inlineCallbacks def push_bulk(self, service, events, txn_id=None): events = self._serialize(events) diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 5ed694e711..72c36615df 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -34,11 +34,11 @@ def log_failure(failure): ) ) -def _is_valid_3pu_result(r): +def _is_valid_3pentity_result(r, field): if not isinstance(r, dict): return False - for k in ("userid", "protocol"): + for k in (field, "protocol"): if k not in r: return False if not isinstance(r[k], str): @@ -185,7 +185,34 @@ class ApplicationServicesHandler(object): if not isinstance(result, list): continue for r in result: - if _is_valid_3pu_result(r): + if _is_valid_3pentity_result(r, field="userid"): + ret.append(r) + else: + logger.warn("Application service returned an " + + "invalid result %r", r) + + defer.returnValue(ret) + + @defer.inlineCallbacks + def query_3pl(self, protocol, fields): + services = yield self._get_services_for_3pn(protocol) + + deferreds = [] + for service in services: + deferreds.append(self.appservice_api.query_3pl( + service, protocol, fields + )) + + results = yield defer.DeferredList(deferreds, consumeErrors=True) + + ret = [] + for (success, result) in results: + if not success: + continue + if not isinstance(result, list): + continue + for r in result: + if _is_valid_3pentity_result(r, field="alias"): ret.append(r) else: logger.warn("Application service returned an " + diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index bce104c545..eec08425e6 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -43,5 +43,25 @@ class ThirdPartyUserServlet(RestServlet): defer.returnValue((200, results)) +class ThirdPartyLocationServlet(RestServlet): + PATTERNS = client_v2_patterns("/3pl(/(?P[^/]+))?$", + releases=()) + + def __init__(self, hs): + super(ThirdPartyLocationServlet, self).__init__() + + self.appservice_handler = hs.get_application_service_handler() + + @defer.inlineCallbacks + def on_GET(self, request, protocol): + fields = request.args + del fields["access_token"] + + results = yield self.appservice_handler.query_3pl(protocol, fields) + + defer.returnValue((200, results)) + + def register_servlets(hs, http_server): ThirdPartyUserServlet(hs).register(http_server) + ThirdPartyLocationServlet(hs).register(http_server) -- cgit 1.5.1 From 105ff162d4ae3776674cb1cbec6581e1511871d2 Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Thu, 18 Aug 2016 16:19:23 +0100 Subject: Authenticate 3PE lookup requests --- synapse/rest/client/v2_alpha/thirdparty.py | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'synapse/rest/client') diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index eec08425e6..d229e4b818 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -31,10 +31,13 @@ class ThirdPartyUserServlet(RestServlet): def __init__(self, hs): super(ThirdPartyUserServlet, self).__init__() + self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() @defer.inlineCallbacks def on_GET(self, request, protocol): + yield self.auth.get_user_by_req(request) + fields = request.args del fields["access_token"] @@ -50,10 +53,13 @@ class ThirdPartyLocationServlet(RestServlet): def __init__(self, hs): super(ThirdPartyLocationServlet, self).__init__() + self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() @defer.inlineCallbacks def on_GET(self, request, protocol): + yield self.auth.get_user_by_req(request) + fields = request.args del fields["access_token"] -- cgit 1.5.1 From b515f844ee07c7d6aa1d7e56faa8b65d282e9341 Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Thu, 18 Aug 2016 17:19:55 +0100 Subject: Avoid so much copypasta between 3PU and 3PL query by unifying around a ThirdPartyEntityKind enumeration --- synapse/appservice/api.py | 25 ++++++++++----------- synapse/handlers/appservice.py | 35 ++++++++---------------------- synapse/rest/client/v2_alpha/thirdparty.py | 9 ++++++-- synapse/types.py | 7 ++++++ 4 files changed, 34 insertions(+), 42 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index d4cad1b1ed..dd5e762e0d 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -17,6 +17,7 @@ from twisted.internet import defer from synapse.api.errors import CodeMessageException from synapse.http.client import SimpleHttpClient from synapse.events.utils import serialize_event +from synapse.types import ThirdPartyEntityKind import logging import urllib @@ -72,25 +73,21 @@ class ApplicationServiceApi(SimpleHttpClient): defer.returnValue(False) @defer.inlineCallbacks - def query_3pu(self, service, protocol, fields): - uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol)) - response = None - try: - response = yield self.get_json(uri, fields) - defer.returnValue(response) - except Exception as ex: - logger.warning("query_3pu to %s threw exception %s", uri, ex) - defer.returnValue([]) + def query_3pe(self, service, kind, protocol, fields): + if kind == ThirdPartyEntityKind.USER: + uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol)) + elif kind == ThirdPartyEntityKind.LOCATION: + uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol)) + else: + raise ValueError( + "Unrecognised 'kind' argument %r to query_3pe()", kind + ) - @defer.inlineCallbacks - def query_3pl(self, service, protocol, fields): - uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol)) - response = None try: response = yield self.get_json(uri, fields) defer.returnValue(response) except Exception as ex: - logger.warning("query_3pl to %s threw exception %s", uri, ex) + logger.warning("query_3pe to %s threw exception %s", uri, ex) defer.returnValue([]) @defer.inlineCallbacks diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 03452f6bb0..52c127d2c1 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -18,6 +18,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes from synapse.util.metrics import Measure from synapse.util.logcontext import preserve_fn +from synapse.types import ThirdPartyEntityKind import logging @@ -169,37 +170,19 @@ class ApplicationServicesHandler(object): defer.returnValue(result) @defer.inlineCallbacks - def query_3pu(self, protocol, fields): + def query_3pe(self, kind, protocol, fields): services = yield self._get_services_for_3pn(protocol) results = yield defer.DeferredList([ - self.appservice_api.query_3pu(service, protocol, fields) + self.appservice_api.query_3pe(service, kind, protocol, fields) for service in services ], consumeErrors=True) - ret = [] - for (success, result) in results: - if not success: - continue - if not isinstance(result, list): - continue - for r in result: - if _is_valid_3pentity_result(r, field="userid"): - ret.append(r) - else: - logger.warn("Application service returned an " + - "invalid result %r", r) - - defer.returnValue(ret) - - @defer.inlineCallbacks - def query_3pl(self, protocol, fields): - services = yield self._get_services_for_3pn(protocol) - - results = yield defer.DeferredList([ - self.appservice_api.query_3pl(service, protocol, fields) - for service in services - ], consumeErrors=True) + required_field = ( + "userid" if kind == ThirdPartyEntityKind.USER else + "alias" if kind == ThirdPartyEntityKind.LOCATION else + None + ) ret = [] for (success, result) in results: @@ -208,7 +191,7 @@ class ApplicationServicesHandler(object): if not isinstance(result, list): continue for r in result: - if _is_valid_3pentity_result(r, field="alias"): + if _is_valid_3pentity_result(r, field=required_field): ret.append(r) else: logger.warn("Application service returned an " + diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index d229e4b818..9abca3a8ad 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -19,6 +19,7 @@ import logging from twisted.internet import defer from synapse.http.servlet import RestServlet +from synapse.types import ThirdPartyEntityKind from ._base import client_v2_patterns logger = logging.getLogger(__name__) @@ -41,7 +42,9 @@ class ThirdPartyUserServlet(RestServlet): fields = request.args del fields["access_token"] - results = yield self.appservice_handler.query_3pu(protocol, fields) + results = yield self.appservice_handler.query_3pe( + ThirdPartyEntityKind.USER, protocol, fields + ) defer.returnValue((200, results)) @@ -63,7 +66,9 @@ class ThirdPartyLocationServlet(RestServlet): fields = request.args del fields["access_token"] - results = yield self.appservice_handler.query_3pl(protocol, fields) + results = yield self.appservice_handler.query_3pe( + ThirdPartyEntityKind.LOCATION, protocol, fields + ) defer.returnValue((200, results)) diff --git a/synapse/types.py b/synapse/types.py index 5349b0c450..fd17ecbbe0 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -269,3 +269,10 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")): return "t%d-%d" % (self.topological, self.stream) else: return "s%d" % (self.stream,) + + +# Some arbitrary constants used for internal API enumerations. Don't rely on +# exact values; always pass or compare symbolically +class ThirdPartyEntityKind(object): + USER = 'user' + LOCATION = 'location' -- cgit 1.5.1 From 0acdd0f1eafa962394fd2d1ca950186edf853653 Mon Sep 17 00:00:00 2001 From: David Baker Date: Thu, 18 Aug 2016 17:51:08 +0100 Subject: Use tuple comparison Hopefully easier to read --- synapse/rest/client/v2_alpha/notifications.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py index 4d84230e68..f1a48acf07 100644 --- a/synapse/rest/client/v2_alpha/notifications.py +++ b/synapse/rest/client/v2_alpha/notifications.py @@ -82,10 +82,9 @@ class NotificationsServlet(RestServlet): receipt = receipts_by_room[pa["room_id"]] returned_pa["read"] = ( - receipt["topological_ordering"] >= pa["topological_ordering"] or ( - receipt["topological_ordering"] == pa["topological_ordering"] and - receipt["stream_ordering"] >= pa["stream_ordering"] - ) + receipt["topological_ordering"], receipt["stream_ordering"] + ) >= ( + pa["topological_ordering"], pa["stream_ordering"] ) returned_push_actions.append(returned_pa) next_token = pa["stream_ordering"] -- cgit 1.5.1 From 4b31426a02d32b26b97bd04328426df1666f756d Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Tue, 23 Aug 2016 16:32:04 +0100 Subject: Pass through user-supplied content in /join/$room_id It was always intended to allow custom keys on the join event, but this has at some point been lost. Restore it. If the user specifies keys like "avatar_url" then they will be clobbered. --- synapse/handlers/room_member.py | 14 ++++++++++++-- synapse/rest/client/v1/room.py | 1 + 2 files changed, 13 insertions(+), 2 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 4709112a0c..8b17632fdc 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -59,10 +59,13 @@ class RoomMemberHandler(BaseHandler): prev_event_ids, txn_id=None, ratelimit=True, + content=None, ): + if content is None: + content = {} msg_handler = self.hs.get_handlers().message_handler - content = {"membership": membership} + content["membership"] = membership if requester.is_guest: content["kind"] = "guest" @@ -140,6 +143,7 @@ class RoomMemberHandler(BaseHandler): remote_room_hosts=None, third_party_signed=None, ratelimit=True, + content=None, ): key = (room_id,) @@ -153,6 +157,7 @@ class RoomMemberHandler(BaseHandler): remote_room_hosts=remote_room_hosts, third_party_signed=third_party_signed, ratelimit=ratelimit, + content=content, ) defer.returnValue(result) @@ -168,7 +173,11 @@ class RoomMemberHandler(BaseHandler): remote_room_hosts=None, third_party_signed=None, ratelimit=True, + content=None, ): + if content is None: + content = {} + effective_membership_state = action if action in ["kick", "unban"]: effective_membership_state = "leave" @@ -218,7 +227,7 @@ class RoomMemberHandler(BaseHandler): if inviter and not self.hs.is_mine(inviter): remote_room_hosts.append(inviter.domain) - content = {"membership": Membership.JOIN} + content["membership"] = Membership.JOIN profile = self.hs.get_handlers().profile_handler content["displayname"] = yield profile.get_displayname(target) @@ -272,6 +281,7 @@ class RoomMemberHandler(BaseHandler): txn_id=txn_id, ratelimit=ratelimit, prev_event_ids=latest_event_ids, + content=content, ) @defer.inlineCallbacks diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 89c3895118..0d81757010 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -268,6 +268,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet): action="join", txn_id=txn_id, remote_room_hosts=remote_room_hosts, + content=content, third_party_signed=content.get("third_party_signed", None), ) -- cgit 1.5.1 From 921913935176f5bd3df5e5b960d87c94a2adb304 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 23 Aug 2016 15:23:39 +0100 Subject: Preserve some logcontexts --- synapse/app/homeserver.py | 3 ++- synapse/appservice/scheduler.py | 6 +++--- synapse/crypto/keyring.py | 36 ++++++++++++++++---------------- synapse/federation/federation_base.py | 7 ++++--- synapse/federation/federation_client.py | 17 +++++++++------ synapse/handlers/appservice.py | 8 +++---- synapse/handlers/federation.py | 28 +++++++++++++------------ synapse/handlers/message.py | 35 ++++++++++++++++++------------- synapse/handlers/typing.py | 12 +++++++---- synapse/notifier.py | 7 ++++++- synapse/push/push_tools.py | 9 ++++---- synapse/push/pusherpool.py | 12 ++++++----- synapse/rest/client/v2_alpha/register.py | 3 +-- synapse/storage/events.py | 16 ++++++++------ synapse/storage/stream.py | 6 +++--- synapse/util/async.py | 9 ++++---- synapse/util/logcontext.py | 15 +++++++++---- synapse/visibility.py | 6 +++--- 18 files changed, 136 insertions(+), 99 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 54f35900f8..4493d3b847 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -50,7 +50,7 @@ from synapse.api.urls import ( ) from synapse.config.homeserver import HomeServerConfig from synapse.crypto import context_factory -from synapse.util.logcontext import LoggingContext +from synapse.util.logcontext import LoggingContext, logcontext_tracer from synapse.metrics import register_memory_metrics, get_metrics_for from synapse.metrics.resource import MetricsResource, METRICS_PREFIX from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX @@ -449,6 +449,7 @@ def run(hs): # Uncomment to enable tracing of log context changes. # sys.settrace(logcontext_tracer) with LoggingContext("run"): + sys.settrace(logcontext_tracer) change_resource_limit(hs.config.soft_file_limit) if hs.config.gc_thresholds: gc.set_threshold(*hs.config.gc_thresholds) diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 6450a12890..68a9de17b8 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -150,12 +150,12 @@ class _TransactionController(object): if service_is_up: sent = yield txn.send(self.as_api) if sent: - txn.complete(self.store) + yield txn.complete(self.store) else: - self._start_recoverer(service) + preserve_fn(self._start_recoverer)(service) except Exception as e: logger.exception(e) - self._start_recoverer(service) + preserve_fn(self._start_recoverer)(service) @defer.inlineCallbacks def on_recovered(self, recoverer): diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 1735ca9345..d7211ee9b3 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -308,15 +308,15 @@ class Keyring(object): @defer.inlineCallbacks def get_keys_from_store(self, server_name_and_key_ids): - res = yield defer.gatherResults( + res = yield preserve_context_over_deferred(defer.gatherResults( [ - self.store.get_server_verify_keys( + preserve_fn(self.store.get_server_verify_keys)( server_name, key_ids ).addCallback(lambda ks, server: (server, ks), server_name) for server_name, key_ids in server_name_and_key_ids ], consumeErrors=True, - ).addErrback(unwrapFirstError) + )).addErrback(unwrapFirstError) defer.returnValue(dict(res)) @@ -337,13 +337,13 @@ class Keyring(object): ) defer.returnValue({}) - results = yield defer.gatherResults( + results = yield preserve_context_over_deferred(defer.gatherResults( [ - get_key(p_name, p_keys) + preserve_fn(get_key)(p_name, p_keys) for p_name, p_keys in self.perspective_servers.items() ], consumeErrors=True, - ).addErrback(unwrapFirstError) + )).addErrback(unwrapFirstError) union_of_keys = {} for result in results: @@ -383,13 +383,13 @@ class Keyring(object): defer.returnValue(keys) - results = yield defer.gatherResults( + results = yield preserve_context_over_deferred(defer.gatherResults( [ - get_key(server_name, key_ids) + preserve_fn(get_key)(server_name, key_ids) for server_name, key_ids in server_name_and_key_ids ], consumeErrors=True, - ).addErrback(unwrapFirstError) + )).addErrback(unwrapFirstError) merged = {} for result in results: @@ -466,9 +466,9 @@ class Keyring(object): for server_name, response_keys in processed_response.items(): keys.setdefault(server_name, {}).update(response_keys) - yield defer.gatherResults( + yield preserve_context_over_deferred(defer.gatherResults( [ - self.store_keys( + preserve_fn(self.store_keys)( server_name=server_name, from_server=perspective_name, verify_keys=response_keys, @@ -476,7 +476,7 @@ class Keyring(object): for server_name, response_keys in keys.items() ], consumeErrors=True - ).addErrback(unwrapFirstError) + )).addErrback(unwrapFirstError) defer.returnValue(keys) @@ -524,7 +524,7 @@ class Keyring(object): keys.update(response_keys) - yield defer.gatherResults( + yield preserve_context_over_deferred(defer.gatherResults( [ preserve_fn(self.store_keys)( server_name=key_server_name, @@ -534,7 +534,7 @@ class Keyring(object): for key_server_name, verify_keys in keys.items() ], consumeErrors=True - ).addErrback(unwrapFirstError) + )).addErrback(unwrapFirstError) defer.returnValue(keys) @@ -600,7 +600,7 @@ class Keyring(object): response_keys.update(verify_keys) response_keys.update(old_verify_keys) - yield defer.gatherResults( + yield preserve_context_over_deferred(defer.gatherResults( [ preserve_fn(self.store.store_server_keys_json)( server_name=server_name, @@ -613,7 +613,7 @@ class Keyring(object): for key_id in updated_key_ids ], consumeErrors=True, - ).addErrback(unwrapFirstError) + )).addErrback(unwrapFirstError) results[server_name] = response_keys @@ -702,7 +702,7 @@ class Keyring(object): A deferred that completes when the keys are stored. """ # TODO(markjh): Store whether the keys have expired. - yield defer.gatherResults( + yield preserve_context_over_deferred(defer.gatherResults( [ preserve_fn(self.store.store_server_verify_key)( server_name, server_name, key.time_added, key @@ -710,4 +710,4 @@ class Keyring(object): for key_id, key in verify_keys.items() ], consumeErrors=True, - ).addErrback(unwrapFirstError) + )).addErrback(unwrapFirstError) diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index da2f5e8cfd..2339cc9034 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -23,6 +23,7 @@ from synapse.crypto.event_signing import check_event_content_hash from synapse.api.errors import SynapseError from synapse.util import unwrapFirstError +from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred import logging @@ -102,10 +103,10 @@ class FederationBase(object): warn, pdu ) - valid_pdus = yield defer.gatherResults( + valid_pdus = yield preserve_context_over_deferred(defer.gatherResults( deferreds, consumeErrors=True - ).addErrback(unwrapFirstError) + )).addErrback(unwrapFirstError) if include_none: defer.returnValue(valid_pdus) @@ -129,7 +130,7 @@ class FederationBase(object): for pdu in pdus ] - deferreds = self.keyring.verify_json_objects_for_server([ + deferreds = preserve_fn(self.keyring.verify_json_objects_for_server)([ (p.origin, p.get_pdu_json()) for p in redacted_pdus ]) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 9ba3151713..f2b3aceb49 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -27,6 +27,7 @@ from synapse.util import unwrapFirstError from synapse.util.async import concurrently_execute from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.logutils import log_function +from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.events import FrozenEvent import synapse.metrics @@ -225,10 +226,10 @@ class FederationClient(FederationBase): ] # FIXME: We should handle signature failures more gracefully. - pdus[:] = yield defer.gatherResults( + pdus[:] = yield preserve_context_over_deferred(defer.gatherResults( self._check_sigs_and_hashes(pdus), consumeErrors=True, - ).addErrback(unwrapFirstError) + )).addErrback(unwrapFirstError) defer.returnValue(pdus) @@ -457,14 +458,16 @@ class FederationClient(FederationBase): batch = set(missing_events[i:i + batch_size]) deferreds = [ - self.get_pdu( + preserve_fn(self.get_pdu)( destinations=random_server_list(), event_id=e_id, ) for e_id in batch ] - res = yield defer.DeferredList(deferreds, consumeErrors=True) + res = yield preserve_context_over_deferred( + defer.DeferredList(deferreds, consumeErrors=True) + ) for success, result in res: if success: signed_events.append(result) @@ -853,14 +856,16 @@ class FederationClient(FederationBase): return srvs deferreds = [ - self.get_pdu( + preserve_fn(self.get_pdu)( destinations=random_server_list(), event_id=e_id, ) for e_id, depth in ordered_missing[:limit - len(signed_events)] ] - res = yield defer.DeferredList(deferreds, consumeErrors=True) + res = yield preserve_context_over_deferred( + defer.DeferredList(deferreds, consumeErrors=True) + ) for (result, val), (e_id, _) in zip(res, ordered_missing): if result and val: signed_events.append(val) diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index dd285452cd..306686a384 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -17,7 +17,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes from synapse.util.metrics import Measure -from synapse.util.logcontext import preserve_fn +from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred import logging @@ -163,10 +163,10 @@ class ApplicationServicesHandler(object): def query_3pe(self, kind, protocol, fields): services = yield self._get_services_for_3pn(protocol) - results = yield defer.DeferredList([ - self.appservice_api.query_3pe(service, kind, protocol, fields) + results = yield preserve_context_over_deferred(defer.DeferredList([ + preserve_fn(self.appservice_api.query_3pe)(service, kind, protocol, fields) for service in services - ], consumeErrors=True) + ], consumeErrors=True)) ret = [] for (success, result) in results: diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 328f8f4842..fe3092b14b 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -26,7 +26,9 @@ from synapse.api.errors import ( from synapse.api.constants import EventTypes, Membership, RejectedReason from synapse.events.validator import EventValidator from synapse.util import unwrapFirstError -from synapse.util.logcontext import PreserveLoggingContext, preserve_fn +from synapse.util.logcontext import ( + PreserveLoggingContext, preserve_fn, preserve_context_over_deferred +) from synapse.util.logutils import log_function from synapse.util.async import run_on_reactor from synapse.util.frozenutils import unfreeze @@ -361,9 +363,9 @@ class FederationHandler(BaseHandler): missing_auth - failed_to_fetch ) - results = yield defer.gatherResults( + results = yield preserve_context_over_deferred(defer.gatherResults( [ - self.replication_layer.get_pdu( + preserve_fn(self.replication_layer.get_pdu)( [dest], event_id, outlier=True, @@ -372,7 +374,7 @@ class FederationHandler(BaseHandler): for event_id in missing_auth - failed_to_fetch ], consumeErrors=True - ).addErrback(unwrapFirstError) + )).addErrback(unwrapFirstError) auth_events.update({a.event_id: a for a in results}) required_auth.update( a_id for event in results for a_id, _ in event.auth_events @@ -552,10 +554,10 @@ class FederationHandler(BaseHandler): event_ids = list(extremities.keys()) - states = yield defer.gatherResults([ - self.state_handler.resolve_state_groups(room_id, [e]) + states = yield preserve_context_over_deferred(defer.gatherResults([ + preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e]) for e in event_ids - ]) + ])) states = dict(zip(event_ids, [s[1] for s in states])) for e_id, _ in sorted_extremeties_tuple: @@ -1166,9 +1168,9 @@ class FederationHandler(BaseHandler): a bunch of outliers, but not a chunk of individual events that depend on each other for state calculations. """ - contexts = yield defer.gatherResults( + contexts = yield preserve_context_over_deferred(defer.gatherResults( [ - self._prep_event( + preserve_fn(self._prep_event)( origin, ev_info["event"], state=ev_info.get("state"), @@ -1176,7 +1178,7 @@ class FederationHandler(BaseHandler): ) for ev_info in event_infos ] - ) + )) yield self.store.persist_events( [ @@ -1460,9 +1462,9 @@ class FederationHandler(BaseHandler): # Do auth conflict res. logger.info("Different auth: %s", different_auth) - different_events = yield defer.gatherResults( + different_events = yield preserve_context_over_deferred(defer.gatherResults( [ - self.store.get_event( + preserve_fn(self.store.get_event)( d, allow_none=True, allow_rejected=False, @@ -1471,7 +1473,7 @@ class FederationHandler(BaseHandler): if d in have_events and not have_events[d] ], consumeErrors=True - ).addErrback(unwrapFirstError) + )).addErrback(unwrapFirstError) if different_events: local_view = dict(auth_events) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index dc76d34a52..4c3cd9d12e 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -28,7 +28,8 @@ from synapse.types import ( from synapse.util import unwrapFirstError from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLock from synapse.util.caches.snapshot_cache import SnapshotCache -from synapse.util.logcontext import preserve_fn +from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred +from synapse.util.metrics import measure_func from synapse.visibility import filter_events_for_client from ._base import BaseHandler @@ -502,15 +503,17 @@ class MessageHandler(BaseHandler): lambda states: states[event.event_id] ) - (messages, token), current_state = yield defer.gatherResults( - [ - self.store.get_recent_events_for_room( - event.room_id, - limit=limit, - end_token=room_end_token, - ), - deferred_room_state, - ] + (messages, token), current_state = yield preserve_context_over_deferred( + defer.gatherResults( + [ + preserve_fn(self.store.get_recent_events_for_room)( + event.room_id, + limit=limit, + end_token=room_end_token, + ), + deferred_room_state, + ] + ) ).addErrback(unwrapFirstError) messages = yield filter_events_for_client( @@ -719,9 +722,9 @@ class MessageHandler(BaseHandler): presence, receipts, (messages, token) = yield defer.gatherResults( [ - get_presence(), - get_receipts(), - self.store.get_recent_events_for_room( + preserve_fn(get_presence)(), + preserve_fn(get_receipts)(), + preserve_fn(self.store.get_recent_events_for_room)( room_id, limit=limit, end_token=now_token.room_key, @@ -755,6 +758,7 @@ class MessageHandler(BaseHandler): defer.returnValue(ret) + @measure_func("_create_new_client_event") @defer.inlineCallbacks def _create_new_client_event(self, builder, prev_event_ids=None): if prev_event_ids: @@ -806,6 +810,7 @@ class MessageHandler(BaseHandler): (event, context,) ) + @measure_func("handle_new_client_event") @defer.inlineCallbacks def handle_new_client_event( self, @@ -934,7 +939,7 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def _notify(): yield run_on_reactor() - self.notifier.on_new_room_event( + yield self.notifier.on_new_room_event( event, event_stream_id, max_stream_id, extra_users=extra_users ) @@ -944,6 +949,6 @@ class MessageHandler(BaseHandler): # If invite, remove room_state from unsigned before sending. event.unsigned.pop("invite_room_state", None) - federation_handler.handle_new_event( + preserve_fn(federation_handler.handle_new_event)( event, destinations=destinations, ) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 5589296c09..46181984c0 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -16,7 +16,9 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, AuthError -from synapse.util.logcontext import PreserveLoggingContext +from synapse.util.logcontext import ( + PreserveLoggingContext, preserve_fn, preserve_context_over_deferred, +) from synapse.util.metrics import Measure from synapse.types import UserID @@ -169,13 +171,13 @@ class TypingHandler(object): deferreds = [] for domain in domains: if domain == self.server_name: - self._push_update_local( + preserve_fn(self._push_update_local)( room_id=room_id, user_id=user_id, typing=typing ) else: - deferreds.append(self.federation.send_edu( + deferreds.append(preserve_fn(self.federation.send_edu)( destination=domain, edu_type="m.typing", content={ @@ -185,7 +187,9 @@ class TypingHandler(object): }, )) - yield defer.DeferredList(deferreds, consumeErrors=True) + yield preserve_context_over_deferred( + defer.DeferredList(deferreds, consumeErrors=True) + ) @defer.inlineCallbacks def _recv_edu(self, origin, content): diff --git a/synapse/notifier.py b/synapse/notifier.py index c48024096d..b86648f5e4 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -19,7 +19,7 @@ from synapse.api.errors import AuthError from synapse.util.logutils import log_function from synapse.util.async import ObservableDeferred -from synapse.util.logcontext import PreserveLoggingContext +from synapse.util.logcontext import PreserveLoggingContext, preserve_fn from synapse.util.metrics import Measure from synapse.types import StreamToken from synapse.visibility import filter_events_for_client @@ -174,6 +174,7 @@ class Notifier(object): lambda: len(self.user_to_user_stream), ) + @preserve_fn def on_new_room_event(self, event, room_stream_id, max_room_stream_id, extra_users=[]): """ Used by handlers to inform the notifier something has happened @@ -195,6 +196,7 @@ class Notifier(object): self.notify_replication() + @preserve_fn def _notify_pending_new_room_events(self, max_room_stream_id): """Notify for the room events that were queued waiting for a previous event to be persisted. @@ -212,6 +214,7 @@ class Notifier(object): else: self._on_new_room_event(event, room_stream_id, extra_users) + @preserve_fn def _on_new_room_event(self, event, room_stream_id, extra_users=[]): """Notify any user streams that are interested in this room event""" # poke any interested application service. @@ -226,6 +229,7 @@ class Notifier(object): rooms=[event.room_id], ) + @preserve_fn def on_new_event(self, stream_key, new_token, users=[], rooms=[]): """ Used to inform listeners that something has happend event wise. @@ -252,6 +256,7 @@ class Notifier(object): self.notify_replication() + @preserve_fn def on_new_replication_data(self): """Used to inform replication listeners that something has happend without waking up any of the normal user event streams""" diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index d555a33e9a..becb8ef1ae 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -17,14 +17,15 @@ from twisted.internet import defer from synapse.util.presentable_names import ( calculate_room_name, name_from_member_event ) +from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred @defer.inlineCallbacks def get_badge_count(store, user_id): - invites, joins = yield defer.gatherResults([ - store.get_invited_rooms_for_user(user_id), - store.get_rooms_for_user(user_id), - ], consumeErrors=True) + invites, joins = yield preserve_context_over_deferred(defer.gatherResults([ + preserve_fn(store.get_invited_rooms_for_user)(user_id), + preserve_fn(store.get_rooms_for_user)(user_id), + ], consumeErrors=True)) my_receipts_by_room = yield store.get_receipts_for_user( user_id, "m.read", diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 54c0f1b849..3837be523d 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -17,7 +17,7 @@ from twisted.internet import defer import pusher -from synapse.util.logcontext import preserve_fn +from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.util.async import run_on_reactor import logging @@ -130,10 +130,12 @@ class PusherPool: if u in self.pushers: for p in self.pushers[u].values(): deferreds.append( - p.on_new_notifications(min_stream_id, max_stream_id) + preserve_fn(p.on_new_notifications)( + min_stream_id, max_stream_id + ) ) - yield defer.gatherResults(deferreds) + yield preserve_context_over_deferred(defer.gatherResults(deferreds)) except: logger.exception("Exception in pusher on_new_notifications") @@ -155,10 +157,10 @@ class PusherPool: if u in self.pushers: for p in self.pushers[u].values(): deferreds.append( - p.on_new_receipts(min_stream_id, max_stream_id) + preserve_fn(p.on_new_receipts)(min_stream_id, max_stream_id) ) - yield defer.gatherResults(deferreds) + yield preserve_context_over_deferred(defer.gatherResults(deferreds)) except: logger.exception("Exception in pusher on_new_receipts") diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 943f5676a3..2121bd75ea 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -403,10 +403,9 @@ class RegisterRestServlet(RestServlet): # register the user's device device_id = params.get("device_id") initial_display_name = params.get("initial_device_display_name") - device_id = self.device_handler.check_device_registered( + return self.device_handler.check_device_registered( user_id, device_id, initial_display_name ) - return device_id @defer.inlineCallbacks def _do_guest_registration(self): diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 97aef25321..57e5005285 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -20,7 +20,9 @@ from synapse.events import FrozenEvent, USE_FROZEN_DICTS from synapse.events.utils import prune_event from synapse.util.async import ObservableDeferred -from synapse.util.logcontext import preserve_fn, PreserveLoggingContext +from synapse.util.logcontext import ( + preserve_fn, PreserveLoggingContext, preserve_context_over_deferred +) from synapse.util.logutils import log_function from synapse.util.metrics import Measure from synapse.api.constants import EventTypes @@ -202,7 +204,7 @@ class EventsStore(SQLBaseStore): deferreds = [] for room_id, evs_ctxs in partitioned.items(): - d = self._event_persist_queue.add_to_queue( + d = preserve_fn(self._event_persist_queue.add_to_queue)( room_id, evs_ctxs, backfilled=backfilled, current_state=None, @@ -212,7 +214,9 @@ class EventsStore(SQLBaseStore): for room_id in partitioned.keys(): self._maybe_start_persisting(room_id) - return defer.gatherResults(deferreds, consumeErrors=True) + return preserve_context_over_deferred( + defer.gatherResults(deferreds, consumeErrors=True) + ) @defer.inlineCallbacks @log_function @@ -225,7 +229,7 @@ class EventsStore(SQLBaseStore): self._maybe_start_persisting(event.room_id) - yield deferred + yield preserve_context_over_deferred(deferred) max_persisted_id = yield self._stream_id_gen.get_current_token() defer.returnValue((event.internal_metadata.stream_ordering, max_persisted_id)) @@ -1088,7 +1092,7 @@ class EventsStore(SQLBaseStore): if not allow_rejected: rows[:] = [r for r in rows if not r["rejects"]] - res = yield defer.gatherResults( + res = yield preserve_context_over_deferred(defer.gatherResults( [ preserve_fn(self._get_event_from_row)( row["internal_metadata"], row["json"], row["redacts"], @@ -1097,7 +1101,7 @@ class EventsStore(SQLBaseStore): for row in rows ], consumeErrors=True - ) + )) defer.returnValue({ e.event.event_id: e diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 862c5c3ea1..0577a0525b 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -39,7 +39,7 @@ from ._base import SQLBaseStore from synapse.util.caches.descriptors import cached from synapse.api.constants import EventTypes from synapse.types import RoomStreamToken -from synapse.util.logcontext import preserve_fn +from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.storage.engines import PostgresEngine, Sqlite3Engine import logging @@ -234,12 +234,12 @@ class StreamStore(SQLBaseStore): results = {} room_ids = list(room_ids) for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)): - res = yield defer.gatherResults([ + res = yield preserve_context_over_deferred(defer.gatherResults([ preserve_fn(self.get_room_events_stream_for_room)( room_id, from_key, to_key, limit, order=order, ) for room_id in rm_ids - ]) + ])) results.update(dict(zip(rm_ids, res))) defer.returnValue(results) diff --git a/synapse/util/async.py b/synapse/util/async.py index c84b23ff46..347fb1e380 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -146,10 +146,10 @@ def concurrently_execute(func, args, limit): except StopIteration: pass - return defer.gatherResults([ + return preserve_context_over_deferred(defer.gatherResults([ preserve_fn(_concurrently_execute_inner)() for _ in xrange(limit) - ], consumeErrors=True).addErrback(unwrapFirstError) + ], consumeErrors=True)).addErrback(unwrapFirstError) class Linearizer(object): @@ -181,7 +181,8 @@ class Linearizer(object): self.key_to_defer[key] = new_defer if current_defer: - yield preserve_context_over_deferred(current_defer) + with PreserveLoggingContext(): + yield current_defer @contextmanager def _ctx_manager(): @@ -264,7 +265,7 @@ class ReadWriteLock(object): curr_readers.clear() self.key_to_current_writer[key] = new_defer - yield defer.gatherResults(to_wait_on) + yield preserve_context_over_deferred(defer.gatherResults(to_wait_on)) @contextmanager def _ctx_manager(): diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index 7a87045f87..6c83eb213d 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -297,12 +297,13 @@ def preserve_context_over_fn(fn, *args, **kwargs): return res -def preserve_context_over_deferred(deferred): +def preserve_context_over_deferred(deferred, context=None): """Given a deferred wrap it such that any callbacks added later to it will be invoked with the current context. """ - current_context = LoggingContext.current_context() - d = _PreservingContextDeferred(current_context) + if context is None: + context = LoggingContext.current_context() + d = _PreservingContextDeferred(context) deferred.chainDeferred(d) return d @@ -316,7 +317,13 @@ def preserve_fn(f): def g(*args, **kwargs): with PreserveLoggingContext(current): - return f(*args, **kwargs) + res = f(*args, **kwargs) + if isinstance(res, defer.Deferred): + return preserve_context_over_deferred( + res, context=LoggingContext.sentinel + ) + else: + return res return g diff --git a/synapse/visibility.py b/synapse/visibility.py index 948ad51772..cc12c0a23d 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -17,7 +17,7 @@ from twisted.internet import defer from synapse.api.constants import Membership, EventTypes -from synapse.util.logcontext import preserve_fn +from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred import logging @@ -55,12 +55,12 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state): given events events ([synapse.events.EventBase]): list of events to filter """ - forgotten = yield defer.gatherResults([ + forgotten = yield preserve_context_over_deferred(defer.gatherResults([ preserve_fn(store.who_forgot_in_room)( room_id, ) for room_id in frozenset(e.room_id for e in events) - ], consumeErrors=True) + ], consumeErrors=True)) # Set of membership event_ids that have been forgotten event_id_forgotten = frozenset( -- cgit 1.5.1