summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/__init__.py2
-rw-r--r--synapse/handlers/auth.py113
-rw-r--r--synapse/handlers/federation.py25
-rw-r--r--synapse/handlers/login.py83
-rw-r--r--synapse/handlers/message.py50
-rw-r--r--synapse/handlers/presence.py89
-rw-r--r--synapse/handlers/register.py10
-rw-r--r--synapse/handlers/sync.py21
8 files changed, 222 insertions, 171 deletions
diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py
index dc5b6ef79d..8725c3c420 100644
--- a/synapse/handlers/__init__.py
+++ b/synapse/handlers/__init__.py
@@ -22,7 +22,6 @@ from .room import (
 from .message import MessageHandler
 from .events import EventStreamHandler, EventHandler
 from .federation import FederationHandler
-from .login import LoginHandler
 from .profile import ProfileHandler
 from .presence import PresenceHandler
 from .directory import DirectoryHandler
@@ -54,7 +53,6 @@ class Handlers(object):
         self.profile_handler = ProfileHandler(hs)
         self.presence_handler = PresenceHandler(hs)
         self.room_list_handler = RoomListHandler(hs)
-        self.login_handler = LoginHandler(hs)
         self.directory_handler = DirectoryHandler(hs)
         self.typing_notification_handler = TypingNotificationHandler(hs)
         self.admin_handler = AdminHandler(hs)
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 1ecf7fef17..98d99dd0a8 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -47,17 +47,24 @@ class AuthHandler(BaseHandler):
         self.sessions = {}
 
     @defer.inlineCallbacks
-    def check_auth(self, flows, clientdict, clientip=None):
+    def check_auth(self, flows, clientdict, clientip):
         """
         Takes a dictionary sent by the client in the login / registration
         protocol and handles the login flow.
 
+        As a side effect, this function fills in the 'creds' key on the user's
+        session with a map, which maps each auth-type (str) to the relevant
+        identity authenticated by that auth-type (mostly str, but for captcha, bool).
+
         Args:
-            flows: list of list of stages
-            authdict: The dictionary from the client root level, not the
-                      'auth' key: this method prompts for auth if none is sent.
+            flows (list): A list of login flows. Each flow is an ordered list of
+                          strings representing auth-types. At least one full
+                          flow must be completed in order for auth to be successful.
+            clientdict: The dictionary from the client root level, not the
+                        'auth' key: this method prompts for auth if none is sent.
+            clientip (str): The IP address of the client.
         Returns:
-            A tuple of authed, dict, dict where authed is true if the client
+            A tuple of (authed, dict, dict) where authed is true if the client
             has successfully completed an auth flow. If it is true, the first
             dict contains the authenticated credentials of each stage.
 
@@ -75,7 +82,7 @@ class AuthHandler(BaseHandler):
             del clientdict['auth']
             if 'session' in authdict:
                 sid = authdict['session']
-        sess = self._get_session_info(sid)
+        session = self._get_session_info(sid)
 
         if len(clientdict) > 0:
             # This was designed to allow the client to omit the parameters
@@ -87,20 +94,19 @@ class AuthHandler(BaseHandler):
             # on a home server.
             # Revisit: Assumimg the REST APIs do sensible validation, the data
             # isn't arbintrary.
-            sess['clientdict'] = clientdict
-            self._save_session(sess)
-            pass
-        elif 'clientdict' in sess:
-            clientdict = sess['clientdict']
+            session['clientdict'] = clientdict
+            self._save_session(session)
+        elif 'clientdict' in session:
+            clientdict = session['clientdict']
 
         if not authdict:
             defer.returnValue(
-                (False, self._auth_dict_for_flows(flows, sess), clientdict)
+                (False, self._auth_dict_for_flows(flows, session), clientdict)
             )
 
-        if 'creds' not in sess:
-            sess['creds'] = {}
-        creds = sess['creds']
+        if 'creds' not in session:
+            session['creds'] = {}
+        creds = session['creds']
 
         # check auth type currently being presented
         if 'type' in authdict:
@@ -109,15 +115,15 @@ class AuthHandler(BaseHandler):
             result = yield self.checkers[authdict['type']](authdict, clientip)
             if result:
                 creds[authdict['type']] = result
-                self._save_session(sess)
+                self._save_session(session)
 
         for f in flows:
             if len(set(f) - set(creds.keys())) == 0:
                 logger.info("Auth completed with creds: %r", creds)
-                self._remove_session(sess)
+                self._remove_session(session)
                 defer.returnValue((True, creds, clientdict))
 
-        ret = self._auth_dict_for_flows(flows, sess)
+        ret = self._auth_dict_for_flows(flows, session)
         ret['completed'] = creds.keys()
         defer.returnValue((False, ret, clientdict))
 
@@ -151,22 +157,13 @@ class AuthHandler(BaseHandler):
         if "user" not in authdict or "password" not in authdict:
             raise LoginError(400, "", Codes.MISSING_PARAM)
 
-        user = authdict["user"]
+        user_id = authdict["user"]
         password = authdict["password"]
-        if not user.startswith('@'):
-            user = UserID.create(user, self.hs.hostname).to_string()
+        if not user_id.startswith('@'):
+            user_id = UserID.create(user_id, self.hs.hostname).to_string()
 
-        user_info = yield self.store.get_user_by_id(user_id=user)
-        if not user_info:
-            logger.warn("Attempted to login as %s but they do not exist", user)
-            raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
-
-        stored_hash = user_info["password_hash"]
-        if bcrypt.checkpw(password, stored_hash):
-            defer.returnValue(user)
-        else:
-            logger.warn("Failed password login for user %s", user)
-            raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
+        self._check_password(user_id, password)
+        defer.returnValue(user_id)
 
     @defer.inlineCallbacks
     def _check_recaptcha(self, authdict, clientip):
@@ -270,6 +267,58 @@ class AuthHandler(BaseHandler):
 
         return self.sessions[session_id]
 
+    @defer.inlineCallbacks
+    def login_with_password(self, user_id, password):
+        """
+        Authenticates the user with their username and password.
+
+        Used only by the v1 login API.
+
+        Args:
+            user_id (str): User ID
+            password (str): Password
+        Returns:
+            The access token for the user's session.
+        Raises:
+            StoreError if there was a problem storing the token.
+            LoginError if there was an authentication problem.
+        """
+        self._check_password(user_id, password)
+
+        reg_handler = self.hs.get_handlers().registration_handler
+        access_token = reg_handler.generate_token(user_id)
+        logger.info("Adding token %s for user %s", access_token, user_id)
+        yield self.store.add_access_token_to_user(user_id, access_token)
+        defer.returnValue(access_token)
+
+    def _check_password(self, user_id, password):
+        """Checks that user_id has passed password, raises LoginError if not."""
+        user_info = yield self.store.get_user_by_id(user_id=user_id)
+        if not user_info:
+            logger.warn("Attempted to login as %s but they do not exist", user_id)
+            raise LoginError(403, "", errcode=Codes.FORBIDDEN)
+
+        stored_hash = user_info["password_hash"]
+        if not bcrypt.checkpw(password, stored_hash):
+            logger.warn("Failed password login for user %s", user_id)
+            raise LoginError(403, "", errcode=Codes.FORBIDDEN)
+
+    @defer.inlineCallbacks
+    def set_password(self, user_id, newpassword):
+        password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt())
+
+        yield self.store.user_set_password_hash(user_id, password_hash)
+        yield self.store.user_delete_access_tokens(user_id)
+        yield self.hs.get_pusherpool().remove_pushers_by_user(user_id)
+        yield self.store.flush_user(user_id)
+
+    @defer.inlineCallbacks
+    def add_threepid(self, user_id, medium, address, validated_at):
+        yield self.store.user_add_threepid(
+            user_id, medium, address, validated_at,
+            self.hs.get_clock().time_msec()
+        )
+
     def _save_session(self, session):
         # TODO: Persistent storage
         logger.debug("Saving session %s", session)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index f7155fd8d3..1e3dccf5a8 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -229,15 +229,15 @@ class FederationHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def _filter_events_for_server(self, server_name, room_id, events):
-        states = yield self.store.get_state_for_events(
-            room_id, [e.event_id for e in events],
+        event_to_state = yield self.store.get_state_for_events(
+            room_id, frozenset(e.event_id for e in events),
+            types=(
+                (EventTypes.RoomHistoryVisibility, ""),
+                (EventTypes.Member, None),
+            )
         )
 
-        events_and_states = zip(events, states)
-
-        def redact_disallowed(event_and_state):
-            event, state = event_and_state
-
+        def redact_disallowed(event, state):
             if not state:
                 return event
 
@@ -271,11 +271,10 @@ class FederationHandler(BaseHandler):
 
             return event
 
-        res = map(redact_disallowed, events_and_states)
-
-        logger.info("_filter_events_for_server %r", res)
-
-        defer.returnValue(res)
+        defer.returnValue([
+            redact_disallowed(e, event_to_state[e.event_id])
+            for e in events
+        ])
 
     @log_function
     @defer.inlineCallbacks
@@ -503,7 +502,7 @@ class FederationHandler(BaseHandler):
         event_ids = list(extremities.keys())
 
         states = yield defer.gatherResults([
-            self.state_handler.resolve_state_groups([e])
+            self.state_handler.resolve_state_groups(room_id, [e])
             for e in event_ids
         ])
         states = dict(zip(event_ids, [s[1] for s in states]))
diff --git a/synapse/handlers/login.py b/synapse/handlers/login.py
deleted file mode 100644
index 91d87d503d..0000000000
--- a/synapse/handlers/login.py
+++ /dev/null
@@ -1,83 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2014, 2015 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from twisted.internet import defer
-
-from ._base import BaseHandler
-from synapse.api.errors import LoginError, Codes
-
-import bcrypt
-import logging
-
-logger = logging.getLogger(__name__)
-
-
-class LoginHandler(BaseHandler):
-
-    def __init__(self, hs):
-        super(LoginHandler, self).__init__(hs)
-        self.hs = hs
-
-    @defer.inlineCallbacks
-    def login(self, user, password):
-        """Login as the specified user with the specified password.
-
-        Args:
-            user (str): The user ID.
-            password (str): The password.
-        Returns:
-            The newly allocated access token.
-        Raises:
-            StoreError if there was a problem storing the token.
-            LoginError if there was an authentication problem.
-        """
-        # TODO do this better, it can't go in __init__ else it cyclic loops
-        if not hasattr(self, "reg_handler"):
-            self.reg_handler = self.hs.get_handlers().registration_handler
-
-        # pull out the hash for this user if they exist
-        user_info = yield self.store.get_user_by_id(user_id=user)
-        if not user_info:
-            logger.warn("Attempted to login as %s but they do not exist", user)
-            raise LoginError(403, "", errcode=Codes.FORBIDDEN)
-
-        stored_hash = user_info["password_hash"]
-        if bcrypt.checkpw(password, stored_hash):
-            # generate an access token and store it.
-            token = self.reg_handler._generate_token(user)
-            logger.info("Adding token %s for user %s", token, user)
-            yield self.store.add_access_token_to_user(user, token)
-            defer.returnValue(token)
-        else:
-            logger.warn("Failed password login for user %s", user)
-            raise LoginError(403, "", errcode=Codes.FORBIDDEN)
-
-    @defer.inlineCallbacks
-    def set_password(self, user_id, newpassword, token_id=None):
-        password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt())
-
-        yield self.store.user_set_password_hash(user_id, password_hash)
-        yield self.store.user_delete_access_tokens_apart_from(user_id, token_id)
-        yield self.hs.get_pusherpool().remove_pushers_by_user_access_token(
-            user_id, token_id
-        )
-        yield self.store.flush_user(user_id)
-
-    @defer.inlineCallbacks
-    def add_threepid(self, user_id, medium, address, validated_at):
-        yield self.store.user_add_threepid(
-            user_id, medium, address, validated_at,
-            self.hs.get_clock().time_msec()
-        )
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 9d6d4f0978..f12465fa2c 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -137,15 +137,15 @@ class MessageHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def _filter_events_for_client(self, user_id, room_id, events):
-        states = yield self.store.get_state_for_events(
-            room_id, [e.event_id for e in events],
+        event_id_to_state = yield self.store.get_state_for_events(
+            room_id, frozenset(e.event_id for e in events),
+            types=(
+                (EventTypes.RoomHistoryVisibility, ""),
+                (EventTypes.Member, user_id),
+            )
         )
 
-        events_and_states = zip(events, states)
-
-        def allowed(event_and_state):
-            event, state = event_and_state
-
+        def allowed(event, state):
             if event.type == EventTypes.RoomHistoryVisibility:
                 return True
 
@@ -175,10 +175,10 @@ class MessageHandler(BaseHandler):
 
             return True
 
-        events_and_states = filter(allowed, events_and_states)
         defer.returnValue([
-            ev
-            for ev, _ in events_and_states
+            event
+            for event in events
+            if allowed(event, event_id_to_state[event.event_id])
         ])
 
     @defer.inlineCallbacks
@@ -401,10 +401,14 @@ class MessageHandler(BaseHandler):
             except:
                 logger.exception("Failed to get snapshot")
 
-        yield defer.gatherResults(
-            [handle_room(e) for e in room_list],
-            consumeErrors=True
-        ).addErrback(unwrapFirstError)
+        # Only do N rooms at once
+        n = 5
+        d_list = [handle_room(e) for e in room_list]
+        for i in range(0, len(d_list), n):
+            yield defer.gatherResults(
+                d_list[i:i + n],
+                consumeErrors=True
+            ).addErrback(unwrapFirstError)
 
         ret = {
             "rooms": rooms_ret,
@@ -456,20 +460,14 @@ class MessageHandler(BaseHandler):
 
         @defer.inlineCallbacks
         def get_presence():
-            presence_defs = yield defer.DeferredList(
-                [
-                    presence_handler.get_state(
-                        target_user=UserID.from_string(m.user_id),
-                        auth_user=auth_user,
-                        as_event=True,
-                        check_auth=False,
-                    )
-                    for m in room_members
-                ],
-                consumeErrors=True,
+            states = yield presence_handler.get_states(
+                target_users=[UserID.from_string(m.user_id) for m in room_members],
+                auth_user=auth_user,
+                as_event=True,
+                check_auth=False,
             )
 
-            defer.returnValue([p for success, p in presence_defs if success])
+            defer.returnValue(states.values())
 
         receipts_handler = self.hs.get_handlers().receipts_handler
 
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 341a516da2..e91e81831e 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -192,6 +192,20 @@ class PresenceHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def get_state(self, target_user, auth_user, as_event=False, check_auth=True):
+        """Get the current presence state of the given user.
+
+        Args:
+            target_user (UserID): The user whose presence we want
+            auth_user (UserID): The user requesting the presence, used for
+                checking if said user is allowed to see the persence of the
+                `target_user`
+            as_event (bool): Format the return as an event or not?
+            check_auth (bool): Perform the auth checks or not?
+
+        Returns:
+            dict: The presence state of the `target_user`, whose format depends
+            on the `as_event` argument.
+        """
         if self.hs.is_mine(target_user):
             if check_auth:
                 visible = yield self.is_presence_visible(
@@ -233,6 +247,81 @@ class PresenceHandler(BaseHandler):
             defer.returnValue(state)
 
     @defer.inlineCallbacks
+    def get_states(self, target_users, auth_user, as_event=False, check_auth=True):
+        """A batched version of the `get_state` method that accepts a list of
+        `target_users`
+
+        Args:
+            target_users (list): The list of UserID's whose presence we want
+            auth_user (UserID): The user requesting the presence, used for
+                checking if said user is allowed to see the persence of the
+                `target_users`
+            as_event (bool): Format the return as an event or not?
+            check_auth (bool): Perform the auth checks or not?
+
+        Returns:
+            dict: A mapping from user -> presence_state
+        """
+        local_users, remote_users = partitionbool(
+            target_users,
+            lambda u: self.hs.is_mine(u)
+        )
+
+        if check_auth:
+            for user in local_users:
+                visible = yield self.is_presence_visible(
+                    observer_user=auth_user,
+                    observed_user=user
+                )
+
+                if not visible:
+                    raise SynapseError(404, "Presence information not visible")
+
+        results = {}
+        if local_users:
+            for user in local_users:
+                if user in self._user_cachemap:
+                    results[user] = self._user_cachemap[user].get_state()
+
+            local_to_user = {u.localpart: u for u in local_users}
+
+            states = yield self.store.get_presence_states(
+                [u.localpart for u in local_users if u not in results]
+            )
+
+            for local_part, state in states.items():
+                if state is None:
+                    continue
+                res = {"presence": state["state"]}
+                if "status_msg" in state and state["status_msg"]:
+                    res["status_msg"] = state["status_msg"]
+                results[local_to_user[local_part]] = res
+
+        for user in remote_users:
+            # TODO(paul): Have remote server send us permissions set
+            results[user] = self._get_or_offline_usercache(user).get_state()
+
+        for state in results.values():
+            if "last_active" in state:
+                state["last_active_ago"] = int(
+                    self.clock.time_msec() - state.pop("last_active")
+                )
+
+        if as_event:
+            for user, state in results.items():
+                content = state
+                content["user_id"] = user.to_string()
+
+                if "last_active" in content:
+                    content["last_active_ago"] = int(
+                        self._clock.time_msec() - content.pop("last_active")
+                    )
+
+                results[user] = {"type": "m.presence", "content": content}
+
+        defer.returnValue(results)
+
+    @defer.inlineCallbacks
     @log_function
     def set_state(self, target_user, auth_user, state):
         # return
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index f81d75017d..39392d9fdd 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -91,7 +91,7 @@ class RegistrationHandler(BaseHandler):
             user = UserID(localpart, self.hs.hostname)
             user_id = user.to_string()
 
-            token = self._generate_token(user_id)
+            token = self.generate_token(user_id)
             yield self.store.register(
                 user_id=user_id,
                 token=token,
@@ -111,7 +111,7 @@ class RegistrationHandler(BaseHandler):
                     user_id = user.to_string()
                     yield self.check_user_id_is_valid(user_id)
 
-                    token = self._generate_token(user_id)
+                    token = self.generate_token(user_id)
                     yield self.store.register(
                         user_id=user_id,
                         token=token,
@@ -161,7 +161,7 @@ class RegistrationHandler(BaseHandler):
                 400, "Invalid user localpart for this application service.",
                 errcode=Codes.EXCLUSIVE
             )
-        token = self._generate_token(user_id)
+        token = self.generate_token(user_id)
         yield self.store.register(
             user_id=user_id,
             token=token,
@@ -208,7 +208,7 @@ class RegistrationHandler(BaseHandler):
         user_id = user.to_string()
 
         yield self.check_user_id_is_valid(user_id)
-        token = self._generate_token(user_id)
+        token = self.generate_token(user_id)
         try:
             yield self.store.register(
                 user_id=user_id,
@@ -273,7 +273,7 @@ class RegistrationHandler(BaseHandler):
                     errcode=Codes.EXCLUSIVE
                 )
 
-    def _generate_token(self, user_id):
+    def generate_token(self, user_id):
         # urlsafe variant uses _ and - so use . as the separator and replace
         # all =s with .s so http clients don't quote =s when it is used as
         # query params.
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 6cff6230c1..7206ae23d7 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -294,15 +294,15 @@ class SyncHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def _filter_events_for_client(self, user_id, room_id, events):
-        states = yield self.store.get_state_for_events(
-            room_id, [e.event_id for e in events],
+        event_id_to_state = yield self.store.get_state_for_events(
+            room_id, frozenset(e.event_id for e in events),
+            types=(
+                (EventTypes.RoomHistoryVisibility, ""),
+                (EventTypes.Member, user_id),
+            )
         )
 
-        events_and_states = zip(events, states)
-
-        def allowed(event_and_state):
-            event, state = event_and_state
-
+        def allowed(event, state):
             if event.type == EventTypes.RoomHistoryVisibility:
                 return True
 
@@ -331,10 +331,11 @@ class SyncHandler(BaseHandler):
                 return membership == Membership.INVITE
 
             return True
-        events_and_states = filter(allowed, events_and_states)
+
         defer.returnValue([
-            ev
-            for ev, _ in events_and_states
+            event
+            for event in events
+            if allowed(event, event_id_to_state[event.event_id])
         ])
 
     @defer.inlineCallbacks