diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index ee5217b074..34f72ae795 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -17,8 +17,6 @@ from mock import Mock
import pymacaroons
-from twisted.internet import defer
-
from synapse.api.auth import Auth
from synapse.api.constants import UserTypes
from synapse.api.errors import (
@@ -33,19 +31,17 @@ from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import UserID
from tests import unittest
-from tests.utils import mock_getRawHeaders, setup_test_homeserver
+from tests.test_utils import simple_async_mock
+from tests.utils import mock_getRawHeaders
-class AuthTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- self.state_handler = Mock()
+class AuthTestCase(unittest.HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
self.store = Mock()
- self.hs = yield setup_test_homeserver(self.addCleanup)
- self.hs.get_datastore = Mock(return_value=self.store)
- self.hs.get_auth_handler().store = self.store
- self.auth = Auth(self.hs)
+ hs.get_datastore = Mock(return_value=self.store)
+ hs.get_auth_handler().store = self.store
+ self.auth = Auth(hs)
# AuthBlocking reads from the hs' config on initialization. We need to
# modify its config instead of the hs'
@@ -57,64 +53,59 @@ class AuthTestCase(unittest.TestCase):
# this is overridden for the appservice tests
self.store.get_app_service_by_token = Mock(return_value=None)
- self.store.insert_client_ip = Mock(return_value=defer.succeed(None))
- self.store.is_support_user = Mock(return_value=defer.succeed(False))
+ self.store.insert_client_ip = simple_async_mock(None)
+ self.store.is_support_user = simple_async_mock(False)
- @defer.inlineCallbacks
def test_get_user_by_req_user_valid_token(self):
user_info = TokenLookupResult(
user_id=self.test_user, token_id=5, device_id="device"
)
- self.store.get_user_by_access_token = Mock(
- return_value=defer.succeed(user_info)
- )
+ self.store.get_user_by_access_token = simple_async_mock(user_info)
request = Mock(args={})
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
+ requester = self.get_success(self.auth.get_user_by_req(request))
self.assertEquals(requester.user.to_string(), self.test_user)
def test_get_user_by_req_user_bad_token(self):
- self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
+ self.store.get_user_by_access_token = simple_async_mock(None)
request = Mock(args={})
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- d = defer.ensureDeferred(self.auth.get_user_by_req(request))
- f = self.failureResultOf(d, InvalidClientTokenError).value
+ f = self.get_failure(
+ self.auth.get_user_by_req(request), InvalidClientTokenError
+ ).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_user_missing_token(self):
user_info = TokenLookupResult(user_id=self.test_user, token_id=5)
- self.store.get_user_by_access_token = Mock(
- return_value=defer.succeed(user_info)
- )
+ self.store.get_user_by_access_token = simple_async_mock(user_info)
request = Mock(args={})
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- d = defer.ensureDeferred(self.auth.get_user_by_req(request))
- f = self.failureResultOf(d, MissingClientTokenError).value
+ f = self.get_failure(
+ self.auth.get_user_by_req(request), MissingClientTokenError
+ ).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_MISSING_TOKEN")
- @defer.inlineCallbacks
def test_get_user_by_req_appservice_valid_token(self):
app_service = Mock(
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
+ self.store.get_user_by_access_token = simple_async_mock(None)
request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
+ requester = self.get_success(self.auth.get_user_by_req(request))
self.assertEquals(requester.user.to_string(), self.test_user)
- @defer.inlineCallbacks
def test_get_user_by_req_appservice_valid_token_good_ip(self):
from netaddr import IPSet
@@ -125,13 +116,13 @@ class AuthTestCase(unittest.TestCase):
ip_range_whitelist=IPSet(["192.168/16"]),
)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
+ self.store.get_user_by_access_token = simple_async_mock(None)
request = Mock(args={})
request.getClientIP.return_value = "192.168.10.10"
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
+ requester = self.get_success(self.auth.get_user_by_req(request))
self.assertEquals(requester.user.to_string(), self.test_user)
def test_get_user_by_req_appservice_valid_token_bad_ip(self):
@@ -144,42 +135,44 @@ class AuthTestCase(unittest.TestCase):
ip_range_whitelist=IPSet(["192.168/16"]),
)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
+ self.store.get_user_by_access_token = simple_async_mock(None)
request = Mock(args={})
request.getClientIP.return_value = "131.111.8.42"
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- d = defer.ensureDeferred(self.auth.get_user_by_req(request))
- f = self.failureResultOf(d, InvalidClientTokenError).value
+ f = self.get_failure(
+ self.auth.get_user_by_req(request), InvalidClientTokenError
+ ).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_appservice_bad_token(self):
self.store.get_app_service_by_token = Mock(return_value=None)
- self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
+ self.store.get_user_by_access_token = simple_async_mock(None)
request = Mock(args={})
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- d = defer.ensureDeferred(self.auth.get_user_by_req(request))
- f = self.failureResultOf(d, InvalidClientTokenError).value
+ f = self.get_failure(
+ self.auth.get_user_by_req(request), InvalidClientTokenError
+ ).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_appservice_missing_token(self):
app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
+ self.store.get_user_by_access_token = simple_async_mock(None)
request = Mock(args={})
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- d = defer.ensureDeferred(self.auth.get_user_by_req(request))
- f = self.failureResultOf(d, MissingClientTokenError).value
+ f = self.get_failure(
+ self.auth.get_user_by_req(request), MissingClientTokenError
+ ).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_MISSING_TOKEN")
- @defer.inlineCallbacks
def test_get_user_by_req_appservice_valid_token_valid_user_id(self):
masquerading_user_id = b"@doppelganger:matrix.org"
app_service = Mock(
@@ -188,17 +181,15 @@ class AuthTestCase(unittest.TestCase):
app_service.is_interested_in_user = Mock(return_value=True)
self.store.get_app_service_by_token = Mock(return_value=app_service)
# This just needs to return a truth-y value.
- self.store.get_user_by_id = Mock(
- return_value=defer.succeed({"is_guest": False})
- )
- self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
+ self.store.get_user_by_id = simple_async_mock({"is_guest": False})
+ self.store.get_user_by_access_token = simple_async_mock(None)
request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
request.args[b"user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
+ requester = self.get_success(self.auth.get_user_by_req(request))
self.assertEquals(
requester.user.to_string(), masquerading_user_id.decode("utf8")
)
@@ -210,22 +201,18 @@ class AuthTestCase(unittest.TestCase):
)
app_service.is_interested_in_user = Mock(return_value=False)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
+ self.store.get_user_by_access_token = simple_async_mock(None)
request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
request.args[b"user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- d = defer.ensureDeferred(self.auth.get_user_by_req(request))
- self.failureResultOf(d, AuthError)
+ self.get_failure(self.auth.get_user_by_req(request), AuthError)
- @defer.inlineCallbacks
def test_get_user_from_macaroon(self):
- self.store.get_user_by_access_token = Mock(
- return_value=defer.succeed(
- TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device")
- )
+ self.store.get_user_by_access_token = simple_async_mock(
+ TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device")
)
user_id = "@baldrick:matrix.org"
@@ -237,7 +224,7 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
- user_info = yield defer.ensureDeferred(
+ user_info = self.get_success(
self.auth.get_user_by_access_token(macaroon.serialize())
)
self.assertEqual(user_id, user_info.user_id)
@@ -246,10 +233,9 @@ class AuthTestCase(unittest.TestCase):
# from the db.
self.assertEqual(user_info.device_id, "device")
- @defer.inlineCallbacks
def test_get_guest_user_from_macaroon(self):
- self.store.get_user_by_id = Mock(return_value=defer.succeed({"is_guest": True}))
- self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
+ self.store.get_user_by_id = simple_async_mock({"is_guest": True})
+ self.store.get_user_by_access_token = simple_async_mock(None)
user_id = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon(
@@ -263,20 +249,17 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("guest = true")
serialized = macaroon.serialize()
- user_info = yield defer.ensureDeferred(
- self.auth.get_user_by_access_token(serialized)
- )
+ user_info = self.get_success(self.auth.get_user_by_access_token(serialized))
self.assertEqual(user_id, user_info.user_id)
self.assertTrue(user_info.is_guest)
self.store.get_user_by_id.assert_called_with(user_id)
- @defer.inlineCallbacks
def test_cannot_use_regular_token_as_guest(self):
USER_ID = "@percy:matrix.org"
- self.store.add_access_token_to_user = Mock(return_value=defer.succeed(None))
- self.store.get_device = Mock(return_value=defer.succeed(None))
+ self.store.add_access_token_to_user = simple_async_mock(None)
+ self.store.get_device = simple_async_mock(None)
- token = yield defer.ensureDeferred(
+ token = self.get_success(
self.hs.get_auth_handler().get_access_token_for_user_id(
USER_ID, "DEVICE", valid_until_ms=None
)
@@ -289,25 +272,24 @@ class AuthTestCase(unittest.TestCase):
puppets_user_id=None,
)
- def get_user(tok):
+ async def get_user(tok):
if token != tok:
- return defer.succeed(None)
- return defer.succeed(
- TokenLookupResult(
- user_id=USER_ID, is_guest=False, token_id=1234, device_id="DEVICE",
- )
+ return None
+ return TokenLookupResult(
+ user_id=USER_ID,
+ is_guest=False,
+ token_id=1234,
+ device_id="DEVICE",
)
self.store.get_user_by_access_token = get_user
- self.store.get_user_by_id = Mock(
- return_value=defer.succeed({"is_guest": False})
- )
+ self.store.get_user_by_id = simple_async_mock({"is_guest": False})
# check the token works
request = Mock(args={})
request.args[b"access_token"] = [token.encode("ascii")]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- requester = yield defer.ensureDeferred(
+ requester = self.get_success(
self.auth.get_user_by_req(request, allow_guest=True)
)
self.assertEqual(UserID.from_string(USER_ID), requester.user)
@@ -323,17 +305,16 @@ class AuthTestCase(unittest.TestCase):
request.args[b"access_token"] = [guest_tok.encode("ascii")]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- with self.assertRaises(InvalidClientCredentialsError) as cm:
- yield defer.ensureDeferred(
- self.auth.get_user_by_req(request, allow_guest=True)
- )
+ cm = self.get_failure(
+ self.auth.get_user_by_req(request, allow_guest=True),
+ InvalidClientCredentialsError,
+ )
- self.assertEqual(401, cm.exception.code)
- self.assertEqual("Guest access token used for regular user", cm.exception.msg)
+ self.assertEqual(401, cm.value.code)
+ self.assertEqual("Guest access token used for regular user", cm.value.msg)
self.store.get_user_by_id.assert_called_with(USER_ID)
- @defer.inlineCallbacks
def test_blocking_mau(self):
self.auth_blocking._limit_usage_by_mau = False
self.auth_blocking._max_mau_value = 50
@@ -341,77 +322,61 @@ class AuthTestCase(unittest.TestCase):
small_number_of_users = 1
# Ensure no error thrown
- yield defer.ensureDeferred(self.auth.check_auth_blocking())
+ self.get_success(self.auth.check_auth_blocking())
self.auth_blocking._limit_usage_by_mau = True
- self.store.get_monthly_active_count = Mock(
- return_value=defer.succeed(lots_of_users)
- )
+ self.store.get_monthly_active_count = simple_async_mock(lots_of_users)
- with self.assertRaises(ResourceLimitError) as e:
- yield defer.ensureDeferred(self.auth.check_auth_blocking())
- self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
- self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
- self.assertEquals(e.exception.code, 403)
+ e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
+ self.assertEquals(e.value.admin_contact, self.hs.config.admin_contact)
+ self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+ self.assertEquals(e.value.code, 403)
# Ensure does not throw an error
- self.store.get_monthly_active_count = Mock(
- return_value=defer.succeed(small_number_of_users)
- )
- yield defer.ensureDeferred(self.auth.check_auth_blocking())
+ self.store.get_monthly_active_count = simple_async_mock(small_number_of_users)
+ self.get_success(self.auth.check_auth_blocking())
- @defer.inlineCallbacks
def test_blocking_mau__depending_on_user_type(self):
self.auth_blocking._max_mau_value = 50
self.auth_blocking._limit_usage_by_mau = True
- self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
+ self.store.get_monthly_active_count = simple_async_mock(100)
# Support users allowed
- yield defer.ensureDeferred(
- self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT)
- )
- self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
+ self.get_success(self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT))
+ self.store.get_monthly_active_count = simple_async_mock(100)
# Bots not allowed
- with self.assertRaises(ResourceLimitError):
- yield defer.ensureDeferred(
- self.auth.check_auth_blocking(user_type=UserTypes.BOT)
- )
- self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
+ self.get_failure(
+ self.auth.check_auth_blocking(user_type=UserTypes.BOT), ResourceLimitError
+ )
+ self.store.get_monthly_active_count = simple_async_mock(100)
# Real users not allowed
- with self.assertRaises(ResourceLimitError):
- yield defer.ensureDeferred(self.auth.check_auth_blocking())
+ self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
- @defer.inlineCallbacks
def test_reserved_threepid(self):
self.auth_blocking._limit_usage_by_mau = True
self.auth_blocking._max_mau_value = 1
- self.store.get_monthly_active_count = lambda: defer.succeed(2)
+ self.store.get_monthly_active_count = simple_async_mock(2)
threepid = {"medium": "email", "address": "reserved@server.com"}
unknown_threepid = {"medium": "email", "address": "unreserved@server.com"}
self.auth_blocking._mau_limits_reserved_threepids = [threepid]
- with self.assertRaises(ResourceLimitError):
- yield defer.ensureDeferred(self.auth.check_auth_blocking())
+ self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
- with self.assertRaises(ResourceLimitError):
- yield defer.ensureDeferred(
- self.auth.check_auth_blocking(threepid=unknown_threepid)
- )
+ self.get_failure(
+ self.auth.check_auth_blocking(threepid=unknown_threepid), ResourceLimitError
+ )
- yield defer.ensureDeferred(self.auth.check_auth_blocking(threepid=threepid))
+ self.get_success(self.auth.check_auth_blocking(threepid=threepid))
- @defer.inlineCallbacks
def test_hs_disabled(self):
self.auth_blocking._hs_disabled = True
self.auth_blocking._hs_disabled_message = "Reason for being disabled"
- with self.assertRaises(ResourceLimitError) as e:
- yield defer.ensureDeferred(self.auth.check_auth_blocking())
- self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
- self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
- self.assertEquals(e.exception.code, 403)
+ e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
+ self.assertEquals(e.value.admin_contact, self.hs.config.admin_contact)
+ self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+ self.assertEquals(e.value.code, 403)
- @defer.inlineCallbacks
def test_hs_disabled_no_server_notices_user(self):
"""Check that 'hs_disabled_message' works correctly when there is no
server_notices user.
@@ -422,16 +387,14 @@ class AuthTestCase(unittest.TestCase):
self.auth_blocking._hs_disabled = True
self.auth_blocking._hs_disabled_message = "Reason for being disabled"
- with self.assertRaises(ResourceLimitError) as e:
- yield defer.ensureDeferred(self.auth.check_auth_blocking())
- self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
- self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
- self.assertEquals(e.exception.code, 403)
+ e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
+ self.assertEquals(e.value.admin_contact, self.hs.config.admin_contact)
+ self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+ self.assertEquals(e.value.code, 403)
- @defer.inlineCallbacks
def test_server_notices_mxid_special_cased(self):
self.auth_blocking._hs_disabled = True
user = "@user:server"
self.auth_blocking._server_notices_mxid = user
self.auth_blocking._hs_disabled_message = "Reason for being disabled"
- yield defer.ensureDeferred(self.auth.check_auth_blocking(user))
+ self.get_success(self.auth.check_auth_blocking(user))
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 279c94a03d..ab7d290724 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -18,15 +18,12 @@
import jsonschema
-from twisted.internet import defer
-
from synapse.api.constants import EventContentFields
from synapse.api.errors import SynapseError
from synapse.api.filtering import Filter
from synapse.events import make_event_from_dict
from tests import unittest
-from tests.utils import setup_test_homeserver
user_localpart = "test_user"
@@ -39,9 +36,8 @@ def MockEvent(**kwargs):
return make_event_from_dict(kwargs)
-class FilteringTestCase(unittest.TestCase):
- def setUp(self):
- hs = setup_test_homeserver(self.addCleanup)
+class FilteringTestCase(unittest.HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
self.filtering = hs.get_filtering()
self.datastore = hs.get_datastore()
@@ -351,10 +347,9 @@ class FilteringTestCase(unittest.TestCase):
self.assertTrue(Filter(definition).check(event))
- @defer.inlineCallbacks
def test_filter_presence_match(self):
user_filter_json = {"presence": {"types": ["m.*"]}}
- filter_id = yield defer.ensureDeferred(
+ filter_id = self.get_success(
self.datastore.add_user_filter(
user_localpart=user_localpart, user_filter=user_filter_json
)
@@ -362,7 +357,7 @@ class FilteringTestCase(unittest.TestCase):
event = MockEvent(sender="@foo:bar", type="m.profile")
events = [event]
- user_filter = yield defer.ensureDeferred(
+ user_filter = self.get_success(
self.filtering.get_user_filter(
user_localpart=user_localpart, filter_id=filter_id
)
@@ -371,11 +366,10 @@ class FilteringTestCase(unittest.TestCase):
results = user_filter.filter_presence(events=events)
self.assertEquals(events, results)
- @defer.inlineCallbacks
def test_filter_presence_no_match(self):
user_filter_json = {"presence": {"types": ["m.*"]}}
- filter_id = yield defer.ensureDeferred(
+ filter_id = self.get_success(
self.datastore.add_user_filter(
user_localpart=user_localpart + "2", user_filter=user_filter_json
)
@@ -387,7 +381,7 @@ class FilteringTestCase(unittest.TestCase):
)
events = [event]
- user_filter = yield defer.ensureDeferred(
+ user_filter = self.get_success(
self.filtering.get_user_filter(
user_localpart=user_localpart + "2", filter_id=filter_id
)
@@ -396,10 +390,9 @@ class FilteringTestCase(unittest.TestCase):
results = user_filter.filter_presence(events=events)
self.assertEquals([], results)
- @defer.inlineCallbacks
def test_filter_room_state_match(self):
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
- filter_id = yield defer.ensureDeferred(
+ filter_id = self.get_success(
self.datastore.add_user_filter(
user_localpart=user_localpart, user_filter=user_filter_json
)
@@ -407,7 +400,7 @@ class FilteringTestCase(unittest.TestCase):
event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
events = [event]
- user_filter = yield defer.ensureDeferred(
+ user_filter = self.get_success(
self.filtering.get_user_filter(
user_localpart=user_localpart, filter_id=filter_id
)
@@ -416,10 +409,9 @@ class FilteringTestCase(unittest.TestCase):
results = user_filter.filter_room_state(events=events)
self.assertEquals(events, results)
- @defer.inlineCallbacks
def test_filter_room_state_no_match(self):
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
- filter_id = yield defer.ensureDeferred(
+ filter_id = self.get_success(
self.datastore.add_user_filter(
user_localpart=user_localpart, user_filter=user_filter_json
)
@@ -429,7 +421,7 @@ class FilteringTestCase(unittest.TestCase):
)
events = [event]
- user_filter = yield defer.ensureDeferred(
+ user_filter = self.get_success(
self.filtering.get_user_filter(
user_localpart=user_localpart, filter_id=filter_id
)
@@ -454,11 +446,10 @@ class FilteringTestCase(unittest.TestCase):
self.assertEquals(filtered_room_ids, ["!allowed:example.com"])
- @defer.inlineCallbacks
def test_add_filter(self):
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
- filter_id = yield defer.ensureDeferred(
+ filter_id = self.get_success(
self.filtering.add_user_filter(
user_localpart=user_localpart, user_filter=user_filter_json
)
@@ -468,7 +459,7 @@ class FilteringTestCase(unittest.TestCase):
self.assertEquals(
user_filter_json,
(
- yield defer.ensureDeferred(
+ self.get_success(
self.datastore.get_user_filter(
user_localpart=user_localpart, filter_id=0
)
@@ -476,17 +467,16 @@ class FilteringTestCase(unittest.TestCase):
),
)
- @defer.inlineCallbacks
def test_get_filter(self):
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
- filter_id = yield defer.ensureDeferred(
+ filter_id = self.get_success(
self.datastore.add_user_filter(
user_localpart=user_localpart, user_filter=user_filter_json
)
)
- filter = yield defer.ensureDeferred(
+ filter = self.get_success(
self.filtering.get_user_filter(
user_localpart=user_localpart, filter_id=filter_id
)
diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index fe504d0869..483418192c 100644
--- a/tests/api/test_ratelimiting.py
+++ b/tests/api/test_ratelimiting.py
@@ -43,7 +43,11 @@ class TestRatelimiter(unittest.TestCase):
def test_allowed_appservice_ratelimited_via_can_requester_do_action(self):
appservice = ApplicationService(
- None, "example.com", id="foo", rate_limited=True, sender="@as:example.com",
+ None,
+ "example.com",
+ id="foo",
+ rate_limited=True,
+ sender="@as:example.com",
)
as_requester = create_requester("@user:example.com", app_service=appservice)
@@ -68,7 +72,11 @@ class TestRatelimiter(unittest.TestCase):
def test_allowed_appservice_via_can_requester_do_action(self):
appservice = ApplicationService(
- None, "example.com", id="foo", rate_limited=False, sender="@as:example.com",
+ None,
+ "example.com",
+ id="foo",
+ rate_limited=False,
+ sender="@as:example.com",
)
as_requester = create_requester("@user:example.com", app_service=appservice)
@@ -113,12 +121,18 @@ class TestRatelimiter(unittest.TestCase):
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
# First attempt should be allowed
- allowed, time_allowed = limiter.can_do_action(("test_id",), _time_now_s=0,)
+ allowed, time_allowed = limiter.can_do_action(
+ ("test_id",),
+ _time_now_s=0,
+ )
self.assertTrue(allowed)
self.assertEqual(10.0, time_allowed)
# Second attempt, 1s later, will fail
- allowed, time_allowed = limiter.can_do_action(("test_id",), _time_now_s=1,)
+ allowed, time_allowed = limiter.can_do_action(
+ ("test_id",),
+ _time_now_s=1,
+ )
self.assertFalse(allowed)
self.assertEqual(10.0, time_allowed)
diff --git a/tests/config/test_cache.py b/tests/config/test_cache.py
index d3ec24c975..2b7f09c14b 100644
--- a/tests/config/test_cache.py
+++ b/tests/config/test_cache.py
@@ -127,8 +127,7 @@ class CacheConfigTests(TestCase):
self.assertEqual(cache.max_size, 150)
def test_cache_with_asterisk_in_name(self):
- """Some caches have asterisks in their name, test that they are set correctly.
- """
+ """Some caches have asterisks in their name, test that they are set correctly."""
config = {
"caches": {
@@ -164,7 +163,8 @@ class CacheConfigTests(TestCase):
t.read_config(config, config_dir_path="", data_dir_path="")
cache = LruCache(
- max_size=t.caches.event_cache_size, apply_cache_factor_from_config=False,
+ max_size=t.caches.event_cache_size,
+ apply_cache_factor_from_config=False,
)
add_resizable_cache("event_cache", cache_resize_callback=cache.set_cache_factor)
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 1d65ea2f9c..30fcc4c1bf 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -400,7 +400,10 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
)
def build_perspectives_response(
- self, server_name: str, signing_key: SigningKey, valid_until_ts: int,
+ self,
+ server_name: str,
+ signing_key: SigningKey,
+ valid_until_ts: int,
) -> dict:
"""
Build a valid perspectives server response to a request for the given key
@@ -455,7 +458,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
VALID_UNTIL_TS = 200 * 1000
response = self.build_perspectives_response(
- SERVER_NAME, testkey, VALID_UNTIL_TS,
+ SERVER_NAME,
+ testkey,
+ VALID_UNTIL_TS,
)
self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py
index 3a80626224..ec85324c0c 100644
--- a/tests/events/test_snapshot.py
+++ b/tests/events/test_snapshot.py
@@ -43,7 +43,10 @@ class TestEventContext(unittest.HomeserverTestCase):
event, context = self.get_success(
create_event(
- self.hs, room_id=self.room_id, type="m.test", sender=self.user_id,
+ self.hs,
+ room_id=self.room_id,
+ type="m.test",
+ sender=self.user_id,
)
)
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 9ccd2d76b8..8186b8ca01 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -150,8 +150,8 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
)
# Artificially raise the complexity
- self.hs.get_datastore().get_current_state_event_counts = lambda x: make_awaitable(
- 600
+ self.hs.get_datastore().get_current_state_event_counts = (
+ lambda x: make_awaitable(600)
)
d = handler._remote_join(
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 917762e6b6..ecc3faa572 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -279,7 +279,8 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
ret = self.get_success(
e2e_handler.upload_signatures_for_device_keys(
- u1, {u1: {"D1": d1_json, "D2": d2_json}},
+ u1,
+ {u1: {"D1": d1_json, "D2": d2_json}},
)
)
self.assertEqual(ret["failures"], {})
@@ -486,9 +487,11 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.assertGreaterEqual(content["stream_id"], prev_stream_id)
return content["stream_id"]
- def check_signing_key_update_txn(self, txn: JsonDict,) -> None:
- """Check that the txn has an EDU with a signing key update.
- """
+ def check_signing_key_update_txn(
+ self,
+ txn: JsonDict,
+ ) -> None:
+ """Check that the txn has an EDU with a signing key update."""
edus = txn["edus"]
self.assertEqual(len(edus), 1)
@@ -502,7 +505,9 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.get_success(
self.hs.get_e2e_keys_handler().upload_keys_for_user(
- user_id, device_id, {"device_keys": device_dict},
+ user_id,
+ device_id,
+ {"device_keys": device_dict},
)
)
return sk
diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py
index 5c2b4de1a6..a01fdd0839 100644
--- a/tests/handlers/test_admin.py
+++ b/tests/handlers/test_admin.py
@@ -44,8 +44,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.token2 = self.login("user2", "password")
def test_single_public_joined_room(self):
- """Test that we write *all* events for a public room
- """
+ """Test that we write *all* events for a public room"""
room_id = self.helper.create_room_as(
self.user1, tok=self.token1, is_public=True
)
@@ -116,8 +115,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.assertEqual(counter[(EventTypes.Member, self.user2)], 1)
def test_single_left_room(self):
- """Tests that we don't see events in the room after we leave.
- """
+ """Tests that we don't see events in the room after we leave."""
room_id = self.helper.create_room_as(self.user1, tok=self.token1)
self.helper.send(room_id, body="Hello!", tok=self.token1)
self.helper.join(room_id, self.user2, tok=self.token2)
@@ -190,8 +188,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.assertEqual(counter[(EventTypes.Member, self.user2)], 3)
def test_invite(self):
- """Tests that pending invites get handled correctly.
- """
+ """Tests that pending invites get handled correctly."""
room_id = self.helper.create_room_as(self.user1, tok=self.token1)
self.helper.send(room_id, body="Hello!", tok=self.token1)
self.helper.invite(room_id, self.user1, self.user2, tok=self.token1)
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 53763cd0f9..d5d3fdd99a 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -35,8 +35,8 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.mock_scheduler = Mock()
hs = Mock()
hs.get_datastore.return_value = self.mock_store
- self.mock_store.get_received_ts.return_value = defer.succeed(0)
- self.mock_store.set_appservice_last_pos.return_value = defer.succeed(None)
+ self.mock_store.get_received_ts.return_value = make_awaitable(0)
+ self.mock_store.set_appservice_last_pos.return_value = make_awaitable(None)
hs.get_application_service_api.return_value = self.mock_as_api
hs.get_application_service_scheduler.return_value = self.mock_scheduler
hs.get_clock.return_value = MockClock()
@@ -50,16 +50,16 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self._mkservice(is_interested=False),
]
- self.mock_as_api.query_user.return_value = defer.succeed(True)
+ self.mock_as_api.query_user.return_value = make_awaitable(True)
self.mock_store.get_app_services.return_value = services
- self.mock_store.get_user_by_id.return_value = defer.succeed([])
+ self.mock_store.get_user_by_id.return_value = make_awaitable([])
event = Mock(
sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar"
)
self.mock_store.get_new_events_for_appservice.side_effect = [
- defer.succeed((0, [event])),
- defer.succeed((0, [])),
+ make_awaitable((0, [event])),
+ make_awaitable((0, [])),
]
self.handler.notify_interested_services(RoomStreamToken(None, 0))
@@ -72,13 +72,13 @@ class AppServiceHandlerTestCase(unittest.TestCase):
services = [self._mkservice(is_interested=True)]
services[0].is_interested_in_user.return_value = True
self.mock_store.get_app_services.return_value = services
- self.mock_store.get_user_by_id.return_value = defer.succeed(None)
+ self.mock_store.get_user_by_id.return_value = make_awaitable(None)
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
- self.mock_as_api.query_user.return_value = defer.succeed(True)
+ self.mock_as_api.query_user.return_value = make_awaitable(True)
self.mock_store.get_new_events_for_appservice.side_effect = [
- defer.succeed((0, [event])),
- defer.succeed((0, [])),
+ make_awaitable((0, [event])),
+ make_awaitable((0, [])),
]
self.handler.notify_interested_services(RoomStreamToken(None, 0))
@@ -90,13 +90,13 @@ class AppServiceHandlerTestCase(unittest.TestCase):
services = [self._mkservice(is_interested=True)]
services[0].is_interested_in_user.return_value = True
self.mock_store.get_app_services.return_value = services
- self.mock_store.get_user_by_id.return_value = defer.succeed({"name": user_id})
+ self.mock_store.get_user_by_id.return_value = make_awaitable({"name": user_id})
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
- self.mock_as_api.query_user.return_value = defer.succeed(True)
+ self.mock_as_api.query_user.return_value = make_awaitable(True)
self.mock_store.get_new_events_for_appservice.side_effect = [
- defer.succeed((0, [event])),
- defer.succeed((0, [])),
+ make_awaitable((0, [event])),
+ make_awaitable((0, [])),
]
self.handler.notify_interested_services(RoomStreamToken(None, 0))
@@ -106,7 +106,6 @@ class AppServiceHandlerTestCase(unittest.TestCase):
"query_user called when it shouldn't have been.",
)
- @defer.inlineCallbacks
def test_query_room_alias_exists(self):
room_alias_str = "#foo:bar"
room_alias = Mock()
@@ -127,8 +126,8 @@ class AppServiceHandlerTestCase(unittest.TestCase):
Mock(room_id=room_id, servers=servers)
)
- result = yield defer.ensureDeferred(
- self.handler.query_room_alias_exists(room_alias)
+ result = self.successResultOf(
+ defer.ensureDeferred(self.handler.query_room_alias_exists(room_alias))
)
self.mock_as_api.query_alias.assert_called_once_with(
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index e24ce81284..0e42013bb9 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -16,28 +16,21 @@ from mock import Mock
import pymacaroons
-from twisted.internet import defer
-
-import synapse
-import synapse.api.errors
-from synapse.api.errors import ResourceLimitError
+from synapse.api.errors import AuthError, ResourceLimitError
from tests import unittest
from tests.test_utils import make_awaitable
-from tests.utils import setup_test_homeserver
-class AuthTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- self.hs = yield setup_test_homeserver(self.addCleanup)
- self.auth_handler = self.hs.get_auth_handler()
- self.macaroon_generator = self.hs.get_macaroon_generator()
+class AuthTestCase(unittest.HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
+ self.auth_handler = hs.get_auth_handler()
+ self.macaroon_generator = hs.get_macaroon_generator()
# MAU tests
# AuthBlocking reads from the hs' config on initialization. We need to
# modify its config instead of the hs'
- self.auth_blocking = self.hs.get_auth()._auth_blocking
+ self.auth_blocking = hs.get_auth()._auth_blocking
self.auth_blocking._max_mau_value = 50
self.small_number_of_users = 1
@@ -52,8 +45,6 @@ class AuthTestCase(unittest.TestCase):
self.fail("some_user was not in %s" % macaroon.inspect())
def test_macaroon_caveats(self):
- self.hs.get_clock().now = 5000
-
token = self.macaroon_generator.generate_access_token("a_user")
macaroon = pymacaroons.Macaroon.deserialize(token)
@@ -76,29 +67,25 @@ class AuthTestCase(unittest.TestCase):
v.satisfy_general(verify_nonce)
v.verify(macaroon, self.hs.config.macaroon_secret_key)
- @defer.inlineCallbacks
def test_short_term_login_token_gives_user_id(self):
- self.hs.get_clock().now = 1000
-
token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
- user_id = yield defer.ensureDeferred(
+ user_id = self.get_success(
self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
)
self.assertEqual("a_user", user_id)
# when we advance the clock, the token should be rejected
- self.hs.get_clock().now = 6000
- with self.assertRaises(synapse.api.errors.AuthError):
- yield defer.ensureDeferred(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
- )
+ self.reactor.advance(6)
+ self.get_failure(
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(token),
+ AuthError,
+ )
- @defer.inlineCallbacks
def test_short_term_login_token_cannot_replace_user_id(self):
token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
macaroon = pymacaroons.Macaroon.deserialize(token)
- user_id = yield defer.ensureDeferred(
+ user_id = self.get_success(
self.auth_handler.validate_short_term_login_token_and_get_user_id(
macaroon.serialize()
)
@@ -109,102 +96,90 @@ class AuthTestCase(unittest.TestCase):
# user_id.
macaroon.add_first_party_caveat("user_id = b_user")
- with self.assertRaises(synapse.api.errors.AuthError):
- yield defer.ensureDeferred(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(
- macaroon.serialize()
- )
- )
+ self.get_failure(
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ macaroon.serialize()
+ ),
+ AuthError,
+ )
- @defer.inlineCallbacks
def test_mau_limits_disabled(self):
self.auth_blocking._limit_usage_by_mau = False
# Ensure does not throw exception
- yield defer.ensureDeferred(
+ self.get_success(
self.auth_handler.get_access_token_for_user_id(
"user_a", device_id=None, valid_until_ms=None
)
)
- yield defer.ensureDeferred(
+ self.get_success(
self.auth_handler.validate_short_term_login_token_and_get_user_id(
self._get_macaroon().serialize()
)
)
- @defer.inlineCallbacks
def test_mau_limits_exceeded_large(self):
self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=make_awaitable(self.large_number_of_users)
)
- with self.assertRaises(ResourceLimitError):
- yield defer.ensureDeferred(
- self.auth_handler.get_access_token_for_user_id(
- "user_a", device_id=None, valid_until_ms=None
- )
- )
+ self.get_failure(
+ self.auth_handler.get_access_token_for_user_id(
+ "user_a", device_id=None, valid_until_ms=None
+ ),
+ ResourceLimitError,
+ )
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=make_awaitable(self.large_number_of_users)
)
- with self.assertRaises(ResourceLimitError):
- yield defer.ensureDeferred(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(
- self._get_macaroon().serialize()
- )
- )
+ self.get_failure(
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self._get_macaroon().serialize()
+ ),
+ ResourceLimitError,
+ )
- @defer.inlineCallbacks
def test_mau_limits_parity(self):
+ # Ensure we're not at the unix epoch.
+ self.reactor.advance(1)
self.auth_blocking._limit_usage_by_mau = True
- # If not in monthly active cohort
+ # Set the server to be at the edge of too many users.
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=make_awaitable(self.auth_blocking._max_mau_value)
)
- with self.assertRaises(ResourceLimitError):
- yield defer.ensureDeferred(
- self.auth_handler.get_access_token_for_user_id(
- "user_a", device_id=None, valid_until_ms=None
- )
- )
- self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=make_awaitable(self.auth_blocking._max_mau_value)
+ # If not in monthly active cohort
+ self.get_failure(
+ self.auth_handler.get_access_token_for_user_id(
+ "user_a", device_id=None, valid_until_ms=None
+ ),
+ ResourceLimitError,
)
- with self.assertRaises(ResourceLimitError):
- yield defer.ensureDeferred(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(
- self._get_macaroon().serialize()
- )
- )
+ self.get_failure(
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self._get_macaroon().serialize()
+ ),
+ ResourceLimitError,
+ )
+
# If in monthly active cohort
self.hs.get_datastore().user_last_seen_monthly_active = Mock(
- return_value=make_awaitable(self.hs.get_clock().time_msec())
+ return_value=make_awaitable(self.clock.time_msec())
)
- self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=make_awaitable(self.auth_blocking._max_mau_value)
- )
- yield defer.ensureDeferred(
+ self.get_success(
self.auth_handler.get_access_token_for_user_id(
"user_a", device_id=None, valid_until_ms=None
)
)
- self.hs.get_datastore().user_last_seen_monthly_active = Mock(
- return_value=make_awaitable(self.hs.get_clock().time_msec())
- )
- self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=make_awaitable(self.auth_blocking._max_mau_value)
- )
- yield defer.ensureDeferred(
+ self.get_success(
self.auth_handler.validate_short_term_login_token_and_get_user_id(
self._get_macaroon().serialize()
)
)
- @defer.inlineCallbacks
def test_mau_limits_not_exceeded(self):
self.auth_blocking._limit_usage_by_mau = True
@@ -212,7 +187,7 @@ class AuthTestCase(unittest.TestCase):
return_value=make_awaitable(self.small_number_of_users)
)
# Ensure does not raise exception
- yield defer.ensureDeferred(
+ self.get_success(
self.auth_handler.get_access_token_for_user_id(
"user_a", device_id=None, valid_until_ms=None
)
@@ -221,7 +196,7 @@ class AuthTestCase(unittest.TestCase):
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=make_awaitable(self.small_number_of_users)
)
- yield defer.ensureDeferred(
+ self.get_success(
self.auth_handler.validate_short_term_login_token_and_get_user_id(
self._get_macaroon().serialize()
)
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index 7baf224f7e..6f992291b8 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -16,7 +16,7 @@ from mock import Mock
from synapse.handlers.cas_handler import CasResponse
from tests.test_utils import simple_async_mock
-from tests.unittest import HomeserverTestCase
+from tests.unittest import HomeserverTestCase, override_config
# These are a few constants that are used as config parameters in the tests.
BASE_URL = "https://synapse/"
@@ -32,6 +32,10 @@ class CasHandlerTestCase(HomeserverTestCase):
"server_url": SERVER_URL,
"service_url": BASE_URL,
}
+
+ # Update this config with what's in the default config so that
+ # override_config works as expected.
+ cas_config.update(config.get("cas_config", {}))
config["cas_config"] = cas_config
return config
@@ -115,7 +119,51 @@ class CasHandlerTestCase(HomeserverTestCase):
"@f=c3=b6=c3=b6:test", request, "redirect_uri", None, new_user=True
)
+ @override_config(
+ {
+ "cas_config": {
+ "required_attributes": {"userGroup": "staff", "department": None}
+ }
+ }
+ )
+ def test_required_attributes(self):
+ """The required attributes must be met from the CAS response."""
+
+ # stub out the auth handler
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
+ # The response doesn't have the proper userGroup or department.
+ cas_response = CasResponse("test_user", {})
+ request = _mock_request()
+ self.get_success(
+ self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+ )
+ auth_handler.complete_sso_login.assert_not_called()
+
+ # The response doesn't have any department.
+ cas_response = CasResponse("test_user", {"userGroup": "staff"})
+ request.reset_mock()
+ self.get_success(
+ self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+ )
+ auth_handler.complete_sso_login.assert_not_called()
+
+ # Add the proper attributes and it should succeed.
+ cas_response = CasResponse(
+ "test_user", {"userGroup": ["staff", "admin"], "department": ["sales"]}
+ )
+ request.reset_mock()
+ self.get_success(
+ self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+ )
+
+ # check that the auth handler got called as expected
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user:test", request, "redirect_uri", None, new_user=True
+ )
+
def _mock_request():
"""Returns a mock which will stand in as a SynapseRequest"""
- return Mock(spec=["getClientIP", "getHeader"])
+ return Mock(spec=["getClientIP", "getHeader", "_disconnected"])
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 5dfeccfeb6..821629bc38 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -260,7 +260,9 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
# Create a new login for the user and dehydrated the device
device_id, access_token = self.get_success(
self.registration.register_device(
- user_id=user_id, device_id=None, initial_display_name="new device",
+ user_id=user_id,
+ device_id=None,
+ initial_display_name="new device",
)
)
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index ebc6a0866a..fadec16e13 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -133,7 +133,9 @@ class TestCreateAlias(unittest.HomeserverTestCase):
"""A user can create an alias for a room they're in."""
self.get_success(
self.handler.create_association(
- create_requester(self.test_user), self.room_alias, self.room_id,
+ create_requester(self.test_user),
+ self.room_alias,
+ self.room_id,
)
)
@@ -145,7 +147,9 @@ class TestCreateAlias(unittest.HomeserverTestCase):
self.get_failure(
self.handler.create_association(
- create_requester(self.test_user), self.room_alias, other_room_id,
+ create_requester(self.test_user),
+ self.room_alias,
+ other_room_id,
),
synapse.api.errors.SynapseError,
)
@@ -158,7 +162,9 @@ class TestCreateAlias(unittest.HomeserverTestCase):
self.get_success(
self.handler.create_association(
- create_requester(self.admin_user), self.room_alias, other_room_id,
+ create_requester(self.admin_user),
+ self.room_alias,
+ other_room_id,
)
)
@@ -277,8 +283,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
class CanonicalAliasTestCase(unittest.HomeserverTestCase):
- """Test modifications of the canonical alias when delete aliases.
- """
+ """Test modifications of the canonical alias when delete aliases."""
servlets = [
synapse.rest.admin.register_servlets,
@@ -319,7 +324,10 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
def _set_canonical_alias(self, content):
"""Configure the canonical alias state on the room."""
self.helper.send_state(
- self.room_id, "m.room.canonical_alias", content, tok=self.admin_user_tok,
+ self.room_id,
+ "m.room.canonical_alias",
+ content,
+ tok=self.admin_user_tok,
)
def _get_canonical_alias(self):
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 924f29f051..5e86c5e56b 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -18,42 +18,26 @@ import mock
from signedjson import key as key, sign as sign
-from twisted.internet import defer
-
-import synapse.handlers.e2e_keys
-import synapse.storage
-from synapse.api import errors
from synapse.api.constants import RoomEncryptionAlgorithms
+from synapse.api.errors import Codes, SynapseError
-from tests import unittest, utils
+from tests import unittest
-class E2eKeysHandlerTestCase(unittest.TestCase):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.hs = None # type: synapse.server.HomeServer
- self.handler = None # type: synapse.handlers.e2e_keys.E2eKeysHandler
- self.store = None # type: synapse.storage.Storage
+class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+ return self.setup_test_homeserver(federation_client=mock.Mock())
- @defer.inlineCallbacks
- def setUp(self):
- self.hs = yield utils.setup_test_homeserver(
- self.addCleanup, federation_client=mock.Mock()
- )
- self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs)
+ def prepare(self, reactor, clock, hs):
+ self.handler = hs.get_e2e_keys_handler()
self.store = self.hs.get_datastore()
- @defer.inlineCallbacks
def test_query_local_devices_no_devices(self):
- """If the user has no devices, we expect an empty list.
- """
+ """If the user has no devices, we expect an empty list."""
local_user = "@boris:" + self.hs.hostname
- res = yield defer.ensureDeferred(
- self.handler.query_local_devices({local_user: None})
- )
+ res = self.get_success(self.handler.query_local_devices({local_user: None}))
self.assertDictEqual(res, {local_user: {}})
- @defer.inlineCallbacks
def test_reupload_one_time_keys(self):
"""we should be able to re-upload the same keys"""
local_user = "@boris:" + self.hs.hostname
@@ -64,7 +48,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"alg2:k3": {"key": "key3"},
}
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": keys}
)
@@ -73,14 +57,13 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
# we should be able to change the signature without a problem
keys["alg2:k2"]["signatures"]["k1"] = "sig2"
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": keys}
)
)
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
- @defer.inlineCallbacks
def test_change_one_time_keys(self):
"""attempts to change one-time-keys should be rejected"""
@@ -92,75 +75,66 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"alg2:k3": {"key": "key3"},
}
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": keys}
)
)
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
- try:
- yield defer.ensureDeferred(
- self.handler.upload_keys_for_user(
- local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}}
- )
- )
- self.fail("No error when changing string key")
- except errors.SynapseError:
- pass
-
- try:
- yield defer.ensureDeferred(
- self.handler.upload_keys_for_user(
- local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}}
- )
- )
- self.fail("No error when replacing dict key with string")
- except errors.SynapseError:
- pass
-
- try:
- yield defer.ensureDeferred(
- self.handler.upload_keys_for_user(
- local_user,
- device_id,
- {"one_time_keys": {"alg1:k1": {"key": "key"}}},
- )
- )
- self.fail("No error when replacing string key with dict")
- except errors.SynapseError:
- pass
-
- try:
- yield defer.ensureDeferred(
- self.handler.upload_keys_for_user(
- local_user,
- device_id,
- {
- "one_time_keys": {
- "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}}
- }
- },
- )
- )
- self.fail("No error when replacing dict key")
- except errors.SynapseError:
- pass
+ # Error when changing string key
+ self.get_failure(
+ self.handler.upload_keys_for_user(
+ local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}}
+ ),
+ SynapseError,
+ )
+
+ # Error when replacing dict key with strin
+ self.get_failure(
+ self.handler.upload_keys_for_user(
+ local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}}
+ ),
+ SynapseError,
+ )
+
+ # Error when replacing string key with dict
+ self.get_failure(
+ self.handler.upload_keys_for_user(
+ local_user,
+ device_id,
+ {"one_time_keys": {"alg1:k1": {"key": "key"}}},
+ ),
+ SynapseError,
+ )
+
+ # Error when replacing dict key
+ self.get_failure(
+ self.handler.upload_keys_for_user(
+ local_user,
+ device_id,
+ {
+ "one_time_keys": {
+ "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}}
+ }
+ },
+ ),
+ SynapseError,
+ )
- @defer.inlineCallbacks
def test_claim_one_time_key(self):
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
keys = {"alg1:k1": "key1"}
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": keys}
)
)
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1}})
- res2 = yield defer.ensureDeferred(
+ res2 = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
@@ -173,7 +147,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
},
)
- @defer.inlineCallbacks
def test_fallback_key(self):
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
@@ -181,12 +154,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
otk = {"alg1:k2": "key2"}
# we shouldn't have any unused fallback keys yet
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
self.assertEqual(res, [])
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.upload_keys_for_user(
local_user,
device_id,
@@ -195,14 +168,14 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
)
# we should now have an unused alg1 key
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
self.assertEqual(res, ["alg1"])
# claiming an OTK when no OTKs are available should return the fallback
# key
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
@@ -213,13 +186,13 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
)
# we shouldn't have any unused fallback keys again
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
self.assertEqual(res, [])
# claiming an OTK again should return the same fallback key
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
@@ -231,22 +204,23 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
# if the user uploads a one-time key, the next claim should fetch the
# one-time key, and then go back to the fallback
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": otk}
)
)
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
- res, {"failures": {}, "one_time_keys": {local_user: {device_id: otk}}},
+ res,
+ {"failures": {}, "one_time_keys": {local_user: {device_id: otk}}},
)
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
@@ -256,7 +230,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
)
- @defer.inlineCallbacks
def test_replace_master_key(self):
"""uploading a new signing key should make the old signing key unavailable"""
local_user = "@boris:" + self.hs.hostname
@@ -270,9 +243,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
},
}
}
- yield defer.ensureDeferred(
- self.handler.upload_signing_keys_for_user(local_user, keys1)
- )
+ self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
keys2 = {
"master_key": {
@@ -284,16 +255,13 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
},
}
}
- yield defer.ensureDeferred(
- self.handler.upload_signing_keys_for_user(local_user, keys2)
- )
+ self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys2))
- devices = yield defer.ensureDeferred(
+ devices = self.get_success(
self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
)
self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
- @defer.inlineCallbacks
def test_reupload_signatures(self):
"""re-uploading a signature should not fail"""
local_user = "@boris:" + self.hs.hostname
@@ -326,9 +294,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
"2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0",
)
- yield defer.ensureDeferred(
- self.handler.upload_signing_keys_for_user(local_user, keys1)
- )
+ self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
# upload two device keys, which will be signed later by the self-signing key
device_key_1 = {
@@ -358,12 +324,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"signatures": {local_user: {"ed25519:def": "base64+signature"}},
}
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.upload_keys_for_user(
local_user, "abc", {"device_keys": device_key_1}
)
)
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.upload_keys_for_user(
local_user, "def", {"device_keys": device_key_2}
)
@@ -372,7 +338,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
# sign the first device key and upload it
del device_key_1["signatures"]
sign.sign_json(device_key_1, local_user, signing_key)
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.upload_signatures_for_device_keys(
local_user, {local_user: {"abc": device_key_1}}
)
@@ -383,7 +349,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
# signature for it
del device_key_2["signatures"]
sign.sign_json(device_key_2, local_user, signing_key)
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.upload_signatures_for_device_keys(
local_user, {local_user: {"abc": device_key_1, "def": device_key_2}}
)
@@ -391,7 +357,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature"
device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature"
- devices = yield defer.ensureDeferred(
+ devices = self.get_success(
self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
)
del devices["device_keys"][local_user]["abc"]["unsigned"]
@@ -399,7 +365,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
self.assertDictEqual(devices["device_keys"][local_user]["abc"], device_key_1)
self.assertDictEqual(devices["device_keys"][local_user]["def"], device_key_2)
- @defer.inlineCallbacks
def test_self_signing_key_doesnt_show_up_as_device(self):
"""signing keys should be hidden when fetching a user's devices"""
local_user = "@boris:" + self.hs.hostname
@@ -413,29 +378,22 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
},
}
}
- yield defer.ensureDeferred(
- self.handler.upload_signing_keys_for_user(local_user, keys1)
- )
-
- res = None
- try:
- yield defer.ensureDeferred(
- self.hs.get_device_handler().check_device_registered(
- user_id=local_user,
- device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
- initial_device_display_name="new display name",
- )
- )
- except errors.SynapseError as e:
- res = e.code
- self.assertEqual(res, 400)
+ self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
- res = yield defer.ensureDeferred(
- self.handler.query_local_devices({local_user: None})
+ e = self.get_failure(
+ self.hs.get_device_handler().check_device_registered(
+ user_id=local_user,
+ device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
+ initial_device_display_name="new display name",
+ ),
+ SynapseError,
)
+ res = e.value.code
+ self.assertEqual(res, 400)
+
+ res = self.get_success(self.handler.query_local_devices({local_user: None}))
self.assertDictEqual(res, {local_user: {}})
- @defer.inlineCallbacks
def test_upload_signatures(self):
"""should check signatures that are uploaded"""
# set up a user with cross-signing keys and a device. This user will
@@ -458,7 +416,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"ed25519", "xyz", "OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA"
)
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.upload_keys_for_user(
local_user, device_id, {"device_keys": device_key}
)
@@ -501,7 +459,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"user_signing_key": usersigning_key,
"self_signing_key": selfsigning_key,
}
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys)
)
@@ -515,14 +473,14 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"usage": ["master"],
"keys": {"ed25519:" + other_master_pubkey: other_master_pubkey},
}
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.upload_signing_keys_for_user(
other_user, {"master_key": other_master_key}
)
)
# test various signature failures (see below)
- ret = yield defer.ensureDeferred(
+ ret = self.get_success(
self.handler.upload_signatures_for_device_keys(
local_user,
{
@@ -602,20 +560,16 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
)
user_failures = ret["failures"][local_user]
+ self.assertEqual(user_failures[device_id]["errcode"], Codes.INVALID_SIGNATURE)
self.assertEqual(
- user_failures[device_id]["errcode"], errors.Codes.INVALID_SIGNATURE
+ user_failures[master_pubkey]["errcode"], Codes.INVALID_SIGNATURE
)
- self.assertEqual(
- user_failures[master_pubkey]["errcode"], errors.Codes.INVALID_SIGNATURE
- )
- self.assertEqual(user_failures["unknown"]["errcode"], errors.Codes.NOT_FOUND)
+ self.assertEqual(user_failures["unknown"]["errcode"], Codes.NOT_FOUND)
other_user_failures = ret["failures"][other_user]
+ self.assertEqual(other_user_failures["unknown"]["errcode"], Codes.NOT_FOUND)
self.assertEqual(
- other_user_failures["unknown"]["errcode"], errors.Codes.NOT_FOUND
- )
- self.assertEqual(
- other_user_failures[other_master_pubkey]["errcode"], errors.Codes.UNKNOWN
+ other_user_failures[other_master_pubkey]["errcode"], Codes.UNKNOWN
)
# test successful signatures
@@ -623,7 +577,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
sign.sign_json(device_key, local_user, selfsigning_signing_key)
sign.sign_json(master_key, local_user, device_signing_key)
sign.sign_json(other_master_key, local_user, usersigning_signing_key)
- ret = yield defer.ensureDeferred(
+ ret = self.get_success(
self.handler.upload_signatures_for_device_keys(
local_user,
{
@@ -636,7 +590,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
self.assertEqual(ret["failures"], {})
# fetch the signed keys/devices and make sure that the signatures are there
- ret = yield defer.ensureDeferred(
+ ret = self.get_success(
self.handler.query_devices(
{"device_keys": {local_user: [], other_user: []}}, 0, local_user
)
diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py
index 45f201a399..d7498aa51a 100644
--- a/tests/handlers/test_e2e_room_keys.py
+++ b/tests/handlers/test_e2e_room_keys.py
@@ -19,14 +19,9 @@ import copy
import mock
-from twisted.internet import defer
+from synapse.api.errors import SynapseError
-import synapse.api.errors
-import synapse.handlers.e2e_room_keys
-import synapse.storage
-from synapse.api import errors
-
-from tests import unittest, utils
+from tests import unittest
# sample room_key data for use in the tests
room_keys = {
@@ -45,51 +40,38 @@ room_keys = {
}
-class E2eRoomKeysHandlerTestCase(unittest.TestCase):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.hs = None # type: synapse.server.HomeServer
- self.handler = None # type: synapse.handlers.e2e_keys.E2eRoomKeysHandler
+class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+ return self.setup_test_homeserver(replication_layer=mock.Mock())
- @defer.inlineCallbacks
- def setUp(self):
- self.hs = yield utils.setup_test_homeserver(
- self.addCleanup, replication_layer=mock.Mock()
- )
- self.handler = synapse.handlers.e2e_room_keys.E2eRoomKeysHandler(self.hs)
- self.local_user = "@boris:" + self.hs.hostname
+ def prepare(self, reactor, clock, hs):
+ self.handler = hs.get_e2e_room_keys_handler()
+ self.local_user = "@boris:" + hs.hostname
- @defer.inlineCallbacks
def test_get_missing_current_version_info(self):
"""Check that we get a 404 if we ask for info about the current version
if there is no version.
"""
- res = None
- try:
- yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
- except errors.SynapseError as e:
- res = e.code
+ e = self.get_failure(
+ self.handler.get_version_info(self.local_user), SynapseError
+ )
+ res = e.value.code
self.assertEqual(res, 404)
- @defer.inlineCallbacks
def test_get_missing_version_info(self):
"""Check that we get a 404 if we ask for info about a specific version
if it doesn't exist.
"""
- res = None
- try:
- yield defer.ensureDeferred(
- self.handler.get_version_info(self.local_user, "bogus_version")
- )
- except errors.SynapseError as e:
- res = e.code
+ e = self.get_failure(
+ self.handler.get_version_info(self.local_user, "bogus_version"),
+ SynapseError,
+ )
+ res = e.value.code
self.assertEqual(res, 404)
- @defer.inlineCallbacks
def test_create_version(self):
- """Check that we can create and then retrieve versions.
- """
- res = yield defer.ensureDeferred(
+ """Check that we can create and then retrieve versions."""
+ res = self.get_success(
self.handler.create_version(
self.local_user,
{
@@ -101,7 +83,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
self.assertEqual(res, "1")
# check we can retrieve it as the current version
- res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+ res = self.get_success(self.handler.get_version_info(self.local_user))
version_etag = res["etag"]
self.assertIsInstance(version_etag, str)
del res["etag"]
@@ -116,9 +98,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
)
# check we can retrieve it as a specific version
- res = yield defer.ensureDeferred(
- self.handler.get_version_info(self.local_user, "1")
- )
+ res = self.get_success(self.handler.get_version_info(self.local_user, "1"))
self.assertEqual(res["etag"], version_etag)
del res["etag"]
self.assertDictEqual(
@@ -132,7 +112,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
)
# upload a new one...
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.handler.create_version(
self.local_user,
{
@@ -144,7 +124,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
self.assertEqual(res, "2")
# check we can retrieve it as the current version
- res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+ res = self.get_success(self.handler.get_version_info(self.local_user))
del res["etag"]
self.assertDictEqual(
res,
@@ -156,11 +136,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
},
)
- @defer.inlineCallbacks
def test_update_version(self):
- """Check that we can update versions.
- """
- version = yield defer.ensureDeferred(
+ """Check that we can update versions."""
+ version = self.get_success(
self.handler.create_version(
self.local_user,
{
@@ -171,7 +149,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
)
self.assertEqual(version, "1")
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.handler.update_version(
self.local_user,
version,
@@ -185,7 +163,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
self.assertDictEqual(res, {})
# check we can retrieve it as the current version
- res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+ res = self.get_success(self.handler.get_version_info(self.local_user))
del res["etag"]
self.assertDictEqual(
res,
@@ -197,32 +175,26 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
},
)
- @defer.inlineCallbacks
def test_update_missing_version(self):
- """Check that we get a 404 on updating nonexistent versions
- """
- res = None
- try:
- yield defer.ensureDeferred(
- self.handler.update_version(
- self.local_user,
- "1",
- {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "revised_first_version_auth_data",
- "version": "1",
- },
- )
- )
- except errors.SynapseError as e:
- res = e.code
+ """Check that we get a 404 on updating nonexistent versions"""
+ e = self.get_failure(
+ self.handler.update_version(
+ self.local_user,
+ "1",
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "revised_first_version_auth_data",
+ "version": "1",
+ },
+ ),
+ SynapseError,
+ )
+ res = e.value.code
self.assertEqual(res, 404)
- @defer.inlineCallbacks
def test_update_omitted_version(self):
- """Check that the update succeeds if the version is missing from the body
- """
- version = yield defer.ensureDeferred(
+ """Check that the update succeeds if the version is missing from the body"""
+ version = self.get_success(
self.handler.create_version(
self.local_user,
{
@@ -233,7 +205,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
)
self.assertEqual(version, "1")
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.update_version(
self.local_user,
version,
@@ -245,7 +217,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
)
# check we can retrieve it as the current version
- res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+ res = self.get_success(self.handler.get_version_info(self.local_user))
del res["etag"] # etag is opaque, so don't test its contents
self.assertDictEqual(
res,
@@ -257,11 +229,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
},
)
- @defer.inlineCallbacks
def test_update_bad_version(self):
- """Check that we get a 400 if the version in the body doesn't match
- """
- version = yield defer.ensureDeferred(
+ """Check that we get a 400 if the version in the body doesn't match"""
+ version = self.get_success(
self.handler.create_version(
self.local_user,
{
@@ -272,52 +242,38 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
)
self.assertEqual(version, "1")
- res = None
- try:
- yield defer.ensureDeferred(
- self.handler.update_version(
- self.local_user,
- version,
- {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "revised_first_version_auth_data",
- "version": "incorrect",
- },
- )
- )
- except errors.SynapseError as e:
- res = e.code
+ e = self.get_failure(
+ self.handler.update_version(
+ self.local_user,
+ version,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "revised_first_version_auth_data",
+ "version": "incorrect",
+ },
+ ),
+ SynapseError,
+ )
+ res = e.value.code
self.assertEqual(res, 400)
- @defer.inlineCallbacks
def test_delete_missing_version(self):
- """Check that we get a 404 on deleting nonexistent versions
- """
- res = None
- try:
- yield defer.ensureDeferred(
- self.handler.delete_version(self.local_user, "1")
- )
- except errors.SynapseError as e:
- res = e.code
+ """Check that we get a 404 on deleting nonexistent versions"""
+ e = self.get_failure(
+ self.handler.delete_version(self.local_user, "1"), SynapseError
+ )
+ res = e.value.code
self.assertEqual(res, 404)
- @defer.inlineCallbacks
def test_delete_missing_current_version(self):
- """Check that we get a 404 on deleting nonexistent current version
- """
- res = None
- try:
- yield defer.ensureDeferred(self.handler.delete_version(self.local_user))
- except errors.SynapseError as e:
- res = e.code
+ """Check that we get a 404 on deleting nonexistent current version"""
+ e = self.get_failure(self.handler.delete_version(self.local_user), SynapseError)
+ res = e.value.code
self.assertEqual(res, 404)
- @defer.inlineCallbacks
def test_delete_version(self):
- """Check that we can create and then delete versions.
- """
- res = yield defer.ensureDeferred(
+ """Check that we can create and then delete versions."""
+ res = self.get_success(
self.handler.create_version(
self.local_user,
{
@@ -329,36 +285,26 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
self.assertEqual(res, "1")
# check we can delete it
- yield defer.ensureDeferred(self.handler.delete_version(self.local_user, "1"))
+ self.get_success(self.handler.delete_version(self.local_user, "1"))
# check that it's gone
- res = None
- try:
- yield defer.ensureDeferred(
- self.handler.get_version_info(self.local_user, "1")
- )
- except errors.SynapseError as e:
- res = e.code
+ e = self.get_failure(
+ self.handler.get_version_info(self.local_user, "1"), SynapseError
+ )
+ res = e.value.code
self.assertEqual(res, 404)
- @defer.inlineCallbacks
def test_get_missing_backup(self):
- """Check that we get a 404 on querying missing backup
- """
- res = None
- try:
- yield defer.ensureDeferred(
- self.handler.get_room_keys(self.local_user, "bogus_version")
- )
- except errors.SynapseError as e:
- res = e.code
+ """Check that we get a 404 on querying missing backup"""
+ e = self.get_failure(
+ self.handler.get_room_keys(self.local_user, "bogus_version"), SynapseError
+ )
+ res = e.value.code
self.assertEqual(res, 404)
- @defer.inlineCallbacks
def test_get_missing_room_keys(self):
- """Check we get an empty response from an empty backup
- """
- version = yield defer.ensureDeferred(
+ """Check we get an empty response from an empty backup"""
+ version = self.get_success(
self.handler.create_version(
self.local_user,
{
@@ -369,33 +315,26 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
)
self.assertEqual(version, "1")
- res = yield defer.ensureDeferred(
- self.handler.get_room_keys(self.local_user, version)
- )
+ res = self.get_success(self.handler.get_room_keys(self.local_user, version))
self.assertDictEqual(res, {"rooms": {}})
# TODO: test the locking semantics when uploading room_keys,
# although this is probably best done in sytest
- @defer.inlineCallbacks
def test_upload_room_keys_no_versions(self):
- """Check that we get a 404 on uploading keys when no versions are defined
- """
- res = None
- try:
- yield defer.ensureDeferred(
- self.handler.upload_room_keys(self.local_user, "no_version", room_keys)
- )
- except errors.SynapseError as e:
- res = e.code
+ """Check that we get a 404 on uploading keys when no versions are defined"""
+ e = self.get_failure(
+ self.handler.upload_room_keys(self.local_user, "no_version", room_keys),
+ SynapseError,
+ )
+ res = e.value.code
self.assertEqual(res, 404)
- @defer.inlineCallbacks
def test_upload_room_keys_bogus_version(self):
"""Check that we get a 404 on uploading keys when an nonexistent version
is specified
"""
- version = yield defer.ensureDeferred(
+ version = self.get_success(
self.handler.create_version(
self.local_user,
{
@@ -406,22 +345,16 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
)
self.assertEqual(version, "1")
- res = None
- try:
- yield defer.ensureDeferred(
- self.handler.upload_room_keys(
- self.local_user, "bogus_version", room_keys
- )
- )
- except errors.SynapseError as e:
- res = e.code
+ e = self.get_failure(
+ self.handler.upload_room_keys(self.local_user, "bogus_version", room_keys),
+ SynapseError,
+ )
+ res = e.value.code
self.assertEqual(res, 404)
- @defer.inlineCallbacks
def test_upload_room_keys_wrong_version(self):
- """Check that we get a 403 on uploading keys for an old version
- """
- version = yield defer.ensureDeferred(
+ """Check that we get a 403 on uploading keys for an old version"""
+ version = self.get_success(
self.handler.create_version(
self.local_user,
{
@@ -432,7 +365,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
)
self.assertEqual(version, "1")
- version = yield defer.ensureDeferred(
+ version = self.get_success(
self.handler.create_version(
self.local_user,
{
@@ -443,20 +376,15 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
)
self.assertEqual(version, "2")
- res = None
- try:
- yield defer.ensureDeferred(
- self.handler.upload_room_keys(self.local_user, "1", room_keys)
- )
- except errors.SynapseError as e:
- res = e.code
+ e = self.get_failure(
+ self.handler.upload_room_keys(self.local_user, "1", room_keys), SynapseError
+ )
+ res = e.value.code
self.assertEqual(res, 403)
- @defer.inlineCallbacks
def test_upload_room_keys_insert(self):
- """Check that we can insert and retrieve keys for a session
- """
- version = yield defer.ensureDeferred(
+ """Check that we can insert and retrieve keys for a session"""
+ version = self.get_success(
self.handler.create_version(
self.local_user,
{
@@ -467,17 +395,15 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
)
self.assertEqual(version, "1")
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.upload_room_keys(self.local_user, version, room_keys)
)
- res = yield defer.ensureDeferred(
- self.handler.get_room_keys(self.local_user, version)
- )
+ res = self.get_success(self.handler.get_room_keys(self.local_user, version))
self.assertDictEqual(res, room_keys)
# check getting room_keys for a given room
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.handler.get_room_keys(
self.local_user, version, room_id="!abc:matrix.org"
)
@@ -485,18 +411,17 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
self.assertDictEqual(res, room_keys)
# check getting room_keys for a given session_id
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.handler.get_room_keys(
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
)
)
self.assertDictEqual(res, room_keys)
- @defer.inlineCallbacks
def test_upload_room_keys_merge(self):
"""Check that we can upload a new room_key for an existing session and
have it correctly merged"""
- version = yield defer.ensureDeferred(
+ version = self.get_success(
self.handler.create_version(
self.local_user,
{
@@ -507,12 +432,12 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
)
self.assertEqual(version, "1")
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.upload_room_keys(self.local_user, version, room_keys)
)
# get the etag to compare to future versions
- res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+ res = self.get_success(self.handler.get_version_info(self.local_user))
backup_etag = res["etag"]
self.assertEqual(res["count"], 1)
@@ -522,37 +447,33 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
# test that increasing the message_index doesn't replace the existing session
new_room_key["first_message_index"] = 2
new_room_key["session_data"] = "new"
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.upload_room_keys(self.local_user, version, new_room_keys)
)
- res = yield defer.ensureDeferred(
- self.handler.get_room_keys(self.local_user, version)
- )
+ res = self.get_success(self.handler.get_room_keys(self.local_user, version))
self.assertEqual(
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"],
"SSBBTSBBIEZJU0gK",
)
# the etag should be the same since the session did not change
- res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+ res = self.get_success(self.handler.get_version_info(self.local_user))
self.assertEqual(res["etag"], backup_etag)
# test that marking the session as verified however /does/ replace it
new_room_key["is_verified"] = True
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.upload_room_keys(self.local_user, version, new_room_keys)
)
- res = yield defer.ensureDeferred(
- self.handler.get_room_keys(self.local_user, version)
- )
+ res = self.get_success(self.handler.get_room_keys(self.local_user, version))
self.assertEqual(
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
)
# the etag should NOT be equal now, since the key changed
- res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+ res = self.get_success(self.handler.get_version_info(self.local_user))
self.assertNotEqual(res["etag"], backup_etag)
backup_etag = res["etag"]
@@ -560,28 +481,24 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
# with a lower forwarding count
new_room_key["forwarded_count"] = 2
new_room_key["session_data"] = "other"
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.upload_room_keys(self.local_user, version, new_room_keys)
)
- res = yield defer.ensureDeferred(
- self.handler.get_room_keys(self.local_user, version)
- )
+ res = self.get_success(self.handler.get_room_keys(self.local_user, version))
self.assertEqual(
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
)
# the etag should be the same since the session did not change
- res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+ res = self.get_success(self.handler.get_version_info(self.local_user))
self.assertEqual(res["etag"], backup_etag)
# TODO: check edge cases as well as the common variations here
- @defer.inlineCallbacks
def test_delete_room_keys(self):
- """Check that we can insert and delete keys for a session
- """
- version = yield defer.ensureDeferred(
+ """Check that we can insert and delete keys for a session"""
+ version = self.get_success(
self.handler.create_version(
self.local_user,
{
@@ -593,13 +510,11 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
self.assertEqual(version, "1")
# check for bulk-delete
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.upload_room_keys(self.local_user, version, room_keys)
)
- yield defer.ensureDeferred(
- self.handler.delete_room_keys(self.local_user, version)
- )
- res = yield defer.ensureDeferred(
+ self.get_success(self.handler.delete_room_keys(self.local_user, version))
+ res = self.get_success(
self.handler.get_room_keys(
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
)
@@ -607,15 +522,15 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
self.assertDictEqual(res, {"rooms": {}})
# check for bulk-delete per room
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.upload_room_keys(self.local_user, version, room_keys)
)
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.delete_room_keys(
self.local_user, version, room_id="!abc:matrix.org"
)
)
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.handler.get_room_keys(
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
)
@@ -623,15 +538,15 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
self.assertDictEqual(res, {"rooms": {}})
# check for bulk-delete per session
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.upload_room_keys(self.local_user, version, room_keys)
)
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.delete_room_keys(
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
)
)
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.handler.get_room_keys(
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
)
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 983e368592..3af361195b 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -226,12 +226,20 @@ class FederationTestCase(unittest.HomeserverTestCase):
for i in range(3):
event = create_invite()
self.get_success(
- self.handler.on_invite_request(other_server, event, event.room_version,)
+ self.handler.on_invite_request(
+ other_server,
+ event,
+ event.room_version,
+ )
)
event = create_invite()
self.get_failure(
- self.handler.on_invite_request(other_server, event, event.room_version,),
+ self.handler.on_invite_request(
+ other_server,
+ event,
+ event.room_version,
+ ),
exc=LimitExceededError,
)
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index f955dfa490..a0d1ebdbe3 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -44,7 +44,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
self.info = self.get_success(
- self.hs.get_datastore().get_user_by_access_token(self.access_token,)
+ self.hs.get_datastore().get_user_by_access_token(
+ self.access_token,
+ )
)
self.token_id = self.info.token_id
@@ -169,8 +171,7 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase):
self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
def test_allow_server_acl(self):
- """Test that sending an ACL that blocks everyone but ourselves works.
- """
+ """Test that sending an ACL that blocks everyone but ourselves works."""
self.helper.send_state(
self.room_id,
@@ -181,8 +182,7 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase):
)
def test_deny_server_acl_block_outselves(self):
- """Test that sending an ACL that blocks ourselves does not work.
- """
+ """Test that sending an ACL that blocks ourselves does not work."""
self.helper.send_state(
self.room_id,
EventTypes.ServerACL,
@@ -192,8 +192,7 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase):
)
def test_deny_redact_server_acl(self):
- """Test that attempting to redact an ACL is blocked.
- """
+ """Test that attempting to redact an ACL is blocked."""
body = self.helper.send_state(
self.room_id,
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index ad20400b1d..cf1de28fa9 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -24,7 +24,7 @@ from synapse.handlers.sso import MappingException
from synapse.server import HomeServer
from synapse.types import UserID
-from tests.test_utils import FakeResponse, simple_async_mock
+from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
from tests.unittest import HomeserverTestCase, override_config
try:
@@ -131,7 +131,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
return config
def make_homeserver(self, reactor, clock):
-
self.http_client = Mock(spec=["get_json"])
self.http_client.get_json.side_effect = get_json
self.http_client.user_agent = "Synapse Test"
@@ -151,7 +150,15 @@ class OidcHandlerTestCase(HomeserverTestCase):
return hs
def metadata_edit(self, values):
- return patch.dict(self.provider._provider_metadata, values)
+ """Modify the result that will be returned by the well-known query"""
+
+ async def patched_get_json(uri):
+ res = await get_json(uri)
+ if uri == WELL_KNOWN:
+ res.update(values)
+ return res
+
+ return patch.object(self.http_client, "get_json", patched_get_json)
def assertRenderedError(self, error, error_description=None):
self.render_error.assert_called_once()
@@ -212,7 +219,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.http_client.get_json.assert_called_once_with(JWKS_URI)
# Throw if the JWKS uri is missing
- with self.metadata_edit({"jwks_uri": None}):
+ original = self.provider.load_metadata
+
+ async def patched_load_metadata():
+ m = (await original()).copy()
+ m.update({"jwks_uri": None})
+ return m
+
+ with patch.object(self.provider, "load_metadata", patched_load_metadata):
self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
# Return empty key set if JWKS are not used
@@ -222,55 +236,60 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.http_client.get_json.assert_not_called()
self.assertEqual(jwks, {"keys": []})
- @override_config({"oidc_config": COMMON_CONFIG})
def test_validate_config(self):
"""Provider metadatas are extensively validated."""
h = self.provider
+ def force_load_metadata():
+ async def force_load():
+ return await h.load_metadata(force=True)
+
+ return get_awaitable_result(force_load())
+
# Default test config does not throw
- h._validate_metadata()
+ force_load_metadata()
with self.metadata_edit({"issuer": None}):
- self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata)
+ self.assertRaisesRegex(ValueError, "issuer", force_load_metadata)
with self.metadata_edit({"issuer": "http://insecure/"}):
- self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata)
+ self.assertRaisesRegex(ValueError, "issuer", force_load_metadata)
with self.metadata_edit({"issuer": "https://invalid/?because=query"}):
- self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata)
+ self.assertRaisesRegex(ValueError, "issuer", force_load_metadata)
with self.metadata_edit({"authorization_endpoint": None}):
self.assertRaisesRegex(
- ValueError, "authorization_endpoint", h._validate_metadata
+ ValueError, "authorization_endpoint", force_load_metadata
)
with self.metadata_edit({"authorization_endpoint": "http://insecure/auth"}):
self.assertRaisesRegex(
- ValueError, "authorization_endpoint", h._validate_metadata
+ ValueError, "authorization_endpoint", force_load_metadata
)
with self.metadata_edit({"token_endpoint": None}):
- self.assertRaisesRegex(ValueError, "token_endpoint", h._validate_metadata)
+ self.assertRaisesRegex(ValueError, "token_endpoint", force_load_metadata)
with self.metadata_edit({"token_endpoint": "http://insecure/token"}):
- self.assertRaisesRegex(ValueError, "token_endpoint", h._validate_metadata)
+ self.assertRaisesRegex(ValueError, "token_endpoint", force_load_metadata)
with self.metadata_edit({"jwks_uri": None}):
- self.assertRaisesRegex(ValueError, "jwks_uri", h._validate_metadata)
+ self.assertRaisesRegex(ValueError, "jwks_uri", force_load_metadata)
with self.metadata_edit({"jwks_uri": "http://insecure/jwks.json"}):
- self.assertRaisesRegex(ValueError, "jwks_uri", h._validate_metadata)
+ self.assertRaisesRegex(ValueError, "jwks_uri", force_load_metadata)
with self.metadata_edit({"response_types_supported": ["id_token"]}):
self.assertRaisesRegex(
- ValueError, "response_types_supported", h._validate_metadata
+ ValueError, "response_types_supported", force_load_metadata
)
with self.metadata_edit(
{"token_endpoint_auth_methods_supported": ["client_secret_basic"]}
):
# should not throw, as client_secret_basic is the default auth method
- h._validate_metadata()
+ force_load_metadata()
with self.metadata_edit(
{"token_endpoint_auth_methods_supported": ["client_secret_post"]}
@@ -278,7 +297,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertRaisesRegex(
ValueError,
"token_endpoint_auth_methods_supported",
- h._validate_metadata,
+ force_load_metadata,
)
# Tests for configs that require the userinfo endpoint
@@ -287,28 +306,30 @@ class OidcHandlerTestCase(HomeserverTestCase):
h._user_profile_method = "userinfo_endpoint"
self.assertTrue(h._uses_userinfo)
- # Revert the profile method and do not request the "openid" scope.
+ # Revert the profile method and do not request the "openid" scope: this should
+ # mean that we check for a userinfo endpoint
h._user_profile_method = "auto"
h._scopes = []
self.assertTrue(h._uses_userinfo)
- self.assertRaisesRegex(ValueError, "userinfo_endpoint", h._validate_metadata)
+ with self.metadata_edit({"userinfo_endpoint": None}):
+ self.assertRaisesRegex(ValueError, "userinfo_endpoint", force_load_metadata)
- with self.metadata_edit(
- {"userinfo_endpoint": USERINFO_ENDPOINT, "jwks_uri": None}
- ):
- # Shouldn't raise with a valid userinfo, even without
- h._validate_metadata()
+ with self.metadata_edit({"jwks_uri": None}):
+ # Shouldn't raise with a valid userinfo, even without jwks
+ force_load_metadata()
@override_config({"oidc_config": {"skip_verification": True}})
def test_skip_verification(self):
"""Provider metadata validation can be disabled by config."""
with self.metadata_edit({"issuer": "http://insecure"}):
# This should not throw
- self.provider._validate_metadata()
+ get_awaitable_result(self.provider.load_metadata())
def test_redirect_request(self):
"""The redirect request has the right arguments & generates a valid session cookie."""
- req = Mock(spec=["addCookie"])
+ req = Mock(spec=["cookies"])
+ req.cookies = []
+
url = self.get_success(
self.provider.handle_redirect_request(req, b"http://client/redirect")
)
@@ -327,19 +348,16 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(len(params["state"]), 1)
self.assertEqual(len(params["nonce"]), 1)
- # Check what is in the cookie
- # note: python3.5 mock does not have the .called_once() method
- calls = req.addCookie.call_args_list
- self.assertEqual(len(calls), 1) # called once
- # For some reason, call.args does not work with python3.5
- args = calls[0][0]
- kwargs = calls[0][1]
+ # Check what is in the cookies
+ self.assertEqual(len(req.cookies), 2) # two cookies
+ cookie_header = req.cookies[0]
# The cookie name and path don't really matter, just that it has to be coherent
# between the callback & redirect handlers.
- self.assertEqual(args[0], b"oidc_session")
- self.assertEqual(kwargs["path"], "/_synapse/client/oidc")
- cookie = args[1]
+ parts = [p.strip() for p in cookie_header.split(b";")]
+ self.assertIn(b"Path=/_synapse/client/oidc", parts)
+ name, cookie = parts[0].split(b"=")
+ self.assertEqual(name, b"oidc_session")
macaroon = pymacaroons.Macaroon.deserialize(cookie)
state = self.handler._token_generator._get_value_from_macaroon(
@@ -470,7 +488,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
def test_callback_session(self):
"""The callback verifies the session presence and validity"""
- request = Mock(spec=["args", "getCookie", "addCookie"])
+ request = Mock(spec=["args", "getCookie", "cookies"])
# Missing cookie
request.args = {}
@@ -493,7 +511,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
# Mismatching session
session = self._generate_oidc_session_token(
- state="state", nonce="nonce", client_redirect_url="http://client/redirect",
+ state="state",
+ nonce="nonce",
+ client_redirect_url="http://client/redirect",
)
request.args = {}
request.args[b"state"] = [b"mismatching state"]
@@ -548,7 +568,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
# Internal server error with no JSON body
self.http_client.request = simple_async_mock(
return_value=FakeResponse(
- code=500, phrase=b"Internal Server Error", body=b"Not JSON",
+ code=500,
+ phrase=b"Internal Server Error",
+ body=b"Not JSON",
)
)
exc = self.get_failure(self.provider._exchange_code(code), OidcError)
@@ -568,7 +590,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
# 4xx error without "error" field
self.http_client.request = simple_async_mock(
- return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",)
+ return_value=FakeResponse(
+ code=400,
+ phrase=b"Bad request",
+ body=b"{}",
+ )
)
exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "server_error")
@@ -576,7 +602,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
# 2xx error with "error" field
self.http_client.request = simple_async_mock(
return_value=FakeResponse(
- code=200, phrase=b"OK", body=b'{"error": "some_error"}',
+ code=200,
+ phrase=b"OK",
+ body=b'{"error": "some_error"}',
)
)
exc = self.get_failure(self.provider._exchange_code(code), OidcError)
@@ -613,7 +641,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
state = "state"
client_redirect_url = "http://client/redirect"
session = self._generate_oidc_session_token(
- state=state, nonce="nonce", client_redirect_url=client_redirect_url,
+ state=state,
+ nonce="nonce",
+ client_redirect_url=client_redirect_url,
)
request = _build_callback_request("code", state, session)
@@ -876,7 +906,9 @@ async def _make_callback_with_userinfo(
session = handler._token_generator.generate_oidc_session_token(
state=state,
session_data=OidcSessionData(
- idp_id="oidc", nonce="nonce", client_redirect_url=client_redirect_url,
+ idp_id="oidc",
+ nonce="nonce",
+ client_redirect_url=client_redirect_url,
),
)
request = _build_callback_request("code", state, session)
@@ -910,13 +942,14 @@ def _build_callback_request(
spec=[
"args",
"getCookie",
- "addCookie",
+ "cookies",
"requestHeaders",
"getClientIP",
"getHeader",
]
)
+ request.cookies = []
request.getCookie.return_value = session
request.args = {}
request.args[b"code"] = [code.encode("utf-8")]
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index f816594ee4..a98a65ae67 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -231,8 +231,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
}
)
def test_no_local_user_fallback_login(self):
- """localdb_enabled can block login with the local password
- """
+ """localdb_enabled can block login with the local password"""
self.register_user("localuser", "localpass")
# check_password must return an awaitable
@@ -251,8 +250,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
}
)
def test_no_local_user_fallback_ui_auth(self):
- """localdb_enabled can block ui auth with the local password
- """
+ """localdb_enabled can block ui auth with the local password"""
self.register_user("localuser", "localpass")
# allow login via the auth provider
@@ -594,7 +592,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
)
def _delete_device(
- self, access_token: str, device: str, body: Union[JsonDict, bytes] = b"",
+ self,
+ access_token: str,
+ device: str,
+ body: Union[JsonDict, bytes] = b"",
) -> FakeChannel:
"""Delete an individual device."""
channel = self.make_request(
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 0794b32c9c..be2ee26f07 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -589,8 +589,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
)
def _add_new_user(self, room_id, user_id):
- """Add new user to the room by creating an event and poking the federation API.
- """
+ """Add new user to the room by creating an event and poking the federation API."""
hostname = get_domain_from_id(user_id)
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 75275f0e4f..909984b3be 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -13,25 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
from mock import Mock
-from twisted.internet import defer
-
import synapse.types
from synapse.api.errors import AuthError, SynapseError
from synapse.types import UserID
from tests import unittest
from tests.test_utils import make_awaitable
-from tests.utils import setup_test_homeserver
-class ProfileTestCase(unittest.TestCase):
+class ProfileTestCase(unittest.HomeserverTestCase):
""" Tests profile management. """
- @defer.inlineCallbacks
- def setUp(self):
+ def make_homeserver(self, reactor, clock):
self.mock_federation = Mock()
self.mock_registry = Mock()
@@ -42,39 +37,35 @@ class ProfileTestCase(unittest.TestCase):
self.mock_registry.register_query_handler = register_query_handler
- hs = yield setup_test_homeserver(
- self.addCleanup,
+ hs = self.setup_test_homeserver(
federation_client=self.mock_federation,
federation_server=Mock(),
federation_registry=self.mock_registry,
)
+ return hs
+ def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.frank = UserID.from_string("@1234ABCD:test")
self.bob = UserID.from_string("@4567:test")
self.alice = UserID.from_string("@alice:remote")
- yield defer.ensureDeferred(self.store.create_profile(self.frank.localpart))
+ self.get_success(self.store.create_profile(self.frank.localpart))
self.handler = hs.get_profile_handler()
- self.hs = hs
- @defer.inlineCallbacks
def test_get_my_name(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.set_profile_displayname(self.frank.localpart, "Frank", 1)
)
- displayname = yield defer.ensureDeferred(
- self.handler.get_displayname(self.frank)
- )
+ displayname = self.get_success(self.handler.get_displayname(self.frank))
self.assertEquals("Frank", displayname)
- @defer.inlineCallbacks
def test_set_my_name(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
)
@@ -82,7 +73,7 @@ class ProfileTestCase(unittest.TestCase):
self.assertEquals(
(
- yield defer.ensureDeferred(
+ self.get_success(
self.store.get_profile_displayname(self.frank.localpart)
)
),
@@ -90,7 +81,7 @@ class ProfileTestCase(unittest.TestCase):
)
# Set displayname again
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank"
)
@@ -98,7 +89,7 @@ class ProfileTestCase(unittest.TestCase):
self.assertEquals(
(
- yield defer.ensureDeferred(
+ self.get_success(
self.store.get_profile_displayname(self.frank.localpart)
)
),
@@ -106,32 +97,27 @@ class ProfileTestCase(unittest.TestCase):
)
# Set displayname to an empty string
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), ""
)
)
self.assertIsNone(
- (
- yield defer.ensureDeferred(
- self.store.get_profile_displayname(self.frank.localpart)
- )
- )
+ (self.get_success(self.store.get_profile_displayname(self.frank.localpart)))
)
- @defer.inlineCallbacks
def test_set_my_name_if_disabled(self):
self.hs.config.enable_set_displayname = False
# Setting displayname for the first time is allowed
- yield defer.ensureDeferred(
+ self.get_success(
self.store.set_profile_displayname(self.frank.localpart, "Frank", 1)
)
self.assertEquals(
(
- yield defer.ensureDeferred(
+ self.get_success(
self.store.get_profile_displayname(self.frank.localpart)
)
),
@@ -139,33 +125,27 @@ class ProfileTestCase(unittest.TestCase):
)
# Setting displayname a second time is forbidden
- d = defer.ensureDeferred(
+ self.get_failure(
self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
- )
+ ),
+ SynapseError,
)
- yield self.assertFailure(d, SynapseError)
-
- @defer.inlineCallbacks
def test_set_my_name_noauth(self):
- d = defer.ensureDeferred(
+ self.get_failure(
self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
- )
+ ),
+ AuthError,
)
- yield self.assertFailure(d, AuthError)
-
- @defer.inlineCallbacks
def test_get_other_name(self):
self.mock_federation.make_query.return_value = make_awaitable(
{"displayname": "Alice"}
)
- displayname = yield defer.ensureDeferred(
- self.handler.get_displayname(self.alice)
- )
+ displayname = self.get_success(self.handler.get_displayname(self.alice))
self.assertEquals(displayname, "Alice")
self.mock_federation.make_query.assert_called_with(
@@ -175,14 +155,11 @@ class ProfileTestCase(unittest.TestCase):
ignore_backoff=True,
)
- @defer.inlineCallbacks
def test_incoming_fed_query(self):
- yield defer.ensureDeferred(self.store.create_profile("caroline"))
- yield defer.ensureDeferred(
- self.store.set_profile_displayname("caroline", "Caroline", 1)
- )
+ self.get_success(self.store.create_profile("caroline"))
+ self.get_success(self.store.set_profile_displayname("caroline", "Caroline", 1))
- response = yield defer.ensureDeferred(
+ response = self.get_success(
self.query_handlers["profile"](
{"user_id": "@caroline:test", "field": "displayname"}
)
@@ -190,20 +167,18 @@ class ProfileTestCase(unittest.TestCase):
self.assertEquals({"displayname": "Caroline"}, response)
- @defer.inlineCallbacks
def test_get_my_avatar(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.set_profile_avatar_url(
self.frank.localpart, "http://my.server/me.png", 1
)
)
- avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank))
+ avatar_url = self.get_success(self.handler.get_avatar_url(self.frank))
self.assertEquals("http://my.server/me.png", avatar_url)
- @defer.inlineCallbacks
def test_set_my_avatar(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.set_avatar_url(
self.frank,
synapse.types.create_requester(self.frank),
@@ -212,16 +187,12 @@ class ProfileTestCase(unittest.TestCase):
)
self.assertEquals(
- (
- yield defer.ensureDeferred(
- self.store.get_profile_avatar_url(self.frank.localpart)
- )
- ),
+ (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
"http://my.server/pic.gif",
)
# Set avatar again
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.set_avatar_url(
self.frank,
synapse.types.create_requester(self.frank),
@@ -230,56 +201,44 @@ class ProfileTestCase(unittest.TestCase):
)
self.assertEquals(
- (
- yield defer.ensureDeferred(
- self.store.get_profile_avatar_url(self.frank.localpart)
- )
- ),
+ (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
"http://my.server/me.png",
)
# Set avatar to an empty string
- yield defer.ensureDeferred(
+ self.get_success(
self.handler.set_avatar_url(
- self.frank, synapse.types.create_requester(self.frank), "",
+ self.frank,
+ synapse.types.create_requester(self.frank),
+ "",
)
)
self.assertIsNone(
- (
- yield defer.ensureDeferred(
- self.store.get_profile_avatar_url(self.frank.localpart)
- )
- ),
+ (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
)
- @defer.inlineCallbacks
def test_set_my_avatar_if_disabled(self):
self.hs.config.enable_set_avatar_url = False
# Setting displayname for the first time is allowed
- yield defer.ensureDeferred(
+ self.get_success(
self.store.set_profile_avatar_url(
self.frank.localpart, "http://my.server/me.png", 1
)
)
self.assertEquals(
- (
- yield defer.ensureDeferred(
- self.store.get_profile_avatar_url(self.frank.localpart)
- )
- ),
+ (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
"http://my.server/me.png",
)
# Set avatar a second time is forbidden
- d = defer.ensureDeferred(
+ self.get_failure(
self.handler.set_avatar_url(
self.frank,
synapse.types.create_requester(self.frank),
"http://my.server/pic.gif",
- )
+ ),
+ SynapseError,
)
-
- yield self.assertFailure(d, SynapseError)
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index a8d6c0f617..029af2853e 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -259,7 +259,61 @@ class SamlHandlerTestCase(HomeserverTestCase):
)
self.assertEqual(e.value.location, b"https://custom-saml-redirect/")
+ @override_config(
+ {
+ "saml2_config": {
+ "attribute_requirements": [
+ {"attribute": "userGroup", "value": "staff"},
+ {"attribute": "department", "value": "sales"},
+ ],
+ },
+ }
+ )
+ def test_attribute_requirements(self):
+ """The required attributes must be met from the SAML response."""
+
+ # stub out the auth handler
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
+ # The response doesn't have the proper userGroup or department.
+ saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
+ request = _mock_request()
+ self.get_success(
+ self.handler._handle_authn_response(request, saml_response, "redirect_uri")
+ )
+ auth_handler.complete_sso_login.assert_not_called()
+
+ # The response doesn't have the proper department.
+ saml_response = FakeAuthnResponse(
+ {"uid": "test_user", "username": "test_user", "userGroup": ["staff"]}
+ )
+ request = _mock_request()
+ self.get_success(
+ self.handler._handle_authn_response(request, saml_response, "redirect_uri")
+ )
+ auth_handler.complete_sso_login.assert_not_called()
+
+ # Add the proper attributes and it should succeed.
+ saml_response = FakeAuthnResponse(
+ {
+ "uid": "test_user",
+ "username": "test_user",
+ "userGroup": ["staff", "admin"],
+ "department": ["sales"],
+ }
+ )
+ request.reset_mock()
+ self.get_success(
+ self.handler._handle_authn_response(request, saml_response, "redirect_uri")
+ )
+
+ # check that the auth handler got called as expected
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user:test", request, "redirect_uri", None, new_user=True
+ )
+
def _mock_request():
"""Returns a mock which will stand in as a SynapseRequest"""
- return Mock(spec=["getClientIP", "getHeader"])
+ return Mock(spec=["getClientIP", "getHeader", "_disconnected"])
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 96e5bdac4a..24e7138196 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -143,14 +143,14 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.datastore.get_current_state_deltas = Mock(return_value=(0, None))
self.datastore.get_to_device_stream_token = lambda: 0
- self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: make_awaitable(
- ([], 0)
+ self.datastore.get_new_device_msgs_for_remote = (
+ lambda *args, **kargs: make_awaitable(([], 0))
)
- self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: make_awaitable(
- None
+ self.datastore.delete_device_msgs_for_remote = (
+ lambda *args, **kargs: make_awaitable(None)
)
- self.datastore.set_received_txn_response = lambda *args, **kwargs: make_awaitable(
- None
+ self.datastore.set_received_txn_response = (
+ lambda *args, **kwargs: make_awaitable(None)
)
def test_started_typing_local(self):
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 2afd1970e6..ddfe6950e1 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -203,7 +203,9 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# Check that the room has an encryption state event
event_content = self.helper.get_state(
- room_id=room_id, event_type=EventTypes.RoomEncryption, tok=user_token,
+ room_id=room_id,
+ event_type=EventTypes.RoomEncryption,
+ tok=user_token,
)
self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
@@ -212,7 +214,9 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# Check that the room has an encryption state event
event_content = self.helper.get_state(
- room_id=room_id, event_type=EventTypes.RoomEncryption, tok=user_token,
+ room_id=room_id,
+ event_type=EventTypes.RoomEncryption,
+ tok=user_token,
)
self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
@@ -230,7 +234,9 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# Check that the room has an encryption state event
event_content = self.helper.get_state(
- room_id=room_id, event_type=EventTypes.RoomEncryption, tok=user_token,
+ room_id=room_id,
+ event_type=EventTypes.RoomEncryption,
+ tok=user_token,
)
self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index b758b29b2a..3972abb038 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -518,8 +518,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.successResultOf(test_d)
def test_get_well_known(self):
- """Test the behaviour when the .well-known delegates elsewhere
- """
+ """Test the behaviour when the .well-known delegates elsewhere"""
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4"
@@ -1135,8 +1134,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.assertIsNone(r.delegated_server)
def test_srv_fallbacks(self):
- """Test that other SRV results are tried if the first one fails.
- """
+ """Test that other SRV results are tried if the first one fails."""
self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
[
Server(host=b"target.com", port=8443),
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 27206ca3db..edacd1b566 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -100,7 +100,10 @@ class ModuleApiTestCase(HomeserverTestCase):
# Check that the event was sent
self.event_creation_handler.create_and_send_nonmember_event.assert_called_with(
- expected_requester, event_dict, ratelimit=False, ignore_shadow_ban=True,
+ expected_requester,
+ event_dict,
+ ratelimit=False,
+ ignore_shadow_ban=True,
)
# Create and send a state event
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index d5dce1f83f..f6a6aed35e 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -79,7 +79,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
repl_handler = ReplicationCommandHandler(self.worker_hs)
self.client = ClientReplicationStreamProtocol(
- self.worker_hs, "client", "test", clock, repl_handler,
+ self.worker_hs,
+ "client",
+ "test",
+ clock,
+ repl_handler,
)
self._client_transport = None
@@ -228,7 +232,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
if self.hs.config.redis.redis_enabled:
# Handle attempts to connect to fake redis server.
self.reactor.add_tcp_client_callback(
- "localhost", 6379, self.connect_any_redis_attempts,
+ "localhost",
+ 6379,
+ self.connect_any_redis_attempts,
)
self.hs.get_tcp_replication().start_replication(self.hs)
@@ -246,8 +252,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
)
def create_test_resource(self):
- """Overrides `HomeserverTestCase.create_test_resource`.
- """
+ """Overrides `HomeserverTestCase.create_test_resource`."""
# We override this so that it automatically registers all the HTTP
# replication servlets, without having to explicitly do that in all
# subclassses.
@@ -296,7 +301,10 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
if instance_loc.host not in self.reactor.lookups:
raise Exception(
"Host does not have an IP for instance_map[%r].host = %r"
- % (instance_name, instance_loc.host,)
+ % (
+ instance_name,
+ instance_loc.host,
+ )
)
self.reactor.add_tcp_client_callback(
@@ -315,7 +323,11 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
if not worker_hs.config.redis_enabled:
repl_handler = ReplicationCommandHandler(worker_hs)
client = ClientReplicationStreamProtocol(
- worker_hs, "client", "test", self.clock, repl_handler,
+ worker_hs,
+ "client",
+ "test",
+ self.clock,
+ repl_handler,
)
server = self.server_factory.buildProtocol(None)
@@ -485,8 +497,7 @@ class _PushHTTPChannel(HTTPChannel):
self._pull_to_push_producer.stop()
def checkPersistence(self, request, version):
- """Check whether the connection can be re-used
- """
+ """Check whether the connection can be re-used"""
# We hijack this to always say no for ease of wiring stuff up in
# `handle_http_replication_attempt`.
request.responseHeaders.setRawHeaders(b"connection", [b"close"])
@@ -494,8 +505,7 @@ class _PushHTTPChannel(HTTPChannel):
class _PullToPushProducer:
- """A push producer that wraps a pull producer.
- """
+ """A push producer that wraps a pull producer."""
def __init__(
self, reactor: IReactorTime, producer: IPullProducer, consumer: IConsumer
@@ -512,39 +522,33 @@ class _PullToPushProducer:
self._start_loop()
def _start_loop(self):
- """Start the looping call to
- """
+ """Start the looping call to"""
if not self._looping_call:
# Start a looping call which runs every tick.
self._looping_call = self._clock.looping_call(self._run_once, 0)
def stop(self):
- """Stops calling resumeProducing.
- """
+ """Stops calling resumeProducing."""
if self._looping_call:
self._looping_call.stop()
self._looping_call = None
def pauseProducing(self):
- """Implements IPushProducer
- """
+ """Implements IPushProducer"""
self.stop()
def resumeProducing(self):
- """Implements IPushProducer
- """
+ """Implements IPushProducer"""
self._start_loop()
def stopProducing(self):
- """Implements IPushProducer
- """
+ """Implements IPushProducer"""
self.stop()
self._producer.stopProducing()
def _run_once(self):
- """Calls resumeProducing on producer once.
- """
+ """Calls resumeProducing on producer once."""
try:
self._producer.resumeProducing()
@@ -559,25 +563,21 @@ class _PullToPushProducer:
class FakeRedisPubSubServer:
- """A fake Redis server for pub/sub.
- """
+ """A fake Redis server for pub/sub."""
def __init__(self):
self._subscribers = set()
def add_subscriber(self, conn):
- """A connection has called SUBSCRIBE
- """
+ """A connection has called SUBSCRIBE"""
self._subscribers.add(conn)
def remove_subscriber(self, conn):
- """A connection has called UNSUBSCRIBE
- """
+ """A connection has called UNSUBSCRIBE"""
self._subscribers.discard(conn)
def publish(self, conn, channel, msg) -> int:
- """A connection want to publish a message to subscribers.
- """
+ """A connection want to publish a message to subscribers."""
for sub in self._subscribers:
sub.send(["message", channel, msg])
@@ -588,8 +588,7 @@ class FakeRedisPubSubServer:
class FakeRedisPubSubProtocol(Protocol):
- """A connection from a client talking to the fake Redis server.
- """
+ """A connection from a client talking to the fake Redis server."""
def __init__(self, server: FakeRedisPubSubServer):
self._server = server
@@ -613,8 +612,7 @@ class FakeRedisPubSubProtocol(Protocol):
self.handle_command(msg[0], *msg[1:])
def handle_command(self, command, *args):
- """Received a Redis command from the client.
- """
+ """Received a Redis command from the client."""
# We currently only support pub/sub.
if command == b"PUBLISH":
@@ -635,8 +633,7 @@ class FakeRedisPubSubProtocol(Protocol):
raise Exception("Unknown command")
def send(self, msg):
- """Send a message back to the client.
- """
+ """Send a message back to the client."""
raw = self.encode(msg).encode("utf-8")
self.transport.write(raw)
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index c0ee1cfbd6..0ceb0f935c 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -66,7 +66,10 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.get_success(
self.master_store.store_room(
- ROOM_ID, USER_ID, is_public=False, room_version=RoomVersions.V1,
+ ROOM_ID,
+ USER_ID,
+ is_public=False,
+ room_version=RoomVersions.V1,
)
)
diff --git a/tests/replication/tcp/streams/test_account_data.py b/tests/replication/tcp/streams/test_account_data.py
index 6a5116dd2a..153634d4ee 100644
--- a/tests/replication/tcp/streams/test_account_data.py
+++ b/tests/replication/tcp/streams/test_account_data.py
@@ -23,8 +23,7 @@ from tests.replication._base import BaseStreamTestCase
class AccountDataStreamTestCase(BaseStreamTestCase):
def test_update_function_room_account_data_limit(self):
- """Test replication with many room account data updates
- """
+ """Test replication with many room account data updates"""
store = self.hs.get_datastore()
# generate lots of account data updates
@@ -70,8 +69,7 @@ class AccountDataStreamTestCase(BaseStreamTestCase):
self.assertEqual([], received_rows)
def test_update_function_global_account_data_limit(self):
- """Test replication with many global account data updates
- """
+ """Test replication with many global account data updates"""
store = self.hs.get_datastore()
# generate lots of account data updates
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index bad0df08cf..77856fc304 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -129,7 +129,10 @@ class EventsStreamTestCase(BaseStreamTestCase):
)
pls["users"][OTHER_USER] = 50
self.helper.send_state(
- self.room_id, EventTypes.PowerLevels, pls, tok=self.user_tok,
+ self.room_id,
+ EventTypes.PowerLevels,
+ pls,
+ tok=self.user_tok,
)
# this is the point in the DAG where we make a fork
@@ -255,8 +258,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.assertIsNone(sr.event_id)
def test_update_function_state_row_limit(self):
- """Test replication with many state events over several stream ids.
- """
+ """Test replication with many state events over several stream ids."""
# we want to generate lots of state changes, but for this test, we want to
# spread out the state changes over a few stream IDs.
@@ -282,7 +284,10 @@ class EventsStreamTestCase(BaseStreamTestCase):
)
pls["users"].update({u: 50 for u in user_ids})
self.helper.send_state(
- self.room_id, EventTypes.PowerLevels, pls, tok=self.user_tok,
+ self.room_id,
+ EventTypes.PowerLevels,
+ pls,
+ tok=self.user_tok,
)
# this is the point in the DAG where we make a fork
diff --git a/tests/replication/tcp/test_remote_server_up.py b/tests/replication/tcp/test_remote_server_up.py
index d1c15caeb0..1fe9d5b4d0 100644
--- a/tests/replication/tcp/test_remote_server_up.py
+++ b/tests/replication/tcp/test_remote_server_up.py
@@ -28,8 +28,7 @@ class RemoteServerUpTestCase(HomeserverTestCase):
self.factory = ReplicationStreamProtocolFactory(hs)
def _make_client(self) -> Tuple[IProtocol, StringTransport]:
- """Create a new direct TCP replication connection
- """
+ """Create a new direct TCP replication connection"""
proto = self.factory.buildProtocol(("127.0.0.1", 0))
transport = StringTransport()
diff --git a/tests/replication/test_auth.py b/tests/replication/test_auth.py
index f35a5235e1..f8fd8a843c 100644
--- a/tests/replication/test_auth.py
+++ b/tests/replication/test_auth.py
@@ -79,8 +79,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
)
def test_no_auth(self):
- """With no authentication the request should finish.
- """
+ """With no authentication the request should finish."""
channel = self._test_register()
self.assertEqual(channel.code, 200)
@@ -89,8 +88,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
@override_config({"main_replication_secret": "my-secret"})
def test_missing_auth(self):
- """If the main process expects a secret that is not provided, an error results.
- """
+ """If the main process expects a secret that is not provided, an error results."""
channel = self._test_register()
self.assertEqual(channel.code, 500)
@@ -101,15 +99,13 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
}
)
def test_unauthorized(self):
- """If the main process receives the wrong secret, an error results.
- """
+ """If the main process receives the wrong secret, an error results."""
channel = self._test_register()
self.assertEqual(channel.code, 500)
@override_config({"worker_replication_secret": "my-secret"})
def test_authorized(self):
- """The request should finish when the worker provides the authentication header.
- """
+ """The request should finish when the worker provides the authentication header."""
channel = self._test_register()
self.assertEqual(channel.code, 200)
diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py
index 4608b65a0c..5da1d5dc4d 100644
--- a/tests/replication/test_client_reader_shard.py
+++ b/tests/replication/test_client_reader_shard.py
@@ -35,8 +35,7 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
return config
def test_register_single_worker(self):
- """Test that registration works when using a single client reader worker.
- """
+ """Test that registration works when using a single client reader worker."""
worker_hs = self.make_worker_hs("synapse.app.client_reader")
site = self._hs_to_site[worker_hs]
@@ -66,8 +65,7 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(channel_2.json_body["user_id"], "@user:test")
def test_register_multi_worker(self):
- """Test that registration works when using multiple client reader workers.
- """
+ """Test that registration works when using multiple client reader workers."""
worker_hs_1 = self.make_worker_hs("synapse.app.client_reader")
worker_hs_2 = self.make_worker_hs("synapse.app.client_reader")
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index d1feca961f..7ff11cde10 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -36,8 +36,7 @@ test_server_connection_factory = None
class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
- """Checks running multiple media repos work correctly.
- """
+ """Checks running multiple media repos work correctly."""
servlets = [
admin.register_servlets_for_client_rest_resource,
@@ -124,8 +123,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
return channel, request
def test_basic(self):
- """Test basic fetching of remote media from a single worker.
- """
+ """Test basic fetching of remote media from a single worker."""
hs1 = self.make_worker_hs("synapse.app.generic_worker")
channel, request = self._get_media_req(hs1, "example.com:443", "ABC123")
@@ -223,16 +221,14 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(start_count + 3, self._count_remote_thumbnails())
def _count_remote_media(self) -> int:
- """Count the number of files in our remote media directory.
- """
+ """Count the number of files in our remote media directory."""
path = os.path.join(
self.hs.get_media_repository().primary_base_path, "remote_content"
)
return sum(len(files) for _, _, files in os.walk(path))
def _count_remote_thumbnails(self) -> int:
- """Count the number of files in our remote thumbnails directory.
- """
+ """Count the number of files in our remote thumbnails directory."""
path = os.path.join(
self.hs.get_media_repository().primary_base_path, "remote_thumbnail"
)
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index 800ad94a04..f118fe32af 100644
--- a/tests/replication/test_pusher_shard.py
+++ b/tests/replication/test_pusher_shard.py
@@ -27,8 +27,7 @@ logger = logging.getLogger(__name__)
class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
- """Checks pusher sharding works
- """
+ """Checks pusher sharding works"""
servlets = [
admin.register_servlets_for_client_rest_resource,
@@ -88,11 +87,10 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
return event_id
def test_send_push_single_worker(self):
- """Test that registration works when using a pusher worker.
- """
+ """Test that registration works when using a pusher worker."""
http_client_mock = Mock(spec_set=["post_json_get_json"])
- http_client_mock.post_json_get_json.side_effect = lambda *_, **__: defer.succeed(
- {}
+ http_client_mock.post_json_get_json.side_effect = (
+ lambda *_, **__: defer.succeed({})
)
self.make_worker_hs(
@@ -119,11 +117,10 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
)
def test_send_push_multiple_workers(self):
- """Test that registration works when using sharded pusher workers.
- """
+ """Test that registration works when using sharded pusher workers."""
http_client_mock1 = Mock(spec_set=["post_json_get_json"])
- http_client_mock1.post_json_get_json.side_effect = lambda *_, **__: defer.succeed(
- {}
+ http_client_mock1.post_json_get_json.side_effect = (
+ lambda *_, **__: defer.succeed({})
)
self.make_worker_hs(
@@ -137,8 +134,8 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
)
http_client_mock2 = Mock(spec_set=["post_json_get_json"])
- http_client_mock2.post_json_get_json.side_effect = lambda *_, **__: defer.succeed(
- {}
+ http_client_mock2.post_json_get_json.side_effect = (
+ lambda *_, **__: defer.succeed({})
)
self.make_worker_hs(
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index 8d494ebc03..c9b773fbd2 100644
--- a/tests/replication/test_sharded_event_persister.py
+++ b/tests/replication/test_sharded_event_persister.py
@@ -29,8 +29,7 @@ logger = logging.getLogger(__name__)
class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
- """Checks event persisting sharding works
- """
+ """Checks event persisting sharding works"""
# Event persister sharding requires postgres (due to needing
# `MutliWriterIdGenerator`).
@@ -63,8 +62,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
return conf
def _create_room(self, room_id: str, user_id: str, tok: str):
- """Create a room with given room_id
- """
+ """Create a room with given room_id"""
# We control the room ID generation by patching out the
# `_generate_room_id` method
@@ -91,11 +89,13 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
"""
self.make_worker_hs(
- "synapse.app.generic_worker", {"worker_name": "worker1"},
+ "synapse.app.generic_worker",
+ {"worker_name": "worker1"},
)
self.make_worker_hs(
- "synapse.app.generic_worker", {"worker_name": "worker2"},
+ "synapse.app.generic_worker",
+ {"worker_name": "worker2"},
)
persisted_on_1 = False
@@ -139,15 +139,18 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
"""
self.make_worker_hs(
- "synapse.app.generic_worker", {"worker_name": "worker1"},
+ "synapse.app.generic_worker",
+ {"worker_name": "worker1"},
)
worker_hs2 = self.make_worker_hs(
- "synapse.app.generic_worker", {"worker_name": "worker2"},
+ "synapse.app.generic_worker",
+ {"worker_name": "worker2"},
)
sync_hs = self.make_worker_hs(
- "synapse.app.generic_worker", {"worker_name": "sync"},
+ "synapse.app.generic_worker",
+ {"worker_name": "sync"},
)
sync_hs_site = self._hs_to_site[sync_hs]
@@ -323,7 +326,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
sync_hs_site,
"GET",
"/rooms/{}/messages?from={}&to={}&dir=f".format(
- room_id2, vector_clock_token, prev_batch2,
+ room_id2,
+ vector_clock_token,
+ prev_batch2,
),
access_token=access_token,
)
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 9d22c04073..057e27372e 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -130,8 +130,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
)
def _get_groups_user_is_in(self, access_token):
- """Returns the list of groups the user is in (given their access token)
- """
+ """Returns the list of groups the user is in (given their access token)"""
channel = self.make_request(
"GET", "/joined_groups".encode("ascii"), access_token=access_token
)
@@ -142,8 +141,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
class QuarantineMediaTestCase(unittest.HomeserverTestCase):
- """Test /quarantine_media admin API.
- """
+ """Test /quarantine_media admin API."""
servlets = [
synapse.rest.admin.register_servlets,
@@ -237,7 +235,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Attempt quarantine media APIs as non-admin
url = "/_synapse/admin/v1/media/quarantine/example.org/abcde12345"
channel = self.make_request(
- "POST", url.encode("ascii"), access_token=non_admin_user_tok,
+ "POST",
+ url.encode("ascii"),
+ access_token=non_admin_user_tok,
)
# Expect a forbidden error
@@ -250,7 +250,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# And the roomID/userID endpoint
url = "/_synapse/admin/v1/room/!room%3Aexample.com/media/quarantine"
channel = self.make_request(
- "POST", url.encode("ascii"), access_token=non_admin_user_tok,
+ "POST",
+ url.encode("ascii"),
+ access_token=non_admin_user_tok,
)
# Expect a forbidden error
@@ -294,7 +296,11 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
urllib.parse.quote(server_name),
urllib.parse.quote(media_id),
)
- channel = self.make_request("POST", url, access_token=admin_user_tok,)
+ channel = self.make_request(
+ "POST",
+ url,
+ access_token=admin_user_tok,
+ )
self.pump(1.0)
self.assertEqual(200, int(channel.code), msg=channel.result["body"])
@@ -346,7 +352,11 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/room/%s/media/quarantine" % urllib.parse.quote(
room_id
)
- channel = self.make_request("POST", url, access_token=admin_user_tok,)
+ channel = self.make_request(
+ "POST",
+ url,
+ access_token=admin_user_tok,
+ )
self.pump(1.0)
self.assertEqual(200, int(channel.code), msg=channel.result["body"])
self.assertEqual(
@@ -391,7 +401,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
non_admin_user
)
channel = self.make_request(
- "POST", url.encode("ascii"), access_token=admin_user_tok,
+ "POST",
+ url.encode("ascii"),
+ access_token=admin_user_tok,
)
self.pump(1.0)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -437,7 +449,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
non_admin_user
)
channel = self.make_request(
- "POST", url.encode("ascii"), access_token=admin_user_tok,
+ "POST",
+ url.encode("ascii"),
+ access_token=admin_user_tok,
)
self.pump(1.0)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
index 248c4442c3..2a1bcf1760 100644
--- a/tests/rest/admin/test_device.py
+++ b/tests/rest/admin/test_device.py
@@ -70,21 +70,27 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
If the user is not a server admin, an error is returned.
"""
channel = self.make_request(
- "GET", self.url, access_token=self.other_user_token,
+ "GET",
+ self.url,
+ access_token=self.other_user_token,
)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
channel = self.make_request(
- "PUT", self.url, access_token=self.other_user_token,
+ "PUT",
+ self.url,
+ access_token=self.other_user_token,
)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
channel = self.make_request(
- "DELETE", self.url, access_token=self.other_user_token,
+ "DELETE",
+ self.url,
+ access_token=self.other_user_token,
)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
@@ -99,17 +105,29 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
% self.other_user_device_id
)
- channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
- channel = self.make_request("PUT", url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "PUT",
+ url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
- channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "DELETE",
+ url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -123,17 +141,29 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
% self.other_user_device_id
)
- channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
- channel = self.make_request("PUT", url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "PUT",
+ url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
- channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "DELETE",
+ url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
@@ -146,16 +176,28 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
self.other_user
)
- channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
- channel = self.make_request("PUT", url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "PUT",
+ url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(200, channel.code, msg=channel.json_body)
- channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "DELETE",
+ url,
+ access_token=self.admin_user_tok,
+ )
# Delete unknown device returns status 200
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -190,7 +232,11 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(Codes.TOO_LARGE, channel.json_body["errcode"])
# Ensure the display name was not updated.
- channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("new display", channel.json_body["display_name"])
@@ -207,12 +253,20 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
)
)
- channel = self.make_request("PUT", self.url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "PUT",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(200, channel.code, msg=channel.json_body)
# Ensure the display name was not updated.
- channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("new display", channel.json_body["display_name"])
@@ -233,7 +287,11 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
# Check new display_name
- channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("new displayname", channel.json_body["display_name"])
@@ -242,7 +300,11 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
"""
Tests that a normal lookup for a device is successfully
"""
- channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["user_id"])
@@ -264,7 +326,9 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
# Delete device
channel = self.make_request(
- "DELETE", self.url, access_token=self.admin_user_tok,
+ "DELETE",
+ self.url,
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -306,7 +370,11 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
"""
other_user_token = self.login("user", "pass")
- channel = self.make_request("GET", self.url, access_token=other_user_token,)
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=other_user_token,
+ )
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -316,7 +384,11 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
Tests that a lookup for a user that does not exist returns a 404
"""
url = "/_synapse/admin/v2/users/@unknown_person:test/devices"
- channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -327,7 +399,11 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
"""
url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices"
- channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
@@ -339,7 +415,11 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
"""
# Get devices
- channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
@@ -355,7 +435,11 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
self.login("user", "pass")
# Get devices
- channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(number_devices, channel.json_body["total"])
@@ -404,7 +488,11 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
"""
other_user_token = self.login("user", "pass")
- channel = self.make_request("POST", self.url, access_token=other_user_token,)
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=other_user_token,
+ )
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -414,7 +502,11 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
Tests that a lookup for a user that does not exist returns a 404
"""
url = "/_synapse/admin/v2/users/@unknown_person:test/delete_devices"
- channel = self.make_request("POST", url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "POST",
+ url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -425,7 +517,11 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
"""
url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/delete_devices"
- channel = self.make_request("POST", url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "POST",
+ url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py
index d0090faa4f..e30ffe4fa0 100644
--- a/tests/rest/admin/test_event_reports.py
+++ b/tests/rest/admin/test_event_reports.py
@@ -51,19 +51,23 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
# Two rooms and two users. Every user sends and reports every room event
for i in range(5):
self._create_event_and_report(
- room_id=self.room_id1, user_tok=self.other_user_tok,
+ room_id=self.room_id1,
+ user_tok=self.other_user_tok,
)
for i in range(5):
self._create_event_and_report(
- room_id=self.room_id2, user_tok=self.other_user_tok,
+ room_id=self.room_id2,
+ user_tok=self.other_user_tok,
)
for i in range(5):
self._create_event_and_report(
- room_id=self.room_id1, user_tok=self.admin_user_tok,
+ room_id=self.room_id1,
+ user_tok=self.admin_user_tok,
)
for i in range(5):
self._create_event_and_report(
- room_id=self.room_id2, user_tok=self.admin_user_tok,
+ room_id=self.room_id2,
+ user_tok=self.admin_user_tok,
)
self.url = "/_synapse/admin/v1/event_reports"
@@ -82,7 +86,11 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
If the user is not a server admin, an error 403 is returned.
"""
- channel = self.make_request("GET", self.url, access_token=self.other_user_tok,)
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.other_user_tok,
+ )
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -92,7 +100,11 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
Testing list of reported events
"""
- channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 20)
@@ -106,7 +118,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request(
- "GET", self.url + "?limit=5", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?limit=5",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -121,7 +135,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request(
- "GET", self.url + "?from=5", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?from=5",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -136,7 +152,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request(
- "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?from=5&limit=10",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -213,7 +231,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
# fetch the most recent first, largest timestamp
channel = self.make_request(
- "GET", self.url + "?dir=b", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?dir=b",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -229,7 +249,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
# fetch the oldest first, smallest timestamp
channel = self.make_request(
- "GET", self.url + "?dir=f", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?dir=f",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -249,7 +271,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request(
- "GET", self.url + "?dir=bar", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?dir=bar",
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@@ -262,7 +286,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request(
- "GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?limit=-5",
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@@ -274,7 +300,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request(
- "GET", self.url + "?from=-5", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?from=-5",
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@@ -288,7 +316,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
# `next_token` does not appear
# Number of results is the number of entries
channel = self.make_request(
- "GET", self.url + "?limit=20", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?limit=20",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -299,7 +329,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
# `next_token` does not appear
# Number of max results is larger than the number of entries
channel = self.make_request(
- "GET", self.url + "?limit=21", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?limit=21",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -310,7 +342,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
# `next_token` does appear
# Number of max results is smaller than the number of entries
channel = self.make_request(
- "GET", self.url + "?limit=19", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?limit=19",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -322,7 +356,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
# Set `from` to value of `next_token` for request remaining entries
# `next_token` does not appear
channel = self.make_request(
- "GET", self.url + "?from=19", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?from=19",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -331,8 +367,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
self.assertNotIn("next_token", channel.json_body)
def _create_event_and_report(self, room_id, user_tok):
- """Create and report events
- """
+ """Create and report events"""
resp = self.helper.send(room_id, tok=user_tok)
event_id = resp["event_id"]
@@ -345,8 +380,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
def _check_fields(self, content):
- """Checks that all attributes are present in an event report
- """
+ """Checks that all attributes are present in an event report"""
for c in content:
self.assertIn("id", c)
self.assertIn("received_ts", c)
@@ -381,7 +415,8 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
self.helper.join(self.room_id1, user=self.admin_user, tok=self.admin_user_tok)
self._create_event_and_report(
- room_id=self.room_id1, user_tok=self.other_user_tok,
+ room_id=self.room_id1,
+ user_tok=self.other_user_tok,
)
# first created event report gets `id`=2
@@ -401,7 +436,11 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
If the user is not a server admin, an error 403 is returned.
"""
- channel = self.make_request("GET", self.url, access_token=self.other_user_tok,)
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.other_user_tok,
+ )
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -411,7 +450,11 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
Testing get a reported event
"""
- channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self._check_fields(channel.json_body)
@@ -479,8 +522,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
self.assertEqual("Event report not found", channel.json_body["error"])
def _create_event_and_report(self, room_id, user_tok):
- """Create and report events
- """
+ """Create and report events"""
resp = self.helper.send(room_id, tok=user_tok)
event_id = resp["event_id"]
@@ -493,8 +535,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
def _check_fields(self, content):
- """Checks that all attributes are present in a event report
- """
+ """Checks that all attributes are present in a event report"""
self.assertIn("id", content)
self.assertIn("received_ts", content)
self.assertIn("room_id", content)
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index 51a7731693..31db472cd3 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -63,7 +63,11 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345")
- channel = self.make_request("DELETE", url, access_token=self.other_user_token,)
+ channel = self.make_request(
+ "DELETE",
+ url,
+ access_token=self.other_user_token,
+ )
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -74,7 +78,11 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
"""
url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345")
- channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "DELETE",
+ url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -85,7 +93,11 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
"""
url = "/_synapse/admin/v1/media/%s/%s" % ("unknown_domain", "12345")
- channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "DELETE",
+ url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only delete local media", channel.json_body["error"])
@@ -139,12 +151,17 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, media_id)
# Delete media
- channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "DELETE",
+ url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
- media_id, channel.json_body["deleted_media"][0],
+ media_id,
+ channel.json_body["deleted_media"][0],
)
# Attempt to access media
@@ -207,7 +224,9 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.other_user_token = self.login("user", "pass")
channel = self.make_request(
- "POST", self.url, access_token=self.other_user_token,
+ "POST",
+ self.url,
+ access_token=self.other_user_token,
)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
@@ -220,7 +239,9 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/media/%s/delete" % "unknown_domain"
channel = self.make_request(
- "POST", url + "?before_ts=1234", access_token=self.admin_user_tok,
+ "POST",
+ url + "?before_ts=1234",
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, channel.code, msg=channel.json_body)
@@ -230,7 +251,11 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
"""
If the parameter `before_ts` is missing, an error is returned.
"""
- channel = self.make_request("POST", self.url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
@@ -243,7 +268,9 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
If parameters are invalid, an error is returned.
"""
channel = self.make_request(
- "POST", self.url + "?before_ts=-1234", access_token=self.admin_user_tok,
+ "POST",
+ self.url + "?before_ts=-1234",
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@@ -304,7 +331,8 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
- media_id, channel.json_body["deleted_media"][0],
+ media_id,
+ channel.json_body["deleted_media"][0],
)
self._access_media(server_and_media_id, False)
@@ -340,7 +368,8 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
- server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0],
+ server_and_media_id.split("/")[1],
+ channel.json_body["deleted_media"][0],
)
self._access_media(server_and_media_id, False)
@@ -374,7 +403,8 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
- server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0],
+ server_and_media_id.split("/")[1],
+ channel.json_body["deleted_media"][0],
)
self._access_media(server_and_media_id, False)
@@ -417,7 +447,8 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
- server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0],
+ server_and_media_id.split("/")[1],
+ channel.json_body["deleted_media"][0],
)
self._access_media(server_and_media_id, False)
@@ -461,7 +492,8 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
- server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0],
+ server_and_media_id.split("/")[1],
+ channel.json_body["deleted_media"][0],
)
self._access_media(server_and_media_id, False)
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 2a217b1ce0..b55160b70a 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -127,8 +127,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
self._assert_peek(room_id, expect_code=403)
def _assert_peek(self, room_id, expect_code):
- """Assert that the admin user can (or cannot) peek into the room.
- """
+ """Assert that the admin user can (or cannot) peek into the room."""
url = "rooms/%s/initialSync" % (room_id,)
channel = self.make_request(
@@ -186,7 +185,10 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request(
- "POST", self.url, json.dumps({}), access_token=self.other_user_tok,
+ "POST",
+ self.url,
+ json.dumps({}),
+ access_token=self.other_user_tok,
)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
@@ -199,7 +201,10 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/rooms/!unknown:test/delete"
channel = self.make_request(
- "POST", url, json.dumps({}), access_token=self.admin_user_tok,
+ "POST",
+ url,
+ json.dumps({}),
+ access_token=self.admin_user_tok,
)
self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
@@ -212,12 +217,16 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/rooms/invalidroom/delete"
channel = self.make_request(
- "POST", url, json.dumps({}), access_token=self.admin_user_tok,
+ "POST",
+ url,
+ json.dumps({}),
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(
- "invalidroom is not a legal room ID", channel.json_body["error"],
+ "invalidroom is not a legal room ID",
+ channel.json_body["error"],
)
def test_new_room_user_does_not_exist(self):
@@ -254,7 +263,8 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(
- "User must be our own: @not:exist.bla", channel.json_body["error"],
+ "User must be our own: @not:exist.bla",
+ channel.json_body["error"],
)
def test_block_is_not_bool(self):
@@ -491,8 +501,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
self._assert_peek(self.room_id, expect_code=403)
def _is_blocked(self, room_id, expect=True):
- """Assert that the room is blocked or not
- """
+ """Assert that the room is blocked or not"""
d = self.store.is_room_blocked(room_id)
if expect:
self.assertTrue(self.get_success(d))
@@ -500,20 +509,17 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
self.assertIsNone(self.get_success(d))
def _has_no_members(self, room_id):
- """Assert there is now no longer anyone in the room
- """
+ """Assert there is now no longer anyone in the room"""
users_in_room = self.get_success(self.store.get_users_in_room(room_id))
self.assertEqual([], users_in_room)
def _is_member(self, room_id, user_id):
- """Test that user is member of the room
- """
+ """Test that user is member of the room"""
users_in_room = self.get_success(self.store.get_users_in_room(room_id))
self.assertIn(user_id, users_in_room)
def _is_purged(self, room_id):
- """Test that the following tables have been purged of all rows related to the room.
- """
+ """Test that the following tables have been purged of all rows related to the room."""
for table in PURGE_TABLES:
count = self.get_success(
self.store.db_pool.simple_select_one_onecol(
@@ -527,8 +533,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
def _assert_peek(self, room_id, expect_code):
- """Assert that the admin user can (or cannot) peek into the room.
- """
+ """Assert that the admin user can (or cannot) peek into the room."""
url = "rooms/%s/initialSync" % (room_id,)
channel = self.make_request(
@@ -548,8 +553,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
class PurgeRoomTestCase(unittest.HomeserverTestCase):
- """Test /purge_room admin API.
- """
+ """Test /purge_room admin API."""
servlets = [
synapse.rest.admin.register_servlets,
@@ -594,8 +598,7 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase):
class RoomTestCase(unittest.HomeserverTestCase):
- """Test /room admin API.
- """
+ """Test /room admin API."""
servlets = [
synapse.rest.admin.register_servlets,
@@ -623,7 +626,9 @@ class RoomTestCase(unittest.HomeserverTestCase):
# Request the list of rooms
url = "/_synapse/admin/v1/rooms"
channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ "GET",
+ url.encode("ascii"),
+ access_token=self.admin_user_tok,
)
# Check request completed successfully
@@ -685,7 +690,10 @@ class RoomTestCase(unittest.HomeserverTestCase):
# Set the name of the rooms so we get a consistent returned ordering
for idx, room_id in enumerate(room_ids):
self.helper.send_state(
- room_id, "m.room.name", {"name": str(idx)}, tok=self.admin_user_tok,
+ room_id,
+ "m.room.name",
+ {"name": str(idx)},
+ tok=self.admin_user_tok,
)
# Request the list of rooms
@@ -704,7 +712,9 @@ class RoomTestCase(unittest.HomeserverTestCase):
"name",
)
channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ "GET",
+ url.encode("ascii"),
+ access_token=self.admin_user_tok,
)
self.assertEqual(
200, int(channel.result["code"]), msg=channel.result["body"]
@@ -744,7 +754,9 @@ class RoomTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/rooms?from=%d&limit=%d" % (start, limit)
channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ "GET",
+ url.encode("ascii"),
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -788,13 +800,18 @@ class RoomTestCase(unittest.HomeserverTestCase):
# Set a name for the room
self.helper.send_state(
- room_id, "m.room.name", {"name": test_room_name}, tok=self.admin_user_tok,
+ room_id,
+ "m.room.name",
+ {"name": test_room_name},
+ tok=self.admin_user_tok,
)
# Request the list of rooms
url = "/_synapse/admin/v1/rooms"
channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ "GET",
+ url.encode("ascii"),
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -860,7 +877,9 @@ class RoomTestCase(unittest.HomeserverTestCase):
)
def _order_test(
- order_type: str, expected_room_list: List[str], reverse: bool = False,
+ order_type: str,
+ expected_room_list: List[str],
+ reverse: bool = False,
):
"""Request the list of rooms in a certain order. Assert that order is what
we expect
@@ -875,7 +894,9 @@ class RoomTestCase(unittest.HomeserverTestCase):
if reverse:
url += "&dir=b"
channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ "GET",
+ url.encode("ascii"),
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -907,13 +928,22 @@ class RoomTestCase(unittest.HomeserverTestCase):
# Set room names in alphabetical order. room 1 -> A, 2 -> B, 3 -> C
self.helper.send_state(
- room_id_1, "m.room.name", {"name": "A"}, tok=self.admin_user_tok,
+ room_id_1,
+ "m.room.name",
+ {"name": "A"},
+ tok=self.admin_user_tok,
)
self.helper.send_state(
- room_id_2, "m.room.name", {"name": "B"}, tok=self.admin_user_tok,
+ room_id_2,
+ "m.room.name",
+ {"name": "B"},
+ tok=self.admin_user_tok,
)
self.helper.send_state(
- room_id_3, "m.room.name", {"name": "C"}, tok=self.admin_user_tok,
+ room_id_3,
+ "m.room.name",
+ {"name": "C"},
+ tok=self.admin_user_tok,
)
# Set room canonical room aliases
@@ -990,10 +1020,16 @@ class RoomTestCase(unittest.HomeserverTestCase):
# Set the name for each room
self.helper.send_state(
- room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok,
+ room_id_1,
+ "m.room.name",
+ {"name": room_name_1},
+ tok=self.admin_user_tok,
)
self.helper.send_state(
- room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok,
+ room_id_2,
+ "m.room.name",
+ {"name": room_name_2},
+ tok=self.admin_user_tok,
)
def _search_test(
@@ -1011,7 +1047,9 @@ class RoomTestCase(unittest.HomeserverTestCase):
"""
url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,)
channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ "GET",
+ url.encode("ascii"),
+ access_token=self.admin_user_tok,
)
self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
@@ -1071,15 +1109,23 @@ class RoomTestCase(unittest.HomeserverTestCase):
# Set the name for each room
self.helper.send_state(
- room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok,
+ room_id_1,
+ "m.room.name",
+ {"name": room_name_1},
+ tok=self.admin_user_tok,
)
self.helper.send_state(
- room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok,
+ room_id_2,
+ "m.room.name",
+ {"name": room_name_2},
+ tok=self.admin_user_tok,
)
url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ "GET",
+ url.encode("ascii"),
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -1109,7 +1155,9 @@ class RoomTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ "GET",
+ url.encode("ascii"),
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["joined_local_devices"])
@@ -1121,7 +1169,9 @@ class RoomTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ "GET",
+ url.encode("ascii"),
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(2, channel.json_body["joined_local_devices"])
@@ -1131,7 +1181,9 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.helper.leave(room_id_1, user_1, tok=user_tok_1)
url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ "GET",
+ url.encode("ascii"),
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["joined_local_devices"])
@@ -1160,7 +1212,9 @@ class RoomTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_1,)
channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ "GET",
+ url.encode("ascii"),
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -1171,7 +1225,9 @@ class RoomTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_2,)
channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ "GET",
+ url.encode("ascii"),
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -1187,7 +1243,9 @@ class RoomTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/rooms/%s/state" % (room_id,)
channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ "GET",
+ url.encode("ascii"),
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("state", channel.json_body)
@@ -1342,7 +1400,9 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
# Validate if user is a member of the room
channel = self.make_request(
- "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
+ "GET",
+ "/_matrix/client/r0/joined_rooms",
+ access_token=self.second_tok,
)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0])
@@ -1389,7 +1449,9 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
# Validate if server admin is a member of the room
channel = self.make_request(
- "GET", "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok,
+ "GET",
+ "/_matrix/client/r0/joined_rooms",
+ access_token=self.admin_user_tok,
)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
@@ -1411,7 +1473,9 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
# Validate if user is a member of the room
channel = self.make_request(
- "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
+ "GET",
+ "/_matrix/client/r0/joined_rooms",
+ access_token=self.second_tok,
)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
@@ -1440,7 +1504,9 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
# Validate if user is a member of the room
channel = self.make_request(
- "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
+ "GET",
+ "/_matrix/client/r0/joined_rooms",
+ access_token=self.second_tok,
)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
@@ -1555,8 +1621,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
)
def test_public_room(self):
- """Test that getting admin in a public room works.
- """
+ """Test that getting admin in a public room works."""
room_id = self.helper.create_room_as(
self.creator, tok=self.creator_tok, is_public=True
)
@@ -1581,10 +1646,11 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
)
def test_private_room(self):
- """Test that getting admin in a private room works and we get invited.
- """
+ """Test that getting admin in a private room works and we get invited."""
room_id = self.helper.create_room_as(
- self.creator, tok=self.creator_tok, is_public=False,
+ self.creator,
+ tok=self.creator_tok,
+ is_public=False,
)
channel = self.make_request(
@@ -1608,8 +1674,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
)
def test_other_user(self):
- """Test that giving admin in a public room works to a non-admin user works.
- """
+ """Test that giving admin in a public room works to a non-admin user works."""
room_id = self.helper.create_room_as(
self.creator, tok=self.creator_tok, is_public=True
)
@@ -1634,8 +1699,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
)
def test_not_enough_power(self):
- """Test that we get a sensible error if there are no local room admins.
- """
+ """Test that we get a sensible error if there are no local room admins."""
room_id = self.helper.create_room_as(
self.creator, tok=self.creator_tok, is_public=True
)
diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py
index f48be3d65a..1f1d11f527 100644
--- a/tests/rest/admin/test_statistics.py
+++ b/tests/rest/admin/test_statistics.py
@@ -55,7 +55,10 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
If the user is not a server admin, an error 403 is returned.
"""
channel = self.make_request(
- "GET", self.url, json.dumps({}), access_token=self.other_user_tok,
+ "GET",
+ self.url,
+ json.dumps({}),
+ access_token=self.other_user_tok,
)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
@@ -67,7 +70,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
"""
# unkown order_by
channel = self.make_request(
- "GET", self.url + "?order_by=bar", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?order_by=bar",
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@@ -75,7 +80,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
# negative from
channel = self.make_request(
- "GET", self.url + "?from=-5", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?from=-5",
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@@ -83,7 +90,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
# negative limit
channel = self.make_request(
- "GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?limit=-5",
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@@ -91,7 +100,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
# negative from_ts
channel = self.make_request(
- "GET", self.url + "?from_ts=-1234", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?from_ts=-1234",
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@@ -99,7 +110,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
# negative until_ts
channel = self.make_request(
- "GET", self.url + "?until_ts=-1234", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?until_ts=-1234",
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@@ -117,7 +130,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
# empty search term
channel = self.make_request(
- "GET", self.url + "?search_term=", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?search_term=",
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@@ -125,7 +140,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
# invalid search order
channel = self.make_request(
- "GET", self.url + "?dir=bar", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?dir=bar",
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@@ -138,7 +155,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self._create_users_with_media(10, 2)
channel = self.make_request(
- "GET", self.url + "?limit=5", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?limit=5",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -154,7 +173,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self._create_users_with_media(20, 2)
channel = self.make_request(
- "GET", self.url + "?from=5", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?from=5",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -170,7 +191,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self._create_users_with_media(20, 2)
channel = self.make_request(
- "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?from=5&limit=10",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -190,7 +213,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
# `next_token` does not appear
# Number of results is the number of entries
channel = self.make_request(
- "GET", self.url + "?limit=20", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?limit=20",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -201,7 +226,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
# `next_token` does not appear
# Number of max results is larger than the number of entries
channel = self.make_request(
- "GET", self.url + "?limit=21", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?limit=21",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -212,7 +239,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
# `next_token` does appear
# Number of max results is smaller than the number of entries
channel = self.make_request(
- "GET", self.url + "?limit=19", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?limit=19",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -223,7 +252,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
# Set `from` to value of `next_token` for request remaining entries
# Check `next_token` does not appear
channel = self.make_request(
- "GET", self.url + "?from=19", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?from=19",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -237,7 +268,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
if users have no media created
"""
- channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
@@ -264,10 +299,14 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
# order by user_id
self._order_test("user_id", ["@user_a:test", "@user_b:test", "@user_c:test"])
self._order_test(
- "user_id", ["@user_a:test", "@user_b:test", "@user_c:test"], "f",
+ "user_id",
+ ["@user_a:test", "@user_b:test", "@user_c:test"],
+ "f",
)
self._order_test(
- "user_id", ["@user_c:test", "@user_b:test", "@user_a:test"], "b",
+ "user_id",
+ ["@user_c:test", "@user_b:test", "@user_a:test"],
+ "b",
)
# order by displayname
@@ -275,32 +314,46 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
"displayname", ["@user_c:test", "@user_b:test", "@user_a:test"]
)
self._order_test(
- "displayname", ["@user_c:test", "@user_b:test", "@user_a:test"], "f",
+ "displayname",
+ ["@user_c:test", "@user_b:test", "@user_a:test"],
+ "f",
)
self._order_test(
- "displayname", ["@user_a:test", "@user_b:test", "@user_c:test"], "b",
+ "displayname",
+ ["@user_a:test", "@user_b:test", "@user_c:test"],
+ "b",
)
# order by media_length
self._order_test(
- "media_length", ["@user_a:test", "@user_c:test", "@user_b:test"],
+ "media_length",
+ ["@user_a:test", "@user_c:test", "@user_b:test"],
)
self._order_test(
- "media_length", ["@user_a:test", "@user_c:test", "@user_b:test"], "f",
+ "media_length",
+ ["@user_a:test", "@user_c:test", "@user_b:test"],
+ "f",
)
self._order_test(
- "media_length", ["@user_b:test", "@user_c:test", "@user_a:test"], "b",
+ "media_length",
+ ["@user_b:test", "@user_c:test", "@user_a:test"],
+ "b",
)
# order by media_count
self._order_test(
- "media_count", ["@user_a:test", "@user_c:test", "@user_b:test"],
+ "media_count",
+ ["@user_a:test", "@user_c:test", "@user_b:test"],
)
self._order_test(
- "media_count", ["@user_a:test", "@user_c:test", "@user_b:test"], "f",
+ "media_count",
+ ["@user_a:test", "@user_c:test", "@user_b:test"],
+ "f",
)
self._order_test(
- "media_count", ["@user_b:test", "@user_c:test", "@user_a:test"], "b",
+ "media_count",
+ ["@user_b:test", "@user_c:test", "@user_a:test"],
+ "b",
)
def test_from_until_ts(self):
@@ -313,14 +366,20 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
ts1 = self.clock.time_msec()
# list all media when filter is not set
- channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["users"][0]["media_count"], 3)
# filter media starting at `ts1` after creating first media
# result is 0
channel = self.make_request(
- "GET", self.url + "?from_ts=%s" % (ts1,), access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?from_ts=%s" % (ts1,),
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 0)
@@ -342,7 +401,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
# filter media until `ts2` and earlier
channel = self.make_request(
- "GET", self.url + "?until_ts=%s" % (ts2,), access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?until_ts=%s" % (ts2,),
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["users"][0]["media_count"], 6)
@@ -351,7 +412,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self._create_users_with_media(20, 1)
# check without filter get all users
- channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 20)
@@ -376,7 +441,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
# filter and get empty result
channel = self.make_request(
- "GET", self.url + "?search_term=foobar", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?search_term=foobar",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 0)
@@ -441,7 +508,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
if dir is not None and dir in ("b", "f"):
url += "&dir=%s" % (dir,)
channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ "GET",
+ url.encode("ascii"),
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], len(expected_user_list))
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 59e58a38f7..7eb6f6317a 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -528,9 +528,14 @@ class UsersListTestCase(unittest.HomeserverTestCase):
search_field: Field which is to request: `name` or `user_id`
expected_http_code: The expected http code for the request
"""
- url = self.url + "?%s=%s" % (search_field, search_term,)
+ url = self.url + "?%s=%s" % (
+ search_field,
+ search_term,
+ )
channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ "GET",
+ url.encode("ascii"),
+ access_token=self.admin_user_tok,
)
self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
@@ -590,7 +595,9 @@ class UsersListTestCase(unittest.HomeserverTestCase):
# negative limit
channel = self.make_request(
- "GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?limit=-5",
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@@ -598,7 +605,9 @@ class UsersListTestCase(unittest.HomeserverTestCase):
# negative from
channel = self.make_request(
- "GET", self.url + "?from=-5", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?from=-5",
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@@ -606,7 +615,9 @@ class UsersListTestCase(unittest.HomeserverTestCase):
# invalid guests
channel = self.make_request(
- "GET", self.url + "?guests=not_bool", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?guests=not_bool",
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@@ -614,7 +625,9 @@ class UsersListTestCase(unittest.HomeserverTestCase):
# invalid deactivated
channel = self.make_request(
- "GET", self.url + "?deactivated=not_bool", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?deactivated=not_bool",
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@@ -630,7 +643,9 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self._create_users(number_users - 1)
channel = self.make_request(
- "GET", self.url + "?limit=5", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?limit=5",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -649,7 +664,9 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self._create_users(number_users - 1)
channel = self.make_request(
- "GET", self.url + "?from=5", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?from=5",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -668,7 +685,9 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self._create_users(number_users - 1)
channel = self.make_request(
- "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?from=5&limit=10",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -689,7 +708,9 @@ class UsersListTestCase(unittest.HomeserverTestCase):
# `next_token` does not appear
# Number of results is the number of entries
channel = self.make_request(
- "GET", self.url + "?limit=20", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?limit=20",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -700,7 +721,9 @@ class UsersListTestCase(unittest.HomeserverTestCase):
# `next_token` does not appear
# Number of max results is larger than the number of entries
channel = self.make_request(
- "GET", self.url + "?limit=21", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?limit=21",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -711,7 +734,9 @@ class UsersListTestCase(unittest.HomeserverTestCase):
# `next_token` does appear
# Number of max results is smaller than the number of entries
channel = self.make_request(
- "GET", self.url + "?limit=19", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?limit=19",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -723,7 +748,9 @@ class UsersListTestCase(unittest.HomeserverTestCase):
# Set `from` to value of `next_token` for request remaining entries
# `next_token` does not appear
channel = self.make_request(
- "GET", self.url + "?from=19", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?from=19",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -753,7 +780,10 @@ class UsersListTestCase(unittest.HomeserverTestCase):
"""
for i in range(1, number_users + 1):
self.register_user(
- "user%d" % i, "pass%d" % i, admin=False, displayname="Name %d" % i,
+ "user%d" % i,
+ "pass%d" % i,
+ admin=False,
+ displayname="Name %d" % i,
)
@@ -808,7 +838,10 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
self.assertEqual("You are not a server admin", channel.json_body["error"])
channel = self.make_request(
- "POST", url, access_token=self.other_user_token, content=b"{}",
+ "POST",
+ url,
+ access_token=self.other_user_token,
+ content=b"{}",
)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
@@ -862,7 +895,9 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
# Get user
channel = self.make_request(
- "GET", self.url_other_user, access_token=self.admin_user_tok,
+ "GET",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -886,7 +921,9 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
# Get user
channel = self.make_request(
- "GET", self.url_other_user, access_token=self.admin_user_tok,
+ "GET",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -905,7 +942,9 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
# Get user
channel = self.make_request(
- "GET", self.url_other_user, access_token=self.admin_user_tok,
+ "GET",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -929,7 +968,9 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
# Get user
channel = self.make_request(
- "GET", self.url_other_user, access_token=self.admin_user_tok,
+ "GET",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -947,8 +988,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
self._is_erased("@user:test", False)
def _is_erased(self, user_id: str, expect: bool) -> None:
- """Assert that the user is erased or not
- """
+ """Assert that the user is erased or not"""
d = self.store.is_user_erased(user_id)
if expect:
self.assertTrue(self.get_success(d))
@@ -982,13 +1022,20 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"""
url = "/_synapse/admin/v2/users/@bob:test"
- channel = self.make_request("GET", url, access_token=self.other_user_token,)
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.other_user_token,
+ )
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("You are not a server admin", channel.json_body["error"])
channel = self.make_request(
- "PUT", url, access_token=self.other_user_token, content=b"{}",
+ "PUT",
+ url,
+ access_token=self.other_user_token,
+ content=b"{}",
)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
@@ -1041,7 +1088,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
# Get user
- channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"])
@@ -1086,7 +1137,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
# Get user
- channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"])
@@ -1311,7 +1366,9 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Get user
channel = self.make_request(
- "GET", self.url_other_user, access_token=self.admin_user_tok,
+ "GET",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1342,7 +1399,9 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Get user
channel = self.make_request(
- "GET", self.url_other_user, access_token=self.admin_user_tok,
+ "GET",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1365,7 +1424,9 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Get user
channel = self.make_request(
- "GET", self.url_other_user, access_token=self.admin_user_tok,
+ "GET",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1405,7 +1466,9 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Get user
channel = self.make_request(
- "GET", self.url_other_user, access_token=self.admin_user_tok,
+ "GET",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1503,7 +1566,9 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Get user
channel = self.make_request(
- "GET", self.url_other_user, access_token=self.admin_user_tok,
+ "GET",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1532,7 +1597,9 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Get user
channel = self.make_request(
- "GET", self.url_other_user, access_token=self.admin_user_tok,
+ "GET",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1561,7 +1628,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("bob", channel.json_body["displayname"])
# Get user
- channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"])
@@ -1581,7 +1652,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
# Check user is not deactivated
- channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"])
@@ -1591,8 +1666,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(0, channel.json_body["deactivated"])
def _is_erased(self, user_id, expect):
- """Assert that the user is erased or not
- """
+ """Assert that the user is erased or not"""
d = self.store.is_user_erased(user_id)
if expect:
self.assertTrue(self.get_success(d))
@@ -1632,7 +1706,11 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
"""
other_user_token = self.login("user", "pass")
- channel = self.make_request("GET", self.url, access_token=other_user_token,)
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=other_user_token,
+ )
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -1642,7 +1720,11 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
Tests that a lookup for a user that does not exist returns an empty list
"""
url = "/_synapse/admin/v1/users/@unknown_person:test/joined_rooms"
- channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
@@ -1654,7 +1736,11 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
"""
url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/joined_rooms"
- channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
@@ -1666,7 +1752,11 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
if user has no memberships
"""
# Get rooms
- channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
@@ -1683,7 +1773,11 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
self.helper.create_room_as(self.other_user, tok=other_user_tok)
# Get rooms
- channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(number_rooms, channel.json_body["total"])
@@ -1726,7 +1820,11 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
# Now get rooms
url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms"
- channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
@@ -1766,7 +1864,11 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
"""
other_user_token = self.login("user", "pass")
- channel = self.make_request("GET", self.url, access_token=other_user_token,)
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=other_user_token,
+ )
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -1776,7 +1878,11 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
Tests that a lookup for a user that does not exist returns a 404
"""
url = "/_synapse/admin/v1/users/@unknown_person:test/pushers"
- channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -1787,7 +1893,11 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
"""
url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/pushers"
- channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
@@ -1798,7 +1908,11 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
"""
# Get pushers
- channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
@@ -1825,7 +1939,11 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
)
# Get pushers
- channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
@@ -1874,7 +1992,11 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
"""
other_user_token = self.login("user", "pass")
- channel = self.make_request("GET", self.url, access_token=other_user_token,)
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=other_user_token,
+ )
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -1884,7 +2006,11 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
Tests that a lookup for a user that does not exist returns a 404
"""
url = "/_synapse/admin/v1/users/@unknown_person:test/media"
- channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -1895,7 +2021,11 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
"""
url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/media"
- channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
@@ -1910,7 +2040,9 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
self._create_media(other_user_tok, number_media)
channel = self.make_request(
- "GET", self.url + "?limit=5", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?limit=5",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1929,7 +2061,9 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
self._create_media(other_user_tok, number_media)
channel = self.make_request(
- "GET", self.url + "?from=5", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?from=5",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1948,7 +2082,9 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
self._create_media(other_user_tok, number_media)
channel = self.make_request(
- "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?from=5&limit=10",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1963,7 +2099,9 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request(
- "GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?limit=-5",
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@@ -1975,7 +2113,9 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request(
- "GET", self.url + "?from=-5", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?from=-5",
+ access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@@ -1993,7 +2133,9 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
# `next_token` does not appear
# Number of results is the number of entries
channel = self.make_request(
- "GET", self.url + "?limit=20", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?limit=20",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -2004,7 +2146,9 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
# `next_token` does not appear
# Number of max results is larger than the number of entries
channel = self.make_request(
- "GET", self.url + "?limit=21", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?limit=21",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -2015,7 +2159,9 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
# `next_token` does appear
# Number of max results is smaller than the number of entries
channel = self.make_request(
- "GET", self.url + "?limit=19", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?limit=19",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -2027,7 +2173,9 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
# Set `from` to value of `next_token` for request remaining entries
# `next_token` does not appear
channel = self.make_request(
- "GET", self.url + "?from=19", access_token=self.admin_user_tok,
+ "GET",
+ self.url + "?from=19",
+ access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -2041,7 +2189,11 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
if user has no media created
"""
- channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
@@ -2056,7 +2208,11 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
other_user_tok = self.login("user", "pass")
self._create_media(other_user_tok, number_media)
- channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(number_media, channel.json_body["total"])
@@ -2083,8 +2239,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
)
def _check_fields(self, content):
- """Checks that all attributes are present in content
- """
+ """Checks that all attributes are present in content"""
for m in content:
self.assertIn("media_id", m)
self.assertIn("media_type", m)
@@ -2097,8 +2252,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
class UserTokenRestTestCase(unittest.HomeserverTestCase):
- """Test for /_synapse/admin/v1/users/<user>/login
- """
+ """Test for /_synapse/admin/v1/users/<user>/login"""
servlets = [
synapse.rest.admin.register_servlets,
@@ -2129,16 +2283,14 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
return channel.json_body["access_token"]
def test_no_auth(self):
- """Try to login as a user without authentication.
- """
+ """Try to login as a user without authentication."""
channel = self.make_request("POST", self.url, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_not_admin(self):
- """Try to login as a user as a non-admin user.
- """
+ """Try to login as a user as a non-admin user."""
channel = self.make_request(
"POST", self.url, b"{}", access_token=self.other_user_tok
)
@@ -2146,8 +2298,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
def test_send_event(self):
- """Test that sending event as a user works.
- """
+ """Test that sending event as a user works."""
# Create a room.
room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_tok)
@@ -2161,8 +2312,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(event.sender, self.other_user)
def test_devices(self):
- """Tests that logging in as a user doesn't create a new device for them.
- """
+ """Tests that logging in as a user doesn't create a new device for them."""
# Login in as the user
self._get_token()
@@ -2176,8 +2326,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(channel.json_body["devices"]), 1)
def test_logout(self):
- """Test that calling `/logout` with the token works.
- """
+ """Test that calling `/logout` with the token works."""
# Login in as the user
puppet_token = self._get_token()
@@ -2267,8 +2416,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
}
)
def test_consent(self):
- """Test that sending a message is not subject to the privacy policies.
- """
+ """Test that sending a message is not subject to the privacy policies."""
# Have the admin user accept the terms.
self.get_success(self.store.user_set_consent_version(self.admin_user, "1.0"))
@@ -2343,11 +2491,19 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
self.register_user("user2", "pass")
other_user2_token = self.login("user2", "pass")
- channel = self.make_request("GET", self.url1, access_token=other_user2_token,)
+ channel = self.make_request(
+ "GET",
+ self.url1,
+ access_token=other_user2_token,
+ )
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- channel = self.make_request("GET", self.url2, access_token=other_user2_token,)
+ channel = self.make_request(
+ "GET",
+ self.url2,
+ access_token=other_user2_token,
+ )
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -2358,11 +2514,19 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
url1 = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain"
url2 = "/_matrix/client/r0/admin/whois/@unknown_person:unknown_domain"
- channel = self.make_request("GET", url1, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "GET",
+ url1,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only whois a local user", channel.json_body["error"])
- channel = self.make_request("GET", url2, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "GET",
+ url2,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only whois a local user", channel.json_body["error"])
@@ -2370,12 +2534,20 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
"""
The lookup should succeed for an admin.
"""
- channel = self.make_request("GET", self.url1, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "GET",
+ self.url1,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["user_id"])
self.assertIn("devices", channel.json_body)
- channel = self.make_request("GET", self.url2, access_token=self.admin_user_tok,)
+ channel = self.make_request(
+ "GET",
+ self.url2,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["user_id"])
self.assertIn("devices", channel.json_body)
@@ -2386,12 +2558,20 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
"""
other_user_token = self.login("user", "pass")
- channel = self.make_request("GET", self.url1, access_token=other_user_token,)
+ channel = self.make_request(
+ "GET",
+ self.url1,
+ access_token=other_user_token,
+ )
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["user_id"])
self.assertIn("devices", channel.json_body)
- channel = self.make_request("GET", self.url2, access_token=other_user_token,)
+ channel = self.make_request(
+ "GET",
+ self.url2,
+ access_token=other_user_token,
+ )
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["user_id"])
self.assertIn("devices", channel.json_body)
diff --git a/tests/rest/client/test_power_levels.py b/tests/rest/client/test_power_levels.py
index 913ea3c98e..5256c11fe6 100644
--- a/tests/rest/client/test_power_levels.py
+++ b/tests/rest/client/test_power_levels.py
@@ -73,7 +73,9 @@ class PowerLevelsTestCase(HomeserverTestCase):
# Mod the mod
room_power_levels = self.helper.get_state(
- self.room_id, "m.room.power_levels", tok=self.admin_access_token,
+ self.room_id,
+ "m.room.power_levels",
+ tok=self.admin_access_token,
)
# Update existing power levels with mod at PL50
diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py
index f0707646bb..e0c74591b6 100644
--- a/tests/rest/client/test_redactions.py
+++ b/tests/rest/client/test_redactions.py
@@ -181,8 +181,7 @@ class RedactionsTestCase(HomeserverTestCase):
)
def test_redact_event_as_moderator_ratelimit(self):
- """Tests that the correct ratelimiting is applied to redactions
- """
+ """Tests that the correct ratelimiting is applied to redactions"""
message_ids = []
# as a regular user, send messages to redact
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index 10b1fbac69..b8285f3240 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -252,7 +252,8 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
mock_federation_client = Mock(spec=["backfill"])
self.hs = self.setup_test_homeserver(
- config=config, federation_client=mock_federation_client,
+ config=config,
+ federation_client=mock_federation_client,
)
return self.hs
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index 0ebdf1415b..d2cce44032 100644
--- a/tests/rest/client/test_shadow_banned.py
+++ b/tests/rest/client/test_shadow_banned.py
@@ -260,7 +260,10 @@ class ProfileTestCase(_ShadowBannedBase):
message_handler = self.hs.get_message_handler()
event = self.get_success(
message_handler.get_room_data(
- self.banned_user_id, room_id, "m.room.member", self.banned_user_id,
+ self.banned_user_id,
+ room_id,
+ "m.room.member",
+ self.banned_user_id,
)
)
self.assertEqual(
@@ -292,7 +295,10 @@ class ProfileTestCase(_ShadowBannedBase):
message_handler = self.hs.get_message_handler()
event = self.get_success(
message_handler.get_room_data(
- self.banned_user_id, room_id, "m.room.member", self.banned_user_id,
+ self.banned_user_id,
+ room_id,
+ "m.room.member",
+ self.banned_user_id,
)
)
self.assertEqual(
diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py
index 0a5ca317ea..2ae896db1e 100644
--- a/tests/rest/client/v1/test_events.py
+++ b/tests/rest/client/v1/test_events.py
@@ -150,6 +150,8 @@ class GetEventsTestCase(unittest.HomeserverTestCase):
event_id = resp["event_id"]
channel = self.make_request(
- "GET", "/events/" + event_id, access_token=self.token,
+ "GET",
+ "/events/" + event_id,
+ access_token=self.token,
)
self.assertEquals(channel.code, 200, msg=channel.result)
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 49543d9acb..fb29eaed6f 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -611,7 +611,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
# matrix access token, mxid, and device id.
login_token = params[2][1]
chan = self.make_request(
- "POST", "/login", content={"type": "m.login.token", "token": login_token},
+ "POST",
+ "/login",
+ content={"type": "m.login.token", "token": login_token},
)
self.assertEqual(chan.code, 200, chan.result)
self.assertEqual(chan.json_body["user_id"], "@user1:test")
@@ -619,7 +621,8 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
def test_multi_sso_redirect_to_unknown(self):
"""An unknown IdP should cause a 400"""
channel = self.make_request(
- "GET", "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
+ "GET",
+ "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
)
self.assertEqual(channel.code, 400, channel.result)
@@ -719,7 +722,8 @@ class CASTestCase(unittest.HomeserverTestCase):
mocked_http_client.get_raw.side_effect = get_raw
self.hs = self.setup_test_homeserver(
- config=config, proxied_http_client=mocked_http_client,
+ config=config,
+ proxied_http_client=mocked_http_client,
)
return self.hs
@@ -1244,7 +1248,9 @@ class UsernamePickerTestCase(HomeserverTestCase):
# looks ok.
username_mapping_sessions = self.hs.get_sso_handler()._username_mapping_sessions
self.assertIn(
- session_id, username_mapping_sessions, "session id not found in map",
+ session_id,
+ username_mapping_sessions,
+ "session id not found in map",
)
session = username_mapping_sessions[session_id]
self.assertEqual(session.remote_user_id, "tester")
@@ -1299,7 +1305,9 @@ class UsernamePickerTestCase(HomeserverTestCase):
# finally, submit the matrix login token to the login API, which gives us our
# matrix access token, mxid, and device id.
chan = self.make_request(
- "POST", "/login", content={"type": "m.login.token", "token": login_token},
+ "POST",
+ "/login",
+ content={"type": "m.login.token", "token": login_token},
)
self.assertEqual(chan.code, 200, chan.result)
self.assertEqual(chan.json_body["user_id"], "@bobby:test")
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index e59fa70baa..f3448c94dd 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -14,163 +14,11 @@
# limitations under the License.
"""Tests REST events for /profile paths."""
-import json
-
-from mock import Mock
-
-from twisted.internet import defer
-
-import synapse.types
-from synapse.api.errors import AuthError, SynapseError
from synapse.rest import admin
from synapse.rest.client.v1 import login, profile, room
from tests import unittest
-from ....utils import MockHttpResource, setup_test_homeserver
-
-myid = "@1234ABCD:test"
-PATH_PREFIX = "/_matrix/client/r0"
-
-
-class MockHandlerProfileTestCase(unittest.TestCase):
- """ Tests rest layer of profile management.
-
- Todo: move these into ProfileTestCase
- """
-
- @defer.inlineCallbacks
- def setUp(self):
- self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
- self.mock_handler = Mock(
- spec=[
- "get_displayname",
- "set_displayname",
- "get_avatar_url",
- "set_avatar_url",
- "check_profile_query_allowed",
- ]
- )
-
- self.mock_handler.get_displayname.return_value = defer.succeed(Mock())
- self.mock_handler.set_displayname.return_value = defer.succeed(Mock())
- self.mock_handler.get_avatar_url.return_value = defer.succeed(Mock())
- self.mock_handler.set_avatar_url.return_value = defer.succeed(Mock())
- self.mock_handler.check_profile_query_allowed.return_value = defer.succeed(
- Mock()
- )
-
- hs = yield setup_test_homeserver(
- self.addCleanup,
- "test",
- federation_http_client=None,
- resource_for_client=self.mock_resource,
- federation=Mock(),
- federation_client=Mock(),
- profile_handler=self.mock_handler,
- )
-
- async def _get_user_by_req(request=None, allow_guest=False):
- return synapse.types.create_requester(myid)
-
- hs.get_auth().get_user_by_req = _get_user_by_req
-
- profile.register_servlets(hs, self.mock_resource)
-
- @defer.inlineCallbacks
- def test_get_my_name(self):
- mocked_get = self.mock_handler.get_displayname
- mocked_get.return_value = defer.succeed("Frank")
-
- (code, response) = yield self.mock_resource.trigger(
- "GET", "/profile/%s/displayname" % (myid), None
- )
-
- self.assertEquals(200, code)
- self.assertEquals({"displayname": "Frank"}, response)
- self.assertEquals(mocked_get.call_args[0][0].localpart, "1234ABCD")
-
- @defer.inlineCallbacks
- def test_set_my_name(self):
- mocked_set = self.mock_handler.set_displayname
- mocked_set.return_value = defer.succeed(())
-
- (code, response) = yield self.mock_resource.trigger(
- "PUT", "/profile/%s/displayname" % (myid), b'{"displayname": "Frank Jr."}'
- )
-
- self.assertEquals(200, code)
- self.assertEquals(mocked_set.call_args[0][0].localpart, "1234ABCD")
- self.assertEquals(mocked_set.call_args[0][1].user.localpart, "1234ABCD")
- self.assertEquals(mocked_set.call_args[0][2], "Frank Jr.")
-
- @defer.inlineCallbacks
- def test_set_my_name_noauth(self):
- mocked_set = self.mock_handler.set_displayname
- mocked_set.side_effect = AuthError(400, "message")
-
- (code, response) = yield self.mock_resource.trigger(
- "PUT",
- "/profile/%s/displayname" % ("@4567:test"),
- b'{"displayname": "Frank Jr."}',
- )
-
- self.assertTrue(400 <= code < 499, msg="code %d is in the 4xx range" % (code))
-
- @defer.inlineCallbacks
- def test_get_other_name(self):
- mocked_get = self.mock_handler.get_displayname
- mocked_get.return_value = defer.succeed("Bob")
-
- (code, response) = yield self.mock_resource.trigger(
- "GET", "/profile/%s/displayname" % ("@opaque:elsewhere"), None
- )
-
- self.assertEquals(200, code)
- self.assertEquals({"displayname": "Bob"}, response)
-
- @defer.inlineCallbacks
- def test_set_other_name(self):
- mocked_set = self.mock_handler.set_displayname
- mocked_set.side_effect = SynapseError(400, "message")
-
- (code, response) = yield self.mock_resource.trigger(
- "PUT",
- "/profile/%s/displayname" % ("@opaque:elsewhere"),
- b'{"displayname":"bob"}',
- )
-
- self.assertTrue(400 <= code <= 499, msg="code %d is in the 4xx range" % (code))
-
- @defer.inlineCallbacks
- def test_get_my_avatar(self):
- mocked_get = self.mock_handler.get_avatar_url
- mocked_get.return_value = defer.succeed("http://my.server/me.png")
-
- (code, response) = yield self.mock_resource.trigger(
- "GET", "/profile/%s/avatar_url" % (myid), None
- )
-
- self.assertEquals(200, code)
- self.assertEquals({"avatar_url": "http://my.server/me.png"}, response)
- self.assertEquals(mocked_get.call_args[0][0].localpart, "1234ABCD")
-
- @defer.inlineCallbacks
- def test_set_my_avatar(self):
- mocked_set = self.mock_handler.set_avatar_url
- mocked_set.return_value = defer.succeed(())
-
- (code, response) = yield self.mock_resource.trigger(
- "PUT",
- "/profile/%s/avatar_url" % (myid),
- b'{"avatar_url": "http://my.server/pic.gif"}',
- )
-
- self.assertEquals(200, code)
- self.assertEquals(mocked_set.call_args[0][0].localpart, "1234ABCD")
- self.assertEquals(mocked_set.call_args[0][1].user.localpart, "1234ABCD")
- self.assertEquals(mocked_set.call_args[0][2], "http://my.server/pic.gif")
-
class ProfileTestCase(unittest.HomeserverTestCase):
@@ -187,37 +35,122 @@ class ProfileTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.owner = self.register_user("owner", "pass")
self.owner_tok = self.login("owner", "pass")
+ self.other = self.register_user("other", "pass", displayname="Bob")
+
+ def test_get_displayname(self):
+ res = self._get_displayname()
+ self.assertEqual(res, "owner")
def test_set_displayname(self):
channel = self.make_request(
"PUT",
"/profile/%s/displayname" % (self.owner,),
- content=json.dumps({"displayname": "test"}),
+ content={"displayname": "test"},
access_token=self.owner_tok,
)
self.assertEqual(channel.code, 200, channel.result)
- res = self.get_displayname()
+ res = self._get_displayname()
self.assertEqual(res, "test")
+ def test_set_displayname_noauth(self):
+ channel = self.make_request(
+ "PUT",
+ "/profile/%s/displayname" % (self.owner,),
+ content={"displayname": "test"},
+ )
+ self.assertEqual(channel.code, 401, channel.result)
+
def test_set_displayname_too_long(self):
"""Attempts to set a stupid displayname should get a 400"""
channel = self.make_request(
"PUT",
"/profile/%s/displayname" % (self.owner,),
- content=json.dumps({"displayname": "test" * 100}),
+ content={"displayname": "test" * 100},
access_token=self.owner_tok,
)
self.assertEqual(channel.code, 400, channel.result)
- res = self.get_displayname()
+ res = self._get_displayname()
self.assertEqual(res, "owner")
- def get_displayname(self):
- channel = self.make_request("GET", "/profile/%s/displayname" % (self.owner,))
+ def test_get_displayname_other(self):
+ res = self._get_displayname(self.other)
+ self.assertEquals(res, "Bob")
+
+ def test_set_displayname_other(self):
+ channel = self.make_request(
+ "PUT",
+ "/profile/%s/displayname" % (self.other,),
+ content={"displayname": "test"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+
+ def test_get_avatar_url(self):
+ res = self._get_avatar_url()
+ self.assertIsNone(res)
+
+ def test_set_avatar_url(self):
+ channel = self.make_request(
+ "PUT",
+ "/profile/%s/avatar_url" % (self.owner,),
+ content={"avatar_url": "http://my.server/pic.gif"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ res = self._get_avatar_url()
+ self.assertEqual(res, "http://my.server/pic.gif")
+
+ def test_set_avatar_url_noauth(self):
+ channel = self.make_request(
+ "PUT",
+ "/profile/%s/avatar_url" % (self.owner,),
+ content={"avatar_url": "http://my.server/pic.gif"},
+ )
+ self.assertEqual(channel.code, 401, channel.result)
+
+ def test_set_avatar_url_too_long(self):
+ """Attempts to set a stupid avatar_url should get a 400"""
+ channel = self.make_request(
+ "PUT",
+ "/profile/%s/avatar_url" % (self.owner,),
+ content={"avatar_url": "http://my.server/pic.gif" * 100},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+
+ res = self._get_avatar_url()
+ self.assertIsNone(res)
+
+ def test_get_avatar_url_other(self):
+ res = self._get_avatar_url(self.other)
+ self.assertIsNone(res)
+
+ def test_set_avatar_url_other(self):
+ channel = self.make_request(
+ "PUT",
+ "/profile/%s/avatar_url" % (self.other,),
+ content={"avatar_url": "http://my.server/pic.gif"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+
+ def _get_displayname(self, name=None):
+ channel = self.make_request(
+ "GET", "/profile/%s/displayname" % (name or self.owner,)
+ )
self.assertEqual(channel.code, 200, channel.result)
return channel.json_body["displayname"]
+ def _get_avatar_url(self, name=None):
+ channel = self.make_request(
+ "GET", "/profile/%s/avatar_url" % (name or self.owner,)
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+ return channel.json_body.get("avatar_url")
+
class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 2548b3a80c..ed65f645fc 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -46,7 +46,9 @@ class RoomBase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.hs = self.setup_test_homeserver(
- "red", federation_http_client=None, federation_client=Mock(),
+ "red",
+ federation_http_client=None,
+ federation_client=Mock(),
)
self.hs.get_federation_handler = Mock()
@@ -1480,7 +1482,9 @@ class LabelsTestCase(unittest.HomeserverTestCase):
results = channel.json_body["search_categories"]["room_events"]["results"]
self.assertEqual(
- len(results), 2, [result["result"]["content"] for result in results],
+ len(results),
+ 2,
+ [result["result"]["content"] for result in results],
)
self.assertEqual(
results[0]["result"]["content"]["body"],
@@ -1515,7 +1519,9 @@ class LabelsTestCase(unittest.HomeserverTestCase):
results = channel.json_body["search_categories"]["room_events"]["results"]
self.assertEqual(
- len(results), 4, [result["result"]["content"] for result in results],
+ len(results),
+ 4,
+ [result["result"]["content"] for result in results],
)
self.assertEqual(
results[0]["result"]["content"]["body"],
@@ -1562,7 +1568,9 @@ class LabelsTestCase(unittest.HomeserverTestCase):
results = channel.json_body["search_categories"]["room_events"]["results"]
self.assertEqual(
- len(results), 1, [result["result"]["content"] for result in results],
+ len(results),
+ 1,
+ [result["result"]["content"] for result in results],
)
self.assertEqual(
results[0]["result"]["content"]["body"],
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 38c51525a3..329dbd06de 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -18,8 +18,6 @@
from mock import Mock
-from twisted.internet import defer
-
from synapse.rest.client.v1 import room
from synapse.types import UserID
@@ -39,7 +37,9 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
- "red", federation_http_client=None, federation_client=Mock(),
+ "red",
+ federation_http_client=None,
+ federation_client=Mock(),
)
self.event_source = hs.get_event_sources().sources["typing"]
@@ -60,32 +60,6 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
hs.get_datastore().insert_client_ip = _insert_client_ip
- def get_room_members(room_id):
- if room_id == self.room_id:
- return defer.succeed([self.user])
- else:
- return defer.succeed([])
-
- @defer.inlineCallbacks
- def fetch_room_distributions_into(
- room_id, localusers=None, remotedomains=None, ignore_user=None
- ):
- members = yield get_room_members(room_id)
- for member in members:
- if ignore_user is not None and member == ignore_user:
- continue
-
- if hs.is_mine(member):
- if localusers is not None:
- localusers.add(member)
- else:
- if remotedomains is not None:
- remotedomains.add(member.domain)
-
- hs.get_room_member_handler().fetch_room_distributions_into = (
- fetch_room_distributions_into
- )
-
return hs
def prepare(self, reactor, clock, hs):
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index b1333df82d..8231a423f3 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -166,9 +166,12 @@ class RestHelper:
json.dumps(data).encode("utf8"),
)
- assert int(channel.result["code"]) == expect_code, (
- "Expected: %d, got: %d, resp: %r"
- % (expect_code, int(channel.result["code"]), channel.result["body"])
+ assert (
+ int(channel.result["code"]) == expect_code
+ ), "Expected: %d, got: %d, resp: %r" % (
+ expect_code,
+ int(channel.result["code"]),
+ channel.result["body"],
)
self.auth_user_id = temp_id
@@ -201,9 +204,12 @@ class RestHelper:
json.dumps(content).encode("utf8"),
)
- assert int(channel.result["code"]) == expect_code, (
- "Expected: %d, got: %d, resp: %r"
- % (expect_code, int(channel.result["code"]), channel.result["body"])
+ assert (
+ int(channel.result["code"]) == expect_code
+ ), "Expected: %d, got: %d, resp: %r" % (
+ expect_code,
+ int(channel.result["code"]),
+ channel.result["body"],
)
return channel.json_body
@@ -251,9 +257,12 @@ class RestHelper:
channel = make_request(self.hs.get_reactor(), self.site, method, path, content)
- assert int(channel.result["code"]) == expect_code, (
- "Expected: %d, got: %d, resp: %r"
- % (expect_code, int(channel.result["code"]), channel.result["body"])
+ assert (
+ int(channel.result["code"]) == expect_code
+ ), "Expected: %d, got: %d, resp: %r" % (
+ expect_code,
+ int(channel.result["code"]),
+ channel.result["body"],
)
return channel.json_body
@@ -447,7 +456,10 @@ class RestHelper:
return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict)
def complete_oidc_auth(
- self, oauth_uri: str, cookies: Mapping[str, str], user_info_dict: JsonDict,
+ self,
+ oauth_uri: str,
+ cookies: Mapping[str, str],
+ user_info_dict: JsonDict,
) -> FakeChannel:
"""Mock out an OIDC authentication flow
@@ -491,7 +503,9 @@ class RestHelper:
(expected_uri, resp_obj) = expected_requests.pop(0)
assert uri == expected_uri
resp = FakeResponse(
- code=200, phrase=b"OK", body=json.dumps(resp_obj).encode("utf-8"),
+ code=200,
+ phrase=b"OK",
+ body=json.dumps(resp_obj).encode("utf-8"),
)
return resp
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index 177dc476da..e72b61963d 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -75,8 +75,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
self.submit_token_resource = PasswordResetSubmitTokenResource(hs)
def test_basic_password_reset(self):
- """Test basic password reset flow
- """
+ """Test basic password reset flow"""
old_password = "monkey"
new_password = "kangeroo"
@@ -114,8 +113,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
@override_config({"rc_3pid_validation": {"burst_count": 3}})
def test_ratelimit_by_email(self):
- """Test that we ratelimit /requestToken for the same email.
- """
+ """Test that we ratelimit /requestToken for the same email."""
old_password = "monkey"
new_password = "kangeroo"
@@ -203,8 +201,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
self.attempt_wrong_password_login("kermit", old_password)
def test_cant_reset_password_without_clicking_link(self):
- """Test that we do actually need to click the link in the email
- """
+ """Test that we do actually need to click the link in the email"""
old_password = "monkey"
new_password = "kangeroo"
@@ -299,7 +296,9 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
if channel.code != 200:
raise HttpResponseException(
- channel.code, channel.result["reason"], channel.result["body"],
+ channel.code,
+ channel.result["reason"],
+ channel.result["body"],
)
return channel.json_body["sid"]
@@ -566,8 +565,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
@override_config({"rc_3pid_validation": {"burst_count": 3}})
def test_ratelimit_by_ip(self):
- """Tests that adding emails is ratelimited by IP
- """
+ """Tests that adding emails is ratelimited by IP"""
# We expect to be able to set three emails before getting ratelimited.
self.get_success(self._add_email("foo1@test.bar", "foo1@test.bar"))
@@ -580,8 +578,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(cm.exception.code, 429)
def test_add_email_if_disabled(self):
- """Test adding email to profile when doing so is disallowed
- """
+ """Test adding email to profile when doing so is disallowed"""
self.hs.config.enable_3pid_changes = False
client_secret = "foobar"
@@ -611,15 +608,16 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
# Get user
channel = self.make_request(
- "GET", self.url_3pid, access_token=self.user_id_tok,
+ "GET",
+ self.url_3pid,
+ access_token=self.user_id_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
def test_delete_email(self):
- """Test deleting an email from profile
- """
+ """Test deleting an email from profile"""
# Add a threepid
self.get_success(
self.store.user_add_threepid(
@@ -641,15 +639,16 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
# Get user
channel = self.make_request(
- "GET", self.url_3pid, access_token=self.user_id_tok,
+ "GET",
+ self.url_3pid,
+ access_token=self.user_id_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
def test_delete_email_if_disabled(self):
- """Test deleting an email from profile when disallowed
- """
+ """Test deleting an email from profile when disallowed"""
self.hs.config.enable_3pid_changes = False
# Add a threepid
@@ -675,7 +674,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
# Get user
channel = self.make_request(
- "GET", self.url_3pid, access_token=self.user_id_tok,
+ "GET",
+ self.url_3pid,
+ access_token=self.user_id_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -683,8 +684,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(self.email, channel.json_body["threepids"][0]["address"])
def test_cant_add_email_without_clicking_link(self):
- """Test that we do actually need to click the link in the email
- """
+ """Test that we do actually need to click the link in the email"""
client_secret = "foobar"
session_id = self._request_token(self.email, client_secret)
@@ -710,7 +710,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
# Get user
channel = self.make_request(
- "GET", self.url_3pid, access_token=self.user_id_tok,
+ "GET",
+ self.url_3pid,
+ access_token=self.user_id_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -743,7 +745,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
# Get user
channel = self.make_request(
- "GET", self.url_3pid, access_token=self.user_id_tok,
+ "GET",
+ self.url_3pid,
+ access_token=self.user_id_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -788,7 +792,10 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
# Ensure not providing a next_link parameter still works
self._request_token(
- "something@example.com", "some_secret", next_link=None, expect_code=200,
+ "something@example.com",
+ "some_secret",
+ next_link=None,
+ expect_code=200,
)
self._request_token(
@@ -846,17 +853,27 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
if next_link:
body["next_link"] = next_link
- channel = self.make_request("POST", b"account/3pid/email/requestToken", body,)
+ channel = self.make_request(
+ "POST",
+ b"account/3pid/email/requestToken",
+ body,
+ )
if channel.code != expect_code:
raise HttpResponseException(
- channel.code, channel.result["reason"], channel.result["body"],
+ channel.code,
+ channel.result["reason"],
+ channel.result["body"],
)
return channel.json_body.get("sid")
def _request_token_invalid_email(
- self, email, expected_errcode, expected_error, client_secret="foobar",
+ self,
+ email,
+ expected_errcode,
+ expected_error,
+ client_secret="foobar",
):
channel = self.make_request(
"POST",
@@ -895,8 +912,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
return match.group(0)
def _add_email(self, request_email, expected_email):
- """Test adding an email to profile
- """
+ """Test adding an email to profile"""
previous_email_attempts = len(self.email_attempts)
client_secret = "foobar"
@@ -926,7 +942,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
# Get user
channel = self.make_request(
- "GET", self.url_3pid, access_token=self.user_id_tok,
+ "GET",
+ self.url_3pid,
+ access_token=self.user_id_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index 3f50c56745..501f09203f 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -102,7 +102,8 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
"""Ensure that fallback auth via a captcha works."""
# Returns a 401 as per the spec
channel = self.register(
- 401, {"username": "user", "type": "m.login.password", "password": "bar"},
+ 401,
+ {"username": "user", "type": "m.login.password", "password": "bar"},
)
# Grab the session
@@ -191,7 +192,10 @@ class UIAuthTests(unittest.HomeserverTestCase):
) -> FakeChannel:
"""Delete an individual device."""
channel = self.make_request(
- "DELETE", "devices/" + device, body, access_token=access_token,
+ "DELETE",
+ "devices/" + device,
+ body,
+ access_token=access_token,
)
# Ensure the response is sane.
@@ -204,7 +208,10 @@ class UIAuthTests(unittest.HomeserverTestCase):
# Note that this uses the delete_devices endpoint so that we can modify
# the payload half-way through some tests.
channel = self.make_request(
- "POST", "delete_devices", body, access_token=self.user_tok,
+ "POST",
+ "delete_devices",
+ body,
+ access_token=self.user_tok,
)
# Ensure the response is sane.
@@ -417,7 +424,10 @@ class UIAuthTests(unittest.HomeserverTestCase):
# and now the delete request should succeed.
self.delete_device(
- self.user_tok, self.device_id, 200, body={"auth": {"session": session_id}},
+ self.user_tok,
+ self.device_id,
+ 200,
+ body={"auth": {"session": session_id}},
)
@skip_unless(HAS_OIDC, "requires OIDC")
@@ -443,8 +453,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
@skip_unless(HAS_OIDC, "requires OIDC")
@override_config({"oidc_config": TEST_OIDC_CONFIG})
def test_offers_both_flows_for_upgraded_user(self):
- """A user that had a password and then logged in with SSO should get both flows
- """
+ """A user that had a password and then logged in with SSO should get both flows"""
login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
self.assertEqual(login_resp["user_id"], self.user)
@@ -459,8 +468,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
@skip_unless(HAS_OIDC, "requires OIDC")
@override_config({"oidc_config": TEST_OIDC_CONFIG})
def test_ui_auth_fails_for_incorrect_sso_user(self):
- """If the user tries to authenticate with the wrong SSO user, they get an error
- """
+ """If the user tries to authenticate with the wrong SSO user, they get an error"""
# log the user in
login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
self.assertEqual(login_resp["user_id"], self.user)
diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/v2_alpha/test_password_policy.py
index fba34def30..5ebc5707a5 100644
--- a/tests/rest/client/v2_alpha/test_password_policy.py
+++ b/tests/rest/client/v2_alpha/test_password_policy.py
@@ -91,7 +91,9 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(
- channel.json_body["errcode"], Codes.PASSWORD_TOO_SHORT, channel.result,
+ channel.json_body["errcode"],
+ Codes.PASSWORD_TOO_SHORT,
+ channel.result,
)
def test_password_no_digit(self):
@@ -100,7 +102,9 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(
- channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT, channel.result,
+ channel.json_body["errcode"],
+ Codes.PASSWORD_NO_DIGIT,
+ channel.result,
)
def test_password_no_symbol(self):
@@ -109,7 +113,9 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(
- channel.json_body["errcode"], Codes.PASSWORD_NO_SYMBOL, channel.result,
+ channel.json_body["errcode"],
+ Codes.PASSWORD_NO_SYMBOL,
+ channel.result,
)
def test_password_no_uppercase(self):
@@ -118,7 +124,9 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(
- channel.json_body["errcode"], Codes.PASSWORD_NO_UPPERCASE, channel.result,
+ channel.json_body["errcode"],
+ Codes.PASSWORD_NO_UPPERCASE,
+ channel.result,
)
def test_password_no_lowercase(self):
@@ -127,7 +135,9 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(
- channel.json_body["errcode"], Codes.PASSWORD_NO_LOWERCASE, channel.result,
+ channel.json_body["errcode"],
+ Codes.PASSWORD_NO_LOWERCASE,
+ channel.result,
)
def test_password_compliant(self):
diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py
index bd574077e7..7c457754f1 100644
--- a/tests/rest/client/v2_alpha/test_relations.py
+++ b/tests/rest/client/v2_alpha/test_relations.py
@@ -83,14 +83,12 @@ class RelationsTestCase(unittest.HomeserverTestCase):
)
def test_deny_membership(self):
- """Test that we deny relations on membership events
- """
+ """Test that we deny relations on membership events"""
channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member)
self.assertEquals(400, channel.code, channel.json_body)
def test_deny_double_react(self):
- """Test that we deny relations on membership events
- """
+ """Test that we deny relations on membership events"""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
self.assertEquals(200, channel.code, channel.json_body)
@@ -98,8 +96,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(400, channel.code, channel.json_body)
def test_basic_paginate_relations(self):
- """Tests that calling pagination API correctly the latest relations.
- """
+ """Tests that calling pagination API correctly the latest relations."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction")
self.assertEquals(200, channel.code, channel.json_body)
@@ -174,8 +171,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(found_event_ids, expected_event_ids)
def test_aggregation_pagination_groups(self):
- """Test that we can paginate annotation groups correctly.
- """
+ """Test that we can paginate annotation groups correctly."""
# We need to create ten separate users to send each reaction.
access_tokens = [self.user_token, self.user2_token]
@@ -240,8 +236,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(sent_groups, found_groups)
def test_aggregation_pagination_within_group(self):
- """Test that we can paginate within an annotation group.
- """
+ """Test that we can paginate within an annotation group."""
# We need to create ten separate users to send each reaction.
access_tokens = [self.user_token, self.user2_token]
@@ -311,8 +306,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(found_event_ids, expected_event_ids)
def test_aggregation(self):
- """Test that annotations get correctly aggregated.
- """
+ """Test that annotations get correctly aggregated."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
self.assertEquals(200, channel.code, channel.json_body)
@@ -344,8 +338,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
)
def test_aggregation_redactions(self):
- """Test that annotations get correctly aggregated after a redaction.
- """
+ """Test that annotations get correctly aggregated after a redaction."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
self.assertEquals(200, channel.code, channel.json_body)
@@ -379,8 +372,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
)
def test_aggregation_must_be_annotation(self):
- """Test that aggregations must be annotations.
- """
+ """Test that aggregations must be annotations."""
channel = self.make_request(
"GET",
@@ -437,8 +429,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
)
def test_edit(self):
- """Test that a simple edit works.
- """
+ """Test that a simple edit works."""
new_body = {"msgtype": "m.text", "body": "I've been edited!"}
channel = self._send_relation(
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index 7f68032d9d..899f4902d7 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -481,13 +481,19 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
# Check that room name changes increase the unread counter.
self.helper.send_state(
- self.room_id, "m.room.name", {"name": "my super room"}, tok=self.tok2,
+ self.room_id,
+ "m.room.name",
+ {"name": "my super room"},
+ tok=self.tok2,
)
self._check_unread_count(1)
# Check that room topic changes increase the unread counter.
self.helper.send_state(
- self.room_id, "m.room.topic", {"topic": "welcome!!!"}, tok=self.tok2,
+ self.room_id,
+ "m.room.topic",
+ {"topic": "welcome!!!"},
+ tok=self.tok2,
)
self._check_unread_count(2)
@@ -497,7 +503,10 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
# Check that custom events with a body increase the unread counter.
self.helper.send_event(
- self.room_id, "org.matrix.custom_type", {"body": "hello"}, tok=self.tok2,
+ self.room_id,
+ "org.matrix.custom_type",
+ {"body": "hello"},
+ tok=self.tok2,
)
self._check_unread_count(4)
@@ -536,14 +545,18 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
"""Syncs and compares the unread count with the expected value."""
channel = self.make_request(
- "GET", self.url % self.next_batch, access_token=self.tok,
+ "GET",
+ self.url % self.next_batch,
+ access_token=self.tok,
)
self.assertEqual(channel.code, 200, channel.json_body)
room_entry = channel.json_body["rooms"]["join"][self.room_id]
self.assertEqual(
- room_entry["org.matrix.msc2654.unread_count"], expected_count, room_entry,
+ room_entry["org.matrix.msc2654.unread_count"],
+ expected_count,
+ room_entry,
)
# Store the next batch for the next request.
diff --git a/tests/rest/client/v2_alpha/test_upgrade_room.py b/tests/rest/client/v2_alpha/test_upgrade_room.py
new file mode 100644
index 0000000000..d890d11863
--- /dev/null
+++ b/tests/rest/client/v2_alpha/test_upgrade_room.py
@@ -0,0 +1,161 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Optional
+
+from synapse.config.server import DEFAULT_ROOM_VERSION
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet
+
+from tests import unittest
+from tests.server import FakeChannel
+
+
+class UpgradeRoomTest(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ room_upgrade_rest_servlet.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.handler = hs.get_user_directory_handler()
+
+ self.creator = self.register_user("creator", "pass")
+ self.creator_token = self.login(self.creator, "pass")
+
+ self.other = self.register_user("user", "pass")
+ self.other_token = self.login(self.other, "pass")
+
+ self.room_id = self.helper.create_room_as(self.creator, tok=self.creator_token)
+ self.helper.join(self.room_id, self.other, tok=self.other_token)
+
+ def _upgrade_room(self, token: Optional[str] = None) -> FakeChannel:
+ # We never want a cached response.
+ self.reactor.advance(5 * 60 + 1)
+
+ return self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/%s/upgrade" % self.room_id,
+ # This will upgrade a room to the same version, but that's fine.
+ content={"new_version": DEFAULT_ROOM_VERSION},
+ access_token=token or self.creator_token,
+ )
+
+ def test_upgrade(self):
+ """
+ Upgrading a room should work fine.
+ """
+ channel = self._upgrade_room()
+ self.assertEquals(200, channel.code, channel.result)
+ self.assertIn("replacement_room", channel.json_body)
+
+ def test_not_in_room(self):
+ """
+ Upgrading a room should work fine.
+ """
+ # THe user isn't in the room.
+ roomless = self.register_user("roomless", "pass")
+ roomless_token = self.login(roomless, "pass")
+
+ channel = self._upgrade_room(roomless_token)
+ self.assertEquals(403, channel.code, channel.result)
+
+ def test_power_levels(self):
+ """
+ Another user can upgrade the room if their power level is increased.
+ """
+ # The other user doesn't have the proper power level.
+ channel = self._upgrade_room(self.other_token)
+ self.assertEquals(403, channel.code, channel.result)
+
+ # Increase the power levels so that this user can upgrade.
+ power_levels = self.helper.get_state(
+ self.room_id,
+ "m.room.power_levels",
+ tok=self.creator_token,
+ )
+ power_levels["users"][self.other] = 100
+ self.helper.send_state(
+ self.room_id,
+ "m.room.power_levels",
+ body=power_levels,
+ tok=self.creator_token,
+ )
+
+ # The upgrade should succeed!
+ channel = self._upgrade_room(self.other_token)
+ self.assertEquals(200, channel.code, channel.result)
+
+ def test_power_levels_user_default(self):
+ """
+ Another user can upgrade the room if the default power level for users is increased.
+ """
+ # The other user doesn't have the proper power level.
+ channel = self._upgrade_room(self.other_token)
+ self.assertEquals(403, channel.code, channel.result)
+
+ # Increase the power levels so that this user can upgrade.
+ power_levels = self.helper.get_state(
+ self.room_id,
+ "m.room.power_levels",
+ tok=self.creator_token,
+ )
+ power_levels["users_default"] = 100
+ self.helper.send_state(
+ self.room_id,
+ "m.room.power_levels",
+ body=power_levels,
+ tok=self.creator_token,
+ )
+
+ # The upgrade should succeed!
+ channel = self._upgrade_room(self.other_token)
+ self.assertEquals(200, channel.code, channel.result)
+
+ def test_power_levels_tombstone(self):
+ """
+ Another user can upgrade the room if they can send the tombstone event.
+ """
+ # The other user doesn't have the proper power level.
+ channel = self._upgrade_room(self.other_token)
+ self.assertEquals(403, channel.code, channel.result)
+
+ # Increase the power levels so that this user can upgrade.
+ power_levels = self.helper.get_state(
+ self.room_id,
+ "m.room.power_levels",
+ tok=self.creator_token,
+ )
+ power_levels["events"]["m.room.tombstone"] = 0
+ self.helper.send_state(
+ self.room_id,
+ "m.room.power_levels",
+ body=power_levels,
+ tok=self.creator_token,
+ )
+
+ # The upgrade should succeed!
+ channel = self._upgrade_room(self.other_token)
+ self.assertEquals(200, channel.code, channel.result)
+
+ power_levels = self.helper.get_state(
+ self.room_id,
+ "m.room.power_levels",
+ tok=self.creator_token,
+ )
+ self.assertNotIn(self.other, power_levels["users"])
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index 5e90d656f7..9d0d0ef414 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -180,7 +180,8 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
async def post_json(destination, path, data):
self.assertEqual(destination, self.hs.hostname)
self.assertEqual(
- path, "/_matrix/key/v2/query",
+ path,
+ "/_matrix/key/v2/query",
)
channel = FakeChannel(self.site, self.reactor)
@@ -188,7 +189,9 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
req.content = BytesIO(encode_canonical_json(data))
req.requestReceived(
- b"POST", path.encode("utf-8"), b"1.1",
+ b"POST",
+ path.encode("utf-8"),
+ b"1.1",
)
channel.await_result()
self.assertEqual(channel.code, 200)
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index c279eb49e3..0789b12392 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -167,7 +167,16 @@ class _TestImage:
),
),
# an empty file
- (_TestImage(b"", b"image/gif", b".gif", None, None, False,),),
+ (
+ _TestImage(
+ b"",
+ b"image/gif",
+ b".gif",
+ None,
+ None,
+ False,
+ ),
+ ),
],
)
class MediaRepoTests(unittest.HomeserverTestCase):
@@ -469,8 +478,7 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
return config
def test_upload_innocent(self):
- """Attempt to upload some innocent data that should be allowed.
- """
+ """Attempt to upload some innocent data that should be allowed."""
image_data = unhexlify(
b"89504e470d0a1a0a0000000d4948445200000001000000010806"
diff --git a/tests/server.py b/tests/server.py
index 6419c445ec..d4ece5c448 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -347,8 +347,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
self._tcp_callbacks[(host, port)] = callback
def connectTCP(self, host, port, factory, timeout=30, bindAddress=None):
- """Fake L{IReactorTCP.connectTCP}.
- """
+ """Fake L{IReactorTCP.connectTCP}."""
conn = super().connectTCP(
host, port, factory, timeout=timeout, bindAddress=None
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index fea54464af..d40d65b06a 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -353,7 +353,11 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
tok = self.login(localpart, "password")
# Sync with the user's token to mark the user as active.
- channel = self.make_request("GET", "/sync?timeout=0", access_token=tok,)
+ channel = self.make_request(
+ "GET",
+ "/sync?timeout=0",
+ access_token=tok,
+ )
# Also retrieves the list of invites for this user. We don't care about that
# one except if we're processing the last user, which should have received an
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index 77c72834f2..66e3cafe8e 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -382,8 +382,7 @@ class StateTestCase(unittest.TestCase):
self.do_check(events, edges, expected_state_ids)
def test_mainline_sort(self):
- """Tests that the mainline ordering works correctly.
- """
+ """Tests that the mainline ordering works correctly."""
events = [
FakeEvent(
@@ -660,15 +659,27 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
# C -|-> B -> A
a = FakeEvent(
- id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ id="A",
+ sender=ALICE,
+ type=EventTypes.Member,
+ state_key="",
+ content={},
).to_event([], [])
b = FakeEvent(
- id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ id="B",
+ sender=ALICE,
+ type=EventTypes.Member,
+ state_key="",
+ content={},
).to_event([a.event_id], [])
c = FakeEvent(
- id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ id="C",
+ sender=ALICE,
+ type=EventTypes.Member,
+ state_key="",
+ content={},
).to_event([b.event_id], [])
persisted_events = {a.event_id: a, b.event_id: b}
@@ -694,19 +705,35 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
# D -> C -|-> B -> A
a = FakeEvent(
- id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ id="A",
+ sender=ALICE,
+ type=EventTypes.Member,
+ state_key="",
+ content={},
).to_event([], [])
b = FakeEvent(
- id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ id="B",
+ sender=ALICE,
+ type=EventTypes.Member,
+ state_key="",
+ content={},
).to_event([a.event_id], [])
c = FakeEvent(
- id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ id="C",
+ sender=ALICE,
+ type=EventTypes.Member,
+ state_key="",
+ content={},
).to_event([b.event_id], [])
d = FakeEvent(
- id="D", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ id="D",
+ sender=ALICE,
+ type=EventTypes.Member,
+ state_key="",
+ content={},
).to_event([c.event_id], [])
persisted_events = {a.event_id: a, b.event_id: b}
@@ -737,23 +764,43 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
# |
a = FakeEvent(
- id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ id="A",
+ sender=ALICE,
+ type=EventTypes.Member,
+ state_key="",
+ content={},
).to_event([], [])
b = FakeEvent(
- id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ id="B",
+ sender=ALICE,
+ type=EventTypes.Member,
+ state_key="",
+ content={},
).to_event([a.event_id], [])
c = FakeEvent(
- id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ id="C",
+ sender=ALICE,
+ type=EventTypes.Member,
+ state_key="",
+ content={},
).to_event([b.event_id], [])
d = FakeEvent(
- id="D", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ id="D",
+ sender=ALICE,
+ type=EventTypes.Member,
+ state_key="",
+ content={},
).to_event([c.event_id], [])
e = FakeEvent(
- id="E", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ id="E",
+ sender=ALICE,
+ type=EventTypes.Member,
+ state_key="",
+ content={},
).to_event([c.event_id, b.event_id], [])
persisted_events = {a.event_id: a, b.event_id: b}
diff --git a/tests/storage/test_account_data.py b/tests/storage/test_account_data.py
index 673e1fe3e3..38444e48e2 100644
--- a/tests/storage/test_account_data.py
+++ b/tests/storage/test_account_data.py
@@ -96,7 +96,9 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
# No ignored_users key.
self.get_success(
self.store.add_account_data_for_user(
- self.user, AccountDataTypes.IGNORED_USER_LIST, {},
+ self.user,
+ AccountDataTypes.IGNORED_USER_LIST,
+ {},
)
)
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 02aae1c13d..1b4fae0bb5 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -67,7 +67,9 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
async def update(progress, count):
self.assertEqual(progress, {"my_key": 2})
self.assertAlmostEqual(
- count, target_background_update_duration_ms / duration_ms, places=0,
+ count,
+ target_background_update_duration_ms / duration_ms,
+ places=0,
)
await self.updates._end_background_update("test_update")
return count
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index c13a57dad1..7791138688 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -43,8 +43,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
self.room_id = info["room_id"]
def run_background_update(self):
- """Re run the background update to clean up the extremities.
- """
+ """Re run the background update to clean up the extremities."""
# Make sure we don't clash with in progress updates.
self.assertTrue(
self.store.db_pool.updates._all_done, "Background updates are still ongoing"
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index a69117c5a9..34e6526097 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -41,7 +41,13 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
device_id = "MY_DEVICE"
# Insert a user IP
- self.get_success(self.store.store_device(user_id, device_id, "display name",))
+ self.get_success(
+ self.store.store_device(
+ user_id,
+ device_id,
+ "display name",
+ )
+ )
self.get_success(
self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", device_id
@@ -214,7 +220,13 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
device_id = "MY_DEVICE"
# Insert a user IP
- self.get_success(self.store.store_device(user_id, device_id, "display name",))
+ self.get_success(
+ self.store.store_device(
+ user_id,
+ device_id,
+ "display name",
+ )
+ )
self.get_success(
self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", device_id
@@ -303,7 +315,13 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
device_id = "MY_DEVICE"
# Insert a user IP
- self.get_success(self.store.store_device(user_id, device_id, "display name",))
+ self.get_success(
+ self.store.store_device(
+ user_id,
+ device_id,
+ "display name",
+ )
+ )
self.get_success(
self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", device_id
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index 0c46ad595b..16daa66cc9 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -90,7 +90,8 @@ class EventChainStoreTestCase(HomeserverTestCase):
"content": {"tag": "power"},
},
).build(
- prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id],
+ prev_event_ids=[],
+ auth_event_ids=[create.event_id, bob_join.event_id],
)
)
@@ -226,7 +227,8 @@ class EventChainStoreTestCase(HomeserverTestCase):
self.assertFalse(
link_map.exists_path_from(
- chain_map[create.event_id], chain_map[event.event_id],
+ chain_map[create.event_id],
+ chain_map[event.event_id],
),
)
@@ -287,7 +289,8 @@ class EventChainStoreTestCase(HomeserverTestCase):
"content": {"tag": "power"},
},
).build(
- prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id],
+ prev_event_ids=[],
+ auth_event_ids=[create.event_id, bob_join.event_id],
)
)
@@ -373,7 +376,8 @@ class EventChainStoreTestCase(HomeserverTestCase):
)
def persist(
- self, events: List[EventBase],
+ self,
+ events: List[EventBase],
):
"""Persist the given events and check that the links generated match
those given.
@@ -394,7 +398,10 @@ class EventChainStoreTestCase(HomeserverTestCase):
persist_events_store._persist_event_auth_chain_txn(txn, events)
self.get_success(
- persist_events_store.db_pool.runInteraction("_persist", _persist,)
+ persist_events_store.db_pool.runInteraction(
+ "_persist",
+ _persist,
+ )
)
def fetch_chains(
@@ -447,8 +454,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
class LinkMapTestCase(unittest.TestCase):
def test_simple(self):
- """Basic tests for the LinkMap.
- """
+ """Basic tests for the LinkMap."""
link_map = _LinkMap()
link_map.add_link((1, 1), (2, 1), new=False)
@@ -490,8 +496,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
self.requester = create_requester(self.user_id)
def _generate_room(self) -> Tuple[str, List[Set[str]]]:
- """Insert a room without a chain cover index.
- """
+ """Insert a room without a chain cover index."""
room_id = self.helper.create_room_as(self.user_id, tok=self.token)
# Mark the room as not having a chain cover index
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 9d04a066d8..06000f81a6 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -215,7 +215,12 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
],
)
- self.get_success(self.store.db_pool.runInteraction("insert", insert_event,))
+ self.get_success(
+ self.store.db_pool.runInteraction(
+ "insert",
+ insert_event,
+ )
+ )
# Now actually test that various combinations give the right result:
@@ -370,7 +375,8 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
)
self.hs.datastores.persist_events._persist_event_auth_chain_txn(
- txn, [FakeEvent("b", room_id, auth_graph["b"])],
+ txn,
+ [FakeEvent("b", room_id, auth_graph["b"])],
)
self.store.db_pool.simple_update_txn(
@@ -380,7 +386,12 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
updatevalues={"has_auth_chain_index": True},
)
- self.get_success(self.store.db_pool.runInteraction("insert", insert_event,))
+ self.get_success(
+ self.store.db_pool.runInteraction(
+ "insert",
+ insert_event,
+ )
+ )
# Now actually test that various combinations give the right result:
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index c0595963dd..485f1ee033 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -84,7 +84,9 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
yield defer.ensureDeferred(
self.store.add_push_actions_to_staging(
- event.event_id, {user_id: action}, False,
+ event.event_id,
+ {user_id: action},
+ False,
)
)
yield defer.ensureDeferred(
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
index 71210ce606..ed898b8dbb 100644
--- a/tests/storage/test_events.py
+++ b/tests/storage/test_events.py
@@ -68,16 +68,14 @@ class ExtremPruneTestCase(HomeserverTestCase):
self.assert_extremities([self.remote_event_1.event_id])
def persist_event(self, event, state=None):
- """Persist the event, with optional state
- """
+ """Persist the event, with optional state"""
context = self.get_success(
self.state.compute_event_context(event, old_state=state)
)
self.get_success(self.persistence.persist_event(event, context))
def assert_extremities(self, expected_extremities):
- """Assert the current extremities for the room
- """
+ """Assert the current extremities for the room"""
extremities = self.get_success(
self.store.get_prev_events_for_room(self.room_id)
)
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 3e2fd4da01..aad6bc907e 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -86,7 +86,11 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
def _insert(txn):
txn.execute(
- "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
+ "INSERT INTO foobar VALUES (?, ?)",
+ (
+ stream_id,
+ instance_name,
+ ),
)
txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,))
txn.execute(
@@ -138,8 +142,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
def test_out_of_order_finish(self):
- """Test that IDs persisted out of order are correctly handled
- """
+ """Test that IDs persisted out of order are correctly handled"""
# Prefill table with 7 rows written by 'master'
self._insert_rows("master", 7)
@@ -246,8 +249,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9})
def test_get_next_txn(self):
- """Test that the `get_next_txn` function works correctly.
- """
+ """Test that the `get_next_txn` function works correctly."""
# Prefill table with 7 rows written by 'master'
self._insert_rows("master", 7)
@@ -386,8 +388,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen_worker.get_positions(), {"master": 9})
def test_writer_config_change(self):
- """Test that changing the writer config correctly works.
- """
+ """Test that changing the writer config correctly works."""
self._insert_row_with_id("first", 3)
self._insert_row_with_id("second", 5)
@@ -434,8 +435,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen_5.get_current_token_for_writer("third"), 6)
def test_sequence_consistency(self):
- """Test that we error out if the table and sequence diverges.
- """
+ """Test that we error out if the table and sequence diverges."""
# Prefill with some rows
self._insert_row_with_id("master", 3)
@@ -452,8 +452,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
- """Tests MultiWriterIdGenerator that produce *negative* stream IDs.
- """
+ """Tests MultiWriterIdGenerator that produce *negative* stream IDs."""
if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres"
@@ -494,12 +493,15 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
return self.get_success(self.db_pool.runWithConnection(_create))
def _insert_row(self, instance_name: str, stream_id: int):
- """Insert one row as the given instance with given stream_id.
- """
+ """Insert one row as the given instance with given stream_id."""
def _insert(txn):
txn.execute(
- "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
+ "INSERT INTO foobar VALUES (?, ?)",
+ (
+ stream_id,
+ instance_name,
+ ),
)
txn.execute(
"""
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 8d97b6d4cd..5858c7fcc4 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -198,7 +198,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
# value, although it gets stored on the config object as mau_limits.
@override_config({"max_mau_value": 5, "mau_limit_reserved_threepids": gen_3pids(5)})
def test_reap_monthly_active_users_reserved_users(self):
- """ Tests that reaping correctly handles reaping where reserved users are
+ """Tests that reaping correctly handles reaping where reserved users are
present"""
threepids = self.hs.config.mau_limits_reserved_threepids
initial_users = len(threepids)
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index a6303bf0ee..b2a0e60856 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -299,8 +299,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
def test_redact_censor(self):
- """Test that a redacted event gets censored in the DB after a month
- """
+ """Test that a redacted event gets censored in the DB after a month"""
self.get_success(
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
@@ -370,8 +369,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.assert_dict({"content": {}}, json.loads(event_json))
def test_redact_redaction(self):
- """Tests that we can redact a redaction and can fetch it again.
- """
+ """Tests that we can redact a redaction and can fetch it again."""
self.get_success(
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
@@ -404,8 +402,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
def test_store_redacted_redaction(self):
- """Tests that we can store a redacted redaction.
- """
+ """Tests that we can store a redacted redaction."""
self.get_success(
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index c8c7a90e5d..abbaed7cdc 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -145,7 +145,10 @@ class RegistrationStoreTestCase(unittest.TestCase):
try:
yield defer.ensureDeferred(
self.store.validate_threepid_session(
- "fake_sid", "fake_client_secret", "fake_token", 0,
+ "fake_sid",
+ "fake_client_secret",
+ "fake_token",
+ 0,
)
)
except ThreepidValidationError as e:
@@ -158,7 +161,10 @@ class RegistrationStoreTestCase(unittest.TestCase):
try:
yield defer.ensureDeferred(
self.store.validate_threepid_session(
- "fake_sid", "fake_client_secret", "fake_token", 0,
+ "fake_sid",
+ "fake_client_secret",
+ "fake_token",
+ 0,
)
)
except ThreepidValidationError as e:
diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
index 69b4c5d6c2..3f2691ee6b 100644
--- a/tests/test_event_auth.py
+++ b/tests/test_event_auth.py
@@ -85,7 +85,10 @@ class EventAuthTestCase(unittest.TestCase):
# king should be able to send state
event_auth.check(
- RoomVersions.V1, _random_state_event(king), auth_events, do_sig_check=False,
+ RoomVersions.V1,
+ _random_state_event(king),
+ auth_events,
+ do_sig_check=False,
)
def test_alias_event(self):
@@ -99,7 +102,10 @@ class EventAuthTestCase(unittest.TestCase):
# creator should be able to send aliases
event_auth.check(
- RoomVersions.V1, _alias_event(creator), auth_events, do_sig_check=False,
+ RoomVersions.V1,
+ _alias_event(creator),
+ auth_events,
+ do_sig_check=False,
)
# Reject an event with no state key.
@@ -122,7 +128,10 @@ class EventAuthTestCase(unittest.TestCase):
# Note that the member does *not* need to be in the room.
event_auth.check(
- RoomVersions.V1, _alias_event(other), auth_events, do_sig_check=False,
+ RoomVersions.V1,
+ _alias_event(other),
+ auth_events,
+ do_sig_check=False,
)
def test_msc2432_alias_event(self):
@@ -136,7 +145,10 @@ class EventAuthTestCase(unittest.TestCase):
# creator should be able to send aliases
event_auth.check(
- RoomVersions.V6, _alias_event(creator), auth_events, do_sig_check=False,
+ RoomVersions.V6,
+ _alias_event(creator),
+ auth_events,
+ do_sig_check=False,
)
# No particular checks are done on the state key.
@@ -156,7 +168,10 @@ class EventAuthTestCase(unittest.TestCase):
# Per standard auth rules, the member must be in the room.
with self.assertRaises(AuthError):
event_auth.check(
- RoomVersions.V6, _alias_event(other), auth_events, do_sig_check=False,
+ RoomVersions.V6,
+ _alias_event(other),
+ auth_events,
+ do_sig_check=False,
)
def test_msc2209(self):
diff --git a/tests/test_mau.py b/tests/test_mau.py
index 51660b51d5..75d28a42df 100644
--- a/tests/test_mau.py
+++ b/tests/test_mau.py
@@ -242,7 +242,10 @@ class TestMauLimit(unittest.HomeserverTestCase):
)
channel = self.make_request(
- "POST", "/register", request_data, access_token=token,
+ "POST",
+ "/register",
+ request_data,
+ access_token=token,
)
if channel.code != 200:
diff --git a/tests/test_metrics.py b/tests/test_metrics.py
index 759e4cd048..f696fcf89e 100644
--- a/tests/test_metrics.py
+++ b/tests/test_metrics.py
@@ -21,7 +21,7 @@ from tests import unittest
def get_sample_labels_value(sample):
- """ Extract the labels and values of a sample.
+ """Extract the labels and values of a sample.
prometheus_client 0.5 changed the sample type to a named tuple with more
members than the plain tuple had in 0.4 and earlier. This function can
diff --git a/tests/test_server.py b/tests/test_server.py
index 815da18e65..55cde7f62f 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -166,7 +166,10 @@ class JsonResourceTests(unittest.TestCase):
res = JsonResource(self.homeserver)
res.register_paths(
- "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet",
+ "GET",
+ [re.compile("^/_matrix/foo$")],
+ _callback,
+ "test_servlet",
)
# The path was registered as GET, but this is a HEAD request.
diff --git a/tests/unittest.py b/tests/unittest.py
index 767d5d6077..ca7031c724 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -255,7 +255,10 @@ class HomeserverTestCase(TestCase):
# We need a valid token ID to satisfy foreign key constraints.
token_id = self.get_success(
self.hs.get_datastore().add_access_token_to_user(
- self.helper.auth_user_id, "some_fake_token", None, None,
+ self.helper.auth_user_id,
+ "some_fake_token",
+ None,
+ None,
)
)
diff --git a/tests/util/caches/test_cached_call.py b/tests/util/caches/test_cached_call.py
new file mode 100644
index 0000000000..f349b5ced0
--- /dev/null
+++ b/tests/util/caches/test_cached_call.py
@@ -0,0 +1,161 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from unittest.mock import Mock
+
+from twisted.internet import defer
+from twisted.internet.defer import Deferred
+
+from synapse.util.caches.cached_call import CachedCall, RetryOnExceptionCachedCall
+
+from tests.test_utils import get_awaitable_result
+from tests.unittest import TestCase
+
+
+class CachedCallTestCase(TestCase):
+ def test_get(self):
+ """
+ Happy-path test case: makes a couple of calls and makes sure they behave
+ correctly
+ """
+ d = Deferred()
+
+ async def f():
+ return await d
+
+ slow_call = Mock(side_effect=f)
+
+ cached_call = CachedCall(slow_call)
+
+ # the mock should not yet have been called
+ slow_call.assert_not_called()
+
+ # now fire off a couple of calls
+ completed_results = []
+
+ async def r():
+ res = await cached_call.get()
+ completed_results.append(res)
+
+ r1 = defer.ensureDeferred(r())
+ r2 = defer.ensureDeferred(r())
+
+ # neither result should be complete yet
+ self.assertNoResult(r1)
+ self.assertNoResult(r2)
+
+ # and the mock should have been called *once*, with no params
+ slow_call.assert_called_once_with()
+
+ # allow the deferred to complete, which should complete both the pending results
+ d.callback(123)
+ self.assertEqual(completed_results, [123, 123])
+ self.successResultOf(r1)
+ self.successResultOf(r2)
+
+ # another call to the getter should complete immediately
+ slow_call.reset_mock()
+ r3 = get_awaitable_result(cached_call.get())
+ self.assertEqual(r3, 123)
+ slow_call.assert_not_called()
+
+ def test_fast_call(self):
+ """
+ Test the behaviour when the underlying function completes immediately
+ """
+
+ async def f():
+ return 12
+
+ fast_call = Mock(side_effect=f)
+ cached_call = CachedCall(fast_call)
+
+ # the mock should not yet have been called
+ fast_call.assert_not_called()
+
+ # run the call a couple of times, which should complete immediately
+ self.assertEqual(get_awaitable_result(cached_call.get()), 12)
+ self.assertEqual(get_awaitable_result(cached_call.get()), 12)
+
+ # the mock should have been called once
+ fast_call.assert_called_once_with()
+
+
+class RetryOnExceptionCachedCallTestCase(TestCase):
+ def test_get(self):
+ # set up the RetryOnExceptionCachedCall around a function which will fail
+ # (after a while)
+ d = Deferred()
+
+ async def f1():
+ await d
+ raise ValueError("moo")
+
+ slow_call = Mock(side_effect=f1)
+ cached_call = RetryOnExceptionCachedCall(slow_call)
+
+ # the mock should not yet have been called
+ slow_call.assert_not_called()
+
+ # now fire off a couple of calls
+ completed_results = []
+
+ async def r():
+ try:
+ await cached_call.get()
+ except Exception as e1:
+ completed_results.append(e1)
+
+ r1 = defer.ensureDeferred(r())
+ r2 = defer.ensureDeferred(r())
+
+ # neither result should be complete yet
+ self.assertNoResult(r1)
+ self.assertNoResult(r2)
+
+ # and the mock should have been called *once*, with no params
+ slow_call.assert_called_once_with()
+
+ # complete the deferred, which should make the pending calls fail
+ d.callback(0)
+ self.assertEqual(len(completed_results), 2)
+ for e in completed_results:
+ self.assertIsInstance(e, ValueError)
+ self.assertEqual(e.args, ("moo",))
+
+ # reset the mock to return a successful result, and make another pair of calls
+ # to the getter
+ d = Deferred()
+
+ async def f2():
+ return await d
+
+ slow_call.reset_mock()
+ slow_call.side_effect = f2
+ r3 = defer.ensureDeferred(cached_call.get())
+ r4 = defer.ensureDeferred(cached_call.get())
+
+ self.assertNoResult(r3)
+ self.assertNoResult(r4)
+ slow_call.assert_called_once_with()
+
+ # let that call complete, and check the results
+ d.callback(123)
+ self.assertEqual(self.successResultOf(r3), 123)
+ self.assertEqual(self.successResultOf(r4), 123)
+
+ # and now more calls to the getter should complete immediately
+ slow_call.reset_mock()
+ self.assertEqual(get_awaitable_result(cached_call.get()), 123)
+ slow_call.assert_not_called()
diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py
index ecd9efc4df..c24c33ee91 100644
--- a/tests/util/caches/test_deferred_cache.py
+++ b/tests/util/caches/test_deferred_cache.py
@@ -232,7 +232,10 @@ class DeferredCacheTestCase(TestCase):
def test_eviction_iterable(self):
cache = DeferredCache(
- "test", max_entries=3, apply_cache_factor_from_config=False, iterable=True,
+ "test",
+ max_entries=3,
+ apply_cache_factor_from_config=False,
+ iterable=True,
)
cache.prefill(1, ["one", "two"])
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index cf1e3203a4..afb11b9caf 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -143,8 +143,7 @@ class DescriptorTestCase(unittest.TestCase):
obj.mock.assert_not_called()
def test_cache_with_sync_exception(self):
- """If the wrapped function throws synchronously, things should continue to work
- """
+ """If the wrapped function throws synchronously, things should continue to work"""
class Cls:
@cached()
@@ -165,8 +164,7 @@ class DescriptorTestCase(unittest.TestCase):
self.failureResultOf(d, SynapseError)
def test_cache_with_async_exception(self):
- """The wrapped function returns a failure
- """
+ """The wrapped function returns a failure"""
class Cls:
result = None
@@ -282,7 +280,8 @@ class DescriptorTestCase(unittest.TestCase):
try:
d = obj.fn(1)
self.assertEqual(
- current_context(), SENTINEL_CONTEXT,
+ current_context(),
+ SENTINEL_CONTEXT,
)
yield d
self.fail("No exception thrown")
@@ -374,8 +373,7 @@ class DescriptorTestCase(unittest.TestCase):
obj.mock.assert_not_called()
def test_cache_iterable_with_sync_exception(self):
- """If the wrapped function throws synchronously, things should continue to work
- """
+ """If the wrapped function throws synchronously, things should continue to work"""
class Cls:
@descriptors.cached(iterable=True)
diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py
index 1ef0af8e8f..e931a7ec18 100644
--- a/tests/util/test_itertools.py
+++ b/tests/util/test_itertools.py
@@ -24,28 +24,32 @@ class ChunkSeqTests(TestCase):
parts = chunk_seq("123", 8)
self.assertEqual(
- list(parts), ["123"],
+ list(parts),
+ ["123"],
)
def test_long_seq(self):
parts = chunk_seq("abcdefghijklmnop", 8)
self.assertEqual(
- list(parts), ["abcdefgh", "ijklmnop"],
+ list(parts),
+ ["abcdefgh", "ijklmnop"],
)
def test_uneven_parts(self):
parts = chunk_seq("abcdefghijklmnop", 5)
self.assertEqual(
- list(parts), ["abcde", "fghij", "klmno", "p"],
+ list(parts),
+ ["abcde", "fghij", "klmno", "p"],
)
def test_empty_input(self):
parts = chunk_seq([], 5)
self.assertEqual(
- list(parts), [],
+ list(parts),
+ [],
)
diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py
index 13b753e367..9ed01f7e0c 100644
--- a/tests/util/test_stream_change_cache.py
+++ b/tests/util/test_stream_change_cache.py
@@ -70,7 +70,8 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
self.assertTrue("user@foo.com" not in cache._entity_to_key)
self.assertEqual(
- cache.get_all_entities_changed(2), ["bar@baz.net", "user@elsewhere.org"],
+ cache.get_all_entities_changed(2),
+ ["bar@baz.net", "user@elsewhere.org"],
)
self.assertIsNone(cache.get_all_entities_changed(1))
@@ -80,7 +81,8 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
{"bar@baz.net", "user@elsewhere.org"}, set(cache._entity_to_key)
)
self.assertEqual(
- cache.get_all_entities_changed(2), ["user@elsewhere.org", "bar@baz.net"],
+ cache.get_all_entities_changed(2),
+ ["user@elsewhere.org", "bar@baz.net"],
)
self.assertIsNone(cache.get_all_entities_changed(1))
@@ -222,7 +224,8 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
# Query a subset of the entries mid-way through the stream. We should
# only get back the subset.
self.assertEqual(
- cache.get_entities_changed(["bar@baz.net"], stream_pos=2), {"bar@baz.net"},
+ cache.get_entities_changed(["bar@baz.net"], stream_pos=2),
+ {"bar@baz.net"},
)
def test_max_pos(self):
diff --git a/tests/utils.py b/tests/utils.py
index d76a46bc89..ebb76b3b16 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -269,7 +269,10 @@ def setup_test_homeserver(
db_conn.close()
hs = homeserver_to_use(
- name, config=config, version_string="Synapse/tests", reactor=reactor,
+ name,
+ config=config,
+ version_string="Synapse/tests",
+ reactor=reactor,
)
# Install @cache_in_self attributes
@@ -371,7 +374,7 @@ class MockHttpResource:
def trigger(
self, http_method, path, content, mock_request, federation_auth_origin=None
):
- """ Fire an HTTP event.
+ """Fire an HTTP event.
Args:
http_method : The HTTP method
@@ -534,8 +537,7 @@ class MockClock:
async def create_room(hs, room_id: str, creator_id: str):
- """Creates and persist a creation event for the given room
- """
+ """Creates and persist a creation event for the given room"""
persistence_store = hs.get_storage().persistence
store = hs.get_datastore()
|