diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index ba7148ec01..2a0b7c1b56 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -19,6 +19,7 @@ from twisted.internet import defer
from synapse.handlers.appservice import ApplicationServicesHandler
+from tests.test_utils import make_awaitable
from tests.utils import MockClock
from .. import unittest
@@ -32,10 +33,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.mock_as_api = Mock()
self.mock_scheduler = Mock()
hs = Mock()
- hs.get_datastore = Mock(return_value=self.mock_store)
- self.mock_store.get_received_ts.return_value = 0
- hs.get_application_service_api = Mock(return_value=self.mock_as_api)
- hs.get_application_service_scheduler = Mock(return_value=self.mock_scheduler)
+ hs.get_datastore.return_value = self.mock_store
+ self.mock_store.get_received_ts.return_value = defer.succeed(0)
+ self.mock_store.set_appservice_last_pos.return_value = defer.succeed(None)
+ hs.get_application_service_api.return_value = self.mock_as_api
+ hs.get_application_service_scheduler.return_value = self.mock_scheduler
hs.get_clock.return_value = MockClock()
self.handler = ApplicationServicesHandler(hs)
@@ -48,18 +50,18 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self._mkservice(is_interested=False),
]
- self.mock_store.get_app_services = Mock(return_value=services)
- self.mock_store.get_user_by_id = Mock(return_value=[])
+ self.mock_as_api.query_user.return_value = defer.succeed(True)
+ self.mock_store.get_app_services.return_value = services
+ self.mock_store.get_user_by_id.return_value = defer.succeed([])
event = Mock(
sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar"
)
self.mock_store.get_new_events_for_appservice.side_effect = [
- (0, [event]),
- (0, []),
+ defer.succeed((0, [event])),
+ defer.succeed((0, [])),
]
- self.mock_as_api.push = Mock()
- yield self.handler.notify_interested_services(0)
+ yield defer.ensureDeferred(self.handler.notify_interested_services(0))
self.mock_scheduler.submit_event_for_as.assert_called_once_with(
interested_service, event
)
@@ -68,36 +70,34 @@ class AppServiceHandlerTestCase(unittest.TestCase):
def test_query_user_exists_unknown_user(self):
user_id = "@someone:anywhere"
services = [self._mkservice(is_interested=True)]
- services[0].is_interested_in_user = Mock(return_value=True)
- self.mock_store.get_app_services = Mock(return_value=services)
- self.mock_store.get_user_by_id = Mock(return_value=None)
+ services[0].is_interested_in_user.return_value = True
+ self.mock_store.get_app_services.return_value = services
+ self.mock_store.get_user_by_id.return_value = defer.succeed(None)
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
- self.mock_as_api.push = Mock()
- self.mock_as_api.query_user = Mock()
+ self.mock_as_api.query_user.return_value = defer.succeed(True)
self.mock_store.get_new_events_for_appservice.side_effect = [
- (0, [event]),
- (0, []),
+ defer.succeed((0, [event])),
+ defer.succeed((0, [])),
]
- yield self.handler.notify_interested_services(0)
+ yield defer.ensureDeferred(self.handler.notify_interested_services(0))
self.mock_as_api.query_user.assert_called_once_with(services[0], user_id)
@defer.inlineCallbacks
def test_query_user_exists_known_user(self):
user_id = "@someone:anywhere"
services = [self._mkservice(is_interested=True)]
- services[0].is_interested_in_user = Mock(return_value=True)
- self.mock_store.get_app_services = Mock(return_value=services)
- self.mock_store.get_user_by_id = Mock(return_value={"name": user_id})
+ services[0].is_interested_in_user.return_value = True
+ self.mock_store.get_app_services.return_value = services
+ self.mock_store.get_user_by_id.return_value = defer.succeed({"name": user_id})
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
- self.mock_as_api.push = Mock()
- self.mock_as_api.query_user = Mock()
+ self.mock_as_api.query_user.return_value = defer.succeed(True)
self.mock_store.get_new_events_for_appservice.side_effect = [
- (0, [event]),
- (0, []),
+ defer.succeed((0, [event])),
+ defer.succeed((0, [])),
]
- yield self.handler.notify_interested_services(0)
+ yield defer.ensureDeferred(self.handler.notify_interested_services(0))
self.assertFalse(
self.mock_as_api.query_user.called,
"query_user called when it shouldn't have been.",
@@ -107,7 +107,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
def test_query_room_alias_exists(self):
room_alias_str = "#foo:bar"
room_alias = Mock()
- room_alias.to_string = Mock(return_value=room_alias_str)
+ room_alias.to_string.return_value = room_alias_str
room_id = "!alpha:bet"
servers = ["aperture"]
@@ -118,12 +118,15 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self._mkservice_alias(is_interested_in_alias=False),
]
- self.mock_store.get_app_services = Mock(return_value=services)
- self.mock_store.get_association_from_room_alias = Mock(
- return_value=Mock(room_id=room_id, servers=servers)
+ self.mock_as_api.query_alias.return_value = make_awaitable(True)
+ self.mock_store.get_app_services.return_value = services
+ self.mock_store.get_association_from_room_alias.return_value = make_awaitable(
+ Mock(room_id=room_id, servers=servers)
)
- result = yield self.handler.query_room_alias_exists(room_alias)
+ result = yield defer.ensureDeferred(
+ self.handler.query_room_alias_exists(room_alias)
+ )
self.mock_as_api.query_alias.assert_called_once_with(
interested_service, room_alias_str
@@ -133,14 +136,14 @@ class AppServiceHandlerTestCase(unittest.TestCase):
def _mkservice(self, is_interested):
service = Mock()
- service.is_interested = Mock(return_value=is_interested)
+ service.is_interested.return_value = make_awaitable(is_interested)
service.token = "mock_service_token"
service.url = "mock_service_url"
return service
def _mkservice_alias(self, is_interested_in_alias):
service = Mock()
- service.is_interested_in_alias = Mock(return_value=is_interested_in_alias)
+ service.is_interested_in_alias.return_value = is_interested_in_alias
service.token = "mock_service_token"
service.url = "mock_service_url"
return service
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index c01b04e1dc..c7efd3822d 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -24,10 +24,11 @@ from synapse.api.errors import ResourceLimitError
from synapse.handlers.auth import AuthHandler
from tests import unittest
+from tests.test_utils import make_awaitable
from tests.utils import setup_test_homeserver
-class AuthHandlers(object):
+class AuthHandlers:
def __init__(self, hs):
self.auth_handler = AuthHandler(hs)
@@ -142,7 +143,7 @@ class AuthTestCase(unittest.TestCase):
def test_mau_limits_exceeded_large(self):
self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=defer.succeed(self.large_number_of_users)
+ side_effect=lambda: make_awaitable(self.large_number_of_users)
)
with self.assertRaises(ResourceLimitError):
@@ -153,7 +154,7 @@ class AuthTestCase(unittest.TestCase):
)
self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=defer.succeed(self.large_number_of_users)
+ side_effect=lambda: make_awaitable(self.large_number_of_users)
)
with self.assertRaises(ResourceLimitError):
yield defer.ensureDeferred(
@@ -168,7 +169,7 @@ class AuthTestCase(unittest.TestCase):
# If not in monthly active cohort
self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=defer.succeed(self.auth_blocking._max_mau_value)
+ side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value)
)
with self.assertRaises(ResourceLimitError):
yield defer.ensureDeferred(
@@ -178,7 +179,7 @@ class AuthTestCase(unittest.TestCase):
)
self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=defer.succeed(self.auth_blocking._max_mau_value)
+ side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value)
)
with self.assertRaises(ResourceLimitError):
yield defer.ensureDeferred(
@@ -188,10 +189,10 @@ class AuthTestCase(unittest.TestCase):
)
# If in monthly active cohort
self.hs.get_datastore().user_last_seen_monthly_active = Mock(
- return_value=defer.succeed(self.hs.get_clock().time_msec())
+ side_effect=lambda user_id: make_awaitable(self.hs.get_clock().time_msec())
)
self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=defer.succeed(self.auth_blocking._max_mau_value)
+ side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value)
)
yield defer.ensureDeferred(
self.auth_handler.get_access_token_for_user_id(
@@ -199,10 +200,10 @@ class AuthTestCase(unittest.TestCase):
)
)
self.hs.get_datastore().user_last_seen_monthly_active = Mock(
- return_value=defer.succeed(self.hs.get_clock().time_msec())
+ side_effect=lambda user_id: make_awaitable(self.hs.get_clock().time_msec())
)
self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=defer.succeed(self.auth_blocking._max_mau_value)
+ side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value)
)
yield defer.ensureDeferred(
self.auth_handler.validate_short_term_login_token_and_get_user_id(
@@ -215,7 +216,7 @@ class AuthTestCase(unittest.TestCase):
self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=defer.succeed(self.small_number_of_users)
+ side_effect=lambda: make_awaitable(self.small_number_of_users)
)
# Ensure does not raise exception
yield defer.ensureDeferred(
@@ -225,7 +226,7 @@ class AuthTestCase(unittest.TestCase):
)
self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=defer.succeed(self.small_number_of_users)
+ side_effect=lambda: make_awaitable(self.small_number_of_users)
)
yield defer.ensureDeferred(
self.auth_handler.validate_short_term_login_token_and_get_user_id(
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 62b47f6574..6aa322bf3a 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -142,10 +142,8 @@ class DeviceTestCase(unittest.HomeserverTestCase):
self.get_success(self.handler.delete_device(user1, "abc"))
# check the device was deleted
- res = self.handler.get_device(user1, "abc")
- self.pump()
- self.assertIsInstance(
- self.failureResultOf(res).value, synapse.api.errors.NotFoundError
+ self.get_failure(
+ self.handler.get_device(user1, "abc"), synapse.api.errors.NotFoundError
)
# we'd like to check the access token was invalidated, but that's a
@@ -180,10 +178,9 @@ class DeviceTestCase(unittest.HomeserverTestCase):
def test_update_unknown_device(self):
update = {"display_name": "new_display"}
- res = self.handler.update_device("user_id", "unknown_device_id", update)
- self.pump()
- self.assertIsInstance(
- self.failureResultOf(res).value, synapse.api.errors.NotFoundError
+ self.get_failure(
+ self.handler.update_device("user_id", "unknown_device_id", update),
+ synapse.api.errors.NotFoundError,
)
def _record_users(self):
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 00bb776271..bc0c5aefdc 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -16,8 +16,6 @@
from mock import Mock
-from twisted.internet import defer
-
import synapse
import synapse.api.errors
from synapse.api.constants import EventTypes
@@ -26,6 +24,7 @@ from synapse.rest.client.v1 import directory, login, room
from synapse.types import RoomAlias, create_requester
from tests import unittest
+from tests.test_utils import make_awaitable
class DirectoryTestCase(unittest.HomeserverTestCase):
@@ -71,7 +70,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.assertEquals({"room_id": "!8765qwer:test", "servers": ["test"]}, result)
def test_get_remote_association(self):
- self.mock_federation.make_query.return_value = defer.succeed(
+ self.mock_federation.make_query.return_value = make_awaitable(
{"room_id": "!8765qwer:test", "servers": ["test", "remote"]}
)
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index e1e144b2e7..210ddcbb88 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -14,17 +14,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import mock
-import signedjson.key as key
-import signedjson.sign as sign
+from signedjson import key as key, sign as sign
from twisted.internet import defer
import synapse.handlers.e2e_keys
import synapse.storage
from synapse.api import errors
+from synapse.api.constants import RoomEncryptionAlgorithms
from tests import unittest, utils
@@ -47,7 +46,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"""If the user has no devices, we expect an empty list.
"""
local_user = "@boris:" + self.hs.hostname
- res = yield self.handler.query_local_devices({local_user: None})
+ res = yield defer.ensureDeferred(
+ self.handler.query_local_devices({local_user: None})
+ )
self.assertDictEqual(res, {local_user: {}})
@defer.inlineCallbacks
@@ -61,15 +62,19 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"alg2:k3": {"key": "key3"},
}
- res = yield self.handler.upload_keys_for_user(
- local_user, device_id, {"one_time_keys": keys}
+ res = yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user, device_id, {"one_time_keys": keys}
+ )
)
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
# we should be able to change the signature without a problem
keys["alg2:k2"]["signatures"]["k1"] = "sig2"
- res = yield self.handler.upload_keys_for_user(
- local_user, device_id, {"one_time_keys": keys}
+ res = yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user, device_id, {"one_time_keys": keys}
+ )
)
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
@@ -85,44 +90,56 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"alg2:k3": {"key": "key3"},
}
- res = yield self.handler.upload_keys_for_user(
- local_user, device_id, {"one_time_keys": keys}
+ res = yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user, device_id, {"one_time_keys": keys}
+ )
)
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
try:
- yield self.handler.upload_keys_for_user(
- local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}}
+ yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}}
+ )
)
self.fail("No error when changing string key")
except errors.SynapseError:
pass
try:
- yield self.handler.upload_keys_for_user(
- local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}}
+ yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}}
+ )
)
self.fail("No error when replacing dict key with string")
except errors.SynapseError:
pass
try:
- yield self.handler.upload_keys_for_user(
- local_user, device_id, {"one_time_keys": {"alg1:k1": {"key": "key"}}}
+ yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user,
+ device_id,
+ {"one_time_keys": {"alg1:k1": {"key": "key"}}},
+ )
)
self.fail("No error when replacing string key with dict")
except errors.SynapseError:
pass
try:
- yield self.handler.upload_keys_for_user(
- local_user,
- device_id,
- {
- "one_time_keys": {
- "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}}
- }
- },
+ yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user,
+ device_id,
+ {
+ "one_time_keys": {
+ "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}}
+ }
+ },
+ )
)
self.fail("No error when replacing dict key")
except errors.SynapseError:
@@ -134,13 +151,17 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
device_id = "xyz"
keys = {"alg1:k1": "key1"}
- res = yield self.handler.upload_keys_for_user(
- local_user, device_id, {"one_time_keys": keys}
+ res = yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user, device_id, {"one_time_keys": keys}
+ )
)
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1}})
- res2 = yield self.handler.claim_one_time_keys(
- {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+ res2 = yield defer.ensureDeferred(
+ self.handler.claim_one_time_keys(
+ {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+ )
)
self.assertEqual(
res2,
@@ -164,7 +185,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
},
}
}
- yield self.handler.upload_signing_keys_for_user(local_user, keys1)
+ yield defer.ensureDeferred(
+ self.handler.upload_signing_keys_for_user(local_user, keys1)
+ )
keys2 = {
"master_key": {
@@ -176,10 +199,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
},
}
}
- yield self.handler.upload_signing_keys_for_user(local_user, keys2)
+ yield defer.ensureDeferred(
+ self.handler.upload_signing_keys_for_user(local_user, keys2)
+ )
- devices = yield self.handler.query_devices(
- {"device_keys": {local_user: []}}, 0, local_user
+ devices = yield defer.ensureDeferred(
+ self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
)
self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
@@ -216,13 +241,18 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
"2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0",
)
- yield self.handler.upload_signing_keys_for_user(local_user, keys1)
+ yield defer.ensureDeferred(
+ self.handler.upload_signing_keys_for_user(local_user, keys1)
+ )
# upload two device keys, which will be signed later by the self-signing key
device_key_1 = {
"user_id": local_user,
"device_id": "abc",
- "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"],
+ "algorithms": [
+ "m.olm.curve25519-aes-sha2",
+ RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+ ],
"keys": {
"ed25519:abc": "base64+ed25519+key",
"curve25519:abc": "base64+curve25519+key",
@@ -232,7 +262,10 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
device_key_2 = {
"user_id": local_user,
"device_id": "def",
- "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"],
+ "algorithms": [
+ "m.olm.curve25519-aes-sha2",
+ RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+ ],
"keys": {
"ed25519:def": "base64+ed25519+key",
"curve25519:def": "base64+curve25519+key",
@@ -240,18 +273,24 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"signatures": {local_user: {"ed25519:def": "base64+signature"}},
}
- yield self.handler.upload_keys_for_user(
- local_user, "abc", {"device_keys": device_key_1}
+ yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user, "abc", {"device_keys": device_key_1}
+ )
)
- yield self.handler.upload_keys_for_user(
- local_user, "def", {"device_keys": device_key_2}
+ yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user, "def", {"device_keys": device_key_2}
+ )
)
# sign the first device key and upload it
del device_key_1["signatures"]
sign.sign_json(device_key_1, local_user, signing_key)
- yield self.handler.upload_signatures_for_device_keys(
- local_user, {local_user: {"abc": device_key_1}}
+ yield defer.ensureDeferred(
+ self.handler.upload_signatures_for_device_keys(
+ local_user, {local_user: {"abc": device_key_1}}
+ )
)
# sign the second device key and upload both device keys. The server
@@ -259,14 +298,16 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
# signature for it
del device_key_2["signatures"]
sign.sign_json(device_key_2, local_user, signing_key)
- yield self.handler.upload_signatures_for_device_keys(
- local_user, {local_user: {"abc": device_key_1, "def": device_key_2}}
+ yield defer.ensureDeferred(
+ self.handler.upload_signatures_for_device_keys(
+ local_user, {local_user: {"abc": device_key_1, "def": device_key_2}}
+ )
)
device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature"
device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature"
- devices = yield self.handler.query_devices(
- {"device_keys": {local_user: []}}, 0, local_user
+ devices = yield defer.ensureDeferred(
+ self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
)
del devices["device_keys"][local_user]["abc"]["unsigned"]
del devices["device_keys"][local_user]["def"]["unsigned"]
@@ -287,20 +328,26 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
},
}
}
- yield self.handler.upload_signing_keys_for_user(local_user, keys1)
+ yield defer.ensureDeferred(
+ self.handler.upload_signing_keys_for_user(local_user, keys1)
+ )
res = None
try:
- yield self.hs.get_device_handler().check_device_registered(
- user_id=local_user,
- device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
- initial_device_display_name="new display name",
+ yield defer.ensureDeferred(
+ self.hs.get_device_handler().check_device_registered(
+ user_id=local_user,
+ device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
+ initial_device_display_name="new display name",
+ )
)
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 400)
- res = yield self.handler.query_local_devices({local_user: None})
+ res = yield defer.ensureDeferred(
+ self.handler.query_local_devices({local_user: None})
+ )
self.assertDictEqual(res, {local_user: {}})
@defer.inlineCallbacks
@@ -315,7 +362,10 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
device_key = {
"user_id": local_user,
"device_id": device_id,
- "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"],
+ "algorithms": [
+ "m.olm.curve25519-aes-sha2",
+ RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+ ],
"keys": {"curve25519:xyz": "curve25519+key", "ed25519:xyz": device_pubkey},
"signatures": {local_user: {"ed25519:xyz": "something"}},
}
@@ -323,8 +373,10 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"ed25519", "xyz", "OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA"
)
- yield self.handler.upload_keys_for_user(
- local_user, device_id, {"device_keys": device_key}
+ yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user, device_id, {"device_keys": device_key}
+ )
)
# private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0
@@ -364,7 +416,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"user_signing_key": usersigning_key,
"self_signing_key": selfsigning_key,
}
- yield self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys)
+ yield defer.ensureDeferred(
+ self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys)
+ )
# set up another user with a master key. This user will be signed by
# the first user
@@ -376,76 +430,90 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"usage": ["master"],
"keys": {"ed25519:" + other_master_pubkey: other_master_pubkey},
}
- yield self.handler.upload_signing_keys_for_user(
- other_user, {"master_key": other_master_key}
+ yield defer.ensureDeferred(
+ self.handler.upload_signing_keys_for_user(
+ other_user, {"master_key": other_master_key}
+ )
)
# test various signature failures (see below)
- ret = yield self.handler.upload_signatures_for_device_keys(
- local_user,
- {
- local_user: {
- # fails because the signature is invalid
- # should fail with INVALID_SIGNATURE
- device_id: {
- "user_id": local_user,
- "device_id": device_id,
- "algorithms": [
- "m.olm.curve25519-aes-sha2",
- "m.megolm.v1.aes-sha2",
- ],
- "keys": {
- "curve25519:xyz": "curve25519+key",
- # private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA
- "ed25519:xyz": device_pubkey,
+ ret = yield defer.ensureDeferred(
+ self.handler.upload_signatures_for_device_keys(
+ local_user,
+ {
+ local_user: {
+ # fails because the signature is invalid
+ # should fail with INVALID_SIGNATURE
+ device_id: {
+ "user_id": local_user,
+ "device_id": device_id,
+ "algorithms": [
+ "m.olm.curve25519-aes-sha2",
+ RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+ ],
+ "keys": {
+ "curve25519:xyz": "curve25519+key",
+ # private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA
+ "ed25519:xyz": device_pubkey,
+ },
+ "signatures": {
+ local_user: {
+ "ed25519:" + selfsigning_pubkey: "something"
+ }
+ },
},
- "signatures": {
- local_user: {"ed25519:" + selfsigning_pubkey: "something"}
+ # fails because device is unknown
+ # should fail with NOT_FOUND
+ "unknown": {
+ "user_id": local_user,
+ "device_id": "unknown",
+ "signatures": {
+ local_user: {
+ "ed25519:" + selfsigning_pubkey: "something"
+ }
+ },
},
- },
- # fails because device is unknown
- # should fail with NOT_FOUND
- "unknown": {
- "user_id": local_user,
- "device_id": "unknown",
- "signatures": {
- local_user: {"ed25519:" + selfsigning_pubkey: "something"}
+ # fails because the signature is invalid
+ # should fail with INVALID_SIGNATURE
+ master_pubkey: {
+ "user_id": local_user,
+ "usage": ["master"],
+ "keys": {"ed25519:" + master_pubkey: master_pubkey},
+ "signatures": {
+ local_user: {"ed25519:" + device_pubkey: "something"}
+ },
},
},
- # fails because the signature is invalid
- # should fail with INVALID_SIGNATURE
- master_pubkey: {
- "user_id": local_user,
- "usage": ["master"],
- "keys": {"ed25519:" + master_pubkey: master_pubkey},
- "signatures": {
- local_user: {"ed25519:" + device_pubkey: "something"}
+ other_user: {
+ # fails because the device is not the user's master-signing key
+ # should fail with NOT_FOUND
+ "unknown": {
+ "user_id": other_user,
+ "device_id": "unknown",
+ "signatures": {
+ local_user: {
+ "ed25519:" + usersigning_pubkey: "something"
+ }
+ },
},
- },
- },
- other_user: {
- # fails because the device is not the user's master-signing key
- # should fail with NOT_FOUND
- "unknown": {
- "user_id": other_user,
- "device_id": "unknown",
- "signatures": {
- local_user: {"ed25519:" + usersigning_pubkey: "something"}
- },
- },
- other_master_pubkey: {
- # fails because the key doesn't match what the server has
- # should fail with UNKNOWN
- "user_id": other_user,
- "usage": ["master"],
- "keys": {"ed25519:" + other_master_pubkey: other_master_pubkey},
- "something": "random",
- "signatures": {
- local_user: {"ed25519:" + usersigning_pubkey: "something"}
+ other_master_pubkey: {
+ # fails because the key doesn't match what the server has
+ # should fail with UNKNOWN
+ "user_id": other_user,
+ "usage": ["master"],
+ "keys": {
+ "ed25519:" + other_master_pubkey: other_master_pubkey
+ },
+ "something": "random",
+ "signatures": {
+ local_user: {
+ "ed25519:" + usersigning_pubkey: "something"
+ }
+ },
},
},
},
- },
+ )
)
user_failures = ret["failures"][local_user]
@@ -470,19 +538,23 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
sign.sign_json(device_key, local_user, selfsigning_signing_key)
sign.sign_json(master_key, local_user, device_signing_key)
sign.sign_json(other_master_key, local_user, usersigning_signing_key)
- ret = yield self.handler.upload_signatures_for_device_keys(
- local_user,
- {
- local_user: {device_id: device_key, master_pubkey: master_key},
- other_user: {other_master_pubkey: other_master_key},
- },
+ ret = yield defer.ensureDeferred(
+ self.handler.upload_signatures_for_device_keys(
+ local_user,
+ {
+ local_user: {device_id: device_key, master_pubkey: master_key},
+ other_user: {other_master_pubkey: other_master_key},
+ },
+ )
)
self.assertEqual(ret["failures"], {})
# fetch the signed keys/devices and make sure that the signatures are there
- ret = yield self.handler.query_devices(
- {"device_keys": {local_user: [], other_user: []}}, 0, local_user
+ ret = yield defer.ensureDeferred(
+ self.handler.query_devices(
+ {"device_keys": {local_user: [], other_user: []}}, 0, local_user
+ )
)
self.assertEqual(
diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py
index 70f172eb02..3362050ce0 100644
--- a/tests/handlers/test_e2e_room_keys.py
+++ b/tests/handlers/test_e2e_room_keys.py
@@ -66,7 +66,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""
res = None
try:
- yield self.handler.get_version_info(self.local_user)
+ yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 404)
@@ -78,7 +78,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""
res = None
try:
- yield self.handler.get_version_info(self.local_user, "bogus_version")
+ yield defer.ensureDeferred(
+ self.handler.get_version_info(self.local_user, "bogus_version")
+ )
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 404)
@@ -87,15 +89,21 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_create_version(self):
"""Check that we can create and then retrieve versions.
"""
- res = yield self.handler.create_version(
- self.local_user,
- {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ res = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
)
self.assertEqual(res, "1")
# check we can retrieve it as the current version
- res = yield self.handler.get_version_info(self.local_user)
+ res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
version_etag = res["etag"]
+ self.assertIsInstance(version_etag, str)
del res["etag"]
self.assertDictEqual(
res,
@@ -108,7 +116,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
)
# check we can retrieve it as a specific version
- res = yield self.handler.get_version_info(self.local_user, "1")
+ res = yield defer.ensureDeferred(
+ self.handler.get_version_info(self.local_user, "1")
+ )
self.assertEqual(res["etag"], version_etag)
del res["etag"]
self.assertDictEqual(
@@ -122,17 +132,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
)
# upload a new one...
- res = yield self.handler.create_version(
- self.local_user,
- {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "second_version_auth_data",
- },
+ res = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "second_version_auth_data",
+ },
+ )
)
self.assertEqual(res, "2")
# check we can retrieve it as the current version
- res = yield self.handler.get_version_info(self.local_user)
+ res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
del res["etag"]
self.assertDictEqual(
res,
@@ -148,25 +160,32 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_update_version(self):
"""Check that we can update versions.
"""
- version = yield self.handler.create_version(
- self.local_user,
- {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ version = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
)
self.assertEqual(version, "1")
- res = yield self.handler.update_version(
- self.local_user,
- version,
- {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "revised_first_version_auth_data",
- "version": version,
- },
+ res = yield defer.ensureDeferred(
+ self.handler.update_version(
+ self.local_user,
+ version,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "revised_first_version_auth_data",
+ "version": version,
+ },
+ )
)
self.assertDictEqual(res, {})
# check we can retrieve it as the current version
- res = yield self.handler.get_version_info(self.local_user)
+ res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
del res["etag"]
self.assertDictEqual(
res,
@@ -184,14 +203,16 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""
res = None
try:
- yield self.handler.update_version(
- self.local_user,
- "1",
- {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "revised_first_version_auth_data",
- "version": "1",
- },
+ yield defer.ensureDeferred(
+ self.handler.update_version(
+ self.local_user,
+ "1",
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "revised_first_version_auth_data",
+ "version": "1",
+ },
+ )
)
except errors.SynapseError as e:
res = e.code
@@ -201,23 +222,30 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_update_omitted_version(self):
"""Check that the update succeeds if the version is missing from the body
"""
- version = yield self.handler.create_version(
- self.local_user,
- {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ version = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
)
self.assertEqual(version, "1")
- yield self.handler.update_version(
- self.local_user,
- version,
- {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "revised_first_version_auth_data",
- },
+ yield defer.ensureDeferred(
+ self.handler.update_version(
+ self.local_user,
+ version,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "revised_first_version_auth_data",
+ },
+ )
)
# check we can retrieve it as the current version
- res = yield self.handler.get_version_info(self.local_user)
+ res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
del res["etag"] # etag is opaque, so don't test its contents
self.assertDictEqual(
res,
@@ -233,22 +261,29 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_update_bad_version(self):
"""Check that we get a 400 if the version in the body doesn't match
"""
- version = yield self.handler.create_version(
- self.local_user,
- {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ version = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
)
self.assertEqual(version, "1")
res = None
try:
- yield self.handler.update_version(
- self.local_user,
- version,
- {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "revised_first_version_auth_data",
- "version": "incorrect",
- },
+ yield defer.ensureDeferred(
+ self.handler.update_version(
+ self.local_user,
+ version,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "revised_first_version_auth_data",
+ "version": "incorrect",
+ },
+ )
)
except errors.SynapseError as e:
res = e.code
@@ -260,7 +295,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""
res = None
try:
- yield self.handler.delete_version(self.local_user, "1")
+ yield defer.ensureDeferred(
+ self.handler.delete_version(self.local_user, "1")
+ )
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 404)
@@ -271,7 +308,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""
res = None
try:
- yield self.handler.delete_version(self.local_user)
+ yield defer.ensureDeferred(self.handler.delete_version(self.local_user))
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 404)
@@ -280,19 +317,26 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_delete_version(self):
"""Check that we can create and then delete versions.
"""
- res = yield self.handler.create_version(
- self.local_user,
- {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ res = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
)
self.assertEqual(res, "1")
# check we can delete it
- yield self.handler.delete_version(self.local_user, "1")
+ yield defer.ensureDeferred(self.handler.delete_version(self.local_user, "1"))
# check that it's gone
res = None
try:
- yield self.handler.get_version_info(self.local_user, "1")
+ yield defer.ensureDeferred(
+ self.handler.get_version_info(self.local_user, "1")
+ )
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 404)
@@ -303,7 +347,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""
res = None
try:
- yield self.handler.get_room_keys(self.local_user, "bogus_version")
+ yield defer.ensureDeferred(
+ self.handler.get_room_keys(self.local_user, "bogus_version")
+ )
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 404)
@@ -312,13 +358,20 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_get_missing_room_keys(self):
"""Check we get an empty response from an empty backup
"""
- version = yield self.handler.create_version(
- self.local_user,
- {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ version = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
)
self.assertEqual(version, "1")
- res = yield self.handler.get_room_keys(self.local_user, version)
+ res = yield defer.ensureDeferred(
+ self.handler.get_room_keys(self.local_user, version)
+ )
self.assertDictEqual(res, {"rooms": {}})
# TODO: test the locking semantics when uploading room_keys,
@@ -330,8 +383,8 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""
res = None
try:
- yield self.handler.upload_room_keys(
- self.local_user, "no_version", room_keys
+ yield defer.ensureDeferred(
+ self.handler.upload_room_keys(self.local_user, "no_version", room_keys)
)
except errors.SynapseError as e:
res = e.code
@@ -342,16 +395,23 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""Check that we get a 404 on uploading keys when an nonexistent version
is specified
"""
- version = yield self.handler.create_version(
- self.local_user,
- {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ version = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
)
self.assertEqual(version, "1")
res = None
try:
- yield self.handler.upload_room_keys(
- self.local_user, "bogus_version", room_keys
+ yield defer.ensureDeferred(
+ self.handler.upload_room_keys(
+ self.local_user, "bogus_version", room_keys
+ )
)
except errors.SynapseError as e:
res = e.code
@@ -361,24 +421,33 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_upload_room_keys_wrong_version(self):
"""Check that we get a 403 on uploading keys for an old version
"""
- version = yield self.handler.create_version(
- self.local_user,
- {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ version = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
)
self.assertEqual(version, "1")
- version = yield self.handler.create_version(
- self.local_user,
- {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "second_version_auth_data",
- },
+ version = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "second_version_auth_data",
+ },
+ )
)
self.assertEqual(version, "2")
res = None
try:
- yield self.handler.upload_room_keys(self.local_user, "1", room_keys)
+ yield defer.ensureDeferred(
+ self.handler.upload_room_keys(self.local_user, "1", room_keys)
+ )
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 403)
@@ -387,26 +456,39 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_upload_room_keys_insert(self):
"""Check that we can insert and retrieve keys for a session
"""
- version = yield self.handler.create_version(
- self.local_user,
- {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ version = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
)
self.assertEqual(version, "1")
- yield self.handler.upload_room_keys(self.local_user, version, room_keys)
+ yield defer.ensureDeferred(
+ self.handler.upload_room_keys(self.local_user, version, room_keys)
+ )
- res = yield self.handler.get_room_keys(self.local_user, version)
+ res = yield defer.ensureDeferred(
+ self.handler.get_room_keys(self.local_user, version)
+ )
self.assertDictEqual(res, room_keys)
# check getting room_keys for a given room
- res = yield self.handler.get_room_keys(
- self.local_user, version, room_id="!abc:matrix.org"
+ res = yield defer.ensureDeferred(
+ self.handler.get_room_keys(
+ self.local_user, version, room_id="!abc:matrix.org"
+ )
)
self.assertDictEqual(res, room_keys)
# check getting room_keys for a given session_id
- res = yield self.handler.get_room_keys(
- self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+ res = yield defer.ensureDeferred(
+ self.handler.get_room_keys(
+ self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+ )
)
self.assertDictEqual(res, room_keys)
@@ -414,16 +496,23 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_upload_room_keys_merge(self):
"""Check that we can upload a new room_key for an existing session and
have it correctly merged"""
- version = yield self.handler.create_version(
- self.local_user,
- {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ version = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
)
self.assertEqual(version, "1")
- yield self.handler.upload_room_keys(self.local_user, version, room_keys)
+ yield defer.ensureDeferred(
+ self.handler.upload_room_keys(self.local_user, version, room_keys)
+ )
# get the etag to compare to future versions
- res = yield self.handler.get_version_info(self.local_user)
+ res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
backup_etag = res["etag"]
self.assertEqual(res["count"], 1)
@@ -433,29 +522,37 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
# test that increasing the message_index doesn't replace the existing session
new_room_key["first_message_index"] = 2
new_room_key["session_data"] = "new"
- yield self.handler.upload_room_keys(self.local_user, version, new_room_keys)
+ yield defer.ensureDeferred(
+ self.handler.upload_room_keys(self.local_user, version, new_room_keys)
+ )
- res = yield self.handler.get_room_keys(self.local_user, version)
+ res = yield defer.ensureDeferred(
+ self.handler.get_room_keys(self.local_user, version)
+ )
self.assertEqual(
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"],
"SSBBTSBBIEZJU0gK",
)
# the etag should be the same since the session did not change
- res = yield self.handler.get_version_info(self.local_user)
+ res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
self.assertEqual(res["etag"], backup_etag)
# test that marking the session as verified however /does/ replace it
new_room_key["is_verified"] = True
- yield self.handler.upload_room_keys(self.local_user, version, new_room_keys)
+ yield defer.ensureDeferred(
+ self.handler.upload_room_keys(self.local_user, version, new_room_keys)
+ )
- res = yield self.handler.get_room_keys(self.local_user, version)
+ res = yield defer.ensureDeferred(
+ self.handler.get_room_keys(self.local_user, version)
+ )
self.assertEqual(
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
)
# the etag should NOT be equal now, since the key changed
- res = yield self.handler.get_version_info(self.local_user)
+ res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
self.assertNotEqual(res["etag"], backup_etag)
backup_etag = res["etag"]
@@ -463,15 +560,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
# with a lower forwarding count
new_room_key["forwarded_count"] = 2
new_room_key["session_data"] = "other"
- yield self.handler.upload_room_keys(self.local_user, version, new_room_keys)
+ yield defer.ensureDeferred(
+ self.handler.upload_room_keys(self.local_user, version, new_room_keys)
+ )
- res = yield self.handler.get_room_keys(self.local_user, version)
+ res = yield defer.ensureDeferred(
+ self.handler.get_room_keys(self.local_user, version)
+ )
self.assertEqual(
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
)
# the etag should be the same since the session did not change
- res = yield self.handler.get_version_info(self.local_user)
+ res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
self.assertEqual(res["etag"], backup_etag)
# TODO: check edge cases as well as the common variations here
@@ -480,36 +581,59 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_delete_room_keys(self):
"""Check that we can insert and delete keys for a session
"""
- version = yield self.handler.create_version(
- self.local_user,
- {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ version = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
)
self.assertEqual(version, "1")
# check for bulk-delete
- yield self.handler.upload_room_keys(self.local_user, version, room_keys)
- yield self.handler.delete_room_keys(self.local_user, version)
- res = yield self.handler.get_room_keys(
- self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+ yield defer.ensureDeferred(
+ self.handler.upload_room_keys(self.local_user, version, room_keys)
+ )
+ yield defer.ensureDeferred(
+ self.handler.delete_room_keys(self.local_user, version)
+ )
+ res = yield defer.ensureDeferred(
+ self.handler.get_room_keys(
+ self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+ )
)
self.assertDictEqual(res, {"rooms": {}})
# check for bulk-delete per room
- yield self.handler.upload_room_keys(self.local_user, version, room_keys)
- yield self.handler.delete_room_keys(
- self.local_user, version, room_id="!abc:matrix.org"
+ yield defer.ensureDeferred(
+ self.handler.upload_room_keys(self.local_user, version, room_keys)
+ )
+ yield defer.ensureDeferred(
+ self.handler.delete_room_keys(
+ self.local_user, version, room_id="!abc:matrix.org"
+ )
)
- res = yield self.handler.get_room_keys(
- self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+ res = yield defer.ensureDeferred(
+ self.handler.get_room_keys(
+ self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+ )
)
self.assertDictEqual(res, {"rooms": {}})
# check for bulk-delete per session
- yield self.handler.upload_room_keys(self.local_user, version, room_keys)
- yield self.handler.delete_room_keys(
- self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+ yield defer.ensureDeferred(
+ self.handler.upload_room_keys(self.local_user, version, room_keys)
+ )
+ yield defer.ensureDeferred(
+ self.handler.delete_room_keys(
+ self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+ )
)
- res = yield self.handler.get_room_keys(
- self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+ res = yield defer.ensureDeferred(
+ self.handler.get_room_keys(
+ self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+ )
)
self.assertDictEqual(res, {"rooms": {}})
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 1bb25ab684..89ec5fcb31 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -75,7 +75,17 @@ COMMON_CONFIG = {
COOKIE_NAME = b"oidc_session"
COOKIE_PATH = "/_synapse/oidc"
-MockedMappingProvider = Mock(OidcMappingProvider)
+
+class TestMappingProvider(OidcMappingProvider):
+ @staticmethod
+ def parse_config(config):
+ return
+
+ def get_remote_user_id(self, userinfo):
+ return userinfo["sub"]
+
+ async def map_user_attributes(self, userinfo, token):
+ return {"localpart": userinfo["username"], "display_name": None}
def simple_async_mock(return_value=None, raises=None):
@@ -123,7 +133,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
oidc_config["issuer"] = ISSUER
oidc_config["scopes"] = SCOPES
oidc_config["user_mapping_provider"] = {
- "module": __name__ + ".MockedMappingProvider"
+ "module": __name__ + ".TestMappingProvider",
}
config["oidc_config"] = oidc_config
@@ -374,12 +384,16 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
self.handler._auth_handler.complete_sso_login = simple_async_mock()
- request = Mock(spec=["args", "getCookie", "addCookie"])
+ request = Mock(
+ spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"]
+ )
code = "code"
state = "state"
nonce = "nonce"
client_redirect_url = "http://client/redirect"
+ user_agent = "Browser"
+ ip_address = "10.0.0.1"
session = self.handler._generate_oidc_session_token(
state=state,
nonce=nonce,
@@ -392,6 +406,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
request.args[b"code"] = [code.encode("utf-8")]
request.args[b"state"] = [state.encode("utf-8")]
+ request.requestHeaders = Mock(spec=["getRawHeaders"])
+ request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")]
+ request.getClientIP.return_value = ip_address
+
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
self.handler._auth_handler.complete_sso_login.assert_called_once_with(
@@ -399,7 +417,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
- self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
+ self.handler._map_userinfo_to_user.assert_called_once_with(
+ userinfo, token, user_agent, ip_address
+ )
self.handler._fetch_userinfo.assert_not_called()
self.handler._render_error.assert_not_called()
@@ -431,7 +451,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_not_called()
- self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
+ self.handler._map_userinfo_to_user.assert_called_once_with(
+ userinfo, token, user_agent, ip_address
+ )
self.handler._fetch_userinfo.assert_called_once_with(token)
self.handler._render_error.assert_not_called()
@@ -568,3 +590,30 @@ class OidcHandlerTestCase(HomeserverTestCase):
with self.assertRaises(OidcError) as exc:
yield defer.ensureDeferred(self.handler._exchange_code(code))
self.assertEqual(exc.exception.error, "some_error")
+
+ def test_map_userinfo_to_user(self):
+ """Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
+ userinfo = {
+ "sub": "test_user",
+ "username": "test_user",
+ }
+ # The token doesn't matter with the default user mapping provider.
+ token = {}
+ mxid = self.get_success(
+ self.handler._map_userinfo_to_user(
+ userinfo, token, "user-agent", "10.10.10.10"
+ )
+ )
+ self.assertEqual(mxid, "@test_user:test")
+
+ # Some providers return an integer ID.
+ userinfo = {
+ "sub": 1234,
+ "username": "test_user_2",
+ }
+ mxid = self.get_success(
+ self.handler._map_userinfo_to_user(
+ userinfo, token, "user-agent", "10.10.10.10"
+ )
+ )
+ self.assertEqual(mxid, "@test_user_2:test")
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 05ea40a7de..306dcfe944 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -19,6 +19,7 @@ from mock import Mock, call
from signedjson.key import generate_signing_key
from synapse.api.constants import EventTypes, Membership, PresenceState
+from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events.builder import EventBuilder
from synapse.handlers.presence import (
@@ -32,7 +33,6 @@ from synapse.handlers.presence import (
handle_update,
)
from synapse.rest.client.v1 import room
-from synapse.storage.presence import UserPresenceState
from synapse.types import UserID, get_domain_from_id
from tests import unittest
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 29dd7d9c6e..8e95e53d9e 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -24,10 +24,11 @@ from synapse.handlers.profile import MasterProfileHandler
from synapse.types import UserID
from tests import unittest
+from tests.test_utils import make_awaitable
from tests.utils import setup_test_homeserver
-class ProfileHandlers(object):
+class ProfileHandlers:
def __init__(self, hs):
self.profile_handler = MasterProfileHandler(hs)
@@ -63,16 +64,20 @@ class ProfileTestCase(unittest.TestCase):
self.bob = UserID.from_string("@4567:test")
self.alice = UserID.from_string("@alice:remote")
- yield self.store.create_profile(self.frank.localpart)
+ yield defer.ensureDeferred(self.store.create_profile(self.frank.localpart))
self.handler = hs.get_profile_handler()
self.hs = hs
@defer.inlineCallbacks
def test_get_my_name(self):
- yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
+ yield defer.ensureDeferred(
+ self.store.set_profile_displayname(self.frank.localpart, "Frank")
+ )
- displayname = yield self.handler.get_displayname(self.frank)
+ displayname = yield defer.ensureDeferred(
+ self.handler.get_displayname(self.frank)
+ )
self.assertEquals("Frank", displayname)
@@ -101,7 +106,12 @@ class ProfileTestCase(unittest.TestCase):
)
self.assertEquals(
- (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank",
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_displayname(self.frank.localpart)
+ )
+ ),
+ "Frank",
)
@defer.inlineCallbacks
@@ -109,10 +119,17 @@ class ProfileTestCase(unittest.TestCase):
self.hs.config.enable_set_displayname = False
# Setting displayname for the first time is allowed
- yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
+ yield defer.ensureDeferred(
+ self.store.set_profile_displayname(self.frank.localpart, "Frank")
+ )
self.assertEquals(
- (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank",
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_displayname(self.frank.localpart)
+ )
+ ),
+ "Frank",
)
# Setting displayname a second time is forbidden
@@ -136,11 +153,13 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_other_name(self):
- self.mock_federation.make_query.return_value = defer.succeed(
+ self.mock_federation.make_query.return_value = make_awaitable(
{"displayname": "Alice"}
)
- displayname = yield self.handler.get_displayname(self.alice)
+ displayname = yield defer.ensureDeferred(
+ self.handler.get_displayname(self.alice)
+ )
self.assertEquals(displayname, "Alice")
self.mock_federation.make_query.assert_called_with(
@@ -152,22 +171,27 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_incoming_fed_query(self):
- yield self.store.create_profile("caroline")
- yield self.store.set_profile_displayname("caroline", "Caroline")
+ yield defer.ensureDeferred(self.store.create_profile("caroline"))
+ yield defer.ensureDeferred(
+ self.store.set_profile_displayname("caroline", "Caroline")
+ )
- response = yield self.query_handlers["profile"](
- {"user_id": "@caroline:test", "field": "displayname"}
+ response = yield defer.ensureDeferred(
+ self.query_handlers["profile"](
+ {"user_id": "@caroline:test", "field": "displayname"}
+ )
)
self.assertEquals({"displayname": "Caroline"}, response)
@defer.inlineCallbacks
def test_get_my_avatar(self):
- yield self.store.set_profile_avatar_url(
- self.frank.localpart, "http://my.server/me.png"
+ yield defer.ensureDeferred(
+ self.store.set_profile_avatar_url(
+ self.frank.localpart, "http://my.server/me.png"
+ )
)
-
- avatar_url = yield self.handler.get_avatar_url(self.frank)
+ avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank))
self.assertEquals("http://my.server/me.png", avatar_url)
@@ -182,7 +206,11 @@ class ProfileTestCase(unittest.TestCase):
)
self.assertEquals(
- (yield self.store.get_profile_avatar_url(self.frank.localpart)),
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_avatar_url(self.frank.localpart)
+ )
+ ),
"http://my.server/pic.gif",
)
@@ -196,7 +224,11 @@ class ProfileTestCase(unittest.TestCase):
)
self.assertEquals(
- (yield self.store.get_profile_avatar_url(self.frank.localpart)),
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_avatar_url(self.frank.localpart)
+ )
+ ),
"http://my.server/me.png",
)
@@ -205,12 +237,18 @@ class ProfileTestCase(unittest.TestCase):
self.hs.config.enable_set_avatar_url = False
# Setting displayname for the first time is allowed
- yield self.store.set_profile_avatar_url(
- self.frank.localpart, "http://my.server/me.png"
+ yield defer.ensureDeferred(
+ self.store.set_profile_avatar_url(
+ self.frank.localpart, "http://my.server/me.png"
+ )
)
self.assertEquals(
- (yield self.store.get_profile_avatar_url(self.frank.localpart)),
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_avatar_url(self.frank.localpart)
+ )
+ ),
"http://my.server/me.png",
)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index ca32f993a3..eddf5e2498 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -15,17 +15,21 @@
from mock import Mock
-from twisted.internet import defer
-
+from synapse.api.auth import Auth
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, ResourceLimitError, SynapseError
from synapse.handlers.register import RegistrationHandler
+from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import RoomAlias, UserID, create_requester
+from tests.test_utils import make_awaitable
+from tests.unittest import override_config
+from tests.utils import mock_getRawHeaders
+
from .. import unittest
-class RegistrationHandlers(object):
+class RegistrationHandlers:
def __init__(self, hs):
self.registration_handler = RegistrationHandler(hs)
@@ -96,7 +100,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_get_or_create_user_mau_not_blocked(self):
self.hs.config.limit_usage_by_mau = True
self.store.count_monthly_users = Mock(
- return_value=defer.succeed(self.hs.config.max_mau_value - 1)
+ side_effect=lambda: make_awaitable(self.hs.config.max_mau_value - 1)
)
# Ensure does not throw exception
self.get_success(self.get_or_create_user(self.requester, "c", "User"))
@@ -104,7 +108,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_get_or_create_user_mau_blocked(self):
self.hs.config.limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock(
- return_value=defer.succeed(self.lots_of_users)
+ side_effect=lambda: make_awaitable(self.lots_of_users)
)
self.get_failure(
self.get_or_create_user(self.requester, "b", "display_name"),
@@ -112,7 +116,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
)
self.store.get_monthly_active_count = Mock(
- return_value=defer.succeed(self.hs.config.max_mau_value)
+ side_effect=lambda: make_awaitable(self.hs.config.max_mau_value)
)
self.get_failure(
self.get_or_create_user(self.requester, "b", "display_name"),
@@ -122,14 +126,14 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_register_mau_blocked(self):
self.hs.config.limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock(
- return_value=defer.succeed(self.lots_of_users)
+ side_effect=lambda: make_awaitable(self.lots_of_users)
)
self.get_failure(
self.handler.register_user(localpart="local_part"), ResourceLimitError
)
self.store.get_monthly_active_count = Mock(
- return_value=defer.succeed(self.hs.config.max_mau_value)
+ side_effect=lambda: make_awaitable(self.hs.config.max_mau_value)
)
self.get_failure(
self.handler.register_user(localpart="local_part"), ResourceLimitError
@@ -145,9 +149,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
+ @override_config({"auto_join_rooms": ["#room:test"]})
def test_auto_create_auto_join_rooms(self):
room_alias_str = "#room:test"
- self.hs.config.auto_join_rooms = [room_alias_str]
user_id = self.get_success(self.handler.register_user(localpart="jeff"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
directory_handler = self.hs.get_handlers().directory_handler
@@ -185,7 +189,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str]
- self.store.is_real_user = Mock(return_value=defer.succeed(False))
+ self.store.is_real_user = Mock(return_value=make_awaitable(False))
user_id = self.get_success(self.handler.register_user(localpart="support"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
@@ -193,12 +197,12 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
room_alias = RoomAlias.from_string(room_alias_str)
self.get_failure(directory_handler.get_association(room_alias), SynapseError)
+ @override_config({"auto_join_rooms": ["#room:test"]})
def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self):
room_alias_str = "#room:test"
- self.hs.config.auto_join_rooms = [room_alias_str]
- self.store.count_real_users = Mock(return_value=defer.succeed(1))
- self.store.is_real_user = Mock(return_value=defer.succeed(True))
+ self.store.count_real_users = Mock(return_value=make_awaitable(1))
+ self.store.is_real_user = Mock(return_value=make_awaitable(True))
user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
directory_handler = self.hs.get_handlers().directory_handler
@@ -212,12 +216,218 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str]
- self.store.count_real_users = Mock(return_value=defer.succeed(2))
- self.store.is_real_user = Mock(return_value=defer.succeed(True))
+ self.store.count_real_users = Mock(return_value=make_awaitable(2))
+ self.store.is_real_user = Mock(return_value=make_awaitable(True))
user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
+ @override_config(
+ {
+ "auto_join_rooms": ["#room:test"],
+ "autocreate_auto_join_rooms_federated": False,
+ }
+ )
+ def test_auto_create_auto_join_rooms_federated(self):
+ """
+ Auto-created rooms that are private require an invite to go to the user
+ (instead of directly joining it).
+ """
+ room_alias_str = "#room:test"
+ user_id = self.get_success(self.handler.register_user(localpart="jeff"))
+
+ # Ensure the room was created.
+ directory_handler = self.hs.get_handlers().directory_handler
+ room_alias = RoomAlias.from_string(room_alias_str)
+ room_id = self.get_success(directory_handler.get_association(room_alias))
+
+ # Ensure the room is properly not federated.
+ room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
+ self.assertFalse(room["federatable"])
+ self.assertFalse(room["public"])
+ self.assertEqual(room["join_rules"], "public")
+ self.assertIsNone(room["guest_access"])
+
+ # The user should be in the room.
+ rooms = self.get_success(self.store.get_rooms_for_user(user_id))
+ self.assertIn(room_id["room_id"], rooms)
+
+ @override_config(
+ {"auto_join_rooms": ["#room:test"], "auto_join_mxid_localpart": "support"}
+ )
+ def test_auto_join_mxid_localpart(self):
+ """
+ Ensure the user still needs up in the room created by a different user.
+ """
+ # Ensure the support user exists.
+ inviter = "@support:test"
+
+ room_alias_str = "#room:test"
+ user_id = self.get_success(self.handler.register_user(localpart="jeff"))
+
+ # Ensure the room was created.
+ directory_handler = self.hs.get_handlers().directory_handler
+ room_alias = RoomAlias.from_string(room_alias_str)
+ room_id = self.get_success(directory_handler.get_association(room_alias))
+
+ # Ensure the room is properly a public room.
+ room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
+ self.assertEqual(room["join_rules"], "public")
+
+ # Both users should be in the room.
+ rooms = self.get_success(self.store.get_rooms_for_user(inviter))
+ self.assertIn(room_id["room_id"], rooms)
+ rooms = self.get_success(self.store.get_rooms_for_user(user_id))
+ self.assertIn(room_id["room_id"], rooms)
+
+ # Register a second user, which should also end up in the room.
+ user_id = self.get_success(self.handler.register_user(localpart="bob"))
+ rooms = self.get_success(self.store.get_rooms_for_user(user_id))
+ self.assertIn(room_id["room_id"], rooms)
+
+ @override_config(
+ {
+ "auto_join_rooms": ["#room:test"],
+ "autocreate_auto_join_room_preset": "private_chat",
+ "auto_join_mxid_localpart": "support",
+ }
+ )
+ def test_auto_create_auto_join_room_preset(self):
+ """
+ Auto-created rooms that are private require an invite to go to the user
+ (instead of directly joining it).
+ """
+ # Ensure the support user exists.
+ inviter = "@support:test"
+
+ room_alias_str = "#room:test"
+ user_id = self.get_success(self.handler.register_user(localpart="jeff"))
+
+ # Ensure the room was created.
+ directory_handler = self.hs.get_handlers().directory_handler
+ room_alias = RoomAlias.from_string(room_alias_str)
+ room_id = self.get_success(directory_handler.get_association(room_alias))
+
+ # Ensure the room is properly a private room.
+ room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
+ self.assertFalse(room["public"])
+ self.assertEqual(room["join_rules"], "invite")
+ self.assertEqual(room["guest_access"], "can_join")
+
+ # Both users should be in the room.
+ rooms = self.get_success(self.store.get_rooms_for_user(inviter))
+ self.assertIn(room_id["room_id"], rooms)
+ rooms = self.get_success(self.store.get_rooms_for_user(user_id))
+ self.assertIn(room_id["room_id"], rooms)
+
+ # Register a second user, which should also end up in the room.
+ user_id = self.get_success(self.handler.register_user(localpart="bob"))
+ rooms = self.get_success(self.store.get_rooms_for_user(user_id))
+ self.assertIn(room_id["room_id"], rooms)
+
+ @override_config(
+ {
+ "auto_join_rooms": ["#room:test"],
+ "autocreate_auto_join_room_preset": "private_chat",
+ "auto_join_mxid_localpart": "support",
+ }
+ )
+ def test_auto_create_auto_join_room_preset_guest(self):
+ """
+ Auto-created rooms that are private require an invite to go to the user
+ (instead of directly joining it).
+
+ This should also work for guests.
+ """
+ inviter = "@support:test"
+
+ room_alias_str = "#room:test"
+ user_id = self.get_success(
+ self.handler.register_user(localpart="jeff", make_guest=True)
+ )
+
+ # Ensure the room was created.
+ directory_handler = self.hs.get_handlers().directory_handler
+ room_alias = RoomAlias.from_string(room_alias_str)
+ room_id = self.get_success(directory_handler.get_association(room_alias))
+
+ # Ensure the room is properly a private room.
+ room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
+ self.assertFalse(room["public"])
+ self.assertEqual(room["join_rules"], "invite")
+ self.assertEqual(room["guest_access"], "can_join")
+
+ # Both users should be in the room.
+ rooms = self.get_success(self.store.get_rooms_for_user(inviter))
+ self.assertIn(room_id["room_id"], rooms)
+ rooms = self.get_success(self.store.get_rooms_for_user(user_id))
+ self.assertIn(room_id["room_id"], rooms)
+
+ @override_config(
+ {
+ "auto_join_rooms": ["#room:test"],
+ "autocreate_auto_join_room_preset": "private_chat",
+ "auto_join_mxid_localpart": "support",
+ }
+ )
+ def test_auto_create_auto_join_room_preset_invalid_permissions(self):
+ """
+ Auto-created rooms that are private require an invite, check that
+ registration doesn't completely break if the inviter doesn't have proper
+ permissions.
+ """
+ inviter = "@support:test"
+
+ # Register an initial user to create the room and such (essentially this
+ # is a subset of test_auto_create_auto_join_room_preset).
+ room_alias_str = "#room:test"
+ user_id = self.get_success(self.handler.register_user(localpart="jeff"))
+
+ # Ensure the room was created.
+ directory_handler = self.hs.get_handlers().directory_handler
+ room_alias = RoomAlias.from_string(room_alias_str)
+ room_id = self.get_success(directory_handler.get_association(room_alias))
+
+ # Ensure the room exists.
+ self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
+
+ # Both users should be in the room.
+ rooms = self.get_success(self.store.get_rooms_for_user(inviter))
+ self.assertIn(room_id["room_id"], rooms)
+ rooms = self.get_success(self.store.get_rooms_for_user(user_id))
+ self.assertIn(room_id["room_id"], rooms)
+
+ # Lower the permissions of the inviter.
+ event_creation_handler = self.hs.get_event_creation_handler()
+ requester = create_requester(inviter)
+ event, context = self.get_success(
+ event_creation_handler.create_event(
+ requester,
+ {
+ "type": "m.room.power_levels",
+ "state_key": "",
+ "room_id": room_id["room_id"],
+ "content": {"invite": 100, "users": {inviter: 0}},
+ "sender": inviter,
+ },
+ )
+ )
+ self.get_success(
+ event_creation_handler.send_nonmember_event(requester, event, context)
+ )
+
+ # Register a second user, which won't be be in the room (or even have an invite)
+ # since the inviter no longer has the proper permissions.
+ user_id = self.get_success(self.handler.register_user(localpart="bob"))
+
+ # This user should not be in any rooms.
+ rooms = self.get_success(self.store.get_rooms_for_user(user_id))
+ invited_rooms = self.get_success(
+ self.store.get_invited_rooms_for_local_user(user_id)
+ )
+ self.assertEqual(rooms, set())
+ self.assertEqual(invited_rooms, [])
+
def test_auto_create_auto_join_where_no_consent(self):
"""Test to ensure that the first user is not auto-joined to a room if
they have not given general consent.
@@ -266,6 +476,53 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.handler.register_user(localpart=invalid_user_id), SynapseError
)
+ def test_spam_checker_deny(self):
+ """A spam checker can deny registration, which results in an error."""
+
+ class DenyAll:
+ def check_registration_for_spam(
+ self, email_threepid, username, request_info
+ ):
+ return RegistrationBehaviour.DENY
+
+ # Configure a spam checker that denies all users.
+ spam_checker = self.hs.get_spam_checker()
+ spam_checker.spam_checkers = [DenyAll()]
+
+ self.get_failure(self.handler.register_user(localpart="user"), SynapseError)
+
+ def test_spam_checker_shadow_ban(self):
+ """A spam checker can choose to shadow-ban a user, which allows registration to succeed."""
+
+ class BanAll:
+ def check_registration_for_spam(
+ self, email_threepid, username, request_info
+ ):
+ return RegistrationBehaviour.SHADOW_BAN
+
+ # Configure a spam checker that denies all users.
+ spam_checker = self.hs.get_spam_checker()
+ spam_checker.spam_checkers = [BanAll()]
+
+ user_id = self.get_success(self.handler.register_user(localpart="user"))
+
+ # Get an access token.
+ token = self.macaroon_generator.generate_access_token(user_id)
+ self.get_success(
+ self.store.add_access_token_to_user(
+ user_id=user_id, token=token, device_id=None, valid_until_ms=None
+ )
+ )
+
+ # Ensure the user was marked as shadow-banned.
+ request = Mock(args={})
+ request.args[b"access_token"] = [token.encode("ascii")]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ auth = Auth(self.hs)
+ requester = self.get_success(auth.get_user_by_req(request))
+
+ self.assertTrue(requester.shadow_banned)
+
async def get_or_create_user(
self, requester, localpart, displayname, password_hash=None
):
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index d9d312f0fb..a609f148c0 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -15,7 +15,7 @@
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
-from synapse.storage.data_stores.main import stats
+from synapse.storage.databases.main import stats
from tests import unittest
@@ -42,36 +42,36 @@ class StatsRoomTests(unittest.HomeserverTestCase):
Add the background updates we need to run.
"""
# Ugh, have to reset this flag
- self.store.db.updates._all_done = False
+ self.store.db_pool.updates._all_done = False
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{"update_name": "populate_stats_prepare", "progress_json": "{}"},
)
)
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{
- "update_name": "populate_stats_process_rooms",
+ "update_name": "populate_stats_process_rooms_2",
"progress_json": "{}",
"depends_on": "populate_stats_prepare",
},
)
)
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{
"update_name": "populate_stats_process_users",
"progress_json": "{}",
- "depends_on": "populate_stats_process_rooms",
+ "depends_on": "populate_stats_process_rooms_2",
},
)
)
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{
"update_name": "populate_stats_cleanup",
@@ -81,8 +81,8 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
)
- def get_all_room_state(self):
- return self.store.db.simple_select_list(
+ async def get_all_room_state(self):
+ return await self.store.db_pool.simple_select_list(
"room_stats_state", None, retcols=("name", "topic", "canonical_alias")
)
@@ -96,7 +96,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
end_ts = self.store.quantise_stats_time(self.reactor.seconds() * 1000)
return self.get_success(
- self.store.db.simple_select_one(
+ self.store.db_pool.simple_select_one(
table + "_historical",
{id_col: stat_id, end_ts: end_ts},
cols,
@@ -109,10 +109,10 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self._add_background_updates()
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
def test_initial_room(self):
@@ -146,10 +146,10 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self._add_background_updates()
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
r = self.get_success(self.get_all_room_state())
@@ -186,9 +186,9 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# the position that the deltas should begin at, once they take over.
self.hs.config.stats_enabled = True
self.handler.stats_enabled = True
- self.store.db.updates._all_done = False
+ self.store.db_pool.updates._all_done = False
self.get_success(
- self.store.db.simple_update_one(
+ self.store.db_pool.simple_update_one(
table="stats_incremental_position",
keyvalues={},
updatevalues={"stream_id": 0},
@@ -196,17 +196,17 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{"update_name": "populate_stats_prepare", "progress_json": "{}"},
)
)
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
# Now, before the table is actually ingested, add some more events.
@@ -217,28 +217,31 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# Now do the initial ingestion.
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
- {"update_name": "populate_stats_process_rooms", "progress_json": "{}"},
+ {
+ "update_name": "populate_stats_process_rooms_2",
+ "progress_json": "{}",
+ },
)
)
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{
"update_name": "populate_stats_cleanup",
"progress_json": "{}",
- "depends_on": "populate_stats_process_rooms",
+ "depends_on": "populate_stats_process_rooms_2",
},
)
)
- self.store.db.updates._all_done = False
+ self.store.db_pool.updates._all_done = False
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
self.reactor.advance(86401)
@@ -253,7 +256,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# self.handler.notify_new_event()
# We need to let the delta processor advanceā¦
- self.pump(10 * 60)
+ self.reactor.advance(10 * 60)
# Get the slices! There should be two -- day 1, and day 2.
r = self.get_success(self.store.get_statistics_for_subject("room", room_1, 0))
@@ -346,6 +349,37 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1)
+ def test_updating_profile_information_does_not_increase_joined_members_count(self):
+ """
+ Check that the joined_members count does not increase when a user changes their
+ profile information (which is done by sending another join membership event into
+ the room.
+ """
+ self._perform_background_initial_update()
+
+ # Create a user and room
+ u1 = self.register_user("u1", "pass")
+ u1token = self.login("u1", "pass")
+ r1 = self.helper.create_room_as(u1, tok=u1token)
+
+ # Get the current room stats
+ r1stats_ante = self._get_current_stats("room", r1)
+
+ # Send a profile update into the room
+ new_profile = {"displayname": "bob"}
+ self.helper.change_membership(
+ r1, u1, u1, "join", extra_data=new_profile, tok=u1token
+ )
+
+ # Get the new room stats
+ r1stats_post = self._get_current_stats("room", r1)
+
+ # Ensure that the user count did not changed
+ self.assertEqual(r1stats_post["joined_members"], r1stats_ante["joined_members"])
+ self.assertEqual(
+ r1stats_post["local_users_in_room"], r1stats_ante["local_users_in_room"]
+ )
+
def test_send_state_event_nonoverwriting(self):
"""
When we send a non-overwriting state event, it increments total_events AND current_state_events
@@ -669,15 +703,15 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# preparation stage of the initial background update
# Ugh, have to reset this flag
- self.store.db.updates._all_done = False
+ self.store.db_pool.updates._all_done = False
self.get_success(
- self.store.db.simple_delete(
+ self.store.db_pool.simple_delete(
"room_stats_current", {"1": 1}, "test_delete_stats"
)
)
self.get_success(
- self.store.db.simple_delete(
+ self.store.db_pool.simple_delete(
"user_stats_current", {"1": 1}, "test_delete_stats"
)
)
@@ -689,29 +723,29 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# now do the background updates
- self.store.db.updates._all_done = False
+ self.store.db_pool.updates._all_done = False
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{
- "update_name": "populate_stats_process_rooms",
+ "update_name": "populate_stats_process_rooms_2",
"progress_json": "{}",
"depends_on": "populate_stats_prepare",
},
)
)
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{
"update_name": "populate_stats_process_users",
"progress_json": "{}",
- "depends_on": "populate_stats_process_rooms",
+ "depends_on": "populate_stats_process_rooms_2",
},
)
)
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{
"update_name": "populate_stats_cleanup",
@@ -722,10 +756,10 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
r1stats_complete = self._get_current_stats("room", r1)
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 2fa8d4739b..7bf15c4ba9 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -21,9 +21,10 @@ from mock import ANY, Mock, call
from twisted.internet import defer
from synapse.api.errors import AuthError
-from synapse.types import UserID
+from synapse.types import UserID, create_requester
from tests import unittest
+from tests.test_utils import make_awaitable
from tests.unittest import override_config
from tests.utils import register_federation_servlets
@@ -115,7 +116,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
retry_timings_res
)
- self.datastore.get_device_updates_by_remote.return_value = defer.succeed(
+ self.datastore.get_device_updates_by_remote.side_effect = lambda destination, from_stream_id, limit: make_awaitable(
(0, [])
)
@@ -126,9 +127,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.room_members = []
- def check_user_in_room(room_id, user_id):
+ async def check_user_in_room(room_id, user_id):
if user_id not in [u.to_string() for u in self.room_members]:
raise AuthError(401, "User is not in the room")
+ return None
hs.get_auth().check_user_in_room = check_user_in_room
@@ -137,24 +139,26 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
- def get_current_users_in_room(room_id):
- return {str(u) for u in self.room_members}
+ def get_users_in_room(room_id):
+ return defer.succeed({str(u) for u in self.room_members})
- hs.get_state_handler().get_current_users_in_room = get_current_users_in_room
+ self.datastore.get_users_in_room = get_users_in_room
- self.datastore.get_user_directory_stream_pos.return_value = (
+ self.datastore.get_user_directory_stream_pos.side_effect = (
# we deliberately return a non-None stream pos to avoid doing an initial_spam
- defer.succeed(1)
+ lambda: make_awaitable(1)
)
self.datastore.get_current_state_deltas.return_value = (0, None)
self.datastore.get_to_device_stream_token = lambda: 0
- self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: defer.succeed(
+ self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: make_awaitable(
([], 0)
)
- self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None
- self.datastore.set_received_txn_response = lambda *args, **kwargs: defer.succeed(
+ self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: make_awaitable(
+ None
+ )
+ self.datastore.set_received_txn_response = lambda *args, **kwargs: make_awaitable(
None
)
@@ -163,9 +167,12 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(self.event_source.get_current_key(), 0)
- self.successResultOf(
+ self.get_success(
self.handler.started_typing(
- target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=20000
+ target_user=U_APPLE,
+ requester=create_requester(U_APPLE),
+ room_id=ROOM_ID,
+ timeout=20000,
)
)
@@ -190,9 +197,12 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
def test_started_typing_remote_send(self):
self.room_members = [U_APPLE, U_ONION]
- self.successResultOf(
+ self.get_success(
self.handler.started_typing(
- target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=20000
+ target_user=U_APPLE,
+ requester=create_requester(U_APPLE),
+ room_id=ROOM_ID,
+ timeout=20000,
)
)
@@ -265,9 +275,11 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(self.event_source.get_current_key(), 0)
- self.successResultOf(
+ self.get_success(
self.handler.stopped_typing(
- target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID
+ target_user=U_APPLE,
+ requester=create_requester(U_APPLE),
+ room_id=ROOM_ID,
)
)
@@ -305,9 +317,12 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(self.event_source.get_current_key(), 0)
- self.successResultOf(
+ self.get_success(
self.handler.started_typing(
- target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=10000
+ target_user=U_APPLE,
+ requester=create_requester(U_APPLE),
+ room_id=ROOM_ID,
+ timeout=10000,
)
)
@@ -344,9 +359,12 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
# SYN-230 - see if we can still set after timeout
- self.successResultOf(
+ self.get_success(
self.handler.started_typing(
- target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=10000
+ target_user=U_APPLE,
+ requester=create_requester(U_APPLE),
+ room_id=ROOM_ID,
+ timeout=10000,
)
)
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index c15bce5bef..87be94111f 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -17,12 +17,13 @@ from mock import Mock
from twisted.internet import defer
import synapse.rest.admin
-from synapse.api.constants import UserTypes
+from synapse.api.constants import EventTypes, RoomEncryptionAlgorithms, UserTypes
from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import user_directory
from synapse.storage.roommember import ProfileInfo
from tests import unittest
+from tests.unittest import override_config
class UserDirectoryTestCase(unittest.HomeserverTestCase):
@@ -147,9 +148,97 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user3", 10))
self.assertEqual(len(s["results"]), 0)
+ @override_config({"encryption_enabled_by_default_for_room_type": "all"})
+ def test_encrypted_by_default_config_option_all(self):
+ """Tests that invite-only and non-invite-only rooms have encryption enabled by
+ default when the config option encryption_enabled_by_default_for_room_type is "all".
+ """
+ # Create a user
+ user = self.register_user("user", "pass")
+ user_token = self.login(user, "pass")
+
+ # Create an invite-only room as that user
+ room_id = self.helper.create_room_as(user, is_public=False, tok=user_token)
+
+ # Check that the room has an encryption state event
+ event_content = self.helper.get_state(
+ room_id=room_id, event_type=EventTypes.RoomEncryption, tok=user_token,
+ )
+ self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
+
+ # Create a non invite-only room as that user
+ room_id = self.helper.create_room_as(user, is_public=True, tok=user_token)
+
+ # Check that the room has an encryption state event
+ event_content = self.helper.get_state(
+ room_id=room_id, event_type=EventTypes.RoomEncryption, tok=user_token,
+ )
+ self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
+
+ @override_config({"encryption_enabled_by_default_for_room_type": "invite"})
+ def test_encrypted_by_default_config_option_invite(self):
+ """Tests that only new, invite-only rooms have encryption enabled by default when
+ the config option encryption_enabled_by_default_for_room_type is "invite".
+ """
+ # Create a user
+ user = self.register_user("user", "pass")
+ user_token = self.login(user, "pass")
+
+ # Create an invite-only room as that user
+ room_id = self.helper.create_room_as(user, is_public=False, tok=user_token)
+
+ # Check that the room has an encryption state event
+ event_content = self.helper.get_state(
+ room_id=room_id, event_type=EventTypes.RoomEncryption, tok=user_token,
+ )
+ self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
+
+ # Create a non invite-only room as that user
+ room_id = self.helper.create_room_as(user, is_public=True, tok=user_token)
+
+ # Check that the room does not have an encryption state event
+ self.helper.get_state(
+ room_id=room_id,
+ event_type=EventTypes.RoomEncryption,
+ tok=user_token,
+ expect_code=404,
+ )
+
+ @override_config({"encryption_enabled_by_default_for_room_type": "off"})
+ def test_encrypted_by_default_config_option_off(self):
+ """Tests that neither new invite-only nor non-invite-only rooms have encryption
+ enabled by default when the config option
+ encryption_enabled_by_default_for_room_type is "off".
+ """
+ # Create a user
+ user = self.register_user("user", "pass")
+ user_token = self.login(user, "pass")
+
+ # Create an invite-only room as that user
+ room_id = self.helper.create_room_as(user, is_public=False, tok=user_token)
+
+ # Check that the room does not have an encryption state event
+ self.helper.get_state(
+ room_id=room_id,
+ event_type=EventTypes.RoomEncryption,
+ tok=user_token,
+ expect_code=404,
+ )
+
+ # Create a non invite-only room as that user
+ room_id = self.helper.create_room_as(user, is_public=True, tok=user_token)
+
+ # Check that the room does not have an encryption state event
+ self.helper.get_state(
+ room_id=room_id,
+ event_type=EventTypes.RoomEncryption,
+ tok=user_token,
+ expect_code=404,
+ )
+
def test_spam_checker(self):
"""
- A user which fails to the spam checks will not appear in search results.
+ A user which fails the spam checks will not appear in search results.
"""
u1 = self.register_user("user1", "pass")
u1_token = self.login(u1, "pass")
@@ -180,7 +269,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# Configure a spam checker that does not filter any users.
spam_checker = self.hs.get_spam_checker()
- class AllowAll(object):
+ class AllowAll:
def check_username_for_spam(self, user_profile):
# Allow all users.
return False
@@ -193,7 +282,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(s["results"]), 1)
# Configure a spam checker that filters all users.
- class BlockAll(object):
+ class BlockAll:
def check_username_for_spam(self, user_profile):
# All users are spammy.
return True
@@ -250,7 +339,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
def get_users_in_public_rooms(self):
r = self.get_success(
- self.store.db.simple_select_list(
+ self.store.db_pool.simple_select_list(
"users_in_public_rooms", None, ("user_id", "room_id")
)
)
@@ -261,7 +350,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
def get_users_who_share_private_rooms(self):
return self.get_success(
- self.store.db.simple_select_list(
+ self.store.db_pool.simple_select_list(
"users_who_share_private_rooms",
None,
["user_id", "other_user_id", "room_id"],
@@ -273,10 +362,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
Add the background updates we need to run.
"""
# Ugh, have to reset this flag
- self.store.db.updates._all_done = False
+ self.store.db_pool.updates._all_done = False
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{
"update_name": "populate_user_directory_createtables",
@@ -285,7 +374,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{
"update_name": "populate_user_directory_process_rooms",
@@ -295,7 +384,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{
"update_name": "populate_user_directory_process_users",
@@ -305,7 +394,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{
"update_name": "populate_user_directory_cleanup",
@@ -348,10 +437,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self._add_background_updates()
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
shares_private = self.get_users_who_share_private_rooms()
@@ -387,10 +476,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self._add_background_updates()
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
shares_private = self.get_users_who_share_private_rooms()
|