summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2015-07-13 13:50:57 +0100
committerErik Johnston <erik@matrix.org>2015-07-13 13:50:57 +0100
commit5989637f37b127f2a0c55ef0b1085ebd0d2928c1 (patch)
tree0480f8a77111f35aede6622e7ff0ed8329b6a6dd
parentComments (diff)
parentMerge pull request #196 from matrix-org/erikj/room_history (diff)
downloadsynapse-5989637f37b127f2a0c55ef0b1085ebd0d2928c1.tar.xz
Merge branch 'develop' of github.com:matrix-org/synapse into erikj/receipts
Diffstat (limited to '')
-rw-r--r--AUTHORS.rst3
-rw-r--r--synapse/api/auth.py3
-rw-r--r--synapse/api/constants.py2
-rw-r--r--synapse/config/homeserver.py5
-rw-r--r--synapse/config/saml2.py54
-rw-r--r--synapse/crypto/keyring.py473
-rw-r--r--synapse/events/utils.py2
-rw-r--r--synapse/federation/federation_base.py125
-rw-r--r--synapse/federation/federation_client.py57
-rw-r--r--synapse/handlers/federation.py54
-rw-r--r--synapse/handlers/message.py66
-rw-r--r--synapse/handlers/register.py29
-rw-r--r--synapse/handlers/room.py1
-rw-r--r--synapse/handlers/sync.py48
-rw-r--r--synapse/python_dependencies.py1
-rw-r--r--synapse/rest/client/v1/login.py75
-rw-r--r--synapse/rest/client/v2_alpha/__init__.py2
-rw-r--r--synapse/rest/client/v2_alpha/keys.py276
-rw-r--r--synapse/storage/__init__.py2
-rw-r--r--synapse/storage/end_to_end_keys.py125
-rw-r--r--synapse/storage/keys.py50
-rw-r--r--synapse/storage/schema/delta/21/end_to_end_keys.sql34
-rw-r--r--synapse/storage/state.py63
23 files changed, 1329 insertions, 221 deletions
diff --git a/AUTHORS.rst b/AUTHORS.rst
index d7224ff5de..54ced67000 100644
--- a/AUTHORS.rst
+++ b/AUTHORS.rst
@@ -42,3 +42,6 @@ Ivan Shapovalov <intelfx100 at gmail.com>
 Eric Myhre <hash at exultant.us>
  * Fix bug where ``media_store_path`` config option was ignored by v0 content
    repository API.
+
+Muthu Subramanian <muthu.subramanian.karunanidhi at ericsson.com>
+ * Add SAML2 support for registration and logins.
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 4da62e5d8d..1a25bf1086 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
 
 AuthEventTypes = (
     EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
-    EventTypes.JoinRules,
+    EventTypes.JoinRules, EventTypes.RoomHistoryVisibility,
 )
 
 
@@ -575,6 +575,7 @@ class Auth(object):
         levels_to_check = [
             ("users_default", []),
             ("events_default", []),
+            ("state_default", []),
             ("ban", []),
             ("redact", []),
             ("kick", []),
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index d8a18ee87b..3e15e8a9d7 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -75,6 +75,8 @@ class EventTypes(object):
     Redaction = "m.room.redaction"
     Feedback = "m.room.message.feedback"
 
+    RoomHistoryVisibility = "m.room.history_visibility"
+
     # These are used for validation
     Message = "m.room.message"
     Topic = "m.room.topic"
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index fe0ccb6eb7..d77f045406 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -25,12 +25,13 @@ from .registration import RegistrationConfig
 from .metrics import MetricsConfig
 from .appservice import AppServiceConfig
 from .key import KeyConfig
+from .saml2 import SAML2Config
 
 
 class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
                        RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
-                       VoipConfig, RegistrationConfig,
-                       MetricsConfig, AppServiceConfig, KeyConfig,):
+                       VoipConfig, RegistrationConfig, MetricsConfig,
+                       AppServiceConfig, KeyConfig, SAML2Config, ):
     pass
 
 
diff --git a/synapse/config/saml2.py b/synapse/config/saml2.py
new file mode 100644
index 0000000000..1532036876
--- /dev/null
+++ b/synapse/config/saml2.py
@@ -0,0 +1,54 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 Ericsson
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import Config
+
+
+class SAML2Config(Config):
+    """SAML2 Configuration
+    Synapse uses pysaml2 libraries for providing SAML2 support
+
+    config_path:      Path to the sp_conf.py configuration file
+    idp_redirect_url: Identity provider URL which will redirect
+                      the user back to /login/saml2 with proper info.
+
+    sp_conf.py file is something like:
+    https://github.com/rohe/pysaml2/blob/master/example/sp-repoze/sp_conf.py.example
+
+    More information: https://pythonhosted.org/pysaml2/howto/config.html
+    """
+
+    def read_config(self, config):
+        saml2_config = config.get("saml2_config", None)
+        if saml2_config:
+            self.saml2_enabled = True
+            self.saml2_config_path = saml2_config["config_path"]
+            self.saml2_idp_redirect_url = saml2_config["idp_redirect_url"]
+        else:
+            self.saml2_enabled = False
+            self.saml2_config_path = None
+            self.saml2_idp_redirect_url = None
+
+    def default_config(self, config_dir_path, server_name):
+        return """
+        # Enable SAML2 for registration and login. Uses pysaml2
+        # config_path:      Path to the sp_conf.py configuration file
+        # idp_redirect_url: Identity provider URL which will redirect
+        #                   the user back to /login/saml2 with proper info.
+        # See pysaml2 docs for format of config.
+        #saml2_config:
+        #   config_path: "%s/sp_conf.py"
+        #   idp_redirect_url: "http://%s/idp"
+        """ % (config_dir_path, server_name)
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index aff69c5f83..aa74d4d0cb 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -25,11 +25,13 @@ from syutil.base64util import decode_base64, encode_base64
 from synapse.api.errors import SynapseError, Codes
 
 from synapse.util.retryutils import get_retry_limiter
+from synapse.util import unwrapFirstError
 
 from synapse.util.async import ObservableDeferred
 
 from OpenSSL import crypto
 
+from collections import namedtuple
 import urllib
 import hashlib
 import logging
@@ -38,6 +40,9 @@ import logging
 logger = logging.getLogger(__name__)
 
 
+KeyGroup = namedtuple("KeyGroup", ("server_name", "group_id", "key_ids"))
+
+
 class Keyring(object):
     def __init__(self, hs):
         self.store = hs.get_datastore()
@@ -49,141 +54,325 @@ class Keyring(object):
 
         self.key_downloads = {}
 
-    @defer.inlineCallbacks
     def verify_json_for_server(self, server_name, json_object):
-        logger.debug("Verifying for %s", server_name)
-        key_ids = signature_ids(json_object, server_name)
-        if not key_ids:
-            raise SynapseError(
-                400,
-                "Not signed with a supported algorithm",
-                Codes.UNAUTHORIZED,
-            )
-        try:
-            verify_key = yield self.get_server_verify_key(server_name, key_ids)
-        except IOError as e:
-            logger.warn(
-                "Got IOError when downloading keys for %s: %s %s",
-                server_name, type(e).__name__, str(e.message),
-            )
-            raise SynapseError(
-                502,
-                "Error downloading keys for %s" % (server_name,),
-                Codes.UNAUTHORIZED,
-            )
-        except Exception as e:
-            logger.warn(
-                "Got Exception when downloading keys for %s: %s %s",
-                server_name, type(e).__name__, str(e.message),
-            )
-            raise SynapseError(
-                401,
-                "No key for %s with id %s" % (server_name, key_ids),
-                Codes.UNAUTHORIZED,
-            )
+        return self.verify_json_objects_for_server(
+            [(server_name, json_object)]
+        )[0]
 
