summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/handlers/test_identity.py116
-rw-r--r--tests/handlers/test_profile.py13
-rw-r--r--tests/handlers/test_register.py21
-rw-r--r--tests/handlers/test_stats.py6
-rw-r--r--tests/handlers/test_user_directory.py135
-rw-r--r--tests/http/federation/test_matrix_federation_agent.py2
-rw-r--r--tests/rest/client/test_identity.py145
-rw-r--r--tests/rest/client/test_retention.py2
-rw-r--r--tests/rest/client/test_room_access_rules.py727
-rw-r--r--tests/rest/client/v2_alpha/test_register.py205
-rw-r--r--tests/rulecheck/__init__.py14
-rw-r--r--tests/rulecheck/test_domainrulecheck.py334
-rw-r--r--tests/storage/test_main.py4
-rw-r--r--tests/storage/test_profile.py8
-rw-r--r--tests/test_types.py22
15 files changed, 1719 insertions, 35 deletions
diff --git a/tests/handlers/test_identity.py b/tests/handlers/test_identity.py
new file mode 100644
index 0000000000..0ab0356109
--- /dev/null
+++ b/tests/handlers/test_identity.py
@@ -0,0 +1,116 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from mock import Mock
+
+from twisted.internet import defer
+
+import synapse.rest.admin
+from synapse.rest.client.v1 import login
+from synapse.rest.client.v2_alpha import account
+
+from tests import unittest
+
+
+class ThreepidISRewrittenURLTestCase(unittest.HomeserverTestCase):
+
+    servlets = [
+        synapse.rest.admin.register_servlets_for_client_rest_resource,
+        login.register_servlets,
+        account.register_servlets,
+    ]
+
+    def make_homeserver(self, reactor, clock):
+        self.address = "test@test"
+        self.is_server_name = "testis"
+        self.is_server_url = "https://testis"
+        self.rewritten_is_url = "https://int.testis"
+
+        config = self.default_config()
+        config["trusted_third_party_id_servers"] = [self.is_server_name]
+        config["rewrite_identity_server_urls"] = {
+            self.is_server_url: self.rewritten_is_url
+        }
+
+        mock_http_client = Mock(spec=["get_json", "post_json_get_json"])
+        mock_http_client.get_json.side_effect = defer.succeed({})
+        mock_http_client.post_json_get_json.return_value = defer.succeed(
+            {"address": self.address, "medium": "email"}
+        )
+
+        self.hs = self.setup_test_homeserver(
+            config=config, simple_http_client=mock_http_client
+        )
+
+        mock_blacklisting_http_client = Mock(spec=["get_json", "post_json_get_json"])
+        mock_blacklisting_http_client.get_json.side_effect = defer.succeed({})
+        mock_blacklisting_http_client.post_json_get_json.return_value = defer.succeed(
+            {"address": self.address, "medium": "email"}
+        )
+
+        # TODO: This class does not use a singleton to get it's http client
+        # This should be fixed for easier testing
+        # https://github.com/matrix-org/synapse-dinsic/issues/26
+        self.hs.get_handlers().identity_handler.blacklisting_http_client = (
+            mock_blacklisting_http_client
+        )
+
+        return self.hs
+
+    def prepare(self, reactor, clock, hs):
+        self.user_id = self.register_user("kermit", "monkey")
+
+    def test_rewritten_id_server(self):
+        """
+        Tests that, when validating a 3PID association while rewriting the IS's server
+        name:
+        * the bind request is done against the rewritten hostname
+        * the original, non-rewritten, server name is stored in the database
+        """
+        handler = self.hs.get_handlers().identity_handler
+        post_json_get_json = handler.blacklisting_http_client.post_json_get_json
+        store = self.hs.get_datastore()
+
+        creds = {"sid": "123", "client_secret": "some_secret"}
+
+        # Make sure processing the mocked response goes through.
+        data = self.get_success(
+            handler.bind_threepid(
+                client_secret=creds["client_secret"],
+                sid=creds["sid"],
+                mxid=self.user_id,
+                id_server=self.is_server_name,
+                use_v2=False,
+            )
+        )
+        self.assertEqual(data.get("address"), self.address)
+
+        # Check that the request was done against the rewritten server name.
+        post_json_get_json.assert_called_once_with(
+            "%s/_matrix/identity/api/v1/3pid/bind" % (self.rewritten_is_url,),
+            {
+                "sid": creds["sid"],
+                "client_secret": creds["client_secret"],
+                "mxid": self.user_id,
+            },
+            headers={},
+        )
+
+        # Check that the original server name is saved in the database instead of the
+        # rewritten one.
+        id_servers = self.get_success(
+            store.get_id_servers_user_bound(self.user_id, "email", self.address)
+        )
+        self.assertEqual(id_servers, [self.is_server_name])
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 29dd7d9c6e..a1f4bde347 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -63,14 +63,12 @@ class ProfileTestCase(unittest.TestCase):
         self.bob = UserID.from_string("@4567:test")
         self.alice = UserID.from_string("@alice:remote")
 
-        yield self.store.create_profile(self.frank.localpart)
-
         self.handler = hs.get_profile_handler()
         self.hs = hs
 
     @defer.inlineCallbacks
     def test_get_my_name(self):
-        yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
+        yield self.store.set_profile_displayname(self.frank.localpart, "Frank", 1)
 
         displayname = yield self.handler.get_displayname(self.frank)
 
@@ -109,7 +107,7 @@ class ProfileTestCase(unittest.TestCase):
         self.hs.config.enable_set_displayname = False
 
         # Setting displayname for the first time is allowed
