diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py
new file mode 100644
index 0000000000..38dc0e5aa2
--- /dev/null
+++ b/tests/federation/transport/test_knocking.py
@@ -0,0 +1,288 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Matrix.org Federation C.I.C
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from collections import OrderedDict
+from typing import Dict, List
+
+from twisted.internet.defer import succeed
+
+from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.api.room_versions import RoomVersions
+from synapse.events import builder
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.server import HomeServer
+from synapse.types import RoomAlias
+
+from tests.test_utils import event_injection
+from tests.unittest import FederatingHomeserverTestCase, TestCase, override_config
+
+# An identifier to use while MSC2304 is not in a stable release of the spec
+KNOCK_UNSTABLE_IDENTIFIER = "xyz.amorgan.knock"
+
+
+class KnockingStrippedStateEventHelperMixin(TestCase):
+ def send_example_state_events_to_room(
+ self, hs: "HomeServer", room_id: str, sender: str,
+ ) -> OrderedDict:
+ """Adds some state to a room. State events are those that should be sent to a knocking
+ user after they knock on the room, as well as some state that *shouldn't* be sent
+ to the knocking user.
+
+ Args:
+ hs: The homeserver of the sender.
+ room_id: The ID of the room to send state into.
+ sender: The ID of the user to send state as. Must be in the room.
+
+ Returns:
+ The OrderedDict of event types and content that a user is expected to see
+ after knocking on a room.
+ """
+ # To set a canonical alias, we'll need to point an alias at the room first.
+ canonical_alias = "#fancy_alias:test"
+ self.get_success(
+ self.store.create_room_alias_association(
+ RoomAlias.from_string(canonical_alias), room_id, ["test"]
+ )
+ )
+
+ # Send some state that we *don't* expect to be given to knocking users
+ self.get_success(
+ event_injection.inject_event(
+ hs,
+ room_version=RoomVersions.V7.identifier,
+ room_id=room_id,
+ sender=sender,
+ type="com.example.secret",
+ state_key="",
+ content={"secret": "password"},
+ )
+ )
+
+ # We use an OrderedDict here to ensure that the knock membership appears last.
+ # Note that order only matters when sending stripped state to clients, not federated
+ # homeservers.
+ room_state = OrderedDict(
+ [
+ # We need to set the room's join rules to allow knocking
+ (
+ EventTypes.JoinRules,
+ {"content": {"join_rule": JoinRules.KNOCK}, "state_key": ""},
+ ),
+ # Below are state events that are to be stripped and sent to clients
+ (
+ EventTypes.Name,
+ {"content": {"name": "A cool room"}, "state_key": ""},
+ ),
+ (
+ EventTypes.RoomAvatar,
+ {
+ "content": {
+ "info": {
+ "h": 398,
+ "mimetype": "image/jpeg",
+ "size": 31037,
+ "w": 394,
+ },
+ "url": "mxc://example.org/JWEIFJgwEIhweiWJE",
+ },
+ "state_key": "",
+ },
+ ),
+ (
+ EventTypes.RoomEncryption,
+ {"content": {"algorithm": "m.megolm.v1.aes-sha2"}, "state_key": ""},
+ ),
+ (
+ EventTypes.CanonicalAlias,
+ {
+ "content": {"alias": canonical_alias, "alt_aliases": []},
+ "state_key": "",
+ },
+ ),
+ ]
+ )
+
+ for event_type, event_dict in room_state.items():
+ event_content = event_dict["content"]
+ state_key = event_dict["state_key"]
+
+ self.get_success(
+ event_injection.inject_event(
+ hs,
+ room_version=RoomVersions.V7.identifier,
+ room_id=room_id,
+ sender=sender,
+ type=event_type,
+ state_key=state_key,
+ content=event_content,
+ )
+ )
+
+ return room_state
+
+ def check_knock_room_state_against_room_state(
+ self, knock_room_state: List[Dict], expected_room_state: Dict,
+ ) -> None:
+ """Test a list of stripped room state events received over federation against a
+ dict of expected state events.
+
+ Args:
+ knock_room_state: The list of room state that was received over federation.
+ expected_room_state: A dict containing the room state we expect to see in
+ `knock_room_state`.
+ """
+ for event in knock_room_state:
+ event_type = event["type"]
+
+ # Check that this event type is one of those that we expected.
+ # Note: This will also check that no excess state was included
+ self.assertIn(event_type, expected_room_state)
+
+ # Check the state content matches
+ self.assertEquals(
+ expected_room_state[event_type]["content"], event["content"]
+ )
+
+ # Check the state key is correct
+ self.assertEqual(
+ expected_room_state[event_type]["state_key"], event["state_key"]
+ )
+
+ # Ensure the event has been stripped
+ self.assertNotIn("signatures", event)
+
+ # Pop once we've found and processed a state event
+ expected_room_state.pop(event_type)
+
+ # Check that all expected state events were accounted for
+ self.assertEqual(len(expected_room_state), 0)
+
+
+class FederationKnockingTestCase(
+ FederatingHomeserverTestCase, KnockingStrippedStateEventHelperMixin
+):
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, homeserver):
+ self.store = homeserver.get_datastore()
+
+ # We're not going to be properly signing events as our remote homeserver is fake,
+ # therefore disable event signature checks.
+ # Note that these checks are not relevant to this test case.
+
+ # Have this homeserver auto-approve all event signature checking.
+ def approve_all_signature_checking(_, ev):
+ return [succeed(ev[0])]
+
+ homeserver.get_federation_server()._check_sigs_and_hashes = (
+ approve_all_signature_checking
+ )
+
+ # Have this homeserver skip event auth checks. This is necessary due to
+ # event auth checks ensuring that events were signed the sender's homeserver.
+ async def do_auth(origin, event, context, auth_events):
+ return context
+
+ homeserver.get_federation_handler().do_auth = do_auth
+
+ return super().prepare(reactor, clock, homeserver)
+
+ @override_config({"experimental_features": {"msc2403_enabled": True}})
+ def test_room_state_returned_when_knocking(self):
+ """
+ Tests that specific, stripped state events from a room are returned after
+ a remote homeserver successfully knocks on a local room.
+ """
+ user_id = self.register_user("u1", "you the one")
+ user_token = self.login("u1", "you the one")
+
+ fake_knocking_user_id = "@user:other.example.com"
+
+ # Create a room with a room version that includes knocking
+ room_id = self.helper.create_room_as(
+ "u1",
+ is_public=False,
+ room_version=RoomVersions.V7.identifier,
+ tok=user_token,
+ )
+
+ # Update the join rules and add additional state to the room to check for later
+ expected_room_state = self.send_example_state_events_to_room(
+ self.hs, room_id, user_id
+ )
+
+ _, channel = self.make_request(
+ "GET",
+ "/_matrix/federation/unstable/%s/make_knock/%s/%s?ver=%s"
+ % (
+ KNOCK_UNSTABLE_IDENTIFIER,
+ room_id,
+ fake_knocking_user_id,
+ # Inform the remote that we support the room version of the room we're
+ # knocking on
+ RoomVersions.V7.identifier,
+ ),
+ )
+ self.assertEquals(200, channel.code, channel.result)
+
+ # Note: We don't expect the knock membership event to be sent over federation as
+ # part of the stripped room state, as the knocking homeserver already has that
+ # event. It is only done for clients during /sync
+
+ # Extract the generated knock event json
+ knock_event = channel.json_body["event"]
+
+ # Check that the event has things we expect in it
+ self.assertEquals(knock_event["room_id"], room_id)
+ self.assertEquals(knock_event["sender"], fake_knocking_user_id)
+ self.assertEquals(knock_event["state_key"], fake_knocking_user_id)
+ self.assertEquals(knock_event["type"], EventTypes.Member)
+ self.assertEquals(knock_event["content"]["membership"], Membership.KNOCK)
+
+ # Turn the event json dict into a proper event.
+ # We won't sign it properly, but that's OK as we stub out event auth in `prepare`
+ signed_knock_event = builder.create_local_event_from_event_dict(
+ self.clock,
+ self.hs.hostname,
+ self.hs.signing_key,
+ room_version=RoomVersions.V7,
+ event_dict=knock_event,
+ )
+
+ # Convert our proper event back to json dict format
+ signed_knock_event_json = signed_knock_event.get_pdu_json(
+ self.clock.time_msec()
+ )
+
+ # Send the signed knock event into the room
+ _, channel = self.make_request(
+ "PUT",
+ "/_matrix/federation/unstable/%s/send_knock/%s/%s"
+ % (KNOCK_UNSTABLE_IDENTIFIER, room_id, signed_knock_event.event_id),
+ signed_knock_event_json,
+ )
+ self.assertEquals(200, channel.code, channel.result)
+
+ # Check that we got the stripped room state in return
+ room_state_events = channel.json_body["knock_state_events"]
+
+ # Validate the stripped room state events
+ self.check_knock_room_state_against_room_state(
+ room_state_events, expected_room_state
+ )
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index a39f898608..ebc6a0866a 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -42,6 +42,8 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.mock_registry.register_query_handler = register_query_handler
hs = self.setup_test_homeserver(
+ federation_http_client=None,
+ resource_for_federation=Mock(),
federation_client=self.mock_federation,
federation_registry=self.mock_registry,
)
diff --git a/tests/handlers/test_identity.py b/tests/handlers/test_identity.py
new file mode 100644
index 0000000000..b7d340bcb8
--- /dev/null
+++ b/tests/handlers/test_identity.py
@@ -0,0 +1,116 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from mock import Mock
+
+from twisted.internet import defer
+
+import synapse.rest.admin
+from synapse.rest.client.v1 import login
+from synapse.rest.client.v2_alpha import account
+
+from tests import unittest
+
+
+class ThreepidISRewrittenURLTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ account.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ self.address = "test@test"
+ self.is_server_name = "testis"
+ self.is_server_url = "https://testis"
+ self.rewritten_is_url = "https://int.testis"
+
+ config = self.default_config()
+ config["trusted_third_party_id_servers"] = [self.is_server_name]
+ config["rewrite_identity_server_urls"] = {
+ self.is_server_url: self.rewritten_is_url
+ }
+
+ mock_http_client = Mock(spec=["get_json", "post_json_get_json"])
+ mock_http_client.get_json.side_effect = defer.succeed({})
+ mock_http_client.post_json_get_json.return_value = defer.succeed(
+ {"address": self.address, "medium": "email"}
+ )
+
+ self.hs = self.setup_test_homeserver(
+ config=config, simple_http_client=mock_http_client
+ )
+
+ mock_blacklisting_http_client = Mock(spec=["get_json", "post_json_get_json"])
+ mock_blacklisting_http_client.get_json.side_effect = defer.succeed({})
+ mock_blacklisting_http_client.post_json_get_json.return_value = defer.succeed(
+ {"address": self.address, "medium": "email"}
+ )
+
+ # TODO: This class does not use a singleton to get it's http client
+ # This should be fixed for easier testing
+ # https://github.com/matrix-org/synapse-dinsic/issues/26
+ self.hs.get_identity_handler().blacklisting_http_client = (
+ mock_blacklisting_http_client
+ )
+
+ return self.hs
+
+ def prepare(self, reactor, clock, hs):
+ self.user_id = self.register_user("kermit", "monkey")
+
+ def test_rewritten_id_server(self):
+ """
+ Tests that, when validating a 3PID association while rewriting the IS's server
+ name:
+ * the bind request is done against the rewritten hostname
+ * the original, non-rewritten, server name is stored in the database
+ """
+ handler = self.hs.get_identity_handler()
+ post_json_get_json = handler.blacklisting_http_client.post_json_get_json
+ store = self.hs.get_datastore()
+
+ creds = {"sid": "123", "client_secret": "some_secret"}
+
+ # Make sure processing the mocked response goes through.
+ data = self.get_success(
+ handler.bind_threepid(
+ client_secret=creds["client_secret"],
+ sid=creds["sid"],
+ mxid=self.user_id,
+ id_server=self.is_server_name,
+ use_v2=False,
+ )
+ )
+ self.assertEqual(data.get("address"), self.address)
+
+ # Check that the request was done against the rewritten server name.
+ post_json_get_json.assert_called_once_with(
+ "%s/_matrix/identity/api/v1/3pid/bind" % (self.rewritten_is_url,),
+ {
+ "sid": creds["sid"],
+ "client_secret": creds["client_secret"],
+ "mxid": self.user_id,
+ },
+ headers={},
+ )
+
+ # Check that the original server name is saved in the database instead of the
+ # rewritten one.
+ id_servers = self.get_success(
+ store.get_id_servers_user_bound(self.user_id, "email", self.address)
+ )
+ self.assertEqual(id_servers, [self.is_server_name])
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 919547556b..50955ade97 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -63,7 +63,7 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_my_name(self):
yield defer.ensureDeferred(
- self.store.set_profile_displayname(self.frank.localpart, "Frank")
+ self.store.set_profile_displayname(self.frank.localpart, "Frank", 1)
)
displayname = yield defer.ensureDeferred(
@@ -111,7 +111,7 @@ class ProfileTestCase(unittest.TestCase):
# Setting displayname for the first time is allowed
yield defer.ensureDeferred(
- self.store.set_profile_displayname(self.frank.localpart, "Frank")
+ self.store.set_profile_displayname(self.frank.localpart, "Frank", 1)
)
self.assertEquals(
@@ -164,7 +164,7 @@ class ProfileTestCase(unittest.TestCase):
def test_incoming_fed_query(self):
yield defer.ensureDeferred(self.store.create_profile("caroline"))
yield defer.ensureDeferred(
- self.store.set_profile_displayname("caroline", "Caroline")
+ self.store.set_profile_displayname("caroline", "Caroline", 1)
)
response = yield defer.ensureDeferred(
@@ -179,7 +179,7 @@ class ProfileTestCase(unittest.TestCase):
def test_get_my_avatar(self):
yield defer.ensureDeferred(
self.store.set_profile_avatar_url(
- self.frank.localpart, "http://my.server/me.png"
+ self.frank.localpart, "http://my.server/me.png", 1
)
)
avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank))
@@ -230,7 +230,7 @@ class ProfileTestCase(unittest.TestCase):
# Setting displayname for the first time is allowed
yield defer.ensureDeferred(
self.store.set_profile_avatar_url(
- self.frank.localpart, "http://my.server/me.png"
+ self.frank.localpart, "http://my.server/me.png", 1
)
)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index bdf3d0a8a2..9c39b140a1 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -18,9 +18,15 @@ from mock import Mock
from synapse.api.auth import Auth
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, ResourceLimitError, SynapseError
+from synapse.http.site import SynapseRequest
+from synapse.rest.client.v2_alpha.register import (
+ _map_email_to_displayname,
+ register_servlets,
+)
from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import RoomAlias, UserID, create_requester
+from tests.server import FakeChannel
from tests.test_utils import make_awaitable
from tests.unittest import override_config
from tests.utils import mock_getRawHeaders
@@ -31,6 +37,10 @@ from .. import unittest
class RegistrationTestCase(unittest.HomeserverTestCase):
""" Tests the RegistrationHandler. """
+ servlets = [
+ register_servlets,
+ ]
+
def make_homeserver(self, reactor, clock):
hs_config = self.default_config()
@@ -517,6 +527,105 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertTrue(requester.shadow_banned)
+ def test_email_to_displayname_mapping(self):
+ """Test that custom emails are mapped to new user displaynames correctly"""
+ self._check_mapping(
+ "jack-phillips.rivers@big-org.com", "Jack-Phillips Rivers [Big-Org]"
+ )
+
+ self._check_mapping("bob.jones@matrix.org", "Bob Jones [Tchap Admin]")
+
+ self._check_mapping("bob-jones.blabla@gouv.fr", "Bob-Jones Blabla [Gouv]")
+
+ # Multibyte unicode characters
+ self._check_mapping(
+ "j\u030a\u0065an-poppy.seed@example.com",
+ "J\u030a\u0065an-Poppy Seed [Example]",
+ )
+
+ def _check_mapping(self, i, expected):
+ result = _map_email_to_displayname(i)
+ self.assertEqual(result, expected)
+
+ @override_config(
+ {
+ "bind_new_user_emails_to_sydent": "https://is.example.com",
+ "registrations_require_3pid": ["email"],
+ "account_threepid_delegates": {},
+ "email": {
+ "smtp_host": "127.0.0.1",
+ "smtp_port": 20,
+ "require_transport_security": False,
+ "smtp_user": None,
+ "smtp_pass": None,
+ "notif_from": "test@example.com",
+ },
+ "public_baseurl": "http://localhost",
+ }
+ )
+ def test_user_email_bound_via_sydent_internal_api(self):
+ """Tests that emails are bound after registration if this option is set"""
+ # Register user with an email address
+ email = "alice@example.com"
+
+ # Mock Synapse's threepid validator
+ get_threepid_validation_session = Mock(
+ return_value=make_awaitable(
+ {"medium": "email", "address": email, "validated_at": 0}
+ )
+ )
+ self.store.get_threepid_validation_session = get_threepid_validation_session
+ delete_threepid_session = Mock(return_value=make_awaitable(None))
+ self.store.delete_threepid_session = delete_threepid_session
+
+ # Mock Synapse's http json post method to check for the internal bind call
+ post_json_get_json = Mock(return_value=make_awaitable(None))
+ self.hs.get_identity_handler().http_client.post_json_get_json = (
+ post_json_get_json
+ )
+
+ # Retrieve a UIA session ID
+ channel = self.uia_register(
+ 401, {"username": "alice", "password": "nobodywillguessthis"}
+ )
+ session_id = channel.json_body["session"]
+
+ # Register our email address using the fake validation session above
+ channel = self.uia_register(
+ 200,
+ {
+ "username": "alice",
+ "password": "nobodywillguessthis",
+ "auth": {
+ "session": session_id,
+ "type": "m.login.email.identity",
+ "threepid_creds": {"sid": "blabla", "client_secret": "blablabla"},
+ },
+ },
+ )
+ self.assertEqual(channel.json_body["user_id"], "@alice:test")
+
+ # Check that a bind attempt was made to our fake identity server
+ post_json_get_json.assert_called_with(
+ "https://is.example.com/_matrix/identity/internal/bind",
+ {"address": "alice@example.com", "medium": "email", "mxid": "@alice:test"},
+ )
+
+ # Check that we stored a mapping of this bind
+ bound_threepids = self.get_success(
+ self.store.user_get_bound_threepids("@alice:test")
+ )
+ self.assertListEqual(bound_threepids, [{"medium": "email", "address": email}])
+
+ def uia_register(self, expected_response: int, body: dict) -> FakeChannel:
+ """Make a register request."""
+ request, channel = self.make_request(
+ "POST", "register", body
+ ) # type: SynapseRequest, FakeChannel
+
+ self.assertEqual(request.code, expected_response)
+ return channel
+
async def get_or_create_user(
self, requester, localpart, displayname, password_hash=None
):
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index 312c0a0d41..0229f58315 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -21,8 +21,14 @@ from tests import unittest
# The expected number of state events in a fresh public room.
EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM = 5
+
# The expected number of state events in a fresh private room.
-EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM = 6
+#
+# Note: we increase this by 2 on the dinsic branch as we send
+# a "im.vector.room.access_rules" state event into new private rooms,
+# and an encryption state event as all private rooms are encrypted
+# by default
+EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM = 7
class StatsRoomTests(unittest.HomeserverTestCase):
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 1260721dbf..dc433e1487 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -18,8 +18,9 @@ from twisted.internet import defer
import synapse.rest.admin
from synapse.api.constants import EventTypes, RoomEncryptionAlgorithms, UserTypes
+from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.rest.client.v1 import login, room
-from synapse.rest.client.v2_alpha import user_directory
+from synapse.rest.client.v2_alpha import account, account_validity, user_directory
from synapse.storage.roommember import ProfileInfo
from tests import unittest
@@ -46,6 +47,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.handler = hs.get_user_directory_handler()
+ self.event_builder_factory = self.hs.get_event_builder_factory()
+ self.event_creation_handler = self.hs.get_event_creation_handler()
def test_handle_local_profile_change_with_support_user(self):
support_user_id = "@support:test"
@@ -503,6 +506,97 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, u4, 10))
self.assertEqual(len(s["results"]), 1)
+ @override_config(
+ {
+ "user_directory": {
+ "enabled": True,
+ "search_all_users": True,
+ "prefer_local_users": True,
+ }
+ }
+ )
+ def test_prefer_local_users(self):
+ """Tests that local users are shown higher in search results when
+ user_directory.prefer_local_users is True.
+ """
+ # Create a room and few users to test the directory with
+ searching_user = self.register_user("searcher", "password")
+ searching_user_tok = self.login("searcher", "password")
+
+ room_id = self.helper.create_room_as(
+ searching_user,
+ room_version=RoomVersions.V1.identifier,
+ tok=searching_user_tok,
+ )
+
+ # Create a few local users and join them to the room
+ local_user_1 = self.register_user("user_xxxxx", "password")
+ local_user_2 = self.register_user("user_bbbbb", "password")
+ local_user_3 = self.register_user("user_zzzzz", "password")
+
+ self._add_user_to_room(room_id, RoomVersions.V1, local_user_1)
+ self._add_user_to_room(room_id, RoomVersions.V1, local_user_2)
+ self._add_user_to_room(room_id, RoomVersions.V1, local_user_3)
+
+ # Create a few "remote" users and join them to the room
+ remote_user_1 = "@user_aaaaa:remote_server"
+ remote_user_2 = "@user_yyyyy:remote_server"
+ remote_user_3 = "@user_ccccc:remote_server"
+ self._add_user_to_room(room_id, RoomVersions.V1, remote_user_1)
+ self._add_user_to_room(room_id, RoomVersions.V1, remote_user_2)
+ self._add_user_to_room(room_id, RoomVersions.V1, remote_user_3)
+
+ local_users = [local_user_1, local_user_2, local_user_3]
+ remote_users = [remote_user_1, remote_user_2, remote_user_3]
+
+ # Populate the user directory via background update
+ self._add_background_updates()
+ while not self.get_success(
+ self.store.db_pool.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
+ )
+
+ # The local searching user searches for the term "user", which other users have
+ # in their user id
+ results = self.get_success(
+ self.handler.search_users(searching_user, "user", 20)
+ )["results"]
+ received_user_id_ordering = [result["user_id"] for result in results]
+
+ # Typically we'd expect Synapse to return users in lexicographical order,
+ # assuming they have similar User IDs/display names, and profile information.
+
+ # Check that the order of returned results using our module is as we expect,
+ # i.e our local users show up first, despite all users having lexographically mixed
+ # user IDs.
+ [self.assertIn(user, local_users) for user in received_user_id_ordering[:3]]
+ [self.assertIn(user, remote_users) for user in received_user_id_ordering[3:]]
+
+ def _add_user_to_room(
+ self, room_id: str, room_version: RoomVersion, user_id: str,
+ ):
+ # Add a user to the room.
+ builder = self.event_builder_factory.for_room_version(
+ room_version,
+ {
+ "type": "m.room.member",
+ "sender": user_id,
+ "state_key": user_id,
+ "room_id": room_id,
+ "content": {"membership": "join"},
+ },
+ )
+
+ event, context = self.get_success(
+ self.event_creation_handler.create_new_client_event(builder)
+ )
+
+ self.get_success(
+ self.hs.get_storage().persistence.persist_event(event, context)
+ )
+
class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
user_id = "@test:test"
@@ -547,3 +641,132 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
)
self.assertEquals(200, channel.code, channel.result)
self.assertTrue(len(channel.json_body["results"]) == 0)
+
+
+class UserInfoTestCase(unittest.FederatingHomeserverTestCase):
+ servlets = [
+ login.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ account_validity.register_servlets,
+ synapse.rest.client.v2_alpha.user_directory.register_servlets,
+ account.register_servlets,
+ ]
+
+ def default_config(self):
+ config = super().default_config()
+
+ # Set accounts to expire after a week
+ config["account_validity"] = {
+ "enabled": True,
+ "period": 604800000, # Time in ms for 1 week
+ }
+ return config
+
+ def prepare(self, reactor, clock, hs):
+ super(UserInfoTestCase, self).prepare(reactor, clock, hs)
+ self.store = hs.get_datastore()
+ self.handler = hs.get_user_directory_handler()
+
+ def test_user_info(self):
+ """Test /users/info for local users from the Client-Server API"""
+ user_one, user_two, user_three, user_three_token = self.setup_test_users()
+
+ # Request info about each user from user_three
+ request, channel = self.make_request(
+ "POST",
+ path="/_matrix/client/unstable/users/info",
+ content={"user_ids": [user_one, user_two, user_three]},
+ access_token=user_three_token,
+ shorthand=False,
+ )
+ self.assertEquals(200, channel.code, channel.result)
+
+ # Check the state of user_one matches
+ user_one_info = channel.json_body[user_one]
+ self.assertTrue(user_one_info["deactivated"])
+ self.assertFalse(user_one_info["expired"])
+
+ # Check the state of user_two matches
+ user_two_info = channel.json_body[user_two]
+ self.assertFalse(user_two_info["deactivated"])
+ self.assertTrue(user_two_info["expired"])
+
+ # Check the state of user_three matches
+ user_three_info = channel.json_body[user_three]
+ self.assertFalse(user_three_info["deactivated"])
+ self.assertFalse(user_three_info["expired"])
+
+ def test_user_info_federation(self):
+ """Test that /users/info can be called from the Federation API, and
+ and that we can query remote users from the Client-Server API
+ """
+ user_one, user_two, user_three, user_three_token = self.setup_test_users()
+
+ # Request information about our local users from the perspective of a remote server
+ request, channel = self.make_request(
+ "POST",
+ path="/_matrix/federation/unstable/users/info",
+ content={"user_ids": [user_one, user_two, user_three]},
+ )
+ self.assertEquals(200, channel.code)
+
+ # Check the state of user_one matches
+ user_one_info = channel.json_body[user_one]
+ self.assertTrue(user_one_info["deactivated"])
+ self.assertFalse(user_one_info["expired"])
+
+ # Check the state of user_two matches
+ user_two_info = channel.json_body[user_two]
+ self.assertFalse(user_two_info["deactivated"])
+ self.assertTrue(user_two_info["expired"])
+
+ # Check the state of user_three matches
+ user_three_info = channel.json_body[user_three]
+ self.assertFalse(user_three_info["deactivated"])
+ self.assertFalse(user_three_info["expired"])
+
+ def setup_test_users(self):
+ """Create an admin user and three test users, each with a different state"""
+
+ # Create an admin user to expire other users with
+ self.register_user("admin", "adminpassword", admin=True)
+ admin_token = self.login("admin", "adminpassword")
+
+ # Create three users
+ user_one = self.register_user("alice", "pass")
+ user_one_token = self.login("alice", "pass")
+ user_two = self.register_user("bob", "pass")
+ user_three = self.register_user("carl", "pass")
+ user_three_token = self.login("carl", "pass")
+
+ # Deactivate user_one
+ self.deactivate(user_one, user_one_token)
+
+ # Expire user_two
+ self.expire(user_two, admin_token)
+
+ # Do nothing to user_three
+
+ return user_one, user_two, user_three, user_three_token
+
+ def expire(self, user_id_to_expire, admin_tok):
+ url = "/_synapse/admin/v1/account_validity/validity"
+ request_data = {
+ "user_id": user_id_to_expire,
+ "expiration_ts": 0,
+ "enable_renewal_emails": False,
+ }
+ request, channel = self.make_request(
+ "POST", url, request_data, access_token=admin_tok
+ )
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ def deactivate(self, user_id, tok):
+ request_data = {
+ "auth": {"type": "m.login.password", "user": user_id, "password": "pass"},
+ "erase": False,
+ }
+ request, channel = self.make_request(
+ "POST", "account/deactivate", request_data, access_token=tok
+ )
+ self.assertEqual(request.code, 200)
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 626acdcaa3..aac8c7d623 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -102,7 +102,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.agent = MatrixFederationAgent(
reactor=self.reactor,
- tls_client_options_factory=self.tls_factory,
+ tls_client_options_factory=FederationPolicyForHTTPS(config),
user_agent="test-agent", # Note that this is unused since _well_known_resolver is provided.
ip_blacklist=IPSet(),
_srv_resolver=self.mock_resolver,
diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
index 22abf76515..25d07704e1 100644
--- a/tests/http/test_proxyagent.py
+++ b/tests/http/test_proxyagent.py
@@ -12,15 +12,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import base64
import logging
+import os
+from typing import Optional
+from unittest.mock import patch
import treq
+from netaddr import IPSet
from twisted.internet import interfaces # noqa: F401
from twisted.internet.protocol import Factory
from twisted.protocols.tls import TLSMemoryBIOFactory
from twisted.web.http import HTTPChannel
+from synapse.http.client import BlacklistingReactorWrapper
from synapse.http.proxyagent import ProxyAgent
from tests.http import TestServerTLSConnectionFactory, get_test_https_policy
@@ -98,18 +104,115 @@ class MatrixFederationAgentTests(TestCase):
return http_protocol
+ def _test_request_direct_connection(self, agent, scheme, hostname, path):
+ """Runs a test case for a direct connection not going through a proxy.
+
+ Args:
+ agent (ProxyAgent): the proxy agent being tested
+
+ scheme (bytes): expected to be either "http" or "https"
+
+ hostname (bytes): the hostname to connect to in the test
+
+ path (bytes): the path to connect to in the test
+ """
+ is_https = scheme == b"https"
+
+ self.reactor.lookups[hostname.decode()] = "1.2.3.4"
+ d = agent.request(b"GET", scheme + b"://" + hostname + b"/" + path)
+
+ # there should be a pending TCP connection
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 443 if is_https else 80)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory,
+ _get_test_protocol_factory(),
+ ssl=is_https,
+ expected_sni=hostname if is_https else None,
+ )
+
+ # the FakeTransport is async, so we need to pump the reactor
+ self.reactor.advance(0)
+
+ # now there should be a pending request
+ self.assertEqual(len(http_server.requests), 1)
+
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"/" + path)
+ self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [hostname])
+ request.write(b"result")
+ request.finish()
+
+ self.reactor.advance(0)
+
+ resp = self.successResultOf(d)
+ body = self.successResultOf(treq.content(resp))
+ self.assertEqual(body, b"result")
+
def test_http_request(self):
agent = ProxyAgent(self.reactor)
+ self._test_request_direct_connection(agent, b"http", b"test.com", b"")
+
+ def test_https_request(self):
+ agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy())
+ self._test_request_direct_connection(agent, b"https", b"test.com", b"abc")
+
+ def test_http_request_use_proxy_empty_environment(self):
+ agent = ProxyAgent(self.reactor, use_proxy=True)
+ self._test_request_direct_connection(agent, b"http", b"test.com", b"")
+
+ @patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "NO_PROXY": "test.com"})
+ def test_http_request_via_uppercase_no_proxy(self):
+ agent = ProxyAgent(self.reactor, use_proxy=True)
+ self._test_request_direct_connection(agent, b"http", b"test.com", b"")
- self.reactor.lookups["test.com"] = "1.2.3.4"
+ @patch.dict(
+ os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "test.com,unused.com"}
+ )
+ def test_http_request_via_no_proxy(self):
+ agent = ProxyAgent(self.reactor, use_proxy=True)
+ self._test_request_direct_connection(agent, b"http", b"test.com", b"")
+
+ @patch.dict(
+ os.environ, {"https_proxy": "proxy.com", "no_proxy": "test.com,unused.com"}
+ )
+ def test_https_request_via_no_proxy(self):
+ agent = ProxyAgent(
+ self.reactor, contextFactory=get_test_https_policy(), use_proxy=True,
+ )
+ self._test_request_direct_connection(agent, b"https", b"test.com", b"abc")
+
+ @patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "*"})
+ def test_http_request_via_no_proxy_star(self):
+ agent = ProxyAgent(self.reactor, use_proxy=True)
+ self._test_request_direct_connection(agent, b"http", b"test.com", b"")
+
+ @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "*"})
+ def test_https_request_via_no_proxy_star(self):
+ agent = ProxyAgent(
+ self.reactor, contextFactory=get_test_https_policy(), use_proxy=True,
+ )
+ self._test_request_direct_connection(agent, b"https", b"test.com", b"abc")
+
+ @patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "unused.com"})
+ def test_http_request_via_proxy(self):
+ agent = ProxyAgent(self.reactor, use_proxy=True)
+
+ self.reactor.lookups["proxy.com"] = "1.2.3.5"
d = agent.request(b"GET", b"http://test.com")
# there should be a pending TCP connection
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
- self.assertEqual(host, "1.2.3.4")
- self.assertEqual(port, 80)
+ self.assertEqual(host, "1.2.3.5")
+ self.assertEqual(port, 8888)
# make a test server, and wire up the client
http_server = self._make_connection(
@@ -124,7 +227,7 @@ class MatrixFederationAgentTests(TestCase):
request = http_server.requests[0]
self.assertEqual(request.method, b"GET")
- self.assertEqual(request.path, b"/")
+ self.assertEqual(request.path, b"http://test.com")
self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
request.write(b"result")
request.finish()
@@ -135,30 +238,99 @@ class MatrixFederationAgentTests(TestCase):
body = self.successResultOf(treq.content(resp))
self.assertEqual(body, b"result")
- def test_https_request(self):
- agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy())
+ @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"})
+ def test_https_request_via_proxy(self):
+ """Tests that TLS-encrypted requests can be made through a proxy"""
+ self._do_https_request_via_proxy(auth_credentials=None)
- self.reactor.lookups["test.com"] = "1.2.3.4"
+ @patch.dict(
+ os.environ,
+ {"https_proxy": "bob:pinkponies@proxy.com", "no_proxy": "unused.com"},
+ )
+ def test_https_request_via_proxy_with_auth(self):
+ """Tests that authenticated, TLS-encrypted requests can be made through a proxy"""
+ self._do_https_request_via_proxy(auth_credentials="bob:pinkponies")
+
+ def _do_https_request_via_proxy(
+ self, auth_credentials: Optional[str] = None,
+ ):
+ agent = ProxyAgent(
+ self.reactor, contextFactory=get_test_https_policy(), use_proxy=True,
+ )
+
+ self.reactor.lookups["proxy.com"] = "1.2.3.5"
d = agent.request(b"GET", b"https://test.com/abc")
# there should be a pending TCP connection
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
- self.assertEqual(host, "1.2.3.4")
- self.assertEqual(port, 443)
+ self.assertEqual(host, "1.2.3.5")
+ self.assertEqual(port, 1080)
- # make a test server, and wire up the client
- http_server = self._make_connection(
- client_factory,
- _get_test_protocol_factory(),
- ssl=True,
- expected_sni=b"test.com",
+ # make a test HTTP server, and wire up the client
+ proxy_server = self._make_connection(
+ client_factory, _get_test_protocol_factory()
)
+ # fish the transports back out so that we can do the old switcheroo
+ s2c_transport = proxy_server.transport
+ client_protocol = s2c_transport.other
+ c2s_transport = client_protocol.transport
+
# the FakeTransport is async, so we need to pump the reactor
self.reactor.advance(0)
+ # now there should be a pending CONNECT request
+ self.assertEqual(len(proxy_server.requests), 1)
+
+ request = proxy_server.requests[0]
+ self.assertEqual(request.method, b"CONNECT")
+ self.assertEqual(request.path, b"test.com:443")
+
+ # Check whether auth credentials have been supplied to the proxy
+ proxy_auth_header_values = request.requestHeaders.getRawHeaders(
+ b"Proxy-Authorization"
+ )
+
+ if auth_credentials is not None:
+ # Compute the correct header value for Proxy-Authorization
+ encoded_credentials = base64.b64encode(b"bob:pinkponies")
+ expected_header_value = b"Basic " + encoded_credentials
+
+ # Validate the header's value
+ self.assertIn(expected_header_value, proxy_auth_header_values)
+ else:
+ # Check that the Proxy-Authorization header has not been supplied to the proxy
+ self.assertIsNone(proxy_auth_header_values)
+
+ # tell the proxy server not to close the connection
+ proxy_server.persistent = True
+
+ # this just stops the http Request trying to do a chunked response
+ # request.setHeader(b"Content-Length", b"0")
+ request.finish()
+
+ # now we can replace the proxy channel with a new, SSL-wrapped HTTP channel
+ ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory())
+ ssl_protocol = ssl_factory.buildProtocol(None)
+ http_server = ssl_protocol.wrappedProtocol
+
+ ssl_protocol.makeConnection(
+ FakeTransport(client_protocol, self.reactor, ssl_protocol)
+ )
+ c2s_transport.other = ssl_protocol
+
+ self.reactor.advance(0)
+
+ server_name = ssl_protocol._tlsConnection.get_servername()
+ expected_sni = b"test.com"
+ self.assertEqual(
+ server_name,
+ expected_sni,
+ "Expected SNI %s but got %s" % (expected_sni, server_name),
+ )
+
# now there should be a pending request
self.assertEqual(len(http_server.requests), 1)
@@ -166,6 +338,13 @@ class MatrixFederationAgentTests(TestCase):
self.assertEqual(request.method, b"GET")
self.assertEqual(request.path, b"/abc")
self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+
+ # Check that the destination server DID NOT receive proxy credentials
+ proxy_auth_header_values = request.requestHeaders.getRawHeaders(
+ b"Proxy-Authorization"
+ )
+ self.assertIsNone(proxy_auth_header_values)
+
request.write(b"result")
request.finish()
@@ -175,8 +354,16 @@ class MatrixFederationAgentTests(TestCase):
body = self.successResultOf(treq.content(resp))
self.assertEqual(body, b"result")
- def test_http_request_via_proxy(self):
- agent = ProxyAgent(self.reactor, http_proxy=b"proxy.com:8888")
+ @patch.dict(os.environ, {"http_proxy": "proxy.com:8888"})
+ def test_http_request_via_proxy_with_blacklist(self):
+ # The blacklist includes the configured proxy IP.
+ agent = ProxyAgent(
+ BlacklistingReactorWrapper(
+ self.reactor, ip_whitelist=None, ip_blacklist=IPSet(["1.0.0.0/8"])
+ ),
+ self.reactor,
+ use_proxy=True,
+ )
self.reactor.lookups["proxy.com"] = "1.2.3.5"
d = agent.request(b"GET", b"http://test.com")
@@ -212,11 +399,16 @@ class MatrixFederationAgentTests(TestCase):
body = self.successResultOf(treq.content(resp))
self.assertEqual(body, b"result")
- def test_https_request_via_proxy(self):
+ @patch.dict(os.environ, {"HTTPS_PROXY": "proxy.com"})
+ def test_https_request_via_uppercase_proxy_with_blacklist(self):
+ # The blacklist includes the configured proxy IP.
agent = ProxyAgent(
+ BlacklistingReactorWrapper(
+ self.reactor, ip_whitelist=None, ip_blacklist=IPSet(["1.0.0.0/8"])
+ ),
self.reactor,
contextFactory=get_test_https_policy(),
- https_proxy=b"proxy.com",
+ use_proxy=True,
)
self.reactor.lookups["proxy.com"] = "1.2.3.5"
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index 60f0820cff..a3b304d316 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -401,8 +401,8 @@ class HTTPPusherTests(HomeserverTestCase):
self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
)
- # check that this is low-priority
- self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
+ # check that this is high-priority
+ self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high")
def test_sends_high_priority_for_mention(self):
"""
@@ -477,8 +477,8 @@ class HTTPPusherTests(HomeserverTestCase):
self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
)
- # check that this is low-priority
- self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
+ # check that this is high-priority
+ self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high")
def test_sends_high_priority_for_atroom(self):
"""
@@ -560,8 +560,8 @@ class HTTPPusherTests(HomeserverTestCase):
self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
)
- # check that this is low-priority
- self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
+ # check that this is high-priority
+ self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high")
def test_push_unread_count_group_by_room(self):
"""
diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py
index c0a9fc6925..3fdb78d63d 100644
--- a/tests/rest/client/test_identity.py
+++ b/tests/rest/client/test_identity.py
@@ -15,15 +15,22 @@
import json
+from mock import Mock
+
+from twisted.internet import defer
+
import synapse.rest.admin
from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v2_alpha import account
from tests import unittest
-class IdentityTestCase(unittest.HomeserverTestCase):
+class IdentityDisabledTestCase(unittest.HomeserverTestCase):
+ """Tests that 3PID lookup attempts fail when the HS's config disallows them."""
servlets = [
+ account.register_servlets,
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
login.register_servlets,
@@ -32,18 +39,93 @@ class IdentityTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
config = self.default_config()
+ config["trusted_third_party_id_servers"] = ["testis"]
config["enable_3pid_lookup"] = False
self.hs = self.setup_test_homeserver(config=config)
return self.hs
+ def prepare(self, reactor, clock, hs):
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+
+ def test_3pid_invite_disabled(self):
+ request, channel = self.make_request(
+ b"POST", "/createRoom", b"{}", access_token=self.tok
+ )
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+ room_id = channel.json_body["room_id"]
+
+ params = {
+ "id_server": "testis",
+ "medium": "email",
+ "address": "test@example.com",
+ }
+ request_data = json.dumps(params)
+ request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii")
+ request, channel = self.make_request(
+ b"POST", request_url, request_data, access_token=self.tok
+ )
+ self.assertEquals(channel.result["code"], b"403", channel.result)
+
def test_3pid_lookup_disabled(self):
- self.hs.config.enable_3pid_lookup = False
+ url = (
+ "/_matrix/client/unstable/account/3pid/lookup"
+ "?id_server=testis&medium=email&address=foo@bar.baz"
+ )
+ request, channel = self.make_request("GET", url, access_token=self.tok)
+ self.assertEqual(channel.result["code"], b"403", channel.result)
- self.register_user("kermit", "monkey")
- tok = self.login("kermit", "monkey")
+ def test_3pid_bulk_lookup_disabled(self):
+ url = "/_matrix/client/unstable/account/3pid/bulk_lookup"
+ data = {
+ "id_server": "testis",
+ "threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]],
+ }
+ request_data = json.dumps(data)
+ request, channel = self.make_request(
+ "POST", url, request_data, access_token=self.tok
+ )
+ self.assertEqual(channel.result["code"], b"403", channel.result)
- channel = self.make_request(b"POST", "/createRoom", b"{}", access_token=tok)
+
+class IdentityEnabledTestCase(unittest.HomeserverTestCase):
+ """Tests that 3PID lookup attempts succeed when the HS's config allows them."""
+
+ servlets = [
+ account.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+
+ config = self.default_config()
+ config["enable_3pid_lookup"] = True
+ config["trusted_third_party_id_servers"] = ["testis"]
+
+ mock_http_client = Mock(spec=["get_json", "post_json_get_json"])
+ mock_http_client.get_json.return_value = defer.succeed({"mxid": "@f:test"})
+ mock_http_client.post_json_get_json.return_value = defer.succeed({})
+
+ self.hs = self.setup_test_homeserver(
+ config=config, simple_http_client=mock_http_client
+ )
+
+ # TODO: This class does not use a singleton to get it's http client
+ # This should be fixed for easier testing
+ # https://github.com/matrix-org/synapse-dinsic/issues/26
+ self.hs.get_identity_handler().http_client = mock_http_client
+
+ return self.hs
+
+ def prepare(self, reactor, clock, hs):
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+
+ def test_3pid_invite_enabled(self):
+ channel = self.make_request(b"POST", "/createRoom", b"{}", access_token=self.tok)
self.assertEquals(channel.result["code"], b"200", channel.result)
room_id = channel.json_body["room_id"]
@@ -55,6 +137,40 @@ class IdentityTestCase(unittest.HomeserverTestCase):
request_data = json.dumps(params)
request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii")
channel = self.make_request(
- b"POST", request_url, request_data, access_token=tok
+ b"POST", request_url, request_data, access_token=self.tok
+ )
+
+ get_json = self.hs.get_identity_handler().http_client.get_json
+ get_json.assert_called_once_with(
+ "https://testis/_matrix/identity/api/v1/lookup",
+ {"address": "test@example.com", "medium": "email"},
+ )
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ def test_3pid_lookup_enabled(self):
+ url = (
+ "/_matrix/client/unstable/account/3pid/lookup"
+ "?id_server=testis&medium=email&address=foo@bar.baz"
+ )
+ self.make_request("GET", url, access_token=self.tok)
+
+ get_json = self.hs.get_simple_http_client().get_json
+ get_json.assert_called_once_with(
+ "https://testis/_matrix/identity/api/v1/lookup",
+ {"address": "foo@bar.baz", "medium": "email"},
+ )
+
+ def test_3pid_bulk_lookup_enabled(self):
+ url = "/_matrix/client/unstable/account/3pid/bulk_lookup"
+ data = {
+ "id_server": "testis",
+ "threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]],
+ }
+ request_data = json.dumps(data)
+ self.make_request("POST", url, request_data, access_token=self.tok)
+
+ post_json = self.hs.get_simple_http_client().post_json_get_json
+ post_json.assert_called_once_with(
+ "https://testis/_matrix/identity/api/v1/bulk_lookup",
+ {"threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]]},
)
- self.assertEquals(channel.result["code"], b"403", channel.result)
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index 31dc832fd5..10b1fbac69 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -34,6 +34,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
config = self.default_config()
+ config["default_room_version"] = "1"
config["retention"] = {
"enabled": True,
"default_policy": {
@@ -243,6 +244,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
config = self.default_config()
+ config["default_room_version"] = "1"
config["retention"] = {
"enabled": True,
}
diff --git a/tests/rest/client/test_room_access_rules.py b/tests/rest/client/test_room_access_rules.py
new file mode 100644
index 0000000000..89016b82ce
--- /dev/null
+++ b/tests/rest/client/test_room_access_rules.py
@@ -0,0 +1,1071 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+import random
+import string
+from typing import Optional
+
+from mock import Mock
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, JoinRules, Membership, RoomCreationPreset
+from synapse.api.errors import SynapseError
+from synapse.rest import admin
+from synapse.rest.client.v1 import directory, login, room
+from synapse.third_party_rules.access_rules import (
+ ACCESS_RULES_TYPE,
+ AccessRules,
+ RoomAccessRules,
+)
+from synapse.types import JsonDict, create_requester
+
+from tests import unittest
+
+
+class RoomAccessTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ directory.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+
+ config["third_party_event_rules"] = {
+ "module": "synapse.third_party_rules.access_rules.RoomAccessRules",
+ "config": {
+ "domains_forbidden_when_restricted": ["forbidden_domain"],
+ "id_server": "testis",
+ "freeze_room_with_no_admin": "true",
+ },
+ }
+ config["trusted_third_party_id_servers"] = ["testis"]
+
+ def send_invite(destination, room_id, event_id, pdu):
+ return defer.succeed(pdu)
+
+ def get_json(uri, args={}, headers=None):
+ address_domain = args["address"].split("@")[1]
+ return defer.succeed({"hs": address_domain})
+
+ def post_json_get_json(uri, post_json, args={}, headers=None):
+ token = "".join(random.choice(string.ascii_letters) for _ in range(10))
+ return defer.succeed(
+ {
+ "token": token,
+ "public_keys": [
+ {
+ "public_key": "serverpublickey",
+ "key_validity_url": "https://testis/pubkey/isvalid",
+ },
+ {
+ "public_key": "phemeralpublickey",
+ "key_validity_url": "https://testis/pubkey/ephemeral/isvalid",
+ },
+ ],
+ "display_name": "f...@b...",
+ }
+ )
+
+ mock_federation_client = Mock(spec=["send_invite"])
+ mock_federation_client.send_invite.side_effect = send_invite
+
+ mock_http_client = Mock(spec=["get_json", "post_json_get_json"],)
+ # Mocking the response for /info on the IS API.
+ mock_http_client.get_json.side_effect = get_json
+ # Mocking the response for /store-invite on the IS API.
+ mock_http_client.post_json_get_json.side_effect = post_json_get_json
+ self.hs = self.setup_test_homeserver(
+ config=config,
+ federation_client=mock_federation_client,
+ simple_http_client=mock_http_client,
+ )
+
+ # TODO: This class does not use a singleton to get it's http client
+ # This should be fixed for easier testing
+ # https://github.com/matrix-org/synapse-dinsic/issues/26
+ self.hs.get_identity_handler().blacklisting_http_client = mock_http_client
+
+ self.third_party_event_rules = self.hs.get_third_party_event_rules()
+
+ return self.hs
+
+ def prepare(self, reactor, clock, homeserver):
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+
+ self.restricted_room = self.create_room()
+ self.unrestricted_room = self.create_room(rule=AccessRules.UNRESTRICTED)
+ self.direct_rooms = [
+ self.create_room(direct=True),
+ self.create_room(direct=True),
+ self.create_room(direct=True),
+ ]
+
+ self.invitee_id = self.register_user("invitee", "test")
+ self.invitee_tok = self.login("invitee", "test")
+
+ self.helper.invite(
+ room=self.direct_rooms[0],
+ src=self.user_id,
+ targ=self.invitee_id,
+ tok=self.tok,
+ )
+
+ def test_create_room_no_rule(self):
+ """Tests that creating a room with no rule will set the default."""
+ room_id = self.create_room()
+ rule = self.current_rule_in_room(room_id)
+
+ self.assertEqual(rule, AccessRules.RESTRICTED)
+
+ def test_create_room_direct_no_rule(self):
+ """Tests that creating a direct room with no rule will set the default."""
+ room_id = self.create_room(direct=True)
+ rule = self.current_rule_in_room(room_id)
+
+ self.assertEqual(rule, AccessRules.DIRECT)
+
+ def test_create_room_valid_rule(self):
+ """Tests that creating a room with a valid rule will set the right."""
+ room_id = self.create_room(rule=AccessRules.UNRESTRICTED)
+ rule = self.current_rule_in_room(room_id)
+
+ self.assertEqual(rule, AccessRules.UNRESTRICTED)
+
+ def test_create_room_invalid_rule(self):
+ """Tests that creating a room with an invalid rule will set fail."""
+ self.create_room(rule=AccessRules.DIRECT, expected_code=400)
+
+ def test_create_room_direct_invalid_rule(self):
+ """Tests that creating a direct room with an invalid rule will fail.
+ """
+ self.create_room(direct=True, rule=AccessRules.RESTRICTED, expected_code=400)
+
+ def test_create_room_default_power_level_rules(self):
+ """Tests that a room created with no power level overrides instead uses the dinum
+ defaults
+ """
+ room_id = self.create_room(direct=True, rule=AccessRules.DIRECT)
+ power_levels = self.helper.get_state(room_id, "m.room.power_levels", self.tok)
+
+ # Inviting another user should require PL50, even in private rooms
+ self.assertEqual(power_levels["invite"], 50)
+ # Sending arbitrary state events should require PL100
+ self.assertEqual(power_levels["state_default"], 100)
+
+ def test_create_room_fails_on_incorrect_power_level_rules(self):
+ """Tests that a room created with power levels lower than that required are rejected"""
+ modified_power_levels = RoomAccessRules._get_default_power_levels(self.user_id)
+ modified_power_levels["invite"] = 0
+ modified_power_levels["state_default"] = 50
+
+ self.create_room(
+ direct=True,
+ rule=AccessRules.DIRECT,
+ initial_state=[
+ {"type": "m.room.power_levels", "content": modified_power_levels}
+ ],
+ expected_code=400,
+ )
+
+ def test_create_room_with_missing_power_levels_use_default_values(self):
+ """
+ Tests that a room created with custom power levels, but without defining invite or state_default
+ succeeds, but the missing values are replaced with the defaults.
+ """
+
+ # Attempt to create a room without defining "invite" or "state_default"
+ modified_power_levels = RoomAccessRules._get_default_power_levels(self.user_id)
+ del modified_power_levels["invite"]
+ del modified_power_levels["state_default"]
+ room_id = self.create_room(
+ direct=True,
+ rule=AccessRules.DIRECT,
+ initial_state=[
+ {"type": "m.room.power_levels", "content": modified_power_levels}
+ ],
+ )
+
+ # This should succeed, but the defaults should be put in place instead
+ room_power_levels = self.helper.get_state(
+ room_id, "m.room.power_levels", self.tok
+ )
+ self.assertEqual(room_power_levels["invite"], 50)
+ self.assertEqual(room_power_levels["state_default"], 100)
+
+ # And now the same test, but using power_levels_content_override instead
+ # of initial_state (which takes a slightly different codepath)
+ modified_power_levels = RoomAccessRules._get_default_power_levels(self.user_id)
+ del modified_power_levels["invite"]
+ del modified_power_levels["state_default"]
+ room_id = self.create_room(
+ direct=True,
+ rule=AccessRules.DIRECT,
+ power_levels_content_override=modified_power_levels,
+ )
+
+ # This should succeed, but the defaults should be put in place instead
+ room_power_levels = self.helper.get_state(
+ room_id, "m.room.power_levels", self.tok
+ )
+ self.assertEqual(room_power_levels["invite"], 50)
+ self.assertEqual(room_power_levels["state_default"], 100)
+
+ def test_existing_room_can_change_power_levels(self):
+ """Tests that a room created with default power levels can have their power levels
+ dropped after room creation
+ """
+ # Creates a room with the default power levels
+ room_id = self.create_room(
+ direct=True, rule=AccessRules.DIRECT, expected_code=200,
+ )
+
+ # Attempt to drop invite and state_default power levels after the fact
+ room_power_levels = self.helper.get_state(
+ room_id, "m.room.power_levels", self.tok
+ )
+ room_power_levels["invite"] = 0
+ room_power_levels["state_default"] = 50
+ self.helper.send_state(
+ room_id, "m.room.power_levels", room_power_levels, self.tok
+ )
+
+ def test_public_room(self):
+ """Tests that it's only possible to have a room listed in the public room list
+ if the access rule is restricted.
+ """
+ # Creating a room with the public_chat preset should succeed and set the access
+ # rule to restricted.
+ preset_room_id = self.create_room(preset=RoomCreationPreset.PUBLIC_CHAT)
+ self.assertEqual(
+ self.current_rule_in_room(preset_room_id), AccessRules.RESTRICTED
+ )
+
+ # Creating a room with the public join rule in its initial state should succeed
+ # and set the access rule to restricted.
+ init_state_room_id = self.create_room(
+ initial_state=[
+ {
+ "type": "m.room.join_rules",
+ "content": {"join_rule": JoinRules.PUBLIC},
+ }
+ ]
+ )
+ self.assertEqual(
+ self.current_rule_in_room(init_state_room_id), AccessRules.RESTRICTED
+ )
+
+ # List preset_room_id in the public room list
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/directory/list/room/%s" % (preset_room_id,),
+ {"visibility": "public"},
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # List init_state_room_id in the public room list
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/directory/list/room/%s" % (init_state_room_id,),
+ {"visibility": "public"},
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # Changing access rule to unrestricted should fail.
+ self.change_rule_in_room(
+ preset_room_id, AccessRules.UNRESTRICTED, expected_code=403
+ )
+ self.change_rule_in_room(
+ init_state_room_id, AccessRules.UNRESTRICTED, expected_code=403
+ )
+
+ # Changing access rule to direct should fail.
+ self.change_rule_in_room(preset_room_id, AccessRules.DIRECT, expected_code=403)
+ self.change_rule_in_room(
+ init_state_room_id, AccessRules.DIRECT, expected_code=403
+ )
+
+ # Creating a new room with the public_chat preset and an access rule of direct
+ # should fail.
+ self.create_room(
+ preset=RoomCreationPreset.PUBLIC_CHAT,
+ rule=AccessRules.DIRECT,
+ expected_code=400,
+ )
+
+ # Changing join rule to public in an direct room should fail.
+ self.change_join_rule_in_room(
+ self.direct_rooms[0], JoinRules.PUBLIC, expected_code=403
+ )
+
+ def test_restricted(self):
+ """Tests that in restricted mode we're unable to invite users from blacklisted
+ servers but can invite other users.
+
+ Also tests that the room can be published to, and removed from, the public room
+ list.
+ """
+ # We can't invite a user from a forbidden HS.
+ self.helper.invite(
+ room=self.restricted_room,
+ src=self.user_id,
+ targ="@test:forbidden_domain",
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ # We can invite a user which HS isn't forbidden.
+ self.helper.invite(
+ room=self.restricted_room,
+ src=self.user_id,
+ targ="@test:allowed_domain",
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # We can't send a 3PID invite to an address that is mapped to a forbidden HS.
+ self.send_threepid_invite(
+ address="test@forbidden_domain",
+ room_id=self.restricted_room,
+ expected_code=403,
+ )
+
+ # We can send a 3PID invite to an address that is mapped to an HS that's not
+ # forbidden.
+ self.send_threepid_invite(
+ address="test@allowed_domain",
+ room_id=self.restricted_room,
+ expected_code=200,
+ )
+
+ # We are allowed to publish the room to the public room list
+ url = "/_matrix/client/r0/directory/list/room/%s" % self.restricted_room
+ data = {"visibility": "public"}
+
+ request, channel = self.make_request("PUT", url, data, access_token=self.tok)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # We are allowed to remove the room from the public room list
+ url = "/_matrix/client/r0/directory/list/room/%s" % self.restricted_room
+ data = {"visibility": "private"}
+
+ request, channel = self.make_request("PUT", url, data, access_token=self.tok)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ def test_direct(self):
+ """Tests that, in direct mode, other users than the initial two can't be invited,
+ but the following scenario works:
+ * invited user joins the room
+ * invited user leaves the room
+ * room creator re-invites invited user
+
+ Tests that a user from a HS that's in the list of forbidden domains (to use
+ in restricted mode) can be invited.
+
+ Tests that the room cannot be published to the public room list.
+ """
+ not_invited_user = "@not_invited:forbidden_domain"
+
+ # We can't invite a new user to the room.
+ self.helper.invite(
+ room=self.direct_rooms[0],
+ src=self.user_id,
+ targ=not_invited_user,
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ # The invited user can join the room.
+ self.helper.join(
+ room=self.direct_rooms[0],
+ user=self.invitee_id,
+ tok=self.invitee_tok,
+ expect_code=200,
+ )
+
+ # The invited user can leave the room.
+ self.helper.leave(
+ room=self.direct_rooms[0],
+ user=self.invitee_id,
+ tok=self.invitee_tok,
+ expect_code=200,
+ )
+
+ # The invited user can be re-invited to the room.
+ self.helper.invite(
+ room=self.direct_rooms[0],
+ src=self.user_id,
+ targ=self.invitee_id,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # If we're alone in the room and have always been the only member, we can invite
+ # someone.
+ self.helper.invite(
+ room=self.direct_rooms[1],
+ src=self.user_id,
+ targ=not_invited_user,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # Disable the 3pid invite ratelimiter
+ burst = self.hs.config.rc_third_party_invite.burst_count
+ per_second = self.hs.config.rc_third_party_invite.per_second
+ self.hs.config.rc_third_party_invite.burst_count = 10
+ self.hs.config.rc_third_party_invite.per_second = 0.1
+
+ # We can't send a 3PID invite to a room that already has two members.
+ self.send_threepid_invite(
+ address="test@allowed_domain",
+ room_id=self.direct_rooms[0],
+ expected_code=403,
+ )
+
+ # We can't send a 3PID invite to a room that already has a pending invite.
+ self.send_threepid_invite(
+ address="test@allowed_domain",
+ room_id=self.direct_rooms[1],
+ expected_code=403,
+ )
+
+ # We can send a 3PID invite to a room in which we've always been the only member.
+ self.send_threepid_invite(
+ address="test@forbidden_domain",
+ room_id=self.direct_rooms[2],
+ expected_code=200,
+ )
+
+ # We can send a 3PID invite to a room in which there's a 3PID invite.
+ self.send_threepid_invite(
+ address="test@forbidden_domain",
+ room_id=self.direct_rooms[2],
+ expected_code=403,
+ )
+
+ self.hs.config.rc_third_party_invite.burst_count = burst
+ self.hs.config.rc_third_party_invite.per_second = per_second
+
+ # We can't publish the room to the public room list
+ url = "/_matrix/client/r0/directory/list/room/%s" % self.direct_rooms[0]
+ data = {"visibility": "public"}
+
+ request, channel = self.make_request("PUT", url, data, access_token=self.tok)
+ self.assertEqual(channel.code, 403, channel.result)
+
+ def test_unrestricted(self):
+ """Tests that, in unrestricted mode, we can invite whoever we want, but we can
+ only change the power level of users that wouldn't be forbidden in restricted
+ mode.
+
+ Tests that the room cannot be published to the public room list.
+ """
+ # We can invite
+ self.helper.invite(
+ room=self.unrestricted_room,
+ src=self.user_id,
+ targ="@test:forbidden_domain",
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.invite(
+ room=self.unrestricted_room,
+ src=self.user_id,
+ targ="@test:not_forbidden_domain",
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # We can send a 3PID invite to an address that is mapped to a forbidden HS.
+ self.send_threepid_invite(
+ address="test@forbidden_domain",
+ room_id=self.unrestricted_room,
+ expected_code=200,
+ )
+
+ # We can send a 3PID invite to an address that is mapped to an HS that's not
+ # forbidden.
+ self.send_threepid_invite(
+ address="test@allowed_domain",
+ room_id=self.unrestricted_room,
+ expected_code=200,
+ )
+
+ # We can send a power level event that doesn't redefine the default PL or set a
+ # non-default PL for a user that would be forbidden in restricted mode.
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.PowerLevels,
+ body={"users": {self.user_id: 100, "@test:not_forbidden_domain": 10}},
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # We can't send a power level event that redefines the default PL and doesn't set
+ # a non-default PL for a user that would be forbidden in restricted mode.
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.PowerLevels,
+ body={
+ "users": {self.user_id: 100, "@test:not_forbidden_domain": 10},
+ "users_default": 10,
+ },
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ # We can't send a power level event that doesn't redefines the default PL but sets
+ # a non-default PL for a user that would be forbidden in restricted mode.
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.PowerLevels,
+ body={"users": {self.user_id: 100, "@test:forbidden_domain": 10}},
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ # We can't publish the room to the public room list
+ url = "/_matrix/client/r0/directory/list/room/%s" % self.unrestricted_room
+ data = {"visibility": "public"}
+
+ request, channel = self.make_request("PUT", url, data, access_token=self.tok)
+ self.assertEqual(channel.code, 403, channel.result)
+
+ def test_change_rules(self):
+ """Tests that we can only change the current rule from restricted to
+ unrestricted.
+ """
+ # We can't change the rule from restricted to direct.
+ self.change_rule_in_room(
+ room_id=self.restricted_room, new_rule=AccessRules.DIRECT, expected_code=403
+ )
+
+ # We can change the rule from restricted to unrestricted.
+ # Note that this changes self.restricted_room to an unrestricted room
+ self.change_rule_in_room(
+ room_id=self.restricted_room,
+ new_rule=AccessRules.UNRESTRICTED,
+ expected_code=200,
+ )
+
+ # We can't change the rule from unrestricted to restricted.
+ self.change_rule_in_room(
+ room_id=self.unrestricted_room,
+ new_rule=AccessRules.RESTRICTED,
+ expected_code=403,
+ )
+
+ # We can't change the rule from unrestricted to direct.
+ self.change_rule_in_room(
+ room_id=self.unrestricted_room,
+ new_rule=AccessRules.DIRECT,
+ expected_code=403,
+ )
+
+ # We can't change the rule from direct to restricted.
+ self.change_rule_in_room(
+ room_id=self.direct_rooms[0],
+ new_rule=AccessRules.RESTRICTED,
+ expected_code=403,
+ )
+
+ # We can't change the rule from direct to unrestricted.
+ self.change_rule_in_room(
+ room_id=self.direct_rooms[0],
+ new_rule=AccessRules.UNRESTRICTED,
+ expected_code=403,
+ )
+
+ # We can't publish a room to the public room list and then change its rule to
+ # unrestricted
+
+ # Create a restricted room
+ test_room_id = self.create_room(rule=AccessRules.RESTRICTED)
+
+ # Publish the room to the public room list
+ url = "/_matrix/client/r0/directory/list/room/%s" % test_room_id
+ data = {"visibility": "public"}
+
+ request, channel = self.make_request("PUT", url, data, access_token=self.tok)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # Attempt to switch the room to "unrestricted"
+ self.change_rule_in_room(
+ room_id=test_room_id, new_rule=AccessRules.UNRESTRICTED, expected_code=403
+ )
+
+ # Attempt to switch the room to "direct"
+ self.change_rule_in_room(
+ room_id=test_room_id, new_rule=AccessRules.DIRECT, expected_code=403
+ )
+
+ def test_change_room_avatar(self):
+ """Tests that changing the room avatar is always allowed unless the room is a
+ direct chat, in which case it's forbidden.
+ """
+
+ avatar_content = {
+ "info": {"h": 398, "mimetype": "image/jpeg", "size": 31037, "w": 394},
+ "url": "mxc://example.org/JWEIFJgwEIhweiWJE",
+ }
+
+ self.helper.send_state(
+ room_id=self.restricted_room,
+ event_type=EventTypes.RoomAvatar,
+ body=avatar_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.RoomAvatar,
+ body=avatar_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.direct_rooms[0],
+ event_type=EventTypes.RoomAvatar,
+ body=avatar_content,
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ def test_change_room_name(self):
+ """Tests that changing the room name is always allowed unless the room is a direct
+ chat, in which case it's forbidden.
+ """
+
+ name_content = {"name": "My super room"}
+
+ self.helper.send_state(
+ room_id=self.restricted_room,
+ event_type=EventTypes.Name,
+ body=name_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.Name,
+ body=name_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.direct_rooms[0],
+ event_type=EventTypes.Name,
+ body=name_content,
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ def test_change_room_topic(self):
+ """Tests that changing the room topic is always allowed unless the room is a
+ direct chat, in which case it's forbidden.
+ """
+
+ topic_content = {"topic": "Welcome to this room"}
+
+ self.helper.send_state(
+ room_id=self.restricted_room,
+ event_type=EventTypes.Topic,
+ body=topic_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.Topic,
+ body=topic_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.direct_rooms[0],
+ event_type=EventTypes.Topic,
+ body=topic_content,
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ def test_revoke_3pid_invite_direct(self):
+ """Tests that revoking a 3PID invite doesn't cause the room access rules module to
+ confuse the revokation as a new 3PID invite.
+ """
+ invite_token = "sometoken"
+
+ invite_body = {
+ "display_name": "ker...@exa...",
+ "public_keys": [
+ {
+ "key_validity_url": "https://validity_url",
+ "public_key": "ta8IQ0u1sp44HVpxYi7dFOdS/bfwDjcy4xLFlfY5KOA",
+ },
+ {
+ "key_validity_url": "https://validity_url",
+ "public_key": "4_9nzEeDwR5N9s51jPodBiLnqH43A2_g2InVT137t9I",
+ },
+ ],
+ "key_validity_url": "https://validity_url",
+ "public_key": "ta8IQ0u1sp44HVpxYi7dFOdS/bfwDjcy4xLFlfY5KOA",
+ }
+
+ self.send_state_with_state_key(
+ room_id=self.direct_rooms[1],
+ event_type=EventTypes.ThirdPartyInvite,
+ state_key=invite_token,
+ body=invite_body,
+ tok=self.tok,
+ )
+
+ self.send_state_with_state_key(
+ room_id=self.direct_rooms[1],
+ event_type=EventTypes.ThirdPartyInvite,
+ state_key=invite_token,
+ body={},
+ tok=self.tok,
+ )
+
+ invite_token = "someothertoken"
+
+ self.send_state_with_state_key(
+ room_id=self.direct_rooms[1],
+ event_type=EventTypes.ThirdPartyInvite,
+ state_key=invite_token,
+ body=invite_body,
+ tok=self.tok,
+ )
+
+ def test_check_event_allowed(self):
+ """Tests that RoomAccessRules.check_event_allowed behaves accordingly.
+
+ It tests that:
+ * forbidden users cannot join restricted rooms.
+ * forbidden users can only join unrestricted rooms if they have an invite.
+ """
+ event_creator = self.hs.get_event_creation_handler()
+
+ # Test that forbidden users cannot join restricted rooms
+ requester = create_requester(self.user_id)
+ allowed_requester = create_requester("@user:allowed_domain")
+ forbidden_requester = create_requester("@user:forbidden_domain")
+
+ # Assert a join event from a forbidden user to a restricted room is rejected
+ self.get_failure(
+ event_creator.create_event(
+ forbidden_requester,
+ {
+ "type": EventTypes.Member,
+ "room_id": self.restricted_room,
+ "sender": forbidden_requester.user.to_string(),
+ "content": {"membership": Membership.JOIN},
+ "state_key": forbidden_requester.user.to_string(),
+ },
+ ),
+ SynapseError,
+ )
+
+ # A join event from an non-forbidden user to a restricted room is allowed
+ self.get_success(
+ event_creator.create_event(
+ allowed_requester,
+ {
+ "type": EventTypes.Member,
+ "room_id": self.restricted_room,
+ "sender": allowed_requester.user.to_string(),
+ "content": {"membership": Membership.JOIN},
+ "state_key": allowed_requester.user.to_string(),
+ },
+ )
+ )
+
+ # Test that forbidden users can only join unrestricted rooms if they have an invite
+
+ # A forbidden user without an invite should not be able to join an unrestricted room
+ self.get_failure(
+ event_creator.create_event(
+ forbidden_requester,
+ {
+ "type": EventTypes.Member,
+ "room_id": self.unrestricted_room,
+ "sender": forbidden_requester.user.to_string(),
+ "content": {"membership": Membership.JOIN},
+ "state_key": forbidden_requester.user.to_string(),
+ },
+ ),
+ SynapseError,
+ )
+
+ # However, if we then invite this user...
+ self.helper.invite(
+ room=self.unrestricted_room,
+ src=requester.user.to_string(),
+ targ=forbidden_requester.user.to_string(),
+ tok=self.tok,
+ )
+
+ # And create another join event, making sure that its context states it's coming
+ # in after the above invite was made...
+ # Then the forbidden user should be able to join!
+ self.get_success(
+ event_creator.create_event(
+ forbidden_requester,
+ {
+ "type": EventTypes.Member,
+ "room_id": self.unrestricted_room,
+ "sender": forbidden_requester.user.to_string(),
+ "content": {"membership": Membership.JOIN},
+ "state_key": forbidden_requester.user.to_string(),
+ },
+ )
+ )
+
+ def test_freezing_a_room(self):
+ """Tests that the power levels in a room change to prevent new events from
+ non-admin users when the last admin of a room leaves.
+ """
+
+ def freeze_room_with_id_and_power_levels(
+ room_id: str, custom_power_levels_content: Optional[JsonDict] = None,
+ ):
+ # Invite a user to the room, they join with PL 0
+ self.helper.invite(
+ room=room_id, src=self.user_id, targ=self.invitee_id, tok=self.tok,
+ )
+
+ # Invitee joins the room
+ self.helper.join(
+ room=room_id, user=self.invitee_id, tok=self.invitee_tok,
+ )
+
+ if not custom_power_levels_content:
+ # Retrieve the room's current power levels event content
+ power_levels = self.helper.get_state(
+ room_id=room_id, event_type="m.room.power_levels", tok=self.tok,
+ )
+ else:
+ power_levels = custom_power_levels_content
+
+ # Override the room's power levels with the given power levels content
+ self.helper.send_state(
+ room_id=room_id,
+ event_type="m.room.power_levels",
+ body=custom_power_levels_content,
+ tok=self.tok,
+ )
+
+ # Ensure that the invitee leaving the room does not change the power levels
+ self.helper.leave(
+ room=room_id, user=self.invitee_id, tok=self.invitee_tok,
+ )
+
+ # Retrieve the new power levels of the room
+ new_power_levels = self.helper.get_state(
+ room_id=room_id, event_type="m.room.power_levels", tok=self.tok,
+ )
+
+ # Ensure they have not changed
+ self.assertDictEqual(power_levels, new_power_levels)
+
+ # Invite the user back again
+ self.helper.invite(
+ room=room_id, src=self.user_id, targ=self.invitee_id, tok=self.tok,
+ )
+
+ # Invitee joins the room
+ self.helper.join(
+ room=room_id, user=self.invitee_id, tok=self.invitee_tok,
+ )
+
+ # Now the admin leaves the room
+ self.helper.leave(
+ room=room_id, user=self.user_id, tok=self.tok,
+ )
+
+ # Check the power levels again
+ new_power_levels = self.helper.get_state(
+ room_id=room_id, event_type="m.room.power_levels", tok=self.invitee_tok,
+ )
+
+ # Ensure that the new power levels prevent anyone but admins from sending
+ # certain events
+ self.assertEquals(new_power_levels["state_default"], 100)
+ self.assertEquals(new_power_levels["events_default"], 100)
+ self.assertEquals(new_power_levels["kick"], 100)
+ self.assertEquals(new_power_levels["invite"], 100)
+ self.assertEquals(new_power_levels["ban"], 100)
+ self.assertEquals(new_power_levels["redact"], 100)
+ self.assertDictEqual(new_power_levels["events"], {})
+ self.assertDictEqual(new_power_levels["users"], {self.user_id: 100})
+
+ # Ensure new users entering the room aren't going to immediately become admins
+ self.assertEquals(new_power_levels["users_default"], 0)
+
+ # Test that freezing a room with the default power level state event content works
+ room1 = self.create_room()
+ freeze_room_with_id_and_power_levels(room1)
+
+ # Test that freezing a room with a power level state event that is missing
+ # `state_default` and `event_default` keys behaves as expected
+ room2 = self.create_room()
+ freeze_room_with_id_and_power_levels(
+ room2,
+ {
+ "ban": 50,
+ "events": {
+ "m.room.avatar": 50,
+ "m.room.canonical_alias": 50,
+ "m.room.history_visibility": 100,
+ "m.room.name": 50,
+ "m.room.power_levels": 100,
+ },
+ "invite": 0,
+ "kick": 50,
+ "redact": 50,
+ "users": {self.user_id: 100},
+ "users_default": 0,
+ # Explicitly remove `state_default` and `event_default` keys
+ },
+ )
+
+ # Test that freezing a room with a power level state event that is *additionally*
+ # missing `ban`, `invite`, `kick` and `redact` keys behaves as expected
+ room3 = self.create_room()
+ freeze_room_with_id_and_power_levels(
+ room3,
+ {
+ "events": {
+ "m.room.avatar": 50,
+ "m.room.canonical_alias": 50,
+ "m.room.history_visibility": 100,
+ "m.room.name": 50,
+ "m.room.power_levels": 100,
+ },
+ "users": {self.user_id: 100},
+ "users_default": 0,
+ # Explicitly remove `state_default` and `event_default` keys
+ # Explicitly remove `ban`, `invite`, `kick` and `redact` keys
+ },
+ )
+
+ def create_room(
+ self,
+ direct=False,
+ rule=None,
+ preset=RoomCreationPreset.TRUSTED_PRIVATE_CHAT,
+ initial_state=None,
+ power_levels_content_override=None,
+ expected_code=200,
+ ):
+ content = {"is_direct": direct, "preset": preset}
+
+ if rule:
+ content["initial_state"] = [
+ {"type": ACCESS_RULES_TYPE, "state_key": "", "content": {"rule": rule}}
+ ]
+
+ if initial_state:
+ if "initial_state" not in content:
+ content["initial_state"] = []
+
+ content["initial_state"] += initial_state
+
+ if power_levels_content_override:
+ content["power_levels_content_override"] = power_levels_content_override
+
+ request, channel = self.make_request(
+ "POST", "/_matrix/client/r0/createRoom", content, access_token=self.tok,
+ )
+
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ if expected_code == 200:
+ return channel.json_body["room_id"]
+
+ def current_rule_in_room(self, room_id):
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, ACCESS_RULES_TYPE),
+ access_token=self.tok,
+ )
+
+ self.assertEqual(channel.code, 200, channel.result)
+ return channel.json_body["rule"]
+
+ def change_rule_in_room(self, room_id, new_rule, expected_code=200):
+ data = {"rule": new_rule}
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, ACCESS_RULES_TYPE),
+ json.dumps(data),
+ access_token=self.tok,
+ )
+
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ def change_join_rule_in_room(self, room_id, new_join_rule, expected_code=200):
+ data = {"join_rule": new_join_rule}
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, EventTypes.JoinRules),
+ json.dumps(data),
+ access_token=self.tok,
+ )
+
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ def send_threepid_invite(self, address, room_id, expected_code=200):
+ params = {"id_server": "testis", "medium": "email", "address": address}
+
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/%s/invite" % room_id,
+ json.dumps(params),
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ def send_state_with_state_key(
+ self, room_id, event_type, state_key, body, tok, expect_code=200
+ ):
+ path = "/_matrix/client/r0/rooms/%s/state/%s/%s" % (
+ room_id,
+ event_type,
+ state_key,
+ )
+
+ request, channel = self.make_request(
+ "PUT", path, json.dumps(body), access_token=tok
+ )
+
+ self.assertEqual(channel.code, expect_code, channel.result)
+
+ return channel.json_body
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 566776e97e..18932d7518 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -475,8 +475,12 @@ class JWTTestCase(unittest.HomeserverTestCase):
self.hs.config.jwt_algorithm = self.jwt_algorithm
return self.hs
- def jwt_encode(self, token, secret=jwt_secret):
- return jwt.encode(token, secret, self.jwt_algorithm).decode("ascii")
+ def jwt_encode(self, token: str, secret: str = jwt_secret) -> str:
+ # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
+ result = jwt.encode(token, secret, self.jwt_algorithm)
+ if isinstance(result, bytes):
+ return result.decode("ascii")
+ return result
def jwt_login(self, *args):
params = json.dumps(
@@ -680,8 +684,12 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
self.hs.config.jwt_algorithm = "RS256"
return self.hs
- def jwt_encode(self, token, secret=jwt_privatekey):
- return jwt.encode(token, secret, "RS256").decode("ascii")
+ def jwt_encode(self, token: str, secret: str = jwt_privatekey) -> str:
+ # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
+ result = jwt.encode(token, secret, "RS256")
+ if isinstance(result, bytes):
+ return result.decode("ascii")
+ return result
def jwt_login(self, *args):
params = json.dumps(
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 27db4f551e..721f34a7f7 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -18,9 +18,15 @@
import datetime
import json
import os
+import os.path
+import tempfile
+
+from mock import Mock
import pkg_resources
+from twisted.internet import defer
+
import synapse.rest.admin
from synapse.api.constants import LoginType
from synapse.api.errors import Codes
@@ -85,13 +91,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"400", channel.result)
self.assertEquals(channel.json_body["error"], "Invalid password")
- def test_POST_bad_username(self):
- request_data = json.dumps({"username": 777, "password": "monkey"})
- channel = self.make_request(b"POST", self.url, request_data)
-
- self.assertEquals(channel.result["code"], b"400", channel.result)
- self.assertEquals(channel.json_body["error"], "Invalid username")
-
def test_POST_user_valid(self):
user_id = "@kermit:test"
device_id = "frogfone"
@@ -289,6 +288,96 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertIsNotNone(channel.json_body.get("sid"))
+class RegisterHideProfileTestCase(unittest.HomeserverTestCase):
+
+ servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource]
+
+ def make_homeserver(self, reactor, clock):
+
+ self.url = b"/_matrix/client/r0/register"
+
+ config = self.default_config()
+ config["enable_registration"] = True
+ config["show_users_in_user_directory"] = False
+ config["replicate_user_profiles_to"] = ["fakeserver"]
+
+ mock_http_client = Mock(spec=["get_json", "post_json_get_json"])
+ mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}"))
+
+ self.hs = self.setup_test_homeserver(
+ config=config, simple_http_client=mock_http_client
+ )
+
+ return self.hs
+
+ def test_profile_hidden(self):
+ user_id = self.register_user("kermit", "monkey")
+
+ post_json = self.hs.get_simple_http_client().post_json_get_json
+
+ # We expect post_json_get_json to have been called twice: once with the original
+ # profile and once with the None profile resulting from the request to hide it
+ # from the user directory.
+ self.assertEqual(post_json.call_count, 2, post_json.call_args_list)
+
+ # Get the args (and not kwargs) passed to post_json.
+ args = post_json.call_args[0]
+ # Make sure the last call was attempting to replicate profiles.
+ split_uri = args[0].split("/")
+ self.assertEqual(split_uri[len(split_uri) - 1], "replicate_profiles", args[0])
+ # Make sure the last profile update was overriding the user's profile to None.
+ self.assertEqual(args[1]["batch"][user_id], None, args[1])
+
+
+class AccountValidityTemplateDirectoryTestCase(unittest.HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+
+ # Create a custom template directory and a template inside to read
+ temp_dir = tempfile.mkdtemp()
+ self.account_renewed_fd, account_renewed_path = tempfile.mkstemp(dir=temp_dir)
+ self.invalid_token_fd, invalid_token_path = tempfile.mkstemp(dir=temp_dir)
+
+ self.account_renewed_template_contents = "Yay, your account has been renewed"
+ self.invalid_token_template_contents = "Boo, you used an invalid token. Booo"
+
+ # Add some content to the custom templates
+ with open(account_renewed_path, "w") as f:
+ f.write(self.account_renewed_template_contents)
+
+ with open(invalid_token_path, "w") as f:
+ f.write(self.invalid_token_template_contents)
+
+ # Write the config, specifying the custom template directory and name of the custom
+ # template files. They must be different than those that exist in the default
+ # template directory in order to properly test everything.
+ config["enable_registration"] = True
+ config["account_validity"] = {
+ "enabled": True,
+ "period": 604800000, # Time in ms for 1 week
+ "template_dir": temp_dir,
+ "account_renewed_html_path": os.path.basename(account_renewed_path),
+ "invalid_token_html_path": os.path.basename(invalid_token_path),
+ }
+ self.hs = self.setup_test_homeserver(config=config)
+
+ return self.hs
+
+ def test_template_contents(self):
+ """Tests that the contents of the custom templates as specified in the config are
+ correct.
+ """
+ self.assertEquals(
+ self.hs.config.account_validity.account_validity_account_renewed_template.render(),
+ self.account_renewed_template_contents,
+ )
+
+ self.assertEquals(
+ self.hs.config.account_validity.account_validity_invalid_token_template.render(),
+ self.invalid_token_template_contents,
+ )
+
+
class AccountValidityTestCase(unittest.HomeserverTestCase):
servlets = [
@@ -298,6 +387,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
sync.register_servlets,
logout.register_servlets,
account_validity.register_servlets,
+ account.register_servlets,
]
def make_homeserver(self, reactor, clock):
@@ -408,6 +498,152 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
+class AccountValidityUserDirectoryTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.client.v1.profile.register_servlets,
+ synapse.rest.client.v1.room.register_servlets,
+ synapse.rest.client.v2_alpha.user_directory.register_servlets,
+ login.register_servlets,
+ register.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ account_validity.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+
+ # Set accounts to expire after a week
+ config["enable_registration"] = True
+ config["account_validity"] = {
+ "enabled": True,
+ "period": 604800000, # Time in ms for 1 week
+ }
+ config["replicate_user_profiles_to"] = "test.is"
+
+ # Mock homeserver requests to an identity server
+ mock_http_client = Mock(spec=["post_json_get_json"])
+ mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}"))
+
+ self.hs = self.setup_test_homeserver(
+ config=config, simple_http_client=mock_http_client
+ )
+
+ return self.hs
+
+ def test_expired_user_in_directory(self):
+ """Test that an expired user is hidden in the user directory"""
+ # Create an admin user to search the user directory
+ admin_id = self.register_user("admin", "adminpassword", admin=True)
+ admin_tok = self.login("admin", "adminpassword")
+
+ # Ensure the admin never expires
+ url = "/_synapse/admin/v1/account_validity/validity"
+ params = {
+ "user_id": admin_id,
+ "expiration_ts": 999999999999,
+ "enable_renewal_emails": False,
+ }
+ request_data = json.dumps(params)
+ request, channel = self.make_request(
+ b"POST", url, request_data, access_token=admin_tok
+ )
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ # Mock the homeserver's HTTP client
+ post_json = self.hs.get_simple_http_client().post_json_get_json
+
+ # Create a user
+ username = "kermit"
+ user_id = self.register_user(username, "monkey")
+ self.login(username, "monkey")
+ self.get_success(
+ self.hs.get_datastore().set_profile_displayname(username, "mr.kermit", 1)
+ )
+
+ # Check that a full profile for this user is replicated
+ self.assertIsNotNone(post_json.call_args, post_json.call_args)
+ payload = post_json.call_args[0][1]
+ batch = payload.get("batch")
+
+ self.assertIsNotNone(batch, batch)
+ self.assertEquals(len(batch), 1, batch)
+
+ replicated_user_id = list(batch.keys())[0]
+ self.assertEquals(replicated_user_id, user_id, replicated_user_id)
+
+ # There was replicated information about our user
+ # Check that it's not None
+ replicated_content = batch[user_id]
+ self.assertIsNotNone(replicated_content)
+
+ # Expire the user
+ url = "/_synapse/admin/v1/account_validity/validity"
+ params = {
+ "user_id": user_id,
+ "expiration_ts": 0,
+ "enable_renewal_emails": False,
+ }
+ request_data = json.dumps(params)
+ request, channel = self.make_request(
+ b"POST", url, request_data, access_token=admin_tok
+ )
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ # Wait for the background job to run which hides expired users in the directory
+ self.reactor.advance(60 * 60 * 1000)
+
+ # Check if the homeserver has replicated the user's profile to the identity server
+ self.assertIsNotNone(post_json.call_args, post_json.call_args)
+ payload = post_json.call_args[0][1]
+ batch = payload.get("batch")
+
+ self.assertIsNotNone(batch, batch)
+ self.assertEquals(len(batch), 1, batch)
+
+ replicated_user_id = list(batch.keys())[0]
+ self.assertEquals(replicated_user_id, user_id, replicated_user_id)
+
+ # There was replicated information about our user
+ # Check that it's None, signifying that the user should be removed from the user
+ # directory because they were expired
+ replicated_content = batch[user_id]
+ self.assertIsNone(replicated_content)
+
+ # Now renew the user, and check they get replicated again to the identity server
+ url = "/_synapse/admin/v1/account_validity/validity"
+ params = {
+ "user_id": user_id,
+ "expiration_ts": 99999999999,
+ "enable_renewal_emails": False,
+ }
+ request_data = json.dumps(params)
+ request, channel = self.make_request(
+ b"POST", url, request_data, access_token=admin_tok
+ )
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ self.pump(10)
+ self.reactor.advance(10)
+ self.pump()
+
+ # Check if the homeserver has replicated the user's profile to the identity server
+ post_json = self.hs.get_simple_http_client().post_json_get_json
+ self.assertNotEquals(post_json.call_args, None, post_json.call_args)
+ payload = post_json.call_args[0][1]
+ batch = payload.get("batch")
+ self.assertNotEquals(batch, None, batch)
+ self.assertEquals(len(batch), 1, batch)
+ replicated_user_id = list(batch.keys())[0]
+ self.assertEquals(replicated_user_id, user_id, replicated_user_id)
+
+ # There was replicated information about our user
+ # Check that it's not None, signifying that the user is back in the user
+ # directory
+ replicated_content = batch[user_id]
+ self.assertIsNotNone(replicated_content)
+
+
class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
servlets = [
@@ -470,8 +706,8 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
(user_id, tok) = self.create_user()
- # Move 6 days forward. This should trigger a renewal email to be sent.
- self.reactor.advance(datetime.timedelta(days=6).total_seconds())
+ # Move 5 days forward. This should trigger a renewal email to be sent.
+ self.reactor.advance(datetime.timedelta(days=5).total_seconds())
self.assertEqual(len(self.email_attempts), 1)
# Retrieving the URL from the email is too much pain for now, so we
@@ -482,14 +718,32 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
# Check that we're getting HTML back.
- content_type = None
- for header in channel.result.get("headers", []):
- if header[0] == b"Content-Type":
- content_type = header[1]
- self.assertEqual(content_type, b"text/html; charset=utf-8", channel.result)
+ content_type = channel.headers.getRawHeaders(b"Content-Type")
+ self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result)
# Check that the HTML we're getting is the one we expect on a successful renewal.
- expected_html = self.hs.config.account_validity.account_renewed_html_content
+ expiration_ts = self.get_success(self.store.get_expiration_ts_for_user(user_id))
+ expected_html = self.hs.config.account_validity_account_renewed_template.render(
+ expiration_ts=expiration_ts
+ )
+ self.assertEqual(
+ channel.result["body"], expected_html.encode("utf8"), channel.result
+ )
+
+ # Move 1 day forward. Try to renew with the same token again.
+ url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token
+ request, channel = self.make_request(b"GET", url)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ # Check that we're getting HTML back.
+ content_type = channel.headers.getRawHeaders(b"Content-Type")
+ self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result)
+
+ # Check that the HTML we're getting is the one we expect when reusing a
+ # token. The account expiration date should not have changed.
+ expected_html = self.hs.config.account_validity_account_previously_renewed_template.render(
+ expiration_ts=expiration_ts
+ )
self.assertEqual(
channel.result["body"], expected_html.encode("utf8"), channel.result
)
@@ -509,15 +763,12 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"404", channel.result)
# Check that we're getting HTML back.
- content_type = None
- for header in channel.result.get("headers", []):
- if header[0] == b"Content-Type":
- content_type = header[1]
- self.assertEqual(content_type, b"text/html; charset=utf-8", channel.result)
+ content_type = channel.headers.getRawHeaders(b"Content-Type")
+ self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result)
# Check that the HTML we're getting is the one we expect when using an
# invalid/unknown token.
- expected_html = self.hs.config.account_validity.invalid_token_html_content
+ expected_html = self.hs.config.account_validity_invalid_token_template.render()
self.assertEqual(
channel.result["body"], expected_html.encode("utf8"), channel.result
)
@@ -625,7 +876,12 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
config["account_validity"] = {"enabled": False}
self.hs = self.setup_test_homeserver(config=config)
- self.hs.config.account_validity.period = self.validity_period
+
+ # We need to set these directly, instead of in the homeserver config dict above.
+ # This is due to account validity-related config options not being read by
+ # Synapse when account_validity.enabled is False.
+ self.hs.get_datastore()._account_validity_period = self.validity_period
+ self.hs.get_datastore()._account_validity_startup_job_max_delta = self.max_delta
self.store = self.hs.get_datastore()
@@ -639,8 +895,6 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
"""
user_id = self.register_user("kermit_delta", "user")
- self.hs.config.account_validity.startup_job_max_delta = self.max_delta
-
now_ms = self.hs.get_clock().time_msec()
self.get_success(self.store._set_expiration_date_when_missing())
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index 512e36c236..d0414864ab 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -16,12 +16,22 @@
import json
import synapse.rest.admin
-from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
+from synapse.api.constants import (
+ EventContentFields,
+ EventTypes,
+ Membership,
+ RelationTypes,
+)
+from synapse.api.room_versions import RoomVersions
from synapse.rest.client.v1 import login, room
-from synapse.rest.client.v2_alpha import read_marker, sync
+from synapse.rest.client.v2_alpha import knock, read_marker, sync
from tests import unittest
+from tests.federation.transport.test_knocking import (
+ KnockingStrippedStateEventHelperMixin,
+)
from tests.server import TimedOutException
+from tests.unittest import override_config
class FilterTestCase(unittest.HomeserverTestCase):
@@ -306,6 +316,89 @@ class SyncTypingTests(unittest.HomeserverTestCase):
self.make_request("GET", sync_url % (access_token, next_batch))
+class SyncKnockTestCase(
+ unittest.HomeserverTestCase, KnockingStrippedStateEventHelperMixin
+):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ sync.register_servlets,
+ knock.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.url = "/sync?since=%s"
+ self.next_batch = "s0"
+
+ # Register the first user (used to create the room to knock on).
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+
+ # Create the room we'll knock on.
+ self.room_id = self.helper.create_room_as(
+ self.user_id,
+ is_public=False,
+ room_version=RoomVersions.V7.identifier,
+ tok=self.tok,
+ )
+
+ # Register the second user (used to knock on the room).
+ self.knocker = self.register_user("knocker", "monkey")
+ self.knocker_tok = self.login("knocker", "monkey")
+
+ # Perform an initial sync for the knocking user.
+ _, channel = self.make_request(
+ "GET", self.url % self.next_batch, access_token=self.tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Store the next batch for the next request.
+ self.next_batch = channel.json_body["next_batch"]
+
+ # Set up some room state to test with.
+ self.expected_room_state = self.send_example_state_events_to_room(
+ hs, self.room_id, self.user_id
+ )
+
+ @override_config({"experimental_features": {"msc2403_enabled": True}})
+ def test_knock_room_state(self):
+ """Tests that /sync returns state from a room after knocking on it."""
+ # Knock on a room
+ _, channel = self.make_request(
+ "POST",
+ "/_matrix/client/unstable/xyz.amorgan.knock/%s" % (self.room_id,),
+ b"{}",
+ self.knocker_tok,
+ )
+ self.assertEquals(200, channel.code, channel.result)
+
+ # We expect to see the knock event in the stripped room state later
+ self.expected_room_state[EventTypes.Member] = {
+ "content": {"membership": Membership.KNOCK, "displayname": "knocker"},
+ "state_key": "@knocker:test",
+ }
+
+ # Check that /sync includes stripped state from the room
+ _, channel = self.make_request(
+ "GET", self.url % self.next_batch, access_token=self.knocker_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Extract the stripped room state events from /sync
+ knock_entry = channel.json_body["rooms"][Membership.KNOCK]
+ room_state_events = knock_entry[self.room_id]["knock_state"]["events"]
+
+ # Validate that the knock membership event came last
+ self.assertEqual(room_state_events[-1]["type"], EventTypes.Member)
+
+ # Validate the stripped room state events
+ self.check_knock_room_state_against_room_state(
+ room_state_events, self.expected_room_state
+ )
+
+
class UnreadMessagesTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
@@ -439,7 +532,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
)
self._check_unread_count(5)
- def _check_unread_count(self, expected_count: True):
+ def _check_unread_count(self, expected_count: int):
"""Syncs and compares the unread count with the expected value."""
channel = self.make_request(
diff --git a/tests/rulecheck/__init__.py b/tests/rulecheck/__init__.py
new file mode 100644
index 0000000000..a354d38ca8
--- /dev/null
+++ b/tests/rulecheck/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/rulecheck/test_domainrulecheck.py b/tests/rulecheck/test_domainrulecheck.py
new file mode 100644
index 0000000000..0937f304cf
--- /dev/null
+++ b/tests/rulecheck/test_domainrulecheck.py
@@ -0,0 +1,333 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import json
+
+import synapse.rest.admin
+from synapse.config._base import ConfigError
+from synapse.rest.client.v1 import login, room
+from synapse.rulecheck.domain_rule_checker import DomainRuleChecker
+
+from tests import unittest
+from tests.server import make_request
+
+
+class DomainRuleCheckerTestCase(unittest.TestCase):
+ def test_allowed(self):
+ config = {
+ "default": False,
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ },
+ "domains_prevented_from_being_invited_to_published_rooms": ["target_two"],
+ }
+ check = DomainRuleChecker(config)
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_one", "test:target_one", None, "room", False
+ )
+ )
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_one", "test:target_two", None, "room", False
+ )
+ )
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_two", "test:target_two", None, "room", False
+ )
+ )
+
+ # User can invite internal user to a published room
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_one", "test1:target_one", None, "room", False, True
+ )
+ )
+
+ # User can invite external user to a non-published room
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_one", "test:target_two", None, "room", False, False
+ )
+ )
+
+ def test_disallowed(self):
+ config = {
+ "default": True,
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ "source_four": [],
+ },
+ }
+ check = DomainRuleChecker(config)
+ self.assertFalse(
+ check.user_may_invite(
+ "test:source_one", "test:target_three", None, "room", False
+ )
+ )
+ self.assertFalse(
+ check.user_may_invite(
+ "test:source_two", "test:target_three", None, "room", False
+ )
+ )
+ self.assertFalse(
+ check.user_may_invite(
+ "test:source_two", "test:target_one", None, "room", False
+ )
+ )
+ self.assertFalse(
+ check.user_may_invite(
+ "test:source_four", "test:target_one", None, "room", False
+ )
+ )
+
+ # User cannot invite external user to a published room
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_one", "test:target_two", None, "room", False, True
+ )
+ )
+
+ def test_default_allow(self):
+ config = {
+ "default": True,
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ },
+ }
+ check = DomainRuleChecker(config)
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_three", "test:target_one", None, "room", False
+ )
+ )
+
+ def test_default_deny(self):
+ config = {
+ "default": False,
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ },
+ }
+ check = DomainRuleChecker(config)
+ self.assertFalse(
+ check.user_may_invite(
+ "test:source_three", "test:target_one", None, "room", False
+ )
+ )
+
+ def test_config_parse(self):
+ config = {
+ "default": False,
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ },
+ }
+ self.assertEquals(config, DomainRuleChecker.parse_config(config))
+
+ def test_config_parse_failure(self):
+ config = {
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ }
+ }
+ self.assertRaises(ConfigError, DomainRuleChecker.parse_config, config)
+
+
+class DomainRuleCheckerRoomTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ hijack_auth = False
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+ config["trusted_third_party_id_servers"] = ["localhost"]
+
+ config["spam_checker"] = {
+ "module": "synapse.rulecheck.domain_rule_checker.DomainRuleChecker",
+ "config": {
+ "default": True,
+ "domain_mapping": {},
+ "can_only_join_rooms_with_invite": True,
+ "can_only_create_one_to_one_rooms": True,
+ "can_only_invite_during_room_creation": True,
+ "can_invite_by_third_party_id": False,
+ },
+ }
+
+ hs = self.setup_test_homeserver(config=config)
+ return hs
+
+ def prepare(self, reactor, clock, hs):
+ self.admin_user_id = self.register_user("admin_user", "pass", admin=True)
+ self.admin_access_token = self.login("admin_user", "pass")
+
+ self.normal_user_id = self.register_user("normal_user", "pass", admin=False)
+ self.normal_access_token = self.login("normal_user", "pass")
+
+ self.other_user_id = self.register_user("other_user", "pass", admin=False)
+
+ def test_admin_can_create_room(self):
+ channel = self._create_room(self.admin_access_token)
+ assert channel.result["code"] == b"200", channel.result
+
+ def test_normal_user_cannot_create_empty_room(self):
+ channel = self._create_room(self.normal_access_token)
+ assert channel.result["code"] == b"403", channel.result
+
+ def test_normal_user_cannot_create_room_with_multiple_invites(self):
+ channel = self._create_room(
+ self.normal_access_token,
+ content={"invite": [self.other_user_id, self.admin_user_id]},
+ )
+ assert channel.result["code"] == b"403", channel.result
+
+ # Test that it correctly counts both normal and third party invites
+ channel = self._create_room(
+ self.normal_access_token,
+ content={
+ "invite": [self.other_user_id],
+ "invite_3pid": [{"medium": "email", "address": "foo@example.com"}],
+ },
+ )
+ assert channel.result["code"] == b"403", channel.result
+
+ # Test that it correctly rejects third party invites
+ channel = self._create_room(
+ self.normal_access_token,
+ content={
+ "invite": [],
+ "invite_3pid": [{"medium": "email", "address": "foo@example.com"}],
+ },
+ )
+ assert channel.result["code"] == b"403", channel.result
+
+ def test_normal_user_can_room_with_single_invites(self):
+ channel = self._create_room(
+ self.normal_access_token, content={"invite": [self.other_user_id]}
+ )
+ assert channel.result["code"] == b"200", channel.result
+
+ def test_cannot_join_public_room(self):
+ channel = self._create_room(self.admin_access_token)
+ assert channel.result["code"] == b"200", channel.result
+
+ room_id = channel.json_body["room_id"]
+
+ self.helper.join(
+ room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=403
+ )
+
+ def test_can_join_invited_room(self):
+ channel = self._create_room(self.admin_access_token)
+ assert channel.result["code"] == b"200", channel.result
+
+ room_id = channel.json_body["room_id"]
+
+ self.helper.invite(
+ room_id,
+ src=self.admin_user_id,
+ targ=self.normal_user_id,
+ tok=self.admin_access_token,
+ )
+
+ self.helper.join(
+ room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=200
+ )
+
+ def test_cannot_invite(self):
+ channel = self._create_room(self.admin_access_token)
+ assert channel.result["code"] == b"200", channel.result
+
+ room_id = channel.json_body["room_id"]
+
+ self.helper.invite(
+ room_id,
+ src=self.admin_user_id,
+ targ=self.normal_user_id,
+ tok=self.admin_access_token,
+ )
+
+ self.helper.join(
+ room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=200
+ )
+
+ self.helper.invite(
+ room_id,
+ src=self.normal_user_id,
+ targ=self.other_user_id,
+ tok=self.normal_access_token,
+ expect_code=403,
+ )
+
+ def test_cannot_3pid_invite(self):
+ """Test that unbound 3pid invites get rejected.
+ """
+ channel = self._create_room(self.admin_access_token)
+ assert channel.result["code"] == b"200", channel.result
+
+ room_id = channel.json_body["room_id"]
+
+ self.helper.invite(
+ room_id,
+ src=self.admin_user_id,
+ targ=self.normal_user_id,
+ tok=self.admin_access_token,
+ )
+
+ self.helper.join(
+ room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=200
+ )
+
+ self.helper.invite(
+ room_id,
+ src=self.normal_user_id,
+ targ=self.other_user_id,
+ tok=self.normal_access_token,
+ expect_code=403,
+ )
+
+ request, channel = self.make_request(
+ "POST",
+ "rooms/%s/invite" % (room_id),
+ {"address": "foo@bar.com", "medium": "email", "id_server": "localhost"},
+ access_token=self.normal_access_token,
+ )
+ self.assertEqual(channel.code, 403, channel.result["body"])
+
+ def _create_room(self, token, content={}):
+ path = "/_matrix/client/r0/createRoom?access_token=%s" % (token,)
+
+ request, channel = make_request(
+ self.hs.get_reactor(),
+ self.site,
+ "POST",
+ path,
+ content=json.dumps(content).encode("utf8"),
+ )
+
+ return channel
diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py
index 7e7f1286d9..fe37d2ed5a 100644
--- a/tests/storage/test_main.py
+++ b/tests/storage/test_main.py
@@ -39,7 +39,7 @@ class DataStoreTestCase(unittest.TestCase):
)
yield defer.ensureDeferred(self.store.create_profile(self.user.localpart))
yield defer.ensureDeferred(
- self.store.set_profile_displayname(self.user.localpart, self.displayname)
+ self.store.set_profile_displayname(self.user.localpart, self.displayname, 1)
)
users, total = yield defer.ensureDeferred(
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 3fd0a38cf5..7a38022e71 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -36,7 +36,7 @@ class ProfileStoreTestCase(unittest.TestCase):
yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
yield defer.ensureDeferred(
- self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
+ self.store.set_profile_displayname(self.u_frank.localpart, "Frank", 1)
)
self.assertEquals(
@@ -54,7 +54,7 @@ class ProfileStoreTestCase(unittest.TestCase):
yield defer.ensureDeferred(
self.store.set_profile_avatar_url(
- self.u_frank.localpart, "http://my.site/here"
+ self.u_frank.localpart, "http://my.site/here", 1
)
)
diff --git a/tests/test_types.py b/tests/test_types.py
index 480bea1bdc..d4a722a30f 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -12,9 +12,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from six import string_types
from synapse.api.errors import SynapseError
-from synapse.types import GroupID, RoomAlias, UserID, map_username_to_mxid_localpart
+from synapse.types import (
+ GroupID,
+ RoomAlias,
+ UserID,
+ map_username_to_mxid_localpart,
+ strip_invalid_mxid_characters,
+)
from tests import unittest
@@ -103,3 +110,16 @@ class MapUsernameTestCase(unittest.TestCase):
self.assertEqual(
map_username_to_mxid_localpart("têst".encode("utf-8")), "t=c3=aast"
)
+
+
+class StripInvalidMxidCharactersTestCase(unittest.TestCase):
+ def test_return_type(self):
+ unstripped = strip_invalid_mxid_characters("test")
+ stripped = strip_invalid_mxid_characters("test@")
+
+ self.assertTrue(isinstance(unstripped, string_types), type(unstripped))
+ self.assertTrue(isinstance(stripped, string_types), type(stripped))
+
+ def test_strip(self):
+ stripped = strip_invalid_mxid_characters("test@")
+ self.assertEqual(stripped, "test", stripped)
diff --git a/tests/utils.py b/tests/utils.py
index 977eeaf6ee..a3b01575dd 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -173,6 +173,8 @@ def default_config(name, parse=False):
"update_user_directory": False,
"caches": {"global_factor": 1},
"listeners": [{"port": 0, "type": "http"}],
+ # Enable encryption by default in private rooms
+ "encryption_enabled_by_default_for_room_type": "invite",
}
if parse:
|