-        try:
-            verify_signed_json(json_object, server_name, verify_key)
-        except:
-            raise SynapseError(
-                401,
-                "Invalid signature for server %s with key %s:%s" % (
-                    server_name, verify_key.alg, verify_key.version
-                ),
-                Codes.UNAUTHORIZED,
+    def verify_json_objects_for_server(self, server_and_json):
+        """Bulk verfies signatures of json objects, bulk fetching keys as
+        necessary.
+
+        Args:
+            server_and_json (list): List of pairs of (server_name, json_object)
+
+        Returns:
+            list of deferreds indicating success or failure to verify each
+            json object's signature for the given server_name.
+        """
+        group_id_to_json = {}
+        group_id_to_group = {}
+        group_ids = []
+
+        next_group_id = 0
+        deferreds = {}
+
+        for server_name, json_object in server_and_json:
+            logger.debug("Verifying for %s", server_name)
+            group_id = next_group_id
+            next_group_id += 1
+            group_ids.append(group_id)
+
+            key_ids = signature_ids(json_object, server_name)
+            if not key_ids:
+                deferreds[group_id] = defer.fail(SynapseError(
+                    400,
+                    "Not signed with a supported algorithm",
+                    Codes.UNAUTHORIZED,
+                ))
+            else:
+                deferreds[group_id] = defer.Deferred()
+
+            group = KeyGroup(server_name, group_id, key_ids)
+
+            group_id_to_group[group_id] = group
+            group_id_to_json[group_id] = json_object
+
+        @defer.inlineCallbacks
+        def handle_key_deferred(group, deferred):
+            server_name = group.server_name
+            try:
+                _, _, key_id, verify_key = yield deferred
+            except IOError as e:
+                logger.warn(
+                    "Got IOError when downloading keys for %s: %s %s",
+                    server_name, type(e).__name__, str(e.message),
+                )
+                raise SynapseError(
+                    502,
+                    "Error downloading keys for %s" % (server_name,),
+                    Codes.UNAUTHORIZED,
+                )
+            except Exception as e:
+                logger.exception(
+                    "Got Exception when downloading keys for %s: %s %s",
+                    server_name, type(e).__name__, str(e.message),
+                )
+                raise SynapseError(
+                    401,
+                    "No key for %s with id %s" % (server_name, key_ids),
+                    Codes.UNAUTHORIZED,
+                )
+
+            json_object = group_id_to_json[group.group_id]
+
+            try:
+                verify_signed_json(json_object, server_name, verify_key)
+            except:
+                raise SynapseError(
+                    401,
+                    "Invalid signature for server %s with key %s:%s" % (
+                        server_name, verify_key.alg, verify_key.version
+                    ),
+                    Codes.UNAUTHORIZED,
+                )
+
+        server_to_deferred = {
+            server_name: defer.Deferred()
+            for server_name, _ in server_and_json
+        }
+
+        # We want to wait for any previous lookups to complete before
+        # proceeding.
+        wait_on_deferred = self.wait_for_previous_lookups(
+            [server_name for server_name, _ in server_and_json],
+            server_to_deferred,
+        )
+
+        # Actually start fetching keys.
+        wait_on_deferred.addBoth(
+            lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
+        )
+
+        # When we've finished fetching all the keys for a given server_name,
+        # resolve the deferred passed to `wait_for_previous_lookups` so that
+        # any lookups waiting will proceed.
+        server_to_gids = {}
+
+        def remove_deferreds(res, server_name, group_id):
+            server_to_gids[server_name].discard(group_id)
+            if not server_to_gids[server_name]:
+                server_to_deferred.pop(server_name).callback(None)
+            return res
+
+        for g_id, deferred in deferreds.items():
+            server_name = group_id_to_group[g_id].server_name
+            server_to_gids.setdefault(server_name, set()).add(g_id)
+            deferred.addBoth(remove_deferreds, server_name, g_id)
+
+        # Pass those keys to handle_key_deferred so that the json object
+        # signatures can be verified
+        return [
+            handle_key_deferred(
+                group_id_to_group[g_id],
+                deferreds[g_id],
             )
+            for g_id in group_ids
+        ]
 
     @defer.inlineCallbacks
-    def get_server_verify_key(self, server_name, key_ids):
-        """Finds a verification key for the server with one of the key ids.
-        Trys to fetch the key from a trusted perspective server first.
+    def wait_for_previous_lookups(self, server_names, server_to_deferred):
+        """Waits for any previous key lookups for the given servers to finish.
+
         Args:
-            server_name(str): The name of the server to fetch a key for.
-            keys_ids (list of str): The key_ids to check for.
+            server_names (list): list of server_names we want to lookup
+            server_to_deferred (dict): server_name to deferred which gets
+                resolved once we've finished looking up keys for that server
+        """
+        while True:
+            wait_on = [
+                self.key_downloads[server_name]
+                for server_name in server_names
+                if server_name in self.key_downloads
+            ]
+            if wait_on:
+                yield defer.DeferredList(wait_on)
+            else:
+                break
+
+        for server_name, deferred in server_to_deferred:
+            self.key_downloads[server_name] = ObservableDeferred(deferred)
+
+    def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred):
+        """Takes a dict of KeyGroups and tries to find at least one key for
+        each group.
         """
-        cached = yield self.store.get_server_verify_keys(server_name, key_ids)
 
-        if cached:
-            defer.returnValue(cached[0])
-            return
+        # These are functions that produce keys given a list of key ids
+        key_fetch_fns = (
+            self.get_keys_from_store,  # First try the local store
+            self.get_keys_from_perspectives,  # Then try via perspectives
+            self.get_keys_from_server,  # Then try directly
+        )
+
+        @defer.inlineCallbacks
+        def do_iterations():
+            merged_results = {}
+
+            missing_keys = {
+                group.server_name: key_id
+                for group in group_id_to_group.values()
+                for key_id in group.key_ids
+            }
+
+            for fn in key_fetch_fns:
+                results = yield fn(missing_keys.items())
+                merged_results.update(results)
+
+                # We now need to figure out which groups we have keys for
+                # and which we don't
+                missing_groups = {}
+                for group in group_id_to_group.values():
+                    for key_id in group.key_ids:
+                        if key_id in merged_results[group.server_name]:
+                            group_id_to_deferred[group.group_id].callback((
+                                group.group_id,
+                                group.server_name,
+                                key_id,
+                                merged_results[group.server_name][key_id],
+                            ))
+                            break
+                    else:
+                        missing_groups.setdefault(
+                            group.server_name, []
+                        ).append(group)
+
+                if not missing_groups:
+                    break
+
+                missing_keys = {
+                    server_name: set(
+                        key_id for group in groups for key_id in group.key_ids
+                    )
+                    for server_name, groups in missing_groups.items()
+                }
 
-        download = self.key_downloads.get(server_name)
+            for group in missing_groups.values():
+                group_id_to_deferred[group.group_id].errback(SynapseError(
+                    401,
+                    "No key for %s with id %s" % (
+                        group.server_name, group.key_ids,
+                    ),
+                    Codes.UNAUTHORIZED,
+                ))
 
-        if download is None:
-            download = self._get_server_verify_key_impl(server_name, key_ids)
-            download = ObservableDeferred(
-                download,
-                consumeErrors=True
-            )
-            self.key_downloads[server_name] = download
+        def on_err(err):
+            for deferred in group_id_to_deferred.values():
+                if not deferred.called:
+                    deferred.errback(err)
 
-            @download.addBoth
-            def callback(ret):
-                del self.key_downloads[server_name]
-                return ret
+        do_iterations().addErrback(on_err)
 
-        r = yield download.observe()
-        defer.returnValue(r)
+        return group_id_to_deferred
 
     @defer.inlineCallbacks
-    def _get_server_verify_key_impl(self, server_name, key_ids):
-        keys = None
+    def get_keys_from_store(self, server_name_and_key_ids):
+        res = yield defer.gatherResults(
+            [
+                self.store.get_server_verify_keys(server_name, key_ids)
+                for server_name, key_ids in server_name_and_key_ids
+            ],
+            consumeErrors=True,
+        ).addErrback(unwrapFirstError)
+
+        defer.returnValue(dict(zip(
+            [server_name for server_name, _ in server_name_and_key_ids],
+            res
+        )))
 
+    @defer.inlineCallbacks
+    def get_keys_from_perspectives(self, server_name_and_key_ids):
         @defer.inlineCallbacks
         def get_key(perspective_name, perspective_keys):
             try:
                 result = yield self.get_server_verify_key_v2_indirect(
-                    server_name, key_ids, perspective_name, perspective_keys
+                    server_name_and_key_ids, perspective_name, perspective_keys
                 )
                 defer.returnValue(result)
             except Exception as e:
-                logging.info(
-                    "Unable to getting key %r for %r from %r: %s %s",
-                    key_ids, server_name, perspective_name,
+                logger.exception(
+                    "Unable to get key from %r: %s %s",
+                    perspective_name,
                     type(e).__name__, str(e.message),
                 )
+                defer.returnValue({})
 
-        perspective_results = yield defer.gatherResults([
-            get_key(p_name, p_keys)
-            for p_name, p_keys in self.perspective_servers.items()
-        ])
+        results = yield defer.gatherResults(
+            [
+                get_key(p_name, p_keys)
+                for p_name, p_keys in self.perspective_servers.items()
+            ],
+            consumeErrors=True,
+        ).addErrback(unwrapFirstError)
 
-        for results in perspective_results:
-            if results is not None:
-                keys = results
+        union_of_keys = {}
+        for result in results:
+            for server_name, keys in result.items():
+                union_of_keys.setdefault(server_name, {}).update(keys)
 
-        limiter = yield get_retry_limiter(
-            server_name,
-            self.clock,
-            self.store,
-        )
+        defer.returnValue(union_of_keys)
 
-        with limiter:
-            if not keys:
+    @defer.inlineCallbacks
+    def get_keys_from_server(self, server_name_and_key_ids):
+        @defer.inlineCallbacks
+        def get_key(server_name, key_ids):
+            limiter = yield get_retry_limiter(
+                server_name,
+                self.clock,
+                self.store,
+            )
+            with limiter:
+                keys = None
                 try:
                     keys = yield self.get_server_verify_key_v2_direct(
                         server_name, key_ids
                     )
                 except Exception as e:
-                    logging.info(
+                    logger.info(
                         "Unable to getting key %r for %r directly: %s %s",
                         key_ids, server_name,
                         type(e).__name__, str(e.message),
                     )
 
-            if not keys:
-                keys = yield self.get_server_verify_key_v1_direct(
-                    server_name, key_ids
-                )
+                if not keys:
+                    keys = yield self.get_server_verify_key_v1_direct(
+                        server_name, key_ids
+                    )
+
+                    keys = {server_name: keys}
+
+            defer.returnValue(keys)
+
+        results = yield defer.gatherResults(
+            [
+                get_key(server_name, key_ids)
+                for server_name, key_ids in server_name_and_key_ids
+            ],
+            consumeErrors=True,
+        ).addErrback(unwrapFirstError)
 
-        for key_id in key_ids:
-            if key_id in keys:
-                defer.returnValue(keys[key_id])
-                return
-        raise ValueError("No verification key found for given key ids")
+        merged = {}
+        for result in results:
+            merged.update(result)
+
+        defer.returnValue({
+            server_name: keys
+            for server_name, keys in merged.items()
+            if keys
+        })
 
     @defer.inlineCallbacks
-    def get_server_verify_key_v2_indirect(self, server_name, key_ids,
+    def get_server_verify_key_v2_indirect(self, server_names_and_key_ids,
                                           perspective_name,
                                           perspective_keys):
         limiter = yield get_retry_limiter(
@@ -204,6 +393,7 @@ class Keyring(object):
                                 u"minimum_valid_until_ts": 0
                             } for key_id in key_ids
                         }
+                        for server_name, key_ids in server_names_and_key_ids
                     }
                 },
             )
@@ -243,23 +433,29 @@ class Keyring(object):
                     " server %r" % (perspective_name,)
                 )
 
-            response_keys = yield self.process_v2_response(
-                server_name, perspective_name, response
+            processed_response = yield self.process_v2_response(
+                perspective_name, response
             )
 
-            keys.update(response_keys)
+            for server_name, response_keys in processed_response.items():
+                keys.setdefault(server_name, {}).update(response_keys)
 
-        yield self.store_keys(
-            server_name=server_name,
-            from_server=perspective_name,
-            verify_keys=keys,
-        )
+        yield defer.gatherResults(
+            [
+                self.store_keys(
+                    server_name=server_name,
+                    from_server=perspective_name,
+                    verify_keys=response_keys,
+                )
+                for server_name, response_keys in keys.items()
+            ],
+            consumeErrors=True
+        ).addErrback(unwrapFirstError)
 
         defer.returnValue(keys)
 
     @defer.inlineCallbacks
     def get_server_verify_key_v2_direct(self, server_name, key_ids):
-
         keys = {}
 
         for requested_key_id in key_ids:
@@ -295,25 +491,30 @@ class Keyring(object):
                 raise ValueError("TLS certificate not allowed by fingerprints")
 
             response_keys = yield self.process_v2_response(
-                server_name=server_name,
                 from_server=server_name,
-                requested_id=requested_key_id,
+                requested_ids=[requested_key_id],
                 response_json=response,
             )
 
             keys.update(response_keys)
 
-        yield self.store_keys(
-            server_name=server_name,
-            from_server=server_name,
-            verify_keys=keys,
-        )
+        yield defer.gatherResults(
+            [
+                self.store_keys(
+                    server_name=key_server_name,
+                    from_server=server_name,
+                    verify_keys=verify_keys,
+                )
+                for key_server_name, verify_keys in keys.items()
+            ],
+            consumeErrors=True
+        ).addErrback(unwrapFirstError)
 
         defer.returnValue(keys)
 
     @defer.inlineCallbacks
-    def process_v2_response(self, server_name, from_server, response_json,
-                            requested_id=None):
+    def process_v2_response(self, from_server, response_json,
+                            requested_ids=[]):
         time_now_ms = self.clock.time_msec()
         response_keys = {}
         verify_keys = {}
@@ -335,6 +536,8 @@ class Keyring(object):
                 verify_key.time_added = time_now_ms
                 old_verify_keys[key_id] = verify_key
 
+        results = {}
+        server_name = response_json["server_name"]
         for key_id in response_json["signatures"].get(server_name, {}):
             if key_id not in response_json["verify_keys"]:
                 raise ValueError(
@@ -357,28 +560,31 @@ class Keyring(object):
         signed_key_json_bytes = encode_canonical_json(signed_key_json)
         ts_valid_until_ms = signed_key_json[u"valid_until_ts"]
 
-        updated_key_ids = set()
-        if requested_id is not None:
-            updated_key_ids.add(requested_id)
+        updated_key_ids = set(requested_ids)
         updated_key_ids.update(verify_keys)
         updated_key_ids.update(old_verify_keys)
 
         response_keys.update(verify_keys)
         response_keys.update(old_verify_keys)
 
-        for key_id in updated_key_ids:
-            yield self.store.store_server_keys_json(
-                server_name=server_name,
-                key_id=key_id,
-                from_server=server_name,
-                ts_now_ms=time_now_ms,
-                ts_expires_ms=ts_valid_until_ms,
-                key_json_bytes=signed_key_json_bytes,
-            )
+        yield defer.gatherResults(
+            [
+                self.store.store_server_keys_json(
+                    server_name=server_name,
+                    key_id=key_id,
+                    from_server=server_name,
+                    ts_now_ms=time_now_ms,
+                    ts_expires_ms=ts_valid_until_ms,
+                    key_json_bytes=signed_key_json_bytes,
+                )
+                for key_id in updated_key_ids
+            ],
+            consumeErrors=True,
+        ).addErrback(unwrapFirstError)
 
-        defer.returnValue(response_keys)
+        results[server_name] = response_keys
 
-        raise ValueError("No verification key found for given key ids")
+        defer.returnValue(results)
 
     @defer.inlineCallbacks
     def get_server_verify_key_v1_direct(self, server_name, key_ids):
@@ -462,8 +668,13 @@ class Keyring(object):
         Returns:
             A deferred that completes when the keys are stored.
         """
-        for key_id, key in verify_keys.items():
-            # TODO(markjh): Store whether the keys have expired.
-            yield self.store.store_server_verify_key(
-                server_name, server_name, key.time_added, key
-            )
+        # TODO(markjh): Store whether the keys have expired.
+        yield defer.gatherResults(
+            [
+                self.store.store_server_verify_key(
+                    server_name, server_name, key.time_added, key
+                )
+                for key_id, key in verify_keys.items()
+            ],
+            consumeErrors=True,
+        ).addErrback(unwrapFirstError)
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 1aa952150e..7bd78343f0 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -74,6 +74,8 @@ def prune_event(event):
         )
     elif event_type == EventTypes.Aliases:
         add_fields("aliases")
+    elif event_type == EventTypes.RoomHistoryVisibility:
+        add_fields("history_visibility")
 
     allowed_fields = {
         k: v
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 299493af91..bdfa247604 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -32,7 +32,8 @@ logger = logging.getLogger(__name__)
 
 class FederationBase(object):
     @defer.inlineCallbacks
-    def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False):
+    def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
+                                       include_none=False):
         """Takes a list of PDUs and checks the signatures and hashs of each
         one. If a PDU fails its signature check then we check if we have it in
         the database and if not then request if from the originating server of
@@ -50,84 +51,108 @@ class FederationBase(object):
         Returns:
             Deferred : A list of PDUs that have valid signatures and hashes.
         """
+        deferreds = self._check_sigs_and_hashes(pdus)
 
-        signed_pdus = []
+        def callback(pdu):
+            return pdu
 
-        @defer.inlineCallbacks
-        def do(pdu):
-            try:
-                new_pdu = yield self._check_sigs_and_hash(pdu)
-                signed_pdus.append(new_pdu)
-            except SynapseError:
-                # FIXME: We should handle signature failures more gracefully.
+        def errback(failure, pdu):
+            failure.trap(SynapseError)
+            return None
 
+        def try_local_db(res, pdu):
+            if not res:
                 # Check local db.
