diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 6121efcfa9..cc0b10e7f6 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -68,7 +68,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={})
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- requester = yield self.auth.get_user_by_req(request)
+ requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
self.assertEquals(requester.user.to_string(), self.test_user)
def test_get_user_by_req_user_bad_token(self):
@@ -105,7 +105,7 @@ class AuthTestCase(unittest.TestCase):
request.getClientIP.return_value = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- requester = yield self.auth.get_user_by_req(request)
+ requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
self.assertEquals(requester.user.to_string(), self.test_user)
@defer.inlineCallbacks
@@ -125,7 +125,7 @@ class AuthTestCase(unittest.TestCase):
request.getClientIP.return_value = "192.168.10.10"
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- requester = yield self.auth.get_user_by_req(request)
+ requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
self.assertEquals(requester.user.to_string(), self.test_user)
def test_get_user_by_req_appservice_valid_token_bad_ip(self):
@@ -188,7 +188,7 @@ class AuthTestCase(unittest.TestCase):
request.args[b"access_token"] = [self.test_token]
request.args[b"user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- requester = yield self.auth.get_user_by_req(request)
+ requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
self.assertEquals(
requester.user.to_string(), masquerading_user_id.decode("utf8")
)
@@ -225,7 +225,9 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
- user_info = yield self.auth.get_user_by_access_token(macaroon.serialize())
+ user_info = yield defer.ensureDeferred(
+ self.auth.get_user_by_access_token(macaroon.serialize())
+ )
user = user_info["user"]
self.assertEqual(UserID.from_string(user_id), user)
@@ -250,7 +252,9 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("guest = true")
serialized = macaroon.serialize()
- user_info = yield self.auth.get_user_by_access_token(serialized)
+ user_info = yield defer.ensureDeferred(
+ self.auth.get_user_by_access_token(serialized)
+ )
user = user_info["user"]
is_guest = user_info["is_guest"]
self.assertEqual(UserID.from_string(user_id), user)
@@ -260,10 +264,13 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_cannot_use_regular_token_as_guest(self):
USER_ID = "@percy:matrix.org"
- self.store.add_access_token_to_user = Mock()
+ self.store.add_access_token_to_user = Mock(return_value=defer.succeed(None))
+ self.store.get_device = Mock(return_value=defer.succeed(None))
- token = yield self.hs.handlers.auth_handler.get_access_token_for_user_id(
- USER_ID, "DEVICE", valid_until_ms=None
+ token = yield defer.ensureDeferred(
+ self.hs.handlers.auth_handler.get_access_token_for_user_id(
+ USER_ID, "DEVICE", valid_until_ms=None
+ )
)
self.store.add_access_token_to_user.assert_called_with(
USER_ID, token, "DEVICE", None
@@ -286,7 +293,9 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={})
request.args[b"access_token"] = [token.encode("ascii")]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ requester = yield defer.ensureDeferred(
+ self.auth.get_user_by_req(request, allow_guest=True)
+ )
self.assertEqual(UserID.from_string(USER_ID), requester.user)
self.assertFalse(requester.is_guest)
@@ -301,7 +310,9 @@ class AuthTestCase(unittest.TestCase):
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
with self.assertRaises(InvalidClientCredentialsError) as cm:
- yield self.auth.get_user_by_req(request, allow_guest=True)
+ yield defer.ensureDeferred(
+ self.auth.get_user_by_req(request, allow_guest=True)
+ )
self.assertEqual(401, cm.exception.code)
self.assertEqual("Guest access token used for regular user", cm.exception.msg)
@@ -316,7 +327,7 @@ class AuthTestCase(unittest.TestCase):
small_number_of_users = 1
# Ensure no error thrown
- yield self.auth.check_auth_blocking()
+ yield defer.ensureDeferred(self.auth.check_auth_blocking())
self.hs.config.limit_usage_by_mau = True
@@ -325,7 +336,7 @@ class AuthTestCase(unittest.TestCase):
)
with self.assertRaises(ResourceLimitError) as e:
- yield self.auth.check_auth_blocking()
+ yield defer.ensureDeferred(self.auth.check_auth_blocking())
self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEquals(e.exception.code, 403)
@@ -334,7 +345,7 @@ class AuthTestCase(unittest.TestCase):
self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(small_number_of_users)
)
- yield self.auth.check_auth_blocking()
+ yield defer.ensureDeferred(self.auth.check_auth_blocking())
@defer.inlineCallbacks
def test_blocking_mau__depending_on_user_type(self):
@@ -343,15 +354,19 @@ class AuthTestCase(unittest.TestCase):
self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
# Support users allowed
- yield self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT)
+ yield defer.ensureDeferred(
+ self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT)
+ )
self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
# Bots not allowed
with self.assertRaises(ResourceLimitError):
- yield self.auth.check_auth_blocking(user_type=UserTypes.BOT)
+ yield defer.ensureDeferred(
+ self.auth.check_auth_blocking(user_type=UserTypes.BOT)
+ )
self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
# Real users not allowed
with self.assertRaises(ResourceLimitError):
- yield self.auth.check_auth_blocking()
+ yield defer.ensureDeferred(self.auth.check_auth_blocking())
@defer.inlineCallbacks
def test_reserved_threepid(self):
@@ -362,21 +377,22 @@ class AuthTestCase(unittest.TestCase):
unknown_threepid = {"medium": "email", "address": "unreserved@server.com"}
self.hs.config.mau_limits_reserved_threepids = [threepid]
- yield self.store.register_user(user_id="user1", password_hash=None)
with self.assertRaises(ResourceLimitError):
- yield self.auth.check_auth_blocking()
+ yield defer.ensureDeferred(self.auth.check_auth_blocking())
with self.assertRaises(ResourceLimitError):
- yield self.auth.check_auth_blocking(threepid=unknown_threepid)
+ yield defer.ensureDeferred(
+ self.auth.check_auth_blocking(threepid=unknown_threepid)
+ )
- yield self.auth.check_auth_blocking(threepid=threepid)
+ yield defer.ensureDeferred(self.auth.check_auth_blocking(threepid=threepid))
@defer.inlineCallbacks
def test_hs_disabled(self):
self.hs.config.hs_disabled = True
self.hs.config.hs_disabled_message = "Reason for being disabled"
with self.assertRaises(ResourceLimitError) as e:
- yield self.auth.check_auth_blocking()
+ yield defer.ensureDeferred(self.auth.check_auth_blocking())
self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEquals(e.exception.code, 403)
@@ -393,7 +409,7 @@ class AuthTestCase(unittest.TestCase):
self.hs.config.hs_disabled = True
self.hs.config.hs_disabled_message = "Reason for being disabled"
with self.assertRaises(ResourceLimitError) as e:
- yield self.auth.check_auth_blocking()
+ yield defer.ensureDeferred(self.auth.check_auth_blocking())
self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEquals(e.exception.code, 403)
@@ -404,4 +420,4 @@ class AuthTestCase(unittest.TestCase):
user = "@user:server"
self.hs.config.server_notices_mxid = user
self.hs.config.hs_disabled_message = "Reason for being disabled"
- yield self.auth.check_auth_blocking(user)
+ yield defer.ensureDeferred(self.auth.check_auth_blocking(user))
diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py
index d3feafa1b7..be20a89682 100644
--- a/tests/app/test_frontend_proxy.py
+++ b/tests/app/test_frontend_proxy.py
@@ -27,8 +27,8 @@ class FrontendProxyTests(HomeserverTestCase):
return hs
- def default_config(self, name="test"):
- c = super().default_config(name)
+ def default_config(self):
+ c = super().default_config()
c["worker_app"] = "synapse.app.frontend_proxy"
return c
diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py
index 89fcc3889a..7364f9f1ec 100644
--- a/tests/app/test_openid_listener.py
+++ b/tests/app/test_openid_listener.py
@@ -29,8 +29,8 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
)
return hs
- def default_config(self, name="test"):
- conf = super().default_config(name)
+ def default_config(self):
+ conf = super().default_config()
# we're using FederationReaderServer, which uses a SlavedStore, so we
# have to tell the FederationHandler not to try to access stuff that is only
# in the primary store.
diff --git a/tests/config/test_database.py b/tests/config/test_database.py
index 151d3006ac..f675bde68e 100644
--- a/tests/config/test_database.py
+++ b/tests/config/test_database.py
@@ -21,9 +21,9 @@ from tests import unittest
class DatabaseConfigTestCase(unittest.TestCase):
- def test_database_configured_correctly_no_database_conf_param(self):
+ def test_database_configured_correctly(self):
conf = yaml.safe_load(
- DatabaseConfig().generate_config_section("/data_dir_path", None)
+ DatabaseConfig().generate_config_section(data_dir_path="/data_dir_path")
)
expected_database_conf = {
@@ -32,21 +32,3 @@ class DatabaseConfigTestCase(unittest.TestCase):
}
self.assertEqual(conf["database"], expected_database_conf)
-
- def test_database_configured_correctly_database_conf_param(self):
-
- database_conf = {
- "name": "my super fast datastore",
- "args": {
- "user": "matrix",
- "password": "synapse_database_password",
- "host": "synapse_database_host",
- "database": "matrix",
- },
- }
-
- conf = yaml.safe_load(
- DatabaseConfig().generate_config_section("/data_dir_path", database_conf)
- )
-
- self.assertEqual(conf["database"], database_conf)
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 34d5895f18..70c8e72303 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -34,6 +34,7 @@ from synapse.crypto.keyring import (
from synapse.logging.context import (
LoggingContext,
PreserveLoggingContext,
+ current_context,
make_deferred_yieldable,
)
from synapse.storage.keys import FetchKeyResult
@@ -83,9 +84,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
)
def check_context(self, _, expected):
- self.assertEquals(
- getattr(LoggingContext.current_context(), "request", None), expected
- )
+ self.assertEquals(getattr(current_context(), "request", None), expected)
def test_verify_json_objects_for_server_awaits_previous_requests(self):
key1 = signedjson.key.generate_signing_key(1)
@@ -105,7 +104,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
@defer.inlineCallbacks
def get_perspectives(**kwargs):
- self.assertEquals(LoggingContext.current_context().request, "11")
+ self.assertEquals(current_context().request, "11")
with PreserveLoggingContext():
yield persp_deferred
return persp_resp
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 24fa8dbb45..94980733c4 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -33,8 +33,8 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
login.register_servlets,
]
- def default_config(self, name="test"):
- config = super().default_config(name=name)
+ def default_config(self):
+ config = super().default_config()
config["limit_remote_rooms"] = {"enabled": True, "complexity": 0.05}
return config
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index d456267b87..33105576af 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -12,19 +12,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Optional
from mock import Mock
+from signedjson import key, sign
+from signedjson.types import BaseKey, SigningKey
+
from twisted.internet import defer
-from synapse.types import ReadReceipt
+from synapse.rest import admin
+from synapse.rest.client.v1 import login
+from synapse.types import JsonDict, ReadReceipt
from tests.unittest import HomeserverTestCase, override_config
-class FederationSenderTestCases(HomeserverTestCase):
+class FederationSenderReceiptsTestCases(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
- return super(FederationSenderTestCases, self).setup_test_homeserver(
+ return self.setup_test_homeserver(
state_handler=Mock(spec=["get_current_hosts_in_room"]),
federation_transport_client=Mock(spec=["send_transaction"]),
)
@@ -147,3 +153,392 @@ class FederationSenderTestCases(HomeserverTestCase):
}
],
)
+
+
+class FederationSenderDevicesTestCases(HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ return self.setup_test_homeserver(
+ state_handler=Mock(spec=["get_current_hosts_in_room"]),
+ federation_transport_client=Mock(spec=["send_transaction"]),
+ )
+
+ def default_config(self):
+ c = super().default_config()
+ c["send_federation"] = True
+ return c
+
+ def prepare(self, reactor, clock, hs):
+ # stub out get_current_hosts_in_room
+ mock_state_handler = hs.get_state_handler()
+ mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"]
+
+ # stub out get_users_who_share_room_with_user so that it claims that
+ # `@user2:host2` is in the room
+ def get_users_who_share_room_with_user(user_id):
+ return defer.succeed({"@user2:host2"})
+
+ hs.get_datastore().get_users_who_share_room_with_user = (
+ get_users_who_share_room_with_user
+ )
+
+ # whenever send_transaction is called, record the edu data
+ self.edus = []
+ self.hs.get_federation_transport_client().send_transaction.side_effect = (
+ self.record_transaction
+ )
+
+ def record_transaction(self, txn, json_cb):
+ data = json_cb()
+ self.edus.extend(data["edus"])
+ return defer.succeed({})
+
+ def test_send_device_updates(self):
+ """Basic case: each device update should result in an EDU"""
+ # create a device
+ u1 = self.register_user("user", "pass")
+ self.login(u1, "pass", device_id="D1")
+
+ # expect one edu
+ self.assertEqual(len(self.edus), 1)
+ stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", None)
+
+ # a second call should produce no new device EDUs
+ self.hs.get_federation_sender().send_device_messages("host2")
+ self.pump()
+ self.assertEqual(self.edus, [])
+
+ # a second device
+ self.login("user", "pass", device_id="D2")
+
+ self.assertEqual(len(self.edus), 1)
+ self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id)
+
+ def test_upload_signatures(self):
+ """Uploading signatures on some devices should produce updates for that user"""
+
+ e2e_handler = self.hs.get_e2e_keys_handler()
+
+ # register two devices
+ u1 = self.register_user("user", "pass")
+ self.login(u1, "pass", device_id="D1")
+ self.login(u1, "pass", device_id="D2")
+
+ # expect two edus
+ self.assertEqual(len(self.edus), 2)
+ stream_id = None
+ stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", stream_id)
+ stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id)
+
+ # upload signing keys for each device
+ device1_signing_key = self.generate_and_upload_device_signing_key(u1, "D1")
+ device2_signing_key = self.generate_and_upload_device_signing_key(u1, "D2")
+
+ # expect two more edus
+ self.assertEqual(len(self.edus), 2)
+ stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", stream_id)
+ stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id)
+
+ # upload master key and self-signing key
+ master_signing_key = generate_self_id_key()
+ master_key = {
+ "user_id": u1,
+ "usage": ["master"],
+ "keys": {key_id(master_signing_key): encode_pubkey(master_signing_key)},
+ }
+
+ # private key: HvQBbU+hc2Zr+JP1sE0XwBe1pfZZEYtJNPJLZJtS+F8
+ selfsigning_signing_key = generate_self_id_key()
+ selfsigning_key = {
+ "user_id": u1,
+ "usage": ["self_signing"],
+ "keys": {
+ key_id(selfsigning_signing_key): encode_pubkey(selfsigning_signing_key)
+ },
+ }
+ sign.sign_json(selfsigning_key, u1, master_signing_key)
+
+ cross_signing_keys = {
+ "master_key": master_key,
+ "self_signing_key": selfsigning_key,
+ }
+
+ self.get_success(
+ e2e_handler.upload_signing_keys_for_user(u1, cross_signing_keys)
+ )
+
+ # expect signing key update edu
+ self.assertEqual(len(self.edus), 1)
+ self.assertEqual(self.edus.pop(0)["edu_type"], "org.matrix.signing_key_update")
+
+ # sign the devices
+ d1_json = build_device_dict(u1, "D1", device1_signing_key)
+ sign.sign_json(d1_json, u1, selfsigning_signing_key)
+ d2_json = build_device_dict(u1, "D2", device2_signing_key)
+ sign.sign_json(d2_json, u1, selfsigning_signing_key)
+
+ ret = self.get_success(
+ e2e_handler.upload_signatures_for_device_keys(
+ u1, {u1: {"D1": d1_json, "D2": d2_json}},
+ )
+ )
+ self.assertEqual(ret["failures"], {})
+
+ # expect two edus, in one or two transactions. We don't know what order the
+ # devices will be updated.
+ self.assertEqual(len(self.edus), 2)
+ stream_id = None # FIXME: there is a discontinuity in the stream IDs: see #7142
+ for edu in self.edus:
+ self.assertEqual(edu["edu_type"], "m.device_list_update")
+ c = edu["content"]
+ if stream_id is not None:
+ self.assertEqual(c["prev_id"], [stream_id])
+ self.assertGreaterEqual(c["stream_id"], stream_id)
+ stream_id = c["stream_id"]
+ devices = {edu["content"]["device_id"] for edu in self.edus}
+ self.assertEqual({"D1", "D2"}, devices)
+
+ def test_delete_devices(self):
+ """If devices are deleted, that should result in EDUs too"""
+
+ # create devices
+ u1 = self.register_user("user", "pass")
+ self.login("user", "pass", device_id="D1")
+ self.login("user", "pass", device_id="D2")
+ self.login("user", "pass", device_id="D3")
+
+ # expect three edus
+ self.assertEqual(len(self.edus), 3)
+ stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", None)
+ stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id)
+ stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D3", stream_id)
+
+ # delete them again
+ self.get_success(
+ self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
+ )
+
+ # expect three edus, in an unknown order
+ self.assertEqual(len(self.edus), 3)
+ for edu in self.edus:
+ self.assertEqual(edu["edu_type"], "m.device_list_update")
+ c = edu["content"]
+ self.assertGreaterEqual(
+ c.items(),
+ {"user_id": u1, "prev_id": [stream_id], "deleted": True}.items(),
+ )
+ self.assertGreaterEqual(c["stream_id"], stream_id)
+ stream_id = c["stream_id"]
+ devices = {edu["content"]["device_id"] for edu in self.edus}
+ self.assertEqual({"D1", "D2", "D3"}, devices)
+
+ def test_unreachable_server(self):
+ """If the destination server is unreachable, all the updates should get sent on
+ recovery
+ """
+ mock_send_txn = self.hs.get_federation_transport_client().send_transaction
+ mock_send_txn.side_effect = lambda t, cb: defer.fail("fail")
+
+ # create devices
+ u1 = self.register_user("user", "pass")
+ self.login("user", "pass", device_id="D1")
+ self.login("user", "pass", device_id="D2")
+ self.login("user", "pass", device_id="D3")
+
+ # delete them again
+ self.get_success(
+ self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
+ )
+
+ self.assertGreaterEqual(mock_send_txn.call_count, 4)
+
+ # recover the server
+ mock_send_txn.side_effect = self.record_transaction
+ self.hs.get_federation_sender().send_device_messages("host2")
+ self.pump()
+
+ # for each device, there should be a single update
+ self.assertEqual(len(self.edus), 3)
+ stream_id = None
+ for edu in self.edus:
+ self.assertEqual(edu["edu_type"], "m.device_list_update")
+ c = edu["content"]
+ self.assertEqual(c["prev_id"], [stream_id] if stream_id is not None else [])
+ if stream_id is not None:
+ self.assertGreaterEqual(c["stream_id"], stream_id)
+ stream_id = c["stream_id"]
+ devices = {edu["content"]["device_id"] for edu in self.edus}
+ self.assertEqual({"D1", "D2", "D3"}, devices)
+
+ def test_prune_outbound_device_pokes1(self):
+ """If a destination is unreachable, and the updates are pruned, we should get
+ a single update.
+
+ This case tests the behaviour when the server has never been reachable.
+ """
+ mock_send_txn = self.hs.get_federation_transport_client().send_transaction
+ mock_send_txn.side_effect = lambda t, cb: defer.fail("fail")
+
+ # create devices
+ u1 = self.register_user("user", "pass")
+ self.login("user", "pass", device_id="D1")
+ self.login("user", "pass", device_id="D2")
+ self.login("user", "pass", device_id="D3")
+
+ # delete them again
+ self.get_success(
+ self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
+ )
+
+ self.assertGreaterEqual(mock_send_txn.call_count, 4)
+
+ # run the prune job
+ self.reactor.advance(10)
+ self.get_success(
+ self.hs.get_datastore()._prune_old_outbound_device_pokes(prune_age=1)
+ )
+
+ # recover the server
+ mock_send_txn.side_effect = self.record_transaction
+ self.hs.get_federation_sender().send_device_messages("host2")
+ self.pump()
+
+ # there should be a single update for this user.
+ self.assertEqual(len(self.edus), 1)
+ edu = self.edus.pop(0)
+ self.assertEqual(edu["edu_type"], "m.device_list_update")
+ c = edu["content"]
+
+ # synapse uses an empty prev_id list to indicate "needs a full resync".
+ self.assertEqual(c["prev_id"], [])
+
+ def test_prune_outbound_device_pokes2(self):
+ """If a destination is unreachable, and the updates are pruned, we should get
+ a single update.
+
+ This case tests the behaviour when the server was reachable, but then goes
+ offline.
+ """
+
+ # create first device
+ u1 = self.register_user("user", "pass")
+ self.login("user", "pass", device_id="D1")
+
+ # expect the update EDU
+ self.assertEqual(len(self.edus), 1)
+ self.check_device_update_edu(self.edus.pop(0), u1, "D1", None)
+
+ # now the server goes offline
+ mock_send_txn = self.hs.get_federation_transport_client().send_transaction
+ mock_send_txn.side_effect = lambda t, cb: defer.fail("fail")
+
+ self.login("user", "pass", device_id="D2")
+ self.login("user", "pass", device_id="D3")
+
+ # delete them again
+ self.get_success(
+ self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
+ )
+
+ self.assertGreaterEqual(mock_send_txn.call_count, 3)
+
+ # run the prune job
+ self.reactor.advance(10)
+ self.get_success(
+ self.hs.get_datastore()._prune_old_outbound_device_pokes(prune_age=1)
+ )
+
+ # recover the server
+ mock_send_txn.side_effect = self.record_transaction
+ self.hs.get_federation_sender().send_device_messages("host2")
+ self.pump()
+
+ # ... and we should get a single update for this user.
+ self.assertEqual(len(self.edus), 1)
+ edu = self.edus.pop(0)
+ self.assertEqual(edu["edu_type"], "m.device_list_update")
+ c = edu["content"]
+
+ # synapse uses an empty prev_id list to indicate "needs a full resync".
+ self.assertEqual(c["prev_id"], [])
+
+ def check_device_update_edu(
+ self,
+ edu: JsonDict,
+ user_id: str,
+ device_id: str,
+ prev_stream_id: Optional[int],
+ ) -> int:
+ """Check that the given EDU is an update for the given device
+ Returns the stream_id.
+ """
+ self.assertEqual(edu["edu_type"], "m.device_list_update")
+ content = edu["content"]
+
+ expected = {
+ "user_id": user_id,
+ "device_id": device_id,
+ "prev_id": [prev_stream_id] if prev_stream_id is not None else [],
+ }
+
+ self.assertLessEqual(expected.items(), content.items())
+ if prev_stream_id is not None:
+ self.assertGreaterEqual(content["stream_id"], prev_stream_id)
+ return content["stream_id"]
+
+ def check_signing_key_update_txn(self, txn: JsonDict,) -> None:
+ """Check that the txn has an EDU with a signing key update.
+ """
+ edus = txn["edus"]
+ self.assertEqual(len(edus), 1)
+
+ def generate_and_upload_device_signing_key(
+ self, user_id: str, device_id: str
+ ) -> SigningKey:
+ """Generate a signing keypair for the given device, and upload it"""
+ sk = key.generate_signing_key(device_id)
+
+ device_dict = build_device_dict(user_id, device_id, sk)
+
+ self.get_success(
+ self.hs.get_e2e_keys_handler().upload_keys_for_user(
+ user_id, device_id, {"device_keys": device_dict},
+ )
+ )
+ return sk
+
+
+def generate_self_id_key() -> SigningKey:
+ """generate a signing key whose version is its public key
+
+ ... as used by the cross-signing-keys.
+ """
+ k = key.generate_signing_key("x")
+ k.version = encode_pubkey(k)
+ return k
+
+
+def key_id(k: BaseKey) -> str:
+ return "%s:%s" % (k.alg, k.version)
+
+
+def encode_pubkey(sk: SigningKey) -> str:
+ """Encode the public key corresponding to the given signing key as base64"""
+ return key.encode_verify_key_base64(key.get_verify_key(sk))
+
+
+def build_device_dict(user_id: str, device_id: str, sk: SigningKey):
+ """Build a dict representing the given device"""
+ return {
+ "user_id": user_id,
+ "device_id": device_id,
+ "algorithms": ["m.olm.curve25519-aes-sha256", "m.megolm.v1.aes-sha"],
+ "keys": {
+ "curve25519:" + device_id: "curve25519+key",
+ key_id(sk): encode_pubkey(sk),
+ },
+ }
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index b03103d96f..52c4ac8b11 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -82,16 +82,16 @@ class AuthTestCase(unittest.TestCase):
self.hs.clock.now = 1000
token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
- user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
- token
+ user_id = yield defer.ensureDeferred(
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
)
self.assertEqual("a_user", user_id)
# when we advance the clock, the token should be rejected
self.hs.clock.now = 6000
with self.assertRaises(synapse.api.errors.AuthError):
- yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
- token
+ yield defer.ensureDeferred(
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
)
@defer.inlineCallbacks
@@ -99,8 +99,10 @@ class AuthTestCase(unittest.TestCase):
token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
macaroon = pymacaroons.Macaroon.deserialize(token)
- user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
- macaroon.serialize()
+ user_id = yield defer.ensureDeferred(
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ macaroon.serialize()
+ )
)
self.assertEqual("a_user", user_id)
@@ -109,20 +111,26 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("user_id = b_user")
with self.assertRaises(synapse.api.errors.AuthError):
- yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
- macaroon.serialize()
+ yield defer.ensureDeferred(
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ macaroon.serialize()
+ )
)
@defer.inlineCallbacks
def test_mau_limits_disabled(self):
self.hs.config.limit_usage_by_mau = False
# Ensure does not throw exception
- yield self.auth_handler.get_access_token_for_user_id(
- "user_a", device_id=None, valid_until_ms=None
+ yield defer.ensureDeferred(
+ self.auth_handler.get_access_token_for_user_id(
+ "user_a", device_id=None, valid_until_ms=None
+ )
)
- yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
- self._get_macaroon().serialize()
+ yield defer.ensureDeferred(
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self._get_macaroon().serialize()
+ )
)
@defer.inlineCallbacks
@@ -133,16 +141,20 @@ class AuthTestCase(unittest.TestCase):
)
with self.assertRaises(ResourceLimitError):
- yield self.auth_handler.get_access_token_for_user_id(
- "user_a", device_id=None, valid_until_ms=None
+ yield defer.ensureDeferred(
+ self.auth_handler.get_access_token_for_user_id(
+ "user_a", device_id=None, valid_until_ms=None
+ )
)
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.large_number_of_users)
)
with self.assertRaises(ResourceLimitError):
- yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
- self._get_macaroon().serialize()
+ yield defer.ensureDeferred(
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self._get_macaroon().serialize()
+ )
)
@defer.inlineCallbacks
@@ -154,16 +166,20 @@ class AuthTestCase(unittest.TestCase):
return_value=defer.succeed(self.hs.config.max_mau_value)
)
with self.assertRaises(ResourceLimitError):
- yield self.auth_handler.get_access_token_for_user_id(
- "user_a", device_id=None, valid_until_ms=None
+ yield defer.ensureDeferred(
+ self.auth_handler.get_access_token_for_user_id(
+ "user_a", device_id=None, valid_until_ms=None
+ )
)
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
)
with self.assertRaises(ResourceLimitError):
- yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
- self._get_macaroon().serialize()
+ yield defer.ensureDeferred(
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self._get_macaroon().serialize()
+ )
)
# If in monthly active cohort
self.hs.get_datastore().user_last_seen_monthly_active = Mock(
@@ -172,8 +188,10 @@ class AuthTestCase(unittest.TestCase):
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
)
- yield self.auth_handler.get_access_token_for_user_id(
- "user_a", device_id=None, valid_until_ms=None
+ yield defer.ensureDeferred(
+ self.auth_handler.get_access_token_for_user_id(
+ "user_a", device_id=None, valid_until_ms=None
+ )
)
self.hs.get_datastore().user_last_seen_monthly_active = Mock(
return_value=defer.succeed(self.hs.get_clock().time_msec())
@@ -181,8 +199,10 @@ class AuthTestCase(unittest.TestCase):
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
)
- yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
- self._get_macaroon().serialize()
+ yield defer.ensureDeferred(
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self._get_macaroon().serialize()
+ )
)
@defer.inlineCallbacks
@@ -193,15 +213,19 @@ class AuthTestCase(unittest.TestCase):
return_value=defer.succeed(self.small_number_of_users)
)
# Ensure does not raise exception
- yield self.auth_handler.get_access_token_for_user_id(
- "user_a", device_id=None, valid_until_ms=None
+ yield defer.ensureDeferred(
+ self.auth_handler.get_access_token_for_user_id(
+ "user_a", device_id=None, valid_until_ms=None
+ )
)
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.small_number_of_users)
)
- yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
- self._get_macaroon().serialize()
+ yield defer.ensureDeferred(
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self._get_macaroon().serialize()
+ )
)
def _get_macaroon(self):
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 5e40adba52..00bb776271 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -102,6 +102,68 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response)
+class TestCreateAlias(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ directory.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.handler = hs.get_handlers().directory_handler
+
+ # Create user
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ # Create a test room
+ self.room_id = self.helper.create_room_as(
+ self.admin_user, tok=self.admin_user_tok
+ )
+
+ self.test_alias = "#test:test"
+ self.room_alias = RoomAlias.from_string(self.test_alias)
+
+ # Create a test user.
+ self.test_user = self.register_user("user", "pass", admin=False)
+ self.test_user_tok = self.login("user", "pass")
+ self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
+
+ def test_create_alias_joined_room(self):
+ """A user can create an alias for a room they're in."""
+ self.get_success(
+ self.handler.create_association(
+ create_requester(self.test_user), self.room_alias, self.room_id,
+ )
+ )
+
+ def test_create_alias_other_room(self):
+ """A user cannot create an alias for a room they're NOT in."""
+ other_room_id = self.helper.create_room_as(
+ self.admin_user, tok=self.admin_user_tok
+ )
+
+ self.get_failure(
+ self.handler.create_association(
+ create_requester(self.test_user), self.room_alias, other_room_id,
+ ),
+ synapse.api.errors.SynapseError,
+ )
+
+ def test_create_alias_admin(self):
+ """An admin can create an alias for a room they're NOT in."""
+ other_room_id = self.helper.create_room_as(
+ self.test_user, tok=self.test_user_tok
+ )
+
+ self.get_success(
+ self.handler.create_association(
+ create_requester(self.admin_user), self.room_alias, other_room_id,
+ )
+ )
+
+
class TestDeleteAlias(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index d60c124eec..be665262c6 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -19,7 +19,7 @@ from mock import Mock, NonCallableMock
from twisted.internet import defer
import synapse.types
-from synapse.api.errors import AuthError
+from synapse.api.errors import AuthError, SynapseError
from synapse.handlers.profile import MasterProfileHandler
from synapse.types import UserID
@@ -70,6 +70,7 @@ class ProfileTestCase(unittest.TestCase):
yield self.store.create_profile(self.frank.localpart)
self.handler = hs.get_profile_handler()
+ self.hs = hs
@defer.inlineCallbacks
def test_get_my_name(self):
@@ -90,6 +91,33 @@ class ProfileTestCase(unittest.TestCase):
"Frank Jr.",
)
+ # Set displayname again
+ yield self.handler.set_displayname(
+ self.frank, synapse.types.create_requester(self.frank), "Frank"
+ )
+
+ self.assertEquals(
+ (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank",
+ )
+
+ @defer.inlineCallbacks
+ def test_set_my_name_if_disabled(self):
+ self.hs.config.enable_set_displayname = False
+
+ # Setting displayname for the first time is allowed
+ yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
+
+ self.assertEquals(
+ (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank",
+ )
+
+ # Setting displayname a second time is forbidden
+ d = self.handler.set_displayname(
+ self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
+ )
+
+ yield self.assertFailure(d, SynapseError)
+
@defer.inlineCallbacks
def test_set_my_name_noauth(self):
d = self.handler.set_displayname(
@@ -147,3 +175,38 @@ class ProfileTestCase(unittest.TestCase):
(yield self.store.get_profile_avatar_url(self.frank.localpart)),
"http://my.server/pic.gif",
)
+
+ # Set avatar again
+ yield self.handler.set_avatar_url(
+ self.frank,
+ synapse.types.create_requester(self.frank),
+ "http://my.server/me.png",
+ )
+
+ self.assertEquals(
+ (yield self.store.get_profile_avatar_url(self.frank.localpart)),
+ "http://my.server/me.png",
+ )
+
+ @defer.inlineCallbacks
+ def test_set_my_avatar_if_disabled(self):
+ self.hs.config.enable_set_avatar_url = False
+
+ # Setting displayname for the first time is allowed
+ yield self.store.set_profile_avatar_url(
+ self.frank.localpart, "http://my.server/me.png"
+ )
+
+ self.assertEquals(
+ (yield self.store.get_profile_avatar_url(self.frank.localpart)),
+ "http://my.server/me.png",
+ )
+
+ # Set avatar a second time is forbidden
+ d = self.handler.set_avatar_url(
+ self.frank,
+ synapse.types.create_requester(self.frank),
+ "http://my.server/pic.gif",
+ )
+
+ yield self.assertFailure(d, SynapseError)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index e2915eb7b1..f1dc51d6c9 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -34,7 +34,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
""" Tests the RegistrationHandler. """
def make_homeserver(self, reactor, clock):
- hs_config = self.default_config("test")
+ hs_config = self.default_config()
# some of the tests rely on us having a user consent version
hs_config["user_consent"] = {
@@ -294,7 +294,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
create_profile_with_displayname=user.localpart,
)
else:
- yield self.hs.get_auth_handler().delete_access_tokens_for_user(user_id)
+ yield defer.ensureDeferred(
+ self.hs.get_auth_handler().delete_access_tokens_for_user(user_id)
+ )
yield self.store.add_access_token_to_user(
user_id=user_id, token=token, device_id=None, valid_until_ms=None
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index fdc1d918ff..562397cdda 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -38,7 +38,7 @@ from synapse.http.federation.well_known_resolver import (
WellKnownResolver,
_cache_period_from_headers,
)
-from synapse.logging.context import LoggingContext
+from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context
from synapse.util.caches.ttlcache import TTLCache
from tests import unittest
@@ -155,7 +155,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.assertNoResult(fetch_d)
# should have reset logcontext to the sentinel
- _check_logcontext(LoggingContext.sentinel)
+ _check_logcontext(SENTINEL_CONTEXT)
try:
fetch_res = yield fetch_d
@@ -1197,7 +1197,7 @@ class TestCachePeriodFromHeaders(unittest.TestCase):
def _check_logcontext(context):
- current = LoggingContext.current_context()
+ current = current_context()
if current is not context:
raise AssertionError("Expected logcontext %s but was %s" % (context, current))
diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py
index df034ab237..babc201643 100644
--- a/tests/http/federation/test_srv_resolver.py
+++ b/tests/http/federation/test_srv_resolver.py
@@ -22,7 +22,7 @@ from twisted.internet.error import ConnectError
from twisted.names import dns, error
from synapse.http.federation.srv_resolver import SrvResolver
-from synapse.logging.context import LoggingContext
+from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context
from tests import unittest
from tests.utils import MockClock
@@ -54,12 +54,12 @@ class SrvResolverTestCase(unittest.TestCase):
self.assertNoResult(resolve_d)
# should have reset to the sentinel context
- self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel)
+ self.assertIs(current_context(), SENTINEL_CONTEXT)
result = yield resolve_d
# should have restored our context
- self.assertIs(LoggingContext.current_context(), ctx)
+ self.assertIs(current_context(), ctx)
return result
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index 2b01f40a42..fff4f0cbf4 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -29,14 +29,14 @@ from synapse.http.matrixfederationclient import (
MatrixFederationHttpClient,
MatrixFederationRequest,
)
-from synapse.logging.context import LoggingContext
+from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context
from tests.server import FakeTransport
from tests.unittest import HomeserverTestCase
def check_logcontext(context):
- current = LoggingContext.current_context()
+ current = current_context()
if current is not context:
raise AssertionError("Expected logcontext %s but was %s" % (context, current))
@@ -64,7 +64,7 @@ class FederationClientTests(HomeserverTestCase):
self.assertNoResult(fetch_d)
# should have reset logcontext to the sentinel
- check_logcontext(LoggingContext.sentinel)
+ check_logcontext(SENTINEL_CONTEXT)
try:
fetch_res = yield fetch_d
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
new file mode 100644
index 0000000000..9ae6a87d7b
--- /dev/null
+++ b/tests/push/test_push_rule_evaluator.py
@@ -0,0 +1,65 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.api.room_versions import RoomVersions
+from synapse.events import FrozenEvent
+from synapse.push.push_rule_evaluator import PushRuleEvaluatorForEvent
+
+from tests import unittest
+
+
+class PushRuleEvaluatorTestCase(unittest.TestCase):
+ def setUp(self):
+ event = FrozenEvent(
+ {
+ "event_id": "$event_id",
+ "type": "m.room.history_visibility",
+ "sender": "@user:test",
+ "state_key": "",
+ "room_id": "@room:test",
+ "content": {"body": "foo bar baz"},
+ },
+ RoomVersions.V1,
+ )
+ room_member_count = 0
+ sender_power_level = 0
+ power_levels = {}
+ self.evaluator = PushRuleEvaluatorForEvent(
+ event, room_member_count, sender_power_level, power_levels
+ )
+
+ def test_display_name(self):
+ """Check for a matching display name in the body of the event."""
+ condition = {
+ "kind": "contains_display_name",
+ }
+
+ # Blank names are skipped.
+ self.assertFalse(self.evaluator.matches(condition, "@user:test", ""))
+
+ # Check a display name that doesn't match.
+ self.assertFalse(self.evaluator.matches(condition, "@user:test", "not found"))
+
+ # Check a display name which matches.
+ self.assertTrue(self.evaluator.matches(condition, "@user:test", "foo"))
+
+ # A display name that matches, but not a full word does not result in a match.
+ self.assertFalse(self.evaluator.matches(condition, "@user:test", "ba"))
+
+ # A display name should not be interpreted as a regular expression.
+ self.assertFalse(self.evaluator.matches(condition, "@user:test", "ba[rz]"))
+
+ # A display name with spaces should work fine.
+ self.assertTrue(self.evaluator.matches(condition, "@user:test", "foo bar"))
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 2a1e7c7166..395c7d0306 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -16,9 +16,10 @@
from mock import Mock, NonCallableMock
from synapse.replication.tcp.client import (
- ReplicationClientFactory,
- ReplicationClientHandler,
+ DirectTcpReplicationClientFactory,
+ ReplicationDataHandler,
)
+from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.storage.database import make_conn
@@ -51,15 +52,19 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
self.event_id = 0
server_factory = ReplicationStreamProtocolFactory(self.hs)
- self.streamer = server_factory.streamer
+ self.streamer = hs.get_replication_streamer()
- handler_factory = Mock()
- self.replication_handler = ReplicationClientHandler(self.slaved_store)
- self.replication_handler.factory = handler_factory
+ # We now do some gut wrenching so that we have a client that is based
+ # off of the slave store rather than the main store.
+ self.replication_handler = ReplicationCommandHandler(self.hs)
+ self.replication_handler._replication_data_handler = ReplicationDataHandler(
+ self.slaved_store
+ )
- client_factory = ReplicationClientFactory(
+ client_factory = DirectTcpReplicationClientFactory(
self.hs, "client_name", self.replication_handler
)
+ client_factory.handler = self.replication_handler
server = server_factory.buildProtocol(None)
client = client_factory.buildProtocol(None)
diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py
index e96ad4ca4e..32238fe79a 100644
--- a/tests/replication/tcp/streams/_base.py
+++ b/tests/replication/tcp/streams/_base.py
@@ -12,9 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
from mock import Mock
-from synapse.replication.tcp.commands import ReplicateCommand
+from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@@ -25,23 +26,46 @@ from tests.server import FakeTransport
class BaseStreamTestCase(unittest.HomeserverTestCase):
"""Base class for tests of the replication streams"""
+ def make_homeserver(self, reactor, clock):
+ self.test_handler = Mock(wraps=TestReplicationDataHandler())
+ return self.setup_test_homeserver(replication_data_handler=self.test_handler)
+
def prepare(self, reactor, clock, hs):
# build a replication server
- server_factory = ReplicationStreamProtocolFactory(self.hs)
- self.streamer = server_factory.streamer
- server = server_factory.buildProtocol(None)
-
- # build a replication client, with a dummy handler
- handler_factory = Mock()
- self.test_handler = TestReplicationClientHandler()
- self.test_handler.factory = handler_factory
+ server_factory = ReplicationStreamProtocolFactory(hs)
+ self.streamer = hs.get_replication_streamer()
+ self.server = server_factory.buildProtocol(None)
+
+ repl_handler = ReplicationCommandHandler(hs)
+ repl_handler.handler = self.test_handler
self.client = ClientReplicationStreamProtocol(
- "client", "test", clock, self.test_handler
+ hs, "client", "test", clock, repl_handler,
)
- # wire them together
- self.client.makeConnection(FakeTransport(server, reactor))
- server.makeConnection(FakeTransport(self.client, reactor))
+ self._client_transport = None
+ self._server_transport = None
+
+ def reconnect(self):
+ if self._client_transport:
+ self.client.close()
+
+ if self._server_transport:
+ self.server.close()
+
+ self._client_transport = FakeTransport(self.server, self.reactor)
+ self.client.makeConnection(self._client_transport)
+
+ self._server_transport = FakeTransport(self.client, self.reactor)
+ self.server.makeConnection(self._server_transport)
+
+ def disconnect(self):
+ if self._client_transport:
+ self._client_transport = None
+ self.client.close()
+
+ if self._server_transport:
+ self._server_transport = None
+ self.server.close()
def replicate(self):
"""Tell the master side of replication that something has happened, and then
@@ -50,29 +74,24 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.streamer.on_notifier_poke()
self.pump(0.1)
- def replicate_stream(self, stream, token="NOW"):
- """Make the client end a REPLICATE command to set up a subscription to a stream"""
- self.client.send_command(ReplicateCommand(stream, token))
-
-class TestReplicationClientHandler(object):
- """Drop-in for ReplicationClientHandler which just collects RDATA rows"""
+class TestReplicationDataHandler:
+ """Drop-in for ReplicationDataHandler which just collects RDATA rows"""
def __init__(self):
- self.received_rdata_rows = []
+ self.streams = set()
+ self._received_rdata_rows = []
def get_streams_to_replicate(self):
- return {}
-
- def get_currently_syncing_users(self):
- return []
-
- def update_connection(self, connection):
- pass
-
- def finished_connecting(self):
- pass
+ positions = {s: 0 for s in self.streams}
+ for stream, token, _ in self._received_rdata_rows:
+ if stream in self.streams:
+ positions[stream] = max(token, positions.get(stream, 0))
+ return positions
async def on_rdata(self, stream_name, token, rows):
for r in rows:
- self.received_rdata_rows.append((stream_name, token, r))
+ self._received_rdata_rows.append((stream_name, token, r))
+
+ async def on_position(self, stream_name, token):
+ pass
diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py
index d5a99f6caa..a0206f7363 100644
--- a/tests/replication/tcp/streams/test_receipts.py
+++ b/tests/replication/tcp/streams/test_receipts.py
@@ -12,35 +12,68 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.replication.tcp.streams._base import ReceiptsStreamRow
+from synapse.replication.tcp.streams._base import ReceiptsStream
from tests.replication.tcp.streams._base import BaseStreamTestCase
USER_ID = "@feeling:blue"
-ROOM_ID = "!room:blue"
-EVENT_ID = "$event:blue"
class ReceiptsStreamTestCase(BaseStreamTestCase):
def test_receipt(self):
+ self.reconnect()
+
# make the client subscribe to the receipts stream
- self.replicate_stream("receipts", "NOW")
+ self.test_handler.streams.add("receipts")
# tell the master to send a new receipt
self.get_success(
self.hs.get_datastore().insert_receipt(
- ROOM_ID, "m.read", USER_ID, [EVENT_ID], {"a": 1}
+ "!room:blue", "m.read", USER_ID, ["$event:blue"], {"a": 1}
)
)
self.replicate()
# there should be one RDATA command
- rdata_rows = self.test_handler.received_rdata_rows
+ self.test_handler.on_rdata.assert_called_once()
+ stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+ self.assertEqual(stream_name, "receipts")
self.assertEqual(1, len(rdata_rows))
- self.assertEqual(rdata_rows[0][0], "receipts")
- row = rdata_rows[0][2] # type: ReceiptsStreamRow
- self.assertEqual(ROOM_ID, row.room_id)
+ row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow
+ self.assertEqual("!room:blue", row.room_id)
self.assertEqual("m.read", row.receipt_type)
self.assertEqual(USER_ID, row.user_id)
- self.assertEqual(EVENT_ID, row.event_id)
+ self.assertEqual("$event:blue", row.event_id)
self.assertEqual({"a": 1}, row.data)
+
+ # Now let's disconnect and insert some data.
+ self.disconnect()
+
+ self.test_handler.on_rdata.reset_mock()
+
+ self.get_success(
+ self.hs.get_datastore().insert_receipt(
+ "!room2:blue", "m.read", USER_ID, ["$event2:foo"], {"a": 2}
+ )
+ )
+ self.replicate()
+
+ # Nothing should have happened as we are disconnected
+ self.test_handler.on_rdata.assert_not_called()
+
+ self.reconnect()
+ self.pump(0.1)
+
+ # We should now have caught up and get the missing data
+ self.test_handler.on_rdata.assert_called_once()
+ stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+ self.assertEqual(stream_name, "receipts")
+ self.assertEqual(token, 3)
+ self.assertEqual(1, len(rdata_rows))
+
+ row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow
+ self.assertEqual("!room2:blue", row.room_id)
+ self.assertEqual("m.read", row.receipt_type)
+ self.assertEqual(USER_ID, row.user_id)
+ self.assertEqual("$event2:foo", row.event_id)
+ self.assertEqual({"a": 2}, row.data)
diff --git a/tests/replication/tcp/test_commands.py b/tests/replication/tcp/test_commands.py
new file mode 100644
index 0000000000..3cbcb513cc
--- /dev/null
+++ b/tests/replication/tcp/test_commands.py
@@ -0,0 +1,42 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.replication.tcp.commands import (
+ RdataCommand,
+ ReplicateCommand,
+ parse_command_from_line,
+)
+
+from tests.unittest import TestCase
+
+
+class ParseCommandTestCase(TestCase):
+ def test_parse_one_word_command(self):
+ line = "REPLICATE"
+ cmd = parse_command_from_line(line)
+ self.assertIsInstance(cmd, ReplicateCommand)
+
+ def test_parse_rdata(self):
+ line = 'RDATA events 6287863 ["ev", ["$eventid", "!roomid", "type", null, null, null]]'
+ cmd = parse_command_from_line(line)
+ self.assertIsInstance(cmd, RdataCommand)
+ self.assertEqual(cmd.stream_name, "events")
+ self.assertEqual(cmd.token, 6287863)
+
+ def test_parse_rdata_batch(self):
+ line = 'RDATA presence batch ["@foo:example.com", "online"]'
+ cmd = parse_command_from_line(line)
+ self.assertIsInstance(cmd, RdataCommand)
+ self.assertEqual(cmd.stream_name, "presence")
+ self.assertIsNone(cmd.token)
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 0342aed416..977615ebef 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -17,7 +17,6 @@ import json
import os
import urllib.parse
from binascii import unhexlify
-from typing import List, Optional
from mock import Mock
@@ -27,7 +26,7 @@ import synapse.rest.admin
from synapse.http.server import JsonResource
from synapse.logging.context import make_deferred_yieldable
from synapse.rest.admin import VersionServlet
-from synapse.rest.client.v1 import directory, events, login, room
+from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import groups
from tests import unittest
@@ -51,129 +50,6 @@ class VersionTestCase(unittest.HomeserverTestCase):
)
-class ShutdownRoomTestCase(unittest.HomeserverTestCase):
- servlets = [
- synapse.rest.admin.register_servlets_for_client_rest_resource,
- login.register_servlets,
- events.register_servlets,
- room.register_servlets,
- room.register_deprecated_servlets,
- ]
-
- def prepare(self, reactor, clock, hs):
- self.event_creation_handler = hs.get_event_creation_handler()
- hs.config.user_consent_version = "1"
-
- consent_uri_builder = Mock()
- consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
- self.event_creation_handler._consent_uri_builder = consent_uri_builder
-
- self.store = hs.get_datastore()
-
- self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
-
- self.other_user = self.register_user("user", "pass")
- self.other_user_token = self.login("user", "pass")
-
- # Mark the admin user as having consented
- self.get_success(self.store.user_set_consent_version(self.admin_user, "1"))
-
- def test_shutdown_room_consent(self):
- """Test that we can shutdown rooms with local users who have not
- yet accepted the privacy policy. This used to fail when we tried to
- force part the user from the old room.
- """
- self.event_creation_handler._block_events_without_consent_error = None
-
- room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
-
- # Assert one user in room
- users_in_room = self.get_success(self.store.get_users_in_room(room_id))
- self.assertEqual([self.other_user], users_in_room)
-
- # Enable require consent to send events
- self.event_creation_handler._block_events_without_consent_error = "Error"
-
- # Assert that the user is getting consent error
- self.helper.send(
- room_id, body="foo", tok=self.other_user_token, expect_code=403
- )
-
- # Test that the admin can still send shutdown
- url = "admin/shutdown_room/" + room_id
- request, channel = self.make_request(
- "POST",
- url.encode("ascii"),
- json.dumps({"new_room_user_id": self.admin_user}),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Assert there is now no longer anyone in the room
- users_in_room = self.get_success(self.store.get_users_in_room(room_id))
- self.assertEqual([], users_in_room)
-
- def test_shutdown_room_block_peek(self):
- """Test that a world_readable room can no longer be peeked into after
- it has been shut down.
- """
-
- self.event_creation_handler._block_events_without_consent_error = None
-
- room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
-
- # Enable world readable
- url = "rooms/%s/state/m.room.history_visibility" % (room_id,)
- request, channel = self.make_request(
- "PUT",
- url.encode("ascii"),
- json.dumps({"history_visibility": "world_readable"}),
- access_token=self.other_user_token,
- )
- self.render(request)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Test that the admin can still send shutdown
- url = "admin/shutdown_room/" + room_id
- request, channel = self.make_request(
- "POST",
- url.encode("ascii"),
- json.dumps({"new_room_user_id": self.admin_user}),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Assert we can no longer peek into the room
- self._assert_peek(room_id, expect_code=403)
-
- def _assert_peek(self, room_id, expect_code):
- """Assert that the admin user can (or cannot) peek into the room.
- """
-
- url = "rooms/%s/initialSync" % (room_id,)
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok
- )
- self.render(request)
- self.assertEqual(
- expect_code, int(channel.result["code"]), msg=channel.result["body"]
- )
-
- url = "events?timeout=0&room_id=" + room_id
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok
- )
- self.render(request)
- self.assertEqual(
- expect_code, int(channel.result["code"]), msg=channel.result["body"]
- )
-
-
class DeleteGroupTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -273,86 +149,6 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
return channel.json_body["groups"]
-class PurgeRoomTestCase(unittest.HomeserverTestCase):
- """Test /purge_room admin API.
- """
-
- servlets = [
- synapse.rest.admin.register_servlets,
- login.register_servlets,
- room.register_servlets,
- ]
-
- def prepare(self, reactor, clock, hs):
- self.store = hs.get_datastore()
-
- self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
-
- def test_purge_room(self):
- room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
-
- # All users have to have left the room.
- self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok)
-
- url = "/_synapse/admin/v1/purge_room"
- request, channel = self.make_request(
- "POST",
- url.encode("ascii"),
- {"room_id": room_id},
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Test that the following tables have been purged of all rows related to the room.
- for table in (
- "current_state_events",
- "event_backward_extremities",
- "event_forward_extremities",
- "event_json",
- "event_push_actions",
- "event_search",
- "events",
- "group_rooms",
- "public_room_list_stream",
- "receipts_graph",
- "receipts_linearized",
- "room_aliases",
- "room_depth",
- "room_memberships",
- "room_stats_state",
- "room_stats_current",
- "room_stats_historical",
- "room_stats_earliest_token",
- "rooms",
- "stream_ordering_to_exterm",
- "users_in_public_rooms",
- "users_who_share_private_rooms",
- "appservice_room_list",
- "e2e_room_keys",
- "event_push_summary",
- "pusher_throttle",
- "group_summary_rooms",
- "local_invites",
- "room_account_data",
- "room_tags",
- # "state_groups", # Current impl leaves orphaned state groups around.
- "state_groups_state",
- ):
- count = self.get_success(
- self.store.db.simple_select_one_onecol(
- table=table,
- keyvalues={"room_id": room_id},
- retcol="COUNT(*)",
- desc="test_purge_room",
- )
- )
-
- self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
-
-
class QuarantineMediaTestCase(unittest.HomeserverTestCase):
"""Test /quarantine_media admin API.
"""
@@ -691,389 +487,3 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
% server_and_media_id_2
),
)
-
-
-class RoomTestCase(unittest.HomeserverTestCase):
- """Test /room admin API.
- """
-
- servlets = [
- synapse.rest.admin.register_servlets,
- login.register_servlets,
- room.register_servlets,
- directory.register_servlets,
- ]
-
- def prepare(self, reactor, clock, hs):
- self.store = hs.get_datastore()
-
- # Create user
- self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
-
- def test_list_rooms(self):
- """Test that we can list rooms"""
- # Create 3 test rooms
- total_rooms = 3
- room_ids = []
- for x in range(total_rooms):
- room_id = self.helper.create_room_as(
- self.admin_user, tok=self.admin_user_tok
- )
- room_ids.append(room_id)
-
- # Request the list of rooms
- url = "/_synapse/admin/v1/rooms"
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
- )
- self.render(request)
-
- # Check request completed successfully
- self.assertEqual(200, int(channel.code), msg=channel.json_body)
-
- # Check that response json body contains a "rooms" key
- self.assertTrue(
- "rooms" in channel.json_body,
- msg="Response body does not " "contain a 'rooms' key",
- )
-
- # Check that 3 rooms were returned
- self.assertEqual(3, len(channel.json_body["rooms"]), msg=channel.json_body)
-
- # Check their room_ids match
- returned_room_ids = [room["room_id"] for room in channel.json_body["rooms"]]
- self.assertEqual(room_ids, returned_room_ids)
-
- # Check that all fields are available
- for r in channel.json_body["rooms"]:
- self.assertIn("name", r)
- self.assertIn("canonical_alias", r)
- self.assertIn("joined_members", r)
-
- # Check that the correct number of total rooms was returned
- self.assertEqual(channel.json_body["total_rooms"], total_rooms)
-
- # Check that the offset is correct
- # Should be 0 as we aren't paginating
- self.assertEqual(channel.json_body["offset"], 0)
-
- # Check that the prev_batch parameter is not present
- self.assertNotIn("prev_batch", channel.json_body)
-
- # We shouldn't receive a next token here as there's no further rooms to show
- self.assertNotIn("next_batch", channel.json_body)
-
- def test_list_rooms_pagination(self):
- """Test that we can get a full list of rooms through pagination"""
- # Create 5 test rooms
- total_rooms = 5
- room_ids = []
- for x in range(total_rooms):
- room_id = self.helper.create_room_as(
- self.admin_user, tok=self.admin_user_tok
- )
- room_ids.append(room_id)
-
- # Set the name of the rooms so we get a consistent returned ordering
- for idx, room_id in enumerate(room_ids):
- self.helper.send_state(
- room_id, "m.room.name", {"name": str(idx)}, tok=self.admin_user_tok,
- )
-
- # Request the list of rooms
- returned_room_ids = []
- start = 0
- limit = 2
-
- run_count = 0
- should_repeat = True
- while should_repeat:
- run_count += 1
-
- url = "/_synapse/admin/v1/rooms?from=%d&limit=%d&order_by=%s" % (
- start,
- limit,
- "alphabetical",
- )
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(
- 200, int(channel.result["code"]), msg=channel.result["body"]
- )
-
- self.assertTrue("rooms" in channel.json_body)
- for r in channel.json_body["rooms"]:
- returned_room_ids.append(r["room_id"])
-
- # Check that the correct number of total rooms was returned
- self.assertEqual(channel.json_body["total_rooms"], total_rooms)
-
- # Check that the offset is correct
- # We're only getting 2 rooms each page, so should be 2 * last run_count
- self.assertEqual(channel.json_body["offset"], 2 * (run_count - 1))
-
- if run_count > 1:
- # Check the value of prev_batch is correct
- self.assertEqual(channel.json_body["prev_batch"], 2 * (run_count - 2))
-
- if "next_batch" not in channel.json_body:
- # We have reached the end of the list
- should_repeat = False
- else:
- # Make another query with an updated start value
- start = channel.json_body["next_batch"]
-
- # We should've queried the endpoint 3 times
- self.assertEqual(
- run_count,
- 3,
- msg="Should've queried 3 times for 5 rooms with limit 2 per query",
- )
-
- # Check that we received all of the room ids
- self.assertEqual(room_ids, returned_room_ids)
-
- url = "/_synapse/admin/v1/rooms?from=%d&limit=%d" % (start, limit)
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- def test_correct_room_attributes(self):
- """Test the correct attributes for a room are returned"""
- # Create a test room
- room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
-
- test_alias = "#test:test"
- test_room_name = "something"
-
- # Have another user join the room
- user_2 = self.register_user("user4", "pass")
- user_tok_2 = self.login("user4", "pass")
- self.helper.join(room_id, user_2, tok=user_tok_2)
-
- # Create a new alias to this room
- url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),)
- request, channel = self.make_request(
- "PUT",
- url.encode("ascii"),
- {"room_id": room_id},
- access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Set this new alias as the canonical alias for this room
- self.helper.send_state(
- room_id,
- "m.room.aliases",
- {"aliases": [test_alias]},
- tok=self.admin_user_tok,
- state_key="test",
- )
- self.helper.send_state(
- room_id,
- "m.room.canonical_alias",
- {"alias": test_alias},
- tok=self.admin_user_tok,
- )
-
- # Set a name for the room
- self.helper.send_state(
- room_id, "m.room.name", {"name": test_room_name}, tok=self.admin_user_tok,
- )
-
- # Request the list of rooms
- url = "/_synapse/admin/v1/rooms"
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Check that rooms were returned
- self.assertTrue("rooms" in channel.json_body)
- rooms = channel.json_body["rooms"]
-
- # Check that only one room was returned
- self.assertEqual(len(rooms), 1)
-
- # And that the value of the total_rooms key was correct
- self.assertEqual(channel.json_body["total_rooms"], 1)
-
- # Check that the offset is correct
- # We're not paginating, so should be 0
- self.assertEqual(channel.json_body["offset"], 0)
-
- # Check that there is no `prev_batch`
- self.assertNotIn("prev_batch", channel.json_body)
-
- # Check that there is no `next_batch`
- self.assertNotIn("next_batch", channel.json_body)
-
- # Check that all provided attributes are set
- r = rooms[0]
- self.assertEqual(room_id, r["room_id"])
- self.assertEqual(test_room_name, r["name"])
- self.assertEqual(test_alias, r["canonical_alias"])
-
- def test_room_list_sort_order(self):
- """Test room list sort ordering. alphabetical versus number of members,
- reversing the order, etc.
- """
- # Create 3 test rooms
- room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
- room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
- room_id_3 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
-
- # Set room names in alphabetical order. room 1 -> A, 2 -> B, 3 -> C
- self.helper.send_state(
- room_id_1, "m.room.name", {"name": "A"}, tok=self.admin_user_tok,
- )
- self.helper.send_state(
- room_id_2, "m.room.name", {"name": "B"}, tok=self.admin_user_tok,
- )
- self.helper.send_state(
- room_id_3, "m.room.name", {"name": "C"}, tok=self.admin_user_tok,
- )
-
- # Set room member size in the reverse order. room 1 -> 1 member, 2 -> 2, 3 -> 3
- user_1 = self.register_user("bob1", "pass")
- user_1_tok = self.login("bob1", "pass")
- self.helper.join(room_id_2, user_1, tok=user_1_tok)
-
- user_2 = self.register_user("bob2", "pass")
- user_2_tok = self.login("bob2", "pass")
- self.helper.join(room_id_3, user_2, tok=user_2_tok)
-
- user_3 = self.register_user("bob3", "pass")
- user_3_tok = self.login("bob3", "pass")
- self.helper.join(room_id_3, user_3, tok=user_3_tok)
-
- def _order_test(
- order_type: str, expected_room_list: List[str], reverse: bool = False,
- ):
- """Request the list of rooms in a certain order. Assert that order is what
- we expect
-
- Args:
- order_type: The type of ordering to give the server
- expected_room_list: The list of room_ids in the order we expect to get
- back from the server
- """
- # Request the list of rooms in the given order
- url = "/_synapse/admin/v1/rooms?order_by=%s" % (order_type,)
- if reverse:
- url += "&dir=b"
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(200, channel.code, msg=channel.json_body)
-
- # Check that rooms were returned
- self.assertTrue("rooms" in channel.json_body)
- rooms = channel.json_body["rooms"]
-
- # Check for the correct total_rooms value
- self.assertEqual(channel.json_body["total_rooms"], 3)
-
- # Check that the offset is correct
- # We're not paginating, so should be 0
- self.assertEqual(channel.json_body["offset"], 0)
-
- # Check that there is no `prev_batch`
- self.assertNotIn("prev_batch", channel.json_body)
-
- # Check that there is no `next_batch`
- self.assertNotIn("next_batch", channel.json_body)
-
- # Check that rooms were returned in alphabetical order
- returned_order = [r["room_id"] for r in rooms]
- self.assertListEqual(expected_room_list, returned_order) # order is checked
-
- # Test different sort orders, with forward and reverse directions
- _order_test("alphabetical", [room_id_1, room_id_2, room_id_3])
- _order_test("alphabetical", [room_id_3, room_id_2, room_id_1], reverse=True)
-
- _order_test("size", [room_id_3, room_id_2, room_id_1])
- _order_test("size", [room_id_1, room_id_2, room_id_3], reverse=True)
-
- def test_search_term(self):
- """Test that searching for a room works correctly"""
- # Create two test rooms
- room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
- room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
-
- room_name_1 = "something"
- room_name_2 = "else"
-
- # Set the name for each room
- self.helper.send_state(
- room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok,
- )
- self.helper.send_state(
- room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok,
- )
-
- def _search_test(
- expected_room_id: Optional[str],
- search_term: str,
- expected_http_code: int = 200,
- ):
- """Search for a room and check that the returned room's id is a match
-
- Args:
- expected_room_id: The room_id expected to be returned by the API. Set
- to None to expect zero results for the search
- search_term: The term to search for room names with
- expected_http_code: The expected http code for the request
- """
- url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,)
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
-
- if expected_http_code != 200:
- return
-
- # Check that rooms were returned
- self.assertTrue("rooms" in channel.json_body)
- rooms = channel.json_body["rooms"]
-
- # Check that the expected number of rooms were returned
- expected_room_count = 1 if expected_room_id else 0
- self.assertEqual(len(rooms), expected_room_count)
- self.assertEqual(channel.json_body["total_rooms"], expected_room_count)
-
- # Check that the offset is correct
- # We're not paginating, so should be 0
- self.assertEqual(channel.json_body["offset"], 0)
-
- # Check that there is no `prev_batch`
- self.assertNotIn("prev_batch", channel.json_body)
-
- # Check that there is no `next_batch`
- self.assertNotIn("next_batch", channel.json_body)
-
- if expected_room_id:
- # Check that the first returned room id is correct
- r = rooms[0]
- self.assertEqual(expected_room_id, r["room_id"])
-
- # Perform search tests
- _search_test(room_id_1, "something")
- _search_test(room_id_1, "thing")
-
- _search_test(room_id_2, "else")
- _search_test(room_id_2, "se")
-
- _search_test(None, "foo")
- _search_test(None, "bar")
- _search_test(None, "", expected_http_code=400)
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
new file mode 100644
index 0000000000..249c93722f
--- /dev/null
+++ b/tests/rest/admin/test_room.py
@@ -0,0 +1,966 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Dirk Klimpel
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import urllib.parse
+from typing import List, Optional
+
+from mock import Mock
+
+import synapse.rest.admin
+from synapse.api.errors import Codes
+from synapse.rest.client.v1 import directory, events, login, room
+
+from tests import unittest
+
+"""Tests admin REST events for /rooms paths."""
+
+
+class ShutdownRoomTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ events.register_servlets,
+ room.register_servlets,
+ room.register_deprecated_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.event_creation_handler = hs.get_event_creation_handler()
+ hs.config.user_consent_version = "1"
+
+ consent_uri_builder = Mock()
+ consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
+ self.event_creation_handler._consent_uri_builder = consent_uri_builder
+
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_token = self.login("user", "pass")
+
+ # Mark the admin user as having consented
+ self.get_success(self.store.user_set_consent_version(self.admin_user, "1"))
+
+ def test_shutdown_room_consent(self):
+ """Test that we can shutdown rooms with local users who have not
+ yet accepted the privacy policy. This used to fail when we tried to
+ force part the user from the old room.
+ """
+ self.event_creation_handler._block_events_without_consent_error = None
+
+ room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
+
+ # Assert one user in room
+ users_in_room = self.get_success(self.store.get_users_in_room(room_id))
+ self.assertEqual([self.other_user], users_in_room)
+
+ # Enable require consent to send events
+ self.event_creation_handler._block_events_without_consent_error = "Error"
+
+ # Assert that the user is getting consent error
+ self.helper.send(
+ room_id, body="foo", tok=self.other_user_token, expect_code=403
+ )
+
+ # Test that the admin can still send shutdown
+ url = "admin/shutdown_room/" + room_id
+ request, channel = self.make_request(
+ "POST",
+ url.encode("ascii"),
+ json.dumps({"new_room_user_id": self.admin_user}),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Assert there is now no longer anyone in the room
+ users_in_room = self.get_success(self.store.get_users_in_room(room_id))
+ self.assertEqual([], users_in_room)
+
+ def test_shutdown_room_block_peek(self):
+ """Test that a world_readable room can no longer be peeked into after
+ it has been shut down.
+ """
+
+ self.event_creation_handler._block_events_without_consent_error = None
+
+ room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
+
+ # Enable world readable
+ url = "rooms/%s/state/m.room.history_visibility" % (room_id,)
+ request, channel = self.make_request(
+ "PUT",
+ url.encode("ascii"),
+ json.dumps({"history_visibility": "world_readable"}),
+ access_token=self.other_user_token,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Test that the admin can still send shutdown
+ url = "admin/shutdown_room/" + room_id
+ request, channel = self.make_request(
+ "POST",
+ url.encode("ascii"),
+ json.dumps({"new_room_user_id": self.admin_user}),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Assert we can no longer peek into the room
+ self._assert_peek(room_id, expect_code=403)
+
+ def _assert_peek(self, room_id, expect_code):
+ """Assert that the admin user can (or cannot) peek into the room.
+ """
+
+ url = "rooms/%s/initialSync" % (room_id,)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok
+ )
+ self.render(request)
+ self.assertEqual(
+ expect_code, int(channel.result["code"]), msg=channel.result["body"]
+ )
+
+ url = "events?timeout=0&room_id=" + room_id
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok
+ )
+ self.render(request)
+ self.assertEqual(
+ expect_code, int(channel.result["code"]), msg=channel.result["body"]
+ )
+
+
+class PurgeRoomTestCase(unittest.HomeserverTestCase):
+ """Test /purge_room admin API.
+ """
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ def test_purge_room(self):
+ room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ # All users have to have left the room.
+ self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok)
+
+ url = "/_synapse/admin/v1/purge_room"
+ request, channel = self.make_request(
+ "POST",
+ url.encode("ascii"),
+ {"room_id": room_id},
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Test that the following tables have been purged of all rows related to the room.
+ for table in (
+ "current_state_events",
+ "event_backward_extremities",
+ "event_forward_extremities",
+ "event_json",
+ "event_push_actions",
+ "event_search",
+ "events",
+ "group_rooms",
+ "public_room_list_stream",
+ "receipts_graph",
+ "receipts_linearized",
+ "room_aliases",
+ "room_depth",
+ "room_memberships",
+ "room_stats_state",
+ "room_stats_current",
+ "room_stats_historical",
+ "room_stats_earliest_token",
+ "rooms",
+ "stream_ordering_to_exterm",
+ "users_in_public_rooms",
+ "users_who_share_private_rooms",
+ "appservice_room_list",
+ "e2e_room_keys",
+ "event_push_summary",
+ "pusher_throttle",
+ "group_summary_rooms",
+ "local_invites",
+ "room_account_data",
+ "room_tags",
+ # "state_groups", # Current impl leaves orphaned state groups around.
+ "state_groups_state",
+ ):
+ count = self.get_success(
+ self.store.db.simple_select_one_onecol(
+ table=table,
+ keyvalues={"room_id": room_id},
+ retcol="COUNT(*)",
+ desc="test_purge_room",
+ )
+ )
+
+ self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
+
+
+class RoomTestCase(unittest.HomeserverTestCase):
+ """Test /room admin API.
+ """
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ directory.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ # Create user
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ def test_list_rooms(self):
+ """Test that we can list rooms"""
+ # Create 3 test rooms
+ total_rooms = 3
+ room_ids = []
+ for x in range(total_rooms):
+ room_id = self.helper.create_room_as(
+ self.admin_user, tok=self.admin_user_tok
+ )
+ room_ids.append(room_id)
+
+ # Request the list of rooms
+ url = "/_synapse/admin/v1/rooms"
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ # Check request completed successfully
+ self.assertEqual(200, int(channel.code), msg=channel.json_body)
+
+ # Check that response json body contains a "rooms" key
+ self.assertTrue(
+ "rooms" in channel.json_body,
+ msg="Response body does not " "contain a 'rooms' key",
+ )
+
+ # Check that 3 rooms were returned
+ self.assertEqual(3, len(channel.json_body["rooms"]), msg=channel.json_body)
+
+ # Check their room_ids match
+ returned_room_ids = [room["room_id"] for room in channel.json_body["rooms"]]
+ self.assertEqual(room_ids, returned_room_ids)
+
+ # Check that all fields are available
+ for r in channel.json_body["rooms"]:
+ self.assertIn("name", r)
+ self.assertIn("canonical_alias", r)
+ self.assertIn("joined_members", r)
+ self.assertIn("joined_local_members", r)
+ self.assertIn("version", r)
+ self.assertIn("creator", r)
+ self.assertIn("encryption", r)
+ self.assertIn("federatable", r)
+ self.assertIn("public", r)
+ self.assertIn("join_rules", r)
+ self.assertIn("guest_access", r)
+ self.assertIn("history_visibility", r)
+ self.assertIn("state_events", r)
+
+ # Check that the correct number of total rooms was returned
+ self.assertEqual(channel.json_body["total_rooms"], total_rooms)
+
+ # Check that the offset is correct
+ # Should be 0 as we aren't paginating
+ self.assertEqual(channel.json_body["offset"], 0)
+
+ # Check that the prev_batch parameter is not present
+ self.assertNotIn("prev_batch", channel.json_body)
+
+ # We shouldn't receive a next token here as there's no further rooms to show
+ self.assertNotIn("next_batch", channel.json_body)
+
+ def test_list_rooms_pagination(self):
+ """Test that we can get a full list of rooms through pagination"""
+ # Create 5 test rooms
+ total_rooms = 5
+ room_ids = []
+ for x in range(total_rooms):
+ room_id = self.helper.create_room_as(
+ self.admin_user, tok=self.admin_user_tok
+ )
+ room_ids.append(room_id)
+
+ # Set the name of the rooms so we get a consistent returned ordering
+ for idx, room_id in enumerate(room_ids):
+ self.helper.send_state(
+ room_id, "m.room.name", {"name": str(idx)}, tok=self.admin_user_tok,
+ )
+
+ # Request the list of rooms
+ returned_room_ids = []
+ start = 0
+ limit = 2
+
+ run_count = 0
+ should_repeat = True
+ while should_repeat:
+ run_count += 1
+
+ url = "/_synapse/admin/v1/rooms?from=%d&limit=%d&order_by=%s" % (
+ start,
+ limit,
+ "name",
+ )
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(
+ 200, int(channel.result["code"]), msg=channel.result["body"]
+ )
+
+ self.assertTrue("rooms" in channel.json_body)
+ for r in channel.json_body["rooms"]:
+ returned_room_ids.append(r["room_id"])
+
+ # Check that the correct number of total rooms was returned
+ self.assertEqual(channel.json_body["total_rooms"], total_rooms)
+
+ # Check that the offset is correct
+ # We're only getting 2 rooms each page, so should be 2 * last run_count
+ self.assertEqual(channel.json_body["offset"], 2 * (run_count - 1))
+
+ if run_count > 1:
+ # Check the value of prev_batch is correct
+ self.assertEqual(channel.json_body["prev_batch"], 2 * (run_count - 2))
+
+ if "next_batch" not in channel.json_body:
+ # We have reached the end of the list
+ should_repeat = False
+ else:
+ # Make another query with an updated start value
+ start = channel.json_body["next_batch"]
+
+ # We should've queried the endpoint 3 times
+ self.assertEqual(
+ run_count,
+ 3,
+ msg="Should've queried 3 times for 5 rooms with limit 2 per query",
+ )
+
+ # Check that we received all of the room ids
+ self.assertEqual(room_ids, returned_room_ids)
+
+ url = "/_synapse/admin/v1/rooms?from=%d&limit=%d" % (start, limit)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ def test_correct_room_attributes(self):
+ """Test the correct attributes for a room are returned"""
+ # Create a test room
+ room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ test_alias = "#test:test"
+ test_room_name = "something"
+
+ # Have another user join the room
+ user_2 = self.register_user("user4", "pass")
+ user_tok_2 = self.login("user4", "pass")
+ self.helper.join(room_id, user_2, tok=user_tok_2)
+
+ # Create a new alias to this room
+ url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),)
+ request, channel = self.make_request(
+ "PUT",
+ url.encode("ascii"),
+ {"room_id": room_id},
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Set this new alias as the canonical alias for this room
+ self.helper.send_state(
+ room_id,
+ "m.room.aliases",
+ {"aliases": [test_alias]},
+ tok=self.admin_user_tok,
+ state_key="test",
+ )
+ self.helper.send_state(
+ room_id,
+ "m.room.canonical_alias",
+ {"alias": test_alias},
+ tok=self.admin_user_tok,
+ )
+
+ # Set a name for the room
+ self.helper.send_state(
+ room_id, "m.room.name", {"name": test_room_name}, tok=self.admin_user_tok,
+ )
+
+ # Request the list of rooms
+ url = "/_synapse/admin/v1/rooms"
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Check that rooms were returned
+ self.assertTrue("rooms" in channel.json_body)
+ rooms = channel.json_body["rooms"]
+
+ # Check that only one room was returned
+ self.assertEqual(len(rooms), 1)
+
+ # And that the value of the total_rooms key was correct
+ self.assertEqual(channel.json_body["total_rooms"], 1)
+
+ # Check that the offset is correct
+ # We're not paginating, so should be 0
+ self.assertEqual(channel.json_body["offset"], 0)
+
+ # Check that there is no `prev_batch`
+ self.assertNotIn("prev_batch", channel.json_body)
+
+ # Check that there is no `next_batch`
+ self.assertNotIn("next_batch", channel.json_body)
+
+ # Check that all provided attributes are set
+ r = rooms[0]
+ self.assertEqual(room_id, r["room_id"])
+ self.assertEqual(test_room_name, r["name"])
+ self.assertEqual(test_alias, r["canonical_alias"])
+
+ def test_room_list_sort_order(self):
+ """Test room list sort ordering. alphabetical name versus number of members,
+ reversing the order, etc.
+ """
+
+ def _set_canonical_alias(room_id: str, test_alias: str, admin_user_tok: str):
+ # Create a new alias to this room
+ url = "/_matrix/client/r0/directory/room/%s" % (
+ urllib.parse.quote(test_alias),
+ )
+ request, channel = self.make_request(
+ "PUT",
+ url.encode("ascii"),
+ {"room_id": room_id},
+ access_token=admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(
+ 200, int(channel.result["code"]), msg=channel.result["body"]
+ )
+
+ # Set this new alias as the canonical alias for this room
+ self.helper.send_state(
+ room_id,
+ "m.room.aliases",
+ {"aliases": [test_alias]},
+ tok=admin_user_tok,
+ state_key="test",
+ )
+ self.helper.send_state(
+ room_id,
+ "m.room.canonical_alias",
+ {"alias": test_alias},
+ tok=admin_user_tok,
+ )
+
+ def _order_test(
+ order_type: str, expected_room_list: List[str], reverse: bool = False,
+ ):
+ """Request the list of rooms in a certain order. Assert that order is what
+ we expect
+
+ Args:
+ order_type: The type of ordering to give the server
+ expected_room_list: The list of room_ids in the order we expect to get
+ back from the server
+ """
+ # Request the list of rooms in the given order
+ url = "/_synapse/admin/v1/rooms?order_by=%s" % (order_type,)
+ if reverse:
+ url += "&dir=b"
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # Check that rooms were returned
+ self.assertTrue("rooms" in channel.json_body)
+ rooms = channel.json_body["rooms"]
+
+ # Check for the correct total_rooms value
+ self.assertEqual(channel.json_body["total_rooms"], 3)
+
+ # Check that the offset is correct
+ # We're not paginating, so should be 0
+ self.assertEqual(channel.json_body["offset"], 0)
+
+ # Check that there is no `prev_batch`
+ self.assertNotIn("prev_batch", channel.json_body)
+
+ # Check that there is no `next_batch`
+ self.assertNotIn("next_batch", channel.json_body)
+
+ # Check that rooms were returned in alphabetical order
+ returned_order = [r["room_id"] for r in rooms]
+ self.assertListEqual(expected_room_list, returned_order) # order is checked
+
+ # Create 3 test rooms
+ room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_id_3 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ # Set room names in alphabetical order. room 1 -> A, 2 -> B, 3 -> C
+ self.helper.send_state(
+ room_id_1, "m.room.name", {"name": "A"}, tok=self.admin_user_tok,
+ )
+ self.helper.send_state(
+ room_id_2, "m.room.name", {"name": "B"}, tok=self.admin_user_tok,
+ )
+ self.helper.send_state(
+ room_id_3, "m.room.name", {"name": "C"}, tok=self.admin_user_tok,
+ )
+
+ # Set room canonical room aliases
+ _set_canonical_alias(room_id_1, "#A_alias:test", self.admin_user_tok)
+ _set_canonical_alias(room_id_2, "#B_alias:test", self.admin_user_tok)
+ _set_canonical_alias(room_id_3, "#C_alias:test", self.admin_user_tok)
+
+ # Set room member size in the reverse order. room 1 -> 1 member, 2 -> 2, 3 -> 3
+ user_1 = self.register_user("bob1", "pass")
+ user_1_tok = self.login("bob1", "pass")
+ self.helper.join(room_id_2, user_1, tok=user_1_tok)
+
+ user_2 = self.register_user("bob2", "pass")
+ user_2_tok = self.login("bob2", "pass")
+ self.helper.join(room_id_3, user_2, tok=user_2_tok)
+
+ user_3 = self.register_user("bob3", "pass")
+ user_3_tok = self.login("bob3", "pass")
+ self.helper.join(room_id_3, user_3, tok=user_3_tok)
+
+ # Test different sort orders, with forward and reverse directions
+ _order_test("name", [room_id_1, room_id_2, room_id_3])
+ _order_test("name", [room_id_3, room_id_2, room_id_1], reverse=True)
+
+ _order_test("canonical_alias", [room_id_1, room_id_2, room_id_3])
+ _order_test("canonical_alias", [room_id_3, room_id_2, room_id_1], reverse=True)
+
+ _order_test("joined_members", [room_id_3, room_id_2, room_id_1])
+ _order_test("joined_members", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("joined_local_members", [room_id_3, room_id_2, room_id_1])
+ _order_test(
+ "joined_local_members", [room_id_1, room_id_2, room_id_3], reverse=True
+ )
+
+ _order_test("version", [room_id_1, room_id_2, room_id_3])
+ _order_test("version", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("creator", [room_id_1, room_id_2, room_id_3])
+ _order_test("creator", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("encryption", [room_id_1, room_id_2, room_id_3])
+ _order_test("encryption", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("federatable", [room_id_1, room_id_2, room_id_3])
+ _order_test("federatable", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("public", [room_id_1, room_id_2, room_id_3])
+ # Different sort order of SQlite and PostreSQL
+ # _order_test("public", [room_id_3, room_id_2, room_id_1], reverse=True)
+
+ _order_test("join_rules", [room_id_1, room_id_2, room_id_3])
+ _order_test("join_rules", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("guest_access", [room_id_1, room_id_2, room_id_3])
+ _order_test("guest_access", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("history_visibility", [room_id_1, room_id_2, room_id_3])
+ _order_test(
+ "history_visibility", [room_id_1, room_id_2, room_id_3], reverse=True
+ )
+
+ _order_test("state_events", [room_id_3, room_id_2, room_id_1])
+ _order_test("state_events", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ def test_search_term(self):
+ """Test that searching for a room works correctly"""
+ # Create two test rooms
+ room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ room_name_1 = "something"
+ room_name_2 = "else"
+
+ # Set the name for each room
+ self.helper.send_state(
+ room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok,
+ )
+ self.helper.send_state(
+ room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok,
+ )
+
+ def _search_test(
+ expected_room_id: Optional[str],
+ search_term: str,
+ expected_http_code: int = 200,
+ ):
+ """Search for a room and check that the returned room's id is a match
+
+ Args:
+ expected_room_id: The room_id expected to be returned by the API. Set
+ to None to expect zero results for the search
+ search_term: The term to search for room names with
+ expected_http_code: The expected http code for the request
+ """
+ url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
+
+ if expected_http_code != 200:
+ return
+
+ # Check that rooms were returned
+ self.assertTrue("rooms" in channel.json_body)
+ rooms = channel.json_body["rooms"]
+
+ # Check that the expected number of rooms were returned
+ expected_room_count = 1 if expected_room_id else 0
+ self.assertEqual(len(rooms), expected_room_count)
+ self.assertEqual(channel.json_body["total_rooms"], expected_room_count)
+
+ # Check that the offset is correct
+ # We're not paginating, so should be 0
+ self.assertEqual(channel.json_body["offset"], 0)
+
+ # Check that there is no `prev_batch`
+ self.assertNotIn("prev_batch", channel.json_body)
+
+ # Check that there is no `next_batch`
+ self.assertNotIn("next_batch", channel.json_body)
+
+ if expected_room_id:
+ # Check that the first returned room id is correct
+ r = rooms[0]
+ self.assertEqual(expected_room_id, r["room_id"])
+
+ # Perform search tests
+ _search_test(room_id_1, "something")
+ _search_test(room_id_1, "thing")
+
+ _search_test(room_id_2, "else")
+ _search_test(room_id_2, "se")
+
+ _search_test(None, "foo")
+ _search_test(None, "bar")
+ _search_test(None, "", expected_http_code=400)
+
+
+class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, homeserver):
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.creator = self.register_user("creator", "test")
+ self.creator_tok = self.login("creator", "test")
+
+ self.second_user_id = self.register_user("second", "test")
+ self.second_tok = self.login("second", "test")
+
+ self.public_room_id = self.helper.create_room_as(
+ self.creator, tok=self.creator_tok, is_public=True
+ )
+ self.url = "/_synapse/admin/v1/join/{}".format(self.public_room_id)
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error 403 is returned.
+ """
+ body = json.dumps({"user_id": self.second_user_id})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.second_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_invalid_parameter(self):
+ """
+ If a parameter is missing, return an error
+ """
+ body = json.dumps({"unknown_parameter": "@unknown:test"})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
+
+ def test_local_user_does_not_exist(self):
+ """
+ Tests that a lookup for a user that does not exist returns a 404
+ """
+ body = json.dumps({"user_id": "@unknown:test"})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_remote_user(self):
+ """
+ Check that only local user can join rooms.
+ """
+ body = json.dumps({"user_id": "@not:exist.bla"})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ "This endpoint can only be used with local users",
+ channel.json_body["error"],
+ )
+
+ def test_room_does_not_exist(self):
+ """
+ Check that unknown rooms/server return error 404.
+ """
+ body = json.dumps({"user_id": self.second_user_id})
+ url = "/_synapse/admin/v1/join/!unknown:test"
+
+ request, channel = self.make_request(
+ "POST",
+ url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("No known servers", channel.json_body["error"])
+
+ def test_room_is_not_valid(self):
+ """
+ Check that invalid room names, return an error 400.
+ """
+ body = json.dumps({"user_id": self.second_user_id})
+ url = "/_synapse/admin/v1/join/invalidroom"
+
+ request, channel = self.make_request(
+ "POST",
+ url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ "invalidroom was not legal room ID or room alias",
+ channel.json_body["error"],
+ )
+
+ def test_join_public_room(self):
+ """
+ Test joining a local user to a public room with "JoinRules.PUBLIC"
+ """
+ body = json.dumps({"user_id": self.second_user_id})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(self.public_room_id, channel.json_body["room_id"])
+
+ # Validate if user is a member of the room
+
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
+ )
+ self.render(request)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0])
+
+ def test_join_private_room_if_not_member(self):
+ """
+ Test joining a local user to a private room with "JoinRules.INVITE"
+ when server admin is not member of this room.
+ """
+ private_room_id = self.helper.create_room_as(
+ self.creator, tok=self.creator_tok, is_public=False
+ )
+ url = "/_synapse/admin/v1/join/{}".format(private_room_id)
+ body = json.dumps({"user_id": self.second_user_id})
+
+ request, channel = self.make_request(
+ "POST",
+ url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_join_private_room_if_member(self):
+ """
+ Test joining a local user to a private room with "JoinRules.INVITE",
+ when server admin is member of this room.
+ """
+ private_room_id = self.helper.create_room_as(
+ self.creator, tok=self.creator_tok, is_public=False
+ )
+ self.helper.invite(
+ room=private_room_id,
+ src=self.creator,
+ targ=self.admin_user,
+ tok=self.creator_tok,
+ )
+ self.helper.join(
+ room=private_room_id, user=self.admin_user, tok=self.admin_user_tok
+ )
+
+ # Validate if server admin is a member of the room
+
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
+
+ # Join user to room.
+
+ url = "/_synapse/admin/v1/join/{}".format(private_room_id)
+ body = json.dumps({"user_id": self.second_user_id})
+
+ request, channel = self.make_request(
+ "POST",
+ url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(private_room_id, channel.json_body["room_id"])
+
+ # Validate if user is a member of the room
+
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
+ )
+ self.render(request)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
+
+ def test_join_private_room_if_owner(self):
+ """
+ Test joining a local user to a private room with "JoinRules.INVITE",
+ when server admin is owner of this room.
+ """
+ private_room_id = self.helper.create_room_as(
+ self.admin_user, tok=self.admin_user_tok, is_public=False
+ )
+ url = "/_synapse/admin/v1/join/{}".format(private_room_id)
+ body = json.dumps({"user_id": self.second_user_id})
+
+ request, channel = self.make_request(
+ "POST",
+ url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(private_room_id, channel.json_body["room_id"])
+
+ # Validate if user is a member of the room
+
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
+ )
+ self.render(request)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
diff --git a/tests/rest/client/test_power_levels.py b/tests/rest/client/test_power_levels.py
new file mode 100644
index 0000000000..913ea3c98e
--- /dev/null
+++ b/tests/rest/client/test_power_levels.py
@@ -0,0 +1,205 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v2_alpha import sync
+
+from tests.unittest import HomeserverTestCase
+
+
+class PowerLevelsTestCase(HomeserverTestCase):
+ """Tests that power levels are enforced in various situations"""
+
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+
+ return self.setup_test_homeserver(config=config)
+
+ def prepare(self, reactor, clock, hs):
+ # register a room admin, moderator and regular user
+ self.admin_user_id = self.register_user("admin", "pass")
+ self.admin_access_token = self.login("admin", "pass")
+ self.mod_user_id = self.register_user("mod", "pass")
+ self.mod_access_token = self.login("mod", "pass")
+ self.user_user_id = self.register_user("user", "pass")
+ self.user_access_token = self.login("user", "pass")
+
+ # Create a room
+ self.room_id = self.helper.create_room_as(
+ self.admin_user_id, tok=self.admin_access_token
+ )
+
+ # Invite the other users
+ self.helper.invite(
+ room=self.room_id,
+ src=self.admin_user_id,
+ tok=self.admin_access_token,
+ targ=self.mod_user_id,
+ )
+ self.helper.invite(
+ room=self.room_id,
+ src=self.admin_user_id,
+ tok=self.admin_access_token,
+ targ=self.user_user_id,
+ )
+
+ # Make the other users join the room
+ self.helper.join(
+ room=self.room_id, user=self.mod_user_id, tok=self.mod_access_token
+ )
+ self.helper.join(
+ room=self.room_id, user=self.user_user_id, tok=self.user_access_token
+ )
+
+ # Mod the mod
+ room_power_levels = self.helper.get_state(
+ self.room_id, "m.room.power_levels", tok=self.admin_access_token,
+ )
+
+ # Update existing power levels with mod at PL50
+ room_power_levels["users"].update({self.mod_user_id: 50})
+
+ self.helper.send_state(
+ self.room_id,
+ "m.room.power_levels",
+ room_power_levels,
+ tok=self.admin_access_token,
+ )
+
+ def test_non_admins_cannot_enable_room_encryption(self):
+ # have the mod try to enable room encryption
+ self.helper.send_state(
+ self.room_id,
+ "m.room.encryption",
+ {"algorithm": "m.megolm.v1.aes-sha2"},
+ tok=self.mod_access_token,
+ expect_code=403, # expect failure
+ )
+
+ # have the user try to enable room encryption
+ self.helper.send_state(
+ self.room_id,
+ "m.room.encryption",
+ {"algorithm": "m.megolm.v1.aes-sha2"},
+ tok=self.user_access_token,
+ expect_code=403, # expect failure
+ )
+
+ def test_non_admins_cannot_send_server_acl(self):
+ # have the mod try to send a server ACL
+ self.helper.send_state(
+ self.room_id,
+ "m.room.server_acl",
+ {
+ "allow": ["*"],
+ "allow_ip_literals": False,
+ "deny": ["*.evil.com", "evil.com"],
+ },
+ tok=self.mod_access_token,
+ expect_code=403, # expect failure
+ )
+
+ # have the user try to send a server ACL
+ self.helper.send_state(
+ self.room_id,
+ "m.room.server_acl",
+ {
+ "allow": ["*"],
+ "allow_ip_literals": False,
+ "deny": ["*.evil.com", "evil.com"],
+ },
+ tok=self.user_access_token,
+ expect_code=403, # expect failure
+ )
+
+ def test_non_admins_cannot_tombstone_room(self):
+ # Create another room that will serve as our "upgraded room"
+ self.upgraded_room_id = self.helper.create_room_as(
+ self.admin_user_id, tok=self.admin_access_token
+ )
+
+ # have the mod try to send a tombstone event
+ self.helper.send_state(
+ self.room_id,
+ "m.room.tombstone",
+ {
+ "body": "This room has been replaced",
+ "replacement_room": self.upgraded_room_id,
+ },
+ tok=self.mod_access_token,
+ expect_code=403, # expect failure
+ )
+
+ # have the user try to send a tombstone event
+ self.helper.send_state(
+ self.room_id,
+ "m.room.tombstone",
+ {
+ "body": "This room has been replaced",
+ "replacement_room": self.upgraded_room_id,
+ },
+ tok=self.user_access_token,
+ expect_code=403, # expect failure
+ )
+
+ def test_admins_can_enable_room_encryption(self):
+ # have the admin try to enable room encryption
+ self.helper.send_state(
+ self.room_id,
+ "m.room.encryption",
+ {"algorithm": "m.megolm.v1.aes-sha2"},
+ tok=self.admin_access_token,
+ expect_code=200, # expect success
+ )
+
+ def test_admins_can_send_server_acl(self):
+ # have the admin try to send a server ACL
+ self.helper.send_state(
+ self.room_id,
+ "m.room.server_acl",
+ {
+ "allow": ["*"],
+ "allow_ip_literals": False,
+ "deny": ["*.evil.com", "evil.com"],
+ },
+ tok=self.admin_access_token,
+ expect_code=200, # expect success
+ )
+
+ def test_admins_can_tombstone_room(self):
+ # Create another room that will serve as our "upgraded room"
+ self.upgraded_room_id = self.helper.create_room_as(
+ self.admin_user_id, tok=self.admin_access_token
+ )
+
+ # have the admin try to send a tombstone event
+ self.helper.send_state(
+ self.room_id,
+ "m.room.tombstone",
+ {
+ "body": "This room has been replaced",
+ "replacement_room": self.upgraded_room_id,
+ },
+ tok=self.admin_access_token,
+ expect_code=200, # expect success
+ )
diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py
index a3d7e3c046..171632e195 100644
--- a/tests/rest/client/test_transactions.py
+++ b/tests/rest/client/test_transactions.py
@@ -2,7 +2,7 @@ from mock import Mock, call
from twisted.internet import defer, reactor
-from synapse.logging.context import LoggingContext
+from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context
from synapse.rest.client.transactions import CLEANUP_PERIOD_MS, HttpTransactionCache
from synapse.util import Clock
@@ -52,14 +52,14 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
def test():
with LoggingContext("c") as c1:
res = yield self.cache.fetch_or_execute(self.mock_key, cb)
- self.assertIs(LoggingContext.current_context(), c1)
+ self.assertIs(current_context(), c1)
self.assertEqual(res, "yay")
# run the test twice in parallel
d = defer.gatherResults([test(), test()])
- self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel)
+ self.assertIs(current_context(), SENTINEL_CONTEXT)
yield d
- self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel)
+ self.assertIs(current_context(), SENTINEL_CONTEXT)
@defer.inlineCallbacks
def test_does_not_cache_exceptions(self):
@@ -81,11 +81,11 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
yield self.cache.fetch_or_execute(self.mock_key, cb)
except Exception as e:
self.assertEqual(e.args[0], "boo")
- self.assertIs(LoggingContext.current_context(), test_context)
+ self.assertIs(current_context(), test_context)
res = yield self.cache.fetch_or_execute(self.mock_key, cb)
self.assertEqual(res, self.mock_http_response)
- self.assertIs(LoggingContext.current_context(), test_context)
+ self.assertIs(current_context(), test_context)
@defer.inlineCallbacks
def test_does_not_cache_failures(self):
@@ -107,11 +107,11 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
yield self.cache.fetch_or_execute(self.mock_key, cb)
except Exception as e:
self.assertEqual(e.args[0], "boo")
- self.assertIs(LoggingContext.current_context(), test_context)
+ self.assertIs(current_context(), test_context)
res = yield self.cache.fetch_or_execute(self.mock_key, cb)
self.assertEqual(res, self.mock_http_response)
- self.assertIs(LoggingContext.current_context(), test_context)
+ self.assertIs(current_context(), test_context)
@defer.inlineCallbacks
def test_cleans_up(self):
diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py
index ffb2de1505..b54b06482b 100644
--- a/tests/rest/client/v1/test_events.py
+++ b/tests/rest/client/v1/test_events.py
@@ -50,7 +50,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
return hs
- def prepare(self, hs, reactor, clock):
+ def prepare(self, reactor, clock, hs):
# register an account
self.user_id = self.register_user("sid1", "pass")
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index da2c9bfa1e..1856c7ffd5 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -257,7 +257,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.code, 200, channel.result)
-class CASRedirectConfirmTestCase(unittest.HomeserverTestCase):
+class CASTestCase(unittest.HomeserverTestCase):
servlets = [
login.register_servlets,
@@ -274,6 +274,9 @@ class CASRedirectConfirmTestCase(unittest.HomeserverTestCase):
"service_url": "https://matrix.goodserver.com:8448",
}
+ cas_user_id = "username"
+ self.user_id = "@%s:test" % cas_user_id
+
async def get_raw(uri, args):
"""Return an example response payload from a call to the `/proxyValidate`
endpoint of a CAS server, copied from
@@ -282,10 +285,11 @@ class CASRedirectConfirmTestCase(unittest.HomeserverTestCase):
This needs to be returned by an async function (as opposed to set as the
mock's return value) because the corresponding Synapse code awaits on it.
"""
- return """
+ return (
+ """
<cas:serviceResponse xmlns:cas='http://www.yale.edu/tp/cas'>
<cas:authenticationSuccess>
- <cas:user>username</cas:user>
+ <cas:user>%s</cas:user>
<cas:proxyGrantingTicket>PGTIOU-84678-8a9d...</cas:proxyGrantingTicket>
<cas:proxies>
<cas:proxy>https://proxy2/pgtUrl</cas:proxy>
@@ -294,6 +298,8 @@ class CASRedirectConfirmTestCase(unittest.HomeserverTestCase):
</cas:authenticationSuccess>
</cas:serviceResponse>
"""
+ % cas_user_id
+ )
mocked_http_client = Mock(spec=["get_raw"])
mocked_http_client.get_raw.side_effect = get_raw
@@ -304,6 +310,9 @@ class CASRedirectConfirmTestCase(unittest.HomeserverTestCase):
return self.hs
+ def prepare(self, reactor, clock, hs):
+ self.deactivate_account_handler = hs.get_deactivate_account_handler()
+
def test_cas_redirect_confirm(self):
"""Tests that the SSO login flow serves a confirmation page before redirecting a
user to the redirect URL.
@@ -350,7 +359,14 @@ class CASRedirectConfirmTestCase(unittest.HomeserverTestCase):
def test_cas_redirect_whitelisted(self):
"""Tests that the SSO login flow serves a redirect to a whitelisted url
"""
- redirect_url = "https://legit-site.com/"
+ self._test_redirect("https://legit-site.com/")
+
+ @override_config({"public_baseurl": "https://example.com"})
+ def test_cas_redirect_login_fallback(self):
+ self._test_redirect("https://example.com/_matrix/static/client/login")
+
+ def _test_redirect(self, redirect_url):
+ """Tests that the SSO login flow serves a redirect for the given redirect URL."""
cas_ticket_url = (
"/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket"
% (urllib.parse.quote(redirect_url))
@@ -363,3 +379,30 @@ class CASRedirectConfirmTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 302)
location_headers = channel.headers.getRawHeaders("Location")
self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url)
+
+ @override_config({"sso": {"client_whitelist": ["https://legit-site.com/"]}})
+ def test_deactivated_user(self):
+ """Logging in as a deactivated account should error."""
+ redirect_url = "https://legit-site.com/"
+
+ # First login (to create the user).
+ self._test_redirect(redirect_url)
+
+ # Deactivate the account.
+ self.get_success(
+ self.deactivate_account_handler.deactivate_account(self.user_id, False)
+ )
+
+ # Request the CAS ticket.
+ cas_ticket_url = (
+ "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket"
+ % (urllib.parse.quote(redirect_url))
+ )
+
+ # Get Synapse to call the fake CAS and serve the template.
+ request, channel = self.make_request("GET", cas_ticket_url)
+ self.render(request)
+
+ # Because the user is deactivated they are served an error template.
+ self.assertEqual(channel.code, 403)
+ self.assertIn(b"SSO account deactivated", channel.result["body"])
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 873d5ef99c..371637618d 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -18,6 +18,7 @@
import json
import time
+from typing import Any, Dict, Optional
import attr
@@ -142,7 +143,34 @@ class RestHelper(object):
return channel.json_body
- def send_state(self, room_id, event_type, body, tok, expect_code=200, state_key=""):
+ def _read_write_state(
+ self,
+ room_id: str,
+ event_type: str,
+ body: Optional[Dict[str, Any]],
+ tok: str,
+ expect_code: int = 200,
+ state_key: str = "",
+ method: str = "GET",
+ ) -> Dict:
+ """Read or write some state from a given room
+
+ Args:
+ room_id:
+ event_type: The type of state event
+ body: Body that is sent when making the request. The content of the state event.
+ If None, the request to the server will have an empty body
+ tok: The access token to use
+ expect_code: The HTTP code to expect in the response
+ state_key:
+ method: "GET" or "PUT" for reading or writing state, respectively
+
+ Returns:
+ The response body from the server
+
+ Raises:
+ AssertionError: if expect_code doesn't match the HTTP code we received
+ """
path = "/_matrix/client/r0/rooms/%s/state/%s/%s" % (
room_id,
event_type,
@@ -151,9 +179,13 @@ class RestHelper(object):
if tok:
path = path + "?access_token=%s" % tok
- request, channel = make_request(
- self.hs.get_reactor(), "PUT", path, json.dumps(body).encode("utf8")
- )
+ # Set request body if provided
+ content = b""
+ if body is not None:
+ content = json.dumps(body).encode("utf8")
+
+ request, channel = make_request(self.hs.get_reactor(), method, path, content)
+
render(request, self.resource, self.hs.get_reactor())
assert int(channel.result["code"]) == expect_code, (
@@ -163,6 +195,62 @@ class RestHelper(object):
return channel.json_body
+ def get_state(
+ self,
+ room_id: str,
+ event_type: str,
+ tok: str,
+ expect_code: int = 200,
+ state_key: str = "",
+ ):
+ """Gets some state from a room
+
+ Args:
+ room_id:
+ event_type: The type of state event
+ tok: The access token to use
+ expect_code: The HTTP code to expect in the response
+ state_key:
+
+ Returns:
+ The response body from the server
+
+ Raises:
+ AssertionError: if expect_code doesn't match the HTTP code we received
+ """
+ return self._read_write_state(
+ room_id, event_type, None, tok, expect_code, state_key, method="GET"
+ )
+
+ def send_state(
+ self,
+ room_id: str,
+ event_type: str,
+ body: Dict[str, Any],
+ tok: str,
+ expect_code: int = 200,
+ state_key: str = "",
+ ):
+ """Set some state in a room
+
+ Args:
+ room_id:
+ event_type: The type of state event
+ body: Body that is sent when making the request. The content of the state event.
+ tok: The access token to use
+ expect_code: The HTTP code to expect in the response
+ state_key:
+
+ Returns:
+ The response body from the server
+
+ Raises:
+ AssertionError: if expect_code doesn't match the HTTP code we received
+ """
+ return self._read_write_state(
+ room_id, event_type, body, tok, expect_code, state_key, method="PUT"
+ )
+
def upload_media(
self,
resource: Resource,
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index c3facc00eb..45a9d445f8 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -24,6 +24,7 @@ import pkg_resources
import synapse.rest.admin
from synapse.api.constants import LoginType, Membership
+from synapse.api.errors import Codes
from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import account, register
@@ -325,3 +326,304 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
)
self.render(request)
self.assertEqual(request.code, 200)
+
+
+class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ account.register_servlets,
+ login.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+
+ # Email config.
+ self.email_attempts = []
+
+ def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs):
+ self.email_attempts.append(msg)
+
+ config["email"] = {
+ "enable_notifs": False,
+ "template_dir": os.path.abspath(
+ pkg_resources.resource_filename("synapse", "res/templates")
+ ),
+ "smtp_host": "127.0.0.1",
+ "smtp_port": 20,
+ "require_transport_security": False,
+ "smtp_user": None,
+ "smtp_pass": None,
+ "notif_from": "test@example.com",
+ }
+ config["public_baseurl"] = "https://example.com"
+
+ self.hs = self.setup_test_homeserver(config=config, sendmail=sendmail)
+ return self.hs
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ self.user_id = self.register_user("kermit", "test")
+ self.user_id_tok = self.login("kermit", "test")
+ self.email = "test@example.com"
+ self.url_3pid = b"account/3pid"
+
+ def test_add_email(self):
+ """Test adding an email to profile
+ """
+ client_secret = "foobar"
+ session_id = self._request_token(self.email, client_secret)
+
+ self.assertEquals(len(self.email_attempts), 1)
+ link = self._get_link_from_email()
+
+ self._validate_token(link)
+
+ request, channel = self.make_request(
+ "POST",
+ b"/_matrix/client/unstable/account/3pid/add",
+ {
+ "client_secret": client_secret,
+ "sid": session_id,
+ "auth": {
+ "type": "m.login.password",
+ "user": self.user_id,
+ "password": "test",
+ },
+ },
+ access_token=self.user_id_tok,
+ )
+
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Get user
+ request, channel = self.make_request(
+ "GET", self.url_3pid, access_token=self.user_id_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
+ self.assertEqual(self.email, channel.json_body["threepids"][0]["address"])
+
+ def test_add_email_if_disabled(self):
+ """Test adding email to profile when doing so is disallowed
+ """
+ self.hs.config.enable_3pid_changes = False
+
+ client_secret = "foobar"
+ session_id = self._request_token(self.email, client_secret)
+
+ self.assertEquals(len(self.email_attempts), 1)
+ link = self._get_link_from_email()
+
+ self._validate_token(link)
+
+ request, channel = self.make_request(
+ "POST",
+ b"/_matrix/client/unstable/account/3pid/add",
+ {
+ "client_secret": client_secret,
+ "sid": session_id,
+ "auth": {
+ "type": "m.login.password",
+ "user": self.user_id,
+ "password": "test",
+ },
+ },
+ access_token=self.user_id_tok,
+ )
+ self.render(request)
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ # Get user
+ request, channel = self.make_request(
+ "GET", self.url_3pid, access_token=self.user_id_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertFalse(channel.json_body["threepids"])
+
+ def test_delete_email(self):
+ """Test deleting an email from profile
+ """
+ # Add a threepid
+ self.get_success(
+ self.store.user_add_threepid(
+ user_id=self.user_id,
+ medium="email",
+ address=self.email,
+ validated_at=0,
+ added_at=0,
+ )
+ )
+
+ request, channel = self.make_request(
+ "POST",
+ b"account/3pid/delete",
+ {"medium": "email", "address": self.email},
+ access_token=self.user_id_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Get user
+ request, channel = self.make_request(
+ "GET", self.url_3pid, access_token=self.user_id_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertFalse(channel.json_body["threepids"])
+
+ def test_delete_email_if_disabled(self):
+ """Test deleting an email from profile when disallowed
+ """
+ self.hs.config.enable_3pid_changes = False
+
+ # Add a threepid
+ self.get_success(
+ self.store.user_add_threepid(
+ user_id=self.user_id,
+ medium="email",
+ address=self.email,
+ validated_at=0,
+ added_at=0,
+ )
+ )
+
+ request, channel = self.make_request(
+ "POST",
+ b"account/3pid/delete",
+ {"medium": "email", "address": self.email},
+ access_token=self.user_id_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ # Get user
+ request, channel = self.make_request(
+ "GET", self.url_3pid, access_token=self.user_id_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
+ self.assertEqual(self.email, channel.json_body["threepids"][0]["address"])
+
+ def test_cant_add_email_without_clicking_link(self):
+ """Test that we do actually need to click the link in the email
+ """
+ client_secret = "foobar"
+ session_id = self._request_token(self.email, client_secret)
+
+ self.assertEquals(len(self.email_attempts), 1)
+
+ # Attempt to add email without clicking the link
+ request, channel = self.make_request(
+ "POST",
+ b"/_matrix/client/unstable/account/3pid/add",
+ {
+ "client_secret": client_secret,
+ "sid": session_id,
+ "auth": {
+ "type": "m.login.password",
+ "user": self.user_id,
+ "password": "test",
+ },
+ },
+ access_token=self.user_id_tok,
+ )
+ self.render(request)
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
+
+ # Get user
+ request, channel = self.make_request(
+ "GET", self.url_3pid, access_token=self.user_id_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertFalse(channel.json_body["threepids"])
+
+ def test_no_valid_token(self):
+ """Test that we do actually need to request a token and can't just
+ make a session up.
+ """
+ client_secret = "foobar"
+ session_id = "weasle"
+
+ # Attempt to add email without even requesting an email
+ request, channel = self.make_request(
+ "POST",
+ b"/_matrix/client/unstable/account/3pid/add",
+ {
+ "client_secret": client_secret,
+ "sid": session_id,
+ "auth": {
+ "type": "m.login.password",
+ "user": self.user_id,
+ "password": "test",
+ },
+ },
+ access_token=self.user_id_tok,
+ )
+ self.render(request)
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
+
+ # Get user
+ request, channel = self.make_request(
+ "GET", self.url_3pid, access_token=self.user_id_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertFalse(channel.json_body["threepids"])
+
+ def _request_token(self, email, client_secret):
+ request, channel = self.make_request(
+ "POST",
+ b"account/3pid/email/requestToken",
+ {"client_secret": client_secret, "email": email, "send_attempt": 1},
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+
+ return channel.json_body["sid"]
+
+ def _validate_token(self, link):
+ # Remove the host
+ path = link.replace("https://example.com", "")
+
+ request, channel = self.make_request("GET", path, shorthand=False)
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+
+ def _get_link_from_email(self):
+ assert self.email_attempts, "No emails have been sent"
+
+ raw_msg = self.email_attempts[-1].decode("UTF-8")
+ mail = Parser().parsestr(raw_msg)
+
+ text = None
+ for part in mail.walk():
+ if part.get_content_type() == "text/plain":
+ text = part.get_payload(decode=True).decode("UTF-8")
+ break
+
+ if not text:
+ self.fail("Could not find text portion of email to parse")
+
+ match = re.search(r"https://example.com\S+", text)
+ assert match, "Could not find link in email"
+
+ return match.group(0)
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index b6df1396ad..624bf5ada2 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -104,7 +104,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
)
self.render(request)
- # Now we should have fufilled a complete auth flow, including
+ # Now we should have fulfilled a complete auth flow, including
# the recaptcha fallback step, we can then send a
# request to the register API with the session in the authdict.
request, channel = self.make_request(
@@ -115,3 +115,69 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
# We're given a registered user.
self.assertEqual(channel.json_body["user_id"], "@user:test")
+
+ def test_cannot_change_operation(self):
+ """
+ The initial requested operation cannot be modified during the user interactive authentication session.
+ """
+
+ # Make the initial request to register. (Later on a different password
+ # will be used.)
+ request, channel = self.make_request(
+ "POST",
+ "register",
+ {"username": "user", "type": "m.login.password", "password": "bar"},
+ )
+ self.render(request)
+
+ # Returns a 401 as per the spec
+ self.assertEqual(request.code, 401)
+ # Grab the session
+ session = channel.json_body["session"]
+ # Assert our configured public key is being given
+ self.assertEqual(
+ channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake"
+ )
+
+ request, channel = self.make_request(
+ "GET", "auth/m.login.recaptcha/fallback/web?session=" + session
+ )
+ self.render(request)
+ self.assertEqual(request.code, 200)
+
+ request, channel = self.make_request(
+ "POST",
+ "auth/m.login.recaptcha/fallback/web?session="
+ + session
+ + "&g-recaptcha-response=a",
+ )
+ self.render(request)
+ self.assertEqual(request.code, 200)
+
+ # The recaptcha handler is called with the response given
+ attempts = self.recaptcha_checker.recaptcha_attempts
+ self.assertEqual(len(attempts), 1)
+ self.assertEqual(attempts[0][0]["response"], "a")
+
+ # also complete the dummy auth
+ request, channel = self.make_request(
+ "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
+ )
+ self.render(request)
+
+ # Now we should have fulfilled a complete auth flow, including
+ # the recaptcha fallback step. Make the initial request again, but
+ # with a different password. This causes the request to fail since the
+ # operaiton was modified during the ui auth session.
+ request, channel = self.make_request(
+ "POST",
+ "register",
+ {
+ "username": "user",
+ "type": "m.login.password",
+ "password": "foo", # Note this doesn't match the original request.
+ "auth": {"session": session},
+ },
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 403)
diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/v2_alpha/test_password_policy.py
new file mode 100644
index 0000000000..c57072f50c
--- /dev/null
+++ b/tests/rest/client/v2_alpha/test_password_policy.py
@@ -0,0 +1,179 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+
+from synapse.api.constants import LoginType
+from synapse.api.errors import Codes
+from synapse.rest import admin
+from synapse.rest.client.v1 import login
+from synapse.rest.client.v2_alpha import account, password_policy, register
+
+from tests import unittest
+
+
+class PasswordPolicyTestCase(unittest.HomeserverTestCase):
+ """Tests the password policy feature and its compliance with MSC2000.
+
+ When validating a password, Synapse does the necessary checks in this order:
+
+ 1. Password is long enough
+ 2. Password contains digit(s)
+ 3. Password contains symbol(s)
+ 4. Password contains uppercase letter(s)
+ 5. Password contains lowercase letter(s)
+
+ For each test below that checks whether a password triggers the right error code,
+ that test provides a password good enough to pass the previous tests, but not the
+ one it is currently testing (nor any test that comes afterward).
+ """
+
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ register.register_servlets,
+ password_policy.register_servlets,
+ account.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ self.register_url = "/_matrix/client/r0/register"
+ self.policy = {
+ "enabled": True,
+ "minimum_length": 10,
+ "require_digit": True,
+ "require_symbol": True,
+ "require_lowercase": True,
+ "require_uppercase": True,
+ }
+
+ config = self.default_config()
+ config["password_config"] = {
+ "policy": self.policy,
+ }
+
+ hs = self.setup_test_homeserver(config=config)
+ return hs
+
+ def test_get_policy(self):
+ """Tests if the /password_policy endpoint returns the configured policy."""
+
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/password_policy"
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(
+ channel.json_body,
+ {
+ "m.minimum_length": 10,
+ "m.require_digit": True,
+ "m.require_symbol": True,
+ "m.require_lowercase": True,
+ "m.require_uppercase": True,
+ },
+ channel.result,
+ )
+
+ def test_password_too_short(self):
+ request_data = json.dumps({"username": "kermit", "password": "shorty"})
+ request, channel = self.make_request("POST", self.register_url, request_data)
+ self.render(request)
+
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"], Codes.PASSWORD_TOO_SHORT, channel.result,
+ )
+
+ def test_password_no_digit(self):
+ request_data = json.dumps({"username": "kermit", "password": "longerpassword"})
+ request, channel = self.make_request("POST", self.register_url, request_data)
+ self.render(request)
+
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT, channel.result,
+ )
+
+ def test_password_no_symbol(self):
+ request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"})
+ request, channel = self.make_request("POST", self.register_url, request_data)
+ self.render(request)
+
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"], Codes.PASSWORD_NO_SYMBOL, channel.result,
+ )
+
+ def test_password_no_uppercase(self):
+ request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"})
+ request, channel = self.make_request("POST", self.register_url, request_data)
+ self.render(request)
+
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"], Codes.PASSWORD_NO_UPPERCASE, channel.result,
+ )
+
+ def test_password_no_lowercase(self):
+ request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"})
+ request, channel = self.make_request("POST", self.register_url, request_data)
+ self.render(request)
+
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"], Codes.PASSWORD_NO_LOWERCASE, channel.result,
+ )
+
+ def test_password_compliant(self):
+ request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"})
+ request, channel = self.make_request("POST", self.register_url, request_data)
+ self.render(request)
+
+ # Getting a 401 here means the password has passed validation and the server has
+ # responded with a list of registration flows.
+ self.assertEqual(channel.code, 401, channel.result)
+
+ def test_password_change(self):
+ """This doesn't test every possible use case, only that hitting /account/password
+ triggers the password validation code.
+ """
+ compliant_password = "C0mpl!antpassword"
+ not_compliant_password = "notcompliantpassword"
+
+ user_id = self.register_user("kermit", compliant_password)
+ tok = self.login("kermit", compliant_password)
+
+ request_data = json.dumps(
+ {
+ "new_password": not_compliant_password,
+ "auth": {
+ "password": compliant_password,
+ "type": LoginType.PASSWORD,
+ "user": user_id,
+ },
+ }
+ )
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/account/password",
+ request_data,
+ access_token=tok,
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT)
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index d0c997e385..b6ed06e02d 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -36,8 +36,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
servlets = [register.register_servlets]
url = b"/_matrix/client/r0/register"
- def default_config(self, name="test"):
- config = super().default_config(name)
+ def default_config(self):
+ config = super().default_config()
config["allow_guest_access"] = True
return config
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index 6776a56cad..99eb477149 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -143,8 +143,8 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
endpoint, to check that the two implementations are compatible.
"""
- def default_config(self, *args, **kwargs):
- config = super().default_config(*args, **kwargs)
+ def default_config(self):
+ config = super().default_config()
# replace the signing key with our own
self.hs_signing_key = signedjson.key.generate_signing_key("kssk")
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index 852b8ab11c..2826211f32 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -74,6 +74,12 @@ class URLPreviewTests(unittest.HomeserverTestCase):
)
config["url_preview_ip_range_whitelist"] = ("1.1.1.1",)
config["url_preview_url_blacklist"] = []
+ config["url_preview_accept_language"] = [
+ "en-UK",
+ "en-US;q=0.9",
+ "fr;q=0.8",
+ "*;q=0.7",
+ ]
self.storage_path = self.mktemp()
self.media_store_path = self.mktemp()
@@ -507,3 +513,52 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.pump()
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body, {})
+
+ def test_accept_language_config_option(self):
+ """
+ Accept-Language header is sent to the remote server
+ """
+ self.lookups["example.com"] = [(IPv4Address, "8.8.8.8")]
+
+ # Build and make a request to the server
+ request, channel = self.make_request(
+ "GET", "url_preview?url=http://example.com", shorthand=False
+ )
+ request.render(self.preview_url)
+ self.pump()
+
+ # Extract Synapse's tcp client
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+
+ # Build a fake remote server to reply with
+ server = AccumulatingProtocol()
+
+ # Connect the two together
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+
+ # Tell Synapse that it has received some data from the remote server
+ client.dataReceived(
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n"
+ % (len(self.end_content),)
+ + self.end_content
+ )
+
+ # Move the reactor along until we get a response on our original channel
+ self.pump()
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
+ )
+
+ # Check that the server received the Accept-Language header as part
+ # of the request from Synapse
+ self.assertIn(
+ (
+ b"Accept-Language: en-UK\r\n"
+ b"Accept-Language: en-US;q=0.9\r\n"
+ b"Accept-Language: fr;q=0.8\r\n"
+ b"Accept-Language: *;q=0.7"
+ ),
+ server.data,
+ )
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index eb540e34f6..93eb053b8c 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -19,6 +19,9 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, LimitBlockingTypes, ServerNoticeMsgType
from synapse.api.errors import ResourceLimitError
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v2_alpha import sync
from synapse.server_notices.resource_limits_server_notices import (
ResourceLimitsServerNotices,
)
@@ -28,7 +31,7 @@ from tests import unittest
class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
- hs_config = self.default_config("test")
+ hs_config = self.default_config()
hs_config["server_notices"] = {
"system_mxid_localpart": "server",
"system_mxid_display_name": "test display name",
@@ -67,7 +70,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
# self.server_notices_mxid_avatar_url = None
# self.server_notices_room_name = "Server Notices"
- self._rlsn._server_notices_manager.get_notice_room_for_user = Mock(
+ self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock(
returnValue=""
)
self._rlsn._store.add_tag_to_room = Mock()
@@ -215,6 +218,26 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def default_config(self):
+ c = super().default_config()
+ c["server_notices"] = {
+ "system_mxid_localpart": "server",
+ "system_mxid_display_name": None,
+ "system_mxid_avatar_url": None,
+ "room_name": "Test Server Notice Room",
+ }
+ c["limit_usage_by_mau"] = True
+ c["max_mau_value"] = 5
+ c["admin_contact"] = "mailto:user@test.com"
+ return c
+
def prepare(self, reactor, clock, hs):
self.store = self.hs.get_datastore()
self.server_notices_sender = self.hs.get_server_notices_sender()
@@ -228,18 +251,8 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
if not isinstance(self._rlsn, ResourceLimitsServerNotices):
raise Exception("Failed to find reference to ResourceLimitsServerNotices")
- self.hs.config.limit_usage_by_mau = True
- self.hs.config.hs_disabled = False
- self.hs.config.max_mau_value = 5
- self.hs.config.server_notices_mxid = "@server:test"
- self.hs.config.server_notices_mxid_display_name = None
- self.hs.config.server_notices_mxid_avatar_url = None
- self.hs.config.server_notices_room_name = "Test Server Notice Room"
-
self.user_id = "@user_id:test"
- self.hs.config.admin_contact = "mailto:user@test.com"
-
def test_server_notice_only_sent_once(self):
self.store.get_monthly_active_count = Mock(return_value=1000)
@@ -253,7 +266,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
# Now lets get the last load of messages in the service notice room and
# check that there is only one server notice
room_id = self.get_success(
- self.server_notices_manager.get_notice_room_for_user(self.user_id)
+ self.server_notices_manager.get_or_create_notice_room_for_user(self.user_id)
)
token = self.get_success(self.event_source.get_current_token())
@@ -273,3 +286,86 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
count += 1
self.assertEqual(count, 1)
+
+ def test_no_invite_without_notice(self):
+ """Tests that a user doesn't get invited to a server notices room without a
+ server notice being sent.
+
+ The scenario for this test is a single user on a server where the MAU limit
+ hasn't been reached (since it's the only user and the limit is 5), so users
+ shouldn't receive a server notice.
+ """
+ self.register_user("user", "password")
+ tok = self.login("user", "password")
+
+ request, channel = self.make_request("GET", "/sync?timeout=0", access_token=tok)
+ self.render(request)
+
+ invites = channel.json_body["rooms"]["invite"]
+ self.assertEqual(len(invites), 0, invites)
+
+ def test_invite_with_notice(self):
+ """Tests that, if the MAU limit is hit, the server notices user invites each user
+ to a room in which it has sent a notice.
+ """
+ user_id, tok, room_id = self._trigger_notice_and_join()
+
+ # Sync again to retrieve the events in the room, so we can check whether this
+ # room has a notice in it.
+ request, channel = self.make_request("GET", "/sync?timeout=0", access_token=tok)
+ self.render(request)
+
+ # Scan the events in the room to search for a message from the server notices
+ # user.
+ events = channel.json_body["rooms"]["join"][room_id]["timeline"]["events"]
+ notice_in_room = False
+ for event in events:
+ if (
+ event["type"] == EventTypes.Message
+ and event["sender"] == self.hs.config.server_notices_mxid
+ ):
+ notice_in_room = True
+
+ self.assertTrue(notice_in_room, "No server notice in room")
+
+ def _trigger_notice_and_join(self):
+ """Creates enough active users to hit the MAU limit and trigger a system notice
+ about it, then joins the system notices room with one of the users created.
+
+ Returns:
+ user_id (str): The ID of the user that joined the room.
+ tok (str): The access token of the user that joined the room.
+ room_id (str): The ID of the room that's been joined.
+ """
+ user_id = None
+ tok = None
+ invites = []
+
+ # Register as many users as the MAU limit allows.
+ for i in range(self.hs.config.max_mau_value):
+ localpart = "user%d" % i
+ user_id = self.register_user(localpart, "password")
+ tok = self.login(localpart, "password")
+
+ # Sync with the user's token to mark the user as active.
+ request, channel = self.make_request(
+ "GET", "/sync?timeout=0", access_token=tok,
+ )
+ self.render(request)
+
+ # Also retrieves the list of invites for this user. We don't care about that
+ # one except if we're processing the last user, which should have received an
+ # invite to a room with a server notice about the MAU limit being reached.
+ # We could also pick another user and sync with it, which would return an
+ # invite to a system notices room, but it doesn't matter which user we're
+ # using so we use the last one because it saves us an extra sync.
+ invites = channel.json_body["rooms"]["invite"]
+
+ # Make sure we have an invite to process.
+ self.assertEqual(len(invites), 1, invites)
+
+ # Join the room.
+ room_id = list(invites.keys())[0]
+ self.helper.join(room=room_id, user=user_id, tok=tok)
+
+ return user_id, tok, room_id
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index ae14fb407d..940b166129 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -11,7 +11,9 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, homeserver):
self.updates = self.hs.get_datastore().db.updates # type: BackgroundUpdater
# the base test class should have run the real bg updates for us
- self.assertTrue(self.updates.has_completed_background_updates())
+ self.assertTrue(
+ self.get_success(self.updates.has_completed_background_updates())
+ )
self.update_handler = Mock()
self.updates.register_background_update_handler(
@@ -25,12 +27,20 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
# the target runtime for each bg update
target_background_update_duration_ms = 50000
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.db.simple_insert(
+ "background_updates",
+ values={"update_name": "test_update", "progress_json": '{"my_key": 1}'},
+ )
+ )
+
# first step: make a bit of progress
@defer.inlineCallbacks
def update(progress, count):
yield self.clock.sleep((count * duration_ms) / 1000)
progress = {"my_key": progress["my_key"] + 1}
- yield self.hs.get_datastore().db.runInteraction(
+ yield store.db.runInteraction(
"update_progress",
self.updates._background_update_progress_txn,
"test_update",
@@ -39,10 +49,6 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
return count
self.update_handler.side_effect = update
-
- self.get_success(
- self.updates.start_background_update("test_update", {"my_key": 1})
- )
self.update_handler.reset_mock()
res = self.get_success(
self.updates.do_next_background_update(
@@ -50,7 +56,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
),
by=0.1,
)
- self.assertIsNotNone(res)
+ self.assertFalse(res)
# on the first call, we should get run with the default background update size
self.update_handler.assert_called_once_with(
@@ -73,7 +79,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
result = self.get_success(
self.updates.do_next_background_update(target_background_update_duration_ms)
)
- self.assertIsNotNone(result)
+ self.assertFalse(result)
self.update_handler.assert_called_once()
# third step: we don't expect to be called any more
@@ -81,5 +87,5 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
result = self.get_success(
self.updates.do_next_background_update(target_background_update_duration_ms)
)
- self.assertIsNone(result)
+ self.assertTrue(result)
self.assertFalse(self.update_handler.called)
diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py
new file mode 100644
index 0000000000..5a77c84962
--- /dev/null
+++ b/tests/storage/test_database.py
@@ -0,0 +1,52 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage.database import make_tuple_comparison_clause
+from synapse.storage.engines import BaseDatabaseEngine
+
+from tests import unittest
+
+
+def _stub_db_engine(**kwargs) -> BaseDatabaseEngine:
+ # returns a DatabaseEngine, circumventing the abc mechanism
+ # any kwargs are set as attributes on the class before instantiating it
+ t = type(
+ "TestBaseDatabaseEngine",
+ (BaseDatabaseEngine,),
+ dict(BaseDatabaseEngine.__dict__),
+ )
+ # defeat the abc mechanism
+ t.__abstractmethods__ = set()
+ for k, v in kwargs.items():
+ setattr(t, k, v)
+ return t(None, None)
+
+
+class TupleComparisonClauseTestCase(unittest.TestCase):
+ def test_native_tuple_comparison(self):
+ db_engine = _stub_db_engine(supports_tuple_comparison=True)
+ clause, args = make_tuple_comparison_clause(db_engine, [("a", 1), ("b", 2)])
+ self.assertEqual(clause, "(a,b) > (?,?)")
+ self.assertEqual(args, [1, 2])
+
+ def test_emulated_tuple_comparison(self):
+ db_engine = _stub_db_engine(supports_tuple_comparison=False)
+ clause, args = make_tuple_comparison_clause(
+ db_engine, [("a", 1), ("b", 2), ("c", 3)]
+ )
+ self.assertEqual(
+ clause, "(a >= ? AND (a > ? OR (b >= ? AND (b > ? OR c > ?))))"
+ )
+ self.assertEqual(args, [1, 1, 2, 2, 3])
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 6f8d990959..c2539b353a 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -88,51 +88,6 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
# Check original device_ids are contained within these updates
self._check_devices_in_updates(device_ids, device_updates)
- @defer.inlineCallbacks
- def test_get_device_updates_by_remote_limited(self):
- # Test breaking the update limit in 1, 101, and 1 device_id segments
-
- # first add one device
- device_ids1 = ["device_id0"]
- yield self.store.add_device_change_to_streams(
- "user_id", device_ids1, ["someotherhost"]
- )
-
- # then add 101
- device_ids2 = ["device_id" + str(i + 1) for i in range(101)]
- yield self.store.add_device_change_to_streams(
- "user_id", device_ids2, ["someotherhost"]
- )
-
- # then one more
- device_ids3 = ["newdevice"]
- yield self.store.add_device_change_to_streams(
- "user_id", device_ids3, ["someotherhost"]
- )
-
- #
- # now read them back.
- #
-
- # first we should get a single update
- now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
- "someotherhost", -1, limit=100
- )
- self._check_devices_in_updates(device_ids1, device_updates)
-
- # Then we should get an empty list back as the 101 devices broke the limit
- now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
- "someotherhost", now_stream_id, limit=100
- )
- self.assertEqual(len(device_updates), 0)
-
- # The 101 devices should've been cleared, so we should now just get one device
- # update
- now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
- "someotherhost", now_stream_id, limit=100
- )
- self._check_devices_in_updates(device_ids3, device_updates)
-
def _check_devices_in_updates(self, expected_device_ids, device_updates):
"""Check that an specific device ids exist in a list of device update EDUs"""
self.assertEqual(len(device_updates), len(expected_device_ids))
diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py
index 5ec5d2b358..5c2817cf28 100644
--- a/tests/test_terms_auth.py
+++ b/tests/test_terms_auth.py
@@ -28,8 +28,8 @@ from tests import unittest
class TermsTestCase(unittest.HomeserverTestCase):
servlets = [register_servlets]
- def default_config(self, name="test"):
- config = super().default_config(name)
+ def default_config(self):
+ config = super().default_config()
config.update(
{
"public_baseurl": "https://example.org/",
@@ -53,7 +53,8 @@ class TermsTestCase(unittest.HomeserverTestCase):
def test_ui_auth(self):
# Do a UI auth request
- request, channel = self.make_request(b"POST", self.url, b"{}")
+ request_data = json.dumps({"username": "kermit", "password": "monkey"})
+ request, channel = self.make_request(b"POST", self.url, request_data)
self.render(request)
self.assertEquals(channel.result["code"], b"401", channel.result)
diff --git a/tests/unittest.py b/tests/unittest.py
index 8816a4d152..27af5228fe 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -38,7 +38,12 @@ from synapse.config.ratelimiting import FederationRateLimitConfig
from synapse.federation.transport import server as federation_server
from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest, SynapseSite
-from synapse.logging.context import LoggingContext
+from synapse.logging.context import (
+ SENTINEL_CONTEXT,
+ LoggingContext,
+ current_context,
+ set_current_context,
+)
from synapse.server import HomeServer
from synapse.types import Requester, UserID, create_requester
from synapse.util.ratelimitutils import FederationRateLimiter
@@ -97,10 +102,10 @@ class TestCase(unittest.TestCase):
def setUp(orig):
# if we're not starting in the sentinel logcontext, then to be honest
# all future bets are off.
- if LoggingContext.current_context() is not LoggingContext.sentinel:
+ if current_context():
self.fail(
"Test starting with non-sentinel logging context %s"
- % (LoggingContext.current_context(),)
+ % (current_context(),)
)
old_level = logging.getLogger().level
@@ -122,7 +127,7 @@ class TestCase(unittest.TestCase):
# force a GC to workaround problems with deferreds leaking logcontexts when
# they are GCed (see the logcontext docs)
gc.collect()
- LoggingContext.set_current_context(LoggingContext.sentinel)
+ set_current_context(SENTINEL_CONTEXT)
return ret
@@ -311,14 +316,11 @@ class HomeserverTestCase(TestCase):
return resource
- def default_config(self, name="test"):
+ def default_config(self):
"""
Get a default HomeServer config dict.
-
- Args:
- name (str): The homeserver name/domain.
"""
- config = default_config(name)
+ config = default_config("test")
# apply any additional config which was specified via the override_config
# decorator.
@@ -418,15 +420,17 @@ class HomeserverTestCase(TestCase):
config_obj.parse_config_dict(config, "", "")
kwargs["config"] = config_obj
+ async def run_bg_updates():
+ with LoggingContext("run_bg_updates", request="run_bg_updates-1"):
+ while not await stor.db.updates.has_completed_background_updates():
+ await stor.db.updates.do_next_background_update(1)
+
hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
stor = hs.get_datastore()
# Run the database background updates, when running against "master".
if hs.__class__.__name__ == "TestHomeServer":
- while not self.get_success(
- stor.db.updates.has_completed_background_updates()
- ):
- self.get_success(stor.db.updates.do_next_background_update(1))
+ self.get_success(run_bg_updates())
return hs
@@ -493,6 +497,7 @@ class HomeserverTestCase(TestCase):
"password": password,
"admin": admin,
"mac": want_mac,
+ "inhibit_login": True,
}
)
request, channel = self.make_request(
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 39e360fe24..4d2b9e0d64 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -22,8 +22,10 @@ from twisted.internet import defer, reactor
from synapse.api.errors import SynapseError
from synapse.logging.context import (
+ SENTINEL_CONTEXT,
LoggingContext,
PreserveLoggingContext,
+ current_context,
make_deferred_yieldable,
)
from synapse.util.caches import descriptors
@@ -194,7 +196,7 @@ class DescriptorTestCase(unittest.TestCase):
with LoggingContext() as c1:
c1.name = "c1"
r = yield obj.fn(1)
- self.assertEqual(LoggingContext.current_context(), c1)
+ self.assertEqual(current_context(), c1)
return r
def check_result(r):
@@ -204,12 +206,12 @@ class DescriptorTestCase(unittest.TestCase):
# set off a deferred which will do a cache lookup
d1 = do_lookup()
- self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel)
+ self.assertEqual(current_context(), SENTINEL_CONTEXT)
d1.addCallback(check_result)
# and another
d2 = do_lookup()
- self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel)
+ self.assertEqual(current_context(), SENTINEL_CONTEXT)
d2.addCallback(check_result)
# let the lookup complete
@@ -239,14 +241,14 @@ class DescriptorTestCase(unittest.TestCase):
try:
d = obj.fn(1)
self.assertEqual(
- LoggingContext.current_context(), LoggingContext.sentinel
+ current_context(), SENTINEL_CONTEXT,
)
yield d
self.fail("No exception thrown")
except SynapseError:
pass
- self.assertEqual(LoggingContext.current_context(), c1)
+ self.assertEqual(current_context(), c1)
# the cache should now be empty
self.assertEqual(len(obj.fn.cache.cache), 0)
@@ -255,7 +257,7 @@ class DescriptorTestCase(unittest.TestCase):
# set off a deferred which will do a cache lookup
d1 = do_lookup()
- self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel)
+ self.assertEqual(current_context(), SENTINEL_CONTEXT)
return d1
@@ -366,10 +368,10 @@ class CachedListDescriptorTestCase(unittest.TestCase):
@descriptors.cachedList("fn", "args1", inlineCallbacks=True)
def list_fn(self, args1, arg2):
- assert LoggingContext.current_context().request == "c1"
+ assert current_context().request == "c1"
# we want this to behave like an asynchronous function
yield run_on_reactor()
- assert LoggingContext.current_context().request == "c1"
+ assert current_context().request == "c1"
return self.mock(args1, arg2)
with LoggingContext() as c1:
@@ -377,9 +379,9 @@ class CachedListDescriptorTestCase(unittest.TestCase):
obj = Cls()
obj.mock.return_value = {10: "fish", 20: "chips"}
d1 = obj.list_fn([10, 20], 2)
- self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel)
+ self.assertEqual(current_context(), SENTINEL_CONTEXT)
r = yield d1
- self.assertEqual(LoggingContext.current_context(), c1)
+ self.assertEqual(current_context(), c1)
obj.mock.assert_called_once_with([10, 20], 2)
self.assertEqual(r, {10: "fish", 20: "chips"})
obj.mock.reset_mock()
diff --git a/tests/util/test_async_utils.py b/tests/util/test_async_utils.py
index f60918069a..17fd86d02d 100644
--- a/tests/util/test_async_utils.py
+++ b/tests/util/test_async_utils.py
@@ -16,7 +16,12 @@ from twisted.internet import defer
from twisted.internet.defer import CancelledError, Deferred
from twisted.internet.task import Clock
-from synapse.logging.context import LoggingContext, PreserveLoggingContext
+from synapse.logging.context import (
+ SENTINEL_CONTEXT,
+ LoggingContext,
+ PreserveLoggingContext,
+ current_context,
+)
from synapse.util.async_helpers import timeout_deferred
from tests.unittest import TestCase
@@ -79,10 +84,10 @@ class TimeoutDeferredTest(TestCase):
# the errbacks should be run in the test logcontext
def errback(res, deferred_name):
self.assertIs(
- LoggingContext.current_context(),
+ current_context(),
context_one,
"errback %s run in unexpected logcontext %s"
- % (deferred_name, LoggingContext.current_context()),
+ % (deferred_name, current_context()),
)
return res
@@ -90,7 +95,7 @@ class TimeoutDeferredTest(TestCase):
original_deferred.addErrback(errback, "orig")
timing_out_d = timeout_deferred(original_deferred, 1.0, self.clock)
self.assertNoResult(timing_out_d)
- self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel)
+ self.assertIs(current_context(), SENTINEL_CONTEXT)
timing_out_d.addErrback(errback, "timingout")
self.clock.pump((1.0,))
@@ -99,4 +104,4 @@ class TimeoutDeferredTest(TestCase):
blocking_was_cancelled[0], "non-completing deferred was not cancelled"
)
self.failureResultOf(timing_out_d, defer.TimeoutError)
- self.assertIs(LoggingContext.current_context(), context_one)
+ self.assertIs(current_context(), context_one)
diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py
index 0ec8ef90ce..852ef23185 100644
--- a/tests/util/test_linearizer.py
+++ b/tests/util/test_linearizer.py
@@ -19,7 +19,7 @@ from six.moves import range
from twisted.internet import defer, reactor
from twisted.internet.defer import CancelledError
-from synapse.logging.context import LoggingContext
+from synapse.logging.context import LoggingContext, current_context
from synapse.util import Clock
from synapse.util.async_helpers import Linearizer
@@ -54,11 +54,11 @@ class LinearizerTestCase(unittest.TestCase):
def func(i, sleep=False):
with LoggingContext("func(%s)" % i) as lc:
with (yield linearizer.queue("")):
- self.assertEqual(LoggingContext.current_context(), lc)
+ self.assertEqual(current_context(), lc)
if sleep:
yield Clock(reactor).sleep(0)
- self.assertEqual(LoggingContext.current_context(), lc)
+ self.assertEqual(current_context(), lc)
func(0, sleep=True)
for i in range(1, 100):
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index 281b32c4b8..95301c013c 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -2,8 +2,10 @@ import twisted.python.failure
from twisted.internet import defer, reactor
from synapse.logging.context import (
+ SENTINEL_CONTEXT,
LoggingContext,
PreserveLoggingContext,
+ current_context,
make_deferred_yieldable,
nested_logging_context,
run_in_background,
@@ -15,7 +17,7 @@ from .. import unittest
class LoggingContextTestCase(unittest.TestCase):
def _check_test_key(self, value):
- self.assertEquals(LoggingContext.current_context().request, value)
+ self.assertEquals(current_context().request, value)
def test_with_context(self):
with LoggingContext() as context_one:
@@ -41,7 +43,7 @@ class LoggingContextTestCase(unittest.TestCase):
self._check_test_key("one")
def _test_run_in_background(self, function):
- sentinel_context = LoggingContext.current_context()
+ sentinel_context = current_context()
callback_completed = [False]
@@ -71,7 +73,7 @@ class LoggingContextTestCase(unittest.TestCase):
# make sure that the context was reset before it got thrown back
# into the reactor
try:
- self.assertIs(LoggingContext.current_context(), sentinel_context)
+ self.assertIs(current_context(), sentinel_context)
d2.callback(None)
except BaseException:
d2.errback(twisted.python.failure.Failure())
@@ -108,7 +110,7 @@ class LoggingContextTestCase(unittest.TestCase):
async def testfunc():
self._check_test_key("one")
d = Clock(reactor).sleep(0)
- self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel)
+ self.assertIs(current_context(), SENTINEL_CONTEXT)
await d
self._check_test_key("one")
@@ -129,14 +131,14 @@ class LoggingContextTestCase(unittest.TestCase):
reactor.callLater(0, d.callback, None)
return d
- sentinel_context = LoggingContext.current_context()
+ sentinel_context = current_context()
with LoggingContext() as context_one:
context_one.request = "one"
d1 = make_deferred_yieldable(blocking_function())
# make sure that the context was reset by make_deferred_yieldable
- self.assertIs(LoggingContext.current_context(), sentinel_context)
+ self.assertIs(current_context(), sentinel_context)
yield d1
@@ -145,14 +147,14 @@ class LoggingContextTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_make_deferred_yieldable_with_chained_deferreds(self):
- sentinel_context = LoggingContext.current_context()
+ sentinel_context = current_context()
with LoggingContext() as context_one:
context_one.request = "one"
d1 = make_deferred_yieldable(_chained_deferred_function())
# make sure that the context was reset by make_deferred_yieldable
- self.assertIs(LoggingContext.current_context(), sentinel_context)
+ self.assertIs(current_context(), sentinel_context)
yield d1
@@ -189,14 +191,14 @@ class LoggingContextTestCase(unittest.TestCase):
reactor.callLater(0, d.callback, None)
await d
- sentinel_context = LoggingContext.current_context()
+ sentinel_context = current_context()
with LoggingContext() as context_one:
context_one.request = "one"
d1 = make_deferred_yieldable(blocking_function())
# make sure that the context was reset by make_deferred_yieldable
- self.assertIs(LoggingContext.current_context(), sentinel_context)
+ self.assertIs(current_context(), sentinel_context)
yield d1
diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py
index 72a9de5370..6857933540 100644
--- a/tests/util/test_stream_change_cache.py
+++ b/tests/util/test_stream_change_cache.py
@@ -28,18 +28,26 @@ class StreamChangeCacheTests(unittest.TestCase):
cache.entity_has_changed("user@foo.com", 6)
cache.entity_has_changed("bar@baz.net", 7)
+ # also test multiple things changing on the same stream ID
+ cache.entity_has_changed("user2@foo.com", 8)
+ cache.entity_has_changed("bar2@baz.net", 8)
+
# If it's been changed after that stream position, return True
self.assertTrue(cache.has_entity_changed("user@foo.com", 4))
self.assertTrue(cache.has_entity_changed("bar@baz.net", 4))
+ self.assertTrue(cache.has_entity_changed("bar2@baz.net", 4))
+ self.assertTrue(cache.has_entity_changed("user2@foo.com", 4))
# If it's been changed at that stream position, return False
self.assertFalse(cache.has_entity_changed("user@foo.com", 6))
+ self.assertFalse(cache.has_entity_changed("user2@foo.com", 8))
# If there's no changes after that stream position, return False
self.assertFalse(cache.has_entity_changed("user@foo.com", 7))
+ self.assertFalse(cache.has_entity_changed("user2@foo.com", 9))
# If the entity does not exist, return False.
- self.assertFalse(cache.has_entity_changed("not@here.website", 7))
+ self.assertFalse(cache.has_entity_changed("not@here.website", 9))
# If we request before the stream cache's earliest known position,
# return True, whether it's a known entity or not.
@@ -47,7 +55,7 @@ class StreamChangeCacheTests(unittest.TestCase):
self.assertTrue(cache.has_entity_changed("not@here.website", 0))
@patch("synapse.util.caches.CACHE_SIZE_FACTOR", 1.0)
- def test_has_entity_changed_pops_off_start(self):
+ def test_entity_has_changed_pops_off_start(self):
"""
StreamChangeCache.entity_has_changed will respect the max size and
purge the oldest items upon reaching that max size.
@@ -64,11 +72,20 @@ class StreamChangeCacheTests(unittest.TestCase):
# The oldest item has been popped off
self.assertTrue("user@foo.com" not in cache._entity_to_key)
+ self.assertEqual(
+ cache.get_all_entities_changed(2), ["bar@baz.net", "user@elsewhere.org"],
+ )
+ self.assertIsNone(cache.get_all_entities_changed(1))
+
# If we update an existing entity, it keeps the two existing entities
cache.entity_has_changed("bar@baz.net", 5)
self.assertEqual(
{"bar@baz.net", "user@elsewhere.org"}, set(cache._entity_to_key)
)
+ self.assertEqual(
+ cache.get_all_entities_changed(2), ["user@elsewhere.org", "bar@baz.net"],
+ )
+ self.assertIsNone(cache.get_all_entities_changed(1))
def test_get_all_entities_changed(self):
"""
@@ -80,18 +97,52 @@ class StreamChangeCacheTests(unittest.TestCase):
cache.entity_has_changed("user@foo.com", 2)
cache.entity_has_changed("bar@baz.net", 3)
+ cache.entity_has_changed("anotheruser@foo.com", 3)
cache.entity_has_changed("user@elsewhere.org", 4)
- self.assertEqual(
- cache.get_all_entities_changed(1),
- ["user@foo.com", "bar@baz.net", "user@elsewhere.org"],
- )
- self.assertEqual(
- cache.get_all_entities_changed(2), ["bar@baz.net", "user@elsewhere.org"]
- )
+ r = cache.get_all_entities_changed(1)
+
+ # either of these are valid
+ ok1 = [
+ "user@foo.com",
+ "bar@baz.net",
+ "anotheruser@foo.com",
+ "user@elsewhere.org",
+ ]
+ ok2 = [
+ "user@foo.com",
+ "anotheruser@foo.com",
+ "bar@baz.net",
+ "user@elsewhere.org",
+ ]
+ self.assertTrue(r == ok1 or r == ok2)
+
+ r = cache.get_all_entities_changed(2)
+ self.assertTrue(r == ok1[1:] or r == ok2[1:])
+
self.assertEqual(cache.get_all_entities_changed(3), ["user@elsewhere.org"])
self.assertEqual(cache.get_all_entities_changed(0), None)
+ # ... later, things gest more updates
+ cache.entity_has_changed("user@foo.com", 5)
+ cache.entity_has_changed("bar@baz.net", 5)
+ cache.entity_has_changed("anotheruser@foo.com", 6)
+
+ ok1 = [
+ "user@elsewhere.org",
+ "user@foo.com",
+ "bar@baz.net",
+ "anotheruser@foo.com",
+ ]
+ ok2 = [
+ "user@elsewhere.org",
+ "bar@baz.net",
+ "user@foo.com",
+ "anotheruser@foo.com",
+ ]
+ r = cache.get_all_entities_changed(3)
+ self.assertTrue(r == ok1 or r == ok2)
+
def test_has_any_entity_changed(self):
"""
StreamChangeCache.has_any_entity_changed will return True if any
diff --git a/tests/utils.py b/tests/utils.py
index 513f358f4f..2079e0143d 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -35,7 +35,7 @@ from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import DEFAULT_ROOM_VERSION
from synapse.federation.transport import server as federation_server
from synapse.http.server import HttpServer
-from synapse.logging.context import LoggingContext
+from synapse.logging.context import current_context, set_current_context
from synapse.server import HomeServer
from synapse.storage import DataStore
from synapse.storage.engines import PostgresEngine, create_engine
@@ -332,10 +332,15 @@ def setup_test_homeserver(
# Need to let the HS build an auth handler and then mess with it
# because AuthHandler's constructor requires the HS, so we can't make one
# beforehand and pass it in to the HS's constructor (chicken / egg)
- hs.get_auth_handler().hash = lambda p: hashlib.md5(p.encode("utf8")).hexdigest()
- hs.get_auth_handler().validate_hash = (
- lambda p, h: hashlib.md5(p.encode("utf8")).hexdigest() == h
- )
+ async def hash(p):
+ return hashlib.md5(p.encode("utf8")).hexdigest()
+
+ hs.get_auth_handler().hash = hash
+
+ async def validate_hash(p, h):
+ return hashlib.md5(p.encode("utf8")).hexdigest() == h
+
+ hs.get_auth_handler().validate_hash = validate_hash
fed = kargs.get("resource_for_federation", None)
if fed:
@@ -493,10 +498,10 @@ class MockClock(object):
return self.time() * 1000
def call_later(self, delay, callback, *args, **kwargs):
- current_context = LoggingContext.current_context()
+ ctx = current_context()
def wrapped_callback():
- LoggingContext.thread_local.current_context = current_context
+ set_current_context(ctx)
callback(*args, **kwargs)
t = [self.now + delay, wrapped_callback, False]
|