diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py
index be20a89682..641093d349 100644
--- a/tests/app/test_frontend_proxy.py
+++ b/tests/app/test_frontend_proxy.py
@@ -30,6 +30,16 @@ class FrontendProxyTests(HomeserverTestCase):
def default_config(self):
c = super().default_config()
c["worker_app"] = "synapse.app.frontend_proxy"
+
+ c["worker_listeners"] = [
+ {
+ "type": "http",
+ "port": 8080,
+ "bind_addresses": ["0.0.0.0"],
+ "resources": [{"names": ["client"]}],
+ }
+ ]
+
return c
def test_listen_http_with_presence_enabled(self):
@@ -39,14 +49,8 @@ class FrontendProxyTests(HomeserverTestCase):
# Presence is on
self.hs.config.use_presence = True
- config = {
- "port": 8080,
- "bind_addresses": ["0.0.0.0"],
- "resources": [{"names": ["client"]}],
- }
-
# Listen with the config
- self.hs._listen_http(config)
+ self.hs._listen_http(self.hs.config.worker.worker_listeners[0])
# Grab the resource from the site that was told to listen
self.assertEqual(len(self.reactor.tcpServers), 1)
@@ -67,14 +71,8 @@ class FrontendProxyTests(HomeserverTestCase):
# Presence is off
self.hs.config.use_presence = False
- config = {
- "port": 8080,
- "bind_addresses": ["0.0.0.0"],
- "resources": [{"names": ["client"]}],
- }
-
# Listen with the config
- self.hs._listen_http(config)
+ self.hs._listen_http(self.hs.config.worker.worker_listeners[0])
# Grab the resource from the site that was told to listen
self.assertEqual(len(self.reactor.tcpServers), 1)
diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py
index 7364f9f1ec..0f016c32eb 100644
--- a/tests/app/test_openid_listener.py
+++ b/tests/app/test_openid_listener.py
@@ -18,6 +18,7 @@ from parameterized import parameterized
from synapse.app.generic_worker import GenericWorkerServer
from synapse.app.homeserver import SynapseHomeServer
+from synapse.config.server import parse_listener_def
from tests.unittest import HomeserverTestCase
@@ -35,6 +36,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
# have to tell the FederationHandler not to try to access stuff that is only
# in the primary store.
conf["worker_app"] = "yes"
+
return conf
@parameterized.expand(
@@ -53,12 +55,13 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
"""
config = {
"port": 8080,
+ "type": "http",
"bind_addresses": ["0.0.0.0"],
"resources": [{"names": names}],
}
# Listen with the config
- self.hs._listen_http(config)
+ self.hs._listen_http(parse_listener_def(config))
# Grab the resource from the site that was told to listen
site = self.reactor.tcpServers[0][1]
@@ -101,12 +104,13 @@ class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
"""
config = {
"port": 8080,
+ "type": "http",
"bind_addresses": ["0.0.0.0"],
"resources": [{"names": names}],
}
# Listen with the config
- self.hs._listener_http(config, config)
+ self.hs._listener_http(self.hs.get_config(), parse_listener_def(config))
# Grab the resource from the site that was told to listen
site = self.reactor.tcpServers[0][1]
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index ff12539041..1a9bd5f37d 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -21,6 +21,7 @@ from signedjson.types import BaseKey, SigningKey
from twisted.internet import defer
+from synapse.api.constants import RoomEncryptionAlgorithms
from synapse.rest import admin
from synapse.rest.client.v1 import login
from synapse.types import JsonDict, ReadReceipt
@@ -536,7 +537,10 @@ def build_device_dict(user_id: str, device_id: str, sk: SigningKey):
return {
"user_id": user_id,
"device_id": device_id,
- "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"],
+ "algorithms": [
+ "m.olm.curve25519-aes-sha2",
+ RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+ ],
"keys": {
"curve25519:" + device_id: "curve25519+key",
key_id(sk): encode_pubkey(sk),
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index e1e144b2e7..6c1dc72bd1 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -25,6 +25,7 @@ from twisted.internet import defer
import synapse.handlers.e2e_keys
import synapse.storage
from synapse.api import errors
+from synapse.api.constants import RoomEncryptionAlgorithms
from tests import unittest, utils
@@ -222,7 +223,10 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
device_key_1 = {
"user_id": local_user,
"device_id": "abc",
- "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"],
+ "algorithms": [
+ "m.olm.curve25519-aes-sha2",
+ RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+ ],
"keys": {
"ed25519:abc": "base64+ed25519+key",
"curve25519:abc": "base64+curve25519+key",
@@ -232,7 +236,10 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
device_key_2 = {
"user_id": local_user,
"device_id": "def",
- "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"],
+ "algorithms": [
+ "m.olm.curve25519-aes-sha2",
+ RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+ ],
"keys": {
"ed25519:def": "base64+ed25519+key",
"curve25519:def": "base64+curve25519+key",
@@ -315,7 +322,10 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
device_key = {
"user_id": local_user,
"device_id": device_id,
- "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"],
+ "algorithms": [
+ "m.olm.curve25519-aes-sha2",
+ RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+ ],
"keys": {"curve25519:xyz": "curve25519+key", "ed25519:xyz": device_pubkey},
"signatures": {local_user: {"ed25519:xyz": "something"}},
}
@@ -392,7 +402,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"device_id": device_id,
"algorithms": [
"m.olm.curve25519-aes-sha2",
- "m.megolm.v1.aes-sha2",
+ RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
],
"keys": {
"curve25519:xyz": "curve25519+key",
diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py
index 70f172eb02..822ea42dde 100644
--- a/tests/handlers/test_e2e_room_keys.py
+++ b/tests/handlers/test_e2e_room_keys.py
@@ -96,6 +96,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
# check we can retrieve it as the current version
res = yield self.handler.get_version_info(self.local_user)
version_etag = res["etag"]
+ self.assertIsInstance(version_etag, str)
del res["etag"]
self.assertDictEqual(
res,
diff --git a/tests/handlers/test_identity.py b/tests/handlers/test_identity.py
new file mode 100644
index 0000000000..0ab0356109
--- /dev/null
+++ b/tests/handlers/test_identity.py
@@ -0,0 +1,116 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from mock import Mock
+
+from twisted.internet import defer
+
+import synapse.rest.admin
+from synapse.rest.client.v1 import login
+from synapse.rest.client.v2_alpha import account
+
+from tests import unittest
+
+
+class ThreepidISRewrittenURLTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ account.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ self.address = "test@test"
+ self.is_server_name = "testis"
+ self.is_server_url = "https://testis"
+ self.rewritten_is_url = "https://int.testis"
+
+ config = self.default_config()
+ config["trusted_third_party_id_servers"] = [self.is_server_name]
+ config["rewrite_identity_server_urls"] = {
+ self.is_server_url: self.rewritten_is_url
+ }
+
+ mock_http_client = Mock(spec=["get_json", "post_json_get_json"])
+ mock_http_client.get_json.side_effect = defer.succeed({})
+ mock_http_client.post_json_get_json.return_value = defer.succeed(
+ {"address": self.address, "medium": "email"}
+ )
+
+ self.hs = self.setup_test_homeserver(
+ config=config, simple_http_client=mock_http_client
+ )
+
+ mock_blacklisting_http_client = Mock(spec=["get_json", "post_json_get_json"])
+ mock_blacklisting_http_client.get_json.side_effect = defer.succeed({})
+ mock_blacklisting_http_client.post_json_get_json.return_value = defer.succeed(
+ {"address": self.address, "medium": "email"}
+ )
+
+ # TODO: This class does not use a singleton to get it's http client
+ # This should be fixed for easier testing
+ # https://github.com/matrix-org/synapse-dinsic/issues/26
+ self.hs.get_handlers().identity_handler.blacklisting_http_client = (
+ mock_blacklisting_http_client
+ )
+
+ return self.hs
+
+ def prepare(self, reactor, clock, hs):
+ self.user_id = self.register_user("kermit", "monkey")
+
+ def test_rewritten_id_server(self):
+ """
+ Tests that, when validating a 3PID association while rewriting the IS's server
+ name:
+ * the bind request is done against the rewritten hostname
+ * the original, non-rewritten, server name is stored in the database
+ """
+ handler = self.hs.get_handlers().identity_handler
+ post_json_get_json = handler.blacklisting_http_client.post_json_get_json
+ store = self.hs.get_datastore()
+
+ creds = {"sid": "123", "client_secret": "some_secret"}
+
+ # Make sure processing the mocked response goes through.
+ data = self.get_success(
+ handler.bind_threepid(
+ client_secret=creds["client_secret"],
+ sid=creds["sid"],
+ mxid=self.user_id,
+ id_server=self.is_server_name,
+ use_v2=False,
+ )
+ )
+ self.assertEqual(data.get("address"), self.address)
+
+ # Check that the request was done against the rewritten server name.
+ post_json_get_json.assert_called_once_with(
+ "%s/_matrix/identity/api/v1/3pid/bind" % (self.rewritten_is_url,),
+ {
+ "sid": creds["sid"],
+ "client_secret": creds["client_secret"],
+ "mxid": self.user_id,
+ },
+ headers={},
+ )
+
+ # Check that the original server name is saved in the database instead of the
+ # rewritten one.
+ id_servers = self.get_success(
+ store.get_id_servers_user_bound(self.user_id, "email", self.address)
+ )
+ self.assertEqual(id_servers, [self.is_server_name])
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 29dd7d9c6e..a1f4bde347 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -63,14 +63,12 @@ class ProfileTestCase(unittest.TestCase):
self.bob = UserID.from_string("@4567:test")
self.alice = UserID.from_string("@alice:remote")
- yield self.store.create_profile(self.frank.localpart)
-
self.handler = hs.get_profile_handler()
self.hs = hs
@defer.inlineCallbacks
def test_get_my_name(self):
- yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
+ yield self.store.set_profile_displayname(self.frank.localpart, "Frank", 1)
displayname = yield self.handler.get_displayname(self.frank)
@@ -109,7 +107,7 @@ class ProfileTestCase(unittest.TestCase):
self.hs.config.enable_set_displayname = False
# Setting displayname for the first time is allowed
- yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
+ yield self.store.set_profile_displayname(self.frank.localpart, "Frank", 1)
self.assertEquals(
(yield self.store.get_profile_displayname(self.frank.localpart)), "Frank",
@@ -152,8 +150,7 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_incoming_fed_query(self):
- yield self.store.create_profile("caroline")
- yield self.store.set_profile_displayname("caroline", "Caroline")
+ yield self.store.set_profile_displayname("caroline", "Caroline", 1)
response = yield self.query_handlers["profile"](
{"user_id": "@caroline:test", "field": "displayname"}
@@ -164,7 +161,7 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_my_avatar(self):
yield self.store.set_profile_avatar_url(
- self.frank.localpart, "http://my.server/me.png"
+ self.frank.localpart, "http://my.server/me.png", 1
)
avatar_url = yield self.handler.get_avatar_url(self.frank)
@@ -206,7 +203,7 @@ class ProfileTestCase(unittest.TestCase):
# Setting displayname for the first time is allowed
yield self.store.set_profile_avatar_url(
- self.frank.localpart, "http://my.server/me.png"
+ self.frank.localpart, "http://my.server/me.png", 1
)
self.assertEquals(
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index ca32f993a3..a7f52067d0 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -20,8 +20,16 @@ from twisted.internet import defer
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, ResourceLimitError, SynapseError
from synapse.handlers.register import RegistrationHandler
+from synapse.http.site import SynapseRequest
+from synapse.rest.client.v2_alpha.register import (
+ _map_email_to_displayname,
+ register_servlets,
+)
from synapse.types import RoomAlias, UserID, create_requester
+from tests.server import FakeChannel
+from tests.unittest import override_config
+
from .. import unittest
@@ -33,6 +41,10 @@ class RegistrationHandlers(object):
class RegistrationTestCase(unittest.HomeserverTestCase):
""" Tests the RegistrationHandler. """
+ servlets = [
+ register_servlets,
+ ]
+
def make_homeserver(self, reactor, clock):
hs_config = self.default_config()
@@ -266,6 +278,98 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.handler.register_user(localpart=invalid_user_id), SynapseError
)
+ def test_email_to_displayname_mapping(self):
+ """Test that custom emails are mapped to new user displaynames correctly"""
+ self._check_mapping(
+ "jack-phillips.rivers@big-org.com", "Jack-Phillips Rivers [Big-Org]"
+ )
+
+ self._check_mapping("bob.jones@matrix.org", "Bob Jones [Tchap Admin]")
+
+ self._check_mapping("bob-jones.blabla@gouv.fr", "Bob-Jones Blabla [Gouv]")
+
+ # Multibyte unicode characters
+ self._check_mapping(
+ "j\u030a\u0065an-poppy.seed@example.com",
+ "J\u030a\u0065an-Poppy Seed [Example]",
+ )
+
+ def _check_mapping(self, i, expected):
+ result = _map_email_to_displayname(i)
+ self.assertEqual(result, expected)
+
+ @override_config(
+ {
+ "bind_new_user_emails_to_sydent": "https://is.example.com",
+ "registrations_require_3pid": ["email"],
+ "account_threepid_delegates": {},
+ "email": {
+ "smtp_host": "127.0.0.1",
+ "smtp_port": 20,
+ "require_transport_security": False,
+ "smtp_user": None,
+ "smtp_pass": None,
+ "notif_from": "test@example.com",
+ },
+ "public_baseurl": "http://localhost",
+ }
+ )
+ def test_user_email_bound_via_sydent_internal_api(self):
+ """Tests that emails are bound after registration if this option is set"""
+ # Register user with an email address
+ email = "alice@example.com"
+
+ # Mock Synapse's threepid validator
+ get_threepid_validation_session = Mock(
+ return_value=defer.succeed(
+ {"medium": "email", "address": email, "validated_at": 0}
+ )
+ )
+ self.store.get_threepid_validation_session = get_threepid_validation_session
+ delete_threepid_session = Mock(return_value=defer.succeed(None))
+ self.store.delete_threepid_session = delete_threepid_session
+
+ # Mock Synapse's http json post method to check for the internal bind call
+ post_json_get_json = Mock(return_value=defer.succeed(None))
+ self.hs.get_simple_http_client().post_json_get_json = post_json_get_json
+
+ # Retrieve a UIA session ID
+ channel = self.uia_register(
+ 401, {"username": "alice", "password": "nobodywillguessthis"}
+ )
+ session_id = channel.json_body["session"]
+
+ # Register our email address using the fake validation session above
+ channel = self.uia_register(
+ 200,
+ {
+ "username": "alice",
+ "password": "nobodywillguessthis",
+ "auth": {
+ "session": session_id,
+ "type": "m.login.email.identity",
+ "threepid_creds": {"sid": "blabla", "client_secret": "blablabla"},
+ },
+ },
+ )
+ self.assertEqual(channel.json_body["user_id"], "@alice:test")
+
+ # Check that a bind attempt was made to our fake identity server
+ post_json_get_json.assert_called_with(
+ "https://is.example.com/_matrix/identity/internal/bind",
+ {"address": "alice@example.com", "medium": "email", "mxid": "@alice:test"},
+ )
+
+ def uia_register(self, expected_response: int, body: dict) -> FakeChannel:
+ """Make a register request."""
+ request, channel = self.make_request(
+ "POST", "register", body
+ ) # type: SynapseRequest, FakeChannel
+ self.render(request)
+
+ self.assertEqual(request.code, expected_response)
+ return channel
+
async def get_or_create_user(
self, requester, localpart, displayname, password_hash=None
):
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index d9d312f0fb..07092f026a 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -21,8 +21,14 @@ from tests import unittest
# The expected number of state events in a fresh public room.
EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM = 5
+
# The expected number of state events in a fresh private room.
-EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM = 6
+#
+# Note: we increase this by 2 on the dinsic branch as we send
+# a "im.vector.room.access_rules" state event into new private rooms,
+# and an encryption state event as all private rooms are encrypted
+# by default
+EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM = 7
class StatsRoomTests(unittest.HomeserverTestCase):
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index c15bce5bef..ddee8d9e3a 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -17,12 +17,13 @@ from mock import Mock
from twisted.internet import defer
import synapse.rest.admin
-from synapse.api.constants import UserTypes
+from synapse.api.constants import EventTypes, RoomEncryptionAlgorithms, UserTypes
from synapse.rest.client.v1 import login, room
-from synapse.rest.client.v2_alpha import user_directory
+from synapse.rest.client.v2_alpha import account, account_validity, user_directory
from synapse.storage.roommember import ProfileInfo
from tests import unittest
+from tests.unittest import override_config
class UserDirectoryTestCase(unittest.HomeserverTestCase):
@@ -147,6 +148,94 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user3", 10))
self.assertEqual(len(s["results"]), 0)
+ @override_config({"encryption_enabled_by_default_for_room_type": "all"})
+ def test_encrypted_by_default_config_option_all(self):
+ """Tests that invite-only and non-invite-only rooms have encryption enabled by
+ default when the config option encryption_enabled_by_default_for_room_type is "all".
+ """
+ # Create a user
+ user = self.register_user("user", "pass")
+ user_token = self.login(user, "pass")
+
+ # Create an invite-only room as that user
+ room_id = self.helper.create_room_as(user, is_public=False, tok=user_token)
+
+ # Check that the room has an encryption state event
+ event_content = self.helper.get_state(
+ room_id=room_id, event_type=EventTypes.RoomEncryption, tok=user_token,
+ )
+ self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
+
+ # Create a non invite-only room as that user
+ room_id = self.helper.create_room_as(user, is_public=True, tok=user_token)
+
+ # Check that the room has an encryption state event
+ event_content = self.helper.get_state(
+ room_id=room_id, event_type=EventTypes.RoomEncryption, tok=user_token,
+ )
+ self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
+
+ @override_config({"encryption_enabled_by_default_for_room_type": "invite"})
+ def test_encrypted_by_default_config_option_invite(self):
+ """Tests that only new, invite-only rooms have encryption enabled by default when
+ the config option encryption_enabled_by_default_for_room_type is "invite".
+ """
+ # Create a user
+ user = self.register_user("user", "pass")
+ user_token = self.login(user, "pass")
+
+ # Create an invite-only room as that user
+ room_id = self.helper.create_room_as(user, is_public=False, tok=user_token)
+
+ # Check that the room has an encryption state event
+ event_content = self.helper.get_state(
+ room_id=room_id, event_type=EventTypes.RoomEncryption, tok=user_token,
+ )
+ self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
+
+ # Create a non invite-only room as that user
+ room_id = self.helper.create_room_as(user, is_public=True, tok=user_token)
+
+ # Check that the room does not have an encryption state event
+ self.helper.get_state(
+ room_id=room_id,
+ event_type=EventTypes.RoomEncryption,
+ tok=user_token,
+ expect_code=404,
+ )
+
+ @override_config({"encryption_enabled_by_default_for_room_type": "off"})
+ def test_encrypted_by_default_config_option_off(self):
+ """Tests that neither new invite-only nor non-invite-only rooms have encryption
+ enabled by default when the config option
+ encryption_enabled_by_default_for_room_type is "off".
+ """
+ # Create a user
+ user = self.register_user("user", "pass")
+ user_token = self.login(user, "pass")
+
+ # Create an invite-only room as that user
+ room_id = self.helper.create_room_as(user, is_public=False, tok=user_token)
+
+ # Check that the room does not have an encryption state event
+ self.helper.get_state(
+ room_id=room_id,
+ event_type=EventTypes.RoomEncryption,
+ tok=user_token,
+ expect_code=404,
+ )
+
+ # Create a non invite-only room as that user
+ room_id = self.helper.create_room_as(user, is_public=True, tok=user_token)
+
+ # Check that the room does not have an encryption state event
+ self.helper.get_state(
+ room_id=room_id,
+ event_type=EventTypes.RoomEncryption,
+ tok=user_token,
+ expect_code=404,
+ )
+
def test_spam_checker(self):
"""
A user which fails to the spam checks will not appear in search results.
@@ -460,3 +549,136 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
self.render(request)
self.assertEquals(200, channel.code, channel.result)
self.assertTrue(len(channel.json_body["results"]) == 0)
+
+
+class UserInfoTestCase(unittest.FederatingHomeserverTestCase):
+ servlets = [
+ login.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ account_validity.register_servlets,
+ synapse.rest.client.v2_alpha.user_directory.register_servlets,
+ account.register_servlets,
+ ]
+
+ def default_config(self):
+ config = super().default_config()
+
+ # Set accounts to expire after a week
+ config["account_validity"] = {
+ "enabled": True,
+ "period": 604800000, # Time in ms for 1 week
+ }
+ return config
+
+ def prepare(self, reactor, clock, hs):
+ super(UserInfoTestCase, self).prepare(reactor, clock, hs)
+ self.store = hs.get_datastore()
+ self.handler = hs.get_user_directory_handler()
+
+ def test_user_info(self):
+ """Test /users/info for local users from the Client-Server API"""
+ user_one, user_two, user_three, user_three_token = self.setup_test_users()
+
+ # Request info about each user from user_three
+ request, channel = self.make_request(
+ "POST",
+ path="/_matrix/client/unstable/users/info",
+ content={"user_ids": [user_one, user_two, user_three]},
+ access_token=user_three_token,
+ shorthand=False,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+
+ # Check the state of user_one matches
+ user_one_info = channel.json_body[user_one]
+ self.assertTrue(user_one_info["deactivated"])
+ self.assertFalse(user_one_info["expired"])
+
+ # Check the state of user_two matches
+ user_two_info = channel.json_body[user_two]
+ self.assertFalse(user_two_info["deactivated"])
+ self.assertTrue(user_two_info["expired"])
+
+ # Check the state of user_three matches
+ user_three_info = channel.json_body[user_three]
+ self.assertFalse(user_three_info["deactivated"])
+ self.assertFalse(user_three_info["expired"])
+
+ def test_user_info_federation(self):
+ """Test that /users/info can be called from the Federation API, and
+ and that we can query remote users from the Client-Server API
+ """
+ user_one, user_two, user_three, user_three_token = self.setup_test_users()
+
+ # Request information about our local users from the perspective of a remote server
+ request, channel = self.make_request(
+ "POST",
+ path="/_matrix/federation/unstable/users/info",
+ content={"user_ids": [user_one, user_two, user_three]},
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
+
+ # Check the state of user_one matches
+ user_one_info = channel.json_body[user_one]
+ self.assertTrue(user_one_info["deactivated"])
+ self.assertFalse(user_one_info["expired"])
+
+ # Check the state of user_two matches
+ user_two_info = channel.json_body[user_two]
+ self.assertFalse(user_two_info["deactivated"])
+ self.assertTrue(user_two_info["expired"])
+
+ # Check the state of user_three matches
+ user_three_info = channel.json_body[user_three]
+ self.assertFalse(user_three_info["deactivated"])
+ self.assertFalse(user_three_info["expired"])
+
+ def setup_test_users(self):
+ """Create an admin user and three test users, each with a different state"""
+
+ # Create an admin user to expire other users with
+ self.register_user("admin", "adminpassword", admin=True)
+ admin_token = self.login("admin", "adminpassword")
+
+ # Create three users
+ user_one = self.register_user("alice", "pass")
+ user_one_token = self.login("alice", "pass")
+ user_two = self.register_user("bob", "pass")
+ user_three = self.register_user("carl", "pass")
+ user_three_token = self.login("carl", "pass")
+
+ # Deactivate user_one
+ self.deactivate(user_one, user_one_token)
+
+ # Expire user_two
+ self.expire(user_two, admin_token)
+
+ # Do nothing to user_three
+
+ return user_one, user_two, user_three, user_three_token
+
+ def expire(self, user_id_to_expire, admin_tok):
+ url = "/_matrix/client/unstable/admin/account_validity/validity"
+ request_data = {
+ "user_id": user_id_to_expire,
+ "expiration_ts": 0,
+ "enable_renewal_emails": False,
+ }
+ request, channel = self.make_request(
+ "POST", url, request_data, access_token=admin_tok
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ def deactivate(self, user_id, tok):
+ request_data = {
+ "auth": {"type": "m.login.password", "user": user_id, "password": "pass"},
+ "erase": False,
+ }
+ request, channel = self.make_request(
+ "POST", "account/deactivate", request_data, access_token=tok
+ )
+ self.render(request)
+ self.assertEqual(request.code, 200)
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 562397cdda..c2752d57ae 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -86,13 +86,15 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.well_known_resolver = WellKnownResolver(
self.reactor,
Agent(self.reactor, contextFactory=self.tls_factory),
+ b"test-agent",
well_known_cache=self.well_known_cache,
had_well_known_cache=self.had_well_known_cache,
)
self.agent = MatrixFederationAgent(
reactor=self.reactor,
- tls_client_options_factory=self.tls_factory,
+ tls_client_options_factory=FederationPolicyForHTTPS(config),
+ user_agent="test-agent", # Note that this is unused since _well_known_resolver is provided.
_srv_resolver=self.mock_resolver,
_well_known_resolver=self.well_known_resolver,
)
@@ -186,6 +188,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
# check the .well-known request and send a response
self.assertEqual(len(well_known_server.requests), 1)
request = well_known_server.requests[0]
+ self.assertEqual(
+ request.requestHeaders.getRawHeaders(b"user-agent"), [b"test-agent"]
+ )
self._send_well_known_response(request, content, headers=response_headers)
return well_known_server
@@ -231,6 +236,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.assertEqual(
request.requestHeaders.getRawHeaders(b"host"), [b"testserv:8448"]
)
+ self.assertEqual(
+ request.requestHeaders.getRawHeaders(b"user-agent"), [b"test-agent"]
+ )
content = request.content.read()
self.assertEqual(content, b"")
@@ -719,10 +727,12 @@ class MatrixFederationAgentTests(unittest.TestCase):
agent = MatrixFederationAgent(
reactor=self.reactor,
tls_client_options_factory=tls_factory,
+ user_agent=b"test-agent", # This is unused since _well_known_resolver is passed below.
_srv_resolver=self.mock_resolver,
_well_known_resolver=WellKnownResolver(
self.reactor,
Agent(self.reactor, contextFactory=tls_factory),
+ b"test-agent",
well_known_cache=self.well_known_cache,
had_well_known_cache=self.had_well_known_cache,
),
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index 9ae6a87d7b..af35d23aea 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -21,7 +21,7 @@ from tests import unittest
class PushRuleEvaluatorTestCase(unittest.TestCase):
- def setUp(self):
+ def _get_evaluator(self, content):
event = FrozenEvent(
{
"event_id": "$event_id",
@@ -29,37 +29,58 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
"sender": "@user:test",
"state_key": "",
"room_id": "@room:test",
- "content": {"body": "foo bar baz"},
+ "content": content,
},
RoomVersions.V1,
)
room_member_count = 0
sender_power_level = 0
power_levels = {}
- self.evaluator = PushRuleEvaluatorForEvent(
+ return PushRuleEvaluatorForEvent(
event, room_member_count, sender_power_level, power_levels
)
def test_display_name(self):
"""Check for a matching display name in the body of the event."""
+ evaluator = self._get_evaluator({"body": "foo bar baz"})
+
condition = {
"kind": "contains_display_name",
}
# Blank names are skipped.
- self.assertFalse(self.evaluator.matches(condition, "@user:test", ""))
+ self.assertFalse(evaluator.matches(condition, "@user:test", ""))
# Check a display name that doesn't match.
- self.assertFalse(self.evaluator.matches(condition, "@user:test", "not found"))
+ self.assertFalse(evaluator.matches(condition, "@user:test", "not found"))
# Check a display name which matches.
- self.assertTrue(self.evaluator.matches(condition, "@user:test", "foo"))
+ self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
# A display name that matches, but not a full word does not result in a match.
- self.assertFalse(self.evaluator.matches(condition, "@user:test", "ba"))
+ self.assertFalse(evaluator.matches(condition, "@user:test", "ba"))
# A display name should not be interpreted as a regular expression.
- self.assertFalse(self.evaluator.matches(condition, "@user:test", "ba[rz]"))
+ self.assertFalse(evaluator.matches(condition, "@user:test", "ba[rz]"))
# A display name with spaces should work fine.
- self.assertTrue(self.evaluator.matches(condition, "@user:test", "foo bar"))
+ self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar"))
+
+ def test_no_body(self):
+ """Not having a body shouldn't break the evaluator."""
+ evaluator = self._get_evaluator({})
+
+ condition = {
+ "kind": "contains_display_name",
+ }
+ self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
+
+ def test_invalid_body(self):
+ """A non-string body should not break the evaluator."""
+ condition = {
+ "kind": "contains_display_name",
+ }
+
+ for body in (1, True, {"foo": "bar"}):
+ evaluator = self._get_evaluator({"body": body})
+ self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 1a88c7fb80..cd8680e812 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -160,7 +160,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id],
- {"highlight_count": 0, "notify_count": 0},
+ {"highlight_count": 0, "notify_count": 0, "unread_count": 0},
)
self.persist(
@@ -173,7 +173,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id],
- {"highlight_count": 0, "notify_count": 1},
+ {"highlight_count": 0, "notify_count": 1, "unread_count": 1},
)
self.persist(
@@ -188,7 +188,20 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id],
- {"highlight_count": 1, "notify_count": 2},
+ {"highlight_count": 1, "notify_count": 2, "unread_count": 2},
+ )
+
+ self.persist(
+ type="m.room.message",
+ msgtype="m.text",
+ body="world",
+ push_actions=[(USER_ID_2, ["org.matrix.msc2625.mark_unread"])],
+ )
+ self.replicate()
+ self.check(
+ "get_unread_event_push_actions_by_room_for_user",
+ [ROOM_ID, USER_ID_2, event1.event_id],
+ {"highlight_count": 1, "notify_count": 2, "unread_count": 3},
)
def test_get_rooms_for_user_with_stream_ordering(self):
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index 51bf0ef4e9..097e1653b4 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -17,6 +17,7 @@ from typing import List, Optional
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
+from synapse.replication.tcp.commands import RdataCommand
from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT
from synapse.replication.tcp.streams.events import (
EventsStreamCurrentStateRow,
@@ -66,11 +67,6 @@ class EventsStreamTestCase(BaseStreamTestCase):
# also one state event
state_event = self._inject_state_event()
- # tell the notifier to catch up to avoid duplicate rows.
- # workaround for https://github.com/matrix-org/synapse/issues/7360
- # FIXME remove this when the above is fixed
- self.replicate()
-
# check we're testing what we think we are: no rows should yet have been
# received
self.assertEqual([], self.test_handler.received_rdata_rows)
@@ -174,11 +170,6 @@ class EventsStreamTestCase(BaseStreamTestCase):
# one more bit of state that doesn't get rolled back
state2 = self._inject_state_event()
- # tell the notifier to catch up to avoid duplicate rows.
- # workaround for https://github.com/matrix-org/synapse/issues/7360
- # FIXME remove this when the above is fixed
- self.replicate()
-
# check we're testing what we think we are: no rows should yet have been
# received
self.assertEqual([], self.test_handler.received_rdata_rows)
@@ -327,11 +318,6 @@ class EventsStreamTestCase(BaseStreamTestCase):
prev_events = [e.event_id]
pl_events.append(e)
- # tell the notifier to catch up to avoid duplicate rows.
- # workaround for https://github.com/matrix-org/synapse/issues/7360
- # FIXME remove this when the above is fixed
- self.replicate()
-
# check we're testing what we think we are: no rows should yet have been
# received
self.assertEqual([], self.test_handler.received_rdata_rows)
@@ -378,6 +364,64 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.assertEqual([], received_rows)
+ def test_backwards_stream_id(self):
+ """
+ Test that RDATA that comes after the current position should be discarded.
+ """
+ # disconnect, so that we can stack up some changes
+ self.disconnect()
+
+ # Generate an events. We inject them using inject_event so that they are
+ # not send out over replication until we call self.replicate().
+ event = self._inject_test_event()
+
+ # check we're testing what we think we are: no rows should yet have been
+ # received
+ self.assertEqual([], self.test_handler.received_rdata_rows)
+
+ # now reconnect to pull the updates
+ self.reconnect()
+ self.replicate()
+
+ # We should have received the expected single row (as well as various
+ # cache invalidation updates which we ignore).
+ received_rows = [
+ row for row in self.test_handler.received_rdata_rows if row[0] == "events"
+ ]
+
+ # There should be a single received row.
+ self.assertEqual(len(received_rows), 1)
+
+ stream_name, token, row = received_rows[0]
+ self.assertEqual("events", stream_name)
+ self.assertIsInstance(row, EventsStreamRow)
+ self.assertEqual(row.type, "ev")
+ self.assertIsInstance(row.data, EventsStreamEventRow)
+ self.assertEqual(row.data.event_id, event.event_id)
+
+ # Reset the data.
+ self.test_handler.received_rdata_rows = []
+
+ # Save the current token for later.
+ worker_events_stream = self.worker_hs.get_replication_streams()["events"]
+ prev_token = worker_events_stream.current_token("master")
+
+ # Manually send an old RDATA command, which should get dropped. This
+ # re-uses the row from above, but with an earlier stream token.
+ self.hs.get_tcp_replication().send_command(
+ RdataCommand("events", "master", 1, row)
+ )
+
+ # No updates have been received (because it was discard as old).
+ received_rows = [
+ row for row in self.test_handler.received_rdata_rows if row[0] == "events"
+ ]
+ self.assertEqual(len(received_rows), 0)
+
+ # Ensure the stream has not gone backwards.
+ current_token = worker_events_stream.current_token("master")
+ self.assertGreaterEqual(current_token, prev_token)
+
event_count = 0
def _inject_test_event(
diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index fd62b26356..5acfb3e53e 100644
--- a/tests/replication/tcp/streams/test_typing.py
+++ b/tests/replication/tcp/streams/test_typing.py
@@ -16,10 +16,15 @@ from mock import Mock
from synapse.handlers.typing import RoomMember
from synapse.replication.tcp.streams import TypingStream
+from synapse.util.caches.stream_change_cache import StreamChangeCache
from tests.replication._base import BaseStreamTestCase
USER_ID = "@feeling:blue"
+USER_ID_2 = "@da-ba-dee:blue"
+
+ROOM_ID = "!bar:blue"
+ROOM_ID_2 = "!foo:blue"
class TypingStreamTestCase(BaseStreamTestCase):
@@ -29,11 +34,9 @@ class TypingStreamTestCase(BaseStreamTestCase):
def test_typing(self):
typing = self.hs.get_typing_handler()
- room_id = "!bar:blue"
-
self.reconnect()
- typing._push_update(member=RoomMember(room_id, USER_ID), typing=True)
+ typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=True)
self.reactor.advance(0)
@@ -46,7 +49,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0] # type: TypingStream.TypingStreamRow
- self.assertEqual(room_id, row.room_id)
+ self.assertEqual(ROOM_ID, row.room_id)
self.assertEqual([USER_ID], row.user_ids)
# Now let's disconnect and insert some data.
@@ -54,7 +57,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
self.test_handler.on_rdata.reset_mock()
- typing._push_update(member=RoomMember(room_id, USER_ID), typing=False)
+ typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=False)
self.test_handler.on_rdata.assert_not_called()
@@ -73,5 +76,78 @@ class TypingStreamTestCase(BaseStreamTestCase):
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0]
- self.assertEqual(room_id, row.room_id)
+ self.assertEqual(ROOM_ID, row.room_id)
+ self.assertEqual([], row.user_ids)
+
+ def test_reset(self):
+ """
+ Test what happens when a typing stream resets.
+
+ This is emulated by jumping the stream ahead, then reconnecting (which
+ sends the proper position and RDATA).
+ """
+ typing = self.hs.get_typing_handler()
+
+ self.reconnect()
+
+ typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=True)
+
+ self.reactor.advance(0)
+
+ # We should now see an attempt to connect to the master
+ request = self.handle_http_replication_attempt()
+ self.assert_request_is_get_repl_stream_updates(request, "typing")
+
+ self.test_handler.on_rdata.assert_called_once()
+ stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+ self.assertEqual(stream_name, "typing")
+ self.assertEqual(1, len(rdata_rows))
+ row = rdata_rows[0] # type: TypingStream.TypingStreamRow
+ self.assertEqual(ROOM_ID, row.room_id)
+ self.assertEqual([USER_ID], row.user_ids)
+
+ # Push the stream forward a bunch so it can be reset.
+ for i in range(100):
+ typing._push_update(
+ member=RoomMember(ROOM_ID, "@test%s:blue" % i), typing=True
+ )
+ self.reactor.advance(0)
+
+ # Disconnect.
+ self.disconnect()
+
+ # Reset the typing handler
+ self.hs.get_replication_streams()["typing"].last_token = 0
+ self.hs.get_tcp_replication()._streams["typing"].last_token = 0
+ typing._latest_room_serial = 0
+ typing._typing_stream_change_cache = StreamChangeCache(
+ "TypingStreamChangeCache", typing._latest_room_serial
+ )
+ typing._reset()
+
+ # Reconnect.
+ self.reconnect()
+ self.pump(0.1)
+
+ # We should now see an attempt to connect to the master
+ request = self.handle_http_replication_attempt()
+ self.assert_request_is_get_repl_stream_updates(request, "typing")
+
+ # Reset the test code.
+ self.test_handler.on_rdata.reset_mock()
+ self.test_handler.on_rdata.assert_not_called()
+
+ # Push additional data.
+ typing._push_update(member=RoomMember(ROOM_ID_2, USER_ID_2), typing=False)
+ self.reactor.advance(0)
+
+ self.test_handler.on_rdata.assert_called_once()
+ stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+ self.assertEqual(stream_name, "typing")
+ self.assertEqual(1, len(rdata_rows))
+ row = rdata_rows[0]
+ self.assertEqual(ROOM_ID_2, row.room_id)
self.assertEqual([], row.user_ids)
+
+ # The token should have been reset.
+ self.assertEqual(token, 1)
diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py
index c973521907..4224b0a92e 100644
--- a/tests/rest/client/test_identity.py
+++ b/tests/rest/client/test_identity.py
@@ -15,15 +15,22 @@
import json
+from mock import Mock
+
+from twisted.internet import defer
+
import synapse.rest.admin
from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v2_alpha import account
from tests import unittest
-class IdentityTestCase(unittest.HomeserverTestCase):
+class IdentityDisabledTestCase(unittest.HomeserverTestCase):
+ """Tests that 3PID lookup attempts fail when the HS's config disallows them."""
servlets = [
+ account.register_servlets,
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
login.register_servlets,
@@ -32,24 +39,111 @@ class IdentityTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
config = self.default_config()
+ config["trusted_third_party_id_servers"] = ["testis"]
config["enable_3pid_lookup"] = False
self.hs = self.setup_test_homeserver(config=config)
return self.hs
+ def prepare(self, reactor, clock, hs):
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+
+ def test_3pid_invite_disabled(self):
+ request, channel = self.make_request(
+ b"POST", "/createRoom", b"{}", access_token=self.tok
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+ room_id = channel.json_body["room_id"]
+
+ params = {
+ "id_server": "testis",
+ "medium": "email",
+ "address": "test@example.com",
+ }
+ request_data = json.dumps(params)
+ request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii")
+ request, channel = self.make_request(
+ b"POST", request_url, request_data, access_token=self.tok
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"403", channel.result)
+
def test_3pid_lookup_disabled(self):
- self.hs.config.enable_3pid_lookup = False
+ url = (
+ "/_matrix/client/unstable/account/3pid/lookup"
+ "?id_server=testis&medium=email&address=foo@bar.baz"
+ )
+ request, channel = self.make_request("GET", url, access_token=self.tok)
+ self.render(request)
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+
+ def test_3pid_bulk_lookup_disabled(self):
+ url = "/_matrix/client/unstable/account/3pid/bulk_lookup"
+ data = {
+ "id_server": "testis",
+ "threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]],
+ }
+ request_data = json.dumps(data)
+ request, channel = self.make_request(
+ "POST", url, request_data, access_token=self.tok
+ )
+ self.render(request)
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+
+
+class IdentityEnabledTestCase(unittest.HomeserverTestCase):
+ """Tests that 3PID lookup attempts succeed when the HS's config allows them."""
+
+ servlets = [
+ account.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
- self.register_user("kermit", "monkey")
- tok = self.login("kermit", "monkey")
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+ config["enable_3pid_lookup"] = True
+ config["trusted_third_party_id_servers"] = ["testis"]
+
+ mock_http_client = Mock(spec=["get_json", "post_json_get_json"])
+ mock_http_client.get_json.return_value = defer.succeed((200, "{}"))
+ mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}"))
+
+ self.hs = self.setup_test_homeserver(
+ config=config, simple_http_client=mock_http_client
+ )
+
+ # TODO: This class does not use a singleton to get it's http client
+ # This should be fixed for easier testing
+ # https://github.com/matrix-org/synapse-dinsic/issues/26
+ self.hs.get_handlers().identity_handler.http_client = mock_http_client
+
+ return self.hs
+
+ def prepare(self, reactor, clock, hs):
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+
+ def test_3pid_invite_enabled(self):
request, channel = self.make_request(
- b"POST", "/createRoom", b"{}", access_token=tok
+ b"POST", "/createRoom", b"{}", access_token=self.tok
)
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
room_id = channel.json_body["room_id"]
+ # Replace the blacklisting SimpleHttpClient with our mock
+ self.hs.get_room_member_handler().simple_http_client = Mock(
+ spec=["get_json", "post_json_get_json"]
+ )
+ self.hs.get_room_member_handler().simple_http_client.get_json.return_value = defer.succeed(
+ (200, "{}")
+ )
+
params = {
"id_server": "testis",
"medium": "email",
@@ -58,7 +152,44 @@ class IdentityTestCase(unittest.HomeserverTestCase):
request_data = json.dumps(params)
request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii")
request, channel = self.make_request(
- b"POST", request_url, request_data, access_token=tok
+ b"POST", request_url, request_data, access_token=self.tok
)
self.render(request)
- self.assertEquals(channel.result["code"], b"403", channel.result)
+
+ get_json = self.hs.get_handlers().identity_handler.http_client.get_json
+ get_json.assert_called_once_with(
+ "https://testis/_matrix/identity/api/v1/lookup",
+ {"address": "test@example.com", "medium": "email"},
+ )
+
+ def test_3pid_lookup_enabled(self):
+ url = (
+ "/_matrix/client/unstable/account/3pid/lookup"
+ "?id_server=testis&medium=email&address=foo@bar.baz"
+ )
+ request, channel = self.make_request("GET", url, access_token=self.tok)
+ self.render(request)
+
+ get_json = self.hs.get_simple_http_client().get_json
+ get_json.assert_called_once_with(
+ "https://testis/_matrix/identity/api/v1/lookup",
+ {"address": "foo@bar.baz", "medium": "email"},
+ )
+
+ def test_3pid_bulk_lookup_enabled(self):
+ url = "/_matrix/client/unstable/account/3pid/bulk_lookup"
+ data = {
+ "id_server": "testis",
+ "threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]],
+ }
+ request_data = json.dumps(data)
+ request, channel = self.make_request(
+ "POST", url, request_data, access_token=self.tok
+ )
+ self.render(request)
+
+ post_json = self.hs.get_simple_http_client().post_json_get_json
+ post_json.assert_called_once_with(
+ "https://testis/_matrix/identity/api/v1/bulk_lookup",
+ {"threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]]},
+ )
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index 95475bb651..9e549d8a91 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -34,6 +34,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
config = self.default_config()
+ config["default_room_version"] = "1"
config["retention"] = {
"enabled": True,
"default_policy": {
@@ -203,6 +204,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
config = self.default_config()
+ config["default_room_version"] = "1"
config["retention"] = {
"enabled": True,
}
diff --git a/tests/rest/client/test_room_access_rules.py b/tests/rest/client/test_room_access_rules.py
new file mode 100644
index 0000000000..7da0ef4e18
--- /dev/null
+++ b/tests/rest/client/test_room_access_rules.py
@@ -0,0 +1,727 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import json
+import random
+import string
+
+from mock import Mock
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.third_party_rules.access_rules import (
+ ACCESS_RULE_DIRECT,
+ ACCESS_RULE_RESTRICTED,
+ ACCESS_RULE_UNRESTRICTED,
+ ACCESS_RULES_TYPE,
+)
+
+from tests import unittest
+
+
+class RoomAccessTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+
+ config["third_party_event_rules"] = {
+ "module": "synapse.third_party_rules.access_rules.RoomAccessRules",
+ "config": {
+ "domains_forbidden_when_restricted": ["forbidden_domain"],
+ "id_server": "testis",
+ },
+ }
+ config["trusted_third_party_id_servers"] = ["testis"]
+
+ def send_invite(destination, room_id, event_id, pdu):
+ return defer.succeed(pdu)
+
+ def get_json(uri, args={}, headers=None):
+ address_domain = args["address"].split("@")[1]
+ return defer.succeed({"hs": address_domain})
+
+ def post_json_get_json(uri, post_json, args={}, headers=None):
+ token = "".join(random.choice(string.ascii_letters) for _ in range(10))
+ return defer.succeed(
+ {
+ "token": token,
+ "public_keys": [
+ {
+ "public_key": "serverpublickey",
+ "key_validity_url": "https://testis/pubkey/isvalid",
+ },
+ {
+ "public_key": "phemeralpublickey",
+ "key_validity_url": "https://testis/pubkey/ephemeral/isvalid",
+ },
+ ],
+ "display_name": "f...@b...",
+ }
+ )
+
+ mock_federation_client = Mock(spec=["send_invite"])
+ mock_federation_client.send_invite.side_effect = send_invite
+
+ mock_http_client = Mock(spec=["get_json", "post_json_get_json"],)
+ # Mocking the response for /info on the IS API.
+ mock_http_client.get_json.side_effect = get_json
+ # Mocking the response for /store-invite on the IS API.
+ mock_http_client.post_json_get_json.side_effect = post_json_get_json
+ self.hs = self.setup_test_homeserver(
+ config=config,
+ federation_client=mock_federation_client,
+ simple_http_client=mock_http_client,
+ )
+
+ # TODO: This class does not use a singleton to get it's http client
+ # This should be fixed for easier testing
+ # https://github.com/matrix-org/synapse-dinsic/issues/26
+ self.hs.get_handlers().identity_handler.blacklisting_http_client = (
+ mock_http_client
+ )
+
+ return self.hs
+
+ def prepare(self, reactor, clock, homeserver):
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+
+ self.restricted_room = self.create_room()
+ self.unrestricted_room = self.create_room(rule=ACCESS_RULE_UNRESTRICTED)
+ self.direct_rooms = [
+ self.create_room(direct=True),
+ self.create_room(direct=True),
+ self.create_room(direct=True),
+ ]
+
+ self.invitee_id = self.register_user("invitee", "test")
+ self.invitee_tok = self.login("invitee", "test")
+
+ self.helper.invite(
+ room=self.direct_rooms[0],
+ src=self.user_id,
+ targ=self.invitee_id,
+ tok=self.tok,
+ )
+
+ def test_create_room_no_rule(self):
+ """Tests that creating a room with no rule will set the default value."""
+ room_id = self.create_room()
+ rule = self.current_rule_in_room(room_id)
+
+ self.assertEqual(rule, ACCESS_RULE_RESTRICTED)
+
+ def test_create_room_direct_no_rule(self):
+ """Tests that creating a direct room with no rule will set the default value."""
+ room_id = self.create_room(direct=True)
+ rule = self.current_rule_in_room(room_id)
+
+ self.assertEqual(rule, ACCESS_RULE_DIRECT)
+
+ def test_create_room_valid_rule(self):
+ """Tests that creating a room with a valid rule will set the right value."""
+ room_id = self.create_room(rule=ACCESS_RULE_UNRESTRICTED)
+ rule = self.current_rule_in_room(room_id)
+
+ self.assertEqual(rule, ACCESS_RULE_UNRESTRICTED)
+
+ def test_create_room_invalid_rule(self):
+ """Tests that creating a room with an invalid rule will set fail."""
+ self.create_room(rule=ACCESS_RULE_DIRECT, expected_code=400)
+
+ def test_create_room_direct_invalid_rule(self):
+ """Tests that creating a direct room with an invalid rule will fail.
+ """
+ self.create_room(direct=True, rule=ACCESS_RULE_RESTRICTED, expected_code=400)
+
+ def test_public_room(self):
+ """Tests that it's not possible to have a room with the public join rule and an
+ access rule that's not restricted.
+ """
+ # Creating a room with the public_chat preset should succeed and set the access
+ # rule to restricted.
+ preset_room_id = self.create_room(preset=RoomCreationPreset.PUBLIC_CHAT)
+ self.assertEqual(
+ self.current_rule_in_room(preset_room_id), ACCESS_RULE_RESTRICTED
+ )
+
+ # Creating a room with the public join rule in its initial state should succeed
+ # and set the access rule to restricted.
+ init_state_room_id = self.create_room(
+ initial_state=[
+ {
+ "type": "m.room.join_rules",
+ "content": {"join_rule": JoinRules.PUBLIC},
+ }
+ ]
+ )
+ self.assertEqual(
+ self.current_rule_in_room(init_state_room_id), ACCESS_RULE_RESTRICTED
+ )
+
+ # Changing access rule to unrestricted should fail.
+ self.change_rule_in_room(
+ preset_room_id, ACCESS_RULE_UNRESTRICTED, expected_code=403
+ )
+ self.change_rule_in_room(
+ init_state_room_id, ACCESS_RULE_UNRESTRICTED, expected_code=403
+ )
+
+ # Changing access rule to direct should fail.
+ self.change_rule_in_room(preset_room_id, ACCESS_RULE_DIRECT, expected_code=403)
+ self.change_rule_in_room(
+ init_state_room_id, ACCESS_RULE_DIRECT, expected_code=403
+ )
+
+ # Changing join rule to public in an unrestricted room should fail.
+ self.change_join_rule_in_room(
+ self.unrestricted_room, JoinRules.PUBLIC, expected_code=403
+ )
+ # Changing join rule to public in an direct room should fail.
+ self.change_join_rule_in_room(
+ self.direct_rooms[0], JoinRules.PUBLIC, expected_code=403
+ )
+
+ # Creating a new room with the public_chat preset and an access rule that isn't
+ # restricted should fail.
+ self.create_room(
+ preset=RoomCreationPreset.PUBLIC_CHAT,
+ rule=ACCESS_RULE_UNRESTRICTED,
+ expected_code=400,
+ )
+ self.create_room(
+ preset=RoomCreationPreset.PUBLIC_CHAT,
+ rule=ACCESS_RULE_DIRECT,
+ expected_code=400,
+ )
+
+ # Creating a room with the public join rule in its initial state and an access
+ # rule that isn't restricted should fail.
+ self.create_room(
+ initial_state=[
+ {
+ "type": "m.room.join_rules",
+ "content": {"join_rule": JoinRules.PUBLIC},
+ }
+ ],
+ rule=ACCESS_RULE_UNRESTRICTED,
+ expected_code=400,
+ )
+ self.create_room(
+ initial_state=[
+ {
+ "type": "m.room.join_rules",
+ "content": {"join_rule": JoinRules.PUBLIC},
+ }
+ ],
+ rule=ACCESS_RULE_DIRECT,
+ expected_code=400,
+ )
+
+ def test_restricted(self):
+ """Tests that in restricted mode we're unable to invite users from blacklisted
+ servers but can invite other users.
+ """
+ # We can't invite a user from a forbidden HS.
+ self.helper.invite(
+ room=self.restricted_room,
+ src=self.user_id,
+ targ="@test:forbidden_domain",
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ # We can invite a user which HS isn't forbidden.
+ self.helper.invite(
+ room=self.restricted_room,
+ src=self.user_id,
+ targ="@test:allowed_domain",
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # We can't send a 3PID invite to an address that is mapped to a forbidden HS.
+ self.send_threepid_invite(
+ address="test@forbidden_domain",
+ room_id=self.restricted_room,
+ expected_code=403,
+ )
+
+ # We can send a 3PID invite to an address that is mapped to an HS that's not
+ # forbidden.
+ self.send_threepid_invite(
+ address="test@allowed_domain",
+ room_id=self.restricted_room,
+ expected_code=200,
+ )
+
+ def test_direct(self):
+ """Tests that, in direct mode, other users than the initial two can't be invited,
+ but the following scenario works:
+ * invited user joins the room
+ * invited user leaves the room
+ * room creator re-invites invited user
+ Also tests that a user from a HS that's in the list of forbidden domains (to use
+ in restricted mode) can be invited.
+ """
+ not_invited_user = "@not_invited:forbidden_domain"
+
+ # We can't invite a new user to the room.
+ self.helper.invite(
+ room=self.direct_rooms[0],
+ src=self.user_id,
+ targ=not_invited_user,
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ # The invited user can join the room.
+ self.helper.join(
+ room=self.direct_rooms[0],
+ user=self.invitee_id,
+ tok=self.invitee_tok,
+ expect_code=200,
+ )
+
+ # The invited user can leave the room.
+ self.helper.leave(
+ room=self.direct_rooms[0],
+ user=self.invitee_id,
+ tok=self.invitee_tok,
+ expect_code=200,
+ )
+
+ # The invited user can be re-invited to the room.
+ self.helper.invite(
+ room=self.direct_rooms[0],
+ src=self.user_id,
+ targ=self.invitee_id,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # If we're alone in the room and have always been the only member, we can invite
+ # someone.
+ self.helper.invite(
+ room=self.direct_rooms[1],
+ src=self.user_id,
+ targ=not_invited_user,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # Disable the 3pid invite ratelimiter
+ burst = self.hs.config.rc_third_party_invite.burst_count
+ per_second = self.hs.config.rc_third_party_invite.per_second
+ self.hs.config.rc_third_party_invite.burst_count = 10
+ self.hs.config.rc_third_party_invite.per_second = 0.1
+
+ # We can't send a 3PID invite to a room that already has two members.
+ self.send_threepid_invite(
+ address="test@allowed_domain",
+ room_id=self.direct_rooms[0],
+ expected_code=403,
+ )
+
+ # We can't send a 3PID invite to a room that already has a pending invite.
+ self.send_threepid_invite(
+ address="test@allowed_domain",
+ room_id=self.direct_rooms[1],
+ expected_code=403,
+ )
+
+ # We can send a 3PID invite to a room in which we've always been the only member.
+ self.send_threepid_invite(
+ address="test@forbidden_domain",
+ room_id=self.direct_rooms[2],
+ expected_code=200,
+ )
+
+ # We can send a 3PID invite to a room in which there's a 3PID invite.
+ self.send_threepid_invite(
+ address="test@forbidden_domain",
+ room_id=self.direct_rooms[2],
+ expected_code=403,
+ )
+
+ self.hs.config.rc_third_party_invite.burst_count = burst
+ self.hs.config.rc_third_party_invite.per_second = per_second
+
+ def test_unrestricted(self):
+ """Tests that, in unrestricted mode, we can invite whoever we want, but we can
+ only change the power level of users that wouldn't be forbidden in restricted
+ mode.
+ """
+ # We can invite
+ self.helper.invite(
+ room=self.unrestricted_room,
+ src=self.user_id,
+ targ="@test:forbidden_domain",
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.invite(
+ room=self.unrestricted_room,
+ src=self.user_id,
+ targ="@test:not_forbidden_domain",
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # We can send a 3PID invite to an address that is mapped to a forbidden HS.
+ self.send_threepid_invite(
+ address="test@forbidden_domain",
+ room_id=self.unrestricted_room,
+ expected_code=200,
+ )
+
+ # We can send a 3PID invite to an address that is mapped to an HS that's not
+ # forbidden.
+ self.send_threepid_invite(
+ address="test@allowed_domain",
+ room_id=self.unrestricted_room,
+ expected_code=200,
+ )
+
+ # We can send a power level event that doesn't redefine the default PL or set a
+ # non-default PL for a user that would be forbidden in restricted mode.
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.PowerLevels,
+ body={"users": {self.user_id: 100, "@test:not_forbidden_domain": 10}},
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # We can't send a power level event that redefines the default PL and doesn't set
+ # a non-default PL for a user that would be forbidden in restricted mode.
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.PowerLevels,
+ body={
+ "users": {self.user_id: 100, "@test:not_forbidden_domain": 10},
+ "users_default": 10,
+ },
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ # We can't send a power level event that doesn't redefines the default PL but sets
+ # a non-default PL for a user that would be forbidden in restricted mode.
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.PowerLevels,
+ body={"users": {self.user_id: 100, "@test:forbidden_domain": 10}},
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ def test_change_rules(self):
+ """Tests that we can only change the current rule from restricted to
+ unrestricted.
+ """
+ # We can change the rule from restricted to unrestricted.
+ self.change_rule_in_room(
+ room_id=self.restricted_room,
+ new_rule=ACCESS_RULE_UNRESTRICTED,
+ expected_code=200,
+ )
+
+ # We can't change the rule from restricted to direct.
+ self.change_rule_in_room(
+ room_id=self.restricted_room, new_rule=ACCESS_RULE_DIRECT, expected_code=403
+ )
+
+ # We can't change the rule from unrestricted to restricted.
+ self.change_rule_in_room(
+ room_id=self.unrestricted_room,
+ new_rule=ACCESS_RULE_RESTRICTED,
+ expected_code=403,
+ )
+
+ # We can't change the rule from unrestricted to direct.
+ self.change_rule_in_room(
+ room_id=self.unrestricted_room,
+ new_rule=ACCESS_RULE_DIRECT,
+ expected_code=403,
+ )
+
+ # We can't change the rule from direct to restricted.
+ self.change_rule_in_room(
+ room_id=self.direct_rooms[0],
+ new_rule=ACCESS_RULE_RESTRICTED,
+ expected_code=403,
+ )
+
+ # We can't change the rule from direct to unrestricted.
+ self.change_rule_in_room(
+ room_id=self.direct_rooms[0],
+ new_rule=ACCESS_RULE_UNRESTRICTED,
+ expected_code=403,
+ )
+
+ def test_change_room_avatar(self):
+ """Tests that changing the room avatar is always allowed unless the room is a
+ direct chat, in which case it's forbidden.
+ """
+
+ avatar_content = {
+ "info": {"h": 398, "mimetype": "image/jpeg", "size": 31037, "w": 394},
+ "url": "mxc://example.org/JWEIFJgwEIhweiWJE",
+ }
+
+ self.helper.send_state(
+ room_id=self.restricted_room,
+ event_type=EventTypes.RoomAvatar,
+ body=avatar_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.RoomAvatar,
+ body=avatar_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.direct_rooms[0],
+ event_type=EventTypes.RoomAvatar,
+ body=avatar_content,
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ def test_change_room_name(self):
+ """Tests that changing the room name is always allowed unless the room is a direct
+ chat, in which case it's forbidden.
+ """
+
+ name_content = {"name": "My super room"}
+
+ self.helper.send_state(
+ room_id=self.restricted_room,
+ event_type=EventTypes.Name,
+ body=name_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.Name,
+ body=name_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.direct_rooms[0],
+ event_type=EventTypes.Name,
+ body=name_content,
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ def test_change_room_topic(self):
+ """Tests that changing the room topic is always allowed unless the room is a
+ direct chat, in which case it's forbidden.
+ """
+
+ topic_content = {"topic": "Welcome to this room"}
+
+ self.helper.send_state(
+ room_id=self.restricted_room,
+ event_type=EventTypes.Topic,
+ body=topic_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.Topic,
+ body=topic_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.direct_rooms[0],
+ event_type=EventTypes.Topic,
+ body=topic_content,
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ def test_revoke_3pid_invite_direct(self):
+ """Tests that revoking a 3PID invite doesn't cause the room access rules module to
+ confuse the revokation as a new 3PID invite.
+ """
+ invite_token = "sometoken"
+
+ invite_body = {
+ "display_name": "ker...@exa...",
+ "public_keys": [
+ {
+ "key_validity_url": "https://validity_url",
+ "public_key": "ta8IQ0u1sp44HVpxYi7dFOdS/bfwDjcy4xLFlfY5KOA",
+ },
+ {
+ "key_validity_url": "https://validity_url",
+ "public_key": "4_9nzEeDwR5N9s51jPodBiLnqH43A2_g2InVT137t9I",
+ },
+ ],
+ "key_validity_url": "https://validity_url",
+ "public_key": "ta8IQ0u1sp44HVpxYi7dFOdS/bfwDjcy4xLFlfY5KOA",
+ }
+
+ self.send_state_with_state_key(
+ room_id=self.direct_rooms[1],
+ event_type=EventTypes.ThirdPartyInvite,
+ state_key=invite_token,
+ body=invite_body,
+ tok=self.tok,
+ )
+
+ self.send_state_with_state_key(
+ room_id=self.direct_rooms[1],
+ event_type=EventTypes.ThirdPartyInvite,
+ state_key=invite_token,
+ body={},
+ tok=self.tok,
+ )
+
+ invite_token = "someothertoken"
+
+ self.send_state_with_state_key(
+ room_id=self.direct_rooms[1],
+ event_type=EventTypes.ThirdPartyInvite,
+ state_key=invite_token,
+ body=invite_body,
+ tok=self.tok,
+ )
+
+ def create_room(
+ self,
+ direct=False,
+ rule=None,
+ preset=RoomCreationPreset.TRUSTED_PRIVATE_CHAT,
+ initial_state=None,
+ expected_code=200,
+ ):
+ content = {"is_direct": direct, "preset": preset}
+
+ if rule:
+ content["initial_state"] = [
+ {"type": ACCESS_RULES_TYPE, "state_key": "", "content": {"rule": rule}}
+ ]
+
+ if initial_state:
+ if "initial_state" not in content:
+ content["initial_state"] = []
+
+ content["initial_state"] += initial_state
+
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/createRoom",
+ json.dumps(content),
+ access_token=self.tok,
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ if expected_code == 200:
+ return channel.json_body["room_id"]
+
+ def current_rule_in_room(self, room_id):
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, ACCESS_RULES_TYPE),
+ access_token=self.tok,
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, 200, channel.result)
+ return channel.json_body["rule"]
+
+ def change_rule_in_room(self, room_id, new_rule, expected_code=200):
+ data = {"rule": new_rule}
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, ACCESS_RULES_TYPE),
+ json.dumps(data),
+ access_token=self.tok,
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ def change_join_rule_in_room(self, room_id, new_join_rule, expected_code=200):
+ data = {"join_rule": new_join_rule}
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, EventTypes.JoinRules),
+ json.dumps(data),
+ access_token=self.tok,
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ def send_threepid_invite(self, address, room_id, expected_code=200):
+ params = {"id_server": "testis", "medium": "email", "address": address}
+
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/%s/invite" % room_id,
+ json.dumps(params),
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ def send_state_with_state_key(
+ self, room_id, event_type, state_key, body, tok, expect_code=200
+ ):
+ path = "/_matrix/client/r0/rooms/%s/state/%s/%s" % (
+ room_id,
+ event_type,
+ state_key,
+ )
+
+ request, channel = self.make_request(
+ "PUT", path, json.dumps(body), access_token=tok
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, expect_code, channel.result)
+
+ return channel.json_body
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 4886bbb401..5ccda8b2bd 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -19,9 +19,9 @@
"""Tests REST events for /rooms paths."""
import json
+from urllib import parse as urlparse
from mock import Mock
-from six.moves.urllib import parse as urlparse
from twisted.internet import defer
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 7deaf5b24a..ceca4041e1 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -19,8 +19,12 @@ import datetime
import json
import os
+from mock import Mock
+
import pkg_resources
+from twisted.internet import defer
+
import synapse.rest.admin
from synapse.api.constants import LoginType
from synapse.api.errors import Codes
@@ -87,14 +91,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"400", channel.result)
self.assertEquals(channel.json_body["error"], "Invalid password")
- def test_POST_bad_username(self):
- request_data = json.dumps({"username": 777, "password": "monkey"})
- request, channel = self.make_request(b"POST", self.url, request_data)
- self.render(request)
-
- self.assertEquals(channel.result["code"], b"400", channel.result)
- self.assertEquals(channel.json_body["error"], "Invalid username")
-
def test_POST_user_valid(self):
user_id = "@kermit:test"
device_id = "frogfone"
@@ -303,6 +299,47 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertIsNotNone(channel.json_body.get("sid"))
+class RegisterHideProfileTestCase(unittest.HomeserverTestCase):
+
+ servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource]
+
+ def make_homeserver(self, reactor, clock):
+
+ self.url = b"/_matrix/client/r0/register"
+
+ config = self.default_config()
+ config["enable_registration"] = True
+ config["show_users_in_user_directory"] = False
+ config["replicate_user_profiles_to"] = ["fakeserver"]
+
+ mock_http_client = Mock(spec=["get_json", "post_json_get_json"])
+ mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}"))
+
+ self.hs = self.setup_test_homeserver(
+ config=config, simple_http_client=mock_http_client
+ )
+
+ return self.hs
+
+ def test_profile_hidden(self):
+ user_id = self.register_user("kermit", "monkey")
+
+ post_json = self.hs.get_simple_http_client().post_json_get_json
+
+ # We expect post_json_get_json to have been called twice: once with the original
+ # profile and once with the None profile resulting from the request to hide it
+ # from the user directory.
+ self.assertEqual(post_json.call_count, 2, post_json.call_args_list)
+
+ # Get the args (and not kwargs) passed to post_json.
+ args = post_json.call_args[0]
+ # Make sure the last call was attempting to replicate profiles.
+ split_uri = args[0].split("/")
+ self.assertEqual(split_uri[len(split_uri) - 1], "replicate_profiles", args[0])
+ # Make sure the last profile update was overriding the user's profile to None.
+ self.assertEqual(args[1]["batch"][user_id], None, args[1])
+
+
class AccountValidityTestCase(unittest.HomeserverTestCase):
servlets = [
@@ -312,6 +349,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
sync.register_servlets,
logout.register_servlets,
account_validity.register_servlets,
+ account.register_servlets,
]
def make_homeserver(self, reactor, clock):
@@ -437,6 +475,155 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
+class AccountValidityUserDirectoryTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.client.v1.profile.register_servlets,
+ synapse.rest.client.v1.room.register_servlets,
+ synapse.rest.client.v2_alpha.user_directory.register_servlets,
+ login.register_servlets,
+ register.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ account_validity.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+
+ # Set accounts to expire after a week
+ config["enable_registration"] = True
+ config["account_validity"] = {
+ "enabled": True,
+ "period": 604800000, # Time in ms for 1 week
+ }
+ config["replicate_user_profiles_to"] = "test.is"
+
+ # Mock homeserver requests to an identity server
+ mock_http_client = Mock(spec=["post_json_get_json"])
+ mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}"))
+
+ self.hs = self.setup_test_homeserver(
+ config=config, simple_http_client=mock_http_client
+ )
+
+ return self.hs
+
+ def test_expired_user_in_directory(self):
+ """Test that an expired user is hidden in the user directory"""
+ # Create an admin user to search the user directory
+ admin_id = self.register_user("admin", "adminpassword", admin=True)
+ admin_tok = self.login("admin", "adminpassword")
+
+ # Ensure the admin never expires
+ url = "/_matrix/client/unstable/admin/account_validity/validity"
+ params = {
+ "user_id": admin_id,
+ "expiration_ts": 999999999999,
+ "enable_renewal_emails": False,
+ }
+ request_data = json.dumps(params)
+ request, channel = self.make_request(
+ b"POST", url, request_data, access_token=admin_tok
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ # Mock the homeserver's HTTP client
+ post_json = self.hs.get_simple_http_client().post_json_get_json
+
+ # Create a user
+ username = "kermit"
+ user_id = self.register_user(username, "monkey")
+ self.login(username, "monkey")
+ self.get_success(
+ self.hs.get_datastore().set_profile_displayname(username, "mr.kermit", 1)
+ )
+
+ # Check that a full profile for this user is replicated
+ self.assertIsNotNone(post_json.call_args, post_json.call_args)
+ payload = post_json.call_args[0][1]
+ batch = payload.get("batch")
+
+ self.assertIsNotNone(batch, batch)
+ self.assertEquals(len(batch), 1, batch)
+
+ replicated_user_id = list(batch.keys())[0]
+ self.assertEquals(replicated_user_id, user_id, replicated_user_id)
+
+ # There was replicated information about our user
+ # Check that it's not None
+ replicated_content = batch[user_id]
+ self.assertIsNotNone(replicated_content)
+
+ # Expire the user
+ url = "/_matrix/client/unstable/admin/account_validity/validity"
+ params = {
+ "user_id": user_id,
+ "expiration_ts": 0,
+ "enable_renewal_emails": False,
+ }
+ request_data = json.dumps(params)
+ request, channel = self.make_request(
+ b"POST", url, request_data, access_token=admin_tok
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ # Wait for the background job to run which hides expired users in the directory
+ self.reactor.advance(60 * 60 * 1000)
+
+ # Check if the homeserver has replicated the user's profile to the identity server
+ self.assertIsNotNone(post_json.call_args, post_json.call_args)
+ payload = post_json.call_args[0][1]
+ batch = payload.get("batch")
+
+ self.assertIsNotNone(batch, batch)
+ self.assertEquals(len(batch), 1, batch)
+
+ replicated_user_id = list(batch.keys())[0]
+ self.assertEquals(replicated_user_id, user_id, replicated_user_id)
+
+ # There was replicated information about our user
+ # Check that it's None, signifying that the user should be removed from the user
+ # directory because they were expired
+ replicated_content = batch[user_id]
+ self.assertIsNone(replicated_content)
+
+ # Now renew the user, and check they get replicated again to the identity server
+ url = "/_matrix/client/unstable/admin/account_validity/validity"
+ params = {
+ "user_id": user_id,
+ "expiration_ts": 99999999999,
+ "enable_renewal_emails": False,
+ }
+ request_data = json.dumps(params)
+ request, channel = self.make_request(
+ b"POST", url, request_data, access_token=admin_tok
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ self.pump(10)
+ self.reactor.advance(10)
+ self.pump()
+
+ # Check if the homeserver has replicated the user's profile to the identity server
+ post_json = self.hs.get_simple_http_client().post_json_get_json
+ self.assertNotEquals(post_json.call_args, None, post_json.call_args)
+ payload = post_json.call_args[0][1]
+ batch = payload.get("batch")
+ self.assertNotEquals(batch, None, batch)
+ self.assertEquals(len(batch), 1, batch)
+ replicated_user_id = list(batch.keys())[0]
+ self.assertEquals(replicated_user_id, user_id, replicated_user_id)
+
+ # There was replicated information about our user
+ # Check that it's not None, signifying that the user is back in the user
+ # directory
+ replicated_content = batch[user_id]
+ self.assertIsNotNone(replicated_content)
+
+
class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
servlets = [
@@ -587,7 +774,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
"POST", "account/deactivate", request_data, access_token=tok
)
self.render(request)
- self.assertEqual(request.code, 200)
+ self.assertEqual(request.code, 200, channel.result)
self.reactor.advance(datetime.timedelta(days=8).total_seconds())
diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py
index c7e5859970..fd641a7c2f 100644
--- a/tests/rest/client/v2_alpha/test_relations.py
+++ b/tests/rest/client/v2_alpha/test_relations.py
@@ -15,8 +15,7 @@
import itertools
import json
-
-import six
+import urllib
from synapse.api.constants import EventTypes, RelationTypes
from synapse.rest import admin
@@ -134,7 +133,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
# Make sure next_batch has something in it that looks like it could be a
# valid token.
self.assertIsInstance(
- channel.json_body.get("next_batch"), six.string_types, channel.json_body
+ channel.json_body.get("next_batch"), str, channel.json_body
)
def test_repeated_paginate_relations(self):
@@ -278,7 +277,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
prev_token = None
found_event_ids = []
- encoded_key = six.moves.urllib.parse.quote_plus("๐".encode("utf-8"))
+ encoded_key = urllib.parse.quote_plus("๐".encode("utf-8"))
for _ in range(20):
from_token = ""
if prev_token:
@@ -670,7 +669,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
query = ""
if key:
- query = "?key=" + six.moves.urllib.parse.quote_plus(key.encode("utf-8"))
+ query = "?key=" + urllib.parse.quote_plus(key.encode("utf-8"))
original_id = parent_id if parent_id else self.parent_id
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 1ca648ef2b..aefe648bdb 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -20,9 +20,9 @@ import tempfile
from binascii import unhexlify
from io import BytesIO
from typing import Optional
+from urllib import parse
from mock import Mock
-from six.moves.urllib import parse
import attr
import PIL.Image as Image
diff --git a/tests/rulecheck/__init__.py b/tests/rulecheck/__init__.py
new file mode 100644
index 0000000000..a354d38ca8
--- /dev/null
+++ b/tests/rulecheck/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/rulecheck/test_domainrulecheck.py b/tests/rulecheck/test_domainrulecheck.py
new file mode 100644
index 0000000000..1accc70dc9
--- /dev/null
+++ b/tests/rulecheck/test_domainrulecheck.py
@@ -0,0 +1,334 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import json
+
+import synapse.rest.admin
+from synapse.config._base import ConfigError
+from synapse.rest.client.v1 import login, room
+from synapse.rulecheck.domain_rule_checker import DomainRuleChecker
+
+from tests import unittest
+from tests.server import make_request, render
+
+
+class DomainRuleCheckerTestCase(unittest.TestCase):
+ def test_allowed(self):
+ config = {
+ "default": False,
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ },
+ "domains_prevented_from_being_invited_to_published_rooms": ["target_two"],
+ }
+ check = DomainRuleChecker(config)
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_one", "test:target_one", None, "room", False
+ )
+ )
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_one", "test:target_two", None, "room", False
+ )
+ )
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_two", "test:target_two", None, "room", False
+ )
+ )
+
+ # User can invite internal user to a published room
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_one", "test1:target_one", None, "room", False, True
+ )
+ )
+
+ # User can invite external user to a non-published room
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_one", "test:target_two", None, "room", False, False
+ )
+ )
+
+ def test_disallowed(self):
+ config = {
+ "default": True,
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ "source_four": [],
+ },
+ }
+ check = DomainRuleChecker(config)
+ self.assertFalse(
+ check.user_may_invite(
+ "test:source_one", "test:target_three", None, "room", False
+ )
+ )
+ self.assertFalse(
+ check.user_may_invite(
+ "test:source_two", "test:target_three", None, "room", False
+ )
+ )
+ self.assertFalse(
+ check.user_may_invite(
+ "test:source_two", "test:target_one", None, "room", False
+ )
+ )
+ self.assertFalse(
+ check.user_may_invite(
+ "test:source_four", "test:target_one", None, "room", False
+ )
+ )
+
+ # User cannot invite external user to a published room
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_one", "test:target_two", None, "room", False, True
+ )
+ )
+
+ def test_default_allow(self):
+ config = {
+ "default": True,
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ },
+ }
+ check = DomainRuleChecker(config)
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_three", "test:target_one", None, "room", False
+ )
+ )
+
+ def test_default_deny(self):
+ config = {
+ "default": False,
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ },
+ }
+ check = DomainRuleChecker(config)
+ self.assertFalse(
+ check.user_may_invite(
+ "test:source_three", "test:target_one", None, "room", False
+ )
+ )
+
+ def test_config_parse(self):
+ config = {
+ "default": False,
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ },
+ }
+ self.assertEquals(config, DomainRuleChecker.parse_config(config))
+
+ def test_config_parse_failure(self):
+ config = {
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ }
+ }
+ self.assertRaises(ConfigError, DomainRuleChecker.parse_config, config)
+
+
+class DomainRuleCheckerRoomTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ hijack_auth = False
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+ config["trusted_third_party_id_servers"] = ["localhost"]
+
+ config["spam_checker"] = {
+ "module": "synapse.rulecheck.domain_rule_checker.DomainRuleChecker",
+ "config": {
+ "default": True,
+ "domain_mapping": {},
+ "can_only_join_rooms_with_invite": True,
+ "can_only_create_one_to_one_rooms": True,
+ "can_only_invite_during_room_creation": True,
+ "can_invite_by_third_party_id": False,
+ },
+ }
+
+ hs = self.setup_test_homeserver(config=config)
+ return hs
+
+ def prepare(self, reactor, clock, hs):
+ self.admin_user_id = self.register_user("admin_user", "pass", admin=True)
+ self.admin_access_token = self.login("admin_user", "pass")
+
+ self.normal_user_id = self.register_user("normal_user", "pass", admin=False)
+ self.normal_access_token = self.login("normal_user", "pass")
+
+ self.other_user_id = self.register_user("other_user", "pass", admin=False)
+
+ def test_admin_can_create_room(self):
+ channel = self._create_room(self.admin_access_token)
+ assert channel.result["code"] == b"200", channel.result
+
+ def test_normal_user_cannot_create_empty_room(self):
+ channel = self._create_room(self.normal_access_token)
+ assert channel.result["code"] == b"403", channel.result
+
+ def test_normal_user_cannot_create_room_with_multiple_invites(self):
+ channel = self._create_room(
+ self.normal_access_token,
+ content={"invite": [self.other_user_id, self.admin_user_id]},
+ )
+ assert channel.result["code"] == b"403", channel.result
+
+ # Test that it correctly counts both normal and third party invites
+ channel = self._create_room(
+ self.normal_access_token,
+ content={
+ "invite": [self.other_user_id],
+ "invite_3pid": [{"medium": "email", "address": "foo@example.com"}],
+ },
+ )
+ assert channel.result["code"] == b"403", channel.result
+
+ # Test that it correctly rejects third party invites
+ channel = self._create_room(
+ self.normal_access_token,
+ content={
+ "invite": [],
+ "invite_3pid": [{"medium": "email", "address": "foo@example.com"}],
+ },
+ )
+ assert channel.result["code"] == b"403", channel.result
+
+ def test_normal_user_can_room_with_single_invites(self):
+ channel = self._create_room(
+ self.normal_access_token, content={"invite": [self.other_user_id]}
+ )
+ assert channel.result["code"] == b"200", channel.result
+
+ def test_cannot_join_public_room(self):
+ channel = self._create_room(self.admin_access_token)
+ assert channel.result["code"] == b"200", channel.result
+
+ room_id = channel.json_body["room_id"]
+
+ self.helper.join(
+ room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=403
+ )
+
+ def test_can_join_invited_room(self):
+ channel = self._create_room(self.admin_access_token)
+ assert channel.result["code"] == b"200", channel.result
+
+ room_id = channel.json_body["room_id"]
+
+ self.helper.invite(
+ room_id,
+ src=self.admin_user_id,
+ targ=self.normal_user_id,
+ tok=self.admin_access_token,
+ )
+
+ self.helper.join(
+ room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=200
+ )
+
+ def test_cannot_invite(self):
+ channel = self._create_room(self.admin_access_token)
+ assert channel.result["code"] == b"200", channel.result
+
+ room_id = channel.json_body["room_id"]
+
+ self.helper.invite(
+ room_id,
+ src=self.admin_user_id,
+ targ=self.normal_user_id,
+ tok=self.admin_access_token,
+ )
+
+ self.helper.join(
+ room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=200
+ )
+
+ self.helper.invite(
+ room_id,
+ src=self.normal_user_id,
+ targ=self.other_user_id,
+ tok=self.normal_access_token,
+ expect_code=403,
+ )
+
+ def test_cannot_3pid_invite(self):
+ """Test that unbound 3pid invites get rejected.
+ """
+ channel = self._create_room(self.admin_access_token)
+ assert channel.result["code"] == b"200", channel.result
+
+ room_id = channel.json_body["room_id"]
+
+ self.helper.invite(
+ room_id,
+ src=self.admin_user_id,
+ targ=self.normal_user_id,
+ tok=self.admin_access_token,
+ )
+
+ self.helper.join(
+ room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=200
+ )
+
+ self.helper.invite(
+ room_id,
+ src=self.normal_user_id,
+ targ=self.other_user_id,
+ tok=self.normal_access_token,
+ expect_code=403,
+ )
+
+ request, channel = self.make_request(
+ "POST",
+ "rooms/%s/invite" % (room_id),
+ {"address": "foo@bar.com", "medium": "email", "id_server": "localhost"},
+ access_token=self.normal_access_token,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 403, channel.result["body"])
+
+ def _create_room(self, token, content={}):
+ path = "/_matrix/client/r0/createRoom?access_token=%s" % (token,)
+
+ request, channel = make_request(
+ self.hs.get_reactor(),
+ "POST",
+ path,
+ content=json.dumps(content).encode("utf8"),
+ )
+ render(request, self.resource, self.hs.get_reactor())
+
+ return channel
diff --git a/tests/server.py b/tests/server.py
index 1644710aa0..a5e57c52fa 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -2,8 +2,6 @@ import json
import logging
from io import BytesIO
-from six import text_type
-
import attr
from zope.interface import implementer
@@ -174,7 +172,7 @@ def make_request(
if not path.startswith(b"/"):
path = b"/" + path
- if isinstance(content, text_type):
+ if isinstance(content, str):
content = content.encode("utf8")
site = FakeSite()
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index a44960203e..cdc347bc53 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -15,8 +15,6 @@
import itertools
-from six.moves import zip
-
import attr
from synapse.api.constants import EventTypes, JoinRules, Membership
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index b45bc9c115..303dc8571c 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -22,6 +22,10 @@ import tests.utils
USER_ID = "@user:example.com"
+MARK_UNREAD = [
+ "org.matrix.msc2625.mark_unread",
+ {"set_tweak": "highlight", "value": False},
+]
PlAIN_NOTIF = ["notify", {"set_tweak": "highlight", "value": False}]
HIGHLIGHT = [
"notify",
@@ -55,13 +59,17 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
user_id = "@user1235:example.com"
@defer.inlineCallbacks
- def _assert_counts(noitf_count, highlight_count):
+ def _assert_counts(unread_count, notif_count, highlight_count):
counts = yield self.store.db.runInteraction(
"", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0
)
self.assertEquals(
counts,
- {"notify_count": noitf_count, "highlight_count": highlight_count},
+ {
+ "unread_count": unread_count,
+ "notify_count": notif_count,
+ "highlight_count": highlight_count,
+ },
)
@defer.inlineCallbacks
@@ -96,23 +104,23 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
stream,
)
- yield _assert_counts(0, 0)
+ yield _assert_counts(0, 0, 0)
yield _inject_actions(1, PlAIN_NOTIF)
- yield _assert_counts(1, 0)
+ yield _assert_counts(1, 1, 0)
yield _rotate(2)
- yield _assert_counts(1, 0)
+ yield _assert_counts(1, 1, 0)
yield _inject_actions(3, PlAIN_NOTIF)
- yield _assert_counts(2, 0)
+ yield _assert_counts(2, 2, 0)
yield _rotate(4)
- yield _assert_counts(2, 0)
+ yield _assert_counts(2, 2, 0)
yield _inject_actions(5, PlAIN_NOTIF)
yield _mark_read(3, 3)
- yield _assert_counts(1, 0)
+ yield _assert_counts(1, 1, 0)
yield _mark_read(5, 5)
- yield _assert_counts(0, 0)
+ yield _assert_counts(0, 0, 0)
yield _inject_actions(6, PlAIN_NOTIF)
yield _rotate(7)
@@ -121,17 +129,22 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
table="event_push_actions", keyvalues={"1": 1}, desc=""
)
- yield _assert_counts(1, 0)
+ yield _assert_counts(1, 1, 0)
yield _mark_read(7, 7)
- yield _assert_counts(0, 0)
+ yield _assert_counts(0, 0, 0)
- yield _inject_actions(8, HIGHLIGHT)
- yield _assert_counts(1, 1)
+ yield _inject_actions(8, MARK_UNREAD)
+ yield _assert_counts(1, 0, 0)
yield _rotate(9)
- yield _assert_counts(1, 1)
- yield _rotate(10)
- yield _assert_counts(1, 1)
+ yield _assert_counts(1, 0, 0)
+
+ yield _inject_actions(10, HIGHLIGHT)
+ yield _assert_counts(2, 1, 1)
+ yield _rotate(11)
+ yield _assert_counts(2, 1, 1)
+ yield _rotate(12)
+ yield _assert_counts(2, 1, 1)
@defer.inlineCallbacks
def test_find_first_stream_ordering_after_ts(self):
diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py
index ab0df5ea93..0155ffd04e 100644
--- a/tests/storage/test_main.py
+++ b/tests/storage/test_main.py
@@ -36,7 +36,9 @@ class DataStoreTestCase(unittest.TestCase):
def test_get_users_paginate(self):
yield self.store.register_user(self.user.to_string(), "pass")
yield self.store.create_profile(self.user.localpart)
- yield self.store.set_profile_displayname(self.user.localpart, self.displayname)
+ yield self.store.set_profile_displayname(
+ self.user.localpart, self.displayname, 1
+ )
users, total = yield self.store.get_users_paginate(
0, 10, name="bc", guests=False
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 9b6f7211ae..7458a37e54 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -33,9 +33,7 @@ class ProfileStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_displayname(self):
- yield self.store.create_profile(self.u_frank.localpart)
-
- yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
+ yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank", 1)
self.assertEquals(
"Frank", (yield self.store.get_profile_displayname(self.u_frank.localpart))
@@ -43,10 +41,8 @@ class ProfileStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_avatar_url(self):
- yield self.store.create_profile(self.u_frank.localpart)
-
yield self.store.set_profile_avatar_url(
- self.u_frank.localpart, "http://my.site/here"
+ self.u_frank.localpart, "http://my.site/here", 1
)
self.assertEquals(
diff --git a/tests/test_federation.py b/tests/test_federation.py
index c662195eec..89dcc58b99 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -30,7 +30,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
room_creator = self.homeserver.get_room_creation_handler()
room_deferred = ensureDeferred(
room_creator.create_room(
- our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False
+ our_user, room_creator._presets_dict["public_chat"], ratelimit=False
)
)
self.reactor.advance(0.1)
diff --git a/tests/test_server.py b/tests/test_server.py
index e9a43b1e45..3f6f468e5b 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -14,8 +14,7 @@
import logging
import re
-
-from six import StringIO
+from io import StringIO
from twisted.internet.defer import Deferred
from twisted.python.failure import Failure
@@ -24,6 +23,7 @@ from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
from synapse.api.errors import Codes, RedirectException, SynapseError
+from synapse.config.server import parse_listener_def
from synapse.http.server import (
DirectServeResource,
JsonResource,
@@ -189,7 +189,13 @@ class OptionsResourceTests(unittest.TestCase):
request.prepath = [] # This doesn't get set properly by make_request.
# Create a site and query for the resource.
- site = SynapseSite("test", "site_tag", {}, self.resource, "1.0")
+ site = SynapseSite(
+ "test",
+ "site_tag",
+ parse_listener_def({"type": "http", "port": 0}),
+ self.resource,
+ "1.0",
+ )
request.site = site
resource = site.getResourceFor(request)
@@ -348,7 +354,9 @@ class SiteTestCase(unittest.HomeserverTestCase):
# time out the request while it's 'processing'
base_resource = Resource()
base_resource.putChild(b"", HangingResource())
- site = SynapseSite("test", "site_tag", {}, base_resource, "1.0")
+ site = SynapseSite(
+ "test", "site_tag", self.hs.config.listeners[0], base_resource, "1.0"
+ )
server = site.buildProtocol(None)
client = AccumulatingProtocol()
diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py
index 5c2817cf28..b89798336c 100644
--- a/tests/test_terms_auth.py
+++ b/tests/test_terms_auth.py
@@ -14,7 +14,6 @@
import json
-import six
from mock import Mock
from twisted.test.proto_helpers import MemoryReactorClock
@@ -60,7 +59,7 @@ class TermsTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"401", channel.result)
self.assertTrue(channel.json_body is not None)
- self.assertIsInstance(channel.json_body["session"], six.text_type)
+ self.assertIsInstance(channel.json_body["session"], str)
self.assertIsInstance(channel.json_body["flows"], list)
for flow in channel.json_body["flows"]:
@@ -125,6 +124,6 @@ class TermsTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
self.assertTrue(channel.json_body is not None)
- self.assertIsInstance(channel.json_body["user_id"], six.text_type)
- self.assertIsInstance(channel.json_body["access_token"], six.text_type)
- self.assertIsInstance(channel.json_body["device_id"], six.text_type)
+ self.assertIsInstance(channel.json_body["user_id"], str)
+ self.assertIsInstance(channel.json_body["access_token"], str)
+ self.assertIsInstance(channel.json_body["device_id"], str)
diff --git a/tests/test_types.py b/tests/test_types.py
index 480bea1bdc..d4a722a30f 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -12,9 +12,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from six import string_types
from synapse.api.errors import SynapseError
-from synapse.types import GroupID, RoomAlias, UserID, map_username_to_mxid_localpart
+from synapse.types import (
+ GroupID,
+ RoomAlias,
+ UserID,
+ map_username_to_mxid_localpart,
+ strip_invalid_mxid_characters,
+)
from tests import unittest
@@ -103,3 +110,16 @@ class MapUsernameTestCase(unittest.TestCase):
self.assertEqual(
map_username_to_mxid_localpart("tรชst".encode("utf-8")), "t=c3=aast"
)
+
+
+class StripInvalidMxidCharactersTestCase(unittest.TestCase):
+ def test_return_type(self):
+ unstripped = strip_invalid_mxid_characters("test")
+ stripped = strip_invalid_mxid_characters("test@")
+
+ self.assertTrue(isinstance(unstripped, string_types), type(unstripped))
+ self.assertTrue(isinstance(stripped, string_types), type(stripped))
+
+ def test_strip(self):
+ stripped = strip_invalid_mxid_characters("test@")
+ self.assertEqual(stripped, "test", stripped)
diff --git a/tests/unittest.py b/tests/unittest.py
index 6b6f224e9c..3175a3fa02 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -229,7 +229,7 @@ class HomeserverTestCase(TestCase):
self.site = SynapseSite(
logger_name="synapse.access.http.fake",
site_tag="test",
- config={},
+ config=self.hs.config.server.listeners[0],
resource=self.resource,
server_version_string="1",
)
diff --git a/tests/util/test_file_consumer.py b/tests/util/test_file_consumer.py
index e90e08d1c0..8d6627ec33 100644
--- a/tests/util/test_file_consumer.py
+++ b/tests/util/test_file_consumer.py
@@ -15,9 +15,9 @@
import threading
+from io import StringIO
from mock import NonCallableMock
-from six import StringIO
from twisted.internet import defer, reactor
diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py
index ca3858b184..0e52811948 100644
--- a/tests/util/test_linearizer.py
+++ b/tests/util/test_linearizer.py
@@ -14,8 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from six.moves import range
-
from twisted.internet import defer, reactor
from twisted.internet.defer import CancelledError
diff --git a/tests/utils.py b/tests/utils.py
index 59c020a051..4f0b67df9f 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -21,9 +21,9 @@ import time
import uuid
import warnings
from inspect import getcallargs
+from urllib import parse as urlparse
from mock import Mock, patch
-from six.moves.urllib import parse as urlparse
from twisted.internet import defer, reactor
@@ -168,6 +168,9 @@ def default_config(name, parse=False):
# background, which upsets the test runner.
"update_user_directory": False,
"caches": {"global_factor": 1},
+ "listeners": [{"port": 0, "type": "http"}],
+ # Enable encryption by default in private rooms
+ "encryption_enabled_by_default_for_room_type": "invite",
}
if parse:
|