summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/handlers/auth.py279
-rw-r--r--synapse/storage/events.py44
3 files changed, 222 insertions, 103 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 41745170a1..6dbe8fc7e7 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -16,4 +16,4 @@
 """ This is a reference implementation of a Matrix home server.
 """
 
-__version__ = "0.18.0"
+__version__ = "0.18.1"
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 6986930c0d..3933ce171a 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -31,6 +31,7 @@ import simplejson
 
 try:
     import ldap3
+    import ldap3.core.exceptions
 except ImportError:
     ldap3 = None
     pass
@@ -504,6 +505,144 @@ class AuthHandler(BaseHandler):
             raise LoginError(403, "", errcode=Codes.FORBIDDEN)
         defer.returnValue(user_id)
 
+    def _ldap_simple_bind(self, server, localpart, password):
+        """ Attempt a simple bind with the credentials
+            given by the user against the LDAP server.
+
+            Returns True, LDAP3Connection
+                if the bind was successful
+            Returns False, None
+                if an error occured
+        """
+
+        try:
+            # bind with the the local users ldap credentials
+            bind_dn = "{prop}={value},{base}".format(
+                prop=self.ldap_attributes['uid'],
+                value=localpart,
+                base=self.ldap_base
+            )
+            conn = ldap3.Connection(server, bind_dn, password)
+            logger.debug(
+                "Established LDAP connection in simple bind mode: %s",
+                conn
+            )
+
+            if self.ldap_start_tls:
+                conn.start_tls()
+                logger.debug(
+                    "Upgraded LDAP connection in simple bind mode through StartTLS: %s",
+                    conn
+                )
+
+            if conn.bind():
+                # GOOD: bind okay
+                logger.debug("LDAP Bind successful in simple bind mode.")
+                return True, conn
+
+            # BAD: bind failed
+            logger.info(
+                "Binding against LDAP failed for '%s' failed: %s",
+                localpart, conn.result['description']
+            )
+            conn.unbind()
+            return False, None
+
+        except ldap3.core.exceptions.LDAPException as e:
+            logger.warn("Error during LDAP authentication: %s", e)
+            return False, None
+
+    def _ldap_authenticated_search(self, server, localpart, password):
+        """ Attempt to login with the preconfigured bind_dn
+            and then continue searching and filtering within
+            the base_dn
+
+            Returns (True, LDAP3Connection)
+                if a single matching DN within the base was found
+                that matched the filter expression, and with which
+                a successful bind was achieved
+
+                The LDAP3Connection returned is the instance that was used to
+                verify the password not the one using the configured bind_dn.
+            Returns (False, None)
+                if an error occured
+        """
+
+        try:
+            conn = ldap3.Connection(
+                server,
+                self.ldap_bind_dn,
+                self.ldap_bind_password
+            )
+            logger.debug(
+                "Established LDAP connection in search mode: %s",
+                conn
+            )
+
+            if self.ldap_start_tls:
+                conn.start_tls()
+                logger.debug(
+                    "Upgraded LDAP connection in search mode through StartTLS: %s",
+                    conn
+                )
+
+            if not conn.bind():
+                logger.warn(
+                    "Binding against LDAP with `bind_dn` failed: %s",
+                    conn.result['description']
+                )
+                conn.unbind()
+                return False, None
+
+            # construct search_filter like (uid=localpart)
+            query = "({prop}={value})".format(
+                prop=self.ldap_attributes['uid'],
+                value=localpart
+            )
+            if self.ldap_filter:
+                # combine with the AND expression
+                query = "(&{query}{filter})".format(
+                    query=query,
+                    filter=self.ldap_filter
+                )
+            logger.debug(
+                "LDAP search filter: %s",
+                query
+            )
+            conn.search(
+                search_base=self.ldap_base,
+                search_filter=query
+            )
+
+            if len(conn.response) == 1:
+                # GOOD: found exactly one result
+                user_dn = conn.response[0]['dn']
+                logger.debug('LDAP search found dn: %s', user_dn)
+
+                # unbind and simple bind with user_dn to verify the password
+                # Note: do not use rebind(), for some reason it did not verify
+                #       the password for me!
+                conn.unbind()
+                return self._ldap_simple_bind(server, localpart, password)
+            else:
+                # BAD: found 0 or > 1 results, abort!
+                if len(conn.response) == 0:
+                    logger.info(
+                        "LDAP search returned no results for '%s'",
+                        localpart
+                    )
+                else:
+                    logger.info(
+                        "LDAP search returned too many (%s) results for '%s'",
+                        len(conn.response), localpart
+                    )
+                conn.unbind()
+                return False, None
+
+        except ldap3.core.exceptions.LDAPException as e:
+            logger.warn("Error during LDAP authentication: %s", e)
+            return False, None
+
     @defer.inlineCallbacks
     def _check_ldap_password(self, user_id, password):
         """ Attempt to authenticate a user against an LDAP Server
@@ -516,106 +655,62 @@ class AuthHandler(BaseHandler):
         if not ldap3 or not self.ldap_enabled:
             defer.returnValue(False)
 
-        if self.ldap_mode not in LDAPMode.LIST:
-            raise RuntimeError(
-                'Invalid ldap mode specified: {mode}'.format(
-                    mode=self.ldap_mode
-                )
-            )
+        localpart = UserID.from_string(user_id).localpart
 
         try:
             server = ldap3.Server(self.ldap_uri)
             logger.debug(
-                "Attempting ldap connection with %s",
+                "Attempting LDAP connection with %s",
                 self.ldap_uri
             )
 
-            localpart = UserID.from_string(user_id).localpart
             if self.ldap_mode == LDAPMode.SIMPLE:
-                # bind with the the local users ldap credentials
-                bind_dn = "{prop}={value},{base}".format(
-                    prop=self.ldap_attributes['uid'],
-                    value=localpart,
-                    base=self.ldap_base
+                result, conn = self._ldap_simple_bind(
+                    server=server, localpart=localpart, password=password
                 )
-                conn = ldap3.Connection(server, bind_dn, password)
                 logger.debug(
-                    "Established ldap connection in simple mode: %s",
+                    'LDAP authentication method simple bind returned: %s (conn: %s)',
+                    result,
                     conn
                 )
-
-                if self.ldap_start_tls:
-                    conn.start_tls()
-                    logger.debug(
-                        "Upgraded ldap connection in simple mode through StartTLS: %s",
-                        conn
-                    )
-
-                conn.bind()
-
+                if not result:
+                    defer.returnValue(False)
             elif self.ldap_mode == LDAPMode.SEARCH:
-                # connect with preconfigured credentials and search for local user
-                conn = ldap3.Connection(
-                    server,
-                    self.ldap_bind_dn,
-                    self.ldap_bind_password
+                result, conn = self._ldap_authenticated_search(
+                    server=server, localpart=localpart, password=password
                 )
                 logger.debug(
-                    "Established ldap connection in search mode: %s",
+                    'LDAP auth method authenticated search returned: %s (conn: %s)',
+                    result,
                     conn
                 )
-
-                if self.ldap_start_tls:
-                    conn.start_tls()
-                    logger.debug(
-                        "Upgraded ldap connection in search mode through StartTLS: %s",
-                        conn
+                if not result:
+                    defer.returnValue(False)
+            else:
+                raise RuntimeError(
+                    'Invalid LDAP mode specified: {mode}'.format(
+                        mode=self.ldap_mode
                     )
-
-                conn.bind()
-
-                # find matching dn
-                query = "({prop}={value})".format(
-                    prop=self.ldap_attributes['uid'],
-                    value=localpart
                 )
-                if self.ldap_filter:
-                    query = "(&{query}{filter})".format(
-                        query=query,
-                        filter=self.ldap_filter
-                    )
-                logger.debug("ldap search filter: %s", query)
-                result = conn.search(self.ldap_base, query)
-
-                if result and len(conn.response) == 1:
-                    # found exactly one result
-                    user_dn = conn.response[0]['dn']
-                    logger.debug('ldap search found dn: %s', user_dn)
-
-                    # unbind and reconnect, rebind with found dn
-                    conn.unbind()
-                    conn = ldap3.Connection(
-                        server,
-                        user_dn,
-                        password,
-                        auto_bind=True
-                    )
-                else:
-                    # found 0 or > 1 results, abort!
-                    logger.warn(
-                        "ldap search returned unexpected (%d!=1) amount of results",
-                        len(conn.response)
-                    )
-                    defer.returnValue(False)
 
-            logger.info(
-                "User authenticated against ldap server: %s",
-                conn
-            )
+            try:
+                logger.info(
+                    "User authenticated against LDAP server: %s",
+                    conn
+                )
+            except NameError:
+                logger.warn("Authentication method yielded no LDAP connection, aborting!")
+                defer.returnValue(False)
+
+            # check if user with user_id exists
+            if (yield self.check_user_exists(user_id)):
+                # exists, authentication complete
+                conn.unbind()
+                defer.returnValue(True)
 
-            # check for existing account, if none exists, create one
-            if not (yield self.check_user_exists(user_id)):
-                # query user metadata for account creation
+            else:
+                # does not exist, fetch metadata for account creation from
+                # existing ldap connection
                 query = "({prop}={value})".format(
                     prop=self.ldap_attributes['uid'],
                     value=localpart
@@ -626,9 +721,12 @@ class AuthHandler(BaseHandler):
                         filter=query,
                         user_filter=self.ldap_filter
                     )
-                logger.debug("ldap registration filter: %s", query)
+                logger.debug(
+                    "ldap registration filter: %s",
+                    query
+                )
 
-                result = conn.search(
+                conn.search(
                     search_base=self.ldap_base,
                     search_filter=query,
                     attributes=[
@@ -651,20 +749,27 @@ class AuthHandler(BaseHandler):
                     # TODO: bind email, set displayname with data from ldap directory
 
                     logger.info(
-                        "ldap registration successful: %d: %s (%s, %)",
+                        "Registration based on LDAP data was successful: %d: %s (%s, %)",
                         user_id,
                         localpart,
                         name,
                         mail
                     )
+
+                    defer.returnValue(True)
                 else:
-                    logger.warn(
-                        "ldap registration failed: unexpected (%d!=1) amount of results",
-                        len(conn.response)
-                    )
+                    if len(conn.response) == 0:
+                        logger.warn("LDAP registration failed, no result.")
+                    else:
+                        logger.warn(
+                            "LDAP registration failed, too many results (%s)",
+                            len(conn.response)
+                        )
+
                     defer.returnValue(False)
 
-            defer.returnValue(True)
+            defer.returnValue(False)
+
         except ldap3.core.exceptions.LDAPException as e:
             logger.warn("Error during ldap authentication: %s", e)
             defer.returnValue(False)
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 6dc46fa50f..6cf9d1176d 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -1355,39 +1355,53 @@ class EventsStore(SQLBaseStore):
             min_stream_id = rows[-1][0]
             event_ids = [row[1] for row in rows]
 
-            events = self._get_events_txn(txn, event_ids)
+            rows_to_update = []
 
-            rows = []
-            for event in events:
-                try:
-                    event_id = event.event_id
-                    origin_server_ts = event.origin_server_ts
-                except (KeyError, AttributeError):
-                    # If the event is missing a necessary field then
-                    # skip over it.
-                    continue
+            chunks = [
+                event_ids[i:i + 100]
+                for i in xrange(0, len(event_ids), 100)
+            ]
+            for chunk in chunks:
+                ev_rows = self._simple_select_many_txn(
+                    txn,
+                    table="event_json",
+                    column="event_id",
+                    iterable=chunk,
+                    retcols=["event_id", "json"],
+                    keyvalues={},
+                )
 
-                rows.append((origin_server_ts, event_id))
+                for row in ev_rows:
+                    event_id = row["event_id"]
+                    event_json = json.loads(row["json"])
+                    try:
+                        origin_server_ts = event_json["origin_server_ts"]
+                    except (KeyError, AttributeError):
+                        # If the event is missing a necessary field then
+                        # skip over it.
+                        continue
+
+                    rows_to_update.append((origin_server_ts, event_id))
 
             sql = (
                 "UPDATE events SET origin_server_ts = ? WHERE event_id = ?"
             )
 
-            for index in range(0, len(rows), INSERT_CLUMP_SIZE):
-                clump = rows[index:index + INSERT_CLUMP_SIZE]
+            for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE):
+                clump = rows_to_update[index:index + INSERT_CLUMP_SIZE]
                 txn.executemany(sql, clump)
 
             progress = {
                 "target_min_stream_id_inclusive": target_min_stream_id,
                 "max_stream_id_exclusive": min_stream_id,
-                "rows_inserted": rows_inserted + len(rows)
+                "rows_inserted": rows_inserted + len(rows_to_update)
             }
 
             self._background_update_progress_txn(
                 txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress
             )
 
-            return len(rows)
+            return len(rows_to_update)
 
         result = yield self.runInteraction(
             self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn