diff options
Diffstat (limited to '')
-rw-r--r-- | synapse/rest/client/v1/admin.py | 8 | ||||
-rw-r--r-- | synapse/rest/client/v1/base.py | 1 | ||||
-rw-r--r-- | synapse/rest/client/v1/directory.py | 5 | ||||
-rw-r--r-- | synapse/rest/client/v1/events.py | 11 | ||||
-rw-r--r-- | synapse/rest/client/v1/initial_sync.py | 4 | ||||
-rw-r--r-- | synapse/rest/client/v1/login.py | 162 | ||||
-rw-r--r-- | synapse/rest/client/v1/profile.py | 12 | ||||
-rw-r--r-- | synapse/rest/client/v1/register.py | 2 | ||||
-rw-r--r-- | synapse/rest/client/v1/room.py | 49 |
9 files changed, 124 insertions, 130 deletions
diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index b0cb31a448..af21661d7c 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -28,6 +28,10 @@ logger = logging.getLogger(__name__) class WhoisRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/admin/whois/(?P<user_id>[^/]*)") + def __init__(self, hs): + super(WhoisRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, user_id): target_user = UserID.from_string(user_id) @@ -82,6 +86,10 @@ class PurgeHistoryRestServlet(ClientV1RestServlet): "/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" ) + def __init__(self, hs): + super(PurgeHistoryRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_POST(self, request, room_id, event_id): requester = yield self.auth.get_user_by_req(request) diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py index 96b49b01f2..c2a8447860 100644 --- a/synapse/rest/client/v1/base.py +++ b/synapse/rest/client/v1/base.py @@ -57,7 +57,6 @@ class ClientV1RestServlet(RestServlet): hs (synapse.server.HomeServer): """ self.hs = hs - self.handlers = hs.get_handlers() self.builder_factory = hs.get_event_builder_factory() self.auth = hs.get_v1auth() self.txns = HttpTransactionStore() diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 8ac09419dc..09d0831594 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -36,6 +36,10 @@ def register_servlets(hs, http_server): class ClientDirectoryServer(ClientV1RestServlet): PATTERNS = client_path_patterns("/directory/room/(?P<room_alias>[^/]*)$") + def __init__(self, hs): + super(ClientDirectoryServer, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, room_alias): room_alias = RoomAlias.from_string(room_alias) @@ -146,6 +150,7 @@ class ClientDirectoryListServer(ClientV1RestServlet): def __init__(self, hs): super(ClientDirectoryListServer, self).__init__(hs) self.store = hs.get_datastore() + self.handlers = hs.get_handlers() @defer.inlineCallbacks def on_GET(self, request, room_id): diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index 498bb9e18a..701b6f549b 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -32,6 +32,10 @@ class EventStreamRestServlet(ClientV1RestServlet): DEFAULT_LONGPOLL_TIME_MS = 30000 + def __init__(self, hs): + super(EventStreamRestServlet, self).__init__(hs) + self.event_stream_handler = hs.get_event_stream_handler() + @defer.inlineCallbacks def on_GET(self, request): requester = yield self.auth.get_user_by_req( @@ -46,7 +50,6 @@ class EventStreamRestServlet(ClientV1RestServlet): if "room_id" in request.args: room_id = request.args["room_id"][0] - handler = self.handlers.event_stream_handler pagin_config = PaginationConfig.from_request(request) timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS if "timeout" in request.args: @@ -57,7 +60,7 @@ class EventStreamRestServlet(ClientV1RestServlet): as_client_event = "raw" not in request.args - chunk = yield handler.get_stream( + chunk = yield self.event_stream_handler.get_stream( requester.user.to_string(), pagin_config, timeout=timeout, @@ -80,12 +83,12 @@ class EventRestServlet(ClientV1RestServlet): def __init__(self, hs): super(EventRestServlet, self).__init__(hs) self.clock = hs.get_clock() + self.event_handler = hs.get_event_handler() @defer.inlineCallbacks def on_GET(self, request, event_id): requester = yield self.auth.get_user_by_req(request) - handler = self.handlers.event_handler - event = yield handler.get_event(requester.user, event_id) + event = yield self.event_handler.get_event(requester.user, event_id) time_now = self.clock.time_msec() if event: diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index 36c3520567..113a49e539 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -23,6 +23,10 @@ from .base import ClientV1RestServlet, client_path_patterns class InitialSyncRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/initialSync$") + def __init__(self, hs): + super(InitialSyncRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request): requester = yield self.auth.get_user_by_req(request) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 92fcae674a..6c0eec8fb3 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -54,12 +54,9 @@ class LoginRestServlet(ClientV1RestServlet): self.jwt_secret = hs.config.jwt_secret self.jwt_algorithm = hs.config.jwt_algorithm self.cas_enabled = hs.config.cas_enabled - self.cas_server_url = hs.config.cas_server_url - self.cas_required_attributes = hs.config.cas_required_attributes - self.servername = hs.config.server_name - self.http_client = hs.get_simple_http_client() self.auth_handler = self.hs.get_auth_handler() self.device_handler = self.hs.get_device_handler() + self.handlers = hs.get_handlers() def on_GET(self, request): flows = [] @@ -110,17 +107,6 @@ class LoginRestServlet(ClientV1RestServlet): LoginRestServlet.JWT_TYPE): result = yield self.do_jwt_login(login_submission) defer.returnValue(result) - # TODO Delete this after all CAS clients switch to token login instead - elif self.cas_enabled and (login_submission["type"] == - LoginRestServlet.CAS_TYPE): - uri = "%s/proxyValidate" % (self.cas_server_url,) - args = { - "ticket": login_submission["ticket"], - "service": login_submission["service"] - } - body = yield self.http_client.get_raw(uri, args) - result = yield self.do_cas_login(body) - defer.returnValue(result) elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: result = yield self.do_token_login(login_submission) defer.returnValue(result) @@ -191,51 +177,6 @@ class LoginRestServlet(ClientV1RestServlet): defer.returnValue((200, result)) - # TODO Delete this after all CAS clients switch to token login instead - @defer.inlineCallbacks - def do_cas_login(self, cas_response_body): - user, attributes = self.parse_cas_response(cas_response_body) - - for required_attribute, required_value in self.cas_required_attributes.items(): - # If required attribute was not in CAS Response - Forbidden - if required_attribute not in attributes: - raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) - - # Also need to check value - if required_value is not None: - actual_value = attributes[required_attribute] - # If required attribute value does not match expected - Forbidden - if required_value != actual_value: - raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) - - user_id = UserID.create(user, self.hs.hostname).to_string() - auth_handler = self.auth_handler - registered_user_id = yield auth_handler.check_user_exists(user_id) - if registered_user_id: - access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id( - registered_user_id - ) - ) - result = { - "user_id": registered_user_id, # may have changed - "access_token": access_token, - "refresh_token": refresh_token, - "home_server": self.hs.hostname, - } - - else: - user_id, access_token = ( - yield self.handlers.registration_handler.register(localpart=user) - ) - result = { - "user_id": user_id, # may have changed - "access_token": access_token, - "home_server": self.hs.hostname, - } - - defer.returnValue((200, result)) - @defer.inlineCallbacks def do_jwt_login(self, login_submission): token = login_submission.get("token", None) @@ -293,33 +234,6 @@ class LoginRestServlet(ClientV1RestServlet): defer.returnValue((200, result)) - # TODO Delete this after all CAS clients switch to token login instead - def parse_cas_response(self, cas_response_body): - root = ET.fromstring(cas_response_body) - if not root.tag.endswith("serviceResponse"): - raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) - if not root[0].tag.endswith("authenticationSuccess"): - raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED) - for child in root[0]: - if child.tag.endswith("user"): - user = child.text - if child.tag.endswith("attributes"): - attributes = {} - for attribute in child: - # ElementTree library expands the namespace in attribute tags - # to the full URL of the namespace. - # See (https://docs.python.org/2/library/xml.etree.elementtree.html) - # We don't care about namespace here and it will always be encased in - # curly braces, so we remove them. - if "}" in attribute.tag: - attributes[attribute.tag.split("}")[1]] = attribute.text - else: - attributes[attribute.tag] = attribute.text - if user is None or attributes is None: - raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) - - return (user, attributes) - def _register_device(self, user_id, login_submission): """Register a device for a user. @@ -347,6 +261,7 @@ class SAML2RestServlet(ClientV1RestServlet): def __init__(self, hs): super(SAML2RestServlet, self).__init__(hs) self.sp_config = hs.config.saml2_config_path + self.handlers = hs.get_handlers() @defer.inlineCallbacks def on_POST(self, request): @@ -384,18 +299,6 @@ class SAML2RestServlet(ClientV1RestServlet): defer.returnValue((200, {"status": "not_authenticated"})) -# TODO Delete this after all CAS clients switch to token login instead -class CasRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/login/cas", releases=()) - - def __init__(self, hs): - super(CasRestServlet, self).__init__(hs) - self.cas_server_url = hs.config.cas_server_url - - def on_GET(self, request): - return (200, {"serverUrl": self.cas_server_url}) - - class CasRedirectServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/login/cas/redirect", releases=()) @@ -427,6 +330,8 @@ class CasTicketServlet(ClientV1RestServlet): self.cas_server_url = hs.config.cas_server_url self.cas_service_url = hs.config.cas_service_url self.cas_required_attributes = hs.config.cas_required_attributes + self.auth_handler = hs.get_auth_handler() + self.handlers = hs.get_handlers() @defer.inlineCallbacks def on_GET(self, request): @@ -479,30 +384,39 @@ class CasTicketServlet(ClientV1RestServlet): return urlparse.urlunparse(url_parts) def parse_cas_response(self, cas_response_body): - root = ET.fromstring(cas_response_body) - if not root.tag.endswith("serviceResponse"): - raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) - if not root[0].tag.endswith("authenticationSuccess"): - raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED) - for child in root[0]: - if child.tag.endswith("user"): - user = child.text - if child.tag.endswith("attributes"): - attributes = {} - for attribute in child: - # ElementTree library expands the namespace in attribute tags - # to the full URL of the namespace. - # See (https://docs.python.org/2/library/xml.etree.elementtree.html) - # We don't care about namespace here and it will always be encased in - # curly braces, so we remove them. - if "}" in attribute.tag: - attributes[attribute.tag.split("}")[1]] = attribute.text - else: - attributes[attribute.tag] = attribute.text - if user is None or attributes is None: - raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) - - return (user, attributes) + user = None + attributes = None + try: + root = ET.fromstring(cas_response_body) + if not root.tag.endswith("serviceResponse"): + raise Exception("root of CAS response is not serviceResponse") + success = (root[0].tag.endswith("authenticationSuccess")) + for child in root[0]: + if child.tag.endswith("user"): + user = child.text + if child.tag.endswith("attributes"): + attributes = {} + for attribute in child: + # ElementTree library expands the namespace in + # attribute tags to the full URL of the namespace. + # We don't care about namespace here and it will always + # be encased in curly braces, so we remove them. + tag = attribute.tag + if "}" in tag: + tag = tag.split("}")[1] + attributes[tag] = attribute.text + if user is None: + raise Exception("CAS response does not contain user") + if attributes is None: + raise Exception("CAS response does not contain attributes") + except Exception: + logger.error("Error parsing CAS response", exc_info=1) + raise LoginError(401, "Invalid CAS response", + errcode=Codes.UNAUTHORIZED) + if not success: + raise LoginError(401, "Unsuccessful CAS response", + errcode=Codes.UNAUTHORIZED) + return user, attributes def register_servlets(hs, http_server): @@ -512,5 +426,3 @@ def register_servlets(hs, http_server): if hs.config.cas_enabled: CasRedirectServlet(hs).register(http_server) CasTicketServlet(hs).register(http_server) - CasRestServlet(hs).register(http_server) - # TODO PasswordResetRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index 65c4e2ebef..355e82474b 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -24,6 +24,10 @@ from synapse.http.servlet import parse_json_object_from_request class ProfileDisplaynameRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/displayname") + def __init__(self, hs): + super(ProfileDisplaynameRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, user_id): user = UserID.from_string(user_id) @@ -62,6 +66,10 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet): class ProfileAvatarURLRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/avatar_url") + def __init__(self, hs): + super(ProfileAvatarURLRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, user_id): user = UserID.from_string(user_id) @@ -99,6 +107,10 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet): class ProfileRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)") + def __init__(self, hs): + super(ProfileRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, user_id): user = UserID.from_string(user_id) diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index 2383b9df86..71d58c8e8d 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -65,6 +65,7 @@ class RegisterRestServlet(ClientV1RestServlet): self.sessions = {} self.enable_registration = hs.config.enable_registration self.auth_handler = hs.get_auth_handler() + self.handlers = hs.get_handlers() def on_GET(self, request): if self.hs.config.enable_registration_captcha: @@ -383,6 +384,7 @@ class CreateUserRestServlet(ClientV1RestServlet): super(CreateUserRestServlet, self).__init__(hs) self.store = hs.get_datastore() self.direct_user_creation_max_duration = hs.config.user_creation_max_duration + self.handlers = hs.get_handlers() @defer.inlineCallbacks def on_POST(self, request): diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 866a1e9120..0d81757010 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -35,6 +35,10 @@ logger = logging.getLogger(__name__) class RoomCreateRestServlet(ClientV1RestServlet): # No PATTERN; we have custom dispatch rules here + def __init__(self, hs): + super(RoomCreateRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + def register(self, http_server): PATTERNS = "/createRoom" register_txn_path(self, PATTERNS, http_server) @@ -82,6 +86,10 @@ class RoomCreateRestServlet(ClientV1RestServlet): # TODO: Needs unit testing for generic events class RoomStateEventRestServlet(ClientV1RestServlet): + def __init__(self, hs): + super(RoomStateEventRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + def register(self, http_server): # /room/$roomid/state/$eventtype no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$" @@ -166,6 +174,10 @@ class RoomStateEventRestServlet(ClientV1RestServlet): # TODO: Needs unit testing for generic events + feedback class RoomSendEventRestServlet(ClientV1RestServlet): + def __init__(self, hs): + super(RoomSendEventRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + def register(self, http_server): # /rooms/$roomid/send/$event_type[/$txn_id] PATTERNS = ("/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)") @@ -210,6 +222,9 @@ class RoomSendEventRestServlet(ClientV1RestServlet): # TODO: Needs unit testing for room ID + alias joins class JoinRoomAliasServlet(ClientV1RestServlet): + def __init__(self, hs): + super(JoinRoomAliasServlet, self).__init__(hs) + self.handlers = hs.get_handlers() def register(self, http_server): # /join/$room_identifier[/$txn_id] @@ -253,6 +268,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet): action="join", txn_id=txn_id, remote_room_hosts=remote_room_hosts, + content=content, third_party_signed=content.get("third_party_signed", None), ) @@ -296,6 +312,10 @@ class PublicRoomListRestServlet(ClientV1RestServlet): class RoomMemberListRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/members$") + def __init__(self, hs): + super(RoomMemberListRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, room_id): # TODO support Pagination stream API (limit/tokens) @@ -322,6 +342,10 @@ class RoomMemberListRestServlet(ClientV1RestServlet): class RoomMessageListRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/messages$") + def __init__(self, hs): + super(RoomMessageListRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, room_id): requester = yield self.auth.get_user_by_req(request, allow_guest=True) @@ -351,6 +375,10 @@ class RoomMessageListRestServlet(ClientV1RestServlet): class RoomStateRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/state$") + def __init__(self, hs): + super(RoomStateRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, room_id): requester = yield self.auth.get_user_by_req(request, allow_guest=True) @@ -368,6 +396,10 @@ class RoomStateRestServlet(ClientV1RestServlet): class RoomInitialSyncRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$") + def __init__(self, hs): + super(RoomInitialSyncRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, room_id): requester = yield self.auth.get_user_by_req(request, allow_guest=True) @@ -388,6 +420,7 @@ class RoomEventContext(ClientV1RestServlet): def __init__(self, hs): super(RoomEventContext, self).__init__(hs) self.clock = hs.get_clock() + self.handlers = hs.get_handlers() @defer.inlineCallbacks def on_GET(self, request, room_id, event_id): @@ -424,6 +457,10 @@ class RoomEventContext(ClientV1RestServlet): class RoomForgetRestServlet(ClientV1RestServlet): + def __init__(self, hs): + super(RoomForgetRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + def register(self, http_server): PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget") register_txn_path(self, PATTERNS, http_server) @@ -462,6 +499,10 @@ class RoomForgetRestServlet(ClientV1RestServlet): # TODO: Needs unit testing class RoomMembershipRestServlet(ClientV1RestServlet): + def __init__(self, hs): + super(RoomMembershipRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + def register(self, http_server): # /rooms/$roomid/[invite|join|leave] PATTERNS = ("/rooms/(?P<room_id>[^/]*)/" @@ -542,6 +583,10 @@ class RoomMembershipRestServlet(ClientV1RestServlet): class RoomRedactEventRestServlet(ClientV1RestServlet): + def __init__(self, hs): + super(RoomRedactEventRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + def register(self, http_server): PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)") register_txn_path(self, PATTERNS, http_server) @@ -624,6 +669,10 @@ class SearchRestServlet(ClientV1RestServlet): "/search$" ) + def __init__(self, hs): + super(SearchRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_POST(self, request): requester = yield self.auth.get_user_by_req(request) |