-                new_pdu = yield self.store.get_event(
+                return self.store.get_event(
                     pdu.event_id,
                     allow_rejected=True,
                     allow_none=True,
                 )
-                if new_pdu:
-                    signed_pdus.append(new_pdu)
-                    return
-
-                # Check pdu.origin
-                if pdu.origin != origin:
-                    try:
-                        new_pdu = yield self.get_pdu(
-                            destinations=[pdu.origin],
-                            event_id=pdu.event_id,
-                            outlier=outlier,
-                            timeout=10000,
-                        )
-
-                        if new_pdu:
-                            signed_pdus.append(new_pdu)
-                            return
-                    except:
-                        pass
-
+            return res
+
+        def try_remote(res, pdu):
+            if not res and pdu.origin != origin:
+                return self.get_pdu(
+                    destinations=[pdu.origin],
+                    event_id=pdu.event_id,
+                    outlier=outlier,
+                    timeout=10000,
+                ).addErrback(lambda e: None)
+            return res
+
+        def warn(res, pdu):
+            if not res:
                 logger.warn(
                     "Failed to find copy of %s with valid signature",
                     pdu.event_id,
                 )
+            return res
+
+        for pdu, deferred in zip(pdus, deferreds):
+            deferred.addCallbacks(
+                callback, errback, errbackArgs=[pdu]
+            ).addCallback(
+                try_local_db, pdu
+            ).addCallback(
+                try_remote, pdu
+            ).addCallback(
+                warn, pdu
+            )
 
-        yield defer.gatherResults(
-            [do(pdu) for pdu in pdus],
+        valid_pdus = yield defer.gatherResults(
+            deferreds,
             consumeErrors=True
         ).addErrback(unwrapFirstError)
 
-        defer.returnValue(signed_pdus)
+        if include_none:
+            defer.returnValue(valid_pdus)
+        else:
+            defer.returnValue([p for p in valid_pdus if p])
 
-    @defer.inlineCallbacks
     def _check_sigs_and_hash(self, pdu):
-        """Throws a SynapseError if the PDU does not have the correct
+        return self._check_sigs_and_hashes([pdu])[0]
+
+    def _check_sigs_and_hashes(self, pdus):
+        """Throws a SynapseError if a PDU does not have the correct
         signatures.
 
         Returns:
             FrozenEvent: Either the given event or it redacted if it failed the
             content hash check.
         """
-        # Check signatures are correct.
-        redacted_event = prune_event(pdu)
-        redacted_pdu_json = redacted_event.get_pdu_json()
 
-        try:
-            yield self.keyring.verify_json_for_server(
-                pdu.origin, redacted_pdu_json
-            )
-        except SynapseError:
+        redacted_pdus = [
+            prune_event(pdu)
+            for pdu in pdus
+        ]
+
+        deferreds = self.keyring.verify_json_objects_for_server([
+            (p.origin, p.get_pdu_json())
+            for p in redacted_pdus
+        ])
+
+        def callback(_, pdu, redacted):
+            if not check_event_content_hash(pdu):
+                logger.warn(
+                    "Event content has been tampered, redacting %s: %s",
+                    pdu.event_id, pdu.get_pdu_json()
+                )
+                return redacted
+            return pdu
+
+        def errback(failure, pdu):
+            failure.trap(SynapseError)
             logger.warn(
                 "Signature check failed for %s",
                 pdu.event_id,
             )
-            raise
+            return failure
 
-        if not check_event_content_hash(pdu):
-            logger.warn(
-                "Event content has been tampered, redacting.",
-                pdu.event_id,
+        for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus):
+            deferred.addCallbacks(
+                callback, errback,
+                callbackArgs=[pdu, redacted],
+                errbackArgs=[pdu],
             )
-            defer.returnValue(redacted_event)
 
-        defer.returnValue(pdu)
+        return deferreds
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index d3b46b24c1..7736d14fb5 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -30,6 +30,7 @@ import synapse.metrics
 
 from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
 
+import copy
 import itertools
 import logging
 import random
@@ -167,7 +168,7 @@ class FederationClient(FederationBase):
 
         # FIXME: We should handle signature failures more gracefully.
         pdus[:] = yield defer.gatherResults(
-            [self._check_sigs_and_hash(pdu) for pdu in pdus],
+            self._check_sigs_and_hashes(pdus),
             consumeErrors=True,
         ).addErrback(unwrapFirstError)
 
@@ -230,7 +231,7 @@ class FederationClient(FederationBase):
                         pdu = pdu_list[0]
 
                         # Check signatures are correct.
-                        pdu = yield self._check_sigs_and_hash(pdu)
+                        pdu = yield self._check_sigs_and_hashes([pdu])[0]
 
                         break
 
@@ -327,6 +328,9 @@ class FederationClient(FederationBase):
     @defer.inlineCallbacks
     def make_join(self, destinations, room_id, user_id):
         for destination in destinations:
+            if destination == self.server_name:
+                continue
+
             try:
                 ret = yield self.transport_layer.make_join(
                     destination, room_id, user_id
@@ -353,6 +357,9 @@ class FederationClient(FederationBase):
     @defer.inlineCallbacks
     def send_join(self, destinations, pdu):
         for destination in destinations:
+            if destination == self.server_name:
+                continue
+
             try:
                 time_now = self._clock.time_msec()
                 _, content = yield self.transport_layer.send_join(
@@ -374,17 +381,39 @@ class FederationClient(FederationBase):
                     for p in content.get("auth_chain", [])
                 ]
 
-                signed_state, signed_auth = yield defer.gatherResults(
-                    [
-                        self._check_sigs_and_hash_and_fetch(
-                            destination, state, outlier=True
-                        ),
-                        self._check_sigs_and_hash_and_fetch(
-                            destination, auth_chain, outlier=True
-                        )
-                    ],
-                    consumeErrors=True
-                ).addErrback(unwrapFirstError)
+                pdus = {
+                    p.event_id: p
+                    for p in itertools.chain(state, auth_chain)
+                }
+
+                valid_pdus = yield self._check_sigs_and_hash_and_fetch(
+                    destination, pdus.values(),
+                    outlier=True,
+                )
+
+                valid_pdus_map = {
+                    p.event_id: p
+                    for p in valid_pdus
+                }
+
+                # NB: We *need* to copy to ensure that we don't have multiple
+                # references being passed on, as that causes... issues.
+                signed_state = [
+                    copy.copy(valid_pdus_map[p.event_id])
+                    for p in state
+                    if p.event_id in valid_pdus_map
+                ]
+
+                signed_auth = [
+                    valid_pdus_map[p.event_id]
+                    for p in auth_chain
+                    if p.event_id in valid_pdus_map
+                ]
+
+                # NB: We *need* to copy to ensure that we don't have multiple
+                # references being passed on, as that causes... issues.
+                for s in signed_state:
+                    s.internal_metadata = copy.deepcopy(s.internal_metadata)
 
                 auth_chain.sort(key=lambda e: e.depth)
 
@@ -396,7 +425,7 @@ class FederationClient(FederationBase):
             except CodeMessageException:
                 raise
             except Exception as e:
-                logger.warn(
+                logger.exception(
                     "Failed to send_join via %s: %s",
                     destination, e.message
                 )
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index b5d882fd65..d7f197f247 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -31,6 +31,8 @@ from synapse.crypto.event_signing import (
 )
 from synapse.types import UserID
 
+from synapse.events.utils import prune_event
+
 from synapse.util.retryutils import NotRetryingDestination
 
 from twisted.internet import defer
@@ -222,6 +224,56 @@ class FederationHandler(BaseHandler):
                     "user_joined_room", user=user, room_id=event.room_id
                 )
 
+    @defer.inlineCallbacks
+    def _filter_events_for_server(self, server_name, room_id, events):
+        states = yield self.store.get_state_for_events(
+            room_id, [e.event_id for e in events],
+        )
+
+        events_and_states = zip(events, states)
+
+        def redact_disallowed(event_and_state):
+            event, state = event_and_state
+
+            if not state:
+                return event
+
+            history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
+            if history:
+                visibility = history.content.get("history_visibility", "shared")
+                if visibility in ["invited", "joined"]:
+                    # We now loop through all state events looking for
+                    # membership states for the requesting server to determine
+                    # if the server is either in the room or has been invited
+                    # into the room.
+                    for ev in state.values():
+                        if ev.type != EventTypes.Member:
+                            continue
+                        try:
+                            domain = UserID.from_string(ev.state_key).domain
+                        except:
+                            continue
+
+                        if domain != server_name:
+                            continue
+
+                        memtype = ev.membership
+                        if memtype == Membership.JOIN:
+                            return event
+                        elif memtype == Membership.INVITE:
+                            if visibility == "invited":
+                                return event
+                    else:
+                        return prune_event(event)
+
+            return event
+
+        res = map(redact_disallowed, events_and_states)
+
+        logger.info("_filter_events_for_server %r", res)
+
+        defer.returnValue(res)
+
     @log_function
     @defer.inlineCallbacks
     def backfill(self, dest, room_id, limit, extremities=[]):
@@ -882,6 +934,8 @@ class FederationHandler(BaseHandler):
             limit
         )
 
+        events = yield self._filter_events_for_server(origin, room_id, events)
+
         defer.returnValue(events)
 
     @defer.inlineCallbacks
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 7c1d6b5489..9d6d4f0978 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -113,11 +113,21 @@ class MessageHandler(BaseHandler):
             "room_key", next_key
         )
 
+        if not events:
+            defer.returnValue({
+                "chunk": [],
+                "start": pagin_config.from_token.to_string(),
+                "end": next_token.to_string(),
+            })
+
+        events = yield self._filter_events_for_client(user_id, room_id, events)
+
         time_now = self.clock.time_msec()
 
         chunk = {
             "chunk": [
-                serialize_event(e, time_now, as_client_event) for e in events
+                serialize_event(e, time_now, as_client_event)
+                for e in events
             ],
             "start": pagin_config.from_token.to_string(),
             "end": next_token.to_string(),
@@ -126,6 +136,52 @@ class MessageHandler(BaseHandler):
         defer.returnValue(chunk)
 
     @defer.inlineCallbacks
+    def _filter_events_for_client(self, user_id, room_id, events):
+        states = yield self.store.get_state_for_events(
+            room_id, [e.event_id for e in events],
+        )
+
+        events_and_states = zip(events, states)
+
+        def allowed(event_and_state):
+            event, state = event_and_state
+
+            if event.type == EventTypes.RoomHistoryVisibility:
+                return True
+
+            membership_ev = state.get((EventTypes.Member, user_id), None)
+            if membership_ev:
+                membership = membership_ev.membership
+            else:
+                membership = Membership.LEAVE
+
+            if membership == Membership.JOIN:
+                return True
+
+            history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
+            if history:
+                visibility = history.content.get("history_visibility", "shared")
+            else:
+                visibility = "shared"
+
+            if visibility == "public":
+                return True
+            elif visibility == "shared":
+                return True
+            elif visibility == "joined":
+                return membership == Membership.JOIN
+            elif visibility == "invited":
+                return membership == Membership.INVITE
+
+            return True
+
+        events_and_states = filter(allowed, events_and_states)
+        defer.returnValue([
+            ev
+            for ev, _ in events_and_states
+        ])
+
+    @defer.inlineCallbacks
     def create_and_send_event(self, event_dict, ratelimit=True,
                               client=None, txn_id=None):
         """ Given a dict from a client, create and handle a new event.
@@ -321,6 +377,10 @@ class MessageHandler(BaseHandler):
                     ]
                 ).addErrback(unwrapFirstError)
 
