diff options
author | Mark Haines <mark.haines@matrix.org> | 2015-10-13 10:33:00 +0100 |
---|---|---|
committer | Mark Haines <mark.haines@matrix.org> | 2015-10-13 10:33:00 +0100 |
commit | f96b48067065b12654f60a7d8849eebeb5338a73 (patch) | |
tree | 1b9edd18e4e8c4f2218756b501292fde1fcb252e | |
parent | Start spliting out the rooms into joined and invited in v2 sync (diff) | |
parent | Merge pull request #299 from stevenhammerton/sh-cas-required-attribute (diff) | |
download | synapse-f96b48067065b12654f60a7d8849eebeb5338a73.tar.xz |
Merge branch 'develop' into markjh/v2_sync_api
-rw-r--r-- | AUTHORS.rst | 5 | ||||
-rw-r--r-- | synapse/api/errors.py | 1 | ||||
-rw-r--r-- | synapse/config/cas.py | 43 | ||||
-rw-r--r-- | synapse/config/homeserver.py | 3 | ||||
-rw-r--r-- | synapse/config/server.py | 1 | ||||
-rw-r--r-- | synapse/handlers/auth.py | 32 | ||||
-rw-r--r-- | synapse/handlers/message.py | 13 | ||||
-rw-r--r-- | synapse/http/client.py | 66 | ||||
-rw-r--r-- | synapse/push/__init__.py | 2 | ||||
-rw-r--r-- | synapse/rest/client/v1/initial_sync.py | 4 | ||||
-rw-r--r-- | synapse/rest/client/v1/login.py | 106 |
11 files changed, 242 insertions, 34 deletions
diff --git a/AUTHORS.rst b/AUTHORS.rst index 54ced67000..58a67c6b12 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -44,4 +44,7 @@ Eric Myhre <hash at exultant.us> repository API. Muthu Subramanian <muthu.subramanian.karunanidhi at ericsson.com> - * Add SAML2 support for registration and logins. + * Add SAML2 support for registration and login. + +Steven Hammerton <steven.hammerton at openmarket.com> + * Add CAS support for registration and login. diff --git a/synapse/api/errors.py b/synapse/api/errors.py index ee3045268f..d1356eb4d9 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -47,7 +47,6 @@ class CodeMessageException(RuntimeError): """An exception with integer code and message string attributes.""" def __init__(self, code, msg): - logger.info("%s: %s, %s", type(self).__name__, code, msg) super(CodeMessageException, self).__init__("%d: %s" % (code, msg)) self.code = code self.msg = msg diff --git a/synapse/config/cas.py b/synapse/config/cas.py new file mode 100644 index 0000000000..d268680729 --- /dev/null +++ b/synapse/config/cas.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import Config + + +class CasConfig(Config): + """Cas Configuration + + cas_server_url: URL of CAS server + """ + + def read_config(self, config): + cas_config = config.get("cas_config", None) + if cas_config: + self.cas_enabled = True + self.cas_server_url = cas_config["server_url"] + self.cas_required_attributes = cas_config.get("required_attributes", {}) + else: + self.cas_enabled = False + self.cas_server_url = None + self.cas_required_attributes = {} + + def default_config(self, config_dir_path, server_name, **kwargs): + return """ + # Enable CAS for registration and login. + #cas_config: + # server_url: "https://cas-server.com" + # #required_attributes: + # # name: value + """ diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index d77f045406..3039f3c0bf 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -26,12 +26,13 @@ from .metrics import MetricsConfig from .appservice import AppServiceConfig from .key import KeyConfig from .saml2 import SAML2Config +from .cas import CasConfig class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, RatelimitConfig, ContentRepositoryConfig, CaptchaConfig, VoipConfig, RegistrationConfig, MetricsConfig, - AppServiceConfig, KeyConfig, SAML2Config, ): + AppServiceConfig, KeyConfig, SAML2Config, CasConfig): pass diff --git a/synapse/config/server.py b/synapse/config/server.py index 4d12d49857..5c2d6bfeab 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -26,6 +26,7 @@ class ServerConfig(Config): self.soft_file_limit = config["soft_file_limit"] self.daemonize = config.get("daemonize") self.print_pidfile = config.get("print_pidfile") + self.user_agent_suffix = config.get("user_agent_suffix") self.use_frozen_dicts = config.get("use_frozen_dicts", True) self.listeners = config.get("listeners", []) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 793b3fcd8b..484f719253 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -296,6 +296,38 @@ class AuthHandler(BaseHandler): defer.returnValue((user_id, access_token, refresh_token)) @defer.inlineCallbacks + def login_with_cas_user_id(self, user_id): + """ + Authenticates the user with the given user ID, + intended to have been captured from a CAS response + + Args: + user_id (str): User ID + Returns: + A tuple of: + The user's ID. + The access token for the user's session. + The refresh token for the user's session. + Raises: + StoreError if there was a problem storing the token. + LoginError if there was an authentication problem. + """ + user_id, ignored = yield self._find_user_id_and_pwd_hash(user_id) + + logger.info("Logging in user %s", user_id) + access_token = yield self.issue_access_token(user_id) + refresh_token = yield self.issue_refresh_token(user_id) + defer.returnValue((user_id, access_token, refresh_token)) + + @defer.inlineCallbacks + def does_user_exist(self, user_id): + try: + yield self._find_user_id_and_pwd_hash(user_id) + defer.returnValue(True) + except LoginError: + defer.returnValue(False) + + @defer.inlineCallbacks def _find_user_id_and_pwd_hash(self, user_id): """Checks to see if a user with the given id exists. Will check case insensitively, but will throw if there are multiple inexact matches. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 30949ff7a6..b70258697b 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -324,7 +324,8 @@ class MessageHandler(BaseHandler): ) @defer.inlineCallbacks - def snapshot_all_rooms(self, user_id=None, pagin_config=None, as_client_event=True): + def snapshot_all_rooms(self, user_id=None, pagin_config=None, + as_client_event=True, include_archived=False): """Retrieve a snapshot of all rooms the user is invited or has joined. This snapshot may include messages for all rooms where the user is @@ -335,17 +336,19 @@ class MessageHandler(BaseHandler): pagin_config (synapse.api.streams.PaginationConfig): The pagination config used to determine how many messages *PER ROOM* to return. as_client_event (bool): True to get events in client-server format. + include_archived (bool): True to get rooms that the user has left Returns: A list of dicts with "room_id" and "membership" keys for all rooms the user is currently invited or joined in on. Rooms where the user is joined on, may return a "messages" key with messages, depending on the specified PaginationConfig. """ + memberships = [Membership.INVITE, Membership.JOIN] + if include_archived: + memberships.append(Membership.LEAVE) + room_list = yield self.store.get_rooms_for_user_where_membership_is( - user_id=user_id, - membership_list=[ - Membership.INVITE, Membership.JOIN, Membership.LEAVE - ] + user_id=user_id, membership_list=memberships ) user = UserID.from_string(user_id) diff --git a/synapse/http/client.py b/synapse/http/client.py index 0933388c04..9a5869abee 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -67,7 +67,9 @@ class SimpleHttpClient(object): connectTimeout=15, contextFactory=hs.get_http_client_context_factory() ) - self.version_string = hs.version_string + self.user_agent = hs.version_string + if hs.config.user_agent_suffix: + self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix,) def request(self, method, uri, *args, **kwargs): # A small wrapper around self.agent.request() so we can easily attach @@ -112,7 +114,7 @@ class SimpleHttpClient(object): uri.encode("ascii"), headers=Headers({ b"Content-Type": [b"application/x-www-form-urlencoded"], - b"User-Agent": [self.version_string], + b"User-Agent": [self.user_agent], }), bodyProducer=FileBodyProducer(StringIO(query_bytes)) ) @@ -131,7 +133,8 @@ class SimpleHttpClient(object): "POST", uri.encode("ascii"), headers=Headers({ - "Content-Type": ["application/json"] + b"Content-Type": [b"application/json"], + b"User-Agent": [self.user_agent], }), bodyProducer=FileBodyProducer(StringIO(json_str)) ) @@ -157,16 +160,40 @@ class SimpleHttpClient(object): On a non-2xx HTTP response. The response body will be used as the error message. """ + body = yield self.get_raw(uri, args) + defer.returnValue(json.loads(body)) + + @defer.inlineCallbacks + def put_json(self, uri, json_body, args={}): + """ Puts some json to the given URI. + + Args: + uri (str): The URI to request, not including query parameters + json_body (dict): The JSON to put in the HTTP body, + args (dict): A dictionary used to create query strings, defaults to + None. + **Note**: The value of each key is assumed to be an iterable + and *not* a string. + Returns: + Deferred: Succeeds when we get *any* 2xx HTTP response, with the + HTTP body as JSON. + Raises: + On a non-2xx HTTP response. + """ if len(args): query_bytes = urllib.urlencode(args, True) uri = "%s?%s" % (uri, query_bytes) + json_str = encode_canonical_json(json_body) + response = yield self.request( - "GET", + "PUT", uri.encode("ascii"), headers=Headers({ - b"User-Agent": [self.version_string], - }) + b"User-Agent": [self.user_agent], + "Content-Type": ["application/json"] + }), + bodyProducer=FileBodyProducer(StringIO(json_str)) ) body = yield preserve_context_over_fn(readBody, response) @@ -180,46 +207,39 @@ class SimpleHttpClient(object): raise CodeMessageException(response.code, body) @defer.inlineCallbacks - def put_json(self, uri, json_body, args={}): - """ Puts some json to the given URI. + def get_raw(self, uri, args={}): + """ Gets raw text from the given URI. Args: uri (str): The URI to request, not including query parameters - json_body (dict): The JSON to put in the HTTP body, args (dict): A dictionary used to create query strings, defaults to None. **Note**: The value of each key is assumed to be an iterable and *not* a string. Returns: Deferred: Succeeds when we get *any* 2xx HTTP response, with the - HTTP body as JSON. + HTTP body at text. Raises: - On a non-2xx HTTP response. + On a non-2xx HTTP response. The response body will be used as the + error message. """ if len(args): query_bytes = urllib.urlencode(args, True) uri = "%s?%s" % (uri, query_bytes) - json_str = encode_canonical_json(json_body) - response = yield self.request( - "PUT", + "GET", uri.encode("ascii"), headers=Headers({ - b"User-Agent": [self.version_string], - "Content-Type": ["application/json"] - }), - bodyProducer=FileBodyProducer(StringIO(json_str)) + b"User-Agent": [self.user_agent], + }) ) body = yield preserve_context_over_fn(readBody, response) if 200 <= response.code < 300: - defer.returnValue(json.loads(body)) + defer.returnValue(body) else: - # NB: This is explicitly not json.loads(body)'d because the contract - # of CodeMessageException is a *string* message. Callers can always - # load it into JSON if they want. raise CodeMessageException(response.code, body) @@ -241,7 +261,7 @@ class CaptchaServerHttpClient(SimpleHttpClient): bodyProducer=FileBodyProducer(StringIO(query_bytes)), headers=Headers({ b"Content-Type": [b"application/x-www-form-urlencoded"], - b"User-Agent": [self.version_string], + b"User-Agent": [self.user_agent], }) ) diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py index f1952b5a0f..0e0c61dec8 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py @@ -186,7 +186,7 @@ class Pusher(object): if not display_name: return False return re.search( - "\b%s\b" % re.escape(display_name), ev['content']['body'], + r"\b%s\b" % re.escape(display_name), ev['content']['body'], flags=re.IGNORECASE ) is not None diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index bac68cc29f..52c7943400 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -29,10 +29,12 @@ class InitialSyncRestServlet(ClientV1RestServlet): as_client_event = "raw" not in request.args pagination_config = PaginationConfig.from_request(request) handler = self.handlers.message_handler + include_archived = request.args.get("archived", None) == ["true"] content = yield handler.snapshot_all_rooms( user_id=user.to_string(), pagin_config=pagination_config, - as_client_event=as_client_event + as_client_event=as_client_event, + include_archived=include_archived, ) defer.returnValue((200, content)) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index e580f71964..2e3e4f39f3 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -15,7 +15,8 @@ from twisted.internet import defer -from synapse.api.errors import SynapseError +from synapse.api.errors import SynapseError, LoginError, Codes +from synapse.http.client import SimpleHttpClient from synapse.types import UserID from base import ClientV1RestServlet, client_path_pattern @@ -27,6 +28,8 @@ from saml2 import BINDING_HTTP_POST from saml2 import config from saml2.client import Saml2Client +import xml.etree.ElementTree as ET + logger = logging.getLogger(__name__) @@ -35,16 +38,23 @@ class LoginRestServlet(ClientV1RestServlet): PATTERN = client_path_pattern("/login$") PASS_TYPE = "m.login.password" SAML2_TYPE = "m.login.saml2" + CAS_TYPE = "m.login.cas" def __init__(self, hs): super(LoginRestServlet, self).__init__(hs) self.idp_redirect_url = hs.config.saml2_idp_redirect_url self.saml2_enabled = hs.config.saml2_enabled + self.cas_enabled = hs.config.cas_enabled + self.cas_server_url = hs.config.cas_server_url + self.cas_required_attributes = hs.config.cas_required_attributes + self.servername = hs.config.server_name def on_GET(self, request): flows = [{"type": LoginRestServlet.PASS_TYPE}] if self.saml2_enabled: flows.append({"type": LoginRestServlet.SAML2_TYPE}) + if self.cas_enabled: + flows.append({"type": LoginRestServlet.CAS_TYPE}) return (200, {"flows": flows}) def on_OPTIONS(self, request): @@ -67,6 +77,19 @@ class LoginRestServlet(ClientV1RestServlet): "uri": "%s%s" % (self.idp_redirect_url, relay_state) } defer.returnValue((200, result)) + elif self.cas_enabled and (login_submission["type"] == + LoginRestServlet.CAS_TYPE): + # TODO: get this from the homeserver rather than creating a new one for + # each request + http_client = SimpleHttpClient(self.hs) + uri = "%s/proxyValidate" % (self.cas_server_url,) + args = { + "ticket": login_submission["ticket"], + "service": login_submission["service"] + } + body = yield http_client.get_raw(uri, args) + result = yield self.do_cas_login(body) + defer.returnValue(result) else: raise SynapseError(400, "Bad login type.") except KeyError: @@ -100,6 +123,74 @@ class LoginRestServlet(ClientV1RestServlet): defer.returnValue((200, result)) + @defer.inlineCallbacks + def do_cas_login(self, cas_response_body): + user, attributes = self.parse_cas_response(cas_response_body) + + for required_attribute, required_value in self.cas_required_attributes.items(): + # If required attribute was not in CAS Response - Forbidden + if required_attribute not in attributes: + raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) + + # Also need to check value + if required_value is not None: + actual_value = attributes[required_attribute] + # If required attribute value does not match expected - Forbidden + if required_value != actual_value: + raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) + + user_id = UserID.create(user, self.hs.hostname).to_string() + auth_handler = self.handlers.auth_handler + user_exists = yield auth_handler.does_user_exist(user_id) + if user_exists: + user_id, access_token, refresh_token = ( + yield auth_handler.login_with_cas_user_id(user_id) + ) + result = { + "user_id": user_id, # may have changed + "access_token": access_token, + "refresh_token": refresh_token, + "home_server": self.hs.hostname, + } + + else: + user_id, access_token = ( + yield self.handlers.registration_handler.register(localpart=user) + ) + result = { + "user_id": user_id, # may have changed + "access_token": access_token, + "home_server": self.hs.hostname, + } + + defer.returnValue((200, result)) + + def parse_cas_response(self, cas_response_body): + root = ET.fromstring(cas_response_body) + if not root.tag.endswith("serviceResponse"): + raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) + if not root[0].tag.endswith("authenticationSuccess"): + raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED) + for child in root[0]: + if child.tag.endswith("user"): + user = child.text + if child.tag.endswith("attributes"): + attributes = {} + for attribute in child: + # ElementTree library expands the namespace in attribute tags + # to the full URL of the namespace. + # See (https://docs.python.org/2/library/xml.etree.elementtree.html) + # We don't care about namespace here and it will always be encased in + # curly braces, so we remove them. + if "}" in attribute.tag: + attributes[attribute.tag.split("}")[1]] = attribute.text + else: + attributes[attribute.tag] = attribute.text + if user is None or attributes is None: + raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) + + return (user, attributes) + class LoginFallbackRestServlet(ClientV1RestServlet): PATTERN = client_path_pattern("/login/fallback$") @@ -174,6 +265,17 @@ class SAML2RestServlet(ClientV1RestServlet): defer.returnValue((200, {"status": "not_authenticated"})) +class CasRestServlet(ClientV1RestServlet): + PATTERN = client_path_pattern("/login/cas") + + def __init__(self, hs): + super(CasRestServlet, self).__init__(hs) + self.cas_server_url = hs.config.cas_server_url + + def on_GET(self, request): + return (200, {"serverUrl": self.cas_server_url}) + + def _parse_json(request): try: content = json.loads(request.content.read()) @@ -188,4 +290,6 @@ def register_servlets(hs, http_server): LoginRestServlet(hs).register(http_server) if hs.config.saml2_enabled: SAML2RestServlet(hs).register(http_server) + if hs.config.cas_enabled: + CasRestServlet(hs).register(http_server) # TODO PasswordResetRestServlet(hs).register(http_server) |