diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 795703967d..c4f0bbd3dd 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -86,35 +86,6 @@ class KeyringTestCase(unittest.HomeserverTestCase):
getattr(LoggingContext.current_context(), "request", None), expected
)
- def test_wait_for_previous_lookups(self):
- kr = keyring.Keyring(self.hs)
-
- lookup_1_deferred = defer.Deferred()
- lookup_2_deferred = defer.Deferred()
-
- # we run the lookup in a logcontext so that the patched inlineCallbacks can check
- # it is doing the right thing with logcontexts.
- wait_1_deferred = run_in_context(
- kr.wait_for_previous_lookups, {"server1": lookup_1_deferred}
- )
-
- # there were no previous lookups, so the deferred should be ready
- self.successResultOf(wait_1_deferred)
-
- # set off another wait. It should block because the first lookup
- # hasn't yet completed.
- wait_2_deferred = run_in_context(
- kr.wait_for_previous_lookups, {"server1": lookup_2_deferred}
- )
-
- self.assertFalse(wait_2_deferred.called)
-
- # let the first lookup complete (in the sentinel context)
- lookup_1_deferred.callback(None)
-
- # now the second wait should complete.
- self.successResultOf(wait_2_deferred)
-
def test_verify_json_objects_for_server_awaits_previous_requests(self):
key1 = signedjson.key.generate_signing_key(1)
@@ -136,7 +107,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
self.assertEquals(LoggingContext.current_context().request, "11")
with PreserveLoggingContext():
yield persp_deferred
- defer.returnValue(persp_resp)
+ return persp_resp
self.http_client.post_json.side_effect = get_perspectives
@@ -583,7 +554,7 @@ def run_in_context(f, *args, **kwargs):
# logs.
ctx.request = "testctx"
rv = yield f(*args, **kwargs)
- defer.returnValue(rv)
+ return rv
def _verify_json_for_server(kr, *args):
@@ -594,6 +565,6 @@ def _verify_json_for_server(kr, *args):
@defer.inlineCallbacks
def v():
rv1 = yield kr.verify_json_for_server(*args)
- defer.returnValue(rv1)
+ return rv1
return run_in_context(v)
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
new file mode 100644
index 0000000000..b1ae15a2bd
--- /dev/null
+++ b/tests/handlers/test_federation.py
@@ -0,0 +1,85 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.api.constants import EventTypes
+from synapse.api.errors import AuthError, Codes
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+
+from tests import unittest
+
+
+class FederationTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver(http_client=None)
+ self.handler = hs.get_handlers().federation_handler
+ self.store = hs.get_datastore()
+ return hs
+
+ def test_exchange_revoked_invite(self):
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+
+ # Send a 3PID invite event with an empty body so it's considered as a revoked one.
+ invite_token = "sometoken"
+ self.helper.send_state(
+ room_id=room_id,
+ event_type=EventTypes.ThirdPartyInvite,
+ state_key=invite_token,
+ body={},
+ tok=tok,
+ )
+
+ d = self.handler.on_exchange_third_party_invite_request(
+ origin="example.com",
+ room_id=room_id,
+ event_dict={
+ "type": EventTypes.Member,
+ "room_id": room_id,
+ "sender": user_id,
+ "state_key": "@someone:example.org",
+ "content": {
+ "membership": "invite",
+ "third_party_invite": {
+ "display_name": "alice",
+ "signed": {
+ "mxid": "@alice:localhost",
+ "token": invite_token,
+ "signatures": {
+ "magic.forest": {
+ "ed25519:3": (
+ "fQpGIW1Snz+pwLZu6sTy2aHy/DYWWTspTJRPyNp0PKkymfIs"
+ "NffysMl6ObMMFdIJhk6g6pwlIqZ54rxo8SLmAg"
+ )
+ }
+ },
+ },
+ },
+ },
+ },
+ )
+
+ failure = self.get_failure(d, AuthError).value
+
+ self.assertEqual(failure.code, 403, failure)
+ self.assertEqual(failure.errcode, Codes.FORBIDDEN, failure)
+ self.assertEqual(failure.msg, "You are not invited to this room.")
diff --git a/tests/handlers/test_identity.py b/tests/handlers/test_identity.py
new file mode 100644
index 0000000000..32c31b2f66
--- /dev/null
+++ b/tests/handlers/test_identity.py
@@ -0,0 +1,101 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from mock import Mock
+
+from twisted.internet import defer
+
+import synapse.rest.admin
+from synapse.rest.client.v1 import login
+from synapse.rest.client.v2_alpha import account
+
+from tests import unittest
+
+
+class ThreepidISRewrittenURLTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ account.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ self.address = "test@test"
+ self.is_server_name = "testis"
+ self.rewritten_is_url = "int.testis"
+
+ config = self.default_config()
+ config["trusted_third_party_id_servers"] = [self.is_server_name]
+ config["rewrite_identity_server_urls"] = {
+ self.is_server_name: self.rewritten_is_url
+ }
+
+ mock_http_client = Mock(spec=["post_json_get_json"])
+ mock_http_client.post_json_get_json.return_value = defer.succeed(
+ {"address": self.address, "medium": "email"}
+ )
+
+ self.hs = self.setup_test_homeserver(
+ config=config, simple_http_client=mock_http_client
+ )
+
+ return self.hs
+
+ def prepare(self, reactor, clock, hs):
+ self.user_id = self.register_user("kermit", "monkey")
+
+ def test_rewritten_id_server(self):
+ """
+ Tests that, when validating a 3PID association while rewriting the IS's server
+ name:
+ * the bind request is done against the rewritten hostname
+ * the original, non-rewritten, server name is stored in the database
+ """
+ handler = self.hs.get_handlers().identity_handler
+ post_json_get_json = self.hs.get_simple_http_client().post_json_get_json
+ store = self.hs.get_datastore()
+
+ creds = {"sid": "123", "client_secret": "some_secret"}
+
+ # Make sure processing the mocked response goes through.
+ data = self.get_success(
+ handler.bind_threepid(
+ {
+ "id_server": self.is_server_name,
+ "client_secret": creds["client_secret"],
+ "sid": creds["sid"],
+ },
+ self.user_id,
+ )
+ )
+ self.assertEqual(data.get("address"), self.address)
+
+ # Check that the request was done against the rewritten server name.
+ post_json_get_json.assert_called_once_with(
+ "https://%s/_matrix/identity/api/v1/3pid/bind" % self.rewritten_is_url,
+ {
+ "sid": creds["sid"],
+ "client_secret": creds["client_secret"],
+ "mxid": self.user_id,
+ },
+ )
+
+ # Check that the original server name is saved in the database instead of the
+ # rewritten one.
+ id_servers = self.get_success(
+ store.get_id_servers_user_bound(self.user_id, "email", self.address)
+ )
+ self.assertEqual(id_servers, [self.is_server_name])
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index d60c124eec..2311040201 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -67,13 +67,11 @@ class ProfileTestCase(unittest.TestCase):
self.bob = UserID.from_string("@4567:test")
self.alice = UserID.from_string("@alice:remote")
- yield self.store.create_profile(self.frank.localpart)
-
self.handler = hs.get_profile_handler()
@defer.inlineCallbacks
def test_get_my_name(self):
- yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
+ yield self.store.set_profile_displayname(self.frank.localpart, "Frank", 1)
displayname = yield self.handler.get_displayname(self.frank)
@@ -116,8 +114,7 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_incoming_fed_query(self):
- yield self.store.create_profile("caroline")
- yield self.store.set_profile_displayname("caroline", "Caroline")
+ yield self.store.set_profile_displayname("caroline", "Caroline", 1)
response = yield self.query_handlers["profile"](
{"user_id": "@caroline:test", "field": "displayname"}
@@ -128,7 +125,7 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_my_avatar(self):
yield self.store.set_profile_avatar_url(
- self.frank.localpart, "http://my.server/me.png"
+ self.frank.localpart, "http://my.server/me.png", 1
)
avatar_url = yield self.handler.get_avatar_url(self.frank)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 90d0129374..408f8583f1 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -20,6 +20,7 @@ from twisted.internet import defer
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, ResourceLimitError, SynapseError
from synapse.handlers.register import RegistrationHandler
+from synapse.rest.client.v2_alpha.register import _map_email_to_displayname
from synapse.types import RoomAlias, UserID, create_requester
from .. import unittest
@@ -231,6 +232,26 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.handler.register_user(localpart=invalid_user_id), SynapseError
)
+ def test_email_to_displayname_mapping(self):
+ """Test that custom emails are mapped to new user displaynames correctly"""
+ self._check_mapping(
+ "jack-phillips.rivers@big-org.com", "Jack-Phillips Rivers [Big-Org]"
+ )
+
+ self._check_mapping("bob.jones@matrix.org", "Bob Jones [Tchap Admin]")
+
+ self._check_mapping("bob-jones.blabla@gouv.fr", "Bob-Jones Blabla [Gouv]")
+
+ # Multibyte unicode characters
+ self._check_mapping(
+ "j\u030a\u0065an-poppy.seed@example.com",
+ "J\u030a\u0065an-Poppy Seed [Example]",
+ )
+
+ def _check_mapping(self, i, expected):
+ result = _map_email_to_displayname(i)
+ self.assertEqual(result, expected)
+
@defer.inlineCallbacks
def get_or_create_user(self, requester, localpart, displayname, password_hash=None):
"""Creates a new user if the user does not exist,
@@ -283,4 +304,4 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
user, requester, displayname, by_admin=True
)
- defer.returnValue((user_id, token))
+ return (user_id, token)
diff --git a/tests/http/__init__.py b/tests/http/__init__.py
index 2d5dba6464..2096ba3c91 100644
--- a/tests/http/__init__.py
+++ b/tests/http/__init__.py
@@ -20,6 +20,23 @@ from zope.interface import implementer
from OpenSSL import SSL
from OpenSSL.SSL import Connection
from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
+from twisted.internet.ssl import Certificate, trustRootFromCertificates
+from twisted.web.client import BrowserLikePolicyForHTTPS # noqa: F401
+from twisted.web.iweb import IPolicyForHTTPS # noqa: F401
+
+
+def get_test_https_policy():
+ """Get a test IPolicyForHTTPS which trusts the test CA cert
+
+ Returns:
+ IPolicyForHTTPS
+ """
+ ca_file = get_test_ca_cert_file()
+ with open(ca_file) as stream:
+ content = stream.read()
+ cert = Certificate.loadPEM(content)
+ trust_root = trustRootFromCertificates([cert])
+ return BrowserLikePolicyForHTTPS(trustRoot=trust_root)
def get_test_ca_cert_file():
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index a49f9b3224..cf3f52fd9d 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -115,19 +115,24 @@ class MatrixFederationAgentTests(TestCase):
FakeTransport(client_protocol, self.reactor, server_tls_protocol)
)
+ # grab a hold of the TLS connection, in case it gets torn down
+ server_tls_connection = server_tls_protocol._tlsConnection
+
+ # fish the test server back out of the server-side TLS protocol.
+ http_protocol = server_tls_protocol.wrappedProtocol
+
# give the reactor a pump to get the TLS juices flowing.
self.reactor.pump((0.1,))
# check the SNI
- server_name = server_tls_protocol._tlsConnection.get_servername()
+ server_name = server_tls_connection.get_servername()
self.assertEqual(
server_name,
expected_sni,
"Expected SNI %s but got %s" % (expected_sni, server_name),
)
- # fish the test server back out of the server-side TLS protocol.
- return server_tls_protocol.wrappedProtocol
+ return http_protocol
@defer.inlineCallbacks
def _make_get_request(self, uri):
@@ -145,7 +150,7 @@ class MatrixFederationAgentTests(TestCase):
try:
fetch_res = yield fetch_d
- defer.returnValue(fetch_res)
+ return fetch_res
except Exception as e:
logger.info("Fetch of %s failed: %s", uri.decode("ascii"), e)
raise
@@ -936,7 +941,7 @@ class MatrixFederationAgentTests(TestCase):
except Exception as e:
logger.warning("Error fetching well-known: %s", e)
raise
- defer.returnValue(result)
+ return result
def test_well_known_cache(self):
self.reactor.lookups["testserv"] = "1.2.3.4"
diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py
index 65b51dc981..3b885ef64b 100644
--- a/tests/http/federation/test_srv_resolver.py
+++ b/tests/http/federation/test_srv_resolver.py
@@ -61,7 +61,7 @@ class SrvResolverTestCase(unittest.TestCase):
# should have restored our context
self.assertIs(LoggingContext.current_context(), ctx)
- defer.returnValue(result)
+ return result
test_d = do_lookup()
self.assertNoResult(test_d)
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index b9d6d7ad1c..2b01f40a42 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -68,7 +68,7 @@ class FederationClientTests(HomeserverTestCase):
try:
fetch_res = yield fetch_d
- defer.returnValue(fetch_res)
+ return fetch_res
finally:
check_logcontext(context)
diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
new file mode 100644
index 0000000000..22abf76515
--- /dev/null
+++ b/tests/http/test_proxyagent.py
@@ -0,0 +1,334 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+import treq
+
+from twisted.internet import interfaces # noqa: F401
+from twisted.internet.protocol import Factory
+from twisted.protocols.tls import TLSMemoryBIOFactory
+from twisted.web.http import HTTPChannel
+
+from synapse.http.proxyagent import ProxyAgent
+
+from tests.http import TestServerTLSConnectionFactory, get_test_https_policy
+from tests.server import FakeTransport, ThreadedMemoryReactorClock
+from tests.unittest import TestCase
+
+logger = logging.getLogger(__name__)
+
+HTTPFactory = Factory.forProtocol(HTTPChannel)
+
+
+class MatrixFederationAgentTests(TestCase):
+ def setUp(self):
+ self.reactor = ThreadedMemoryReactorClock()
+
+ def _make_connection(
+ self, client_factory, server_factory, ssl=False, expected_sni=None
+ ):
+ """Builds a test server, and completes the outgoing client connection
+
+ Args:
+ client_factory (interfaces.IProtocolFactory): the the factory that the
+ application is trying to use to make the outbound connection. We will
+ invoke it to build the client Protocol
+
+ server_factory (interfaces.IProtocolFactory): a factory to build the
+ server-side protocol
+
+ ssl (bool): If true, we will expect an ssl connection and wrap
+ server_factory with a TLSMemoryBIOFactory
+
+ expected_sni (bytes|None): the expected SNI value
+
+ Returns:
+ IProtocol: the server Protocol returned by server_factory
+ """
+ if ssl:
+ server_factory = _wrap_server_factory_for_tls(server_factory)
+
+ server_protocol = server_factory.buildProtocol(None)
+
+ # now, tell the client protocol factory to build the client protocol,
+ # and wire the output of said protocol up to the server via
+ # a FakeTransport.
+ #
+ # Normally this would be done by the TCP socket code in Twisted, but we are
+ # stubbing that out here.
+ client_protocol = client_factory.buildProtocol(None)
+ client_protocol.makeConnection(
+ FakeTransport(server_protocol, self.reactor, client_protocol)
+ )
+
+ # tell the server protocol to send its stuff back to the client, too
+ server_protocol.makeConnection(
+ FakeTransport(client_protocol, self.reactor, server_protocol)
+ )
+
+ if ssl:
+ http_protocol = server_protocol.wrappedProtocol
+ tls_connection = server_protocol._tlsConnection
+ else:
+ http_protocol = server_protocol
+ tls_connection = None
+
+ # give the reactor a pump to get the TLS juices flowing (if needed)
+ self.reactor.advance(0)
+
+ if expected_sni is not None:
+ server_name = tls_connection.get_servername()
+ self.assertEqual(
+ server_name,
+ expected_sni,
+ "Expected SNI %s but got %s" % (expected_sni, server_name),
+ )
+
+ return http_protocol
+
+ def test_http_request(self):
+ agent = ProxyAgent(self.reactor)
+
+ self.reactor.lookups["test.com"] = "1.2.3.4"
+ d = agent.request(b"GET", b"http://test.com")
+
+ # there should be a pending TCP connection
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 80)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory, _get_test_protocol_factory()
+ )
+
+ # the FakeTransport is async, so we need to pump the reactor
+ self.reactor.advance(0)
+
+ # now there should be a pending request
+ self.assertEqual(len(http_server.requests), 1)
+
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"/")
+ self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+ request.write(b"result")
+ request.finish()
+
+ self.reactor.advance(0)
+
+ resp = self.successResultOf(d)
+ body = self.successResultOf(treq.content(resp))
+ self.assertEqual(body, b"result")
+
+ def test_https_request(self):
+ agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy())
+
+ self.reactor.lookups["test.com"] = "1.2.3.4"
+ d = agent.request(b"GET", b"https://test.com/abc")
+
+ # there should be a pending TCP connection
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 443)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory,
+ _get_test_protocol_factory(),
+ ssl=True,
+ expected_sni=b"test.com",
+ )
+
+ # the FakeTransport is async, so we need to pump the reactor
+ self.reactor.advance(0)
+
+ # now there should be a pending request
+ self.assertEqual(len(http_server.requests), 1)
+
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"/abc")
+ self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+ request.write(b"result")
+ request.finish()
+
+ self.reactor.advance(0)
+
+ resp = self.successResultOf(d)
+ body = self.successResultOf(treq.content(resp))
+ self.assertEqual(body, b"result")
+
+ def test_http_request_via_proxy(self):
+ agent = ProxyAgent(self.reactor, http_proxy=b"proxy.com:8888")
+
+ self.reactor.lookups["proxy.com"] = "1.2.3.5"
+ d = agent.request(b"GET", b"http://test.com")
+
+ # there should be a pending TCP connection
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.5")
+ self.assertEqual(port, 8888)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory, _get_test_protocol_factory()
+ )
+
+ # the FakeTransport is async, so we need to pump the reactor
+ self.reactor.advance(0)
+
+ # now there should be a pending request
+ self.assertEqual(len(http_server.requests), 1)
+
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"http://test.com")
+ self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+ request.write(b"result")
+ request.finish()
+
+ self.reactor.advance(0)
+
+ resp = self.successResultOf(d)
+ body = self.successResultOf(treq.content(resp))
+ self.assertEqual(body, b"result")
+
+ def test_https_request_via_proxy(self):
+ agent = ProxyAgent(
+ self.reactor,
+ contextFactory=get_test_https_policy(),
+ https_proxy=b"proxy.com",
+ )
+
+ self.reactor.lookups["proxy.com"] = "1.2.3.5"
+ d = agent.request(b"GET", b"https://test.com/abc")
+
+ # there should be a pending TCP connection
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.5")
+ self.assertEqual(port, 1080)
+
+ # make a test HTTP server, and wire up the client
+ proxy_server = self._make_connection(
+ client_factory, _get_test_protocol_factory()
+ )
+
+ # fish the transports back out so that we can do the old switcheroo
+ s2c_transport = proxy_server.transport
+ client_protocol = s2c_transport.other
+ c2s_transport = client_protocol.transport
+
+ # the FakeTransport is async, so we need to pump the reactor
+ self.reactor.advance(0)
+
+ # now there should be a pending CONNECT request
+ self.assertEqual(len(proxy_server.requests), 1)
+
+ request = proxy_server.requests[0]
+ self.assertEqual(request.method, b"CONNECT")
+ self.assertEqual(request.path, b"test.com:443")
+
+ # tell the proxy server not to close the connection
+ proxy_server.persistent = True
+
+ # this just stops the http Request trying to do a chunked response
+ # request.setHeader(b"Content-Length", b"0")
+ request.finish()
+
+ # now we can replace the proxy channel with a new, SSL-wrapped HTTP channel
+ ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory())
+ ssl_protocol = ssl_factory.buildProtocol(None)
+ http_server = ssl_protocol.wrappedProtocol
+
+ ssl_protocol.makeConnection(
+ FakeTransport(client_protocol, self.reactor, ssl_protocol)
+ )
+ c2s_transport.other = ssl_protocol
+
+ self.reactor.advance(0)
+
+ server_name = ssl_protocol._tlsConnection.get_servername()
+ expected_sni = b"test.com"
+ self.assertEqual(
+ server_name,
+ expected_sni,
+ "Expected SNI %s but got %s" % (expected_sni, server_name),
+ )
+
+ # now there should be a pending request
+ self.assertEqual(len(http_server.requests), 1)
+
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"/abc")
+ self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+ request.write(b"result")
+ request.finish()
+
+ self.reactor.advance(0)
+
+ resp = self.successResultOf(d)
+ body = self.successResultOf(treq.content(resp))
+ self.assertEqual(body, b"result")
+
+
+def _wrap_server_factory_for_tls(factory, sanlist=None):
+ """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
+
+ The resultant factory will create a TLS server which presents a certificate
+ signed by our test CA, valid for the domains in `sanlist`
+
+ Args:
+ factory (interfaces.IProtocolFactory): protocol factory to wrap
+ sanlist (iterable[bytes]): list of domains the cert should be valid for
+
+ Returns:
+ interfaces.IProtocolFactory
+ """
+ if sanlist is None:
+ sanlist = [b"DNS:test.com"]
+
+ connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist)
+ return TLSMemoryBIOFactory(
+ connection_creator, isClient=False, wrappedFactory=factory
+ )
+
+
+def _get_test_protocol_factory():
+ """Get a protocol Factory which will build an HTTPChannel
+
+ Returns:
+ interfaces.IProtocolFactory
+ """
+ server_factory = Factory.forProtocol(HTTPChannel)
+
+ # Request.finish expects the factory to have a 'log' method.
+ server_factory.log = _log_request
+
+ return server_factory
+
+
+def _log_request(request):
+ """Implements Factory.log, which is expected by Request.finish"""
+ logger.info("Completed request %s", request)
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index 8ce6bb62da..af2327fb66 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -50,7 +50,7 @@ class HTTPPusherTests(HomeserverTestCase):
config = self.default_config()
config["start_pushers"] = True
- hs = self.setup_test_homeserver(config=config, simple_http_client=m)
+ hs = self.setup_test_homeserver(config=config, proxied_http_client=m)
return hs
diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py
index c973521907..f81f81602e 100644
--- a/tests/rest/client/test_identity.py
+++ b/tests/rest/client/test_identity.py
@@ -15,15 +15,22 @@
import json
+from mock import Mock
+
+from twisted.internet import defer
+
import synapse.rest.admin
from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v2_alpha import account
from tests import unittest
-class IdentityTestCase(unittest.HomeserverTestCase):
+class IdentityDisabledTestCase(unittest.HomeserverTestCase):
+ """Tests that 3PID lookup attempts fail when the HS's config disallows them."""
servlets = [
+ account.register_servlets,
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
login.register_servlets,
@@ -32,19 +39,93 @@ class IdentityTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
config = self.default_config()
+ config["trusted_third_party_id_servers"] = ["testis"]
config["enable_3pid_lookup"] = False
self.hs = self.setup_test_homeserver(config=config)
return self.hs
+ def prepare(self, reactor, clock, hs):
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+
+ def test_3pid_invite_disabled(self):
+ request, channel = self.make_request(
+ b"POST", "/createRoom", b"{}", access_token=self.tok
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+ room_id = channel.json_body["room_id"]
+
+ params = {
+ "id_server": "testis",
+ "medium": "email",
+ "address": "test@example.com",
+ }
+ request_data = json.dumps(params)
+ request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii")
+ request, channel = self.make_request(
+ b"POST", request_url, request_data, access_token=self.tok
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"403", channel.result)
+
def test_3pid_lookup_disabled(self):
- self.hs.config.enable_3pid_lookup = False
+ url = (
+ "/_matrix/client/unstable/account/3pid/lookup"
+ "?id_server=testis&medium=email&address=foo@bar.baz"
+ )
+ request, channel = self.make_request("GET", url, access_token=self.tok)
+ self.render(request)
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+
+ def test_3pid_bulk_lookup_disabled(self):
+ url = "/_matrix/client/unstable/account/3pid/bulk_lookup"
+ data = {
+ "id_server": "testis",
+ "threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]],
+ }
+ request_data = json.dumps(data)
+ request, channel = self.make_request(
+ "POST", url, request_data, access_token=self.tok
+ )
+ self.render(request)
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+
+
+class IdentityEnabledTestCase(unittest.HomeserverTestCase):
+ """Tests that 3PID lookup attempts succeed when the HS's config allows them."""
- self.register_user("kermit", "monkey")
- tok = self.login("kermit", "monkey")
+ servlets = [
+ account.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+
+ config = self.default_config()
+ config["enable_3pid_lookup"] = True
+ config["trusted_third_party_id_servers"] = ["testis"]
+
+ mock_http_client = Mock(spec=["get_json", "post_json_get_json"])
+ mock_http_client.get_json.return_value = defer.succeed((200, "{}"))
+ mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}"))
+
+ self.hs = self.setup_test_homeserver(
+ config=config, simple_http_client=mock_http_client
+ )
+ return self.hs
+
+ def prepare(self, reactor, clock, hs):
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+
+ def test_3pid_invite_enabled(self):
request, channel = self.make_request(
- b"POST", "/createRoom", b"{}", access_token=tok
+ b"POST", "/createRoom", b"{}", access_token=self.tok
)
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
@@ -58,7 +139,44 @@ class IdentityTestCase(unittest.HomeserverTestCase):
request_data = json.dumps(params)
request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii")
request, channel = self.make_request(
- b"POST", request_url, request_data, access_token=tok
+ b"POST", request_url, request_data, access_token=self.tok
)
self.render(request)
- self.assertEquals(channel.result["code"], b"403", channel.result)
+
+ get_json = self.hs.get_simple_http_client().get_json
+ get_json.assert_called_once_with(
+ "https://testis/_matrix/identity/api/v1/lookup",
+ {"address": "test@example.com", "medium": "email"},
+ )
+
+ def test_3pid_lookup_enabled(self):
+ url = (
+ "/_matrix/client/unstable/account/3pid/lookup"
+ "?id_server=testis&medium=email&address=foo@bar.baz"
+ )
+ request, channel = self.make_request("GET", url, access_token=self.tok)
+ self.render(request)
+
+ get_json = self.hs.get_simple_http_client().get_json
+ get_json.assert_called_once_with(
+ "https://testis/_matrix/identity/api/v1/lookup",
+ {"address": "foo@bar.baz", "medium": "email"},
+ )
+
+ def test_3pid_bulk_lookup_enabled(self):
+ url = "/_matrix/client/unstable/account/3pid/bulk_lookup"
+ data = {
+ "id_server": "testis",
+ "threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]],
+ }
+ request_data = json.dumps(data)
+ request, channel = self.make_request(
+ "POST", url, request_data, access_token=self.tok
+ )
+ self.render(request)
+
+ post_json = self.hs.get_simple_http_client().post_json_get_json
+ post_json.assert_called_once_with(
+ "https://testis/_matrix/identity/api/v1/bulk_lookup",
+ {"threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]]},
+ )
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
new file mode 100644
index 0000000000..4303f95206
--- /dev/null
+++ b/tests/rest/client/test_retention.py
@@ -0,0 +1,292 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from mock import Mock
+
+from synapse.api.constants import EventTypes
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.visibility import filter_events_for_client
+
+from tests import unittest
+
+one_hour_ms = 3600000
+one_day_ms = one_hour_ms * 24
+
+
+class RetentionTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+ config["default_room_version"] = "1"
+ config["retention"] = {
+ "enabled": True,
+ "default_policy": {
+ "min_lifetime": one_day_ms,
+ "max_lifetime": one_day_ms * 3,
+ },
+ "allowed_lifetime_min": one_day_ms,
+ "allowed_lifetime_max": one_day_ms * 3,
+ }
+
+ self.hs = self.setup_test_homeserver(config=config)
+ return self.hs
+
+ def prepare(self, reactor, clock, homeserver):
+ self.user_id = self.register_user("user", "password")
+ self.token = self.login("user", "password")
+
+ def test_retention_state_event(self):
+ """Tests that the server configuration can limit the values a user can set to the
+ room's retention policy.
+ """
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+ self.helper.send_state(
+ room_id=room_id,
+ event_type=EventTypes.Retention,
+ body={"max_lifetime": one_day_ms * 4},
+ tok=self.token,
+ expect_code=400,
+ )
+
+ self.helper.send_state(
+ room_id=room_id,
+ event_type=EventTypes.Retention,
+ body={"max_lifetime": one_hour_ms},
+ tok=self.token,
+ expect_code=400,
+ )
+
+ def test_retention_event_purged_with_state_event(self):
+ """Tests that expired events are correctly purged when the room's retention policy
+ is defined by a state event.
+ """
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+ # Set the room's retention period to 2 days.
+ lifetime = one_day_ms * 2
+ self.helper.send_state(
+ room_id=room_id,
+ event_type=EventTypes.Retention,
+ body={"max_lifetime": lifetime},
+ tok=self.token,
+ )
+
+ self._test_retention_event_purged(room_id, one_day_ms * 1.5)
+
+ def test_retention_event_purged_without_state_event(self):
+ """Tests that expired events are correctly purged when the room's retention policy
+ is defined by the server's configuration's default retention policy.
+ """
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+ self._test_retention_event_purged(room_id, one_day_ms * 2)
+
+ def test_visibility(self):
+ """Tests that synapse.visibility.filter_events_for_client correctly filters out
+ outdated events
+ """
+ store = self.hs.get_datastore()
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+ events = []
+
+ # Send a first event, which should be filtered out at the end of the test.
+ resp = self.helper.send(room_id=room_id, body="1", tok=self.token)
+
+ # Get the event from the store so that we end up with a FrozenEvent that we can
+ # give to filter_events_for_client. We need to do this now because the event won't
+ # be in the database anymore after it has expired.
+ events.append(self.get_success(store.get_event(resp.get("event_id"))))
+
+ # Advance the time by 2 days. We're using the default retention policy, therefore
+ # after this the first event will still be valid.
+ self.reactor.advance(one_day_ms * 2 / 1000)
+
+ # Send another event, which shouldn't get filtered out.
+ resp = self.helper.send(room_id=room_id, body="2", tok=self.token)
+
+ valid_event_id = resp.get("event_id")
+
+ events.append(self.get_success(store.get_event(valid_event_id)))
+
+ # Advance the time by anothe 2 days. After this, the first event should be
+ # outdated but not the second one.
+ self.reactor.advance(one_day_ms * 2 / 1000)
+
+ # Run filter_events_for_client with our list of FrozenEvents.
+ filtered_events = self.get_success(
+ filter_events_for_client(store, self.user_id, events)
+ )
+
+ # We should only get one event back.
+ self.assertEqual(len(filtered_events), 1, filtered_events)
+ # That event should be the second, not outdated event.
+ self.assertEqual(filtered_events[0].event_id, valid_event_id, filtered_events)
+
+ def _test_retention_event_purged(self, room_id, increment):
+ # Get the create event to, later, check that we can still access it.
+ message_handler = self.hs.get_message_handler()
+ create_event = self.get_success(
+ message_handler.get_room_data(self.user_id, room_id, EventTypes.Create)
+ )
+
+ # Send a first event to the room. This is the event we'll want to be purged at the
+ # end of the test.
+ resp = self.helper.send(room_id=room_id, body="1", tok=self.token)
+
+ expired_event_id = resp.get("event_id")
+
+ # Check that we can retrieve the event.
+ expired_event = self.get_event(room_id, expired_event_id)
+ self.assertEqual(
+ expired_event.get("content", {}).get("body"), "1", expired_event
+ )
+
+ # Advance the time.
+ self.reactor.advance(increment / 1000)
+
+ # Send another event. We need this because the purge job won't purge the most
+ # recent event in the room.
+ resp = self.helper.send(room_id=room_id, body="2", tok=self.token)
+
+ valid_event_id = resp.get("event_id")
+
+ # Advance the time again. Now our first event should have expired but our second
+ # one should still be kept.
+ self.reactor.advance(increment / 1000)
+
+ # Check that the event has been purged from the database.
+ self.get_event(room_id, expired_event_id, expected_code=404)
+
+ # Check that the event that hasn't been purged can still be retrieved.
+ valid_event = self.get_event(room_id, valid_event_id)
+ self.assertEqual(valid_event.get("content", {}).get("body"), "2", valid_event)
+
+ # Check that we can still access state events that were sent before the event that
+ # has been purged.
+ self.get_event(room_id, create_event.event_id)
+
+ def get_event(self, room_id, event_id, expected_code=200):
+ url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
+
+ request, channel = self.make_request("GET", url, access_token=self.token)
+ self.render(request)
+
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ return channel.json_body
+
+
+class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+ config["default_room_version"] = "1"
+ config["retention"] = {"enabled": True}
+
+ mock_federation_client = Mock(spec=["backfill"])
+
+ self.hs = self.setup_test_homeserver(
+ config=config, federation_client=mock_federation_client
+ )
+ return self.hs
+
+ def prepare(self, reactor, clock, homeserver):
+ self.user_id = self.register_user("user", "password")
+ self.token = self.login("user", "password")
+
+ def test_no_default_policy(self):
+ """Tests that an event doesn't get expired if there is neither a default retention
+ policy nor a policy specific to the room.
+ """
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+ self._test_retention(room_id)
+
+ def test_state_policy(self):
+ """Tests that an event gets correctly expired if there is no default retention
+ policy but there's a policy specific to the room.
+ """
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+ # Set the maximum lifetime to 35 days so that the first event gets expired but not
+ # the second one.
+ self.helper.send_state(
+ room_id=room_id,
+ event_type=EventTypes.Retention,
+ body={"max_lifetime": one_day_ms * 35},
+ tok=self.token,
+ )
+
+ self._test_retention(room_id, expected_code_for_first_event=404)
+
+ def _test_retention(self, room_id, expected_code_for_first_event=200):
+ # Send a first event to the room. This is the event we'll want to be purged at the
+ # end of the test.
+ resp = self.helper.send(room_id=room_id, body="1", tok=self.token)
+
+ first_event_id = resp.get("event_id")
+
+ # Check that we can retrieve the event.
+ expired_event = self.get_event(room_id, first_event_id)
+ self.assertEqual(
+ expired_event.get("content", {}).get("body"), "1", expired_event
+ )
+
+ # Advance the time by a month.
+ self.reactor.advance(one_day_ms * 30 / 1000)
+
+ # Send another event. We need this because the purge job won't purge the most
+ # recent event in the room.
+ resp = self.helper.send(room_id=room_id, body="2", tok=self.token)
+
+ second_event_id = resp.get("event_id")
+
+ # Advance the time by another month.
+ self.reactor.advance(one_day_ms * 30 / 1000)
+
+ # Check if the event has been purged from the database.
+ first_event = self.get_event(
+ room_id, first_event_id, expected_code=expected_code_for_first_event
+ )
+
+ if expected_code_for_first_event == 200:
+ self.assertEqual(
+ first_event.get("content", {}).get("body"), "1", first_event
+ )
+
+ # Check that the event that hasn't been purged can still be retrieved.
+ second_event = self.get_event(room_id, second_event_id)
+ self.assertEqual(second_event.get("content", {}).get("body"), "2", second_event)
+
+ def get_event(self, room_id, event_id, expected_code=200):
+ url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
+
+ request, channel = self.make_request("GET", url, access_token=self.token)
+ self.render(request)
+
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ return channel.json_body
diff --git a/tests/rest/client/test_room_access_rules.py b/tests/rest/client/test_room_access_rules.py
new file mode 100644
index 0000000000..d44f5c2c8c
--- /dev/null
+++ b/tests/rest/client/test_room_access_rules.py
@@ -0,0 +1,721 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import json
+import random
+import string
+
+from mock import Mock
+
+from twisted.internet import defer
+from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.third_party_rules.access_rules import (
+ ACCESS_RULE_DIRECT,
+ ACCESS_RULE_RESTRICTED,
+ ACCESS_RULE_UNRESTRICTED,
+ ACCESS_RULES_TYPE,
+)
+
+from tests import unittest
+
+
+class RoomAccessTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+
+ config["third_party_event_rules"] = {
+ "module": "synapse.third_party_rules.access_rules.RoomAccessRules",
+ "config": {
+ "domains_forbidden_when_restricted": ["forbidden_domain"],
+ "id_server": "testis",
+ },
+ }
+ config["trusted_third_party_id_servers"] = ["testis"]
+
+ def send_invite(destination, room_id, event_id, pdu):
+ return defer.succeed(pdu)
+
+ def get_json(uri, args={}, headers=None):
+ address_domain = args["address"].split("@")[1]
+ return defer.succeed({"hs": address_domain})
+
+ def post_json_get_json(uri, post_json, args={}, headers=None):
+ token = "".join(random.choice(string.ascii_letters) for _ in range(10))
+ return defer.succeed(
+ {
+ "token": token,
+ "public_keys": [
+ {
+ "public_key": "serverpublickey",
+ "key_validity_url": "https://testis/pubkey/isvalid",
+ },
+ {
+ "public_key": "phemeralpublickey",
+ "key_validity_url": "https://testis/pubkey/ephemeral/isvalid",
+ },
+ ],
+ "display_name": "f...@b...",
+ }
+ )
+
+ mock_federation_client = Mock(spec=["send_invite"])
+ mock_federation_client.send_invite.side_effect = send_invite
+
+ mock_http_client = Mock(
+ spec=["get_json", "post_json_get_json"],
+ )
+ # Mocking the response for /info on the IS API.
+ mock_http_client.get_json.side_effect = get_json
+ # Mocking the response for /store-invite on the IS API.
+ mock_http_client.post_json_get_json.side_effect = post_json_get_json
+ self.hs = self.setup_test_homeserver(
+ config=config,
+ federation_client=mock_federation_client,
+ simple_http_client=mock_http_client,
+ )
+
+ return self.hs
+
+ def prepare(self, reactor, clock, homeserver):
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+
+ self.restricted_room = self.create_room()
+ self.unrestricted_room = self.create_room(rule=ACCESS_RULE_UNRESTRICTED)
+ self.direct_rooms = [
+ self.create_room(direct=True),
+ self.create_room(direct=True),
+ self.create_room(direct=True),
+ ]
+
+ self.invitee_id = self.register_user("invitee", "test")
+ self.invitee_tok = self.login("invitee", "test")
+
+ self.helper.invite(
+ room=self.direct_rooms[0],
+ src=self.user_id,
+ targ=self.invitee_id,
+ tok=self.tok,
+ )
+
+ def test_create_room_no_rule(self):
+ """Tests that creating a room with no rule will set the default value."""
+ room_id = self.create_room()
+ rule = self.current_rule_in_room(room_id)
+
+ self.assertEqual(rule, ACCESS_RULE_RESTRICTED)
+
+ def test_create_room_direct_no_rule(self):
+ """Tests that creating a direct room with no rule will set the default value."""
+ room_id = self.create_room(direct=True)
+ rule = self.current_rule_in_room(room_id)
+
+ self.assertEqual(rule, ACCESS_RULE_DIRECT)
+
+ def test_create_room_valid_rule(self):
+ """Tests that creating a room with a valid rule will set the right value."""
+ room_id = self.create_room(rule=ACCESS_RULE_UNRESTRICTED)
+ rule = self.current_rule_in_room(room_id)
+
+ self.assertEqual(rule, ACCESS_RULE_UNRESTRICTED)
+
+ def test_create_room_invalid_rule(self):
+ """Tests that creating a room with an invalid rule will set fail."""
+ self.create_room(rule=ACCESS_RULE_DIRECT, expected_code=400)
+
+ def test_create_room_direct_invalid_rule(self):
+ """Tests that creating a direct room with an invalid rule will fail.
+ """
+ self.create_room(direct=True, rule=ACCESS_RULE_RESTRICTED, expected_code=400)
+
+ def test_public_room(self):
+ """Tests that it's not possible to have a room with the public join rule and an
+ access rule that's not restricted.
+ """
+ # Creating a room with the public_chat preset should succeed and set the access
+ # rule to restricted.
+ preset_room_id = self.create_room(preset=RoomCreationPreset.PUBLIC_CHAT)
+ self.assertEqual(
+ self.current_rule_in_room(preset_room_id), ACCESS_RULE_RESTRICTED
+ )
+
+ # Creating a room with the public join rule in its initial state should succeed
+ # and set the access rule to restricted.
+ init_state_room_id = self.create_room(
+ initial_state=[
+ {
+ "type": "m.room.join_rules",
+ "content": {"join_rule": JoinRules.PUBLIC},
+ }
+ ]
+ )
+ self.assertEqual(
+ self.current_rule_in_room(init_state_room_id), ACCESS_RULE_RESTRICTED
+ )
+
+ # Changing access rule to unrestricted should fail.
+ self.change_rule_in_room(
+ preset_room_id, ACCESS_RULE_UNRESTRICTED, expected_code=403
+ )
+ self.change_rule_in_room(
+ init_state_room_id, ACCESS_RULE_UNRESTRICTED, expected_code=403
+ )
+
+ # Changing access rule to direct should fail.
+ self.change_rule_in_room(preset_room_id, ACCESS_RULE_DIRECT, expected_code=403)
+ self.change_rule_in_room(
+ init_state_room_id, ACCESS_RULE_DIRECT, expected_code=403
+ )
+
+ # Changing join rule to public in an unrestricted room should fail.
+ self.change_join_rule_in_room(
+ self.unrestricted_room, JoinRules.PUBLIC, expected_code=403
+ )
+ # Changing join rule to public in an direct room should fail.
+ self.change_join_rule_in_room(
+ self.direct_rooms[0], JoinRules.PUBLIC, expected_code=403
+ )
+
+ # Creating a new room with the public_chat preset and an access rule that isn't
+ # restricted should fail.
+ self.create_room(
+ preset=RoomCreationPreset.PUBLIC_CHAT,
+ rule=ACCESS_RULE_UNRESTRICTED,
+ expected_code=400,
+ )
+ self.create_room(
+ preset=RoomCreationPreset.PUBLIC_CHAT,
+ rule=ACCESS_RULE_DIRECT,
+ expected_code=400,
+ )
+
+ # Creating a room with the public join rule in its initial state and an access
+ # rule that isn't restricted should fail.
+ self.create_room(
+ initial_state=[
+ {
+ "type": "m.room.join_rules",
+ "content": {"join_rule": JoinRules.PUBLIC},
+ }
+ ],
+ rule=ACCESS_RULE_UNRESTRICTED,
+ expected_code=400,
+ )
+ self.create_room(
+ initial_state=[
+ {
+ "type": "m.room.join_rules",
+ "content": {"join_rule": JoinRules.PUBLIC},
+ }
+ ],
+ rule=ACCESS_RULE_DIRECT,
+ expected_code=400,
+ )
+
+ def test_restricted(self):
+ """Tests that in restricted mode we're unable to invite users from blacklisted
+ servers but can invite other users.
+ """
+ # We can't invite a user from a forbidden HS.
+ self.helper.invite(
+ room=self.restricted_room,
+ src=self.user_id,
+ targ="@test:forbidden_domain",
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ # We can invite a user which HS isn't forbidden.
+ self.helper.invite(
+ room=self.restricted_room,
+ src=self.user_id,
+ targ="@test:allowed_domain",
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # We can't send a 3PID invite to an address that is mapped to a forbidden HS.
+ self.send_threepid_invite(
+ address="test@forbidden_domain",
+ room_id=self.restricted_room,
+ expected_code=403,
+ )
+
+ # We can send a 3PID invite to an address that is mapped to an HS that's not
+ # forbidden.
+ self.send_threepid_invite(
+ address="test@allowed_domain",
+ room_id=self.restricted_room,
+ expected_code=200,
+ )
+
+ def test_direct(self):
+ """Tests that, in direct mode, other users than the initial two can't be invited,
+ but the following scenario works:
+ * invited user joins the room
+ * invited user leaves the room
+ * room creator re-invites invited user
+ Also tests that a user from a HS that's in the list of forbidden domains (to use
+ in restricted mode) can be invited.
+ """
+ not_invited_user = "@not_invited:forbidden_domain"
+
+ # We can't invite a new user to the room.
+ self.helper.invite(
+ room=self.direct_rooms[0],
+ src=self.user_id,
+ targ=not_invited_user,
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ # The invited user can join the room.
+ self.helper.join(
+ room=self.direct_rooms[0],
+ user=self.invitee_id,
+ tok=self.invitee_tok,
+ expect_code=200,
+ )
+
+ # The invited user can leave the room.
+ self.helper.leave(
+ room=self.direct_rooms[0],
+ user=self.invitee_id,
+ tok=self.invitee_tok,
+ expect_code=200,
+ )
+
+ # The invited user can be re-invited to the room.
+ self.helper.invite(
+ room=self.direct_rooms[0],
+ src=self.user_id,
+ targ=self.invitee_id,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # If we're alone in the room and have always been the only member, we can invite
+ # someone.
+ self.helper.invite(
+ room=self.direct_rooms[1],
+ src=self.user_id,
+ targ=not_invited_user,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # Disable the 3pid invite ratelimiter
+ burst = self.hs.config.rc_third_party_invite.burst_count
+ per_second = self.hs.config.rc_third_party_invite.per_second
+ self.hs.config.rc_third_party_invite.burst_count = 10
+ self.hs.config.rc_third_party_invite.per_second = 0.1
+
+ # We can't send a 3PID invite to a room that already has two members.
+ self.send_threepid_invite(
+ address="test@allowed_domain",
+ room_id=self.direct_rooms[0],
+ expected_code=403,
+ )
+
+ # We can't send a 3PID invite to a room that already has a pending invite.
+ self.send_threepid_invite(
+ address="test@allowed_domain",
+ room_id=self.direct_rooms[1],
+ expected_code=403,
+ )
+
+ # We can send a 3PID invite to a room in which we've always been the only member.
+ self.send_threepid_invite(
+ address="test@forbidden_domain",
+ room_id=self.direct_rooms[2],
+ expected_code=200,
+ )
+
+ # We can send a 3PID invite to a room in which there's a 3PID invite.
+ self.send_threepid_invite(
+ address="test@forbidden_domain",
+ room_id=self.direct_rooms[2],
+ expected_code=403,
+ )
+
+ self.hs.config.rc_third_party_invite.burst_count = burst
+ self.hs.config.rc_third_party_invite.per_second = per_second
+
+ def test_unrestricted(self):
+ """Tests that, in unrestricted mode, we can invite whoever we want, but we can
+ only change the power level of users that wouldn't be forbidden in restricted
+ mode.
+ """
+ # We can invite
+ self.helper.invite(
+ room=self.unrestricted_room,
+ src=self.user_id,
+ targ="@test:forbidden_domain",
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.invite(
+ room=self.unrestricted_room,
+ src=self.user_id,
+ targ="@test:not_forbidden_domain",
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # We can send a 3PID invite to an address that is mapped to a forbidden HS.
+ self.send_threepid_invite(
+ address="test@forbidden_domain",
+ room_id=self.unrestricted_room,
+ expected_code=200,
+ )
+
+ # We can send a 3PID invite to an address that is mapped to an HS that's not
+ # forbidden.
+ self.send_threepid_invite(
+ address="test@allowed_domain",
+ room_id=self.unrestricted_room,
+ expected_code=200,
+ )
+
+ # We can send a power level event that doesn't redefine the default PL or set a
+ # non-default PL for a user that would be forbidden in restricted mode.
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.PowerLevels,
+ body={"users": {self.user_id: 100, "@test:not_forbidden_domain": 10}},
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # We can't send a power level event that redefines the default PL and doesn't set
+ # a non-default PL for a user that would be forbidden in restricted mode.
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.PowerLevels,
+ body={
+ "users": {self.user_id: 100, "@test:not_forbidden_domain": 10},
+ "users_default": 10,
+ },
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ # We can't send a power level event that doesn't redefines the default PL but sets
+ # a non-default PL for a user that would be forbidden in restricted mode.
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.PowerLevels,
+ body={"users": {self.user_id: 100, "@test:forbidden_domain": 10}},
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ def test_change_rules(self):
+ """Tests that we can only change the current rule from restricted to
+ unrestricted.
+ """
+ # We can change the rule from restricted to unrestricted.
+ self.change_rule_in_room(
+ room_id=self.restricted_room,
+ new_rule=ACCESS_RULE_UNRESTRICTED,
+ expected_code=200,
+ )
+
+ # We can't change the rule from restricted to direct.
+ self.change_rule_in_room(
+ room_id=self.restricted_room, new_rule=ACCESS_RULE_DIRECT, expected_code=403
+ )
+
+ # We can't change the rule from unrestricted to restricted.
+ self.change_rule_in_room(
+ room_id=self.unrestricted_room,
+ new_rule=ACCESS_RULE_RESTRICTED,
+ expected_code=403,
+ )
+
+ # We can't change the rule from unrestricted to direct.
+ self.change_rule_in_room(
+ room_id=self.unrestricted_room,
+ new_rule=ACCESS_RULE_DIRECT,
+ expected_code=403,
+ )
+
+ # We can't change the rule from direct to restricted.
+ self.change_rule_in_room(
+ room_id=self.direct_rooms[0],
+ new_rule=ACCESS_RULE_RESTRICTED,
+ expected_code=403,
+ )
+
+ # We can't change the rule from direct to unrestricted.
+ self.change_rule_in_room(
+ room_id=self.direct_rooms[0],
+ new_rule=ACCESS_RULE_UNRESTRICTED,
+ expected_code=403,
+ )
+
+ def test_change_room_avatar(self):
+ """Tests that changing the room avatar is always allowed unless the room is a
+ direct chat, in which case it's forbidden.
+ """
+
+ avatar_content = {
+ "info": {"h": 398, "mimetype": "image/jpeg", "size": 31037, "w": 394},
+ "url": "mxc://example.org/JWEIFJgwEIhweiWJE",
+ }
+
+ self.helper.send_state(
+ room_id=self.restricted_room,
+ event_type=EventTypes.RoomAvatar,
+ body=avatar_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.RoomAvatar,
+ body=avatar_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.direct_rooms[0],
+ event_type=EventTypes.RoomAvatar,
+ body=avatar_content,
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ def test_change_room_name(self):
+ """Tests that changing the room name is always allowed unless the room is a direct
+ chat, in which case it's forbidden.
+ """
+
+ name_content = {"name": "My super room"}
+
+ self.helper.send_state(
+ room_id=self.restricted_room,
+ event_type=EventTypes.Name,
+ body=name_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.Name,
+ body=name_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.direct_rooms[0],
+ event_type=EventTypes.Name,
+ body=name_content,
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ def test_change_room_topic(self):
+ """Tests that changing the room topic is always allowed unless the room is a
+ direct chat, in which case it's forbidden.
+ """
+
+ topic_content = {"topic": "Welcome to this room"}
+
+ self.helper.send_state(
+ room_id=self.restricted_room,
+ event_type=EventTypes.Topic,
+ body=topic_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.Topic,
+ body=topic_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.direct_rooms[0],
+ event_type=EventTypes.Topic,
+ body=topic_content,
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ def test_revoke_3pid_invite_direct(self):
+ """Tests that revoking a 3PID invite doesn't cause the room access rules module to
+ confuse the revokation as a new 3PID invite.
+ """
+ invite_token = "sometoken"
+
+ invite_body = {
+ "display_name": "ker...@exa...",
+ "public_keys": [
+ {
+ "key_validity_url": "https://validity_url",
+ "public_key": "ta8IQ0u1sp44HVpxYi7dFOdS/bfwDjcy4xLFlfY5KOA",
+ },
+ {
+ "key_validity_url": "https://validity_url",
+ "public_key": "4_9nzEeDwR5N9s51jPodBiLnqH43A2_g2InVT137t9I",
+ },
+ ],
+ "key_validity_url": "https://validity_url",
+ "public_key": "ta8IQ0u1sp44HVpxYi7dFOdS/bfwDjcy4xLFlfY5KOA",
+ }
+
+ self.send_state_with_state_key(
+ room_id=self.direct_rooms[1],
+ event_type=EventTypes.ThirdPartyInvite,
+ state_key=invite_token,
+ body=invite_body,
+ tok=self.tok,
+ )
+
+ self.send_state_with_state_key(
+ room_id=self.direct_rooms[1],
+ event_type=EventTypes.ThirdPartyInvite,
+ state_key=invite_token,
+ body={},
+ tok=self.tok,
+ )
+
+ invite_token = "someothertoken"
+
+ self.send_state_with_state_key(
+ room_id=self.direct_rooms[1],
+ event_type=EventTypes.ThirdPartyInvite,
+ state_key=invite_token,
+ body=invite_body,
+ tok=self.tok,
+ )
+
+ def create_room(
+ self,
+ direct=False,
+ rule=None,
+ preset=RoomCreationPreset.TRUSTED_PRIVATE_CHAT,
+ initial_state=None,
+ expected_code=200,
+ ):
+ content = {"is_direct": direct, "preset": preset}
+
+ if rule:
+ content["initial_state"] = [
+ {"type": ACCESS_RULES_TYPE, "state_key": "", "content": {"rule": rule}}
+ ]
+
+ if initial_state:
+ if "initial_state" not in content:
+ content["initial_state"] = []
+
+ content["initial_state"] += initial_state
+
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/createRoom",
+ json.dumps(content),
+ access_token=self.tok,
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ if expected_code == 200:
+ return channel.json_body["room_id"]
+
+ def current_rule_in_room(self, room_id):
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, ACCESS_RULES_TYPE),
+ access_token=self.tok,
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, 200, channel.result)
+ return channel.json_body["rule"]
+
+ def change_rule_in_room(self, room_id, new_rule, expected_code=200):
+ data = {"rule": new_rule}
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, ACCESS_RULES_TYPE),
+ json.dumps(data),
+ access_token=self.tok,
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ def change_join_rule_in_room(self, room_id, new_join_rule, expected_code=200):
+ data = {"join_rule": new_join_rule}
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, EventTypes.JoinRules),
+ json.dumps(data),
+ access_token=self.tok,
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ def send_threepid_invite(self, address, room_id, expected_code=200):
+ params = {"id_server": "testis", "medium": "email", "address": address}
+
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/%s/invite" % room_id,
+ json.dumps(params),
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ def send_state_with_state_key(
+ self, room_id, event_type, state_key, body, tok, expect_code=200
+ ):
+ path = "/_matrix/client/r0/rooms/%s/state/%s/%s" % (
+ room_id,
+ event_type,
+ state_key,
+ )
+
+ request, channel = self.make_request(
+ "PUT", path, json.dumps(body), access_token=tok
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, expect_code, channel.result)
+
+ return channel.json_body
diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py
index a8adc9a61d..a3d7e3c046 100644
--- a/tests/rest/client/test_transactions.py
+++ b/tests/rest/client/test_transactions.py
@@ -46,7 +46,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
@defer.inlineCallbacks
def cb():
yield Clock(reactor).sleep(0)
- defer.returnValue("yay")
+ return "yay"
@defer.inlineCallbacks
def test():
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index 140d8b3772..02b4b8f5eb 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -229,6 +229,7 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
config = self.default_config()
config["require_auth_for_profile_requests"] = True
+ config["limit_profile_requests_to_known_users"] = True
self.hs = self.setup_test_homeserver(config=config)
return self.hs
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 9915367144..cdded88b7f 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -128,8 +128,12 @@ class RestHelper(object):
return channel.json_body
- def send_state(self, room_id, event_type, body, tok, expect_code=200):
- path = "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, event_type)
+ def send_state(self, room_id, event_type, body, tok, expect_code=200, state_key=""):
+ path = "/_matrix/client/r0/rooms/%s/state/%s/%s" % (
+ room_id,
+ event_type,
+ state_key,
+ )
if tok:
path = path + "?access_token=%s" % tok
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index 920de41de4..9fed900f4a 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -23,8 +23,8 @@ from email.parser import Parser
import pkg_resources
import synapse.rest.admin
-from synapse.api.constants import LoginType
-from synapse.rest.client.v1 import login
+from synapse.api.constants import LoginType, Membership
+from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import account, register
from tests import unittest
@@ -244,6 +244,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
account.register_servlets,
+ room.register_servlets,
]
def make_homeserver(self, reactor, clock):
@@ -279,3 +280,56 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request("GET", "account/whoami")
self.render(request)
self.assertEqual(request.code, 401)
+
+ @unittest.INFO
+ def test_pending_invites(self):
+ """Tests that deactivating a user rejects every pending invite for them."""
+ store = self.hs.get_datastore()
+
+ inviter_id = self.register_user("inviter", "test")
+ inviter_tok = self.login("inviter", "test")
+
+ invitee_id = self.register_user("invitee", "test")
+ invitee_tok = self.login("invitee", "test")
+
+ # Make @inviter:test invite @invitee:test in a new room.
+ room_id = self.helper.create_room_as(inviter_id, tok=inviter_tok)
+ self.helper.invite(
+ room=room_id, src=inviter_id, targ=invitee_id, tok=inviter_tok
+ )
+
+ # Make sure the invite is here.
+ pending_invites = self.get_success(store.get_invited_rooms_for_user(invitee_id))
+ self.assertEqual(len(pending_invites), 1, pending_invites)
+ self.assertEqual(pending_invites[0].room_id, room_id, pending_invites)
+
+ # Deactivate @invitee:test.
+ self.deactivate(invitee_id, invitee_tok)
+
+ # Check that the invite isn't there anymore.
+ pending_invites = self.get_success(store.get_invited_rooms_for_user(invitee_id))
+ self.assertEqual(len(pending_invites), 0, pending_invites)
+
+ # Check that the membership of @invitee:test in the room is now "leave".
+ memberships = self.get_success(
+ store.get_rooms_for_user_where_membership_is(invitee_id, [Membership.LEAVE])
+ )
+ self.assertEqual(len(memberships), 1, memberships)
+ self.assertEqual(memberships[0].room_id, room_id, memberships)
+
+ def deactivate(self, user_id, tok):
+ request_data = json.dumps(
+ {
+ "auth": {
+ "type": "m.login.password",
+ "user": user_id,
+ "password": "test",
+ },
+ "erase": False,
+ }
+ )
+ request, channel = self.make_request(
+ "POST", "account/deactivate", request_data, access_token=tok
+ )
+ self.render(request)
+ self.assertEqual(request.code, 200)
diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/v2_alpha/test_password_policy.py
new file mode 100644
index 0000000000..37f970c6b0
--- /dev/null
+++ b/tests/rest/client/v2_alpha/test_password_policy.py
@@ -0,0 +1,177 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+
+from synapse.api.constants import LoginType
+from synapse.api.errors import Codes
+from synapse.rest import admin
+from synapse.rest.client.v1 import login
+from synapse.rest.client.v2_alpha import account, password_policy, register
+
+from tests import unittest
+
+
+class PasswordPolicyTestCase(unittest.HomeserverTestCase):
+ """Tests the password policy feature and its compliance with MSC2000.
+
+ When validating a password, Synapse does the necessary checks in this order:
+
+ 1. Password is long enough
+ 2. Password contains digit(s)
+ 3. Password contains symbol(s)
+ 4. Password contains uppercase letter(s)
+ 5. Password contains lowercase letter(s)
+
+ Therefore, each test in this test case that tests whether a password triggers the
+ right error code to be returned provides a password good enough to pass the previous
+ steps but not the one it's testing (nor any step that comes after).
+ """
+
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ register.register_servlets,
+ password_policy.register_servlets,
+ account.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ self.register_url = "/_matrix/client/r0/register"
+ self.policy = {
+ "enabled": True,
+ "minimum_length": 10,
+ "require_digit": True,
+ "require_symbol": True,
+ "require_lowercase": True,
+ "require_uppercase": True,
+ }
+
+ config = self.default_config()
+ config["password_config"] = {"policy": self.policy}
+
+ hs = self.setup_test_homeserver(config=config)
+ return hs
+
+ def test_get_policy(self):
+ """Tests if the /password_policy endpoint returns the configured policy."""
+
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/password_policy"
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(
+ channel.json_body,
+ {
+ "m.minimum_length": 10,
+ "m.require_digit": True,
+ "m.require_symbol": True,
+ "m.require_lowercase": True,
+ "m.require_uppercase": True,
+ },
+ channel.result,
+ )
+
+ def test_password_too_short(self):
+ request_data = json.dumps({"username": "kermit", "password": "shorty"})
+ request, channel = self.make_request("POST", self.register_url, request_data)
+ self.render(request)
+
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"], Codes.PASSWORD_TOO_SHORT, channel.result
+ )
+
+ def test_password_no_digit(self):
+ request_data = json.dumps({"username": "kermit", "password": "longerpassword"})
+ request, channel = self.make_request("POST", self.register_url, request_data)
+ self.render(request)
+
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT, channel.result
+ )
+
+ def test_password_no_symbol(self):
+ request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"})
+ request, channel = self.make_request("POST", self.register_url, request_data)
+ self.render(request)
+
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"], Codes.PASSWORD_NO_SYMBOL, channel.result
+ )
+
+ def test_password_no_uppercase(self):
+ request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"})
+ request, channel = self.make_request("POST", self.register_url, request_data)
+ self.render(request)
+
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"], Codes.PASSWORD_NO_UPPERCASE, channel.result
+ )
+
+ def test_password_no_lowercase(self):
+ request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"})
+ request, channel = self.make_request("POST", self.register_url, request_data)
+ self.render(request)
+
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"], Codes.PASSWORD_NO_LOWERCASE, channel.result
+ )
+
+ def test_password_compliant(self):
+ request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"})
+ request, channel = self.make_request("POST", self.register_url, request_data)
+ self.render(request)
+
+ # Getting a 401 here means the password has passed validation and the server has
+ # responded with a list of registration flows.
+ self.assertEqual(channel.code, 401, channel.result)
+
+ def test_password_change(self):
+ """This doesn't test every possible use case, only that hitting /account/password
+ triggers the password validation code.
+ """
+ compliant_password = "C0mpl!antpassword"
+ not_compliant_password = "notcompliantpassword"
+
+ user_id = self.register_user("kermit", compliant_password)
+ tok = self.login("kermit", compliant_password)
+
+ request_data = json.dumps(
+ {
+ "new_password": not_compliant_password,
+ "auth": {
+ "password": compliant_password,
+ "type": LoginType.PASSWORD,
+ "user": user_id,
+ },
+ }
+ )
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/account/password",
+ request_data,
+ access_token=tok,
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT)
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 89a3f95c0a..9b7eb1e856 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -19,8 +19,12 @@ import datetime
import json
import os
+from mock import Mock
+
import pkg_resources
+from twisted.internet import defer
+
import synapse.rest.admin
from synapse.api.constants import LoginType
from synapse.api.errors import Codes
@@ -200,6 +204,47 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
+class RegisterHideProfileTestCase(unittest.HomeserverTestCase):
+
+ servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource]
+
+ def make_homeserver(self, reactor, clock):
+
+ self.url = b"/_matrix/client/r0/register"
+
+ config = self.default_config()
+ config["enable_registration"] = True
+ config["show_users_in_user_directory"] = False
+ config["replicate_user_profiles_to"] = ["fakeserver"]
+
+ mock_http_client = Mock(spec=["get_json", "post_json_get_json"])
+ mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}"))
+
+ self.hs = self.setup_test_homeserver(
+ config=config, simple_http_client=mock_http_client
+ )
+
+ return self.hs
+
+ def test_profile_hidden(self):
+ user_id = self.register_user("kermit", "monkey")
+
+ post_json = self.hs.get_simple_http_client().post_json_get_json
+
+ # We expect post_json_get_json to have been called twice: once with the original
+ # profile and once with the None profile resulting from the request to hide it
+ # from the user directory.
+ self.assertEqual(post_json.call_count, 2, post_json.call_args_list)
+
+ # Get the args (and not kwargs) passed to post_json.
+ args = post_json.call_args[0]
+ # Make sure the last call was attempting to replicate profiles.
+ split_uri = args[0].split("/")
+ self.assertEqual(split_uri[len(split_uri) - 1], "replicate_profiles", args[0])
+ # Make sure the last profile update was overriding the user's profile to None.
+ self.assertEqual(args[1]["batch"][user_id], None, args[1])
+
+
class AccountValidityTestCase(unittest.HomeserverTestCase):
servlets = [
@@ -208,6 +253,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
login.register_servlets,
sync.register_servlets,
account_validity.register_servlets,
+ account.register_servlets,
]
def make_homeserver(self, reactor, clock):
@@ -300,6 +346,138 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
)
+class AccountValidityUserDirectoryTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.client.v1.profile.register_servlets,
+ synapse.rest.client.v1.room.register_servlets,
+ synapse.rest.client.v2_alpha.user_directory.register_servlets,
+ login.register_servlets,
+ register.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ account_validity.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+
+ # Set accounts to expire after a week
+ config["enable_registration"] = True
+ config["account_validity"] = {
+ "enabled": True,
+ "period": 604800000, # Time in ms for 1 week
+ }
+ config["replicate_user_profiles_to"] = "test.is"
+
+ # Mock homeserver requests to an identity server
+ mock_http_client = Mock(spec=["post_json_get_json"])
+ mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}"))
+
+ self.hs = self.setup_test_homeserver(
+ config=config, simple_http_client=mock_http_client
+ )
+
+ return self.hs
+
+ def test_expired_user_in_directory(self):
+ """Test that an expired user is hidden in the user directory"""
+ # Create an admin user to search the user directory
+ admin_id = self.register_user("admin", "adminpassword", admin=True)
+ admin_tok = self.login("admin", "adminpassword")
+
+ # Ensure the admin never expires
+ url = "/_matrix/client/unstable/admin/account_validity/validity"
+ params = {
+ "user_id": admin_id,
+ "expiration_ts": 999999999999,
+ "enable_renewal_emails": False,
+ }
+ request_data = json.dumps(params)
+ request, channel = self.make_request(
+ b"POST", url, request_data, access_token=admin_tok
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ # Create a user to expire
+ username = "kermit"
+ user_id = self.register_user(username, "monkey")
+ self.login(username, "monkey")
+
+ self.pump(1000)
+ self.reactor.advance(1000)
+ self.pump()
+
+ # Expire the user
+ url = "/_matrix/client/unstable/admin/account_validity/validity"
+ params = {
+ "user_id": user_id,
+ "expiration_ts": 0,
+ "enable_renewal_emails": False,
+ }
+ request_data = json.dumps(params)
+ request, channel = self.make_request(
+ b"POST", url, request_data, access_token=admin_tok
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ # Wait for the background job to run which hides expired users in the directory
+ self.pump(60 * 60 * 1000)
+
+ # Mock the homeserver's HTTP client
+ post_json = self.hs.get_simple_http_client().post_json_get_json
+
+ # Check if the homeserver has replicated the user's profile to the identity server
+ self.assertNotEquals(post_json.call_args, None, post_json.call_args)
+ payload = post_json.call_args[0][1]
+ batch = payload.get("batch")
+ self.assertNotEquals(batch, None, batch)
+ self.assertEquals(len(batch), 1, batch)
+ replicated_user_id = list(batch.keys())[0]
+ self.assertEquals(replicated_user_id, user_id, replicated_user_id)
+
+ # There was replicated information about our user
+ # Check that it's None, signifying that the user should be removed from the user
+ # directory because they were expired
+ replicated_content = batch[user_id]
+ self.assertIsNone(replicated_content)
+
+ # Now renew the user, and check they get replicated again to the identity server
+ url = "/_matrix/client/unstable/admin/account_validity/validity"
+ params = {
+ "user_id": user_id,
+ "expiration_ts": 99999999999,
+ "enable_renewal_emails": False,
+ }
+ request_data = json.dumps(params)
+ request, channel = self.make_request(
+ b"POST", url, request_data, access_token=admin_tok
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ self.pump(10)
+ self.reactor.advance(10)
+ self.pump()
+
+ # Check if the homeserver has replicated the user's profile to the identity server
+ post_json = self.hs.get_simple_http_client().post_json_get_json
+ self.assertNotEquals(post_json.call_args, None, post_json.call_args)
+ payload = post_json.call_args[0][1]
+ batch = payload.get("batch")
+ self.assertNotEquals(batch, None, batch)
+ self.assertEquals(len(batch), 1, batch)
+ replicated_user_id = list(batch.keys())[0]
+ self.assertEquals(replicated_user_id, user_id, replicated_user_id)
+
+ # There was replicated information about our user
+ # Check that it's not None, signifying that the user is back in the user
+ # directory
+ replicated_content = batch[user_id]
+ self.assertIsNotNone(replicated_content)
+
+
class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
servlets = [
@@ -323,6 +501,8 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
"renew_at": 172800000, # Time in ms for 2 days
"renew_by_email_enabled": True,
"renew_email_subject": "Renew your account",
+ "account_renewed_html_path": "account_renewed.html",
+ "invalid_token_html_path": "invalid_token.html",
}
# Email config.
@@ -373,6 +553,19 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
+ # Check that we're getting HTML back.
+ content_type = None
+ for header in channel.result.get("headers", []):
+ if header[0] == b"Content-Type":
+ content_type = header[1]
+ self.assertEqual(content_type, b"text/html; charset=utf-8", channel.result)
+
+ # Check that the HTML we're getting is the one we expect on a successful renewal.
+ expected_html = self.hs.config.account_validity.account_renewed_html_content
+ self.assertEqual(
+ channel.result["body"], expected_html.encode("utf8"), channel.result
+ )
+
# Move 3 days forward. If the renewal failed, every authed request with
# our access token should be denied from now, otherwise they should
# succeed.
@@ -381,6 +574,28 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
+ def test_renewal_invalid_token(self):
+ # Hit the renewal endpoint with an invalid token and check that it behaves as
+ # expected, i.e. that it responds with 404 Not Found and the correct HTML.
+ url = "/_matrix/client/unstable/account_validity/renew?token=123"
+ request, channel = self.make_request(b"GET", url)
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"404", channel.result)
+
+ # Check that we're getting HTML back.
+ content_type = None
+ for header in channel.result.get("headers", []):
+ if header[0] == b"Content-Type":
+ content_type = header[1]
+ self.assertEqual(content_type, b"text/html; charset=utf-8", channel.result)
+
+ # Check that the HTML we're getting is the one we expect when using an
+ # invalid/unknown token.
+ expected_html = self.hs.config.account_validity.invalid_token_html_content
+ self.assertEqual(
+ channel.result["body"], expected_html.encode("utf8"), channel.result
+ )
+
def test_manual_email_send(self):
self.email_attempts = []
@@ -414,7 +629,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
"POST", "account/deactivate", request_data, access_token=tok
)
self.render(request)
- self.assertEqual(request.code, 200)
+ self.assertEqual(request.code, 200, channel.result)
self.reactor.advance(datetime.timedelta(days=8).total_seconds())
diff --git a/tests/rulecheck/__init__.py b/tests/rulecheck/__init__.py
new file mode 100644
index 0000000000..a354d38ca8
--- /dev/null
+++ b/tests/rulecheck/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/rulecheck/test_domainrulecheck.py b/tests/rulecheck/test_domainrulecheck.py
new file mode 100644
index 0000000000..1accc70dc9
--- /dev/null
+++ b/tests/rulecheck/test_domainrulecheck.py
@@ -0,0 +1,334 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import json
+
+import synapse.rest.admin
+from synapse.config._base import ConfigError
+from synapse.rest.client.v1 import login, room
+from synapse.rulecheck.domain_rule_checker import DomainRuleChecker
+
+from tests import unittest
+from tests.server import make_request, render
+
+
+class DomainRuleCheckerTestCase(unittest.TestCase):
+ def test_allowed(self):
+ config = {
+ "default": False,
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ },
+ "domains_prevented_from_being_invited_to_published_rooms": ["target_two"],
+ }
+ check = DomainRuleChecker(config)
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_one", "test:target_one", None, "room", False
+ )
+ )
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_one", "test:target_two", None, "room", False
+ )
+ )
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_two", "test:target_two", None, "room", False
+ )
+ )
+
+ # User can invite internal user to a published room
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_one", "test1:target_one", None, "room", False, True
+ )
+ )
+
+ # User can invite external user to a non-published room
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_one", "test:target_two", None, "room", False, False
+ )
+ )
+
+ def test_disallowed(self):
+ config = {
+ "default": True,
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ "source_four": [],
+ },
+ }
+ check = DomainRuleChecker(config)
+ self.assertFalse(
+ check.user_may_invite(
+ "test:source_one", "test:target_three", None, "room", False
+ )
+ )
+ self.assertFalse(
+ check.user_may_invite(
+ "test:source_two", "test:target_three", None, "room", False
+ )
+ )
+ self.assertFalse(
+ check.user_may_invite(
+ "test:source_two", "test:target_one", None, "room", False
+ )
+ )
+ self.assertFalse(
+ check.user_may_invite(
+ "test:source_four", "test:target_one", None, "room", False
+ )
+ )
+
+ # User cannot invite external user to a published room
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_one", "test:target_two", None, "room", False, True
+ )
+ )
+
+ def test_default_allow(self):
+ config = {
+ "default": True,
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ },
+ }
+ check = DomainRuleChecker(config)
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_three", "test:target_one", None, "room", False
+ )
+ )
+
+ def test_default_deny(self):
+ config = {
+ "default": False,
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ },
+ }
+ check = DomainRuleChecker(config)
+ self.assertFalse(
+ check.user_may_invite(
+ "test:source_three", "test:target_one", None, "room", False
+ )
+ )
+
+ def test_config_parse(self):
+ config = {
+ "default": False,
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ },
+ }
+ self.assertEquals(config, DomainRuleChecker.parse_config(config))
+
+ def test_config_parse_failure(self):
+ config = {
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ }
+ }
+ self.assertRaises(ConfigError, DomainRuleChecker.parse_config, config)
+
+
+class DomainRuleCheckerRoomTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ hijack_auth = False
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+ config["trusted_third_party_id_servers"] = ["localhost"]
+
+ config["spam_checker"] = {
+ "module": "synapse.rulecheck.domain_rule_checker.DomainRuleChecker",
+ "config": {
+ "default": True,
+ "domain_mapping": {},
+ "can_only_join_rooms_with_invite": True,
+ "can_only_create_one_to_one_rooms": True,
+ "can_only_invite_during_room_creation": True,
+ "can_invite_by_third_party_id": False,
+ },
+ }
+
+ hs = self.setup_test_homeserver(config=config)
+ return hs
+
+ def prepare(self, reactor, clock, hs):
+ self.admin_user_id = self.register_user("admin_user", "pass", admin=True)
+ self.admin_access_token = self.login("admin_user", "pass")
+
+ self.normal_user_id = self.register_user("normal_user", "pass", admin=False)
+ self.normal_access_token = self.login("normal_user", "pass")
+
+ self.other_user_id = self.register_user("other_user", "pass", admin=False)
+
+ def test_admin_can_create_room(self):
+ channel = self._create_room(self.admin_access_token)
+ assert channel.result["code"] == b"200", channel.result
+
+ def test_normal_user_cannot_create_empty_room(self):
+ channel = self._create_room(self.normal_access_token)
+ assert channel.result["code"] == b"403", channel.result
+
+ def test_normal_user_cannot_create_room_with_multiple_invites(self):
+ channel = self._create_room(
+ self.normal_access_token,
+ content={"invite": [self.other_user_id, self.admin_user_id]},
+ )
+ assert channel.result["code"] == b"403", channel.result
+
+ # Test that it correctly counts both normal and third party invites
+ channel = self._create_room(
+ self.normal_access_token,
+ content={
+ "invite": [self.other_user_id],
+ "invite_3pid": [{"medium": "email", "address": "foo@example.com"}],
+ },
+ )
+ assert channel.result["code"] == b"403", channel.result
+
+ # Test that it correctly rejects third party invites
+ channel = self._create_room(
+ self.normal_access_token,
+ content={
+ "invite": [],
+ "invite_3pid": [{"medium": "email", "address": "foo@example.com"}],
+ },
+ )
+ assert channel.result["code"] == b"403", channel.result
+
+ def test_normal_user_can_room_with_single_invites(self):
+ channel = self._create_room(
+ self.normal_access_token, content={"invite": [self.other_user_id]}
+ )
+ assert channel.result["code"] == b"200", channel.result
+
+ def test_cannot_join_public_room(self):
+ channel = self._create_room(self.admin_access_token)
+ assert channel.result["code"] == b"200", channel.result
+
+ room_id = channel.json_body["room_id"]
+
+ self.helper.join(
+ room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=403
+ )
+
+ def test_can_join_invited_room(self):
+ channel = self._create_room(self.admin_access_token)
+ assert channel.result["code"] == b"200", channel.result
+
+ room_id = channel.json_body["room_id"]
+
+ self.helper.invite(
+ room_id,
+ src=self.admin_user_id,
+ targ=self.normal_user_id,
+ tok=self.admin_access_token,
+ )
+
+ self.helper.join(
+ room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=200
+ )
+
+ def test_cannot_invite(self):
+ channel = self._create_room(self.admin_access_token)
+ assert channel.result["code"] == b"200", channel.result
+
+ room_id = channel.json_body["room_id"]
+
+ self.helper.invite(
+ room_id,
+ src=self.admin_user_id,
+ targ=self.normal_user_id,
+ tok=self.admin_access_token,
+ )
+
+ self.helper.join(
+ room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=200
+ )
+
+ self.helper.invite(
+ room_id,
+ src=self.normal_user_id,
+ targ=self.other_user_id,
+ tok=self.normal_access_token,
+ expect_code=403,
+ )
+
+ def test_cannot_3pid_invite(self):
+ """Test that unbound 3pid invites get rejected.
+ """
+ channel = self._create_room(self.admin_access_token)
+ assert channel.result["code"] == b"200", channel.result
+
+ room_id = channel.json_body["room_id"]
+
+ self.helper.invite(
+ room_id,
+ src=self.admin_user_id,
+ targ=self.normal_user_id,
+ tok=self.admin_access_token,
+ )
+
+ self.helper.join(
+ room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=200
+ )
+
+ self.helper.invite(
+ room_id,
+ src=self.normal_user_id,
+ targ=self.other_user_id,
+ tok=self.normal_access_token,
+ expect_code=403,
+ )
+
+ request, channel = self.make_request(
+ "POST",
+ "rooms/%s/invite" % (room_id),
+ {"address": "foo@bar.com", "medium": "email", "id_server": "localhost"},
+ access_token=self.normal_access_token,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 403, channel.result["body"])
+
+ def _create_room(self, token, content={}):
+ path = "/_matrix/client/r0/createRoom?access_token=%s" % (token,)
+
+ request, channel = make_request(
+ self.hs.get_reactor(),
+ "POST",
+ path,
+ content=json.dumps(content).encode("utf8"),
+ )
+ render(request, self.resource, self.hs.get_reactor())
+
+ return channel
diff --git a/tests/server.py b/tests/server.py
index e573c4e4c5..d10c0603e9 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -387,11 +387,24 @@ class FakeTransport(object):
self.disconnecting = True
if self._protocol:
self._protocol.connectionLost(reason)
- self.disconnected = True
+
+ # if we still have data to write, delay until that is done
+ if self.buffer:
+ logger.info(
+ "FakeTransport: Delaying disconnect until buffer is flushed"
+ )
+ else:
+ self.disconnected = True
def abortConnection(self):
logger.info("FakeTransport: abortConnection()")
- self.loseConnection()
+
+ if not self.disconnecting:
+ self.disconnecting = True
+ if self._protocol:
+ self._protocol.connectionLost(None)
+
+ self.disconnected = True
def pauseProducing(self):
if not self.producer:
@@ -422,6 +435,9 @@ class FakeTransport(object):
self._reactor.callLater(0.0, _produce)
def write(self, byt):
+ if self.disconnecting:
+ raise Exception("Writing to disconnecting FakeTransport")
+
self.buffer = self.buffer + byt
# always actually do the write asynchronously. Some protocols (notably the
@@ -465,3 +481,7 @@ class FakeTransport(object):
self.buffer = self.buffer[len(to_write) :]
if self.buffer and self.autoflush:
self._reactor.callLater(0.0, self.flush)
+
+ if not self.buffer and self.disconnecting:
+ logger.info("FakeTransport: Buffer now empty, completing disconnect")
+ self.disconnected = True
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index fbb9302694..9fabe3fbc0 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -43,7 +43,7 @@ class BackgroundUpdateTestCase(unittest.TestCase):
"test_update",
progress,
)
- defer.returnValue(count)
+ return count
self.update_handler.side_effect = update
@@ -60,7 +60,7 @@ class BackgroundUpdateTestCase(unittest.TestCase):
@defer.inlineCallbacks
def update(progress, count):
yield self.store._end_background_update("test_update")
- defer.returnValue(count)
+ return count
self.update_handler.side_effect = update
self.update_handler.reset_mock()
diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py
index e07ff01201..95f309fbbc 100644
--- a/tests/storage/test_keys.py
+++ b/tests/storage/test_keys.py
@@ -14,6 +14,7 @@
# limitations under the License.
import signedjson.key
+import unpaddedbase64
from twisted.internet.defer import Deferred
@@ -21,11 +22,17 @@ from synapse.storage.keys import FetchKeyResult
import tests.unittest
-KEY_1 = signedjson.key.decode_verify_key_base64(
- "ed25519", "key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw"
+
+def decode_verify_key_base64(key_id: str, key_base64: str):
+ key_bytes = unpaddedbase64.decode_base64(key_base64)
+ return signedjson.key.decode_verify_key_bytes(key_id, key_bytes)
+
+
+KEY_1 = decode_verify_key_base64(
+ "ed25519:key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw"
)
-KEY_2 = signedjson.key.decode_verify_key_base64(
- "ed25519", "key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
+KEY_2 = decode_verify_key_base64(
+ "ed25519:key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
)
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 45824bd3b2..13e9f8ec09 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -34,9 +34,7 @@ class ProfileStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_displayname(self):
- yield self.store.create_profile(self.u_frank.localpart)
-
- yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
+ yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank", 1)
self.assertEquals(
"Frank", (yield self.store.get_profile_displayname(self.u_frank.localpart))
@@ -44,10 +42,8 @@ class ProfileStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_avatar_url(self):
- yield self.store.create_profile(self.u_frank.localpart)
-
yield self.store.set_profile_avatar_url(
- self.u_frank.localpart, "http://my.site/here"
+ self.u_frank.localpart, "http://my.site/here", 1
)
self.assertEquals(
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 732a778fab..8488b6edc8 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,23 +17,21 @@
from mock import Mock
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.types import RoomID, UserID
from tests import unittest
-from tests.utils import create_room, setup_test_homeserver
+from tests.utils import create_room
-class RedactionTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield setup_test_homeserver(
- self.addCleanup, resource_for_federation=Mock(), http_client=None
+class RedactionTestCase(unittest.HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+ return self.setup_test_homeserver(
+ resource_for_federation=Mock(), http_client=None
)
+ def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.event_builder_factory = hs.get_event_builder_factory()
self.event_creation_handler = hs.get_event_creation_handler()
@@ -42,11 +41,12 @@ class RedactionTestCase(unittest.TestCase):
self.room1 = RoomID.from_string("!abc123:test")
- yield create_room(hs, self.room1.to_string(), self.u_alice.to_string())
+ self.get_success(
+ create_room(hs, self.room1.to_string(), self.u_alice.to_string())
+ )
self.depth = 1
- @defer.inlineCallbacks
def inject_room_member(
self, room, user, membership, replaces_state=None, extra_content={}
):
@@ -63,15 +63,14 @@ class RedactionTestCase(unittest.TestCase):
},
)
- event, context = yield self.event_creation_handler.create_new_client_event(
- builder
+ event, context = self.get_success(
+ self.event_creation_handler.create_new_client_event(builder)
)
- yield self.store.persist_event(event, context)
+ self.get_success(self.store.persist_event(event, context))
- defer.returnValue(event)
+ return event
- @defer.inlineCallbacks
def inject_message(self, room, user, body):
self.depth += 1
@@ -86,15 +85,14 @@ class RedactionTestCase(unittest.TestCase):
},
)
- event, context = yield self.event_creation_handler.create_new_client_event(
- builder
+ event, context = self.get_success(
+ self.event_creation_handler.create_new_client_event(builder)
)
- yield self.store.persist_event(event, context)
+ self.get_success(self.store.persist_event(event, context))
- defer.returnValue(event)
+ return event
- @defer.inlineCallbacks
def inject_redaction(self, room, event_id, user, reason):
builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
@@ -108,20 +106,21 @@ class RedactionTestCase(unittest.TestCase):
},
)
- event, context = yield self.event_creation_handler.create_new_client_event(
- builder
+ event, context = self.get_success(
+ self.event_creation_handler.create_new_client_event(builder)
)
- yield self.store.persist_event(event, context)
+ self.get_success(self.store.persist_event(event, context))
- @defer.inlineCallbacks
def test_redact(self):
- yield self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
+ self.get_success(
+ self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
+ )
- msg_event = yield self.inject_message(self.room1, self.u_alice, "t")
+ msg_event = self.get_success(self.inject_message(self.room1, self.u_alice, "t"))
# Check event has not been redacted:
- event = yield self.store.get_event(msg_event.event_id)
+ event = self.get_success(self.store.get_event(msg_event.event_id))
self.assertObjectHasAttributes(
{
@@ -136,11 +135,11 @@ class RedactionTestCase(unittest.TestCase):
# Redact event
reason = "Because I said so"
- yield self.inject_redaction(
- self.room1, msg_event.event_id, self.u_alice, reason
+ self.get_success(
+ self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason)
)
- event = yield self.store.get_event(msg_event.event_id)
+ event = self.get_success(self.store.get_event(msg_event.event_id))
self.assertEqual(msg_event.event_id, event.event_id)
@@ -164,15 +163,18 @@ class RedactionTestCase(unittest.TestCase):
event.unsigned["redacted_because"],
)
- @defer.inlineCallbacks
def test_redact_join(self):
- yield self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
+ self.get_success(
+ self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
+ )
- msg_event = yield self.inject_room_member(
- self.room1, self.u_bob, Membership.JOIN, extra_content={"blue": "red"}
+ msg_event = self.get_success(
+ self.inject_room_member(
+ self.room1, self.u_bob, Membership.JOIN, extra_content={"blue": "red"}
+ )
)
- event = yield self.store.get_event(msg_event.event_id)
+ event = self.get_success(self.store.get_event(msg_event.event_id))
self.assertObjectHasAttributes(
{
@@ -187,13 +189,13 @@ class RedactionTestCase(unittest.TestCase):
# Redact event
reason = "Because I said so"
- yield self.inject_redaction(
- self.room1, msg_event.event_id, self.u_alice, reason
+ self.get_success(
+ self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason)
)
# Check redaction
- event = yield self.store.get_event(msg_event.event_id)
+ event = self.get_success(self.store.get_event(msg_event.event_id))
self.assertTrue("redacted_because" in event.unsigned)
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 73ed943f5a..c6e8196b91 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -67,7 +67,7 @@ class RoomMemberStoreTestCase(unittest.TestCase):
yield self.store.persist_event(event, context)
- defer.returnValue(event)
+ return event
@defer.inlineCallbacks
def test_one_member(self):
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 212a7ae765..5c2cf3c2db 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -65,7 +65,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
yield self.store.persist_event(event, context)
- defer.returnValue(event)
+ return event
def assertStateMapEqual(self, s1, s2):
for t in s1:
diff --git a/tests/test_types.py b/tests/test_types.py
index 9ab5f829b0..7cb1f8acb4 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -12,9 +12,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from six import string_types
from synapse.api.errors import SynapseError
-from synapse.types import GroupID, RoomAlias, UserID, map_username_to_mxid_localpart
+from synapse.types import (
+ GroupID,
+ RoomAlias,
+ UserID,
+ map_username_to_mxid_localpart,
+ strip_invalid_mxid_characters,
+)
from tests import unittest
from tests.utils import TestHomeServer
@@ -106,3 +113,16 @@ class MapUsernameTestCase(unittest.TestCase):
self.assertEqual(
map_username_to_mxid_localpart("têst".encode("utf-8")), "t=c3=aast"
)
+
+
+class StripInvalidMxidCharactersTestCase(unittest.TestCase):
+ def test_return_type(self):
+ unstripped = strip_invalid_mxid_characters("test")
+ stripped = strip_invalid_mxid_characters("test@")
+
+ self.assertTrue(isinstance(unstripped, string_types), type(unstripped))
+ self.assertTrue(isinstance(stripped, string_types), type(stripped))
+
+ def test_strip(self):
+ stripped = strip_invalid_mxid_characters("test@")
+ self.assertEqual(stripped, "test", stripped)
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 118c3bd238..e0605dac2f 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -139,7 +139,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
builder
)
yield self.hs.get_datastore().persist_event(event, context)
- defer.returnValue(event)
+ return event
@defer.inlineCallbacks
def inject_room_member(self, user_id, membership="join", extra_content={}):
@@ -161,7 +161,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
)
yield self.hs.get_datastore().persist_event(event, context)
- defer.returnValue(event)
+ return event
@defer.inlineCallbacks
def inject_message(self, user_id, content=None):
@@ -182,7 +182,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
)
yield self.hs.get_datastore().persist_event(event, context)
- defer.returnValue(event)
+ return event
@defer.inlineCallbacks
def test_large_room(self):
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 7807328e2f..5713870f48 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -27,6 +27,7 @@ from synapse.logging.context import (
make_deferred_yieldable,
)
from synapse.util.caches import descriptors
+from synapse.util.caches.descriptors import cached
from tests import unittest
@@ -55,12 +56,15 @@ class CacheTestCase(unittest.TestCase):
d2 = defer.Deferred()
cache.set("key2", d2, partial(record_callback, 1))
- # lookup should return the deferreds
- self.assertIs(cache.get("key1"), d1)
- self.assertIs(cache.get("key2"), d2)
+ # lookup should return observable deferreds
+ self.assertFalse(cache.get("key1").has_called())
+ self.assertFalse(cache.get("key2").has_called())
# let one of the lookups complete
d2.callback("result2")
+
+ # for now at least, the cache will return real results rather than an
+ # observabledeferred
self.assertEqual(cache.get("key2"), "result2")
# now do the invalidation
@@ -146,6 +150,28 @@ class DescriptorTestCase(unittest.TestCase):
self.assertEqual(r, "chips")
obj.mock.assert_not_called()
+ def test_cache_with_sync_exception(self):
+ """If the wrapped function throws synchronously, things should continue to work
+ """
+
+ class Cls(object):
+ @cached()
+ def fn(self, arg1):
+ raise SynapseError(100, "mai spoon iz too big!!1")
+
+ obj = Cls()
+
+ # this should fail immediately
+ d = obj.fn(1)
+ self.failureResultOf(d, SynapseError)
+
+ # ... leaving the cache empty
+ self.assertEqual(len(obj.fn.cache.cache), 0)
+
+ # and a second call should result in a second exception
+ d = obj.fn(1)
+ self.failureResultOf(d, SynapseError)
+
def test_cache_logcontexts(self):
"""Check that logcontexts are set and restored correctly when
using the cache."""
@@ -159,7 +185,7 @@ class DescriptorTestCase(unittest.TestCase):
def inner_fn():
with PreserveLoggingContext():
yield complete_lookup
- defer.returnValue(1)
+ return 1
return inner_fn()
@@ -169,7 +195,7 @@ class DescriptorTestCase(unittest.TestCase):
c1.name = "c1"
r = yield obj.fn(1)
self.assertEqual(LoggingContext.current_context(), c1)
- defer.returnValue(r)
+ return r
def check_result(r):
self.assertEqual(r, 1)
@@ -222,6 +248,9 @@ class DescriptorTestCase(unittest.TestCase):
self.assertEqual(LoggingContext.current_context(), c1)
+ # the cache should now be empty
+ self.assertEqual(len(obj.fn.cache.cache), 0)
+
obj = Cls()
# set off a deferred which will do a cache lookup
@@ -268,6 +297,61 @@ class DescriptorTestCase(unittest.TestCase):
self.assertEqual(r, "chips")
obj.mock.assert_not_called()
+ def test_cache_iterable(self):
+ class Cls(object):
+ def __init__(self):
+ self.mock = mock.Mock()
+
+ @descriptors.cached(iterable=True)
+ def fn(self, arg1, arg2):
+ return self.mock(arg1, arg2)
+
+ obj = Cls()
+
+ obj.mock.return_value = ["spam", "eggs"]
+ r = obj.fn(1, 2)
+ self.assertEqual(r, ["spam", "eggs"])
+ obj.mock.assert_called_once_with(1, 2)
+ obj.mock.reset_mock()
+
+ # a call with different params should call the mock again
+ obj.mock.return_value = ["chips"]
+ r = obj.fn(1, 3)
+ self.assertEqual(r, ["chips"])
+ obj.mock.assert_called_once_with(1, 3)
+ obj.mock.reset_mock()
+
+ # the two values should now be cached
+ self.assertEqual(len(obj.fn.cache.cache), 3)
+
+ r = obj.fn(1, 2)
+ self.assertEqual(r, ["spam", "eggs"])
+ r = obj.fn(1, 3)
+ self.assertEqual(r, ["chips"])
+ obj.mock.assert_not_called()
+
+ def test_cache_iterable_with_sync_exception(self):
+ """If the wrapped function throws synchronously, things should continue to work
+ """
+
+ class Cls(object):
+ @descriptors.cached(iterable=True)
+ def fn(self, arg1):
+ raise SynapseError(100, "mai spoon iz too big!!1")
+
+ obj = Cls()
+
+ # this should fail immediately
+ d = obj.fn(1)
+ self.failureResultOf(d, SynapseError)
+
+ # ... leaving the cache empty
+ self.assertEqual(len(obj.fn.cache.cache), 0)
+
+ # and a second call should result in a second exception
+ d = obj.fn(1)
+ self.failureResultOf(d, SynapseError)
+
class CachedListDescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks
@@ -286,7 +370,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
# we want this to behave like an asynchronous function
yield run_on_reactor()
assert LoggingContext.current_context().request == "c1"
- defer.returnValue(self.mock(args1, arg2))
+ return self.mock(args1, arg2)
with LoggingContext() as c1:
c1.request = "c1"
@@ -334,7 +418,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
def list_fn(self, args1, arg2):
# we want this to behave like an asynchronous function
yield run_on_reactor()
- defer.returnValue(self.mock(args1, arg2))
+ return self.mock(args1, arg2)
obj = Cls()
invalidate0 = mock.Mock()
diff --git a/tests/utils.py b/tests/utils.py
index 99a3deae21..6350646263 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -361,7 +361,7 @@ def setup_test_homeserver(
if fed:
register_federation_servlets(hs, fed)
- defer.returnValue(hs)
+ return hs
def register_federation_servlets(hs, resource):
@@ -465,9 +465,9 @@ class MockHttpResource(HttpServer):
args = [urlparse.unquote(u) for u in matcher.groups()]
(code, response) = yield func(mock_request, *args)
- defer.returnValue((code, response))
+ return (code, response)
except CodeMessageException as e:
- defer.returnValue((e.code, cs_error(e.msg, code=e.errcode)))
+ return (e.code, cs_error(e.msg, code=e.errcode))
raise KeyError("No event can handle %s" % path)
|