+                messages = yield self._filter_events_for_client(
+                    user_id, event.room_id, messages
+                )
+
                 start_token = now_token.copy_and_replace("room_key", token[0])
                 end_token = now_token.copy_and_replace("room_key", token[1])
                 time_now = self.clock.time_msec()
@@ -426,6 +486,10 @@ class MessageHandler(BaseHandler):
             consumeErrors=True,
         ).addErrback(unwrapFirstError)
 
+        messages = yield self._filter_events_for_client(
+            user_id, room_id, messages
+        )
+
         start_token = now_token.copy_and_replace("room_key", token[0])
         end_token = now_token.copy_and_replace("room_key", token[1])
 
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 7b68585a17..a1288b4252 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -193,6 +193,35 @@ class RegistrationHandler(BaseHandler):
             logger.info("Valid captcha entered from %s", ip)
 
     @defer.inlineCallbacks
+    def register_saml2(self, localpart):
+        """
+        Registers email_id as SAML2 Based Auth.
+        """
+        if urllib.quote(localpart) != localpart:
+            raise SynapseError(
+                400,
+                "User ID must only contain characters which do not"
+                " require URL encoding."
+                )
+        user = UserID(localpart, self.hs.hostname)
+        user_id = user.to_string()
+
+        yield self.check_user_id_is_valid(user_id)
+        token = self._generate_token(user_id)
+        try:
+            yield self.store.register(
+                user_id=user_id,
+                token=token,
+                password_hash=None
+            )
+            yield self.distributor.fire("registered_user", user)
+        except Exception, e:
+            yield self.store.add_access_token_to_user(user_id, token)
+            # Ignore Registration errors
+            logger.exception(e)
+        defer.returnValue((user_id, token))
+
+    @defer.inlineCallbacks
     def register_email(self, threepidCreds):
         """
         Registers emails with an identity server.
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 4bd027d9bb..891707df44 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -213,6 +213,7 @@ class RoomCreationHandler(BaseHandler):
                 "events": {
                     EventTypes.Name: 100,
                     EventTypes.PowerLevels: 100,
+                    EventTypes.RoomHistoryVisibility: 100,
                 },
                 "events_default": 0,
                 "state_default": 50,
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index bd8c603681..6cff6230c1 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -293,6 +293,51 @@ class SyncHandler(BaseHandler):
         ))
 
     @defer.inlineCallbacks
+    def _filter_events_for_client(self, user_id, room_id, events):
+        states = yield self.store.get_state_for_events(
+            room_id, [e.event_id for e in events],
+        )
+
+        events_and_states = zip(events, states)
+
+        def allowed(event_and_state):
+            event, state = event_and_state
+
+            if event.type == EventTypes.RoomHistoryVisibility:
+                return True
+
+            membership_ev = state.get((EventTypes.Member, user_id), None)
+            if membership_ev:
+                membership = membership_ev.membership
+            else:
+                membership = Membership.LEAVE
+
+            if membership == Membership.JOIN:
+                return True
+
+            history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
+            if history:
+                visibility = history.content.get("history_visibility", "shared")
+            else:
+                visibility = "shared"
+
+            if visibility == "public":
+                return True
+            elif visibility == "shared":
+                return True
+            elif visibility == "joined":
+                return membership == Membership.JOIN
+            elif visibility == "invited":
+                return membership == Membership.INVITE
+
+            return True
+        events_and_states = filter(allowed, events_and_states)
+        defer.returnValue([
+            ev
+            for ev, _ in events_and_states
+        ])
+
+    @defer.inlineCallbacks
     def load_filtered_recents(self, room_id, sync_config, now_token,
                               since_token=None):
         limited = True
@@ -313,6 +358,9 @@ class SyncHandler(BaseHandler):
             (room_key, _) = keys
             end_key = "s" + room_key.split('-')[-1]
             loaded_recents = sync_config.filter.filter_room_events(events)
+            loaded_recents = yield self._filter_events_for_client(
+                sync_config.user.to_string(), room_id, loaded_recents,
+            )
             loaded_recents.extend(recents)
             recents = loaded_recents
             if len(events) <= load_limit:
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index daaa61b5f2..115bee8c41 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -32,6 +32,7 @@ REQUIREMENTS = {
     "pydenticon": ["pydenticon"],
     "ujson": ["ujson"],
     "blist": ["blist"],
+    "pysaml2": ["saml2"],
 }
 CONDITIONAL_REQUIREMENTS = {
     "web_client": {
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index b2257b749d..998d4d44c6 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -20,14 +20,32 @@ from synapse.types import UserID
 from base import ClientV1RestServlet, client_path_pattern
 
 import simplejson as json
+import urllib
+
+import logging
+from saml2 import BINDING_HTTP_POST
+from saml2 import config
+from saml2.client import Saml2Client
+
+
+logger = logging.getLogger(__name__)
 
 
 class LoginRestServlet(ClientV1RestServlet):
     PATTERN = client_path_pattern("/login$")
     PASS_TYPE = "m.login.password"
+    SAML2_TYPE = "m.login.saml2"
+
+    def __init__(self, hs):
+        super(LoginRestServlet, self).__init__(hs)
+        self.idp_redirect_url = hs.config.saml2_idp_redirect_url
+        self.saml2_enabled = hs.config.saml2_enabled
 
     def on_GET(self, request):
-        return (200, {"flows": [{"type": LoginRestServlet.PASS_TYPE}]})
+        flows = [{"type": LoginRestServlet.PASS_TYPE}]
+        if self.saml2_enabled:
+            flows.append({"type": LoginRestServlet.SAML2_TYPE})
+        return (200, {"flows": flows})
 
     def on_OPTIONS(self, request):
         return (200, {})
@@ -39,6 +57,16 @@ class LoginRestServlet(ClientV1RestServlet):
             if login_submission["type"] == LoginRestServlet.PASS_TYPE:
                 result = yield self.do_password_login(login_submission)
                 defer.returnValue(result)
+            elif self.saml2_enabled and (login_submission["type"] ==
+                                         LoginRestServlet.SAML2_TYPE):
+                relay_state = ""
+                if "relay_state" in login_submission:
+                    relay_state = "&RelayState="+urllib.quote(
+                                  login_submission["relay_state"])
+                result = {
+                    "uri": "%s%s" % (self.idp_redirect_url, relay_state)
+                }
+                defer.returnValue((200, result))
             else:
                 raise SynapseError(400, "Bad login type.")
         except KeyError:
@@ -94,6 +122,49 @@ class PasswordResetRestServlet(ClientV1RestServlet):
             )
 
 
+class SAML2RestServlet(ClientV1RestServlet):
+    PATTERN = client_path_pattern("/login/saml2")
+
+    def __init__(self, hs):
+        super(SAML2RestServlet, self).__init__(hs)
+        self.sp_config = hs.config.saml2_config_path
+
+    @defer.inlineCallbacks
+    def on_POST(self, request):
+        saml2_auth = None
+        try:
+            conf = config.SPConfig()
+            conf.load_file(self.sp_config)
+            SP = Saml2Client(conf)
+            saml2_auth = SP.parse_authn_request_response(
+                request.args['SAMLResponse'][0], BINDING_HTTP_POST)
+        except Exception, e:        # Not authenticated
+            logger.exception(e)
+        if saml2_auth and saml2_auth.status_ok() and not saml2_auth.not_signed:
+            username = saml2_auth.name_id.text
+            handler = self.handlers.registration_handler
+            (user_id, token) = yield handler.register_saml2(username)
+            # Forward to the RelayState callback along with ava
+            if 'RelayState' in request.args:
+                request.redirect(urllib.unquote(
+                                 request.args['RelayState'][0]) +
+                                 '?status=authenticated&access_token=' +
+                                 token + '&user_id=' + user_id + '&ava=' +
+                                 urllib.quote(json.dumps(saml2_auth.ava)))
+                request.finish()
+                defer.returnValue(None)
+            defer.returnValue((200, {"status": "authenticated",
+                                     "user_id": user_id, "token": token,
+                                     "ava": saml2_auth.ava}))
+        elif 'RelayState' in request.args:
+            request.redirect(urllib.unquote(
+                             request.args['RelayState'][0]) +
+                             '?status=not_authenticated')
+            request.finish()
+            defer.returnValue(None)
+        defer.returnValue((200, {"status": "not_authenticated"}))
+
+
 def _parse_json(request):
     try:
         content = json.loads(request.content.read())
@@ -106,4 +177,6 @@ def _parse_json(request):
 
 def register_servlets(hs, http_server):
     LoginRestServlet(hs).register(http_server)
+    if hs.config.saml2_enabled:
+        SAML2RestServlet(hs).register(http_server)
     # TODO PasswordResetRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/__init__.py b/synapse/rest/client/v2_alpha/__init__.py
index 231e1dd97a..33f961e898 100644
--- a/synapse/rest/client/v2_alpha/__init__.py
+++ b/synapse/rest/client/v2_alpha/__init__.py
@@ -20,6 +20,7 @@ from . import (
     register,
     auth,
     receipts,
+    keys,
 )
 
 from synapse.http.server import JsonResource
@@ -40,3 +41,4 @@ class ClientV2AlphaRestResource(JsonResource):
         register.register_servlets(hs, client_resource)
         auth.register_servlets(hs, client_resource)
         receipts.register_servlets(hs, client_resource)
+        keys.register_servlets(hs, client_resource)
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
new file mode 100644
index 0000000000..f031267751
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -0,0 +1,276 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError
+from synapse.http.servlet import RestServlet
+from syutil.jsonutil import encode_canonical_json
+
+from ._base import client_v2_pattern
+
+import simplejson as json
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class KeyUploadServlet(RestServlet):
+    """
+    POST /keys/upload/<device_id> HTTP/1.1
+    Content-Type: application/json
+
+    {
+      "device_keys": {
+        "user_id": "<user_id>",
+        "device_id": "<device_id>",
+        "valid_until_ts": <millisecond_timestamp>,
+        "algorithms": [
+          "m.olm.curve25519-aes-sha256",
+        ]
+        "keys": {
+          "<algorithm>:<device_id>": "<key_base64>",
+        },
+        "signatures:" {
+          "<user_id>" {
+            "<algorithm>:<device_id>": "<signature_base64>"
+      } } },
+      "one_time_keys": {
+        "<algorithm>:<key_id>": "<key_base64>"
+      },
+    }
+    """
+    PATTERN = client_v2_pattern("/keys/upload/(?P<device_id>[^/]*)")
+
+    def __init__(self, hs):
+        super(KeyUploadServlet, self).__init__()
+        self.store = hs.get_datastore()
+        self.clock = hs.get_clock()
+        self.auth = hs.get_auth()
+
+    @defer.inlineCallbacks
+    def on_POST(self, request, device_id):
+        auth_user, client_info = yield self.auth.get_user_by_req(request)
+        user_id = auth_user.to_string()
+        # TODO: Check that the device_id matches that in the authentication
+        # or derive the device_id from the authentication instead.
+        try:
+            body = json.loads(request.content.read())
+        except:
+            raise SynapseError(400, "Invalid key JSON")
+        time_now = self.clock.time_msec()
+
+        # TODO: Validate the JSON to make sure it has the right keys.
+        device_keys = body.get("device_keys", None)
+        if device_keys:
+            logger.info(
+                "Updating device_keys for device %r for user %r at %d",
+                device_id, auth_user, time_now
+            )
+            # TODO: Sign the JSON with the server key
+            yield self.store.set_e2e_device_keys(
+                user_id, device_id, time_now,
+                encode_canonical_json(device_keys)
+            )
+
+        one_time_keys = body.get("one_time_keys", None)
+        if one_time_keys:
+            logger.info(
+                "Adding %d one_time_keys for device %r for user %r at %d",
+                len(one_time_keys), device_id, user_id, time_now
+            )
+            key_list = []
+            for key_id, key_json in one_time_keys.items():
+                algorithm, key_id = key_id.split(":")
+                key_list.append((
+                    algorithm, key_id, encode_canonical_json(key_json)
+                ))
+
+            yield self.store.add_e2e_one_time_keys(
+                user_id, device_id, time_now, key_list
+            )
+
+        result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
+        defer.returnValue((200, {"one_time_key_counts": result}))
+
+    @defer.inlineCallbacks
+    def on_GET(self, request, device_id):
+        auth_user, client_info = yield self.auth.get_user_by_req(request)
+        user_id = auth_user.to_string()
+
+        result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
+        defer.returnValue((200, {"one_time_key_counts": result}))
+
+
+class KeyQueryServlet(RestServlet):
+    """
+    GET /keys/query/<user_id> HTTP/1.1
+
+    GET /keys/query/<user_id>/<device_id> HTTP/1.1
+
+    POST /keys/query HTTP/1.1
+    Content-Type: application/json
+    {
+      "device_keys": {
+        "<user_id>": ["<device_id>"]
+    } }
+
+    HTTP/1.1 200 OK
+    {
+      "device_keys": {
+        "<user_id>": {
+          "<device_id>": {
+            "user_id": "<user_id>", // Duplicated to be signed
+            "device_id": "<device_id>", // Duplicated to be signed
+            "valid_until_ts": <millisecond_timestamp>,
+            "algorithms": [ // List of supported algorithms
+              "m.olm.curve25519-aes-sha256",
+            ],
+            "keys": { // Must include a ed25519 signing key
+              "<algorithm>:<key_id>": "<key_base64>",
+            },
+            "signatures:" {
+              // Must be signed with device's ed25519 key
+              "<user_id>/<device_id>": {
+                "<algorithm>:<key_id>": "<signature_base64>"
+              }
+              // Must be signed by this server.
+              "<server_name>": {
+                "<algorithm>:<key_id>": "<signature_base64>"
+    } } } } } }
+    """
+
+    PATTERN = client_v2_pattern(
+        "/keys/query(?:"
+        "/(?P<user_id>[^/]*)(?:"
+        "/(?P<device_id>[^/]*)"
+        ")?"
+        ")?"
+    )
+
+    def __init__(self, hs):
+        super(KeyQueryServlet, self).__init__()
+        self.store = hs.get_datastore()
+        self.auth = hs.get_auth()
+
+    @defer.inlineCallbacks
+    def on_POST(self, request, user_id, device_id):
+        logger.debug("onPOST")
+        yield self.auth.get_user_by_req(request)
+        try:
+            body = json.loads(request.content.read())
+        except:
+            raise SynapseError(400, "Invalid key JSON")
+        query = []
+        for user_id, device_ids in body.get("device_keys", {}).items():
+            if not device_ids:
+                query.append((user_id, None))
+            else:
+                for device_id in device_ids:
+                    query.append((user_id, device_id))
+        results = yield self.store.get_e2e_device_keys([(user_id, device_id)])
+        defer.returnValue(self.json_result(request, results))
+
+    @defer.inlineCallbacks
+    def on_GET(self, request, user_id, device_id):
+        auth_user, client_info = yield self.auth.get_user_by_req(request)
+        auth_user_id = auth_user.to_string()
+        if not user_id:
+            user_id = auth_user_id
+        if not device_id:
+            device_id = None
+        # Returns a map of user_id->device_id->json_bytes.
+        results = yield self.store.get_e2e_device_keys([(user_id, device_id)])
+        defer.returnValue(self.json_result(request, results))
+
+    def json_result(self, request, results):
+        json_result = {}
+        for user_id, device_keys in results.items():
+            for device_id, json_bytes in device_keys.items():
+                json_result.setdefault(user_id, {})[device_id] = json.loads(
+                    json_bytes
+                )
+        return (200, {"device_keys": json_result})
+
+
+class OneTimeKeyServlet(RestServlet):
+    """
+    GET /keys/take/<user-id>/<device-id>/<algorithm> HTTP/1.1
+
+    POST /keys/take HTTP/1.1
+    {
+      "one_time_keys": {
+        "<user_id>": {
+          "<device_id>": "<algorithm>"
+    } } }
+
+    HTTP/1.1 200 OK
+    {
+      "one_time_keys": {
+        "<user_id>": {
+          "<device_id>": {
+            "<algorithm>:<key_id>": "<key_base64>"
+    } } } }
+
+    """
+    PATTERN = client_v2_pattern(
+        "/keys/take(?:/?|(?:/"
+        "(?P<user_id>[^/]*)/(?P<device_id>[^/]*)/(?P<algorithm>[^/]*)"
+        ")?)"
+    )
+
+    def __init__(self, hs):
+        super(OneTimeKeyServlet, self).__init__()
+        self.store = hs.get_datastore()
+        self.auth = hs.get_auth()
+        self.clock = hs.get_clock()
+
+    @defer.inlineCallbacks
+    def on_GET(self, request, user_id, device_id, algorithm):
+        yield self.auth.get_user_by_req(request)
+        results = yield self.store.take_e2e_one_time_keys(
+            [(user_id, device_id, algorithm)]
+        )
+        defer.returnValue(self.json_result(request, results))
+
+    @defer.inlineCallbacks
+    def on_POST(self, request, user_id, device_id, algorithm):
+        yield self.auth.get_user_by_req(request)
+        try:
+            body = json.loads(request.content.read())
+        except:
+            raise SynapseError(400, "Invalid key JSON")
+        query = []
+        for user_id, device_keys in body.get("one_time_keys", {}).items():
+            for device_id, algorithm in device_keys.items():
+                query.append((user_id, device_id, algorithm))
+        results = yield self.store.take_e2e_one_time_keys(query)
+        defer.returnValue(self.json_result(request, results))
+
+    def json_result(self, request, results):
+        json_result = {}
+        for user_id, device_keys in results.items():
+            for device_id, keys in device_keys.items():
+                for key_id, json_bytes in keys.items():
+                    json_result.setdefault(user_id, {})[device_id] = {
+                        key_id: json.loads(json_bytes)
+                    }
+        return (200, {"one_time_keys": json_result})
+
+
+def register_servlets(hs, http_server):
+    KeyUploadServlet(hs).register(http_server)
+    KeyQueryServlet(hs).register(http_server)
+    OneTimeKeyServlet(hs).register(http_server)
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 2bc88a7954..71d5d92500 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -37,6 +37,7 @@ from .rejections import RejectionsStore
 from .state import StateStore
 from .signatures import SignatureStore
 from .filtering import FilteringStore
+from .end_to_end_keys import EndToEndKeyStore
 
 from .receipts import ReceiptsStore
 
@@ -77,6 +78,7 @@ class DataStore(RoomMemberStore, RoomStore,
                 ApplicationServiceTransactionStore,
                 EventsStore,
                 ReceiptsStore,
+                EndToEndKeyStore,
                 ):
 
     def __init__(self, hs):
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
new file mode 100644
index 0000000000..99dc864e46
--- /dev/null
+++ b/synapse/storage/end_to_end_keys.py
@@ -0,0 +1,125 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from _base import SQLBaseStore
+
+
+class EndToEndKeyStore(SQLBaseStore):
+    def set_e2e_device_keys(self, user_id, device_id, time_now, json_bytes):
+        return self._simple_upsert(
+            table="e2e_device_keys_json",
+            keyvalues={
+                "user_id": user_id,
+                "device_id": device_id,
+            },
+            values={
+                "ts_added_ms": time_now,
+                "key_json": json_bytes,
+            }
+        )
+
+    def get_e2e_device_keys(self, query_list):
+        """Fetch a list of device keys.
+        Args:
+            query_list(list): List of pairs of user_ids and device_ids.
+        Returns:
+            Dict mapping from user-id to dict mapping from device_id to
+            key json byte strings.
+        """
+        def _get_e2e_device_keys(txn):
+            result = {}
+            for user_id, device_id in query_list:
+                user_result = result.setdefault(user_id, {})
+                keyvalues = {"user_id": user_id}
+                if device_id:
+                    keyvalues["device_id"] = device_id
+                rows = self._simple_select_list_txn(
+                    txn, table="e2e_device_keys_json",
+                    keyvalues=keyvalues,
+                    retcols=["device_id", "key_json"]
+                )
+                for row in rows:
+                    user_result[row["device_id"]] = row["key_json"]
+            return result
+        return self.runInteraction("get_e2e_device_keys", _get_e2e_device_keys)
+
+    def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list):
+        def _add_e2e_one_time_keys(txn):
+            for (algorithm, key_id, json_bytes) in key_list:
+                self._simple_upsert_txn(
+                    txn, table="e2e_one_time_keys_json",
+                    keyvalues={
+                        "user_id": user_id,
+                        "device_id": device_id,
+                        "algorithm": algorithm,
+                        "key_id": key_id,
+                    },
+                    values={
+                        "ts_added_ms": time_now,
+                        "key_json": json_bytes,
+                    }
+                )
+        return self.runInteraction(
+            "add_e2e_one_time_keys", _add_e2e_one_time_keys
+        )
+
+    def count_e2e_one_time_keys(self, user_id, device_id):
+        """ Count the number of one time keys the server has for a device
+        Returns:
+            Dict mapping from algorithm to number of keys for that algorithm.
+        """
+        def _count_e2e_one_time_keys(txn):
+            sql = (
+                "SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json"
+                " WHERE user_id = ? AND device_id = ?"
+                " GROUP BY algorithm"
+            )
+            txn.execute(sql, (user_id, device_id))
+            result = {}
+            for algorithm, key_count in txn.fetchall():
+                result[algorithm] = key_count
+            return result
+        return self.runInteraction(
+            "count_e2e_one_time_keys", _count_e2e_one_time_keys
+        )
+
+    def take_e2e_one_time_keys(self, query_list):
+        """Take a list of one time keys out of the database"""
+        def _take_e2e_one_time_keys(txn):
+            sql = (
+                "SELECT key_id, key_json FROM e2e_one_time_keys_json"
+                " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
+                " LIMIT 1"
+            )
+            result = {}
+            delete = []
+            for user_id, device_id, algorithm in query_list:
+                user_result = result.setdefault(user_id, {})
+                device_result = user_result.setdefault(device_id, {})
+                txn.execute(sql, (user_id, device_id, algorithm))
+                for key_id, key_json in txn.fetchall():
+                    device_result[algorithm + ":" + key_id] = key_json
+                    delete.append((user_id, device_id, algorithm, key_id))
+            sql = (
+                "DELETE FROM e2e_one_time_keys_json"
+                " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
+                " AND key_id = ?"
+            )
+            for user_id, device_id, algorithm, key_id in delete:
+                txn.execute(sql, (user_id, device_id, algorithm, key_id))
+            return result
+        return self.runInteraction(
+            "take_e2e_one_time_keys", _take_e2e_one_time_keys
+        )
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 5bdf497b93..940a5f7e08 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from _base import SQLBaseStore
+from _base import SQLBaseStore, cached
 
 from twisted.internet import defer
 
@@ -71,6 +71,25 @@ class KeyStore(SQLBaseStore):
             desc="store_server_certificate",
         )
 
+    @cached()
+    @defer.inlineCallbacks
+    def get_all_server_verify_keys(self, server_name):
+        rows = yield self._simple_select_list(
+            table="server_signature_keys",
+            keyvalues={
+                "server_name": server_name,
+            },
+            retcols=["key_id", "verify_key"],
+            desc="get_all_server_verify_keys",
+        )
+
+        defer.returnValue({
+            row["key_id"]: decode_verify_key_bytes(
+                row["key_id"], str(row["verify_key"])
+            )
+            for row in rows
+        })
+
     @defer.inlineCallbacks
     def get_server_verify_keys(self, server_name, key_ids):
         """Retrieve the NACL verification key for a given server for the given
