diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index dbdd427cac..d580e729c5 100644
--- a/tests/api/test_ratelimiting.py
+++ b/tests/api/test_ratelimiting.py
@@ -1,39 +1,97 @@
-from synapse.api.ratelimiting import Ratelimiter
+from synapse.api.ratelimiting import LimitExceededError, Ratelimiter
from tests import unittest
class TestRatelimiter(unittest.TestCase):
- def test_allowed(self):
- limiter = Ratelimiter()
- allowed, time_allowed = limiter.can_do_action(
- key="test_id", time_now_s=0, rate_hz=0.1, burst_count=1
- )
+ def test_allowed_via_can_do_action(self):
+ limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+ allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=0)
self.assertTrue(allowed)
self.assertEquals(10.0, time_allowed)
- allowed, time_allowed = limiter.can_do_action(
- key="test_id", time_now_s=5, rate_hz=0.1, burst_count=1
- )
+ allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=5)
self.assertFalse(allowed)
self.assertEquals(10.0, time_allowed)
- allowed, time_allowed = limiter.can_do_action(
- key="test_id", time_now_s=10, rate_hz=0.1, burst_count=1
- )
+ allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=10)
self.assertTrue(allowed)
self.assertEquals(20.0, time_allowed)
- def test_pruning(self):
- limiter = Ratelimiter()
+ def test_allowed_via_ratelimit(self):
+ limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+
+ # Shouldn't raise
+ limiter.ratelimit(key="test_id", _time_now_s=0)
+
+ # Should raise
+ with self.assertRaises(LimitExceededError) as context:
+ limiter.ratelimit(key="test_id", _time_now_s=5)
+ self.assertEqual(context.exception.retry_after_ms, 5000)
+
+ # Shouldn't raise
+ limiter.ratelimit(key="test_id", _time_now_s=10)
+
+ def test_allowed_via_can_do_action_and_overriding_parameters(self):
+ """Test that we can override options of can_do_action that would otherwise fail
+ an action
+ """
+ # Create a Ratelimiter with a very low allowed rate_hz and burst_count
+ limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+
+ # First attempt should be allowed
+ allowed, time_allowed = limiter.can_do_action(("test_id",), _time_now_s=0,)
+ self.assertTrue(allowed)
+ self.assertEqual(10.0, time_allowed)
+
+ # Second attempt, 1s later, will fail
+ allowed, time_allowed = limiter.can_do_action(("test_id",), _time_now_s=1,)
+ self.assertFalse(allowed)
+ self.assertEqual(10.0, time_allowed)
+
+ # But, if we allow 10 actions/sec for this request, we should be allowed
+ # to continue.
allowed, time_allowed = limiter.can_do_action(
- key="test_id_1", time_now_s=0, rate_hz=0.1, burst_count=1
+ ("test_id",), _time_now_s=1, rate_hz=10.0
)
+ self.assertTrue(allowed)
+ self.assertEqual(1.1, time_allowed)
- self.assertIn("test_id_1", limiter.message_counts)
-
+ # Similarly if we allow a burst of 10 actions
allowed, time_allowed = limiter.can_do_action(
- key="test_id_2", time_now_s=10, rate_hz=0.1, burst_count=1
+ ("test_id",), _time_now_s=1, burst_count=10
)
+ self.assertTrue(allowed)
+ self.assertEqual(1.0, time_allowed)
+
+ def test_allowed_via_ratelimit_and_overriding_parameters(self):
+ """Test that we can override options of the ratelimit method that would otherwise
+ fail an action
+ """
+ # Create a Ratelimiter with a very low allowed rate_hz and burst_count
+ limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+
+ # First attempt should be allowed
+ limiter.ratelimit(key=("test_id",), _time_now_s=0)
+
+ # Second attempt, 1s later, will fail
+ with self.assertRaises(LimitExceededError) as context:
+ limiter.ratelimit(key=("test_id",), _time_now_s=1)
+ self.assertEqual(context.exception.retry_after_ms, 9000)
+
+ # But, if we allow 10 actions/sec for this request, we should be allowed
+ # to continue.
+ limiter.ratelimit(key=("test_id",), _time_now_s=1, rate_hz=10.0)
+
+ # Similarly if we allow a burst of 10 actions
+ limiter.ratelimit(key=("test_id",), _time_now_s=1, burst_count=10)
+
+ def test_pruning(self):
+ limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+ limiter.can_do_action(key="test_id_1", _time_now_s=0)
+
+ self.assertIn("test_id_1", limiter.actions)
+
+ limiter.can_do_action(key="test_id_2", _time_now_s=10)
- self.assertNotIn("test_id_1", limiter.message_counts)
+ self.assertNotIn("test_id_1", limiter.actions)
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 33105576af..ff12539041 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -536,7 +536,7 @@ def build_device_dict(user_id: str, device_id: str, sk: SigningKey):
return {
"user_id": user_id,
"device_id": device_id,
- "algorithms": ["m.olm.curve25519-aes-sha256", "m.megolm.v1.aes-sha"],
+ "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"],
"keys": {
"curve25519:" + device_id: "curve25519+key",
key_id(sk): encode_pubkey(sk),
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 854eb6c024..e1e144b2e7 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -222,7 +222,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
device_key_1 = {
"user_id": local_user,
"device_id": "abc",
- "algorithms": ["m.olm.curve25519-aes-sha256", "m.megolm.v1.aes-sha"],
+ "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"],
"keys": {
"ed25519:abc": "base64+ed25519+key",
"curve25519:abc": "base64+curve25519+key",
@@ -232,7 +232,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
device_key_2 = {
"user_id": local_user,
"device_id": "def",
- "algorithms": ["m.olm.curve25519-aes-sha256", "m.megolm.v1.aes-sha"],
+ "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"],
"keys": {
"ed25519:def": "base64+ed25519+key",
"curve25519:def": "base64+curve25519+key",
@@ -315,7 +315,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
device_key = {
"user_id": local_user,
"device_id": device_id,
- "algorithms": ["m.olm.curve25519-aes-sha256", "m.megolm.v1.aes-sha"],
+ "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"],
"keys": {"curve25519:xyz": "curve25519+key", "ed25519:xyz": device_pubkey},
"signatures": {local_user: {"ed25519:xyz": "something"}},
}
@@ -391,8 +391,8 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"user_id": local_user,
"device_id": device_id,
"algorithms": [
- "m.olm.curve25519-aes-sha256",
- "m.megolm.v1.aes-sha",
+ "m.olm.curve25519-aes-sha2",
+ "m.megolm.v1.aes-sha2",
],
"keys": {
"curve25519:xyz": "curve25519+key",
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 69c0af7716..a1f4bde347 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -14,7 +14,7 @@
# limitations under the License.
-from mock import Mock, NonCallableMock
+from mock import Mock
from twisted.internet import defer
@@ -55,12 +55,8 @@ class ProfileTestCase(unittest.TestCase):
federation_client=self.mock_federation,
federation_server=Mock(),
federation_registry=self.mock_registry,
- ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
)
- self.ratelimiter = hs.get_ratelimiter()
- self.ratelimiter.can_do_action.return_value = (True, 0)
-
self.store = hs.get_datastore()
self.frank = UserID.from_string("@1234ABCD:test")
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index e3623321dc..2a377a4eb9 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -136,6 +136,16 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.handler.register_user(localpart="local_part"), ResourceLimitError
)
+ def test_auto_join_rooms_for_guests(self):
+ room_alias_str = "#room:test"
+ self.hs.config.auto_join_rooms = [room_alias_str]
+ self.hs.config.auto_join_rooms_for_guests = False
+ user_id = self.get_success(
+ self.handler.register_user(localpart="jeff", make_guest=True),
+ )
+ rooms = self.get_success(self.store.get_rooms_for_user(user_id))
+ self.assertEqual(len(rooms), 0)
+
def test_auto_create_auto_join_rooms(self):
room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str]
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 9f54b55acd..0c5cdbd33a 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -14,6 +14,8 @@
# limitations under the License.
from mock import Mock
+from twisted.internet import defer
+
import synapse.rest.admin
from synapse.api.constants import UserTypes
from synapse.rest.client.v1 import login, room
@@ -75,18 +77,16 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
)
- self.store.remove_from_user_dir = Mock()
- self.store.remove_from_user_in_public_room = Mock()
+ self.store.remove_from_user_dir = Mock(return_value=defer.succeed(None))
self.get_success(self.handler.handle_user_deactivated(s_user_id))
self.store.remove_from_user_dir.not_called()
- self.store.remove_from_user_in_public_room.not_called()
def test_handle_user_deactivated_regular_user(self):
r_user_id = "@regular:test"
self.get_success(
self.store.register_user(user_id=r_user_id, password_hash=None)
)
- self.store.remove_from_user_dir = Mock()
+ self.store.remove_from_user_dir = Mock(return_value=defer.succeed(None))
self.get_success(self.handler.handle_user_deactivated(r_user_id))
self.store.remove_from_user_dir.called_once_with(r_user_id)
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 32cb04645f..56497b8476 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock, NonCallableMock
+from mock import Mock
from tests.replication._base import BaseStreamTestCase
@@ -21,12 +21,7 @@ from tests.replication._base import BaseStreamTestCase
class BaseSlavedStoreTestCase(BaseStreamTestCase):
def make_homeserver(self, reactor, clock):
- hs = self.setup_test_homeserver(
- federation_client=Mock(),
- ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
- )
-
- hs.get_ratelimiter().can_do_action.return_value = (True, 0)
+ hs = self.setup_test_homeserver(federation_client=Mock())
return hs
diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
new file mode 100644
index 0000000000..faa7f381a9
--- /dev/null
+++ b/tests/rest/admin/test_device.py
@@ -0,0 +1,541 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Dirk Klimpel
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import urllib.parse
+
+import synapse.rest.admin
+from synapse.api.errors import Codes
+from synapse.rest.client.v1 import login
+
+from tests import unittest
+
+
+class DeviceRestTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.handler = hs.get_device_handler()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_token = self.login("user", "pass")
+ res = self.get_success(self.handler.get_devices_by_user(self.other_user))
+ self.other_user_device_id = res[0]["device_id"]
+
+ self.url = "/_synapse/admin/v2/users/%s/devices/%s" % (
+ urllib.parse.quote(self.other_user),
+ self.other_user_device_id,
+ )
+
+ def test_no_auth(self):
+ """
+ Try to get a device of an user without authentication.
+ """
+ request, channel = self.make_request("GET", self.url, b"{}")
+ self.render(request)
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ request, channel = self.make_request("PUT", self.url, b"{}")
+ self.render(request)
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ request, channel = self.make_request("DELETE", self.url, b"{}")
+ self.render(request)
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error is returned.
+ """
+ request, channel = self.make_request(
+ "GET", self.url, access_token=self.other_user_token,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ request, channel = self.make_request(
+ "PUT", self.url, access_token=self.other_user_token,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ request, channel = self.make_request(
+ "DELETE", self.url, access_token=self.other_user_token,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_user_does_not_exist(self):
+ """
+ Tests that a lookup for a user that does not exist returns a 404
+ """
+ url = (
+ "/_synapse/admin/v2/users/@unknown_person:test/devices/%s"
+ % self.other_user_device_id
+ )
+
+ request, channel = self.make_request(
+ "GET", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ request, channel = self.make_request(
+ "PUT", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ request, channel = self.make_request(
+ "DELETE", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_user_is_not_local(self):
+ """
+ Tests that a lookup for a user that is not a local returns a 400
+ """
+ url = (
+ "/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices/%s"
+ % self.other_user_device_id
+ )
+
+ request, channel = self.make_request(
+ "GET", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Can only lookup local users", channel.json_body["error"])
+
+ request, channel = self.make_request(
+ "PUT", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Can only lookup local users", channel.json_body["error"])
+
+ request, channel = self.make_request(
+ "DELETE", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Can only lookup local users", channel.json_body["error"])
+
+ def test_unknown_device(self):
+ """
+ Tests that a lookup for a device that does not exist returns either 404 or 200.
+ """
+ url = "/_synapse/admin/v2/users/%s/devices/unknown_device" % urllib.parse.quote(
+ self.other_user
+ )
+
+ request, channel = self.make_request(
+ "GET", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ request, channel = self.make_request(
+ "PUT", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ request, channel = self.make_request(
+ "DELETE", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ # Delete unknown device returns status 200
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ def test_update_device_too_long_display_name(self):
+ """
+ Update a device with a display name that is invalid (too long).
+ """
+ # Set iniital display name.
+ update = {"display_name": "new display"}
+ self.get_success(
+ self.handler.update_device(
+ self.other_user, self.other_user_device_id, update
+ )
+ )
+
+ # Request to update a device display name with a new value that is longer than allowed.
+ update = {
+ "display_name": "a"
+ * (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1)
+ }
+
+ body = json.dumps(update)
+ request, channel = self.make_request(
+ "PUT",
+ self.url,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+ self.render(request)
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+
+ # Ensure the display name was not updated.
+ request, channel = self.make_request(
+ "GET", self.url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual("new display", channel.json_body["display_name"])
+
+ def test_update_no_display_name(self):
+ """
+ Tests that a update for a device without JSON returns a 200
+ """
+ # Set iniital display name.
+ update = {"display_name": "new display"}
+ self.get_success(
+ self.handler.update_device(
+ self.other_user, self.other_user_device_id, update
+ )
+ )
+
+ request, channel = self.make_request(
+ "PUT", self.url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # Ensure the display name was not updated.
+ request, channel = self.make_request(
+ "GET", self.url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual("new display", channel.json_body["display_name"])
+
+ def test_update_display_name(self):
+ """
+ Tests a normal successful update of display name
+ """
+ # Set new display_name
+ body = json.dumps({"display_name": "new displayname"})
+ request, channel = self.make_request(
+ "PUT",
+ self.url,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+ self.render(request)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # Check new display_name
+ request, channel = self.make_request(
+ "GET", self.url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual("new displayname", channel.json_body["display_name"])
+
+ def test_get_device(self):
+ """
+ Tests that a normal lookup for a device is successfully
+ """
+ request, channel = self.make_request(
+ "GET", self.url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(self.other_user, channel.json_body["user_id"])
+ # Check that all fields are available
+ self.assertIn("user_id", channel.json_body)
+ self.assertIn("device_id", channel.json_body)
+ self.assertIn("display_name", channel.json_body)
+ self.assertIn("last_seen_ip", channel.json_body)
+ self.assertIn("last_seen_ts", channel.json_body)
+
+ def test_delete_device(self):
+ """
+ Tests that a remove of a device is successfully
+ """
+ # Count number of devies of an user.
+ res = self.get_success(self.handler.get_devices_by_user(self.other_user))
+ number_devices = len(res)
+ self.assertEqual(1, number_devices)
+
+ # Delete device
+ request, channel = self.make_request(
+ "DELETE", self.url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # Ensure that the number of devices is decreased
+ res = self.get_success(self.handler.get_devices_by_user(self.other_user))
+ self.assertEqual(number_devices - 1, len(res))
+
+
+class DevicesRestTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+
+ self.url = "/_synapse/admin/v2/users/%s/devices" % urllib.parse.quote(
+ self.other_user
+ )
+
+ def test_no_auth(self):
+ """
+ Try to list devices of an user without authentication.
+ """
+ request, channel = self.make_request("GET", self.url, b"{}")
+ self.render(request)
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error is returned.
+ """
+ other_user_token = self.login("user", "pass")
+
+ request, channel = self.make_request(
+ "GET", self.url, access_token=other_user_token,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_user_does_not_exist(self):
+ """
+ Tests that a lookup for a user that does not exist returns a 404
+ """
+ url = "/_synapse/admin/v2/users/@unknown_person:test/devices"
+ request, channel = self.make_request(
+ "GET", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_user_is_not_local(self):
+ """
+ Tests that a lookup for a user that is not a local returns a 400
+ """
+ url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices"
+
+ request, channel = self.make_request(
+ "GET", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Can only lookup local users", channel.json_body["error"])
+
+ def test_get_devices(self):
+ """
+ Tests that a normal lookup for devices is successfully
+ """
+ # Create devices
+ number_devices = 5
+ for n in range(number_devices):
+ self.login("user", "pass")
+
+ # Get devices
+ request, channel = self.make_request(
+ "GET", self.url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(number_devices, len(channel.json_body["devices"]))
+ self.assertEqual(self.other_user, channel.json_body["devices"][0]["user_id"])
+ # Check that all fields are available
+ for d in channel.json_body["devices"]:
+ self.assertIn("user_id", d)
+ self.assertIn("device_id", d)
+ self.assertIn("display_name", d)
+ self.assertIn("last_seen_ip", d)
+ self.assertIn("last_seen_ts", d)
+
+
+class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.handler = hs.get_device_handler()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+
+ self.url = "/_synapse/admin/v2/users/%s/delete_devices" % urllib.parse.quote(
+ self.other_user
+ )
+
+ def test_no_auth(self):
+ """
+ Try to delete devices of an user without authentication.
+ """
+ request, channel = self.make_request("POST", self.url, b"{}")
+ self.render(request)
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error is returned.
+ """
+ other_user_token = self.login("user", "pass")
+
+ request, channel = self.make_request(
+ "POST", self.url, access_token=other_user_token,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_user_does_not_exist(self):
+ """
+ Tests that a lookup for a user that does not exist returns a 404
+ """
+ url = "/_synapse/admin/v2/users/@unknown_person:test/delete_devices"
+ request, channel = self.make_request(
+ "POST", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_user_is_not_local(self):
+ """
+ Tests that a lookup for a user that is not a local returns a 400
+ """
+ url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/delete_devices"
+
+ request, channel = self.make_request(
+ "POST", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Can only lookup local users", channel.json_body["error"])
+
+ def test_unknown_devices(self):
+ """
+ Tests that a remove of a device that does not exist returns 200.
+ """
+ body = json.dumps({"devices": ["unknown_device1", "unknown_device2"]})
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+ self.render(request)
+
+ # Delete unknown devices returns status 200
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ def test_delete_devices(self):
+ """
+ Tests that a remove of devices is successfully
+ """
+
+ # Create devices
+ number_devices = 5
+ for n in range(number_devices):
+ self.login("user", "pass")
+
+ # Get devices
+ res = self.get_success(self.handler.get_devices_by_user(self.other_user))
+ self.assertEqual(number_devices, len(res))
+
+ # Create list of device IDs
+ device_ids = []
+ for d in res:
+ device_ids.append(str(d["device_id"]))
+
+ # Delete devices
+ body = json.dumps({"devices": device_ids})
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+ self.render(request)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ res = self.get_success(self.handler.get_devices_by_user(self.other_user))
+ self.assertEqual(0, len(res))
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 6c88ab06e2..cca5f548e6 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -22,9 +22,12 @@ from mock import Mock
import synapse.rest.admin
from synapse.api.constants import UserTypes
+from synapse.api.errors import HttpResponseException, ResourceLimitError
from synapse.rest.client.v1 import login
+from synapse.rest.client.v2_alpha import sync
from tests import unittest
+from tests.unittest import override_config
class UserRegisterTestCase(unittest.HomeserverTestCase):
@@ -320,6 +323,52 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("Invalid user type", channel.json_body["error"])
+ @override_config(
+ {"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0}
+ )
+ def test_register_mau_limit_reached(self):
+ """
+ Check we can register a user via the shared secret registration API
+ even if the MAU limit is reached.
+ """
+ handler = self.hs.get_registration_handler()
+ store = self.hs.get_datastore()
+
+ # Set monthly active users to the limit
+ store.get_monthly_active_count = Mock(return_value=self.hs.config.max_mau_value)
+ # Check that the blocking of monthly active users is working as expected
+ # The registration of a new user fails due to the limit
+ self.get_failure(
+ handler.register_user(localpart="local_part"), ResourceLimitError
+ )
+
+ # Register new user with admin API
+ request, channel = self.make_request("GET", self.url)
+ self.render(request)
+ nonce = channel.json_body["nonce"]
+
+ want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
+ want_mac.update(
+ nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin\x00support"
+ )
+ want_mac = want_mac.hexdigest()
+
+ body = json.dumps(
+ {
+ "nonce": nonce,
+ "username": "bob",
+ "password": "abc123",
+ "admin": True,
+ "user_type": UserTypes.SUPPORT,
+ "mac": want_mac,
+ }
+ )
+ request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@bob:test", channel.json_body["user_id"])
+
class UsersListTestCase(unittest.HomeserverTestCase):
@@ -368,6 +417,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
+ sync.register_servlets,
]
def prepare(self, reactor, clock, hs):
@@ -386,7 +436,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"""
If the user is not a server admin, an error is returned.
"""
- self.hs.config.registration_shared_secret = None
url = "/_synapse/admin/v2/users/@bob:test"
request, channel = self.make_request(
@@ -409,7 +458,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"""
Tests that a lookup for a user that does not exist returns a 404
"""
- self.hs.config.registration_shared_secret = None
request, channel = self.make_request(
"GET",
@@ -425,7 +473,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"""
Check that a new admin user is created successfully.
"""
- self.hs.config.registration_shared_secret = None
url = "/_synapse/admin/v2/users/@bob:test"
# Create user (server admin)
@@ -473,7 +520,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"""
Check that a new regular user is created successfully.
"""
- self.hs.config.registration_shared_secret = None
url = "/_synapse/admin/v2/users/@bob:test"
# Create user
@@ -516,11 +562,192 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(False, channel.json_body["is_guest"])
self.assertEqual(False, channel.json_body["deactivated"])
+ @override_config(
+ {"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0}
+ )
+ def test_create_user_mau_limit_reached_active_admin(self):
+ """
+ Check that an admin can register a new user via the admin API
+ even if the MAU limit is reached.
+ Admin user was active before creating user.
+ """
+
+ handler = self.hs.get_registration_handler()
+
+ # Sync to set admin user to active
+ # before limit of monthly active users is reached
+ request, channel = self.make_request(
+ "GET", "/sync", access_token=self.admin_user_tok
+ )
+ self.render(request)
+
+ if channel.code != 200:
+ raise HttpResponseException(
+ channel.code, channel.result["reason"], channel.result["body"]
+ )
+
+ # Set monthly active users to the limit
+ self.store.get_monthly_active_count = Mock(
+ return_value=self.hs.config.max_mau_value
+ )
+ # Check that the blocking of monthly active users is working as expected
+ # The registration of a new user fails due to the limit
+ self.get_failure(
+ handler.register_user(localpart="local_part"), ResourceLimitError
+ )
+
+ # Register new user with admin API
+ url = "/_synapse/admin/v2/users/@bob:test"
+
+ # Create user
+ body = json.dumps({"password": "abc123", "admin": False})
+
+ request, channel = self.make_request(
+ "PUT",
+ url,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+ self.render(request)
+
+ self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@bob:test", channel.json_body["name"])
+ self.assertEqual(False, channel.json_body["admin"])
+
+ @override_config(
+ {"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0}
+ )
+ def test_create_user_mau_limit_reached_passive_admin(self):
+ """
+ Check that an admin can register a new user via the admin API
+ even if the MAU limit is reached.
+ Admin user was not active before creating user.
+ """
+
+ handler = self.hs.get_registration_handler()
+
+ # Set monthly active users to the limit
+ self.store.get_monthly_active_count = Mock(
+ return_value=self.hs.config.max_mau_value
+ )
+ # Check that the blocking of monthly active users is working as expected
+ # The registration of a new user fails due to the limit
+ self.get_failure(
+ handler.register_user(localpart="local_part"), ResourceLimitError
+ )
+
+ # Register new user with admin API
+ url = "/_synapse/admin/v2/users/@bob:test"
+
+ # Create user
+ body = json.dumps({"password": "abc123", "admin": False})
+
+ request, channel = self.make_request(
+ "PUT",
+ url,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+ self.render(request)
+
+ # Admin user is not blocked by mau anymore
+ self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@bob:test", channel.json_body["name"])
+ self.assertEqual(False, channel.json_body["admin"])
+
+ @override_config(
+ {
+ "email": {
+ "enable_notifs": True,
+ "notif_for_new_users": True,
+ "notif_from": "test@example.com",
+ },
+ "public_baseurl": "https://example.com",
+ }
+ )
+ def test_create_user_email_notif_for_new_users(self):
+ """
+ Check that a new regular user is created successfully and
+ got an email pusher.
+ """
+ url = "/_synapse/admin/v2/users/@bob:test"
+
+ # Create user
+ body = json.dumps(
+ {
+ "password": "abc123",
+ "threepids": [{"medium": "email", "address": "bob@bob.bob"}],
+ }
+ )
+
+ request, channel = self.make_request(
+ "PUT",
+ url,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+ self.render(request)
+
+ self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@bob:test", channel.json_body["name"])
+ self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
+ self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
+
+ pushers = self.get_success(
+ self.store.get_pushers_by({"user_name": "@bob:test"})
+ )
+ pushers = list(pushers)
+ self.assertEqual(len(pushers), 1)
+ self.assertEqual("@bob:test", pushers[0]["user_name"])
+
+ @override_config(
+ {
+ "email": {
+ "enable_notifs": False,
+ "notif_for_new_users": False,
+ "notif_from": "test@example.com",
+ },
+ "public_baseurl": "https://example.com",
+ }
+ )
+ def test_create_user_email_no_notif_for_new_users(self):
+ """
+ Check that a new regular user is created successfully and
+ got not an email pusher.
+ """
+ url = "/_synapse/admin/v2/users/@bob:test"
+
+ # Create user
+ body = json.dumps(
+ {
+ "password": "abc123",
+ "threepids": [{"medium": "email", "address": "bob@bob.bob"}],
+ }
+ )
+
+ request, channel = self.make_request(
+ "PUT",
+ url,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+ self.render(request)
+
+ self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@bob:test", channel.json_body["name"])
+ self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
+ self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
+
+ pushers = self.get_success(
+ self.store.get_pushers_by({"user_name": "@bob:test"})
+ )
+ pushers = list(pushers)
+ self.assertEqual(len(pushers), 0)
+
def test_set_password(self):
"""
Test setting a new password for another user.
"""
- self.hs.config.registration_shared_secret = None
# Change password
body = json.dumps({"password": "hahaha"})
@@ -539,7 +766,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"""
Test setting the displayname of another user.
"""
- self.hs.config.registration_shared_secret = None
# Modify user
body = json.dumps({"displayname": "foobar"})
@@ -570,7 +796,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"""
Test setting threepid for an other user.
"""
- self.hs.config.registration_shared_secret = None
# Delete old and add new threepid to user
body = json.dumps(
@@ -636,7 +861,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"""
Test setting the admin flag on a user.
"""
- self.hs.config.registration_shared_secret = None
# Set a user as an admin
body = json.dumps({"admin": True})
@@ -668,7 +892,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
Ensure an account can't accidentally be deactivated by using a str value
for the deactivated body parameter
"""
- self.hs.config.registration_shared_secret = None
url = "/_synapse/admin/v2/users/@bob:test"
# Create user
diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py
index b54b06482b..f75520877f 100644
--- a/tests/rest/client/v1/test_events.py
+++ b/tests/rest/client/v1/test_events.py
@@ -15,7 +15,7 @@
""" Tests REST events for /events paths."""
-from mock import Mock, NonCallableMock
+from mock import Mock
import synapse.rest.admin
from synapse.rest.client.v1 import events, login, room
@@ -40,11 +40,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
config["enable_registration"] = True
config["auto_join_rooms"] = []
- hs = self.setup_test_homeserver(
- config=config, ratelimiter=NonCallableMock(spec_set=["can_do_action"])
- )
- self.ratelimiter = hs.get_ratelimiter()
- self.ratelimiter.can_do_action.return_value = (True, 0)
+ hs = self.setup_test_homeserver(config=config)
hs.get_handlers().federation_handler = Mock()
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index eb8f6264fd..9033f09fd2 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -1,8 +1,11 @@
import json
+import time
import urllib.parse
from mock import Mock
+import jwt
+
import synapse.rest.admin
from synapse.rest.client.v1 import login, logout
from synapse.rest.client.v2_alpha import devices
@@ -26,7 +29,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
]
def make_homeserver(self, reactor, clock):
-
self.hs = self.setup_test_homeserver()
self.hs.config.enable_registration = True
self.hs.config.registrations_require_3pid = []
@@ -35,10 +37,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
return self.hs
+ @override_config(
+ {
+ "rc_login": {
+ "address": {"per_second": 0.17, "burst_count": 5},
+ # Prevent the account login ratelimiter from raising first
+ #
+ # This is normally covered by the default test homeserver config
+ # which sets these values to 10000, but as we're overriding the entire
+ # rc_login dict here, we need to set this manually as well
+ "account": {"per_second": 10000, "burst_count": 10000},
+ }
+ }
+ )
def test_POST_ratelimiting_per_address(self):
- self.hs.config.rc_login_address.burst_count = 5
- self.hs.config.rc_login_address.per_second = 0.17
-
# Create different users so we're sure not to be bothered by the per-user
# ratelimiter.
for i in range(0, 6):
@@ -77,10 +89,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
+ @override_config(
+ {
+ "rc_login": {
+ "account": {"per_second": 0.17, "burst_count": 5},
+ # Prevent the address login ratelimiter from raising first
+ #
+ # This is normally covered by the default test homeserver config
+ # which sets these values to 10000, but as we're overriding the entire
+ # rc_login dict here, we need to set this manually as well
+ "address": {"per_second": 10000, "burst_count": 10000},
+ }
+ }
+ )
def test_POST_ratelimiting_per_account(self):
- self.hs.config.rc_login_account.burst_count = 5
- self.hs.config.rc_login_account.per_second = 0.17
-
self.register_user("kermit", "monkey")
for i in range(0, 6):
@@ -116,10 +138,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
+ @override_config(
+ {
+ "rc_login": {
+ # Prevent the address login ratelimiter from raising first
+ #
+ # This is normally covered by the default test homeserver config
+ # which sets these values to 10000, but as we're overriding the entire
+ # rc_login dict here, we need to set this manually as well
+ "address": {"per_second": 10000, "burst_count": 10000},
+ "failed_attempts": {"per_second": 0.17, "burst_count": 5},
+ }
+ }
+ )
def test_POST_ratelimiting_per_account_failed_attempts(self):
- self.hs.config.rc_login_failed_attempts.burst_count = 5
- self.hs.config.rc_login_failed_attempts.per_second = 0.17
-
self.register_user("kermit", "monkey")
for i in range(0, 6):
@@ -473,3 +505,153 @@ class CASTestCase(unittest.HomeserverTestCase):
# Because the user is deactivated they are served an error template.
self.assertEqual(channel.code, 403)
self.assertIn(b"SSO account deactivated", channel.result["body"])
+
+
+class JWTTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ ]
+
+ jwt_secret = "secret"
+
+ def make_homeserver(self, reactor, clock):
+ self.hs = self.setup_test_homeserver()
+ self.hs.config.jwt_enabled = True
+ self.hs.config.jwt_secret = self.jwt_secret
+ self.hs.config.jwt_algorithm = "HS256"
+ return self.hs
+
+ def jwt_encode(self, token, secret=jwt_secret):
+ return jwt.encode(token, secret, "HS256").decode("ascii")
+
+ def jwt_login(self, *args):
+ params = json.dumps({"type": "m.login.jwt", "token": self.jwt_encode(*args)})
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
+ self.render(request)
+ return channel
+
+ def test_login_jwt_valid_registered(self):
+ self.register_user("kermit", "monkey")
+ channel = self.jwt_login({"sub": "kermit"})
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.json_body["user_id"], "@kermit:test")
+
+ def test_login_jwt_valid_unregistered(self):
+ channel = self.jwt_login({"sub": "frog"})
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.json_body["user_id"], "@frog:test")
+
+ def test_login_jwt_invalid_signature(self):
+ channel = self.jwt_login({"sub": "frog"}, "notsecret")
+ self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(channel.json_body["error"], "Invalid JWT")
+
+ def test_login_jwt_expired(self):
+ channel = self.jwt_login({"sub": "frog", "exp": 864000})
+ self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(channel.json_body["error"], "JWT expired")
+
+ def test_login_jwt_not_before(self):
+ now = int(time.time())
+ channel = self.jwt_login({"sub": "frog", "nbf": now + 3600})
+ self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(channel.json_body["error"], "Invalid JWT")
+
+ def test_login_no_sub(self):
+ channel = self.jwt_login({"username": "root"})
+ self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(channel.json_body["error"], "Invalid JWT")
+
+ def test_login_no_token(self):
+ params = json.dumps({"type": "m.login.jwt"})
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
+ self.render(request)
+ self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(channel.json_body["error"], "Token field for JWT is missing")
+
+
+# The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use
+# RSS256, with a public key configured in synapse as "jwt_secret", and tokens
+# signed by the private key.
+class JWTPubKeyTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ login.register_servlets,
+ ]
+
+ # This key's pubkey is used as the jwt_secret setting of synapse. Valid
+ # tokens are signed by this and validated using the pubkey. It is generated
+ # with `openssl genrsa 512` (not a secure way to generate real keys, but
+ # good enough for tests!)
+ jwt_privatekey = "\n".join(
+ [
+ "-----BEGIN RSA PRIVATE KEY-----",
+ "MIIBPAIBAAJBAM50f1Q5gsdmzifLstzLHb5NhfajiOt7TKO1vSEWdq7u9x8SMFiB",
+ "492RM9W/XFoh8WUfL9uL6Now6tPRDsWv3xsCAwEAAQJAUv7OOSOtiU+wzJq82rnk",
+ "yR4NHqt7XX8BvkZPM7/+EjBRanmZNSp5kYZzKVaZ/gTOM9+9MwlmhidrUOweKfB/",
+ "kQIhAPZwHazbjo7dYlJs7wPQz1vd+aHSEH+3uQKIysebkmm3AiEA1nc6mDdmgiUq",
+ "TpIN8A4MBKmfZMWTLq6z05y/qjKyxb0CIQDYJxCwTEenIaEa4PdoJl+qmXFasVDN",
+ "ZU0+XtNV7yul0wIhAMI9IhiStIjS2EppBa6RSlk+t1oxh2gUWlIh+YVQfZGRAiEA",
+ "tqBR7qLZGJ5CVKxWmNhJZGt1QHoUtOch8t9C4IdOZ2g=",
+ "-----END RSA PRIVATE KEY-----",
+ ]
+ )
+
+ # Generated with `openssl rsa -in foo.key -pubout`, with the the above
+ # private key placed in foo.key (jwt_privatekey).
+ jwt_pubkey = "\n".join(
+ [
+ "-----BEGIN PUBLIC KEY-----",
+ "MFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBAM50f1Q5gsdmzifLstzLHb5NhfajiOt7",
+ "TKO1vSEWdq7u9x8SMFiB492RM9W/XFoh8WUfL9uL6Now6tPRDsWv3xsCAwEAAQ==",
+ "-----END PUBLIC KEY-----",
+ ]
+ )
+
+ # This key is used to sign tokens that shouldn't be accepted by synapse.
+ # Generated just like jwt_privatekey.
+ bad_privatekey = "\n".join(
+ [
+ "-----BEGIN RSA PRIVATE KEY-----",
+ "MIIBOgIBAAJBAL//SQrKpKbjCCnv/FlasJCv+t3k/MPsZfniJe4DVFhsktF2lwQv",
+ "gLjmQD3jBUTz+/FndLSBvr3F4OHtGL9O/osCAwEAAQJAJqH0jZJW7Smzo9ShP02L",
+ "R6HRZcLExZuUrWI+5ZSP7TaZ1uwJzGFspDrunqaVoPobndw/8VsP8HFyKtceC7vY",
+ "uQIhAPdYInDDSJ8rFKGiy3Ajv5KWISBicjevWHF9dbotmNO9AiEAxrdRJVU+EI9I",
+ "eB4qRZpY6n4pnwyP0p8f/A3NBaQPG+cCIFlj08aW/PbxNdqYoBdeBA0xDrXKfmbb",
+ "iwYxBkwL0JCtAiBYmsi94sJn09u2Y4zpuCbJeDPKzWkbuwQh+W1fhIWQJQIhAKR0",
+ "KydN6cRLvphNQ9c/vBTdlzWxzcSxREpguC7F1J1m",
+ "-----END RSA PRIVATE KEY-----",
+ ]
+ )
+
+ def make_homeserver(self, reactor, clock):
+ self.hs = self.setup_test_homeserver()
+ self.hs.config.jwt_enabled = True
+ self.hs.config.jwt_secret = self.jwt_pubkey
+ self.hs.config.jwt_algorithm = "RS256"
+ return self.hs
+
+ def jwt_encode(self, token, secret=jwt_privatekey):
+ return jwt.encode(token, secret, "RS256").decode("ascii")
+
+ def jwt_login(self, *args):
+ params = json.dumps({"type": "m.login.jwt", "token": self.jwt_encode(*args)})
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
+ self.render(request)
+ return channel
+
+ def test_login_jwt_valid(self):
+ channel = self.jwt_login({"sub": "kermit"})
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.json_body["user_id"], "@kermit:test")
+
+ def test_login_jwt_invalid_signature(self):
+ channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey)
+ self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(channel.json_body["error"], "Invalid JWT")
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 7dd86d0c27..4886bbb401 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -20,7 +20,7 @@
import json
-from mock import Mock, NonCallableMock
+from mock import Mock
from six.moves.urllib import parse as urlparse
from twisted.internet import defer
@@ -46,13 +46,8 @@ class RoomBase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.hs = self.setup_test_homeserver(
- "red",
- http_client=None,
- federation_client=Mock(),
- ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
+ "red", http_client=None, federation_client=Mock(),
)
- self.ratelimiter = self.hs.get_ratelimiter()
- self.ratelimiter.can_do_action.return_value = (True, 0)
self.hs.get_federation_handler = Mock(return_value=Mock())
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 4bc3aaf02d..18260bb90e 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -16,7 +16,7 @@
"""Tests REST events for /rooms paths."""
-from mock import Mock, NonCallableMock
+from mock import Mock
from twisted.internet import defer
@@ -39,17 +39,11 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
- "red",
- http_client=None,
- federation_client=Mock(),
- ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
+ "red", http_client=None, federation_client=Mock(),
)
self.event_source = hs.get_event_sources().sources["typing"]
- self.ratelimiter = hs.get_ratelimiter()
- self.ratelimiter.can_do_action.return_value = (True, 0)
-
hs.get_handlers().federation_handler = Mock()
def get_user_by_access_token(token=None, allow_guest=False):
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 78f2e78943..ceca4041e1 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -33,6 +33,7 @@ from synapse.rest.client.v1 import login, logout
from synapse.rest.client.v2_alpha import account, account_validity, register, sync
from tests import unittest
+from tests.unittest import override_config
class RegisterRestServletTestCase(unittest.HomeserverTestCase):
@@ -142,10 +143,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(channel.json_body["error"], "Guest access is disabled")
+ @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
def test_POST_ratelimiting_guest(self):
- self.hs.config.rc_registration.burst_count = 5
- self.hs.config.rc_registration.per_second = 0.17
-
for i in range(0, 6):
url = self.url + b"?kind=guest"
request, channel = self.make_request(b"POST", url, b"{}")
@@ -164,10 +163,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
+ @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
def test_POST_ratelimiting(self):
- self.hs.config.rc_registration.burst_count = 5
- self.hs.config.rc_registration.per_second = 0.17
-
for i in range(0, 6):
params = {
"username": "kermit" + str(i),
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 1809ceb839..1ca648ef2b 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -18,10 +18,16 @@ import os
import shutil
import tempfile
from binascii import unhexlify
+from io import BytesIO
+from typing import Optional
from mock import Mock
from six.moves.urllib import parse
+import attr
+import PIL.Image as Image
+from parameterized import parameterized_class
+
from twisted.internet.defer import Deferred
from synapse.logging.context import make_deferred_yieldable
@@ -94,6 +100,68 @@ class MediaStorageTests(unittest.HomeserverTestCase):
self.assertEqual(test_body, body)
+@attr.s
+class _TestImage:
+ """An image for testing thumbnailing with the expected results
+
+ Attributes:
+ data: The raw image to thumbnail
+ content_type: The type of the image as a content type, e.g. "image/png"
+ extension: The extension associated with the format, e.g. ".png"
+ expected_cropped: The expected bytes from cropped thumbnailing, or None if
+ test should just check for success.
+ expected_scaled: The expected bytes from scaled thumbnailing, or None if
+ test should just check for a valid image returned.
+ """
+
+ data = attr.ib(type=bytes)
+ content_type = attr.ib(type=bytes)
+ extension = attr.ib(type=bytes)
+ expected_cropped = attr.ib(type=Optional[bytes])
+ expected_scaled = attr.ib(type=Optional[bytes])
+
+
+@parameterized_class(
+ ("test_image",),
+ [
+ # smol png
+ (
+ _TestImage(
+ unhexlify(
+ b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+ b"0000001f15c4890000000a49444154789c63000100000500010d"
+ b"0a2db40000000049454e44ae426082"
+ ),
+ b"image/png",
+ b".png",
+ unhexlify(
+ b"89504e470d0a1a0a0000000d4948445200000020000000200806"
+ b"000000737a7af40000001a49444154789cedc101010000008220"
+ b"ffaf6e484001000000ef0610200001194334ee0000000049454e"
+ b"44ae426082"
+ ),
+ unhexlify(
+ b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+ b"0000001f15c4890000000d49444154789c636060606000000005"
+ b"0001a5f645400000000049454e44ae426082"
+ ),
+ ),
+ ),
+ # small lossless webp
+ (
+ _TestImage(
+ unhexlify(
+ b"524946461a000000574542505650384c0d0000002f0000001007"
+ b"1011118888fe0700"
+ ),
+ b"image/webp",
+ b".webp",
+ None,
+ None,
+ ),
+ ),
+ ],
+)
class MediaRepoTests(unittest.HomeserverTestCase):
hijack_auth = True
@@ -151,13 +219,6 @@ class MediaRepoTests(unittest.HomeserverTestCase):
self.download_resource = self.media_repo.children[b"download"]
self.thumbnail_resource = self.media_repo.children[b"thumbnail"]
- # smol png
- self.end_content = unhexlify(
- b"89504e470d0a1a0a0000000d4948445200000001000000010806"
- b"0000001f15c4890000000a49444154789c63000100000500010d"
- b"0a2db40000000049454e44ae426082"
- )
-
self.media_id = "example.com/12345"
def _req(self, content_disposition):
@@ -176,14 +237,14 @@ class MediaRepoTests(unittest.HomeserverTestCase):
self.assertEqual(self.fetches[0][3], {"allow_remote": "false"})
headers = {
- b"Content-Length": [b"%d" % (len(self.end_content))],
- b"Content-Type": [b"image/png"],
+ b"Content-Length": [b"%d" % (len(self.test_image.data))],
+ b"Content-Type": [self.test_image.content_type],
}
if content_disposition:
headers[b"Content-Disposition"] = [content_disposition]
self.fetches[0][0].callback(
- (self.end_content, (len(self.end_content), headers))
+ (self.test_image.data, (len(self.test_image.data), headers))
)
self.pump()
@@ -196,12 +257,15 @@ class MediaRepoTests(unittest.HomeserverTestCase):
If the filename is filename=<ascii> then Synapse will decode it as an
ASCII string, and use filename= in the response.
"""
- channel = self._req(b"inline; filename=out.png")
+ channel = self._req(b"inline; filename=out" + self.test_image.extension)
headers = channel.headers
- self.assertEqual(headers.getRawHeaders(b"Content-Type"), [b"image/png"])
self.assertEqual(
- headers.getRawHeaders(b"Content-Disposition"), [b"inline; filename=out.png"]
+ headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type]
+ )
+ self.assertEqual(
+ headers.getRawHeaders(b"Content-Disposition"),
+ [b"inline; filename=out" + self.test_image.extension],
)
def test_disposition_filenamestar_utf8escaped(self):
@@ -211,13 +275,17 @@ class MediaRepoTests(unittest.HomeserverTestCase):
response.
"""
filename = parse.quote("\u2603".encode("utf8")).encode("ascii")
- channel = self._req(b"inline; filename*=utf-8''" + filename + b".png")
+ channel = self._req(
+ b"inline; filename*=utf-8''" + filename + self.test_image.extension
+ )
headers = channel.headers
- self.assertEqual(headers.getRawHeaders(b"Content-Type"), [b"image/png"])
+ self.assertEqual(
+ headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type]
+ )
self.assertEqual(
headers.getRawHeaders(b"Content-Disposition"),
- [b"inline; filename*=utf-8''" + filename + b".png"],
+ [b"inline; filename*=utf-8''" + filename + self.test_image.extension],
)
def test_disposition_none(self):
@@ -228,27 +296,16 @@ class MediaRepoTests(unittest.HomeserverTestCase):
channel = self._req(None)
headers = channel.headers
- self.assertEqual(headers.getRawHeaders(b"Content-Type"), [b"image/png"])
+ self.assertEqual(
+ headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type]
+ )
self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
def test_thumbnail_crop(self):
- expected_body = unhexlify(
- b"89504e470d0a1a0a0000000d4948445200000020000000200806"
- b"000000737a7af40000001a49444154789cedc101010000008220"
- b"ffaf6e484001000000ef0610200001194334ee0000000049454e"
- b"44ae426082"
- )
-
- self._test_thumbnail("crop", expected_body)
+ self._test_thumbnail("crop", self.test_image.expected_cropped)
def test_thumbnail_scale(self):
- expected_body = unhexlify(
- b"89504e470d0a1a0a0000000d4948445200000001000000010806"
- b"0000001f15c4890000000d49444154789c636060606000000005"
- b"0001a5f645400000000049454e44ae426082"
- )
-
- self._test_thumbnail("scale", expected_body)
+ self._test_thumbnail("scale", self.test_image.expected_scaled)
def _test_thumbnail(self, method, expected_body):
params = "?width=32&height=32&method=" + method
@@ -259,13 +316,19 @@ class MediaRepoTests(unittest.HomeserverTestCase):
self.pump()
headers = {
- b"Content-Length": [b"%d" % (len(self.end_content))],
- b"Content-Type": [b"image/png"],
+ b"Content-Length": [b"%d" % (len(self.test_image.data))],
+ b"Content-Type": [self.test_image.content_type],
}
self.fetches[0][0].callback(
- (self.end_content, (len(self.end_content), headers))
+ (self.test_image.data, (len(self.test_image.data), headers))
)
self.pump()
self.assertEqual(channel.code, 200)
- self.assertEqual(channel.result["body"], expected_body, channel.result["body"])
+ if expected_body is not None:
+ self.assertEqual(
+ channel.result["body"], expected_body, channel.result["body"]
+ )
+ else:
+ # ensure that the result is at least some valid image
+ Image.open(BytesIO(channel.result["body"]))
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 447fcb3a1c..9c04e92577 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -61,21 +61,27 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
user2_email = threepids[1]["address"]
user3 = "@user3:server"
- self.store.register_user(user_id=user1)
- self.store.register_user(user_id=user2)
- self.store.register_user(user_id=user3, user_type=UserTypes.SUPPORT)
- self.pump()
+ self.get_success(self.store.register_user(user_id=user1))
+ self.get_success(self.store.register_user(user_id=user2))
+ self.get_success(
+ self.store.register_user(user_id=user3, user_type=UserTypes.SUPPORT)
+ )
now = int(self.hs.get_clock().time_msec())
- self.store.user_add_threepid(user1, "email", user1_email, now, now)
- self.store.user_add_threepid(user2, "email", user2_email, now, now)
+ self.get_success(
+ self.store.user_add_threepid(user1, "email", user1_email, now, now)
+ )
+ self.get_success(
+ self.store.user_add_threepid(user2, "email", user2_email, now, now)
+ )
# XXX why are we doing this here? this function is only run at startup
# so it is odd to re-run it here.
- self.store.db.runInteraction(
- "initialise", self.store._initialise_reserved_users, threepids
+ self.get_success(
+ self.store.db.runInteraction(
+ "initialise", self.store._initialise_reserved_users, threepids
+ )
)
- self.pump()
# the number of users we expect will be counted against the mau limit
# -1 because user3 is a support user and does not count
@@ -83,13 +89,13 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
# Check the number of active users. Ensure user3 (support user) is not counted
active_count = self.get_success(self.store.get_monthly_active_count())
- self.assertEquals(active_count, user_num)
+ self.assertEqual(active_count, user_num)
# Test each of the registered users is marked as active
- timestamp = self.store.user_last_seen_monthly_active(user1)
- self.assertTrue(self.get_success(timestamp))
- timestamp = self.store.user_last_seen_monthly_active(user2)
- self.assertTrue(self.get_success(timestamp))
+ timestamp = self.get_success(self.store.user_last_seen_monthly_active(user1))
+ self.assertGreater(timestamp, 0)
+ timestamp = self.get_success(self.store.user_last_seen_monthly_active(user2))
+ self.assertGreater(timestamp, 0)
# Test that users with reserved 3pids are not removed from the MAU table
# XXX some of this is redundant. poking things into the config shouldn't
@@ -98,77 +104,79 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.hs.config.max_mau_value = 0
self.reactor.advance(FORTY_DAYS)
self.hs.config.max_mau_value = 5
- self.store.reap_monthly_active_users()
- self.pump()
- active_count = self.store.get_monthly_active_count()
- self.assertEquals(self.get_success(active_count), user_num)
+ self.get_success(self.store.reap_monthly_active_users())
+
+ active_count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(active_count, user_num)
# Add some more users and check they are counted as active
ru_count = 2
- self.store.upsert_monthly_active_user("@ru1:server")
- self.store.upsert_monthly_active_user("@ru2:server")
- self.pump()
- active_count = self.store.get_monthly_active_count()
- self.assertEqual(self.get_success(active_count), user_num + ru_count)
+
+ self.get_success(self.store.upsert_monthly_active_user("@ru1:server"))
+ self.get_success(self.store.upsert_monthly_active_user("@ru2:server"))
+
+ active_count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(active_count, user_num + ru_count)
# now run the reaper and check that the number of active users is reduced
# to max_mau_value
- self.store.reap_monthly_active_users()
- self.pump()
+ self.get_success(self.store.reap_monthly_active_users())
- active_count = self.store.get_monthly_active_count()
- self.assertEquals(self.get_success(active_count), 3)
+ active_count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(active_count, 3)
def test_can_insert_and_count_mau(self):
- count = self.store.get_monthly_active_count()
- self.assertEqual(0, self.get_success(count))
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(count, 0)
- self.store.upsert_monthly_active_user("@user:server")
- self.pump()
+ d = self.store.upsert_monthly_active_user("@user:server")
+ self.get_success(d)
- count = self.store.get_monthly_active_count()
- self.assertEqual(1, self.get_success(count))
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(count, 1)
def test_user_last_seen_monthly_active(self):
user_id1 = "@user1:server"
user_id2 = "@user2:server"
user_id3 = "@user3:server"
- result = self.store.user_last_seen_monthly_active(user_id1)
- self.assertFalse(self.get_success(result) == 0)
+ result = self.get_success(self.store.user_last_seen_monthly_active(user_id1))
+ self.assertNotEqual(result, 0)
- self.store.upsert_monthly_active_user(user_id1)
- self.store.upsert_monthly_active_user(user_id2)
- self.pump()
+ self.get_success(self.store.upsert_monthly_active_user(user_id1))
+ self.get_success(self.store.upsert_monthly_active_user(user_id2))
- result = self.store.user_last_seen_monthly_active(user_id1)
- self.assertGreater(self.get_success(result), 0)
+ result = self.get_success(self.store.user_last_seen_monthly_active(user_id1))
+ self.assertGreater(result, 0)
- result = self.store.user_last_seen_monthly_active(user_id3)
- self.assertNotEqual(self.get_success(result), 0)
+ result = self.get_success(self.store.user_last_seen_monthly_active(user_id3))
+ self.assertNotEqual(result, 0)
@override_config({"max_mau_value": 5})
def test_reap_monthly_active_users(self):
initial_users = 10
for i in range(initial_users):
- self.store.upsert_monthly_active_user("@user%d:server" % i)
- self.pump()
+ self.get_success(
+ self.store.upsert_monthly_active_user("@user%d:server" % i)
+ )
- count = self.store.get_monthly_active_count()
- self.assertTrue(self.get_success(count), initial_users)
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(count, initial_users)
- self.store.reap_monthly_active_users()
- self.pump()
- count = self.store.get_monthly_active_count()
- self.assertEquals(self.get_success(count), self.hs.config.max_mau_value)
+ d = self.store.reap_monthly_active_users()
+ self.get_success(d)
+
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(count, self.hs.config.max_mau_value)
self.reactor.advance(FORTY_DAYS)
- self.store.reap_monthly_active_users()
- self.pump()
- count = self.store.get_monthly_active_count()
- self.assertEquals(self.get_success(count), 0)
+ d = self.store.reap_monthly_active_users()
+ self.get_success(d)
+
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(count, 0)
# Note that below says mau_limit (no s), this is the name of the config
# value, although it gets stored on the config object as mau_limits.
@@ -182,7 +190,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
for i in range(initial_users):
user = "@user%d:server" % i
email = "user%d@matrix.org" % i
+
self.get_success(self.store.upsert_monthly_active_user(user))
+
# Need to ensure that the most recent entries in the
# monthly_active_users table are reserved
now = int(self.hs.get_clock().time_msec())
@@ -194,26 +204,37 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.user_add_threepid(user, "email", email, now, now)
)
- self.store.db.runInteraction(
+ d = self.store.db.runInteraction(
"initialise", self.store._initialise_reserved_users, threepids
)
- count = self.store.get_monthly_active_count()
- self.assertTrue(self.get_success(count), initial_users)
+ self.get_success(d)
- users = self.store.get_registered_reserved_users()
- self.assertEquals(len(self.get_success(users)), reserved_user_number)
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(count, initial_users)
- self.get_success(self.store.reap_monthly_active_users())
- count = self.store.get_monthly_active_count()
- self.assertEquals(self.get_success(count), self.hs.config.max_mau_value)
+ users = self.get_success(self.store.get_registered_reserved_users())
+ self.assertEqual(len(users), reserved_user_number)
+
+ d = self.store.reap_monthly_active_users()
+ self.get_success(d)
+
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(count, self.hs.config.max_mau_value)
def test_populate_monthly_users_is_guest(self):
# Test that guest users are not added to mau list
user_id = "@user_id:host"
- self.store.register_user(user_id=user_id, password_hash=None, make_guest=True)
+
+ d = self.store.register_user(
+ user_id=user_id, password_hash=None, make_guest=True
+ )
+ self.get_success(d)
+
self.store.upsert_monthly_active_user = Mock()
- self.store.populate_monthly_active_users(user_id)
- self.pump()
+
+ d = self.store.populate_monthly_active_users(user_id)
+ self.get_success(d)
+
self.store.upsert_monthly_active_user.assert_not_called()
def test_populate_monthly_users_should_update(self):
@@ -224,8 +245,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(None)
)
- self.store.populate_monthly_active_users("user_id")
- self.pump()
+ d = self.store.populate_monthly_active_users("user_id")
+ self.get_success(d)
+
self.store.upsert_monthly_active_user.assert_called_once()
def test_populate_monthly_users_should_not_update(self):
@@ -235,16 +257,18 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(self.hs.get_clock().time_msec())
)
- self.store.populate_monthly_active_users("user_id")
- self.pump()
+
+ d = self.store.populate_monthly_active_users("user_id")
+ self.get_success(d)
+
self.store.upsert_monthly_active_user.assert_not_called()
def test_get_reserved_real_user_account(self):
# Test no reserved users, or reserved threepids
users = self.get_success(self.store.get_registered_reserved_users())
- self.assertEquals(len(users), 0)
- # Test reserved users but no registered users
+ self.assertEqual(len(users), 0)
+ # Test reserved users but no registered users
user1 = "@user1:example.com"
user2 = "@user2:example.com"
@@ -254,63 +278,64 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
{"medium": "email", "address": user1_email},
{"medium": "email", "address": user2_email},
]
+
self.hs.config.mau_limits_reserved_threepids = threepids
- self.store.db.runInteraction(
+ d = self.store.db.runInteraction(
"initialise", self.store._initialise_reserved_users, threepids
)
+ self.get_success(d)
- self.pump()
users = self.get_success(self.store.get_registered_reserved_users())
- self.assertEquals(len(users), 0)
+ self.assertEqual(len(users), 0)
- # Test reserved registed users
- self.store.register_user(user_id=user1, password_hash=None)
- self.store.register_user(user_id=user2, password_hash=None)
- self.pump()
+ # Test reserved registered users
+ self.get_success(self.store.register_user(user_id=user1, password_hash=None))
+ self.get_success(self.store.register_user(user_id=user2, password_hash=None))
now = int(self.hs.get_clock().time_msec())
self.store.user_add_threepid(user1, "email", user1_email, now, now)
self.store.user_add_threepid(user2, "email", user2_email, now, now)
users = self.get_success(self.store.get_registered_reserved_users())
- self.assertEquals(len(users), len(threepids))
+ self.assertEqual(len(users), len(threepids))
def test_support_user_not_add_to_mau_limits(self):
support_user_id = "@support:test"
- count = self.store.get_monthly_active_count()
- self.pump()
- self.assertEqual(self.get_success(count), 0)
- self.store.register_user(
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(count, 0)
+
+ d = self.store.register_user(
user_id=support_user_id, password_hash=None, user_type=UserTypes.SUPPORT
)
+ self.get_success(d)
- self.store.upsert_monthly_active_user(support_user_id)
- count = self.store.get_monthly_active_count()
- self.pump()
- self.assertEqual(self.get_success(count), 0)
+ d = self.store.upsert_monthly_active_user(support_user_id)
+ self.get_success(d)
+
+ d = self.store.get_monthly_active_count()
+ count = self.get_success(d)
+ self.assertEqual(count, 0)
# Note that the max_mau_value setting should not matter.
@override_config(
{"limit_usage_by_mau": False, "mau_stats_only": True, "max_mau_value": 1}
)
def test_track_monthly_users_without_cap(self):
- count = self.store.get_monthly_active_count()
- self.assertEqual(0, self.get_success(count))
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(0, count)
- self.store.upsert_monthly_active_user("@user1:server")
- self.store.upsert_monthly_active_user("@user2:server")
- self.pump()
+ self.get_success(self.store.upsert_monthly_active_user("@user1:server"))
+ self.get_success(self.store.upsert_monthly_active_user("@user2:server"))
- count = self.store.get_monthly_active_count()
- self.assertEqual(2, self.get_success(count))
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(2, count)
@override_config({"limit_usage_by_mau": False, "mau_stats_only": False})
def test_no_users_when_not_tracking(self):
self.store.upsert_monthly_active_user = Mock()
- self.store.populate_monthly_active_users("@user:sever")
- self.pump()
+ self.get_success(self.store.populate_monthly_active_users("@user:sever"))
self.store.upsert_monthly_active_user.assert_not_called()
@@ -325,33 +350,39 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
service2 = "service2"
native = "native"
- self.store.register_user(
- user_id=appservice1_user1, password_hash=None, appservice_id=service1
+ self.get_success(
+ self.store.register_user(
+ user_id=appservice1_user1, password_hash=None, appservice_id=service1
+ )
+ )
+ self.get_success(
+ self.store.register_user(
+ user_id=appservice1_user2, password_hash=None, appservice_id=service1
+ )
)
- self.store.register_user(
- user_id=appservice1_user2, password_hash=None, appservice_id=service1
+ self.get_success(
+ self.store.register_user(
+ user_id=appservice2_user1, password_hash=None, appservice_id=service2
+ )
)
- self.store.register_user(
- user_id=appservice2_user1, password_hash=None, appservice_id=service2
+ self.get_success(
+ self.store.register_user(user_id=native_user1, password_hash=None)
)
- self.store.register_user(user_id=native_user1, password_hash=None)
- self.pump()
- count = self.store.get_monthly_active_count_by_service()
- self.assertEqual({}, self.get_success(count))
+ count = self.get_success(self.store.get_monthly_active_count_by_service())
+ self.assertEqual(count, {})
- self.store.upsert_monthly_active_user(native_user1)
- self.store.upsert_monthly_active_user(appservice1_user1)
- self.store.upsert_monthly_active_user(appservice1_user2)
- self.store.upsert_monthly_active_user(appservice2_user1)
- self.pump()
+ self.get_success(self.store.upsert_monthly_active_user(native_user1))
+ self.get_success(self.store.upsert_monthly_active_user(appservice1_user1))
+ self.get_success(self.store.upsert_monthly_active_user(appservice1_user2))
+ self.get_success(self.store.upsert_monthly_active_user(appservice2_user1))
- count = self.store.get_monthly_active_count()
- self.assertEqual(4, self.get_success(count))
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(count, 4)
- count = self.store.get_monthly_active_count_by_service()
- result = self.get_success(count)
+ d = self.store.get_monthly_active_count_by_service()
+ result = self.get_success(d)
- self.assertEqual(2, result[service1])
- self.assertEqual(1, result[service2])
- self.assertEqual(1, result[native])
+ self.assertEqual(result[service1], 2)
+ self.assertEqual(result[service2], 1)
+ self.assertEqual(result[native], 1)
diff --git a/tests/test_federation.py b/tests/test_federation.py
index c5099dd039..c662195eec 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -206,3 +206,59 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# list.
self.reactor.advance(30)
self.assertEqual(self.resync_attempts, 2)
+
+ def test_cross_signing_keys_retry(self):
+ """Tests that resyncing a device list correctly processes cross-signing keys from
+ the remote server.
+ """
+ remote_user_id = "@john:test_remote"
+ remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
+ remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
+
+ # Register mock device list retrieval on the federation client.
+ federation_client = self.homeserver.get_federation_client()
+ federation_client.query_user_devices = Mock(
+ return_value={
+ "user_id": remote_user_id,
+ "stream_id": 1,
+ "devices": [],
+ "master_key": {
+ "user_id": remote_user_id,
+ "usage": ["master"],
+ "keys": {"ed25519:" + remote_master_key: remote_master_key},
+ },
+ "self_signing_key": {
+ "user_id": remote_user_id,
+ "usage": ["self_signing"],
+ "keys": {
+ "ed25519:" + remote_self_signing_key: remote_self_signing_key
+ },
+ },
+ }
+ )
+
+ # Resync the device list.
+ device_handler = self.homeserver.get_device_handler()
+ self.get_success(
+ device_handler.device_list_updater.user_device_resync(remote_user_id),
+ )
+
+ # Retrieve the cross-signing keys for this user.
+ keys = self.get_success(
+ self.store.get_e2e_cross_signing_keys_bulk(user_ids=[remote_user_id]),
+ )
+ self.assertTrue(remote_user_id in keys)
+
+ # Check that the master key is the one returned by the mock.
+ master_key = keys[remote_user_id]["master"]
+ self.assertEqual(len(master_key["keys"]), 1)
+ self.assertTrue("ed25519:" + remote_master_key in master_key["keys"].keys())
+ self.assertTrue(remote_master_key in master_key["keys"].values())
+
+ # Check that the self-signing key is the one returned by the mock.
+ self_signing_key = keys[remote_user_id]["self_signing"]
+ self.assertEqual(len(self_signing_key["keys"]), 1)
+ self.assertTrue(
+ "ed25519:" + remote_self_signing_key in self_signing_key["keys"].keys(),
+ )
+ self.assertTrue(remote_self_signing_key in self_signing_key["keys"].values())
diff --git a/tests/test_mau.py b/tests/test_mau.py
index 8a97f0998d..49667ed7f4 100644
--- a/tests/test_mau.py
+++ b/tests/test_mau.py
@@ -85,7 +85,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
# Advance time by 31 days
self.reactor.advance(31 * 24 * 60 * 60)
- self.store.reap_monthly_active_users()
+ self.get_success(self.store.reap_monthly_active_users())
self.reactor.advance(0)
@@ -147,8 +147,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
# Advance by 2 months so everyone falls out of MAU
self.reactor.advance(60 * 24 * 60 * 60)
- self.store.reap_monthly_active_users()
- self.reactor.advance(0)
+ self.get_success(self.store.reap_monthly_active_users())
# We can create as many new users as we want
token4 = self.create_user("kermit4")
diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py
index 852ef23185..ca3858b184 100644
--- a/tests/util/test_linearizer.py
+++ b/tests/util/test_linearizer.py
@@ -45,6 +45,38 @@ class LinearizerTestCase(unittest.TestCase):
with (yield d2):
pass
+ @defer.inlineCallbacks
+ def test_linearizer_is_queued(self):
+ linearizer = Linearizer()
+
+ key = object()
+
+ d1 = linearizer.queue(key)
+ cm1 = yield d1
+
+ # Since d1 gets called immediately, "is_queued" should return false.
+ self.assertFalse(linearizer.is_queued(key))
+
+ d2 = linearizer.queue(key)
+ self.assertFalse(d2.called)
+
+ # Now d2 is queued up behind successful completion of cm1
+ self.assertTrue(linearizer.is_queued(key))
+
+ with cm1:
+ self.assertFalse(d2.called)
+
+ # cm1 still not done, so d2 still queued.
+ self.assertTrue(linearizer.is_queued(key))
+
+ # And now d2 is called and nothing is in the queue again
+ self.assertFalse(linearizer.is_queued(key))
+
+ with (yield d2):
+ self.assertFalse(linearizer.is_queued(key))
+
+ self.assertFalse(linearizer.is_queued(key))
+
def test_lots_of_queued_things(self):
# we have one slow thing, and lots of fast things queued up behind it.
# it should *not* explode the stack.
|