diff --git a/AUTHORS.rst b/AUTHORS.rst
index 8711a6ae5c..3dcb1c2a89 100644
--- a/AUTHORS.rst
+++ b/AUTHORS.rst
@@ -57,3 +57,6 @@ Florent Violleau <floviolleau at gmail dot com>
Niklas Riekenbrauck <nikriek at gmail dot.com>
* Add JWT support for registration and login
+
+Christoph Witzany <christoph at web.crofting.com>
+ * Add LDAP support for authentication
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index acf74c8761..9a80ac39ec 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -30,13 +30,14 @@ from .saml2 import SAML2Config
from .cas import CasConfig
from .password import PasswordConfig
from .jwt import JWTConfig
+from .ldap import LDAPConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig,
AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
- JWTConfig, PasswordConfig,):
+ JWTConfig, LDAPConfig, PasswordConfig,):
pass
diff --git a/synapse/config/ldap.py b/synapse/config/ldap.py
new file mode 100644
index 0000000000..9c14593a99
--- /dev/null
+++ b/synapse/config/ldap.py
@@ -0,0 +1,52 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 Niklas Riekenbrauck
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import Config
+
+
+class LDAPConfig(Config):
+ def read_config(self, config):
+ ldap_config = config.get("ldap_config", None)
+ if ldap_config:
+ self.ldap_enabled = ldap_config.get("enabled", False)
+ self.ldap_server = ldap_config["server"]
+ self.ldap_port = ldap_config["port"]
+ self.ldap_tls = ldap_config.get("tls", False)
+ self.ldap_search_base = ldap_config["search_base"]
+ self.ldap_search_property = ldap_config["search_property"]
+ self.ldap_email_property = ldap_config["email_property"]
+ self.ldap_full_name_property = ldap_config["full_name_property"]
+ else:
+ self.ldap_enabled = False
+ self.ldap_server = None
+ self.ldap_port = None
+ self.ldap_tls = False
+ self.ldap_search_base = None
+ self.ldap_search_property = None
+ self.ldap_email_property = None
+ self.ldap_full_name_property = None
+
+ def default_config(self, **kwargs):
+ return """\
+ # ldap_config:
+ # enabled: true
+ # server: "ldap://localhost"
+ # port: 389
+ # tls: false
+ # search_base: "ou=Users,dc=example,dc=com"
+ # search_property: "cn"
+ # email_property: "email"
+ # full_name_property: "givenName"
+ """
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index d5d6faa85f..7a13a8b11c 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -49,6 +49,21 @@ class AuthHandler(BaseHandler):
self.sessions = {}
self.INVALID_TOKEN_HTTP_STATUS = 401
+ self.ldap_enabled = hs.config.ldap_enabled
+ self.ldap_server = hs.config.ldap_server
+ self.ldap_port = hs.config.ldap_port
+ self.ldap_tls = hs.config.ldap_tls
+ self.ldap_search_base = hs.config.ldap_search_base
+ self.ldap_search_property = hs.config.ldap_search_property
+ self.ldap_email_property = hs.config.ldap_email_property
+ self.ldap_full_name_property = hs.config.ldap_full_name_property
+
+ if self.ldap_enabled is True:
+ import ldap
+ logger.info("Import ldap version: %s", ldap.__version__)
+
+ self.hs = hs # FIXME better possibility to access registrationHandler later?
+
@defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip):
"""
@@ -215,8 +230,10 @@ class AuthHandler(BaseHandler):
if not user_id.startswith('@'):
user_id = UserID.create(user_id, self.hs.hostname).to_string()
- user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
- self._check_password(user_id, password, password_hash)
+ if not (yield self._check_password(user_id, password)):
+ logger.warn("Failed password login for user %s", user_id)
+ raise LoginError(403, "", errcode=Codes.FORBIDDEN)
+
defer.returnValue(user_id)
@defer.inlineCallbacks
@@ -340,8 +357,10 @@ class AuthHandler(BaseHandler):
StoreError if there was a problem storing the token.
LoginError if there was an authentication problem.
"""
- user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
- self._check_password(user_id, password, password_hash)
+
+ if not (yield self._check_password(user_id, password)):
+ logger.warn("Failed password login for user %s", user_id)
+ raise LoginError(403, "", errcode=Codes.FORBIDDEN)
logger.info("Logging in user %s", user_id)
access_token = yield self.issue_access_token(user_id)
@@ -407,11 +426,60 @@ class AuthHandler(BaseHandler):
else:
defer.returnValue(user_infos.popitem())
- def _check_password(self, user_id, password, stored_hash):
- """Checks that user_id has passed password, raises LoginError if not."""
- if not self.validate_hash(password, stored_hash):
- logger.warn("Failed password login for user %s", user_id)
- raise LoginError(403, "", errcode=Codes.FORBIDDEN)
+ @defer.inlineCallbacks
+ def _check_password(self, user_id, password):
+ defer.returnValue(
+ not (
+ (yield self._check_ldap_password(user_id, password))
+ or
+ (yield self._check_local_password(user_id, password))
+ ))
+
+ @defer.inlineCallbacks
+ def _check_local_password(self, user_id, password):
+ try:
+ user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
+ defer.returnValue(not self.validate_hash(password, password_hash))
+ except LoginError:
+ defer.returnValue(False)
+
+ @defer.inlineCallbacks
+ def _check_ldap_password(self, user_id, password):
+ if self.ldap_enabled is not True:
+ logger.debug("LDAP not configured")
+ defer.returnValue(False)
+
+ import ldap
+
+ logger.info("Authenticating %s with LDAP" % user_id)
+ try:
+ ldap_url = "%s:%s" % (self.ldap_server, self.ldap_port)
+ logger.debug("Connecting LDAP server at %s" % ldap_url)
+ l = ldap.initialize(ldap_url)
+ if self.ldap_tls:
+ logger.debug("Initiating TLS")
+ self._connection.start_tls_s()
+
+ local_name = UserID.from_string(user_id).localpart
+
+ dn = "%s=%s, %s" % (
+ self.ldap_search_property,
+ local_name,
+ self.ldap_search_base)
+ logger.debug("DN for LDAP authentication: %s" % dn)
+
+ l.simple_bind_s(dn.encode('utf-8'), password.encode('utf-8'))
+
+ if not (yield self.does_user_exist(user_id)):
+ handler = self.hs.get_handlers().registration_handler
+ user_id, access_token = (
+ yield handler.register(localpart=local_name)
+ )
+
+ defer.returnValue(True)
+ except ldap.LDAPError, e:
+ logger.warn("LDAP error: %s", e)
+ defer.returnValue(False)
@defer.inlineCallbacks
def issue_access_token(self, user_id):
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 707ddd248a..cfc728a038 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -69,6 +69,7 @@ class SlavedEventStore(BaseSlavedStore):
"_get_current_state_for_key"
]
+ get_event = DataStore.get_event.__func__
get_current_state = DataStore.get_current_state.__func__
get_current_state_for_key = DataStore.get_current_state_for_key.__func__
get_rooms_for_user_where_membership_is = (
@@ -103,7 +104,7 @@ class SlavedEventStore(BaseSlavedStore):
def stream_positions(self):
result = super(SlavedEventStore, self).stream_positions()
result["events"] = self._stream_id_gen.get_current_token()
- result["backfilled"] = self._backfill_id_gen.get_current_token()
+ result["backfill"] = self._backfill_id_gen.get_current_token()
return result
def process_replication(self, result):
@@ -145,7 +146,6 @@ class SlavedEventStore(BaseSlavedStore):
position = row[0]
internal = json.loads(row[1])
event_json = json.loads(row[2])
-
event = FrozenEvent(event_json, internal_metadata_dict=internal)
self._invalidate_caches_for_event(
event, backfilled, reset_state=position in state_resets
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index f69f1cdad4..46cf93ff87 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -112,7 +112,7 @@ class StreamIdGenerator(object):
self._current + self._step * (n + 1),
self._step
)
- self._current += n
+ self._current += n * self._step
for next_id in next_ids:
self._unfinished_ids.append(next_id)
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 0f525a8943..983caafe8a 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -51,7 +51,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
def check(self, method, args, expected_result=None):
master_result = yield getattr(self.master_store, method)(*args)
slaved_result = yield getattr(self.slaved_store, method)(*args)
- self.assertEqual(master_result, slaved_result)
if expected_result is not None:
self.assertEqual(master_result, expected_result)
self.assertEqual(slaved_result, expected_result)
+ self.assertEqual(master_result, slaved_result)
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 9af62702b3..baa4a26eb5 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -205,13 +205,59 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
[join3]
)
+ @defer.inlineCallbacks
+ def test_redactions(self):
+ yield self.persist(type="m.room.create", key="", creator=USER_ID)
+ yield self.persist(type="m.room.member", key=USER_ID, membership="join")
+
+ msg = yield self.persist(
+ type="m.room.message", msgtype="m.text", body="Hello"
+ )
+ yield self.replicate()
+ yield self.check("get_event", [msg.event_id], msg)
+
+ redaction = yield self.persist(
+ type="m.room.redaction", redacts=msg.event_id
+ )
+ yield self.replicate()
+
+ msg_dict = msg.get_dict()
+ msg_dict["content"] = {}
+ msg_dict["unsigned"]["redacted_by"] = redaction.event_id
+ msg_dict["unsigned"]["redacted_because"] = redaction
+ redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict())
+ yield self.check("get_event", [msg.event_id], redacted)
+
+ @defer.inlineCallbacks
+ def test_backfilled_redactions(self):
+ yield self.persist(type="m.room.create", key="", creator=USER_ID)
+ yield self.persist(type="m.room.member", key=USER_ID, membership="join")
+
+ msg = yield self.persist(
+ type="m.room.message", msgtype="m.text", body="Hello"
+ )
+ yield self.replicate()
+ yield self.check("get_event", [msg.event_id], msg)
+
+ redaction = yield self.persist(
+ type="m.room.redaction", redacts=msg.event_id, backfill=True
+ )
+ yield self.replicate()
+
+ msg_dict = msg.get_dict()
+ msg_dict["content"] = {}
+ msg_dict["unsigned"]["redacted_by"] = redaction.event_id
+ msg_dict["unsigned"]["redacted_because"] = redaction
+ redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict())
+ yield self.check("get_event", [msg.event_id], redacted)
+
event_id = 0
@defer.inlineCallbacks
def persist(
self, sender=USER_ID, room_id=ROOM_ID, type={}, key=None, internal={},
state=None, reset_state=False, backfill=False,
- depth=None, prev_events=[], auth_events=[], prev_state=[],
+ depth=None, prev_events=[], auth_events=[], prev_state=[], redacts=None,
**content
):
"""
@@ -236,6 +282,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
event_dict["state_key"] = key
event_dict["prev_state"] = prev_state
+ if redacts is not None:
+ event_dict["redacts"] = redacts
+
event = FrozenEvent(event_dict, internal_metadata_dict=internal)
self.event_id += 1
|