@@ -81,24 +100,14 @@ class KeyStore(SQLBaseStore):
         Returns:
             (list of VerifyKey): The verification keys.
         """
-        sql = (
-            "SELECT key_id, verify_key FROM server_signature_keys"
-            " WHERE server_name = ?"
-            " AND key_id in (" + ",".join("?" for key_id in key_ids) + ")"
-        )
-
-        rows = yield self._execute_and_decode(
-            "get_server_verify_keys", sql, server_name, *key_ids
-        )
-
-        keys = []
-        for row in rows:
-            key_id = row["key_id"]
-            key_bytes = row["verify_key"]
-            key = decode_verify_key_bytes(key_id, str(key_bytes))
-            keys.append(key)
-        defer.returnValue(keys)
+        keys = yield self.get_all_server_verify_keys(server_name)
+        defer.returnValue({
+            k: keys[k]
+            for k in key_ids
+            if k in keys and keys[k]
+        })
 
+    @defer.inlineCallbacks
     def store_server_verify_key(self, server_name, from_server, time_now_ms,
                                 verify_key):
         """Stores a NACL verification key for the given server.
@@ -109,7 +118,7 @@ class KeyStore(SQLBaseStore):
             ts_now_ms (int): The time now in milliseconds
             verification_key (VerifyKey): The NACL verify key.
         """
-        return self._simple_upsert(
+        yield self._simple_upsert(
             table="server_signature_keys",
             keyvalues={
                 "server_name": server_name,
@@ -123,6 +132,8 @@ class KeyStore(SQLBaseStore):
             desc="store_server_verify_key",
         )
 
+        self.get_all_server_verify_keys.invalidate(server_name)
+
     def store_server_keys_json(self, server_name, key_id, from_server,
                                ts_now_ms, ts_expires_ms, key_json_bytes):
         """Stores the JSON bytes for a set of keys from a server
