diff options
-rw-r--r-- | AUTHORS.rst | 3 | ||||
-rw-r--r-- | synapse/config/homeserver.py | 3 | ||||
-rw-r--r-- | synapse/config/ldap.py | 52 | ||||
-rw-r--r-- | synapse/handlers/auth.py | 86 | ||||
-rw-r--r-- | synapse/replication/slave/storage/events.py | 4 | ||||
-rw-r--r-- | synapse/storage/util/id_generators.py | 2 | ||||
-rw-r--r-- | tests/replication/slave/storage/_base.py | 2 | ||||
-rw-r--r-- | tests/replication/slave/storage/test_events.py | 51 |
8 files changed, 188 insertions, 15 deletions
diff --git a/AUTHORS.rst b/AUTHORS.rst index 8711a6ae5c..3dcb1c2a89 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -57,3 +57,6 @@ Florent Violleau <floviolleau at gmail dot com> Niklas Riekenbrauck <nikriek at gmail dot.com> * Add JWT support for registration and login + +Christoph Witzany <christoph at web.crofting.com> + * Add LDAP support for authentication diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index acf74c8761..9a80ac39ec 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -30,13 +30,14 @@ from .saml2 import SAML2Config from .cas import CasConfig from .password import PasswordConfig from .jwt import JWTConfig +from .ldap import LDAPConfig class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, RatelimitConfig, ContentRepositoryConfig, CaptchaConfig, VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig, AppServiceConfig, KeyConfig, SAML2Config, CasConfig, - JWTConfig, PasswordConfig,): + JWTConfig, LDAPConfig, PasswordConfig,): pass diff --git a/synapse/config/ldap.py b/synapse/config/ldap.py new file mode 100644 index 0000000000..9c14593a99 --- /dev/null +++ b/synapse/config/ldap.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 Niklas Riekenbrauck +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import Config + + +class LDAPConfig(Config): + def read_config(self, config): + ldap_config = config.get("ldap_config", None) + if ldap_config: + self.ldap_enabled = ldap_config.get("enabled", False) + self.ldap_server = ldap_config["server"] + self.ldap_port = ldap_config["port"] + self.ldap_tls = ldap_config.get("tls", False) + self.ldap_search_base = ldap_config["search_base"] + self.ldap_search_property = ldap_config["search_property"] + self.ldap_email_property = ldap_config["email_property"] + self.ldap_full_name_property = ldap_config["full_name_property"] + else: + self.ldap_enabled = False + self.ldap_server = None + self.ldap_port = None + self.ldap_tls = False + self.ldap_search_base = None + self.ldap_search_property = None + self.ldap_email_property = None + self.ldap_full_name_property = None + + def default_config(self, **kwargs): + return """\ + # ldap_config: + # enabled: true + # server: "ldap://localhost" + # port: 389 + # tls: false + # search_base: "ou=Users,dc=example,dc=com" + # search_property: "cn" + # email_property: "email" + # full_name_property: "givenName" + """ diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index d5d6faa85f..7a13a8b11c 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -49,6 +49,21 @@ class AuthHandler(BaseHandler): self.sessions = {} self.INVALID_TOKEN_HTTP_STATUS = 401 + self.ldap_enabled = hs.config.ldap_enabled + self.ldap_server = hs.config.ldap_server + self.ldap_port = hs.config.ldap_port + self.ldap_tls = hs.config.ldap_tls + self.ldap_search_base = hs.config.ldap_search_base + self.ldap_search_property = hs.config.ldap_search_property + self.ldap_email_property = hs.config.ldap_email_property + self.ldap_full_name_property = hs.config.ldap_full_name_property + + if self.ldap_enabled is True: + import ldap + logger.info("Import ldap version: %s", ldap.__version__) + + self.hs = hs # FIXME better possibility to access registrationHandler later? + @defer.inlineCallbacks def check_auth(self, flows, clientdict, clientip): """ @@ -215,8 +230,10 @@ class AuthHandler(BaseHandler): if not user_id.startswith('@'): user_id = UserID.create(user_id, self.hs.hostname).to_string() - user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) - self._check_password(user_id, password, password_hash) + if not (yield self._check_password(user_id, password)): + logger.warn("Failed password login for user %s", user_id) + raise LoginError(403, "", errcode=Codes.FORBIDDEN) + defer.returnValue(user_id) @defer.inlineCallbacks @@ -340,8 +357,10 @@ class AuthHandler(BaseHandler): StoreError if there was a problem storing the token. LoginError if there was an authentication problem. """ - user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) - self._check_password(user_id, password, password_hash) + + if not (yield self._check_password(user_id, password)): + logger.warn("Failed password login for user %s", user_id) + raise LoginError(403, "", errcode=Codes.FORBIDDEN) logger.info("Logging in user %s", user_id) access_token = yield self.issue_access_token(user_id) @@ -407,11 +426,60 @@ class AuthHandler(BaseHandler): else: defer.returnValue(user_infos.popitem()) - def _check_password(self, user_id, password, stored_hash): - """Checks that user_id has passed password, raises LoginError if not.""" - if not self.validate_hash(password, stored_hash): - logger.warn("Failed password login for user %s", user_id) - raise LoginError(403, "", errcode=Codes.FORBIDDEN) + @defer.inlineCallbacks + def _check_password(self, user_id, password): + defer.returnValue( + not ( + (yield self._check_ldap_password(user_id, password)) + or + (yield self._check_local_password(user_id, password)) + )) + + @defer.inlineCallbacks + def _check_local_password(self, user_id, password): + try: + user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) + defer.returnValue(not self.validate_hash(password, password_hash)) + except LoginError: + defer.returnValue(False) + + @defer.inlineCallbacks + def _check_ldap_password(self, user_id, password): + if self.ldap_enabled is not True: + logger.debug("LDAP not configured") + defer.returnValue(False) + + import ldap + + logger.info("Authenticating %s with LDAP" % user_id) + try: + ldap_url = "%s:%s" % (self.ldap_server, self.ldap_port) + logger.debug("Connecting LDAP server at %s" % ldap_url) + l = ldap.initialize(ldap_url) + if self.ldap_tls: + logger.debug("Initiating TLS") + self._connection.start_tls_s() + + local_name = UserID.from_string(user_id).localpart + + dn = "%s=%s, %s" % ( + self.ldap_search_property, + local_name, + self.ldap_search_base) + logger.debug("DN for LDAP authentication: %s" % dn) + + l.simple_bind_s(dn.encode('utf-8'), password.encode('utf-8')) + + if not (yield self.does_user_exist(user_id)): + handler = self.hs.get_handlers().registration_handler + user_id, access_token = ( + yield handler.register(localpart=local_name) + ) + + defer.returnValue(True) + except ldap.LDAPError, e: + logger.warn("LDAP error: %s", e) + defer.returnValue(False) @defer.inlineCallbacks def issue_access_token(self, user_id): diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 707ddd248a..cfc728a038 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -69,6 +69,7 @@ class SlavedEventStore(BaseSlavedStore): "_get_current_state_for_key" ] + get_event = DataStore.get_event.__func__ get_current_state = DataStore.get_current_state.__func__ get_current_state_for_key = DataStore.get_current_state_for_key.__func__ get_rooms_for_user_where_membership_is = ( @@ -103,7 +104,7 @@ class SlavedEventStore(BaseSlavedStore): def stream_positions(self): result = super(SlavedEventStore, self).stream_positions() result["events"] = self._stream_id_gen.get_current_token() - result["backfilled"] = self._backfill_id_gen.get_current_token() + result["backfill"] = self._backfill_id_gen.get_current_token() return result def process_replication(self, result): @@ -145,7 +146,6 @@ class SlavedEventStore(BaseSlavedStore): position = row[0] internal = json.loads(row[1]) event_json = json.loads(row[2]) - event = FrozenEvent(event_json, internal_metadata_dict=internal) self._invalidate_caches_for_event( event, backfilled, reset_state=position in state_resets diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index f69f1cdad4..46cf93ff87 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -112,7 +112,7 @@ class StreamIdGenerator(object): self._current + self._step * (n + 1), self._step ) - self._current += n + self._current += n * self._step for next_id in next_ids: self._unfinished_ids.append(next_id) diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 0f525a8943..983caafe8a 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -51,7 +51,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase): def check(self, method, args, expected_result=None): master_result = yield getattr(self.master_store, method)(*args) slaved_result = yield getattr(self.slaved_store, method)(*args) - self.assertEqual(master_result, slaved_result) if expected_result is not None: self.assertEqual(master_result, expected_result) self.assertEqual(slaved_result, expected_result) + self.assertEqual(master_result, slaved_result) diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 9af62702b3..baa4a26eb5 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -205,13 +205,59 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): [join3] ) + @defer.inlineCallbacks + def test_redactions(self): + yield self.persist(type="m.room.create", key="", creator=USER_ID) + yield self.persist(type="m.room.member", key=USER_ID, membership="join") + + msg = yield self.persist( + type="m.room.message", msgtype="m.text", body="Hello" + ) + yield self.replicate() + yield self.check("get_event", [msg.event_id], msg) + + redaction = yield self.persist( + type="m.room.redaction", redacts=msg.event_id + ) + yield self.replicate() + + msg_dict = msg.get_dict() + msg_dict["content"] = {} + msg_dict["unsigned"]["redacted_by"] = redaction.event_id + msg_dict["unsigned"]["redacted_because"] = redaction + redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict()) + yield self.check("get_event", [msg.event_id], redacted) + + @defer.inlineCallbacks + def test_backfilled_redactions(self): + yield self.persist(type="m.room.create", key="", creator=USER_ID) + yield self.persist(type="m.room.member", key=USER_ID, membership="join") + + msg = yield self.persist( + type="m.room.message", msgtype="m.text", body="Hello" + ) + yield self.replicate() + yield self.check("get_event", [msg.event_id], msg) + + redaction = yield self.persist( + type="m.room.redaction", redacts=msg.event_id, backfill=True + ) + yield self.replicate() + + msg_dict = msg.get_dict() + msg_dict["content"] = {} + msg_dict["unsigned"]["redacted_by"] = redaction.event_id + msg_dict["unsigned"]["redacted_because"] = redaction + redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict()) + yield self.check("get_event", [msg.event_id], redacted) + event_id = 0 @defer.inlineCallbacks def persist( self, sender=USER_ID, room_id=ROOM_ID, type={}, key=None, internal={}, state=None, reset_state=False, backfill=False, - depth=None, prev_events=[], auth_events=[], prev_state=[], + depth=None, prev_events=[], auth_events=[], prev_state=[], redacts=None, **content ): """ @@ -236,6 +282,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): event_dict["state_key"] = key event_dict["prev_state"] = prev_state + if redacts is not None: + event_dict["redacts"] = redacts + event = FrozenEvent(event_dict, internal_metadata_dict=internal) self.event_id += 1 |