-        yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
+        yield self.store.set_profile_displayname(self.frank.localpart, "Frank", 1)
 
         self.assertEquals(
             (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank",
@@ -152,8 +150,7 @@ class ProfileTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_incoming_fed_query(self):
-        yield self.store.create_profile("caroline")
-        yield self.store.set_profile_displayname("caroline", "Caroline")
+        yield self.store.set_profile_displayname("caroline", "Caroline", 1)
 
         response = yield self.query_handlers["profile"](
             {"user_id": "@caroline:test", "field": "displayname"}
@@ -164,7 +161,7 @@ class ProfileTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_get_my_avatar(self):
         yield self.store.set_profile_avatar_url(
-            self.frank.localpart, "http://my.server/me.png"
+            self.frank.localpart, "http://my.server/me.png", 1
         )
 
         avatar_url = yield self.handler.get_avatar_url(self.frank)
@@ -206,7 +203,7 @@ class ProfileTestCase(unittest.TestCase):
 
         # Setting displayname for the first time is allowed
         yield self.store.set_profile_avatar_url(
-            self.frank.localpart, "http://my.server/me.png"
+            self.frank.localpart, "http://my.server/me.png", 1
         )
 
         self.assertEquals(
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index ca32f993a3..2a377a4eb9 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -20,6 +20,7 @@ from twisted.internet import defer
 from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, ResourceLimitError, SynapseError
 from synapse.handlers.register import RegistrationHandler
+from synapse.rest.client.v2_alpha.register import _map_email_to_displayname
 from synapse.types import RoomAlias, UserID, create_requester
 
 from .. import unittest
@@ -266,6 +267,26 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
             self.handler.register_user(localpart=invalid_user_id), SynapseError
         )
 
+    def test_email_to_displayname_mapping(self):
+        """Test that custom emails are mapped to new user displaynames correctly"""
+        self._check_mapping(
+            "jack-phillips.rivers@big-org.com", "Jack-Phillips Rivers [Big-Org]"
+        )
+
+        self._check_mapping("bob.jones@matrix.org", "Bob Jones [Tchap Admin]")
+
+        self._check_mapping("bob-jones.blabla@gouv.fr", "Bob-Jones Blabla [Gouv]")
+
+        # Multibyte unicode characters
+        self._check_mapping(
+            "j\u030a\u0065an-poppy.seed@example.com",
+            "J\u030a\u0065an-Poppy Seed [Example]",
+        )
+
+    def _check_mapping(self, i, expected):
+        result = _map_email_to_displayname(i)
+        self.assertEqual(result, expected)
+
     async def get_or_create_user(
         self, requester, localpart, displayname, password_hash=None
     ):
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index d9d312f0fb..8e6b0b7536 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -21,8 +21,12 @@ from tests import unittest
 
 # The expected number of state events in a fresh public room.
 EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM = 5
+
 # The expected number of state events in a fresh private room.
-EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM = 6
+#
+# Note: we increase this by 1 on the dinsic branch as we send
+# a "im.vector.room.access_rules" state event into new private rooms
+EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM = 7
 
 
 class StatsRoomTests(unittest.HomeserverTestCase):
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index c15bce5bef..0c5cdbd33a 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -19,7 +19,7 @@ from twisted.internet import defer
 import synapse.rest.admin
 from synapse.api.constants import UserTypes
 from synapse.rest.client.v1 import login, room
-from synapse.rest.client.v2_alpha import user_directory
+from synapse.rest.client.v2_alpha import account, account_validity, user_directory
 from synapse.storage.roommember import ProfileInfo
 
 from tests import unittest
@@ -460,3 +460,136 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
         self.render(request)
         self.assertEquals(200, channel.code, channel.result)
         self.assertTrue(len(channel.json_body["results"]) == 0)
+
+
+class UserInfoTestCase(unittest.FederatingHomeserverTestCase):
+    servlets = [
+        login.register_servlets,
+        synapse.rest.admin.register_servlets_for_client_rest_resource,
+        account_validity.register_servlets,
+        synapse.rest.client.v2_alpha.user_directory.register_servlets,
+        account.register_servlets,
+    ]
+
+    def default_config(self):
+        config = super().default_config()
+
+        # Set accounts to expire after a week
+        config["account_validity"] = {
+            "enabled": True,
+            "period": 604800000,  # Time in ms for 1 week
+        }
+        return config
+
+    def prepare(self, reactor, clock, hs):
+        super(UserInfoTestCase, self).prepare(reactor, clock, hs)
+        self.store = hs.get_datastore()
+        self.handler = hs.get_user_directory_handler()
+
+    def test_user_info(self):
+        """Test /users/info for local users from the Client-Server API"""
+        user_one, user_two, user_three, user_three_token = self.setup_test_users()
+
+        # Request info about each user from user_three
+        request, channel = self.make_request(
+            "POST",
+            path="/_matrix/client/unstable/users/info",
+            content={"user_ids": [user_one, user_two, user_three]},
+            access_token=user_three_token,
+            shorthand=False,
+        )
+        self.render(request)
+        self.assertEquals(200, channel.code, channel.result)
+
+        # Check the state of user_one matches
+        user_one_info = channel.json_body[user_one]
+        self.assertTrue(user_one_info["deactivated"])
+        self.assertFalse(user_one_info["expired"])
+
+        # Check the state of user_two matches
+        user_two_info = channel.json_body[user_two]
+        self.assertFalse(user_two_info["deactivated"])
+        self.assertTrue(user_two_info["expired"])
+
+        # Check the state of user_three matches
+        user_three_info = channel.json_body[user_three]
+        self.assertFalse(user_three_info["deactivated"])
+        self.assertFalse(user_three_info["expired"])
+
+    def test_user_info_federation(self):
+        """Test that /users/info can be called from the Federation API, and
+        and that we can query remote users from the Client-Server API
+        """
+        user_one, user_two, user_three, user_three_token = self.setup_test_users()
+
+        # Request information about our local users from the perspective of a remote server
+        request, channel = self.make_request(
+            "POST",
+            path="/_matrix/federation/unstable/users/info",
+            content={"user_ids": [user_one, user_two, user_three]},
+        )
+        self.render(request)
+        self.assertEquals(200, channel.code)
+
+        # Check the state of user_one matches
+        user_one_info = channel.json_body[user_one]
+        self.assertTrue(user_one_info["deactivated"])
+        self.assertFalse(user_one_info["expired"])
+
+        # Check the state of user_two matches
+        user_two_info = channel.json_body[user_two]
+        self.assertFalse(user_two_info["deactivated"])
+        self.assertTrue(user_two_info["expired"])
+
+        # Check the state of user_three matches
+        user_three_info = channel.json_body[user_three]
+        self.assertFalse(user_three_info["deactivated"])
+        self.assertFalse(user_three_info["expired"])
+
+    def setup_test_users(self):
+        """Create an admin user and three test users, each with a different state"""
+
+        # Create an admin user to expire other users with
+        self.register_user("admin", "adminpassword", admin=True)
+        admin_token = self.login("admin", "adminpassword")
+
+        # Create three users
+        user_one = self.register_user("alice", "pass")
+        user_one_token = self.login("alice", "pass")
+        user_two = self.register_user("bob", "pass")
+        user_three = self.register_user("carl", "pass")
+        user_three_token = self.login("carl", "pass")
+
+        # Deactivate user_one
+        self.deactivate(user_one, user_one_token)
+
+        # Expire user_two
+        self.expire(user_two, admin_token)
+
+        # Do nothing to user_three
+
+        return user_one, user_two, user_three, user_three_token
+
+    def expire(self, user_id_to_expire, admin_tok):
+        url = "/_matrix/client/unstable/admin/account_validity/validity"
+        request_data = {
+            "user_id": user_id_to_expire,
+            "expiration_ts": 0,
+            "enable_renewal_emails": False,
+        }
+        request, channel = self.make_request(
+            "POST", url, request_data, access_token=admin_tok
+        )
+        self.render(request)
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+
+    def deactivate(self, user_id, tok):
+        request_data = {
+            "auth": {"type": "m.login.password", "user": user_id, "password": "pass"},
+            "erase": False,
+        }
+        request, channel = self.make_request(
+            "POST", "account/deactivate", request_data, access_token=tok
+        )
+        self.render(request)
+        self.assertEqual(request.code, 200)
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 562397cdda..95c93d80e5 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -92,7 +92,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
 
         self.agent = MatrixFederationAgent(
             reactor=self.reactor,
-            tls_client_options_factory=self.tls_factory,
+            tls_client_options_factory=FederationPolicyForHTTPS(config),
             _srv_resolver=self.mock_resolver,
             _well_known_resolver=self.well_known_resolver,
         )
diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py
index c973521907..4224b0a92e 100644
--- a/tests/rest/client/test_identity.py
+++ b/tests/rest/client/test_identity.py
@@ -15,15 +15,22 @@
 
 import json
 
+from mock import Mock
+
+from twisted.internet import defer
+
 import synapse.rest.admin
 from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v2_alpha import account
 
 from tests import unittest
 
 
-class IdentityTestCase(unittest.HomeserverTestCase):
+class IdentityDisabledTestCase(unittest.HomeserverTestCase):
+    """Tests that 3PID lookup attempts fail when the HS's config disallows them."""
 
     servlets = [
+        account.register_servlets,
         synapse.rest.admin.register_servlets_for_client_rest_resource,
         room.register_servlets,
         login.register_servlets,
@@ -32,24 +39,111 @@ class IdentityTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
 
         config = self.default_config()
+        config["trusted_third_party_id_servers"] = ["testis"]
         config["enable_3pid_lookup"] = False
         self.hs = self.setup_test_homeserver(config=config)
 
         return self.hs
 
+    def prepare(self, reactor, clock, hs):
+        self.user_id = self.register_user("kermit", "monkey")
+        self.tok = self.login("kermit", "monkey")
+
+    def test_3pid_invite_disabled(self):
+        request, channel = self.make_request(
+            b"POST", "/createRoom", b"{}", access_token=self.tok
+        )
+        self.render(request)
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+        room_id = channel.json_body["room_id"]
+
+        params = {
+            "id_server": "testis",
+            "medium": "email",
+            "address": "test@example.com",
+        }
+        request_data = json.dumps(params)
+        request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii")
+        request, channel = self.make_request(
+            b"POST", request_url, request_data, access_token=self.tok
+        )
+        self.render(request)
+        self.assertEquals(channel.result["code"], b"403", channel.result)
+
     def test_3pid_lookup_disabled(self):
-        self.hs.config.enable_3pid_lookup = False
+        url = (
+            "/_matrix/client/unstable/account/3pid/lookup"
+            "?id_server=testis&medium=email&address=foo@bar.baz"
+        )
+        request, channel = self.make_request("GET", url, access_token=self.tok)
+        self.render(request)
+        self.assertEqual(channel.result["code"], b"403", channel.result)
+
+    def test_3pid_bulk_lookup_disabled(self):
+        url = "/_matrix/client/unstable/account/3pid/bulk_lookup"
+        data = {
+            "id_server": "testis",
+            "threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]],
+        }
+        request_data = json.dumps(data)
+        request, channel = self.make_request(
+            "POST", url, request_data, access_token=self.tok
+        )
+        self.render(request)
+        self.assertEqual(channel.result["code"], b"403", channel.result)
+
+
+class IdentityEnabledTestCase(unittest.HomeserverTestCase):
+    """Tests that 3PID lookup attempts succeed when the HS's config allows them."""
+
+    servlets = [
+        account.register_servlets,
+        synapse.rest.admin.register_servlets_for_client_rest_resource,
+        room.register_servlets,
+        login.register_servlets,
+    ]
 
-        self.register_user("kermit", "monkey")
-        tok = self.login("kermit", "monkey")
+    def make_homeserver(self, reactor, clock):
 
+        config = self.default_config()
+        config["enable_3pid_lookup"] = True
+        config["trusted_third_party_id_servers"] = ["testis"]
+
+        mock_http_client = Mock(spec=["get_json", "post_json_get_json"])
+        mock_http_client.get_json.return_value = defer.succeed((200, "{}"))
+        mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}"))
+
+        self.hs = self.setup_test_homeserver(
+            config=config, simple_http_client=mock_http_client
+        )
+
+        # TODO: This class does not use a singleton to get it's http client
+        # This should be fixed for easier testing
+        # https://github.com/matrix-org/synapse-dinsic/issues/26
+        self.hs.get_handlers().identity_handler.http_client = mock_http_client
+
+        return self.hs
+
+    def prepare(self, reactor, clock, hs):
+        self.user_id = self.register_user("kermit", "monkey")
+        self.tok = self.login("kermit", "monkey")
+
+    def test_3pid_invite_enabled(self):
         request, channel = self.make_request(
-            b"POST", "/createRoom", b"{}", access_token=tok
+            b"POST", "/createRoom", b"{}", access_token=self.tok
         )
         self.render(request)
         self.assertEquals(channel.result["code"], b"200", channel.result)
         room_id = channel.json_body["room_id"]
 
+        # Replace the blacklisting SimpleHttpClient with our mock
+        self.hs.get_room_member_handler().simple_http_client = Mock(
+            spec=["get_json", "post_json_get_json"]
+        )
+        self.hs.get_room_member_handler().simple_http_client.get_json.return_value = defer.succeed(
+            (200, "{}")
+        )
+
         params = {
             "id_server": "testis",
             "medium": "email",
@@ -58,7 +152,44 @@ class IdentityTestCase(unittest.HomeserverTestCase):
         request_data = json.dumps(params)
         request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii")
         request, channel = self.make_request(
-            b"POST", request_url, request_data, access_token=tok
+            b"POST", request_url, request_data, access_token=self.tok
         )
         self.render(request)
-        self.assertEquals(channel.result["code"], b"403", channel.result)
+
+        get_json = self.hs.get_handlers().identity_handler.http_client.get_json
+        get_json.assert_called_once_with(
+            "https://testis/_matrix/identity/api/v1/lookup",
+            {"address": "test@example.com", "medium": "email"},
+        )
+
+    def test_3pid_lookup_enabled(self):
+        url = (
+            "/_matrix/client/unstable/account/3pid/lookup"
+            "?id_server=testis&medium=email&address=foo@bar.baz"
+        )
+        request, channel = self.make_request("GET", url, access_token=self.tok)
+        self.render(request)
+
+        get_json = self.hs.get_simple_http_client().get_json
+        get_json.assert_called_once_with(
+            "https://testis/_matrix/identity/api/v1/lookup",
+            {"address": "foo@bar.baz", "medium": "email"},
+        )
+
+    def test_3pid_bulk_lookup_enabled(self):
+        url = "/_matrix/client/unstable/account/3pid/bulk_lookup"
+        data = {
+            "id_server": "testis",
+            "threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]],
+        }
+        request_data = json.dumps(data)
+        request, channel = self.make_request(
+            "POST", url, request_data, access_token=self.tok
+        )
+        self.render(request)
+
+        post_json = self.hs.get_simple_http_client().post_json_get_json
+        post_json.assert_called_once_with(
+            "https://testis/_matrix/identity/api/v1/bulk_lookup",
+            {"threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]]},
+        )
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index 95475bb651..9e549d8a91 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -34,6 +34,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
 
     def make_homeserver(self, reactor, clock):
         config = self.default_config()
+        config["default_room_version"] = "1"
         config["retention"] = {
             "enabled": True,
             "default_policy": {
@@ -203,6 +204,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
 
     def make_homeserver(self, reactor, clock):
         config = self.default_config()
+        config["default_room_version"] = "1"
         config["retention"] = {
             "enabled": True,
         }
diff --git a/tests/rest/client/test_room_access_rules.py b/tests/rest/client/test_room_access_rules.py
new file mode 100644
index 0000000000..7da0ef4e18
--- /dev/null
+++ b/tests/rest/client/test_room_access_rules.py
@@ -0,0 +1,727 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import json
+import random
+import string
+
+from mock import Mock
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.third_party_rules.access_rules import (
+    ACCESS_RULE_DIRECT,
+    ACCESS_RULE_RESTRICTED,
+    ACCESS_RULE_UNRESTRICTED,
+    ACCESS_RULES_TYPE,
+)
+
+from tests import unittest
+
+
+class RoomAccessTestCase(unittest.HomeserverTestCase):
+
+    servlets = [
+        admin.register_servlets,
+        login.register_servlets,
+        room.register_servlets,
+    ]
+
+    def make_homeserver(self, reactor, clock):
+        config = self.default_config()
+
+        config["third_party_event_rules"] = {
+            "module": "synapse.third_party_rules.access_rules.RoomAccessRules",
+            "config": {
+                "domains_forbidden_when_restricted": ["forbidden_domain"],
+                "id_server": "testis",
+            },
+        }
+        config["trusted_third_party_id_servers"] = ["testis"]
+
+        def send_invite(destination, room_id, event_id, pdu):
+            return defer.succeed(pdu)
+
+        def get_json(uri, args={}, headers=None):
+            address_domain = args["address"].split("@")[1]
+            return defer.succeed({"hs": address_domain})
+
+        def post_json_get_json(uri, post_json, args={}, headers=None):
+            token = "".join(random.choice(string.ascii_letters) for _ in range(10))
+            return defer.succeed(
+                {
+                    "token": token,
+                    "public_keys": [
+                        {
+                            "public_key": "serverpublickey",
+                            "key_validity_url": "https://testis/pubkey/isvalid",
+                        },
+                        {
+                            "public_key": "phemeralpublickey",
+                            "key_validity_url": "https://testis/pubkey/ephemeral/isvalid",
+                        },
+                    ],
+                    "display_name": "f...@b...",
+                }
+            )
+
+        mock_federation_client = Mock(spec=["send_invite"])
+        mock_federation_client.send_invite.side_effect = send_invite
+
+        mock_http_client = Mock(spec=["get_json", "post_json_get_json"],)
+        # Mocking the response for /info on the IS API.
+        mock_http_client.get_json.side_effect = get_json
+        # Mocking the response for /store-invite on the IS API.
+        mock_http_client.post_json_get_json.side_effect = post_json_get_json
+        self.hs = self.setup_test_homeserver(
+            config=config,
+            federation_client=mock_federation_client,
+            simple_http_client=mock_http_client,
+        )
+
+        # TODO: This class does not use a singleton to get it's http client
+        # This should be fixed for easier testing
+        # https://github.com/matrix-org/synapse-dinsic/issues/26
+        self.hs.get_handlers().identity_handler.blacklisting_http_client = (
+            mock_http_client
+        )
+
+        return self.hs
+
+    def prepare(self, reactor, clock, homeserver):
+        self.user_id = self.register_user("kermit", "monkey")
+        self.tok = self.login("kermit", "monkey")
+
+        self.restricted_room = self.create_room()
+        self.unrestricted_room = self.create_room(rule=ACCESS_RULE_UNRESTRICTED)
+        self.direct_rooms = [
+            self.create_room(direct=True),
+            self.create_room(direct=True),
+            self.create_room(direct=True),
+        ]
+
+        self.invitee_id = self.register_user("invitee", "test")
+        self.invitee_tok = self.login("invitee", "test")
+
+        self.helper.invite(
+            room=self.direct_rooms[0],
+            src=self.user_id,
+            targ=self.invitee_id,
+            tok=self.tok,
+        )
+
+    def test_create_room_no_rule(self):
+        """Tests that creating a room with no rule will set the default value."""
+        room_id = self.create_room()
+        rule = self.current_rule_in_room(room_id)
+
+        self.assertEqual(rule, ACCESS_RULE_RESTRICTED)
+
+    def test_create_room_direct_no_rule(self):
+        """Tests that creating a direct room with no rule will set the default value."""
+        room_id = self.create_room(direct=True)
+        rule = self.current_rule_in_room(room_id)
+
+        self.assertEqual(rule, ACCESS_RULE_DIRECT)
+
+    def test_create_room_valid_rule(self):
+        """Tests that creating a room with a valid rule will set the right value."""
+        room_id = self.create_room(rule=ACCESS_RULE_UNRESTRICTED)
+        rule = self.current_rule_in_room(room_id)
+
+        self.assertEqual(rule, ACCESS_RULE_UNRESTRICTED)
+
+    def test_create_room_invalid_rule(self):
+        """Tests that creating a room with an invalid rule will set fail."""
+        self.create_room(rule=ACCESS_RULE_DIRECT, expected_code=400)
+
+    def test_create_room_direct_invalid_rule(self):
+        """Tests that creating a direct room with an invalid rule will fail.
+        """
+        self.create_room(direct=True, rule=ACCESS_RULE_RESTRICTED, expected_code=400)
+
+    def test_public_room(self):
+        """Tests that it's not possible to have a room with the public join rule and an
+        access rule that's not restricted.
+        """
+        # Creating a room with the public_chat preset should succeed and set the access
+        # rule to restricted.
+        preset_room_id = self.create_room(preset=RoomCreationPreset.PUBLIC_CHAT)
+        self.assertEqual(
+            self.current_rule_in_room(preset_room_id), ACCESS_RULE_RESTRICTED
+        )
+
+        # Creating a room with the public join rule in its initial state should succeed
+        # and set the access rule to restricted.
+        init_state_room_id = self.create_room(
+            initial_state=[
+                {
+                    "type": "m.room.join_rules",
+                    "content": {"join_rule": JoinRules.PUBLIC},
+                }
+            ]
+        )
+        self.assertEqual(
+            self.current_rule_in_room(init_state_room_id), ACCESS_RULE_RESTRICTED
+        )
+
+        # Changing access rule to unrestricted should fail.
+        self.change_rule_in_room(
+            preset_room_id, ACCESS_RULE_UNRESTRICTED, expected_code=403
+        )
+        self.change_rule_in_room(
+            init_state_room_id, ACCESS_RULE_UNRESTRICTED, expected_code=403
+        )
+
+        # Changing access rule to direct should fail.
+        self.change_rule_in_room(preset_room_id, ACCESS_RULE_DIRECT, expected_code=403)
+        self.change_rule_in_room(
+            init_state_room_id, ACCESS_RULE_DIRECT, expected_code=403
+        )
+
+        # Changing join rule to public in an unrestricted room should fail.
+        self.change_join_rule_in_room(
+            self.unrestricted_room, JoinRules.PUBLIC, expected_code=403
+        )
+        # Changing join rule to public in an direct room should fail.
+        self.change_join_rule_in_room(
+            self.direct_rooms[0], JoinRules.PUBLIC, expected_code=403
+        )
+
+        # Creating a new room with the public_chat preset and an access rule that isn't
+        # restricted should fail.
+        self.create_room(
+            preset=RoomCreationPreset.PUBLIC_CHAT,
+            rule=ACCESS_RULE_UNRESTRICTED,
+            expected_code=400,
+        )
+        self.create_room(
+            preset=RoomCreationPreset.PUBLIC_CHAT,
+            rule=ACCESS_RULE_DIRECT,
+            expected_code=400,
+        )
+
+        # Creating a room with the public join rule in its initial state and an access
+        # rule that isn't restricted should fail.
+        self.create_room(
+            initial_state=[
+                {
+                    "type": "m.room.join_rules",
+                    "content": {"join_rule": JoinRules.PUBLIC},
+                }
+            ],
+            rule=ACCESS_RULE_UNRESTRICTED,
+            expected_code=400,
+        )
+        self.create_room(
+            initial_state=[
+                {
+                    "type": "m.room.join_rules",
+                    "content": {"join_rule": JoinRules.PUBLIC},
+                }
+            ],
+            rule=ACCESS_RULE_DIRECT,
+            expected_code=400,
+        )
+
+    def test_restricted(self):
+        """Tests that in restricted mode we're unable to invite users from blacklisted
+        servers but can invite other users.
+        """
+        # We can't invite a user from a forbidden HS.
+        self.helper.invite(
+            room=self.restricted_room,
+            src=self.user_id,
+            targ="@test:forbidden_domain",
+            tok=self.tok,
+            expect_code=403,
+        )
+
+        # We can invite a user which HS isn't forbidden.
+        self.helper.invite(
+            room=self.restricted_room,
+            src=self.user_id,
+            targ="@test:allowed_domain",
+            tok=self.tok,
+            expect_code=200,
+        )
+
+        # We can't send a 3PID invite to an address that is mapped to a forbidden HS.
+        self.send_threepid_invite(
+            address="test@forbidden_domain",
+            room_id=self.restricted_room,
+            expected_code=403,
+        )
+
+        # We can send a 3PID invite to an address that is mapped to an HS that's not
+        # forbidden.
+        self.send_threepid_invite(
+            address="test@allowed_domain",
+            room_id=self.restricted_room,
+            expected_code=200,
+        )
+
+    def test_direct(self):
+        """Tests that, in direct mode, other users than the initial two can't be invited,
+        but the following scenario works:
+          * invited user joins the room
+          * invited user leaves the room
+          * room creator re-invites invited user
+        Also tests that a user from a HS that's in the list of forbidden domains (to use
+        in restricted mode) can be invited.
+        """
+        not_invited_user = "@not_invited:forbidden_domain"
+
+        # We can't invite a new user to the room.
+        self.helper.invite(
+            room=self.direct_rooms[0],
+            src=self.user_id,
+            targ=not_invited_user,
+            tok=self.tok,
+            expect_code=403,
+        )
+
+        # The invited user can join the room.
+        self.helper.join(
+            room=self.direct_rooms[0],
+            user=self.invitee_id,
+            tok=self.invitee_tok,
+            expect_code=200,
+        )
+
+        # The invited user can leave the room.
+        self.helper.leave(
+            room=self.direct_rooms[0],
+            user=self.invitee_id,
+            tok=self.invitee_tok,
+            expect_code=200,
+        )
+
+        # The invited user can be re-invited to the room.
+        self.helper.invite(
+            room=self.direct_rooms[0],
+            src=self.user_id,
+            targ=self.invitee_id,
+            tok=self.tok,
+            expect_code=200,
+        )
+
+        # If we're alone in the room and have always been the only member, we can invite
+        # someone.
+        self.helper.invite(
+            room=self.direct_rooms[1],
+            src=self.user_id,
+            targ=not_invited_user,
+            tok=self.tok,
+            expect_code=200,
+        )
+
+        # Disable the 3pid invite ratelimiter
+        burst = self.hs.config.rc_third_party_invite.burst_count
+        per_second = self.hs.config.rc_third_party_invite.per_second
+        self.hs.config.rc_third_party_invite.burst_count = 10
+        self.hs.config.rc_third_party_invite.per_second = 0.1
+
+        # We can't send a 3PID invite to a room that already has two members.
+        self.send_threepid_invite(
+            address="test@allowed_domain",
+            room_id=self.direct_rooms[0],
+            expected_code=403,
+        )
+
+        # We can't send a 3PID invite to a room that already has a pending invite.
+        self.send_threepid_invite(
+            address="test@allowed_domain",
+            room_id=self.direct_rooms[1],
+            expected_code=403,
+        )
+
+        # We can send a 3PID invite to a room in which we've always been the only member.
+        self.send_threepid_invite(
+            address="test@forbidden_domain",
+            room_id=self.direct_rooms[2],
+            expected_code=200,
+        )
+
+        # We can send a 3PID invite to a room in which there's a 3PID invite.
+        self.send_threepid_invite(
+            address="test@forbidden_domain",
+            room_id=self.direct_rooms[2],
+            expected_code=403,
+        )
+
+        self.hs.config.rc_third_party_invite.burst_count = burst
+        self.hs.config.rc_third_party_invite.per_second = per_second
+
+    def test_unrestricted(self):
+        """Tests that, in unrestricted mode, we can invite whoever we want, but we can
+        only change the power level of users that wouldn't be forbidden in restricted
+        mode.
+        """
+        # We can invite
+        self.helper.invite(
+            room=self.unrestricted_room,
+            src=self.user_id,
+            targ="@test:forbidden_domain",
+            tok=self.tok,
+            expect_code=200,
+        )
+
+        self.helper.invite(
+            room=self.unrestricted_room,
+            src=self.user_id,
+            targ="@test:not_forbidden_domain",
+            tok=self.tok,
+            expect_code=200,
+        )
+
+        # We can send a 3PID invite to an address that is mapped to a forbidden HS.
+        self.send_threepid_invite(
+            address="test@forbidden_domain",
+            room_id=self.unrestricted_room,
+            expected_code=200,
+        )
+
+        # We can send a 3PID invite to an address that is mapped to an HS that's not
+        # forbidden.
+        self.send_threepid_invite(
+            address="test@allowed_domain",
+            room_id=self.unrestricted_room,
+            expected_code=200,
+        )
+
+        # We can send a power level event that doesn't redefine the default PL or set a
+        # non-default PL for a user that would be forbidden in restricted mode.
+        self.helper.send_state(
+            room_id=self.unrestricted_room,
+            event_type=EventTypes.PowerLevels,
+            body={"users": {self.user_id: 100, "@test:not_forbidden_domain": 10}},
+            tok=self.tok,
+            expect_code=200,
+        )
+
+        # We can't send a power level event that redefines the default PL and doesn't set
+        # a non-default PL for a user that would be forbidden in restricted mode.
+        self.helper.send_state(
+            room_id=self.unrestricted_room,
+            event_type=EventTypes.PowerLevels,
+            body={
+                "users": {self.user_id: 100, "@test:not_forbidden_domain": 10},
+                "users_default": 10,
+            },
+            tok=self.tok,
+            expect_code=403,
+        )
+
+        # We can't send a power level event that doesn't redefines the default PL but sets
+        # a non-default PL for a user that would be forbidden in restricted mode.
+        self.helper.send_state(
+            room_id=self.unrestricted_room,
+            event_type=EventTypes.PowerLevels,
+            body={"users": {self.user_id: 100, "@test:forbidden_domain": 10}},
+            tok=self.tok,
+            expect_code=403,
+        )
+
+    def test_change_rules(self):
+        """Tests that we can only change the current rule from restricted to
+        unrestricted.
+        """
+        # We can change the rule from restricted to unrestricted.
+        self.change_rule_in_room(
+            room_id=self.restricted_room,
+            new_rule=ACCESS_RULE_UNRESTRICTED,
+            expected_code=200,
+        )
+
+        # We can't change the rule from restricted to direct.
+        self.change_rule_in_room(
+            room_id=self.restricted_room, new_rule=ACCESS_RULE_DIRECT, expected_code=403
+        )
+
+        # We can't change the rule from unrestricted to restricted.
+        self.change_rule_in_room(
+            room_id=self.unrestricted_room,
+            new_rule=ACCESS_RULE_RESTRICTED,
+            expected_code=403,
+        )
+
+        # We can't change the rule from unrestricted to direct.
+        self.change_rule_in_room(
+            room_id=self.unrestricted_room,
+            new_rule=ACCESS_RULE_DIRECT,
+            expected_code=403,
+        )
+
+        # We can't change the rule from direct to restricted.
+        self.change_rule_in_room(
+            room_id=self.direct_rooms[0],
+            new_rule=ACCESS_RULE_RESTRICTED,
+            expected_code=403,
+        )
+
+        # We can't change the rule from direct to unrestricted.
+        self.change_rule_in_room(
+            room_id=self.direct_rooms[0],
+            new_rule=ACCESS_RULE_UNRESTRICTED,
+            expected_code=403,
+        )
+
+    def test_change_room_avatar(self):
+        """Tests that changing the room avatar is always allowed unless the room is a
+        direct chat, in which case it's forbidden.
+        """
+
+        avatar_content = {
+            "info": {"h": 398, "mimetype": "image/jpeg", "size": 31037, "w": 394},
+            "url": "mxc://example.org/JWEIFJgwEIhweiWJE",
+        }
+
+        self.helper.send_state(
+            room_id=self.restricted_room,
+            event_type=EventTypes.RoomAvatar,
+            body=avatar_content,
+            tok=self.tok,
+            expect_code=200,
+        )
+
+        self.helper.send_state(
+            room_id=self.unrestricted_room,
+            event_type=EventTypes.RoomAvatar,
+            body=avatar_content,
+            tok=self.tok,
+            expect_code=200,
+        )
+
+        self.helper.send_state(
+            room_id=self.direct_rooms[0],
+            event_type=EventTypes.RoomAvatar,
+            body=avatar_content,
+            tok=self.tok,
+            expect_code=403,
+        )
+
+    def test_change_room_name(self):
+        """Tests that changing the room name is always allowed unless the room is a direct
+        chat, in which case it's forbidden.
+        """
+
+        name_content = {"name": "My super room"}
+
+        self.helper.send_state(
+            room_id=self.restricted_room,
+            event_type=EventTypes.Name,
+            body=name_content,
+            tok=self.tok,
+            expect_code=200,
+        )
+
+        self.helper.send_state(
+            room_id=self.unrestricted_room,
+            event_type=EventTypes.Name,
+            body=name_content,
+            tok=self.tok,
+            expect_code=200,
+        )
+
+        self.helper.send_state(
+            room_id=self.direct_rooms[0],
+            event_type=EventTypes.Name,
+            body=name_content,
+            tok=self.tok,
+            expect_code=403,
+        )
+
+    def test_change_room_topic(self):
+        """Tests that changing the room topic is always allowed unless the room is a
+        direct chat, in which case it's forbidden.
+        """
+
+        topic_content = {"topic": "Welcome to this room"}
+
+        self.helper.send_state(
+            room_id=self.restricted_room,
+            event_type=EventTypes.Topic,
+            body=topic_content,
+            tok=self.tok,
+            expect_code=200,
+        )
+
+        self.helper.send_state(
+            room_id=self.unrestricted_room,
+            event_type=EventTypes.Topic,
+            body=topic_content,
+            tok=self.tok,
+            expect_code=200,
+        )
+
+        self.helper.send_state(
+            room_id=self.direct_rooms[0],
+            event_type=EventTypes.Topic,
+            body=topic_content,
+            tok=self.tok,
+            expect_code=403,
+        )
+
+    def test_revoke_3pid_invite_direct(self):
+        """Tests that revoking a 3PID invite doesn't cause the room access rules module to
+        confuse the revokation as a new 3PID invite.
+        """
+        invite_token = "sometoken"
+
+        invite_body = {
+            "display_name": "ker...@exa...",
+            "public_keys": [
+                {
+                    "key_validity_url": "https://validity_url",
+                    "public_key": "ta8IQ0u1sp44HVpxYi7dFOdS/bfwDjcy4xLFlfY5KOA",
+                },
+                {
+                    "key_validity_url": "https://validity_url",
+                    "public_key": "4_9nzEeDwR5N9s51jPodBiLnqH43A2_g2InVT137t9I",
+                },
+            ],
+            "key_validity_url": "https://validity_url",
+            "public_key": "ta8IQ0u1sp44HVpxYi7dFOdS/bfwDjcy4xLFlfY5KOA",
+        }
+
+        self.send_state_with_state_key(
+            room_id=self.direct_rooms[1],
+            event_type=EventTypes.ThirdPartyInvite,
+            state_key=invite_token,
+            body=invite_body,
+            tok=self.tok,
+        )
+
+        self.send_state_with_state_key(
+            room_id=self.direct_rooms[1],
+            event_type=EventTypes.ThirdPartyInvite,
+            state_key=invite_token,
+            body={},
+            tok=self.tok,
+        )
+
+        invite_token = "someothertoken"
+
+        self.send_state_with_state_key(
+            room_id=self.direct_rooms[1],
+            event_type=EventTypes.ThirdPartyInvite,
+            state_key=invite_token,
+            body=invite_body,
+            tok=self.tok,
+        )
+
+    def create_room(
+        self,
+        direct=False,
+        rule=None,
+        preset=RoomCreationPreset.TRUSTED_PRIVATE_CHAT,
+        initial_state=None,
+        expected_code=200,
+    ):
+        content = {"is_direct": direct, "preset": preset}
+
+        if rule:
+            content["initial_state"] = [
+                {"type": ACCESS_RULES_TYPE, "state_key": "", "content": {"rule": rule}}
+            ]
+
+        if initial_state:
+            if "initial_state" not in content:
+                content["initial_state"] = []
+
+            content["initial_state"] += initial_state
+
+        request, channel = self.make_request(
+            "POST",
+            "/_matrix/client/r0/createRoom",
+            json.dumps(content),
+            access_token=self.tok,
+        )
+        self.render(request)
+
+        self.assertEqual(channel.code, expected_code, channel.result)
+
+        if expected_code == 200:
+            return channel.json_body["room_id"]
+
+    def current_rule_in_room(self, room_id):
+        request, channel = self.make_request(
+            "GET",
+            "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, ACCESS_RULES_TYPE),
+            access_token=self.tok,
+        )
+        self.render(request)
+
+        self.assertEqual(channel.code, 200, channel.result)
+        return channel.json_body["rule"]
+
+    def change_rule_in_room(self, room_id, new_rule, expected_code=200):
+        data = {"rule": new_rule}
+        request, channel = self.make_request(
+            "PUT",
+            "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, ACCESS_RULES_TYPE),
+            json.dumps(data),
+            access_token=self.tok,
+        )
+        self.render(request)
+
+        self.assertEqual(channel.code, expected_code, channel.result)
+
+    def change_join_rule_in_room(self, room_id, new_join_rule, expected_code=200):
+        data = {"join_rule": new_join_rule}
+        request, channel = self.make_request(
+            "PUT",
+            "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, EventTypes.JoinRules),
+            json.dumps(data),
+            access_token=self.tok,
+        )
+        self.render(request)
+
+        self.assertEqual(channel.code, expected_code, channel.result)
+
+    def send_threepid_invite(self, address, room_id, expected_code=200):
+        params = {"id_server": "testis", "medium": "email", "address": address}
+
+        request, channel = self.make_request(
+            "POST",
+            "/_matrix/client/r0/rooms/%s/invite" % room_id,
+            json.dumps(params),
+            access_token=self.tok,
+        )
+        self.render(request)
+        self.assertEqual(channel.code, expected_code, channel.result)
+
+    def send_state_with_state_key(
+        self, room_id, event_type, state_key, body, tok, expect_code=200
+    ):
+        path = "/_matrix/client/r0/rooms/%s/state/%s/%s" % (
+            room_id,
+            event_type,
+            state_key,
+        )
+
+        request, channel = self.make_request(
+            "PUT", path, json.dumps(body), access_token=tok
+        )
+        self.render(request)
+
+        self.assertEqual(channel.code, expect_code, channel.result)
+
+        return channel.json_body
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 7deaf5b24a..ceca4041e1 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -19,8 +19,12 @@ import datetime
 import json
 import os
 
+from mock import Mock
+
 import pkg_resources
 
+from twisted.internet import defer
+
 import synapse.rest.admin
 from synapse.api.constants import LoginType
 from synapse.api.errors import Codes
@@ -87,14 +91,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         self.assertEquals(channel.result["code"], b"400", channel.result)
         self.assertEquals(channel.json_body["error"], "Invalid password")
 
-    def test_POST_bad_username(self):
-        request_data = json.dumps({"username": 777, "password": "monkey"})
-        request, channel = self.make_request(b"POST", self.url, request_data)
-        self.render(request)
-
-        self.assertEquals(channel.result["code"], b"400", channel.result)
-        self.assertEquals(channel.json_body["error"], "Invalid username")
-
     def test_POST_user_valid(self):
         user_id = "@kermit:test"
         device_id = "frogfone"
@@ -303,6 +299,47 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         self.assertIsNotNone(channel.json_body.get("sid"))
 
 
+class RegisterHideProfileTestCase(unittest.HomeserverTestCase):
+
+    servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource]
+
+    def make_homeserver(self, reactor, clock):
+
+        self.url = b"/_matrix/client/r0/register"
+
+        config = self.default_config()
+        config["enable_registration"] = True
+        config["show_users_in_user_directory"] = False
+        config["replicate_user_profiles_to"] = ["fakeserver"]
+
+        mock_http_client = Mock(spec=["get_json", "post_json_get_json"])
+        mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}"))
+
+        self.hs = self.setup_test_homeserver(
+            config=config, simple_http_client=mock_http_client
+        )
+
+        return self.hs
+
+    def test_profile_hidden(self):
+        user_id = self.register_user("kermit", "monkey")
+
+        post_json = self.hs.get_simple_http_client().post_json_get_json
+
+        # We expect post_json_get_json to have been called twice: once with the original
+        # profile and once with the None profile resulting from the request to hide it
+        # from the user directory.
+        self.assertEqual(post_json.call_count, 2, post_json.call_args_list)
+
+        # Get the args (and not kwargs) passed to post_json.
+        args = post_json.call_args[0]
+        # Make sure the last call was attempting to replicate profiles.
+        split_uri = args[0].split("/")
+        self.assertEqual(split_uri[len(split_uri) - 1], "replicate_profiles", args[0])
+        # Make sure the last profile update was overriding the user's profile to None.
+        self.assertEqual(args[1]["batch"][user_id], None, args[1])
+
+
 class AccountValidityTestCase(unittest.HomeserverTestCase):
 
     servlets = [
@@ -312,6 +349,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
         sync.register_servlets,
         logout.register_servlets,
         account_validity.register_servlets,
+        account.register_servlets,
     ]
 
     def make_homeserver(self, reactor, clock):
@@ -437,6 +475,155 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
 
+class AccountValidityUserDirectoryTestCase(unittest.HomeserverTestCase):
+
+    servlets = [
+        synapse.rest.client.v1.profile.register_servlets,
+        synapse.rest.client.v1.room.register_servlets,
+        synapse.rest.client.v2_alpha.user_directory.register_servlets,
+        login.register_servlets,
+        register.register_servlets,
+        synapse.rest.admin.register_servlets_for_client_rest_resource,
+        account_validity.register_servlets,
+    ]
+
+    def make_homeserver(self, reactor, clock):
+        config = self.default_config()
+
+        # Set accounts to expire after a week
+        config["enable_registration"] = True
+        config["account_validity"] = {
+            "enabled": True,
+            "period": 604800000,  # Time in ms for 1 week
+        }
+        config["replicate_user_profiles_to"] = "test.is"
+
+        # Mock homeserver requests to an identity server
+        mock_http_client = Mock(spec=["post_json_get_json"])
+        mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}"))
+
+        self.hs = self.setup_test_homeserver(
+            config=config, simple_http_client=mock_http_client
+        )
+
+        return self.hs
+
+    def test_expired_user_in_directory(self):
+        """Test that an expired user is hidden in the user directory"""
+        # Create an admin user to search the user directory
+        admin_id = self.register_user("admin", "adminpassword", admin=True)
+        admin_tok = self.login("admin", "adminpassword")
+
+        # Ensure the admin never expires
+        url = "/_matrix/client/unstable/admin/account_validity/validity"
+        params = {
+            "user_id": admin_id,
+            "expiration_ts": 999999999999,
+            "enable_renewal_emails": False,
+        }
+        request_data = json.dumps(params)
+        request, channel = self.make_request(
+            b"POST", url, request_data, access_token=admin_tok
+        )
+        self.render(request)
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+
+        # Mock the homeserver's HTTP client
+        post_json = self.hs.get_simple_http_client().post_json_get_json
+
+        # Create a user
+        username = "kermit"
+        user_id = self.register_user(username, "monkey")
+        self.login(username, "monkey")
+        self.get_success(
+            self.hs.get_datastore().set_profile_displayname(username, "mr.kermit", 1)
+        )
+
+        # Check that a full profile for this user is replicated
+        self.assertIsNotNone(post_json.call_args, post_json.call_args)
+        payload = post_json.call_args[0][1]
+        batch = payload.get("batch")
+
+        self.assertIsNotNone(batch, batch)
+        self.assertEquals(len(batch), 1, batch)
+
+        replicated_user_id = list(batch.keys())[0]
+        self.assertEquals(replicated_user_id, user_id, replicated_user_id)
+
+        # There was replicated information about our user
+        # Check that it's not None
+        replicated_content = batch[user_id]
+        self.assertIsNotNone(replicated_content)
+
+        # Expire the user
+        url = "/_matrix/client/unstable/admin/account_validity/validity"
+        params = {
+            "user_id": user_id,
+            "expiration_ts": 0,
+            "enable_renewal_emails": False,
+        }
+        request_data = json.dumps(params)
+        request, channel = self.make_request(
+            b"POST", url, request_data, access_token=admin_tok
+        )
+        self.render(request)
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+
+        # Wait for the background job to run which hides expired users in the directory
+        self.reactor.advance(60 * 60 * 1000)
+
+        # Check if the homeserver has replicated the user's profile to the identity server
+        self.assertIsNotNone(post_json.call_args, post_json.call_args)
+        payload = post_json.call_args[0][1]
+        batch = payload.get("batch")
+
+        self.assertIsNotNone(batch, batch)
+        self.assertEquals(len(batch), 1, batch)
+
+        replicated_user_id = list(batch.keys())[0]
+        self.assertEquals(replicated_user_id, user_id, replicated_user_id)
+
+        # There was replicated information about our user
+        # Check that it's None, signifying that the user should be removed from the user
+        # directory because they were expired
+        replicated_content = batch[user_id]
+        self.assertIsNone(replicated_content)
+
+        # Now renew the user, and check they get replicated again to the identity server
+        url = "/_matrix/client/unstable/admin/account_validity/validity"
+        params = {
+            "user_id": user_id,
+            "expiration_ts": 99999999999,
+            "enable_renewal_emails": False,
+        }
+        request_data = json.dumps(params)
+        request, channel = self.make_request(
+            b"POST", url, request_data, access_token=admin_tok
+        )
+        self.render(request)
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+
+        self.pump(10)
+        self.reactor.advance(10)
+        self.pump()
+
+        # Check if the homeserver has replicated the user's profile to the identity server
+        post_json = self.hs.get_simple_http_client().post_json_get_json
+        self.assertNotEquals(post_json.call_args, None, post_json.call_args)
+        payload = post_json.call_args[0][1]
+        batch = payload.get("batch")
+        self.assertNotEquals(batch, None, batch)
+        self.assertEquals(len(batch), 1, batch)
+        replicated_user_id = list(batch.keys())[0]
+        self.assertEquals(replicated_user_id, user_id, replicated_user_id)
+
+        # There was replicated information about our user
+        # Check that it's not None, signifying that the user is back in the user
+        # directory
+        replicated_content = batch[user_id]
+        self.assertIsNotNone(replicated_content)
+
+
 class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
 
     servlets = [
@@ -587,7 +774,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
             "POST", "account/deactivate", request_data, access_token=tok
         )
         self.render(request)
-        self.assertEqual(request.code, 200)
+        self.assertEqual(request.code, 200, channel.result)
 
         self.reactor.advance(datetime.timedelta(days=8).total_seconds())
 
diff --git a/tests/rulecheck/__init__.py b/tests/rulecheck/__init__.py
new file mode 100644
index 0000000000..a354d38ca8
--- /dev/null
+++ b/tests/rulecheck/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/rulecheck/test_domainrulecheck.py b/tests/rulecheck/test_domainrulecheck.py
new file mode 100644
index 0000000000..1accc70dc9
--- /dev/null
+++ b/tests/rulecheck/test_domainrulecheck.py
@@ -0,0 +1,334 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import json
+
+import synapse.rest.admin
+from synapse.config._base import ConfigError
+from synapse.rest.client.v1 import login, room
+from synapse.rulecheck.domain_rule_checker import DomainRuleChecker
+
+from tests import unittest
+from tests.server import make_request, render
+
+
+class DomainRuleCheckerTestCase(unittest.TestCase):
+    def test_allowed(self):
+        config = {
+            "default": False,
+            "domain_mapping": {
+                "source_one": ["target_one", "target_two"],
+                "source_two": ["target_two"],
+            },
+            "domains_prevented_from_being_invited_to_published_rooms": ["target_two"],
+        }
+        check = DomainRuleChecker(config)
+        self.assertTrue(
+            check.user_may_invite(
+                "test:source_one", "test:target_one", None, "room", False
+            )
+        )
+        self.assertTrue(
+            check.user_may_invite(
+                "test:source_one", "test:target_two", None, "room", False
+            )
+        )
+        self.assertTrue(
+            check.user_may_invite(
+                "test:source_two", "test:target_two", None, "room", False
+            )
+        )
+
+        # User can invite internal user to a published room
+        self.assertTrue(
+            check.user_may_invite(
+                "test:source_one", "test1:target_one", None, "room", False, True
+            )
+        )
+
+        # User can invite external user to a non-published room
+        self.assertTrue(
+            check.user_may_invite(
+                "test:source_one", "test:target_two", None, "room", False, False
+            )
+        )
+
+    def test_disallowed(self):
+        config = {
+            "default": True,
+            "domain_mapping": {
+                "source_one": ["target_one", "target_two"],
+                "source_two": ["target_two"],
+                "source_four": [],
+            },
+        }
+        check = DomainRuleChecker(config)
+        self.assertFalse(
+            check.user_may_invite(
+                "test:source_one", "test:target_three", None, "room", False
+            )
+        )
+        self.assertFalse(
+            check.user_may_invite(
+                "test:source_two", "test:target_three", None, "room", False
+            )
+        )
+        self.assertFalse(
+            check.user_may_invite(
+                "test:source_two", "test:target_one", None, "room", False
+            )
+        )
+        self.assertFalse(
+            check.user_may_invite(
+                "test:source_four", "test:target_one", None, "room", False
+            )
+        )
+
+        # User cannot invite external user to a published room
+        self.assertTrue(
+            check.user_may_invite(
+                "test:source_one", "test:target_two", None, "room", False, True
+            )
+        )
+
+    def test_default_allow(self):
+        config = {
+            "default": True,
+            "domain_mapping": {
+                "source_one": ["target_one", "target_two"],
+                "source_two": ["target_two"],
+            },
+        }
+        check = DomainRuleChecker(config)
+        self.assertTrue(
+            check.user_may_invite(
+                "test:source_three", "test:target_one", None, "room", False
+            )
+        )
+
+    def test_default_deny(self):
+        config = {
+            "default": False,
+            "domain_mapping": {
+                "source_one": ["target_one", "target_two"],
+                "source_two": ["target_two"],
+            },
+        }
+        check = DomainRuleChecker(config)
+        self.assertFalse(
+            check.user_may_invite(
+                "test:source_three", "test:target_one", None, "room", False
+            )
+        )
+
+    def test_config_parse(self):
+        config = {
+            "default": False,
+            "domain_mapping": {
+                "source_one": ["target_one", "target_two"],
+                "source_two": ["target_two"],
+            },
+        }
+        self.assertEquals(config, DomainRuleChecker.parse_config(config))
+
+    def test_config_parse_failure(self):
+        config = {
+            "domain_mapping": {
+                "source_one": ["target_one", "target_two"],
+                "source_two": ["target_two"],
+            }
+        }
+        self.assertRaises(ConfigError, DomainRuleChecker.parse_config, config)
+
+
+class DomainRuleCheckerRoomTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        synapse.rest.admin.register_servlets_for_client_rest_resource,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
+    hijack_auth = False
+
+    def make_homeserver(self, reactor, clock):
+        config = self.default_config()
+        config["trusted_third_party_id_servers"] = ["localhost"]
+
+        config["spam_checker"] = {
+            "module": "synapse.rulecheck.domain_rule_checker.DomainRuleChecker",
+            "config": {
+                "default": True,
+                "domain_mapping": {},
+                "can_only_join_rooms_with_invite": True,
+                "can_only_create_one_to_one_rooms": True,
+                "can_only_invite_during_room_creation": True,
+                "can_invite_by_third_party_id": False,
+            },
+        }
+
+        hs = self.setup_test_homeserver(config=config)
+        return hs
+
+    def prepare(self, reactor, clock, hs):
+        self.admin_user_id = self.register_user("admin_user", "pass", admin=True)
+        self.admin_access_token = self.login("admin_user", "pass")
+
+        self.normal_user_id = self.register_user("normal_user", "pass", admin=False)
+        self.normal_access_token = self.login("normal_user", "pass")
+
+        self.other_user_id = self.register_user("other_user", "pass", admin=False)
+
+    def test_admin_can_create_room(self):
+        channel = self._create_room(self.admin_access_token)
+        assert channel.result["code"] == b"200", channel.result
+
+    def test_normal_user_cannot_create_empty_room(self):
+        channel = self._create_room(self.normal_access_token)
+        assert channel.result["code"] == b"403", channel.result
+
+    def test_normal_user_cannot_create_room_with_multiple_invites(self):
+        channel = self._create_room(
+            self.normal_access_token,
+            content={"invite": [self.other_user_id, self.admin_user_id]},
+        )
+        assert channel.result["code"] == b"403", channel.result
+
+        # Test that it correctly counts both normal and third party invites
+        channel = self._create_room(
+            self.normal_access_token,
+            content={
+                "invite": [self.other_user_id],
+                "invite_3pid": [{"medium": "email", "address": "foo@example.com"}],
+            },
+        )
+        assert channel.result["code"] == b"403", channel.result
+
+        # Test that it correctly rejects third party invites
+        channel = self._create_room(
+            self.normal_access_token,
+            content={
+                "invite": [],
+                "invite_3pid": [{"medium": "email", "address": "foo@example.com"}],
+            },
+        )
+        assert channel.result["code"] == b"403", channel.result
+
+    def test_normal_user_can_room_with_single_invites(self):
+        channel = self._create_room(
+            self.normal_access_token, content={"invite": [self.other_user_id]}
+        )
+        assert channel.result["code"] == b"200", channel.result
+
+    def test_cannot_join_public_room(self):
+        channel = self._create_room(self.admin_access_token)
+        assert channel.result["code"] == b"200", channel.result
+
+        room_id = channel.json_body["room_id"]
+
+        self.helper.join(
+            room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=403
+        )
+
+    def test_can_join_invited_room(self):
+        channel = self._create_room(self.admin_access_token)
+        assert channel.result["code"] == b"200", channel.result
+
+        room_id = channel.json_body["room_id"]
+
+        self.helper.invite(
+            room_id,
+            src=self.admin_user_id,
+            targ=self.normal_user_id,
+            tok=self.admin_access_token,
+        )
+
+        self.helper.join(
+            room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=200
+        )
+
+    def test_cannot_invite(self):
+        channel = self._create_room(self.admin_access_token)
+        assert channel.result["code"] == b"200", channel.result
+
+        room_id = channel.json_body["room_id"]
+
+        self.helper.invite(
+            room_id,
+            src=self.admin_user_id,
+            targ=self.normal_user_id,
+            tok=self.admin_access_token,
+        )
+
+        self.helper.join(
+            room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=200
+        )
+
+        self.helper.invite(
+            room_id,
+            src=self.normal_user_id,
+            targ=self.other_user_id,
+            tok=self.normal_access_token,
+            expect_code=403,
+        )
+
+    def test_cannot_3pid_invite(self):
+        """Test that unbound 3pid invites get rejected.
+        """
+        channel = self._create_room(self.admin_access_token)
+        assert channel.result["code"] == b"200", channel.result
+
+        room_id = channel.json_body["room_id"]
+
+        self.helper.invite(
+            room_id,
+            src=self.admin_user_id,
+            targ=self.normal_user_id,
+            tok=self.admin_access_token,
+        )
+
+        self.helper.join(
+            room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=200
+        )
+
+        self.helper.invite(
+            room_id,
+            src=self.normal_user_id,
+            targ=self.other_user_id,
+            tok=self.normal_access_token,
+            expect_code=403,
+        )
+
+        request, channel = self.make_request(
+            "POST",
+            "rooms/%s/invite" % (room_id),
+            {"address": "foo@bar.com", "medium": "email", "id_server": "localhost"},
+            access_token=self.normal_access_token,
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 403, channel.result["body"])
+
+    def _create_room(self, token, content={}):
+        path = "/_matrix/client/r0/createRoom?access_token=%s" % (token,)
+
+        request, channel = make_request(
+            self.hs.get_reactor(),
+            "POST",
+            path,
+            content=json.dumps(content).encode("utf8"),
+        )
+        render(request, self.resource, self.hs.get_reactor())
+
+        return channel
diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py
index ab0df5ea93..0155ffd04e 100644
--- a/tests/storage/test_main.py
+++ b/tests/storage/test_main.py
@@ -36,7 +36,9 @@ class DataStoreTestCase(unittest.TestCase):
     def test_get_users_paginate(self):
         yield self.store.register_user(self.user.to_string(), "pass")
         yield self.store.create_profile(self.user.localpart)
-        yield self.store.set_profile_displayname(self.user.localpart, self.displayname)
+        yield self.store.set_profile_displayname(
+            self.user.localpart, self.displayname, 1
+        )
 
         users, total = yield self.store.get_users_paginate(
             0, 10, name="bc", guests=False
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 9b6f7211ae..7458a37e54 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -33,9 +33,7 @@ class ProfileStoreTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_displayname(self):
-        yield self.store.create_profile(self.u_frank.localpart)
-
-        yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
+        yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank", 1)
 
         self.assertEquals(
             "Frank", (yield self.store.get_profile_displayname(self.u_frank.localpart))
@@ -43,10 +41,8 @@ class ProfileStoreTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_avatar_url(self):
-        yield self.store.create_profile(self.u_frank.localpart)
-
         yield self.store.set_profile_avatar_url(
-            self.u_frank.localpart, "http://my.site/here"
+            self.u_frank.localpart, "http://my.site/here", 1
         )
 
         self.assertEquals(
diff --git a/tests/test_types.py b/tests/test_types.py
index 480bea1bdc..d4a722a30f 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -12,9 +12,16 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from six import string_types
 
 from synapse.api.errors import SynapseError
-from synapse.types import GroupID, RoomAlias, UserID, map_username_to_mxid_localpart
+from synapse.types import (
+    GroupID,
+    RoomAlias,
+    UserID,
+    map_username_to_mxid_localpart,
+    strip_invalid_mxid_characters,
+)
 
 from tests import unittest
 
@@ -103,3 +110,16 @@ class MapUsernameTestCase(unittest.TestCase):
         self.assertEqual(
             map_username_to_mxid_localpart("têst".encode("utf-8")), "t=c3=aast"
         )
+
+
+class StripInvalidMxidCharactersTestCase(unittest.TestCase):
+    def test_return_type(self):
+        unstripped = strip_invalid_mxid_characters("test")
+        stripped = strip_invalid_mxid_characters("test@")
+
+        self.assertTrue(isinstance(unstripped, string_types), type(unstripped))
+        self.assertTrue(isinstance(stripped, string_types), type(stripped))
+
+    def test_strip(self):
+        stripped = strip_invalid_mxid_characters("test@")
+        self.assertEqual(stripped, "test", stripped)