@@ -152,6 +163,7 @@ class KeyStore(SQLBaseStore):
                 "ts_valid_until_ms": ts_expires_ms,
                 "key_json": buffer(key_json_bytes),
             },
+            desc="store_server_keys_json",
         )
 
     def get_server_keys_json(self, server_keys):
diff --git a/synapse/storage/schema/delta/21/end_to_end_keys.sql b/synapse/storage/schema/delta/21/end_to_end_keys.sql
new file mode 100644
index 0000000000..8b4a380d11
--- /dev/null
+++ b/synapse/storage/schema/delta/21/end_to_end_keys.sql
@@ -0,0 +1,34 @@
+/* Copyright 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+CREATE TABLE IF NOT EXISTS e2e_device_keys_json (
+    user_id TEXT NOT NULL, -- The user these keys are for.
+    device_id TEXT NOT NULL, -- Which of the user's devices these keys are for.
+    ts_added_ms BIGINT NOT NULL, -- When the keys were uploaded.
+    key_json TEXT NOT NULL, -- The keys for the device as a JSON blob.
+    CONSTRAINT e2e_device_keys_json_uniqueness UNIQUE (user_id, device_id)
+);
+
+
+CREATE TABLE IF NOT EXISTS e2e_one_time_keys_json (
+    user_id TEXT NOT NULL, -- The user this one-time key is for.
+    device_id TEXT NOT NULL, -- The device this one-time key is for.
+    algorithm TEXT NOT NULL, -- Which algorithm this one-time key is for.
+    key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads.
+    ts_added_ms BIGINT NOT NULL, -- When this key was uploaded.
+    key_json TEXT NOT NULL, -- The key as a JSON blob.
+    CONSTRAINT e2e_one_time_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm, key_id)
+);
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index f2b17f29ea..d7844edee3 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -92,11 +92,11 @@ class StateStore(SQLBaseStore):
         defer.returnValue(dict(state_list))
 
     @cached(num_args=1)
-    def _fetch_events_for_group(self, state_group, events):
+    def _fetch_events_for_group(self, key, events):
         return self._get_events(
             events, get_prev_content=False
         ).addCallback(
-            lambda evs: (state_group, evs)
+            lambda evs: (key, evs)
         )
 
     def _store_state_groups_txn(self, txn, event, context):
@@ -194,6 +194,65 @@ class StateStore(SQLBaseStore):
         events = yield self._get_events(event_ids, get_prev_content=False)
         defer.returnValue(events)
 
+    @defer.inlineCallbacks
+    def get_state_for_events(self, room_id, event_ids):
+        def f(txn):
+            groups = set()
+            event_to_group = {}
+            for event_id in event_ids:
+                # TODO: Remove this loop.
+                group = self._simple_select_one_onecol_txn(
+                    txn,
+                    table="event_to_state_groups",
+                    keyvalues={"event_id": event_id},
+                    retcol="state_group",
+                    allow_none=True,
+                )
+                if group:
+                    event_to_group[event_id] = group
+                    groups.add(group)
+
+            group_to_state_ids = {}
+            for group in groups:
+                state_ids = self._simple_select_onecol_txn(
+                    txn,
+                    table="state_groups_state",
+                    keyvalues={"state_group": group},
+                    retcol="event_id",
+                )
+
+                group_to_state_ids[group] = state_ids
+
+            return event_to_group, group_to_state_ids
+
+        res = yield self.runInteraction(
+            "annotate_events_with_state_groups",
+            f,
+        )
+
+        event_to_group, group_to_state_ids = res
+
+        state_list = yield defer.gatherResults(
+            [
+                self._fetch_events_for_group(group, vals)
+                for group, vals in group_to_state_ids.items()
+            ],
+            consumeErrors=True,
+        )
+
+        state_dict = {
+            group: {
+                (ev.type, ev.state_key): ev
+                for ev in state
+            }
+            for group, state in state_list
+        }
+
+        defer.returnValue([
+            state_dict.get(event_to_group.get(event, None), None)
+            for event in event_ids
+        ])
+
 
 def _make_group_id(clock):
     return str(int(clock.time_msec())) + random_string(5)