diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 8ab56ec94c..ee5217b074 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -19,7 +19,6 @@ import pymacaroons
from twisted.internet import defer
-import synapse.handlers.auth
from synapse.api.auth import Auth
from synapse.api.constants import UserTypes
from synapse.api.errors import (
@@ -30,26 +29,22 @@ from synapse.api.errors import (
MissingClientTokenError,
ResourceLimitError,
)
+from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import UserID
from tests import unittest
from tests.utils import mock_getRawHeaders, setup_test_homeserver
-class TestHandlers:
- def __init__(self, hs):
- self.auth_handler = synapse.handlers.auth.AuthHandler(hs)
-
-
class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.state_handler = Mock()
self.store = Mock()
- self.hs = yield setup_test_homeserver(self.addCleanup, handlers=None)
+ self.hs = yield setup_test_homeserver(self.addCleanup)
self.hs.get_datastore = Mock(return_value=self.store)
- self.hs.handlers = TestHandlers(self.hs)
+ self.hs.get_auth_handler().store = self.store
self.auth = Auth(self.hs)
# AuthBlocking reads from the hs' config on initialization. We need to
@@ -67,7 +62,9 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_user_by_req_user_valid_token(self):
- user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"}
+ user_info = TokenLookupResult(
+ user_id=self.test_user, token_id=5, device_id="device"
+ )
self.store.get_user_by_access_token = Mock(
return_value=defer.succeed(user_info)
)
@@ -90,7 +87,7 @@ class AuthTestCase(unittest.TestCase):
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_user_missing_token(self):
- user_info = {"name": self.test_user, "token_id": "ditto"}
+ user_info = TokenLookupResult(user_id=self.test_user, token_id=5)
self.store.get_user_by_access_token = Mock(
return_value=defer.succeed(user_info)
)
@@ -227,7 +224,7 @@ class AuthTestCase(unittest.TestCase):
def test_get_user_from_macaroon(self):
self.store.get_user_by_access_token = Mock(
return_value=defer.succeed(
- {"name": "@baldrick:matrix.org", "device_id": "device"}
+ TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device")
)
)
@@ -243,12 +240,11 @@ class AuthTestCase(unittest.TestCase):
user_info = yield defer.ensureDeferred(
self.auth.get_user_by_access_token(macaroon.serialize())
)
- user = user_info["user"]
- self.assertEqual(UserID.from_string(user_id), user)
+ self.assertEqual(user_id, user_info.user_id)
# TODO: device_id should come from the macaroon, but currently comes
# from the db.
- self.assertEqual(user_info["device_id"], "device")
+ self.assertEqual(user_info.device_id, "device")
@defer.inlineCallbacks
def test_get_guest_user_from_macaroon(self):
@@ -270,10 +266,8 @@ class AuthTestCase(unittest.TestCase):
user_info = yield defer.ensureDeferred(
self.auth.get_user_by_access_token(serialized)
)
- user = user_info["user"]
- is_guest = user_info["is_guest"]
- self.assertEqual(UserID.from_string(user_id), user)
- self.assertTrue(is_guest)
+ self.assertEqual(user_id, user_info.user_id)
+ self.assertTrue(user_info.is_guest)
self.store.get_user_by_id.assert_called_with(user_id)
@defer.inlineCallbacks
@@ -283,24 +277,25 @@ class AuthTestCase(unittest.TestCase):
self.store.get_device = Mock(return_value=defer.succeed(None))
token = yield defer.ensureDeferred(
- self.hs.handlers.auth_handler.get_access_token_for_user_id(
+ self.hs.get_auth_handler().get_access_token_for_user_id(
USER_ID, "DEVICE", valid_until_ms=None
)
)
self.store.add_access_token_to_user.assert_called_with(
- USER_ID, token, "DEVICE", None
+ user_id=USER_ID,
+ token=token,
+ device_id="DEVICE",
+ valid_until_ms=None,
+ puppets_user_id=None,
)
def get_user(tok):
if token != tok:
return defer.succeed(None)
return defer.succeed(
- {
- "name": USER_ID,
- "is_guest": False,
- "token_id": 1234,
- "device_id": "DEVICE",
- }
+ TokenLookupResult(
+ user_id=USER_ID, is_guest=False, token_id=1234, device_id="DEVICE",
+ )
)
self.store.get_user_by_access_token = get_user
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index d2d535d23c..279c94a03d 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -16,8 +16,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
-
import jsonschema
from twisted.internet import defer
@@ -28,7 +26,7 @@ from synapse.api.filtering import Filter
from synapse.events import make_event_from_dict
from tests import unittest
-from tests.utils import DeferredMockCallable, MockHttpResource, setup_test_homeserver
+from tests.utils import setup_test_homeserver
user_localpart = "test_user"
@@ -42,22 +40,9 @@ def MockEvent(**kwargs):
class FilteringTestCase(unittest.TestCase):
- @defer.inlineCallbacks
def setUp(self):
- self.mock_federation_resource = MockHttpResource()
-
- self.mock_http_client = Mock(spec=[])
- self.mock_http_client.put_json = DeferredMockCallable()
-
- hs = yield setup_test_homeserver(
- self.addCleanup,
- handlers=None,
- http_client=self.mock_http_client,
- keyring=Mock(),
- )
-
+ hs = setup_test_homeserver(self.addCleanup)
self.filtering = hs.get_filtering()
-
self.datastore = hs.get_datastore()
def test_errors_on_invalid_filters(self):
diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index 1e1f30d790..fe504d0869 100644
--- a/tests/api/test_ratelimiting.py
+++ b/tests/api/test_ratelimiting.py
@@ -43,7 +43,7 @@ class TestRatelimiter(unittest.TestCase):
def test_allowed_appservice_ratelimited_via_can_requester_do_action(self):
appservice = ApplicationService(
- None, "example.com", id="foo", rate_limited=True,
+ None, "example.com", id="foo", rate_limited=True, sender="@as:example.com",
)
as_requester = create_requester("@user:example.com", app_service=appservice)
@@ -68,7 +68,7 @@ class TestRatelimiter(unittest.TestCase):
def test_allowed_appservice_via_can_requester_do_action(self):
appservice = ApplicationService(
- None, "example.com", id="foo", rate_limited=False,
+ None, "example.com", id="foo", rate_limited=False, sender="@as:example.com",
)
as_requester = create_requester("@user:example.com", app_service=appservice)
diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py
index 641093d349..e0ca288829 100644
--- a/tests/app/test_frontend_proxy.py
+++ b/tests/app/test_frontend_proxy.py
@@ -15,6 +15,7 @@
from synapse.app.generic_worker import GenericWorkerServer
+from tests.server import make_request
from tests.unittest import HomeserverTestCase
@@ -22,7 +23,7 @@ class FrontendProxyTests(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
- http_client=None, homeserverToUse=GenericWorkerServer
+ federation_http_client=None, homeserver_to_use=GenericWorkerServer
)
return hs
@@ -55,10 +56,8 @@ class FrontendProxyTests(HomeserverTestCase):
# Grab the resource from the site that was told to listen
self.assertEqual(len(self.reactor.tcpServers), 1)
site = self.reactor.tcpServers[0][1]
- self.resource = site.resource.children[b"_matrix"].children[b"client"]
- request, channel = self.make_request("PUT", "presence/a/status")
- self.render(request)
+ channel = make_request(self.reactor, site, "PUT", "presence/a/status")
# 400 + unrecognised, because nothing is registered
self.assertEqual(channel.code, 400)
@@ -77,10 +76,8 @@ class FrontendProxyTests(HomeserverTestCase):
# Grab the resource from the site that was told to listen
self.assertEqual(len(self.reactor.tcpServers), 1)
site = self.reactor.tcpServers[0][1]
- self.resource = site.resource.children[b"_matrix"].children[b"client"]
- request, channel = self.make_request("PUT", "presence/a/status")
- self.render(request)
+ channel = make_request(self.reactor, site, "PUT", "presence/a/status")
# 401, because the stub servlet still checks authentication
self.assertEqual(channel.code, 401)
diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py
index 0f016c32eb..467033e201 100644
--- a/tests/app/test_openid_listener.py
+++ b/tests/app/test_openid_listener.py
@@ -20,13 +20,14 @@ from synapse.app.generic_worker import GenericWorkerServer
from synapse.app.homeserver import SynapseHomeServer
from synapse.config.server import parse_listener_def
+from tests.server import make_request
from tests.unittest import HomeserverTestCase
class FederationReaderOpenIDListenerTests(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
- http_client=None, homeserverToUse=GenericWorkerServer
+ federation_http_client=None, homeserver_to_use=GenericWorkerServer
)
return hs
@@ -66,16 +67,15 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
# Grab the resource from the site that was told to listen
site = self.reactor.tcpServers[0][1]
try:
- self.resource = site.resource.children[b"_matrix"].children[b"federation"]
+ site.resource.children[b"_matrix"].children[b"federation"]
except KeyError:
if expectation == "no_resource":
return
raise
- request, channel = self.make_request(
- "GET", "/_matrix/federation/v1/openid/userinfo"
+ channel = make_request(
+ self.reactor, site, "GET", "/_matrix/federation/v1/openid/userinfo"
)
- self.render(request)
self.assertEqual(channel.code, 401)
@@ -84,7 +84,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
- http_client=None, homeserverToUse=SynapseHomeServer
+ federation_http_client=None, homeserver_to_use=SynapseHomeServer
)
return hs
@@ -115,15 +115,14 @@ class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
# Grab the resource from the site that was told to listen
site = self.reactor.tcpServers[0][1]
try:
- self.resource = site.resource.children[b"_matrix"].children[b"federation"]
+ site.resource.children[b"_matrix"].children[b"federation"]
except KeyError:
if expectation == "no_resource":
return
raise
- request, channel = self.make_request(
- "GET", "/_matrix/federation/v1/openid/userinfo"
+ channel = make_request(
+ self.reactor, site, "GET", "/_matrix/federation/v1/openid/userinfo"
)
- self.render(request)
self.assertEqual(channel.code, 401)
diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py
index 236b608d58..0bffeb1150 100644
--- a/tests/appservice/test_appservice.py
+++ b/tests/appservice/test_appservice.py
@@ -31,6 +31,7 @@ class ApplicationServiceTestCase(unittest.TestCase):
def setUp(self):
self.service = ApplicationService(
id="unique_identifier",
+ sender="@as:test",
url="some_url",
token="some_token",
hostname="matrix.org", # only used by get_groups_for_user
diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index 68a4caabbf..97f8cad0dd 100644
--- a/tests/appservice/test_scheduler.py
+++ b/tests/appservice/test_scheduler.py
@@ -60,7 +60,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
self.store.create_appservice_txn.assert_called_once_with(
- service=service, events=events # txn made and saved
+ service=service, events=events, ephemeral=[] # txn made and saved
)
self.assertEquals(0, len(self.txnctrl.recoverers)) # no recoverer made
txn.complete.assert_called_once_with(self.store) # txn completed
@@ -81,7 +81,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
self.store.create_appservice_txn.assert_called_once_with(
- service=service, events=events # txn made and saved
+ service=service, events=events, ephemeral=[] # txn made and saved
)
self.assertEquals(0, txn.send.call_count) # txn not sent though
self.assertEquals(0, txn.complete.call_count) # or completed
@@ -106,7 +106,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
self.store.create_appservice_txn.assert_called_once_with(
- service=service, events=events
+ service=service, events=events, ephemeral=[]
)
self.assertEquals(1, self.recoverer_fn.call_count) # recoverer made
self.assertEquals(1, self.recoverer.recover.call_count) # and invoked
@@ -202,26 +202,28 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
# Expect the event to be sent immediately.
service = Mock(id=4)
event = Mock()
- self.queuer.enqueue(service, event)
- self.txn_ctrl.send.assert_called_once_with(service, [event])
+ self.queuer.enqueue_event(service, event)
+ self.txn_ctrl.send.assert_called_once_with(service, [event], [])
def test_send_single_event_with_queue(self):
d = defer.Deferred()
- self.txn_ctrl.send = Mock(side_effect=lambda x, y: make_deferred_yieldable(d))
+ self.txn_ctrl.send = Mock(
+ side_effect=lambda x, y, z: make_deferred_yieldable(d)
+ )
service = Mock(id=4)
event = Mock(event_id="first")
event2 = Mock(event_id="second")
event3 = Mock(event_id="third")
# Send an event and don't resolve it just yet.
- self.queuer.enqueue(service, event)
+ self.queuer.enqueue_event(service, event)
# Send more events: expect send() to NOT be called multiple times.
- self.queuer.enqueue(service, event2)
- self.queuer.enqueue(service, event3)
- self.txn_ctrl.send.assert_called_with(service, [event])
+ self.queuer.enqueue_event(service, event2)
+ self.queuer.enqueue_event(service, event3)
+ self.txn_ctrl.send.assert_called_with(service, [event], [])
self.assertEquals(1, self.txn_ctrl.send.call_count)
# Resolve the send event: expect the queued events to be sent
d.callback(service)
- self.txn_ctrl.send.assert_called_with(service, [event2, event3])
+ self.txn_ctrl.send.assert_called_with(service, [event2, event3], [])
self.assertEquals(2, self.txn_ctrl.send.call_count)
def test_multiple_service_queues(self):
@@ -239,21 +241,99 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
send_return_list = [srv_1_defer, srv_2_defer]
- def do_send(x, y):
+ def do_send(x, y, z):
return make_deferred_yieldable(send_return_list.pop(0))
self.txn_ctrl.send = Mock(side_effect=do_send)
# send events for different ASes and make sure they are sent
- self.queuer.enqueue(srv1, srv_1_event)
- self.queuer.enqueue(srv1, srv_1_event2)
- self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event])
- self.queuer.enqueue(srv2, srv_2_event)
- self.queuer.enqueue(srv2, srv_2_event2)
- self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event])
+ self.queuer.enqueue_event(srv1, srv_1_event)
+ self.queuer.enqueue_event(srv1, srv_1_event2)
+ self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [])
+ self.queuer.enqueue_event(srv2, srv_2_event)
+ self.queuer.enqueue_event(srv2, srv_2_event2)
+ self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [])
# make sure callbacks for a service only send queued events for THAT
# service
srv_2_defer.callback(srv2)
- self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2])
+ self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [])
+ self.assertEquals(3, self.txn_ctrl.send.call_count)
+
+ def test_send_large_txns(self):
+ srv_1_defer = defer.Deferred()
+ srv_2_defer = defer.Deferred()
+ send_return_list = [srv_1_defer, srv_2_defer]
+
+ def do_send(x, y, z):
+ return make_deferred_yieldable(send_return_list.pop(0))
+
+ self.txn_ctrl.send = Mock(side_effect=do_send)
+
+ service = Mock(id=4, name="service")
+ event_list = [Mock(name="event%i" % (i + 1)) for i in range(200)]
+ for event in event_list:
+ self.queuer.enqueue_event(service, event)
+
+ # Expect the first event to be sent immediately.
+ self.txn_ctrl.send.assert_called_with(service, [event_list[0]], [])
+ srv_1_defer.callback(service)
+ # Then send the next 100 events
+ self.txn_ctrl.send.assert_called_with(service, event_list[1:101], [])
+ srv_2_defer.callback(service)
+ # Then the final 99 events
+ self.txn_ctrl.send.assert_called_with(service, event_list[101:], [])
self.assertEquals(3, self.txn_ctrl.send.call_count)
+
+ def test_send_single_ephemeral_no_queue(self):
+ # Expect the event to be sent immediately.
+ service = Mock(id=4, name="service")
+ event_list = [Mock(name="event")]
+ self.queuer.enqueue_ephemeral(service, event_list)
+ self.txn_ctrl.send.assert_called_once_with(service, [], event_list)
+
+ def test_send_multiple_ephemeral_no_queue(self):
+ # Expect the event to be sent immediately.
+ service = Mock(id=4, name="service")
+ event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")]
+ self.queuer.enqueue_ephemeral(service, event_list)
+ self.txn_ctrl.send.assert_called_once_with(service, [], event_list)
+
+ def test_send_single_ephemeral_with_queue(self):
+ d = defer.Deferred()
+ self.txn_ctrl.send = Mock(
+ side_effect=lambda x, y, z: make_deferred_yieldable(d)
+ )
+ service = Mock(id=4)
+ event_list_1 = [Mock(event_id="event1"), Mock(event_id="event2")]
+ event_list_2 = [Mock(event_id="event3"), Mock(event_id="event4")]
+ event_list_3 = [Mock(event_id="event5"), Mock(event_id="event6")]
+
+ # Send an event and don't resolve it just yet.
+ self.queuer.enqueue_ephemeral(service, event_list_1)
+ # Send more events: expect send() to NOT be called multiple times.
+ self.queuer.enqueue_ephemeral(service, event_list_2)
+ self.queuer.enqueue_ephemeral(service, event_list_3)
+ self.txn_ctrl.send.assert_called_with(service, [], event_list_1)
+ self.assertEquals(1, self.txn_ctrl.send.call_count)
+ # Resolve txn_ctrl.send
+ d.callback(service)
+ # Expect the queued events to be sent
+ self.txn_ctrl.send.assert_called_with(service, [], event_list_2 + event_list_3)
+ self.assertEquals(2, self.txn_ctrl.send.call_count)
+
+ def test_send_large_txns_ephemeral(self):
+ d = defer.Deferred()
+ self.txn_ctrl.send = Mock(
+ side_effect=lambda x, y, z: make_deferred_yieldable(d)
+ )
+ # Expect the event to be sent immediately.
+ service = Mock(id=4, name="service")
+ first_chunk = [Mock(name="event%i" % (i + 1)) for i in range(100)]
+ second_chunk = [Mock(name="event%i" % (i + 101)) for i in range(50)]
+ event_list = first_chunk + second_chunk
+ self.queuer.enqueue_ephemeral(service, event_list)
+ self.txn_ctrl.send.assert_called_once_with(service, [], first_chunk)
+ d.callback(service)
+ self.txn_ctrl.send.assert_called_with(service, [], second_chunk)
+ self.assertEquals(2, self.txn_ctrl.send.call_count)
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 8ff1460c0d..1d65ea2f9c 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -75,7 +75,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
return val
def test_verify_json_objects_for_server_awaits_previous_requests(self):
- mock_fetcher = keyring.KeyFetcher()
+ mock_fetcher = Mock()
mock_fetcher.get_keys = Mock()
kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,))
@@ -195,7 +195,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
"""Tests that we correctly handle key requests for keys we've stored
with a null `ts_valid_until_ms`
"""
- mock_fetcher = keyring.KeyFetcher()
+ mock_fetcher = Mock()
mock_fetcher.get_keys = Mock(return_value=make_awaitable({}))
kr = keyring.Keyring(
@@ -249,7 +249,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
}
}
- mock_fetcher = keyring.KeyFetcher()
+ mock_fetcher = Mock()
mock_fetcher.get_keys = Mock(side_effect=get_keys)
kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,))
@@ -288,9 +288,9 @@ class KeyringTestCase(unittest.HomeserverTestCase):
}
}
- mock_fetcher1 = keyring.KeyFetcher()
+ mock_fetcher1 = Mock()
mock_fetcher1.get_keys = Mock(side_effect=get_keys1)
- mock_fetcher2 = keyring.KeyFetcher()
+ mock_fetcher2 = Mock()
mock_fetcher2.get_keys = Mock(side_effect=get_keys2)
kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher1, mock_fetcher2))
@@ -315,7 +315,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.http_client = Mock()
- hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client)
+ hs = self.setup_test_homeserver(federation_http_client=self.http_client)
return hs
def test_get_keys_from_server(self):
@@ -396,7 +396,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
]
return self.setup_test_homeserver(
- handlers=None, http_client=self.http_client, config=config
+ federation_http_client=self.http_client, config=config
)
def build_perspectives_response(
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 1471cc1a28..9ccd2d76b8 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -48,10 +48,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
)
# Get the room complexity
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
)
- self.render(request)
self.assertEquals(200, channel.code)
complexity = channel.json_body["v1"]
self.assertTrue(complexity > 0, complexity)
@@ -61,10 +60,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
store.get_current_state_event_counts = lambda x: make_awaitable(500 * 1.23)
# Get the room complexity again -- make sure it's our artificial value
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
)
- self.render(request)
self.assertEquals(200, channel.code)
complexity = channel.json_body["v1"]
self.assertEqual(complexity, 1.23)
diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index da933ecd75..cfeccc0577 100644
--- a/tests/federation/test_federation_server.py
+++ b/tests/federation/test_federation_server.py
@@ -46,12 +46,11 @@ class FederationServerTests(unittest.FederatingHomeserverTestCase):
"/get_missing_events/(?P<room_id>[^/]*)/?"
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/federation/v1/get_missing_events/%s" % (room_1,),
query_content,
)
- self.render(request)
self.assertEquals(400, channel.code, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON")
@@ -96,10 +95,9 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase):
room_1 = self.helper.create_room_as(u1, tok=u1_token)
self.inject_room_member(room_1, "@user:other.example.com", "join")
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/_matrix/federation/v1/state/%s" % (room_1,)
)
- self.render(request)
self.assertEquals(200, channel.code, channel.result)
self.assertEqual(
@@ -129,10 +127,9 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase):
room_1 = self.helper.create_room_as(u1, tok=u1_token)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/_matrix/federation/v1/state/%s" % (room_1,)
)
- self.render(request)
self.assertEquals(403, channel.code, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
diff --git a/tests/federation/transport/__init__.py b/tests/federation/transport/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tests/federation/transport/__init__.py
diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py
index 72e22d655f..85500e169c 100644
--- a/tests/federation/transport/test_server.py
+++ b/tests/federation/transport/test_server.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,40 +13,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from twisted.internet import defer
-
-from synapse.config.ratelimiting import FederationRateLimitConfig
-from synapse.federation.transport import server
-from synapse.util.ratelimitutils import FederationRateLimiter
-
from tests import unittest
from tests.unittest import override_config
-class RoomDirectoryFederationTests(unittest.HomeserverTestCase):
- def prepare(self, reactor, clock, homeserver):
- class Authenticator:
- def authenticate_request(self, request, content):
- return defer.succeed("otherserver.nottld")
-
- ratelimiter = FederationRateLimiter(clock, FederationRateLimitConfig())
- server.register_servlets(
- homeserver, self.resource, Authenticator(), ratelimiter
- )
-
+class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase):
@override_config({"allow_public_rooms_over_federation": False})
def test_blocked_public_room_list_over_federation(self):
- request, channel = self.make_request(
- "GET", "/_matrix/federation/v1/publicRooms"
+ """Test that unauthenticated requests to the public rooms directory 403 when
+ allow_public_rooms_over_federation is False.
+ """
+ channel = self.make_request(
+ "GET",
+ "/_matrix/federation/v1/publicRooms",
+ federation_auth_origin=b"example.com",
)
- self.render(request)
self.assertEquals(403, channel.code)
@override_config({"allow_public_rooms_over_federation": True})
def test_open_public_room_list_over_federation(self):
- request, channel = self.make_request(
- "GET", "/_matrix/federation/v1/publicRooms"
+ """Test that unauthenticated requests to the public rooms directory 200 when
+ allow_public_rooms_over_federation is True.
+ """
+ channel = self.make_request(
+ "GET",
+ "/_matrix/federation/v1/publicRooms",
+ federation_auth_origin=b"example.com",
)
- self.render(request)
self.assertEquals(200, channel.code)
diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py
index fc37c4328c..5c2b4de1a6 100644
--- a/tests/handlers/test_admin.py
+++ b/tests/handlers/test_admin.py
@@ -35,7 +35,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
]
def prepare(self, reactor, clock, hs):
- self.admin_handler = hs.get_handlers().admin_handler
+ self.admin_handler = hs.get_admin_handler()
self.user1 = self.register_user("user1", "password")
self.token1 = self.login("user1", "password")
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 2a0b7c1b56..53763cd0f9 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -18,6 +18,7 @@ from mock import Mock
from twisted.internet import defer
from synapse.handlers.appservice import ApplicationServicesHandler
+from synapse.types import RoomStreamToken
from tests.test_utils import make_awaitable
from tests.utils import MockClock
@@ -41,7 +42,6 @@ class AppServiceHandlerTestCase(unittest.TestCase):
hs.get_clock.return_value = MockClock()
self.handler = ApplicationServicesHandler(hs)
- @defer.inlineCallbacks
def test_notify_interested_services(self):
interested_service = self._mkservice(is_interested=True)
services = [
@@ -61,12 +61,12 @@ class AppServiceHandlerTestCase(unittest.TestCase):
defer.succeed((0, [event])),
defer.succeed((0, [])),
]
- yield defer.ensureDeferred(self.handler.notify_interested_services(0))
+ self.handler.notify_interested_services(RoomStreamToken(None, 0))
+
self.mock_scheduler.submit_event_for_as.assert_called_once_with(
interested_service, event
)
- @defer.inlineCallbacks
def test_query_user_exists_unknown_user(self):
user_id = "@someone:anywhere"
services = [self._mkservice(is_interested=True)]
@@ -80,10 +80,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
defer.succeed((0, [event])),
defer.succeed((0, [])),
]
- yield defer.ensureDeferred(self.handler.notify_interested_services(0))
+
+ self.handler.notify_interested_services(RoomStreamToken(None, 0))
+
self.mock_as_api.query_user.assert_called_once_with(services[0], user_id)
- @defer.inlineCallbacks
def test_query_user_exists_known_user(self):
user_id = "@someone:anywhere"
services = [self._mkservice(is_interested=True)]
@@ -97,7 +98,9 @@ class AppServiceHandlerTestCase(unittest.TestCase):
defer.succeed((0, [event])),
defer.succeed((0, [])),
]
- yield defer.ensureDeferred(self.handler.notify_interested_services(0))
+
+ self.handler.notify_interested_services(RoomStreamToken(None, 0))
+
self.assertFalse(
self.mock_as_api.query_user.called,
"query_user called when it shouldn't have been.",
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 97877c2e42..e24ce81284 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -21,24 +21,17 @@ from twisted.internet import defer
import synapse
import synapse.api.errors
from synapse.api.errors import ResourceLimitError
-from synapse.handlers.auth import AuthHandler
from tests import unittest
from tests.test_utils import make_awaitable
from tests.utils import setup_test_homeserver
-class AuthHandlers:
- def __init__(self, hs):
- self.auth_handler = AuthHandler(hs)
-
-
class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
- self.hs = yield setup_test_homeserver(self.addCleanup, handlers=None)
- self.hs.handlers = AuthHandlers(self.hs)
- self.auth_handler = self.hs.handlers.auth_handler
+ self.hs = yield setup_test_homeserver(self.addCleanup)
+ self.auth_handler = self.hs.get_auth_handler()
self.macaroon_generator = self.hs.get_macaroon_generator()
# MAU tests
@@ -59,7 +52,7 @@ class AuthTestCase(unittest.TestCase):
self.fail("some_user was not in %s" % macaroon.inspect())
def test_macaroon_caveats(self):
- self.hs.clock.now = 5000
+ self.hs.get_clock().now = 5000
token = self.macaroon_generator.generate_access_token("a_user")
macaroon = pymacaroons.Macaroon.deserialize(token)
@@ -85,7 +78,7 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_short_term_login_token_gives_user_id(self):
- self.hs.clock.now = 1000
+ self.hs.get_clock().now = 1000
token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
user_id = yield defer.ensureDeferred(
@@ -94,7 +87,7 @@ class AuthTestCase(unittest.TestCase):
self.assertEqual("a_user", user_id)
# when we advance the clock, the token should be rejected
- self.hs.clock.now = 6000
+ self.hs.get_clock().now = 6000
with self.assertRaises(synapse.api.errors.AuthError):
yield defer.ensureDeferred(
self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
new file mode 100644
index 0000000000..bd7a1b6891
--- /dev/null
+++ b/tests/handlers/test_cas.py
@@ -0,0 +1,121 @@
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from mock import Mock
+
+from synapse.handlers.cas_handler import CasResponse
+
+from tests.test_utils import simple_async_mock
+from tests.unittest import HomeserverTestCase
+
+# These are a few constants that are used as config parameters in the tests.
+BASE_URL = "https://synapse/"
+SERVER_URL = "https://issuer/"
+
+
+class CasHandlerTestCase(HomeserverTestCase):
+ def default_config(self):
+ config = super().default_config()
+ config["public_baseurl"] = BASE_URL
+ cas_config = {
+ "enabled": True,
+ "server_url": SERVER_URL,
+ "service_url": BASE_URL,
+ }
+ config["cas_config"] = cas_config
+
+ return config
+
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver()
+
+ self.handler = hs.get_cas_handler()
+
+ # Reduce the number of attempts when generating MXIDs.
+ sso_handler = hs.get_sso_handler()
+ sso_handler._MAP_USERNAME_RETRIES = 3
+
+ return hs
+
+ def test_map_cas_user_to_user(self):
+ """Ensure that mapping the CAS user returned from a provider to an MXID works properly."""
+
+ # stub out the auth handler
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
+ cas_response = CasResponse("test_user", {})
+ request = _mock_request()
+ self.get_success(
+ self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+ )
+
+ # check that the auth handler got called as expected
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user:test", request, "redirect_uri", None
+ )
+
+ def test_map_cas_user_to_existing_user(self):
+ """Existing users can log in with CAS account."""
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.register_user(user_id="@test_user:test", password_hash=None)
+ )
+
+ # stub out the auth handler
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
+ # Map a user via SSO.
+ cas_response = CasResponse("test_user", {})
+ request = _mock_request()
+ self.get_success(
+ self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+ )
+
+ # check that the auth handler got called as expected
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user:test", request, "redirect_uri", None
+ )
+
+ # Subsequent calls should map to the same mxid.
+ auth_handler.complete_sso_login.reset_mock()
+ self.get_success(
+ self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+ )
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user:test", request, "redirect_uri", None
+ )
+
+ def test_map_cas_user_to_invalid_localpart(self):
+ """CAS automaps invalid characters to base-64 encoding."""
+
+ # stub out the auth handler
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
+ cas_response = CasResponse("föö", {})
+ request = _mock_request()
+ self.get_success(
+ self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+ )
+
+ # check that the auth handler got called as expected
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@f=c3=b6=c3=b6:test", request, "redirect_uri", None
+ )
+
+
+def _mock_request():
+ """Returns a mock which will stand in as a SynapseRequest"""
+ return Mock(spec=["getClientIP", "get_user_agent"])
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 969d44c787..5dfeccfeb6 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -26,7 +27,7 @@ user2 = "@theresa:bbb"
class DeviceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
- hs = self.setup_test_homeserver("server", http_client=None)
+ hs = self.setup_test_homeserver("server", federation_http_client=None)
self.handler = hs.get_device_handler()
self.store = hs.get_datastore()
return hs
@@ -224,3 +225,84 @@ class DeviceTestCase(unittest.HomeserverTestCase):
)
)
self.reactor.advance(1000)
+
+
+class DehydrationTestCase(unittest.HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver("server", federation_http_client=None)
+ self.handler = hs.get_device_handler()
+ self.registration = hs.get_registration_handler()
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ return hs
+
+ def test_dehydrate_and_rehydrate_device(self):
+ user_id = "@boris:dehydration"
+
+ self.get_success(self.store.register_user(user_id, "foobar"))
+
+ # First check if we can store and fetch a dehydrated device
+ stored_dehydrated_device_id = self.get_success(
+ self.handler.store_dehydrated_device(
+ user_id=user_id,
+ device_data={"device_data": {"foo": "bar"}},
+ initial_device_display_name="dehydrated device",
+ )
+ )
+
+ retrieved_device_id, device_data = self.get_success(
+ self.handler.get_dehydrated_device(user_id=user_id)
+ )
+
+ self.assertEqual(retrieved_device_id, stored_dehydrated_device_id)
+ self.assertEqual(device_data, {"device_data": {"foo": "bar"}})
+
+ # Create a new login for the user and dehydrated the device
+ device_id, access_token = self.get_success(
+ self.registration.register_device(
+ user_id=user_id, device_id=None, initial_display_name="new device",
+ )
+ )
+
+ # Trying to claim a nonexistent device should throw an error
+ self.get_failure(
+ self.handler.rehydrate_device(
+ user_id=user_id,
+ access_token=access_token,
+ device_id="not the right device ID",
+ ),
+ synapse.api.errors.NotFoundError,
+ )
+
+ # dehydrating the right devices should succeed and change our device ID
+ # to the dehydrated device's ID
+ res = self.get_success(
+ self.handler.rehydrate_device(
+ user_id=user_id,
+ access_token=access_token,
+ device_id=retrieved_device_id,
+ )
+ )
+
+ self.assertEqual(res, {"success": True})
+
+ # make sure that our device ID has changed
+ user_info = self.get_success(self.auth.get_user_by_access_token(access_token))
+
+ self.assertEqual(user_info.device_id, retrieved_device_id)
+
+ # make sure the device has the display name that was set from the login
+ res = self.get_success(self.handler.get_device(user_id, retrieved_device_id))
+
+ self.assertEqual(res["display_name"], "new device")
+
+ # make sure that the device ID that we were initially assigned no longer exists
+ self.get_failure(
+ self.handler.get_device(user_id, device_id),
+ synapse.api.errors.NotFoundError,
+ )
+
+ # make sure that there's no device available for dehydrating now
+ ret = self.get_success(self.handler.get_dehydrated_device(user_id=user_id))
+
+ self.assertIsNone(ret)
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index bc0c5aefdc..a39f898608 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -42,13 +42,11 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.mock_registry.register_query_handler = register_query_handler
hs = self.setup_test_homeserver(
- http_client=None,
- resource_for_federation=Mock(),
federation_client=self.mock_federation,
federation_registry=self.mock_registry,
)
- self.handler = hs.get_handlers().directory_handler
+ self.handler = hs.get_directory_handler()
self.store = hs.get_datastore()
@@ -110,7 +108,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
]
def prepare(self, reactor, clock, hs):
- self.handler = hs.get_handlers().directory_handler
+ self.handler = hs.get_directory_handler()
# Create user
self.admin_user = self.register_user("admin", "pass", admin=True)
@@ -173,7 +171,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
- self.handler = hs.get_handlers().directory_handler
+ self.handler = hs.get_directory_handler()
self.state_handler = hs.get_state_handler()
# Create user
@@ -289,7 +287,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
- self.handler = hs.get_handlers().directory_handler
+ self.handler = hs.get_directory_handler()
self.state_handler = hs.get_state_handler()
# Create user
@@ -407,23 +405,21 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
def test_denied(self):
room_id = self.helper.create_room_as(self.user_id)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
b"directory/room/%23test%3Atest",
('{"room_id":"%s"}' % (room_id,)).encode("ascii"),
)
- self.render(request)
self.assertEquals(403, channel.code, channel.result)
def test_allowed(self):
room_id = self.helper.create_room_as(self.user_id)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
b"directory/room/%23unofficial_test%3Atest",
('{"room_id":"%s"}' % (room_id,)).encode("ascii"),
)
- self.render(request)
self.assertEquals(200, channel.code, channel.result)
@@ -435,14 +431,13 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
room_id = self.helper.create_room_as(self.user_id)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}"
)
- self.render(request)
self.assertEquals(200, channel.code, channel.result)
self.room_list_handler = hs.get_room_list_handler()
- self.directory_handler = hs.get_handlers().directory_handler
+ self.directory_handler = hs.get_directory_handler()
return hs
@@ -451,8 +446,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
self.directory_handler.enable_room_list_search = True
# Room list is enabled so we should get some results
- request, channel = self.make_request("GET", b"publicRooms")
- self.render(request)
+ channel = self.make_request("GET", b"publicRooms")
self.assertEquals(200, channel.code, channel.result)
self.assertTrue(len(channel.json_body["chunk"]) > 0)
@@ -460,15 +454,13 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
self.directory_handler.enable_room_list_search = False
# Room list disabled so we should get no results
- request, channel = self.make_request("GET", b"publicRooms")
- self.render(request)
+ channel = self.make_request("GET", b"publicRooms")
self.assertEquals(200, channel.code, channel.result)
self.assertTrue(len(channel.json_body["chunk"]) == 0)
# Room list disabled so we shouldn't be allowed to publish rooms
room_id = self.helper.create_room_as(self.user_id)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}"
)
- self.render(request)
self.assertEquals(403, channel.code, channel.result)
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 366dcfb670..924f29f051 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -33,13 +33,15 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
super().__init__(*args, **kwargs)
self.hs = None # type: synapse.server.HomeServer
self.handler = None # type: synapse.handlers.e2e_keys.E2eKeysHandler
+ self.store = None # type: synapse.storage.Storage
@defer.inlineCallbacks
def setUp(self):
self.hs = yield utils.setup_test_homeserver(
- self.addCleanup, handlers=None, federation_client=mock.Mock()
+ self.addCleanup, federation_client=mock.Mock()
)
self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs)
+ self.store = self.hs.get_datastore()
@defer.inlineCallbacks
def test_query_local_devices_no_devices(self):
@@ -172,6 +174,89 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
)
@defer.inlineCallbacks
+ def test_fallback_key(self):
+ local_user = "@boris:" + self.hs.hostname
+ device_id = "xyz"
+ fallback_key = {"alg1:k1": "key1"}
+ otk = {"alg1:k2": "key2"}
+
+ # we shouldn't have any unused fallback keys yet
+ res = yield defer.ensureDeferred(
+ self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
+ )
+ self.assertEqual(res, [])
+
+ yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user,
+ device_id,
+ {"org.matrix.msc2732.fallback_keys": fallback_key},
+ )
+ )
+
+ # we should now have an unused alg1 key
+ res = yield defer.ensureDeferred(
+ self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
+ )
+ self.assertEqual(res, ["alg1"])
+
+ # claiming an OTK when no OTKs are available should return the fallback
+ # key
+ res = yield defer.ensureDeferred(
+ self.handler.claim_one_time_keys(
+ {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+ )
+ )
+ self.assertEqual(
+ res,
+ {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
+ )
+
+ # we shouldn't have any unused fallback keys again
+ res = yield defer.ensureDeferred(
+ self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
+ )
+ self.assertEqual(res, [])
+
+ # claiming an OTK again should return the same fallback key
+ res = yield defer.ensureDeferred(
+ self.handler.claim_one_time_keys(
+ {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+ )
+ )
+ self.assertEqual(
+ res,
+ {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
+ )
+
+ # if the user uploads a one-time key, the next claim should fetch the
+ # one-time key, and then go back to the fallback
+ yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user, device_id, {"one_time_keys": otk}
+ )
+ )
+
+ res = yield defer.ensureDeferred(
+ self.handler.claim_one_time_keys(
+ {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+ )
+ )
+ self.assertEqual(
+ res, {"failures": {}, "one_time_keys": {local_user: {device_id: otk}}},
+ )
+
+ res = yield defer.ensureDeferred(
+ self.handler.claim_one_time_keys(
+ {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+ )
+ )
+ self.assertEqual(
+ res,
+ {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
+ )
+
+ @defer.inlineCallbacks
def test_replace_master_key(self):
"""uploading a new signing key should make the old signing key unavailable"""
local_user = "@boris:" + self.hs.hostname
diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py
index 7adde9b9de..45f201a399 100644
--- a/tests/handlers/test_e2e_room_keys.py
+++ b/tests/handlers/test_e2e_room_keys.py
@@ -54,7 +54,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.hs = yield utils.setup_test_homeserver(
- self.addCleanup, handlers=None, replication_layer=mock.Mock()
+ self.addCleanup, replication_layer=mock.Mock()
)
self.handler = synapse.handlers.e2e_room_keys.E2eRoomKeysHandler(self.hs)
self.local_user = "@boris:" + self.hs.hostname
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 96fea58673..0b24b89a2e 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -37,8 +37,8 @@ class FederationTestCase(unittest.HomeserverTestCase):
]
def make_homeserver(self, reactor, clock):
- hs = self.setup_test_homeserver(http_client=None)
- self.handler = hs.get_handlers().federation_handler
+ hs = self.setup_test_homeserver(federation_http_client=None)
+ self.handler = hs.get_federation_handler()
self.store = hs.get_datastore()
return hs
@@ -59,7 +59,6 @@ class FederationTestCase(unittest.HomeserverTestCase):
)
d = self.handler.on_exchange_third_party_invite_request(
- room_id=room_id,
event_dict={
"type": EventTypes.Member,
"room_id": room_id,
@@ -127,7 +126,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
room_version,
)
- with LoggingContext(request="send_rejected"):
+ with LoggingContext("send_rejected"):
d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
self.get_success(d)
@@ -179,7 +178,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
room_version,
)
- with LoggingContext(request="send_rejected"):
+ with LoggingContext("send_rejected"):
d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
self.get_success(d)
@@ -199,7 +198,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
# the auth code requires that a signature exists, but doesn't check that
# signature... go figure.
join_event.signatures[other_server] = {"x": "y"}
- with LoggingContext(request="send_join"):
+ with LoggingContext("send_join"):
d = run_in_background(
self.handler.on_send_join_request, other_server, join_event
)
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
new file mode 100644
index 0000000000..f955dfa490
--- /dev/null
+++ b/tests/handlers/test_message.py
@@ -0,0 +1,212 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import Tuple
+
+from synapse.api.constants import EventTypes
+from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.types import create_requester
+from synapse.util.stringutils import random_string
+
+from tests import unittest
+
+logger = logging.getLogger(__name__)
+
+
+class EventCreationTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.handler = self.hs.get_event_creation_handler()
+ self.persist_event_storage = self.hs.get_storage().persistence
+
+ self.user_id = self.register_user("tester", "foobar")
+ self.access_token = self.login("tester", "foobar")
+ self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
+
+ self.info = self.get_success(
+ self.hs.get_datastore().get_user_by_access_token(self.access_token,)
+ )
+ self.token_id = self.info.token_id
+
+ self.requester = create_requester(self.user_id, access_token_id=self.token_id)
+
+ def _create_duplicate_event(self, txn_id: str) -> Tuple[EventBase, EventContext]:
+ """Create a new event with the given transaction ID. All events produced
+ by this method will be considered duplicates.
+ """
+
+ # We create a new event with a random body, as otherwise we'll produce
+ # *exactly* the same event with the same hash, and so same event ID.
+ return self.get_success(
+ self.handler.create_event(
+ self.requester,
+ {
+ "type": EventTypes.Message,
+ "room_id": self.room_id,
+ "sender": self.requester.user.to_string(),
+ "content": {"msgtype": "m.text", "body": random_string(5)},
+ },
+ txn_id=txn_id,
+ )
+ )
+
+ def test_duplicated_txn_id(self):
+ """Test that attempting to handle/persist an event with a transaction ID
+ that has already been persisted correctly returns the old event and does
+ *not* produce duplicate messages.
+ """
+
+ txn_id = "something_suitably_random"
+
+ event1, context = self._create_duplicate_event(txn_id)
+
+ ret_event1 = self.get_success(
+ self.handler.handle_new_client_event(self.requester, event1, context)
+ )
+ stream_id1 = ret_event1.internal_metadata.stream_ordering
+
+ self.assertEqual(event1.event_id, ret_event1.event_id)
+
+ event2, context = self._create_duplicate_event(txn_id)
+
+ # We want to test that the deduplication at the persit event end works,
+ # so we want to make sure we test with different events.
+ self.assertNotEqual(event1.event_id, event2.event_id)
+
+ ret_event2 = self.get_success(
+ self.handler.handle_new_client_event(self.requester, event2, context)
+ )
+ stream_id2 = ret_event2.internal_metadata.stream_ordering
+
+ # Assert that the returned values match those from the initial event
+ # rather than the new one.
+ self.assertEqual(ret_event1.event_id, ret_event2.event_id)
+ self.assertEqual(stream_id1, stream_id2)
+
+ # Let's test that calling `persist_event` directly also does the right
+ # thing.
+ event3, context = self._create_duplicate_event(txn_id)
+ self.assertNotEqual(event1.event_id, event3.event_id)
+
+ ret_event3, event_pos3, _ = self.get_success(
+ self.persist_event_storage.persist_event(event3, context)
+ )
+
+ # Assert that the returned values match those from the initial event
+ # rather than the new one.
+ self.assertEqual(ret_event1.event_id, ret_event3.event_id)
+ self.assertEqual(stream_id1, event_pos3.stream)
+
+ # Let's test that calling `persist_events` directly also does the right
+ # thing.
+ event4, context = self._create_duplicate_event(txn_id)
+ self.assertNotEqual(event1.event_id, event3.event_id)
+
+ events, _ = self.get_success(
+ self.persist_event_storage.persist_events([(event3, context)])
+ )
+ ret_event4 = events[0]
+
+ # Assert that the returned values match those from the initial event
+ # rather than the new one.
+ self.assertEqual(ret_event1.event_id, ret_event4.event_id)
+
+ def test_duplicated_txn_id_one_call(self):
+ """Test that we correctly handle duplicates that we try and persist at
+ the same time.
+ """
+
+ txn_id = "something_else_suitably_random"
+
+ # Create two duplicate events to persist at the same time
+ event1, context1 = self._create_duplicate_event(txn_id)
+ event2, context2 = self._create_duplicate_event(txn_id)
+
+ # Ensure their event IDs are different to start with
+ self.assertNotEqual(event1.event_id, event2.event_id)
+
+ events, _ = self.get_success(
+ self.persist_event_storage.persist_events(
+ [(event1, context1), (event2, context2)]
+ )
+ )
+
+ # Check that we've deduplicated the events.
+ self.assertEqual(len(events), 2)
+ self.assertEqual(events[0].event_id, events[1].event_id)
+
+
+class ServerAclValidationTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.user_id = self.register_user("tester", "foobar")
+ self.access_token = self.login("tester", "foobar")
+ self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
+
+ def test_allow_server_acl(self):
+ """Test that sending an ACL that blocks everyone but ourselves works.
+ """
+
+ self.helper.send_state(
+ self.room_id,
+ EventTypes.ServerACL,
+ body={"allow": [self.hs.hostname]},
+ tok=self.access_token,
+ expect_code=200,
+ )
+
+ def test_deny_server_acl_block_outselves(self):
+ """Test that sending an ACL that blocks ourselves does not work.
+ """
+ self.helper.send_state(
+ self.room_id,
+ EventTypes.ServerACL,
+ body={},
+ tok=self.access_token,
+ expect_code=400,
+ )
+
+ def test_deny_redact_server_acl(self):
+ """Test that attempting to redact an ACL is blocked.
+ """
+
+ body = self.helper.send_state(
+ self.room_id,
+ EventTypes.ServerACL,
+ body={"allow": [self.hs.hostname]},
+ tok=self.access_token,
+ expect_code=200,
+ )
+ event_id = body["event_id"]
+
+ # Redaction of event should fail.
+ path = "/_matrix/client/r0/rooms/%s/redact/%s" % (self.room_id, event_id)
+ channel = self.make_request(
+ "POST", path, content={}, access_token=self.access_token
+ )
+ self.assertEqual(int(channel.result["code"]), 403)
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index d5087e58be..368d600b33 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -12,40 +12,28 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import json
-from urllib.parse import parse_qs, urlparse
+import re
+from typing import Dict
+from urllib.parse import parse_qs, urlencode, urlparse
-from mock import Mock, patch
+from mock import ANY, Mock, patch
-import attr
import pymacaroons
-from twisted.python.failure import Failure
-from twisted.web._newclient import ResponseDone
+from twisted.web.resource import Resource
-from synapse.handlers.oidc_handler import (
- MappingException,
- OidcError,
- OidcHandler,
- OidcMappingProvider,
-)
+from synapse.api.errors import RedirectException
+from synapse.handlers.oidc_handler import OidcError
+from synapse.handlers.sso import MappingException
+from synapse.rest.client.v1 import login
+from synapse.rest.synapse.client.pick_username import pick_username_resource
+from synapse.server import HomeServer
from synapse.types import UserID
+from tests.test_utils import FakeResponse, simple_async_mock
from tests.unittest import HomeserverTestCase, override_config
-
-@attr.s
-class FakeResponse:
- code = attr.ib()
- body = attr.ib()
- phrase = attr.ib()
-
- def deliverBody(self, protocol):
- protocol.dataReceived(self.body)
- protocol.connectionLost(Failure(ResponseDone()))
-
-
# These are a few constants that are used as config parameters in the tests.
ISSUER = "https://issuer/"
CLIENT_ID = "test-client-id"
@@ -75,11 +63,14 @@ COOKIE_NAME = b"oidc_session"
COOKIE_PATH = "/_synapse/oidc"
-class TestMappingProvider(OidcMappingProvider):
+class TestMappingProvider:
@staticmethod
def parse_config(config):
return
+ def __init__(self, config):
+ pass
+
def get_remote_user_id(self, userinfo):
return userinfo["sub"]
@@ -94,14 +85,12 @@ class TestMappingProviderExtra(TestMappingProvider):
return {"phone": userinfo["phone"]}
-def simple_async_mock(return_value=None, raises=None):
- # AsyncMock is not available in python3.5, this mimics part of its behaviour
- async def cb(*args, **kwargs):
- if raises:
- raise raises
- return return_value
-
- return Mock(side_effect=cb)
+class TestMappingProviderFailures(TestMappingProvider):
+ async def map_user_attributes(self, userinfo, token, failures):
+ return {
+ "localpart": userinfo["username"] + (str(failures) if failures else ""),
+ "display_name": None,
+ }
async def get_json(url):
@@ -124,22 +113,16 @@ async def get_json(url):
class OidcHandlerTestCase(HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
-
- self.http_client = Mock(spec=["get_json"])
- self.http_client.get_json.side_effect = get_json
- self.http_client.user_agent = "Synapse Test"
-
- config = self.default_config()
+ def default_config(self):
+ config = super().default_config()
config["public_baseurl"] = BASE_URL
- oidc_config = {}
- oidc_config["enabled"] = True
- oidc_config["client_id"] = CLIENT_ID
- oidc_config["client_secret"] = CLIENT_SECRET
- oidc_config["issuer"] = ISSUER
- oidc_config["scopes"] = SCOPES
- oidc_config["user_mapping_provider"] = {
- "module": __name__ + ".TestMappingProvider",
+ oidc_config = {
+ "enabled": True,
+ "client_id": CLIENT_ID,
+ "client_secret": CLIENT_SECRET,
+ "issuer": ISSUER,
+ "scopes": SCOPES,
+ "user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
}
# Update this config with what's in the default config so that
@@ -147,13 +130,24 @@ class OidcHandlerTestCase(HomeserverTestCase):
oidc_config.update(config.get("oidc_config", {}))
config["oidc_config"] = oidc_config
- hs = self.setup_test_homeserver(
- http_client=self.http_client,
- proxied_http_client=self.http_client,
- config=config,
- )
+ return config
- self.handler = OidcHandler(hs)
+ def make_homeserver(self, reactor, clock):
+
+ self.http_client = Mock(spec=["get_json"])
+ self.http_client.get_json.side_effect = get_json
+ self.http_client.user_agent = "Synapse Test"
+
+ hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
+
+ self.handler = hs.get_oidc_handler()
+ sso_handler = hs.get_sso_handler()
+ # Mock the render error method.
+ self.render_error = Mock(return_value=None)
+ sso_handler.render_error = self.render_error
+
+ # Reduce the number of attempts when generating MXIDs.
+ sso_handler._MAP_USERNAME_RETRIES = 3
return hs
@@ -161,12 +155,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
return patch.dict(self.handler._provider_metadata, values)
def assertRenderedError(self, error, error_description=None):
- args = self.handler._render_error.call_args[0]
+ args = self.render_error.call_args[0]
self.assertEqual(args[1], error)
if error_description is not None:
self.assertEqual(args[2], error_description)
# Reset the render_error mock
- self.handler._render_error.reset_mock()
+ self.render_error.reset_mock()
+ return args
def test_config(self):
"""Basic config correctly sets up the callback URL and client auth correctly."""
@@ -286,9 +281,15 @@ class OidcHandlerTestCase(HomeserverTestCase):
h._validate_metadata,
)
- # Tests for configs that the userinfo endpoint
+ # Tests for configs that require the userinfo endpoint
self.assertFalse(h._uses_userinfo)
- h._scopes = [] # do not request the openid scope
+ self.assertEqual(h._user_profile_method, "auto")
+ h._user_profile_method = "userinfo_endpoint"
+ self.assertTrue(h._uses_userinfo)
+
+ # Revert the profile method and do not request the "openid" scope.
+ h._user_profile_method = "auto"
+ h._scopes = []
self.assertTrue(h._uses_userinfo)
self.assertRaisesRegex(ValueError, "userinfo_endpoint", h._validate_metadata)
@@ -350,7 +351,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
def test_callback_error(self):
"""Errors from the provider returned in the callback are displayed."""
- self.handler._render_error = Mock()
request = Mock(args={})
request.args[b"error"] = [b"invalid_client"]
self.get_success(self.handler.handle_oidc_callback(request))
@@ -371,25 +371,29 @@ class OidcHandlerTestCase(HomeserverTestCase):
- when the userinfo fetching fails
- when the code exchange fails
"""
+
+ # ensure that we are correctly testing the fallback when "get_extra_attributes"
+ # is not implemented.
+ mapping_provider = self.handler._user_mapping_provider
+ with self.assertRaises(AttributeError):
+ _ = mapping_provider.get_extra_attributes
+
token = {
"type": "bearer",
"id_token": "id_token",
"access_token": "access_token",
}
+ username = "bar"
userinfo = {
"sub": "foo",
- "preferred_username": "bar",
+ "username": username,
}
- user_id = "@foo:domain.org"
- self.handler._render_error = Mock(return_value=None)
+ expected_user_id = "@%s:%s" % (username, self.hs.hostname)
self.handler._exchange_code = simple_async_mock(return_value=token)
self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
- self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
- self.handler._auth_handler.complete_sso_login = simple_async_mock()
- request = Mock(
- spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"]
- )
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
code = "code"
state = "state"
@@ -397,67 +401,56 @@ class OidcHandlerTestCase(HomeserverTestCase):
client_redirect_url = "http://client/redirect"
user_agent = "Browser"
ip_address = "10.0.0.1"
- request.getCookie.return_value = self.handler._generate_oidc_session_token(
+ session = self.handler._generate_oidc_session_token(
state=state,
nonce=nonce,
client_redirect_url=client_redirect_url,
ui_auth_session_id=None,
)
-
- request.args = {}
- request.args[b"code"] = [code.encode("utf-8")]
- request.args[b"state"] = [state.encode("utf-8")]
-
- request.requestHeaders = Mock(spec=["getRawHeaders"])
- request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")]
- request.getClientIP.return_value = ip_address
+ request = _build_callback_request(
+ code, state, session, user_agent=user_agent, ip_address=ip_address
+ )
self.get_success(self.handler.handle_oidc_callback(request))
- self.handler._auth_handler.complete_sso_login.assert_called_once_with(
- user_id, request, client_redirect_url, {},
+ auth_handler.complete_sso_login.assert_called_once_with(
+ expected_user_id, request, client_redirect_url, None,
)
self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
- self.handler._map_userinfo_to_user.assert_called_once_with(
- userinfo, token, user_agent, ip_address
- )
self.handler._fetch_userinfo.assert_not_called()
- self.handler._render_error.assert_not_called()
+ self.render_error.assert_not_called()
# Handle mapping errors
- self.handler._map_userinfo_to_user = simple_async_mock(
- raises=MappingException()
- )
- self.get_success(self.handler.handle_oidc_callback(request))
- self.assertRenderedError("mapping_error")
- self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
+ with patch.object(
+ self.handler,
+ "_remote_id_from_userinfo",
+ new=Mock(side_effect=MappingException()),
+ ):
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("mapping_error")
# Handle ID token errors
self.handler._parse_id_token = simple_async_mock(raises=Exception())
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_token")
- self.handler._auth_handler.complete_sso_login.reset_mock()
+ auth_handler.complete_sso_login.reset_mock()
self.handler._exchange_code.reset_mock()
self.handler._parse_id_token.reset_mock()
- self.handler._map_userinfo_to_user.reset_mock()
self.handler._fetch_userinfo.reset_mock()
# With userinfo fetching
self.handler._scopes = [] # do not ask the "openid" scope
self.get_success(self.handler.handle_oidc_callback(request))
- self.handler._auth_handler.complete_sso_login.assert_called_once_with(
- user_id, request, client_redirect_url, {},
+ auth_handler.complete_sso_login.assert_called_once_with(
+ expected_user_id, request, client_redirect_url, None,
)
self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_not_called()
- self.handler._map_userinfo_to_user.assert_called_once_with(
- userinfo, token, user_agent, ip_address
- )
self.handler._fetch_userinfo.assert_called_once_with(token)
- self.handler._render_error.assert_not_called()
+ self.render_error.assert_not_called()
# Handle userinfo fetching error
self.handler._fetch_userinfo = simple_async_mock(raises=Exception())
@@ -473,7 +466,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
def test_callback_session(self):
"""The callback verifies the session presence and validity"""
- self.handler._render_error = Mock(return_value=None)
request = Mock(spec=["args", "getCookie", "addCookie"])
# Missing cookie
@@ -607,66 +599,55 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
userinfo = {
"sub": "foo",
+ "username": "foo",
"phone": "1234567",
}
- user_id = "@foo:domain.org"
self.handler._exchange_code = simple_async_mock(return_value=token)
self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
- self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
- self.handler._auth_handler.complete_sso_login = simple_async_mock()
- request = Mock(
- spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"]
- )
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
state = "state"
client_redirect_url = "http://client/redirect"
- request.getCookie.return_value = self.handler._generate_oidc_session_token(
+ session = self.handler._generate_oidc_session_token(
state=state,
nonce="nonce",
client_redirect_url=client_redirect_url,
ui_auth_session_id=None,
)
-
- request.args = {}
- request.args[b"code"] = [b"code"]
- request.args[b"state"] = [state.encode("utf-8")]
-
- request.requestHeaders = Mock(spec=["getRawHeaders"])
- request.requestHeaders.getRawHeaders.return_value = [b"Browser"]
- request.getClientIP.return_value = "10.0.0.1"
+ request = _build_callback_request("code", state, session)
self.get_success(self.handler.handle_oidc_callback(request))
- self.handler._auth_handler.complete_sso_login.assert_called_once_with(
- user_id, request, client_redirect_url, {"phone": "1234567"},
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@foo:test", request, client_redirect_url, {"phone": "1234567"},
)
def test_map_userinfo_to_user(self):
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
userinfo = {
"sub": "test_user",
"username": "test_user",
}
- # The token doesn't matter with the default user mapping provider.
- token = {}
- mxid = self.get_success(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- )
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user:test", ANY, ANY, None,
)
- self.assertEqual(mxid, "@test_user:test")
+ auth_handler.complete_sso_login.reset_mock()
# Some providers return an integer ID.
userinfo = {
"sub": 1234,
"username": "test_user_2",
}
- mxid = self.get_success(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- )
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user_2:test", ANY, ANY, None,
)
- self.assertEqual(mxid, "@test_user_2:test")
+ auth_handler.complete_sso_login.reset_mock()
# Test if the mxid is already taken
store = self.hs.get_datastore()
@@ -675,30 +656,352 @@ class OidcHandlerTestCase(HomeserverTestCase):
store.register_user(user_id=user3.to_string(), password_hash=None)
)
userinfo = {"sub": "test3", "username": "test_user_3"}
- e = self.get_failure(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- ),
- MappingException,
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_not_called()
+ self.assertRenderedError(
+ "mapping_error",
+ "Mapping provider does not support de-duplicating Matrix IDs",
)
- self.assertEqual(str(e.value), "mxid '@test_user_3:test' is already taken")
@override_config({"oidc_config": {"allow_existing_users": True}})
def test_map_userinfo_to_existing_user(self):
"""Existing users can log in with OpenID Connect when allow_existing_users is True."""
store = self.hs.get_datastore()
- user4 = UserID.from_string("@test_user_4:test")
+ user = UserID.from_string("@test_user:test")
+ self.get_success(
+ store.register_user(user_id=user.to_string(), password_hash=None)
+ )
+
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
+ # Map a user via SSO.
+ userinfo = {
+ "sub": "test",
+ "username": "test_user",
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_called_once_with(
+ user.to_string(), ANY, ANY, None,
+ )
+ auth_handler.complete_sso_login.reset_mock()
+
+ # Subsequent calls should map to the same mxid.
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_called_once_with(
+ user.to_string(), ANY, ANY, None,
+ )
+ auth_handler.complete_sso_login.reset_mock()
+
+ # Note that a second SSO user can be mapped to the same Matrix ID. (This
+ # requires a unique sub, but something that maps to the same matrix ID,
+ # in this case we'll just use the same username. A more realistic example
+ # would be subs which are email addresses, and mapping from the localpart
+ # of the email, e.g. bob@foo.com and bob@bar.com -> @bob:test.)
+ userinfo = {
+ "sub": "test1",
+ "username": "test_user",
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_called_once_with(
+ user.to_string(), ANY, ANY, None,
+ )
+ auth_handler.complete_sso_login.reset_mock()
+
+ # Register some non-exact matching cases.
+ user2 = UserID.from_string("@TEST_user_2:test")
+ self.get_success(
+ store.register_user(user_id=user2.to_string(), password_hash=None)
+ )
+ user2_caps = UserID.from_string("@test_USER_2:test")
+ self.get_success(
+ store.register_user(user_id=user2_caps.to_string(), password_hash=None)
+ )
+
+ # Attempting to login without matching a name exactly is an error.
+ userinfo = {
+ "sub": "test2",
+ "username": "TEST_USER_2",
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_not_called()
+ args = self.assertRenderedError("mapping_error")
+ self.assertTrue(
+ args[2].startswith(
+ "Attempted to login as '@TEST_USER_2:test' but it matches more than one user inexactly:"
+ )
+ )
+
+ # Logging in when matching a name exactly should work.
+ user2 = UserID.from_string("@TEST_USER_2:test")
+ self.get_success(
+ store.register_user(user_id=user2.to_string(), password_hash=None)
+ )
+
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@TEST_USER_2:test", ANY, ANY, None,
+ )
+
+ def test_map_userinfo_to_invalid_localpart(self):
+ """If the mapping provider generates an invalid localpart it should be rejected."""
+ self.get_success(
+ _make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"})
+ )
+ self.assertRenderedError("mapping_error", "localpart is invalid: föö")
+
+ @override_config(
+ {
+ "oidc_config": {
+ "user_mapping_provider": {
+ "module": __name__ + ".TestMappingProviderFailures"
+ }
+ }
+ }
+ )
+ def test_map_userinfo_to_user_retries(self):
+ """The mapping provider can retry generating an MXID if the MXID is already in use."""
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
+ store = self.hs.get_datastore()
self.get_success(
- store.register_user(user_id=user4.to_string(), password_hash=None)
+ store.register_user(user_id="@test_user:test", password_hash=None)
)
userinfo = {
- "sub": "test4",
- "username": "test_user_4",
+ "sub": "test",
+ "username": "test_user",
}
- token = {}
- mxid = self.get_success(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+
+ # test_user is already taken, so test_user1 gets registered instead.
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user1:test", ANY, ANY, None,
+ )
+ auth_handler.complete_sso_login.reset_mock()
+
+ # Register all of the potential mxids for a particular OIDC username.
+ self.get_success(
+ store.register_user(user_id="@tester:test", password_hash=None)
+ )
+ for i in range(1, 3):
+ self.get_success(
+ store.register_user(user_id="@tester%d:test" % i, password_hash=None)
)
+
+ # Now attempt to map to a username, this will fail since all potential usernames are taken.
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_not_called()
+ self.assertRenderedError(
+ "mapping_error", "Unable to generate a Matrix ID from the SSO response"
+ )
+
+ def test_empty_localpart(self):
+ """Attempts to map onto an empty localpart should be rejected."""
+ userinfo = {
+ "sub": "tester",
+ "username": "",
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ self.assertRenderedError("mapping_error", "localpart is invalid: ")
+
+ @override_config(
+ {
+ "oidc_config": {
+ "user_mapping_provider": {
+ "config": {"localpart_template": "{{ user.username }}"}
+ }
+ }
+ }
+ )
+ def test_null_localpart(self):
+ """Mapping onto a null localpart via an empty OIDC attribute should be rejected"""
+ userinfo = {
+ "sub": "tester",
+ "username": None,
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ self.assertRenderedError("mapping_error", "localpart is invalid: ")
+
+
+class UsernamePickerTestCase(HomeserverTestCase):
+ servlets = [login.register_servlets]
+
+ def default_config(self):
+ config = super().default_config()
+ config["public_baseurl"] = BASE_URL
+ oidc_config = {
+ "enabled": True,
+ "client_id": CLIENT_ID,
+ "client_secret": CLIENT_SECRET,
+ "issuer": ISSUER,
+ "scopes": SCOPES,
+ "user_mapping_provider": {
+ "config": {"display_name_template": "{{ user.displayname }}"}
+ },
+ }
+
+ # Update this config with what's in the default config so that
+ # override_config works as expected.
+ oidc_config.update(config.get("oidc_config", {}))
+ config["oidc_config"] = oidc_config
+
+ # whitelist this client URI so we redirect straight to it rather than
+ # serving a confirmation page
+ config["sso"] = {"client_whitelist": ["https://whitelisted.client"]}
+ return config
+
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ d = super().create_resource_dict()
+ d["/_synapse/client/pick_username"] = pick_username_resource(self.hs)
+ return d
+
+ def test_username_picker(self):
+ """Test the happy path of a username picker flow."""
+ client_redirect_url = "https://whitelisted.client"
+
+ # first of all, mock up an OIDC callback to the OidcHandler, which should
+ # raise a RedirectException
+ userinfo = {"sub": "tester", "displayname": "Jonny"}
+ f = self.get_failure(
+ _make_callback_with_userinfo(
+ self.hs, userinfo, client_redirect_url=client_redirect_url
+ ),
+ RedirectException,
+ )
+
+ # check the Location and cookies returned by the RedirectException
+ self.assertEqual(f.value.location, b"/_synapse/client/pick_username")
+ cookieheader = f.value.cookies[0]
+ regex = re.compile(b"^username_mapping_session=([a-zA-Z]+);")
+ m = regex.search(cookieheader)
+ if not m:
+ self.fail("cookie header %s does not match %s" % (cookieheader, regex))
+
+ # introspect the sso handler a bit to check that the username mapping session
+ # looks ok.
+ session_id = m.group(1).decode("ascii")
+ username_mapping_sessions = self.hs.get_sso_handler()._username_mapping_sessions
+ self.assertIn(
+ session_id, username_mapping_sessions, "session id not found in map"
)
- self.assertEqual(mxid, "@test_user_4:test")
+ session = username_mapping_sessions[session_id]
+ self.assertEqual(session.remote_user_id, "tester")
+ self.assertEqual(session.display_name, "Jonny")
+ self.assertEqual(session.client_redirect_url, client_redirect_url)
+
+ # the expiry time should be about 15 minutes away
+ expected_expiry = self.clock.time_msec() + (15 * 60 * 1000)
+ self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000)
+
+ # Now, submit a username to the username picker, which should serve a redirect
+ # back to the client
+ submit_path = f.value.location + b"/submit"
+ content = urlencode({b"username": b"bobby"}).encode("utf8")
+ chan = self.make_request(
+ "POST",
+ path=submit_path,
+ content=content,
+ content_is_form=True,
+ custom_headers=[
+ ("Cookie", cookieheader),
+ # old versions of twisted don't do form-parsing without a valid
+ # content-length header.
+ ("Content-Length", str(len(content))),
+ ],
+ )
+ self.assertEqual(chan.code, 302, chan.result)
+ location_headers = chan.headers.getRawHeaders("Location")
+ # ensure that the returned location starts with the requested redirect URL
+ self.assertEqual(
+ location_headers[0][: len(client_redirect_url)], client_redirect_url
+ )
+
+ # fish the login token out of the returned redirect uri
+ parts = urlparse(location_headers[0])
+ query = parse_qs(parts.query)
+ login_token = query["loginToken"][0]
+
+ # finally, submit the matrix login token to the login API, which gives us our
+ # matrix access token, mxid, and device id.
+ chan = self.make_request(
+ "POST", "/login", content={"type": "m.login.token", "token": login_token},
+ )
+ self.assertEqual(chan.code, 200, chan.result)
+ self.assertEqual(chan.json_body["user_id"], "@bobby:test")
+
+
+async def _make_callback_with_userinfo(
+ hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect"
+) -> None:
+ """Mock up an OIDC callback with the given userinfo dict
+
+ We'll pull out the OIDC handler from the homeserver, stub out a couple of methods,
+ and poke in the userinfo dict as if it were the response to an OIDC userinfo call.
+
+ Args:
+ hs: the HomeServer impl to send the callback to.
+ userinfo: the OIDC userinfo dict
+ client_redirect_url: the URL to redirect to on success.
+ """
+ handler = hs.get_oidc_handler()
+ handler._exchange_code = simple_async_mock(return_value={})
+ handler._parse_id_token = simple_async_mock(return_value=userinfo)
+ handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
+
+ state = "state"
+ session = handler._generate_oidc_session_token(
+ state=state,
+ nonce="nonce",
+ client_redirect_url=client_redirect_url,
+ ui_auth_session_id=None,
+ )
+ request = _build_callback_request("code", state, session)
+
+ await handler.handle_oidc_callback(request)
+
+
+def _build_callback_request(
+ code: str,
+ state: str,
+ session: str,
+ user_agent: str = "Browser",
+ ip_address: str = "10.0.0.1",
+):
+ """Builds a fake SynapseRequest to mock the browser callback
+
+ Returns a Mock object which looks like the SynapseRequest we get from a browser
+ after SSO (before we return to the client)
+
+ Args:
+ code: the authorization code which would have been returned by the OIDC
+ provider
+ state: the "state" param which would have been passed around in the
+ query param. Should be the same as was embedded in the session in
+ _build_oidc_session.
+ session: the "session" which would have been passed around in the cookie.
+ user_agent: the user-agent to present
+ ip_address: the IP address to pretend the request came from
+ """
+ request = Mock(
+ spec=[
+ "args",
+ "getCookie",
+ "addCookie",
+ "requestHeaders",
+ "getClientIP",
+ "get_user_agent",
+ ]
+ )
+
+ request.getCookie.return_value = session
+ request.args = {}
+ request.args[b"code"] = [code.encode("utf-8")]
+ request.args[b"state"] = [state.encode("utf-8")]
+ request.getClientIP.return_value = ip_address
+ request.get_user_agent.return_value = user_agent
+ return request
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
new file mode 100644
index 0000000000..f816594ee4
--- /dev/null
+++ b/tests/handlers/test_password_providers.py
@@ -0,0 +1,603 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for the password_auth_provider interface"""
+
+from typing import Any, Type, Union
+
+from mock import Mock
+
+from twisted.internet import defer
+
+import synapse
+from synapse.rest.client.v1 import login
+from synapse.rest.client.v2_alpha import devices
+from synapse.types import JsonDict
+
+from tests import unittest
+from tests.server import FakeChannel
+from tests.unittest import override_config
+
+# (possibly experimental) login flows we expect to appear in the list after the normal
+# ones
+ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}]
+
+# a mock instance which the dummy auth providers delegate to, so we can see what's going
+# on
+mock_password_provider = Mock()
+
+
+class PasswordOnlyAuthProvider:
+ """A password_provider which only implements `check_password`."""
+
+ @staticmethod
+ def parse_config(self):
+ pass
+
+ def __init__(self, config, account_handler):
+ pass
+
+ def check_password(self, *args):
+ return mock_password_provider.check_password(*args)
+
+
+class CustomAuthProvider:
+ """A password_provider which implements a custom login type."""
+
+ @staticmethod
+ def parse_config(self):
+ pass
+
+ def __init__(self, config, account_handler):
+ pass
+
+ def get_supported_login_types(self):
+ return {"test.login_type": ["test_field"]}
+
+ def check_auth(self, *args):
+ return mock_password_provider.check_auth(*args)
+
+
+class PasswordCustomAuthProvider:
+ """A password_provider which implements password login via `check_auth`, as well
+ as a custom type."""
+
+ @staticmethod
+ def parse_config(self):
+ pass
+
+ def __init__(self, config, account_handler):
+ pass
+
+ def get_supported_login_types(self):
+ return {"m.login.password": ["password"], "test.login_type": ["test_field"]}
+
+ def check_auth(self, *args):
+ return mock_password_provider.check_auth(*args)
+
+
+def providers_config(*providers: Type[Any]) -> dict:
+ """Returns a config dict that will enable the given password auth providers"""
+ return {
+ "password_providers": [
+ {"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}}
+ for provider in providers
+ ]
+ }
+
+
+class PasswordAuthProviderTests(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ devices.register_servlets,
+ ]
+
+ def setUp(self):
+ # we use a global mock device, so make sure we are starting with a clean slate
+ mock_password_provider.reset_mock()
+ super().setUp()
+
+ @override_config(providers_config(PasswordOnlyAuthProvider))
+ def test_password_only_auth_provider_login(self):
+ # login flows should only have m.login.password
+ flows = self._get_login_flows()
+ self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
+
+ # check_password must return an awaitable
+ mock_password_provider.check_password.return_value = defer.succeed(True)
+ channel = self._send_password_login("u", "p")
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual("@u:test", channel.json_body["user_id"])
+ mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
+ mock_password_provider.reset_mock()
+
+ # login with mxid should work too
+ channel = self._send_password_login("@u:bz", "p")
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual("@u:bz", channel.json_body["user_id"])
+ mock_password_provider.check_password.assert_called_once_with("@u:bz", "p")
+ mock_password_provider.reset_mock()
+
+ # try a weird username / pass. Honestly it's unclear what we *expect* to happen
+ # in these cases, but at least we can guard against the API changing
+ # unexpectedly
+ channel = self._send_password_login(" USER🙂NAME ", " pASS\U0001F622word ")
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual("@ USER🙂NAME :test", channel.json_body["user_id"])
+ mock_password_provider.check_password.assert_called_once_with(
+ "@ USER🙂NAME :test", " pASS😢word "
+ )
+
+ @override_config(providers_config(PasswordOnlyAuthProvider))
+ def test_password_only_auth_provider_ui_auth(self):
+ """UI Auth should delegate correctly to the password provider"""
+
+ # create the user, otherwise access doesn't work
+ module_api = self.hs.get_module_api()
+ self.get_success(module_api.register_user("u"))
+
+ # log in twice, to get two devices
+ mock_password_provider.check_password.return_value = defer.succeed(True)
+ tok1 = self.login("u", "p")
+ self.login("u", "p", device_id="dev2")
+ mock_password_provider.reset_mock()
+
+ # have the auth provider deny the request to start with
+ mock_password_provider.check_password.return_value = defer.succeed(False)
+
+ # make the initial request which returns a 401
+ session = self._start_delete_device_session(tok1, "dev2")
+ mock_password_provider.check_password.assert_not_called()
+
+ # Make another request providing the UI auth flow.
+ channel = self._authed_delete_device(tok1, "dev2", session, "u", "p")
+ self.assertEqual(channel.code, 401) # XXX why not a 403?
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
+ mock_password_provider.reset_mock()
+
+ # Finally, check the request goes through when we allow it
+ mock_password_provider.check_password.return_value = defer.succeed(True)
+ channel = self._authed_delete_device(tok1, "dev2", session, "u", "p")
+ self.assertEqual(channel.code, 200)
+ mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
+
+ @override_config(providers_config(PasswordOnlyAuthProvider))
+ def test_local_user_fallback_login(self):
+ """rejected login should fall back to local db"""
+ self.register_user("localuser", "localpass")
+
+ # check_password must return an awaitable
+ mock_password_provider.check_password.return_value = defer.succeed(False)
+ channel = self._send_password_login("u", "p")
+ self.assertEqual(channel.code, 403, channel.result)
+
+ channel = self._send_password_login("localuser", "localpass")
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual("@localuser:test", channel.json_body["user_id"])
+
+ @override_config(providers_config(PasswordOnlyAuthProvider))
+ def test_local_user_fallback_ui_auth(self):
+ """rejected login should fall back to local db"""
+ self.register_user("localuser", "localpass")
+
+ # have the auth provider deny the request
+ mock_password_provider.check_password.return_value = defer.succeed(False)
+
+ # log in twice, to get two devices
+ tok1 = self.login("localuser", "localpass")
+ self.login("localuser", "localpass", device_id="dev2")
+ mock_password_provider.check_password.reset_mock()
+
+ # first delete should give a 401
+ session = self._start_delete_device_session(tok1, "dev2")
+ mock_password_provider.check_password.assert_not_called()
+
+ # Wrong password
+ channel = self._authed_delete_device(tok1, "dev2", session, "localuser", "xxx")
+ self.assertEqual(channel.code, 401) # XXX why not a 403?
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ mock_password_provider.check_password.assert_called_once_with(
+ "@localuser:test", "xxx"
+ )
+ mock_password_provider.reset_mock()
+
+ # Right password
+ channel = self._authed_delete_device(
+ tok1, "dev2", session, "localuser", "localpass"
+ )
+ self.assertEqual(channel.code, 200)
+ mock_password_provider.check_password.assert_called_once_with(
+ "@localuser:test", "localpass"
+ )
+
+ @override_config(
+ {
+ **providers_config(PasswordOnlyAuthProvider),
+ "password_config": {"localdb_enabled": False},
+ }
+ )
+ def test_no_local_user_fallback_login(self):
+ """localdb_enabled can block login with the local password
+ """
+ self.register_user("localuser", "localpass")
+
+ # check_password must return an awaitable
+ mock_password_provider.check_password.return_value = defer.succeed(False)
+ channel = self._send_password_login("localuser", "localpass")
+ self.assertEqual(channel.code, 403)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ mock_password_provider.check_password.assert_called_once_with(
+ "@localuser:test", "localpass"
+ )
+
+ @override_config(
+ {
+ **providers_config(PasswordOnlyAuthProvider),
+ "password_config": {"localdb_enabled": False},
+ }
+ )
+ def test_no_local_user_fallback_ui_auth(self):
+ """localdb_enabled can block ui auth with the local password
+ """
+ self.register_user("localuser", "localpass")
+
+ # allow login via the auth provider
+ mock_password_provider.check_password.return_value = defer.succeed(True)
+
+ # log in twice, to get two devices
+ tok1 = self.login("localuser", "p")
+ self.login("localuser", "p", device_id="dev2")
+ mock_password_provider.check_password.reset_mock()
+
+ # first delete should give a 401
+ channel = self._delete_device(tok1, "dev2")
+ self.assertEqual(channel.code, 401)
+ # m.login.password UIA is permitted because the auth provider allows it,
+ # even though the localdb does not.
+ self.assertEqual(channel.json_body["flows"], [{"stages": ["m.login.password"]}])
+ session = channel.json_body["session"]
+ mock_password_provider.check_password.assert_not_called()
+
+ # now try deleting with the local password
+ mock_password_provider.check_password.return_value = defer.succeed(False)
+ channel = self._authed_delete_device(
+ tok1, "dev2", session, "localuser", "localpass"
+ )
+ self.assertEqual(channel.code, 401) # XXX why not a 403?
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ mock_password_provider.check_password.assert_called_once_with(
+ "@localuser:test", "localpass"
+ )
+
+ @override_config(
+ {
+ **providers_config(PasswordOnlyAuthProvider),
+ "password_config": {"enabled": False},
+ }
+ )
+ def test_password_auth_disabled(self):
+ """password auth doesn't work if it's disabled across the board"""
+ # login flows should be empty
+ flows = self._get_login_flows()
+ self.assertEqual(flows, ADDITIONAL_LOGIN_FLOWS)
+
+ # login shouldn't work and should be rejected with a 400 ("unknown login type")
+ channel = self._send_password_login("u", "p")
+ self.assertEqual(channel.code, 400, channel.result)
+ mock_password_provider.check_password.assert_not_called()
+
+ @override_config(providers_config(CustomAuthProvider))
+ def test_custom_auth_provider_login(self):
+ # login flows should have the custom flow and m.login.password, since we
+ # haven't disabled local password lookup.
+ # (password must come first, because reasons)
+ flows = self._get_login_flows()
+ self.assertEqual(
+ flows,
+ [{"type": "m.login.password"}, {"type": "test.login_type"}]
+ + ADDITIONAL_LOGIN_FLOWS,
+ )
+
+ # login with missing param should be rejected
+ channel = self._send_login("test.login_type", "u")
+ self.assertEqual(channel.code, 400, channel.result)
+ mock_password_provider.check_auth.assert_not_called()
+
+ mock_password_provider.check_auth.return_value = defer.succeed("@user:bz")
+ channel = self._send_login("test.login_type", "u", test_field="y")
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual("@user:bz", channel.json_body["user_id"])
+ mock_password_provider.check_auth.assert_called_once_with(
+ "u", "test.login_type", {"test_field": "y"}
+ )
+ mock_password_provider.reset_mock()
+
+ # try a weird username. Again, it's unclear what we *expect* to happen
+ # in these cases, but at least we can guard against the API changing
+ # unexpectedly
+ mock_password_provider.check_auth.return_value = defer.succeed(
+ "@ MALFORMED! :bz"
+ )
+ channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual("@ MALFORMED! :bz", channel.json_body["user_id"])
+ mock_password_provider.check_auth.assert_called_once_with(
+ " USER🙂NAME ", "test.login_type", {"test_field": " abc "}
+ )
+
+ @override_config(providers_config(CustomAuthProvider))
+ def test_custom_auth_provider_ui_auth(self):
+ # register the user and log in twice, to get two devices
+ self.register_user("localuser", "localpass")
+ tok1 = self.login("localuser", "localpass")
+ self.login("localuser", "localpass", device_id="dev2")
+
+ # make the initial request which returns a 401
+ channel = self._delete_device(tok1, "dev2")
+ self.assertEqual(channel.code, 401)
+ # Ensure that flows are what is expected.
+ self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
+ self.assertIn({"stages": ["test.login_type"]}, channel.json_body["flows"])
+ session = channel.json_body["session"]
+
+ # missing param
+ body = {
+ "auth": {
+ "type": "test.login_type",
+ "identifier": {"type": "m.id.user", "user": "localuser"},
+ "session": session,
+ },
+ }
+
+ channel = self._delete_device(tok1, "dev2", body)
+ self.assertEqual(channel.code, 400)
+ # there's a perfectly good M_MISSING_PARAM errcode, but heaven forfend we should
+ # use it...
+ self.assertIn("Missing parameters", channel.json_body["error"])
+ mock_password_provider.check_auth.assert_not_called()
+ mock_password_provider.reset_mock()
+
+ # right params, but authing as the wrong user
+ mock_password_provider.check_auth.return_value = defer.succeed("@user:bz")
+ body["auth"]["test_field"] = "foo"
+ channel = self._delete_device(tok1, "dev2", body)
+ self.assertEqual(channel.code, 403)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ mock_password_provider.check_auth.assert_called_once_with(
+ "localuser", "test.login_type", {"test_field": "foo"}
+ )
+ mock_password_provider.reset_mock()
+
+ # and finally, succeed
+ mock_password_provider.check_auth.return_value = defer.succeed(
+ "@localuser:test"
+ )
+ channel = self._delete_device(tok1, "dev2", body)
+ self.assertEqual(channel.code, 200)
+ mock_password_provider.check_auth.assert_called_once_with(
+ "localuser", "test.login_type", {"test_field": "foo"}
+ )
+
+ @override_config(providers_config(CustomAuthProvider))
+ def test_custom_auth_provider_callback(self):
+ callback = Mock(return_value=defer.succeed(None))
+
+ mock_password_provider.check_auth.return_value = defer.succeed(
+ ("@user:bz", callback)
+ )
+ channel = self._send_login("test.login_type", "u", test_field="y")
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual("@user:bz", channel.json_body["user_id"])
+ mock_password_provider.check_auth.assert_called_once_with(
+ "u", "test.login_type", {"test_field": "y"}
+ )
+
+ # check the args to the callback
+ callback.assert_called_once()
+ call_args, call_kwargs = callback.call_args
+ # should be one positional arg
+ self.assertEqual(len(call_args), 1)
+ self.assertEqual(call_args[0]["user_id"], "@user:bz")
+ for p in ["user_id", "access_token", "device_id", "home_server"]:
+ self.assertIn(p, call_args[0])
+
+ @override_config(
+ {**providers_config(CustomAuthProvider), "password_config": {"enabled": False}}
+ )
+ def test_custom_auth_password_disabled(self):
+ """Test login with a custom auth provider where password login is disabled"""
+ self.register_user("localuser", "localpass")
+
+ flows = self._get_login_flows()
+ self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
+
+ # login shouldn't work and should be rejected with a 400 ("unknown login type")
+ channel = self._send_password_login("localuser", "localpass")
+ self.assertEqual(channel.code, 400, channel.result)
+ mock_password_provider.check_auth.assert_not_called()
+
+ @override_config(
+ {
+ **providers_config(CustomAuthProvider),
+ "password_config": {"enabled": False, "localdb_enabled": False},
+ }
+ )
+ def test_custom_auth_password_disabled_localdb_enabled(self):
+ """Check the localdb_enabled == enabled == False
+
+ Regression test for https://github.com/matrix-org/synapse/issues/8914: check
+ that setting *both* `localdb_enabled` *and* `password: enabled` to False doesn't
+ cause an exception.
+ """
+ self.register_user("localuser", "localpass")
+
+ flows = self._get_login_flows()
+ self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
+
+ # login shouldn't work and should be rejected with a 400 ("unknown login type")
+ channel = self._send_password_login("localuser", "localpass")
+ self.assertEqual(channel.code, 400, channel.result)
+ mock_password_provider.check_auth.assert_not_called()
+
+ @override_config(
+ {
+ **providers_config(PasswordCustomAuthProvider),
+ "password_config": {"enabled": False},
+ }
+ )
+ def test_password_custom_auth_password_disabled_login(self):
+ """log in with a custom auth provider which implements password, but password
+ login is disabled"""
+ self.register_user("localuser", "localpass")
+
+ flows = self._get_login_flows()
+ self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
+
+ # login shouldn't work and should be rejected with a 400 ("unknown login type")
+ channel = self._send_password_login("localuser", "localpass")
+ self.assertEqual(channel.code, 400, channel.result)
+ mock_password_provider.check_auth.assert_not_called()
+
+ @override_config(
+ {
+ **providers_config(PasswordCustomAuthProvider),
+ "password_config": {"enabled": False},
+ }
+ )
+ def test_password_custom_auth_password_disabled_ui_auth(self):
+ """UI Auth with a custom auth provider which implements password, but password
+ login is disabled"""
+ # register the user and log in twice via the test login type to get two devices,
+ self.register_user("localuser", "localpass")
+ mock_password_provider.check_auth.return_value = defer.succeed(
+ "@localuser:test"
+ )
+ channel = self._send_login("test.login_type", "localuser", test_field="")
+ self.assertEqual(channel.code, 200, channel.result)
+ tok1 = channel.json_body["access_token"]
+
+ channel = self._send_login(
+ "test.login_type", "localuser", test_field="", device_id="dev2"
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # make the initial request which returns a 401
+ channel = self._delete_device(tok1, "dev2")
+ self.assertEqual(channel.code, 401)
+ # Ensure that flows are what is expected. In particular, "password" should *not*
+ # be present.
+ self.assertIn({"stages": ["test.login_type"]}, channel.json_body["flows"])
+ session = channel.json_body["session"]
+
+ mock_password_provider.reset_mock()
+
+ # check that auth with password is rejected
+ body = {
+ "auth": {
+ "type": "m.login.password",
+ "identifier": {"type": "m.id.user", "user": "localuser"},
+ "password": "localpass",
+ "session": session,
+ },
+ }
+
+ channel = self._delete_device(tok1, "dev2", body)
+ self.assertEqual(channel.code, 400)
+ self.assertEqual(
+ "Password login has been disabled.", channel.json_body["error"]
+ )
+ mock_password_provider.check_auth.assert_not_called()
+ mock_password_provider.reset_mock()
+
+ # successful auth
+ body["auth"]["type"] = "test.login_type"
+ body["auth"]["test_field"] = "x"
+ channel = self._delete_device(tok1, "dev2", body)
+ self.assertEqual(channel.code, 200)
+ mock_password_provider.check_auth.assert_called_once_with(
+ "localuser", "test.login_type", {"test_field": "x"}
+ )
+
+ @override_config(
+ {
+ **providers_config(CustomAuthProvider),
+ "password_config": {"localdb_enabled": False},
+ }
+ )
+ def test_custom_auth_no_local_user_fallback(self):
+ """Test login with a custom auth provider where the local db is disabled"""
+ self.register_user("localuser", "localpass")
+
+ flows = self._get_login_flows()
+ self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
+
+ # password login shouldn't work and should be rejected with a 400
+ # ("unknown login type")
+ channel = self._send_password_login("localuser", "localpass")
+ self.assertEqual(channel.code, 400, channel.result)
+
+ def _get_login_flows(self) -> JsonDict:
+ channel = self.make_request("GET", "/_matrix/client/r0/login")
+ self.assertEqual(channel.code, 200, channel.result)
+ return channel.json_body["flows"]
+
+ def _send_password_login(self, user: str, password: str) -> FakeChannel:
+ return self._send_login(type="m.login.password", user=user, password=password)
+
+ def _send_login(self, type, user, **params) -> FakeChannel:
+ params.update({"identifier": {"type": "m.id.user", "user": user}, "type": type})
+ channel = self.make_request("POST", "/_matrix/client/r0/login", params)
+ return channel
+
+ def _start_delete_device_session(self, access_token, device_id) -> str:
+ """Make an initial delete device request, and return the UI Auth session ID"""
+ channel = self._delete_device(access_token, device_id)
+ self.assertEqual(channel.code, 401)
+ # Ensure that flows are what is expected.
+ self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
+ return channel.json_body["session"]
+
+ def _authed_delete_device(
+ self,
+ access_token: str,
+ device_id: str,
+ session: str,
+ user_id: str,
+ password: str,
+ ) -> FakeChannel:
+ """Make a delete device request, authenticating with the given uid/password"""
+ return self._delete_device(
+ access_token,
+ device_id,
+ {
+ "auth": {
+ "type": "m.login.password",
+ "identifier": {"type": "m.id.user", "user": user_id},
+ "password": password,
+ "session": session,
+ },
+ },
+ )
+
+ def _delete_device(
+ self, access_token: str, device: str, body: Union[JsonDict, bytes] = b"",
+ ) -> FakeChannel:
+ """Delete an individual device."""
+ channel = self.make_request(
+ "DELETE", "devices/" + device, body, access_token=access_token
+ )
+ return channel
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 306dcfe944..0794b32c9c 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -463,14 +463,14 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
- "server", http_client=None, federation_sender=Mock()
+ "server", federation_http_client=None, federation_sender=Mock()
)
return hs
def prepare(self, reactor, clock, hs):
self.federation_sender = hs.get_federation_sender()
self.event_builder_factory = hs.get_event_builder_factory()
- self.federation_handler = hs.get_handlers().federation_handler
+ self.federation_handler = hs.get_federation_handler()
self.presence_handler = hs.get_presence_handler()
# self.event_builder_for_2 = EventBuilderFactory(hs)
@@ -615,7 +615,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
self.store.get_latest_event_ids_in_room(room_id)
)
- event = self.get_success(builder.build(prev_event_ids))
+ event = self.get_success(builder.build(prev_event_ids, None))
self.get_success(self.federation_handler.on_receive_pdu(hostname, event))
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 8e95e53d9e..919547556b 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -20,7 +20,6 @@ from twisted.internet import defer
import synapse.types
from synapse.api.errors import AuthError, SynapseError
-from synapse.handlers.profile import MasterProfileHandler
from synapse.types import UserID
from tests import unittest
@@ -28,11 +27,6 @@ from tests.test_utils import make_awaitable
from tests.utils import setup_test_homeserver
-class ProfileHandlers:
- def __init__(self, hs):
- self.profile_handler = MasterProfileHandler(hs)
-
-
class ProfileTestCase(unittest.TestCase):
""" Tests profile management. """
@@ -50,9 +44,6 @@ class ProfileTestCase(unittest.TestCase):
hs = yield setup_test_homeserver(
self.addCleanup,
- http_client=None,
- handlers=None,
- resource_for_federation=Mock(),
federation_client=self.mock_federation,
federation_server=Mock(),
federation_registry=self.mock_registry,
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index cb7c0ed51a..bdf3d0a8a2 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -18,7 +18,6 @@ from mock import Mock
from synapse.api.auth import Auth
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, ResourceLimitError, SynapseError
-from synapse.handlers.register import RegistrationHandler
from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import RoomAlias, UserID, create_requester
@@ -29,11 +28,6 @@ from tests.utils import mock_getRawHeaders
from .. import unittest
-class RegistrationHandlers:
- def __init__(self, hs):
- self.registration_handler = RegistrationHandler(hs)
-
-
class RegistrationTestCase(unittest.HomeserverTestCase):
""" Tests the RegistrationHandler. """
@@ -154,7 +148,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
room_alias_str = "#room:test"
user_id = self.get_success(self.handler.register_user(localpart="jeff"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
- directory_handler = self.hs.get_handlers().directory_handler
+ directory_handler = self.hs.get_directory_handler()
room_alias = RoomAlias.from_string(room_alias_str)
room_id = self.get_success(directory_handler.get_association(room_alias))
@@ -193,7 +187,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
user_id = self.get_success(self.handler.register_user(localpart="support"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
- directory_handler = self.hs.get_handlers().directory_handler
+ directory_handler = self.hs.get_directory_handler()
room_alias = RoomAlias.from_string(room_alias_str)
self.get_failure(directory_handler.get_association(room_alias), SynapseError)
@@ -205,7 +199,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.store.is_real_user = Mock(return_value=make_awaitable(True))
user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
- directory_handler = self.hs.get_handlers().directory_handler
+ directory_handler = self.hs.get_directory_handler()
room_alias = RoomAlias.from_string(room_alias_str)
room_id = self.get_success(directory_handler.get_association(room_alias))
@@ -237,7 +231,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
user_id = self.get_success(self.handler.register_user(localpart="jeff"))
# Ensure the room was created.
- directory_handler = self.hs.get_handlers().directory_handler
+ directory_handler = self.hs.get_directory_handler()
room_alias = RoomAlias.from_string(room_alias_str)
room_id = self.get_success(directory_handler.get_association(room_alias))
@@ -266,7 +260,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
user_id = self.get_success(self.handler.register_user(localpart="jeff"))
# Ensure the room was created.
- directory_handler = self.hs.get_handlers().directory_handler
+ directory_handler = self.hs.get_directory_handler()
room_alias = RoomAlias.from_string(room_alias_str)
room_id = self.get_success(directory_handler.get_association(room_alias))
@@ -304,7 +298,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
user_id = self.get_success(self.handler.register_user(localpart="jeff"))
# Ensure the room was created.
- directory_handler = self.hs.get_handlers().directory_handler
+ directory_handler = self.hs.get_directory_handler()
room_alias = RoomAlias.from_string(room_alias_str)
room_id = self.get_success(directory_handler.get_association(room_alias))
@@ -347,7 +341,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
)
# Ensure the room was created.
- directory_handler = self.hs.get_handlers().directory_handler
+ directory_handler = self.hs.get_directory_handler()
room_alias = RoomAlias.from_string(room_alias_str)
room_id = self.get_success(directory_handler.get_association(room_alias))
@@ -384,7 +378,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
user_id = self.get_success(self.handler.register_user(localpart="jeff"))
# Ensure the room was created.
- directory_handler = self.hs.get_handlers().directory_handler
+ directory_handler = self.hs.get_directory_handler()
room_alias = RoomAlias.from_string(room_alias_str)
room_id = self.get_success(directory_handler.get_association(room_alias))
@@ -413,7 +407,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
)
)
self.get_success(
- event_creation_handler.send_nonmember_event(requester, event, context)
+ event_creation_handler.handle_new_client_event(requester, event, context)
)
# Register a second user, which won't be be in the room (or even have an invite)
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
new file mode 100644
index 0000000000..548038214b
--- /dev/null
+++ b/tests/handlers/test_saml.py
@@ -0,0 +1,265 @@
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional
+
+from mock import Mock
+
+import attr
+
+from synapse.api.errors import RedirectException
+
+from tests.test_utils import simple_async_mock
+from tests.unittest import HomeserverTestCase, override_config
+
+# Check if we have the dependencies to run the tests.
+try:
+ import saml2.config
+ from saml2.sigver import SigverError
+
+ has_saml2 = True
+
+ # pysaml2 can be installed and imported, but might not be able to find xmlsec1.
+ config = saml2.config.SPConfig()
+ try:
+ config.load({"metadata": {}})
+ has_xmlsec1 = True
+ except SigverError:
+ has_xmlsec1 = False
+except ImportError:
+ has_saml2 = False
+ has_xmlsec1 = False
+
+# These are a few constants that are used as config parameters in the tests.
+BASE_URL = "https://synapse/"
+
+
+@attr.s
+class FakeAuthnResponse:
+ ava = attr.ib(type=dict)
+ assertions = attr.ib(type=list, factory=list)
+ in_response_to = attr.ib(type=Optional[str], default=None)
+
+
+class TestMappingProvider:
+ def __init__(self, config, module):
+ pass
+
+ @staticmethod
+ def parse_config(config):
+ return
+
+ @staticmethod
+ def get_saml_attributes(config):
+ return {"uid"}, {"displayName"}
+
+ def get_remote_user_id(self, saml_response, client_redirect_url):
+ return saml_response.ava["uid"]
+
+ def saml_response_to_user_attributes(
+ self, saml_response, failures, client_redirect_url
+ ):
+ localpart = saml_response.ava["username"] + (str(failures) if failures else "")
+ return {"mxid_localpart": localpart, "displayname": None}
+
+
+class TestRedirectMappingProvider(TestMappingProvider):
+ def saml_response_to_user_attributes(
+ self, saml_response, failures, client_redirect_url
+ ):
+ raise RedirectException(b"https://custom-saml-redirect/")
+
+
+class SamlHandlerTestCase(HomeserverTestCase):
+ def default_config(self):
+ config = super().default_config()
+ config["public_baseurl"] = BASE_URL
+ saml_config = {
+ "sp_config": {"metadata": {}},
+ # Disable grandfathering.
+ "grandfathered_mxid_source_attribute": None,
+ "user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
+ }
+
+ # Update this config with what's in the default config so that
+ # override_config works as expected.
+ saml_config.update(config.get("saml2_config", {}))
+ config["saml2_config"] = saml_config
+
+ return config
+
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver()
+
+ self.handler = hs.get_saml_handler()
+
+ # Reduce the number of attempts when generating MXIDs.
+ sso_handler = hs.get_sso_handler()
+ sso_handler._MAP_USERNAME_RETRIES = 3
+
+ return hs
+
+ if not has_saml2:
+ skip = "Requires pysaml2"
+ elif not has_xmlsec1:
+ skip = "Requires xmlsec1"
+
+ def test_map_saml_response_to_user(self):
+ """Ensure that mapping the SAML response returned from a provider to an MXID works properly."""
+
+ # stub out the auth handler
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
+ # send a mocked-up SAML response to the callback
+ saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
+ request = _mock_request()
+ self.get_success(
+ self.handler._handle_authn_response(request, saml_response, "redirect_uri")
+ )
+
+ # check that the auth handler got called as expected
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user:test", request, "redirect_uri", None
+ )
+
+ @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
+ def test_map_saml_response_to_existing_user(self):
+ """Existing users can log in with SAML account."""
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.register_user(user_id="@test_user:test", password_hash=None)
+ )
+
+ # stub out the auth handler
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
+ # Map a user via SSO.
+ saml_response = FakeAuthnResponse(
+ {"uid": "tester", "mxid": ["test_user"], "username": "test_user"}
+ )
+ request = _mock_request()
+ self.get_success(
+ self.handler._handle_authn_response(request, saml_response, "")
+ )
+
+ # check that the auth handler got called as expected
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user:test", request, "", None
+ )
+
+ # Subsequent calls should map to the same mxid.
+ auth_handler.complete_sso_login.reset_mock()
+ self.get_success(
+ self.handler._handle_authn_response(request, saml_response, "")
+ )
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user:test", request, "", None
+ )
+
+ def test_map_saml_response_to_invalid_localpart(self):
+ """If the mapping provider generates an invalid localpart it should be rejected."""
+
+ # stub out the auth handler
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
+ # mock out the error renderer too
+ sso_handler = self.hs.get_sso_handler()
+ sso_handler.render_error = Mock(return_value=None)
+
+ saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"})
+ request = _mock_request()
+ self.get_success(
+ self.handler._handle_authn_response(request, saml_response, ""),
+ )
+ sso_handler.render_error.assert_called_once_with(
+ request, "mapping_error", "localpart is invalid: föö"
+ )
+ auth_handler.complete_sso_login.assert_not_called()
+
+ def test_map_saml_response_to_user_retries(self):
+ """The mapping provider can retry generating an MXID if the MXID is already in use."""
+
+ # stub out the auth handler and error renderer
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+ sso_handler = self.hs.get_sso_handler()
+ sso_handler.render_error = Mock(return_value=None)
+
+ # register a user to occupy the first-choice MXID
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.register_user(user_id="@test_user:test", password_hash=None)
+ )
+
+ # send the fake SAML response
+ saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
+ request = _mock_request()
+ self.get_success(
+ self.handler._handle_authn_response(request, saml_response, ""),
+ )
+
+ # test_user is already taken, so test_user1 gets registered instead.
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user1:test", request, "", None
+ )
+ auth_handler.complete_sso_login.reset_mock()
+
+ # Register all of the potential mxids for a particular SAML username.
+ self.get_success(
+ store.register_user(user_id="@tester:test", password_hash=None)
+ )
+ for i in range(1, 3):
+ self.get_success(
+ store.register_user(user_id="@tester%d:test" % i, password_hash=None)
+ )
+
+ # Now attempt to map to a username, this will fail since all potential usernames are taken.
+ saml_response = FakeAuthnResponse({"uid": "tester", "username": "tester"})
+ self.get_success(
+ self.handler._handle_authn_response(request, saml_response, ""),
+ )
+ sso_handler.render_error.assert_called_once_with(
+ request,
+ "mapping_error",
+ "Unable to generate a Matrix ID from the SSO response",
+ )
+ auth_handler.complete_sso_login.assert_not_called()
+
+ @override_config(
+ {
+ "saml2_config": {
+ "user_mapping_provider": {
+ "module": __name__ + ".TestRedirectMappingProvider"
+ },
+ }
+ }
+ )
+ def test_map_saml_response_redirect(self):
+ """Test a mapping provider that raises a RedirectException"""
+
+ saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
+ request = _mock_request()
+ e = self.get_failure(
+ self.handler._handle_authn_response(request, saml_response, ""),
+ RedirectException,
+ )
+ self.assertEqual(e.value.location, b"https://custom-saml-redirect/")
+
+
+def _mock_request():
+ """Returns a mock which will stand in as a SynapseRequest"""
+ return Mock(spec=["getClientIP", "get_user_agent"])
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index e178d7765b..e62586142e 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -16,7 +16,7 @@
from synapse.api.errors import Codes, ResourceLimitError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION
from synapse.handlers.sync import SyncConfig
-from synapse.types import UserID
+from synapse.types import UserID, create_requester
import tests.unittest
import tests.utils
@@ -38,6 +38,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
user_id1 = "@user1:test"
user_id2 = "@user2:test"
sync_config = self._generate_sync_config(user_id1)
+ requester = create_requester(user_id1)
self.reactor.advance(100) # So we get not 0 time
self.auth_blocking._limit_usage_by_mau = True
@@ -45,21 +46,26 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# Check that the happy case does not throw errors
self.get_success(self.store.upsert_monthly_active_user(user_id1))
- self.get_success(self.sync_handler.wait_for_sync_for_user(sync_config))
+ self.get_success(
+ self.sync_handler.wait_for_sync_for_user(requester, sync_config)
+ )
# Test that global lock works
self.auth_blocking._hs_disabled = True
e = self.get_failure(
- self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError
+ self.sync_handler.wait_for_sync_for_user(requester, sync_config),
+ ResourceLimitError,
)
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.auth_blocking._hs_disabled = False
sync_config = self._generate_sync_config(user_id2)
+ requester = create_requester(user_id2)
e = self.get_failure(
- self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError
+ self.sync_handler.wait_for_sync_for_user(requester, sync_config),
+ ResourceLimitError,
)
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 3fec09ea8a..96e5bdac4a 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -15,18 +15,20 @@
import json
+from typing import Dict
from mock import ANY, Mock, call
from twisted.internet import defer
+from twisted.web.resource import Resource
from synapse.api.errors import AuthError
+from synapse.federation.transport.server import TransportLayerServer
from synapse.types import UserID, create_requester
from tests import unittest
from tests.test_utils import make_awaitable
from tests.unittest import override_config
-from tests.utils import register_federation_servlets
# Some local users to test with
U_APPLE = UserID.from_string("@apple:test")
@@ -53,8 +55,6 @@ def _make_edu_transaction_json(edu_type, content):
class TypingNotificationsTestCase(unittest.HomeserverTestCase):
- servlets = [register_federation_servlets]
-
def make_homeserver(self, reactor, clock):
# we mock out the keyring so as to skip the authentication check on the
# federation API call.
@@ -65,40 +65,23 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
mock_federation_client = Mock(spec=["put_json"])
mock_federation_client.put_json.return_value = defer.succeed((200, "OK"))
- datastores = Mock()
- datastores.main = Mock(
- spec=[
- # Bits that Federation needs
- "prep_send_transaction",
- "delivered_txn",
- "get_received_txn_response",
- "set_received_txn_response",
- "get_destination_last_successful_stream_ordering",
- "get_destination_retry_timings",
- "get_devices_by_remote",
- "maybe_store_room_on_invite",
- # Bits that user_directory needs
- "get_user_directory_stream_pos",
- "get_current_state_deltas",
- "get_device_updates_by_remote",
- "get_room_max_stream_ordering",
- ]
- )
-
# the tests assume that we are starting at unix time 1000
reactor.pump((1000,))
hs = self.setup_test_homeserver(
notifier=Mock(),
- http_client=mock_federation_client,
+ federation_http_client=mock_federation_client,
keyring=mock_keyring,
replication_streams={},
)
- hs.datastores = datastores
-
return hs
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ d = super().create_resource_dict()
+ d["/_matrix/federation"] = TransportLayerServer(self.hs)
+ return d
+
def prepare(self, reactor, clock, hs):
mock_notifier = hs.get_notifier()
self.on_new_event = mock_notifier.on_new_event
@@ -114,16 +97,16 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
"retry_interval": 0,
"failure_ts": None,
}
- self.datastore.get_destination_retry_timings.return_value = defer.succeed(
- retry_timings_res
+ self.datastore.get_destination_retry_timings = Mock(
+ return_value=defer.succeed(retry_timings_res)
)
- self.datastore.get_device_updates_by_remote.return_value = make_awaitable(
- (0, [])
+ self.datastore.get_device_updates_by_remote = Mock(
+ return_value=make_awaitable((0, []))
)
- self.datastore.get_destination_last_successful_stream_ordering.return_value = make_awaitable(
- None
+ self.datastore.get_destination_last_successful_stream_ordering = Mock(
+ return_value=make_awaitable(None)
)
def get_received_txn_response(*args):
@@ -145,17 +128,19 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
- def get_users_in_room(room_id):
- return defer.succeed({str(u) for u in self.room_members})
+ async def get_users_in_room(room_id):
+ return {str(u) for u in self.room_members}
self.datastore.get_users_in_room = get_users_in_room
- self.datastore.get_user_directory_stream_pos.side_effect = (
- # we deliberately return a non-None stream pos to avoid doing an initial_spam
- lambda: make_awaitable(1)
+ self.datastore.get_user_directory_stream_pos = Mock(
+ side_effect=(
+ # we deliberately return a non-None stream pos to avoid doing an initial_spam
+ lambda: make_awaitable(1)
+ )
)
- self.datastore.get_current_state_deltas.return_value = (0, None)
+ self.datastore.get_current_state_deltas = Mock(return_value=(0, None))
self.datastore.get_to_device_stream_token = lambda: 0
self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: make_awaitable(
@@ -212,7 +197,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
)
- put_json = self.hs.get_http_client().put_json
+ put_json = self.hs.get_federation_http_client().put_json
put_json.assert_called_once_with(
"farm",
path="/_matrix/federation/v1/send/1000000",
@@ -235,7 +220,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(self.event_source.get_current_key(), 0)
- (request, channel) = self.make_request(
+ channel = self.make_request(
"PUT",
"/_matrix/federation/v1/send/1000000",
_make_edu_transaction_json(
@@ -248,7 +233,6 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
),
federation_auth_origin=b"farm",
)
- self.render(request)
self.assertEqual(channel.code, 200)
self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
@@ -291,7 +275,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
- put_json = self.hs.get_http_client().put_json
+ put_json = self.hs.get_federation_http_client().put_json
put_json.assert_called_once_with(
"farm",
path="/_matrix/federation/v1/send/1000000",
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 87be94111f..9c886d671a 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -54,6 +54,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
user_id=support_user_id, password_hash=None, user_type=UserTypes.SUPPORT
)
)
+ regular_user_id = "@regular:test"
+ self.get_success(
+ self.store.register_user(user_id=regular_user_id, password_hash=None)
+ )
self.get_success(
self.handler.handle_local_profile_change(support_user_id, None)
@@ -63,13 +67,47 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
display_name = "display_name"
profile_info = ProfileInfo(avatar_url="avatar_url", display_name=display_name)
- regular_user_id = "@regular:test"
self.get_success(
self.handler.handle_local_profile_change(regular_user_id, profile_info)
)
profile = self.get_success(self.store.get_user_in_directory(regular_user_id))
self.assertTrue(profile["display_name"] == display_name)
+ def test_handle_local_profile_change_with_deactivated_user(self):
+ # create user
+ r_user_id = "@regular:test"
+ self.get_success(
+ self.store.register_user(user_id=r_user_id, password_hash=None)
+ )
+
+ # update profile
+ display_name = "Regular User"
+ profile_info = ProfileInfo(avatar_url="avatar_url", display_name=display_name)
+ self.get_success(
+ self.handler.handle_local_profile_change(r_user_id, profile_info)
+ )
+
+ # profile is in directory
+ profile = self.get_success(self.store.get_user_in_directory(r_user_id))
+ self.assertTrue(profile["display_name"] == display_name)
+
+ # deactivate user
+ self.get_success(self.store.set_user_deactivated_status(r_user_id, True))
+ self.get_success(self.handler.handle_user_deactivated(r_user_id))
+
+ # profile is not in directory
+ profile = self.get_success(self.store.get_user_in_directory(r_user_id))
+ self.assertTrue(profile is None)
+
+ # update profile after deactivation
+ self.get_success(
+ self.handler.handle_local_profile_change(r_user_id, profile_info)
+ )
+
+ # profile is furthermore not in directory
+ profile = self.get_success(self.store.get_user_in_directory(r_user_id))
+ self.assertTrue(profile is None)
+
def test_handle_user_deactivated_support_user(self):
s_user_id = "@support:test"
self.get_success(
@@ -270,7 +308,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
spam_checker = self.hs.get_spam_checker()
class AllowAll:
- def check_username_for_spam(self, user_profile):
+ async def check_username_for_spam(self, user_profile):
# Allow all users.
return False
@@ -283,7 +321,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# Configure a spam checker that filters all users.
class BlockAll:
- def check_username_for_spam(self, user_profile):
+ async def check_username_for_spam(self, user_profile):
# All users are spammy.
return True
@@ -534,18 +572,16 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
self.helper.join(room, user=u2)
# Assert user directory is not empty
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", b"user_directory/search", b'{"search_term":"user2"}'
)
- self.render(request)
self.assertEquals(200, channel.code, channel.result)
self.assertTrue(len(channel.json_body["results"]) > 0)
# Disable user directory and check search returns nothing
self.config.user_directory_search_enabled = False
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", b"user_directory/search", b'{"search_term":"user2"}'
)
- self.render(request)
self.assertEquals(200, channel.code, channel.result)
self.assertTrue(len(channel.json_body["results"]) == 0)
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 8b5ad4574f..4e51839d0f 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -17,6 +17,7 @@ import logging
from mock import Mock
import treq
+from netaddr import IPSet
from service_identity import VerificationError
from zope.interface import implementer
@@ -35,6 +36,7 @@ from synapse.crypto.context_factory import FederationPolicyForHTTPS
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.http.federation.srv_resolver import Server
from synapse.http.federation.well_known_resolver import (
+ WELL_KNOWN_MAX_SIZE,
WellKnownResolver,
_cache_period_from_headers,
)
@@ -103,6 +105,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
reactor=self.reactor,
tls_client_options_factory=self.tls_factory,
user_agent="test-agent", # Note that this is unused since _well_known_resolver is provided.
+ ip_blacklist=IPSet(),
_srv_resolver=self.mock_resolver,
_well_known_resolver=self.well_known_resolver,
)
@@ -736,6 +739,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
reactor=self.reactor,
tls_client_options_factory=tls_factory,
user_agent=b"test-agent", # This is unused since _well_known_resolver is passed below.
+ ip_blacklist=IPSet(),
_srv_resolver=self.mock_resolver,
_well_known_resolver=WellKnownResolver(
self.reactor,
@@ -1104,6 +1108,32 @@ class MatrixFederationAgentTests(unittest.TestCase):
r = self.successResultOf(fetch_d)
self.assertEqual(r.delegated_server, None)
+ def test_well_known_too_large(self):
+ """A well-known query that returns a result which is too large should be rejected."""
+ self.reactor.lookups["testserv"] = "1.2.3.4"
+
+ fetch_d = defer.ensureDeferred(
+ self.well_known_resolver.get_well_known(b"testserv")
+ )
+
+ # there should be an attempt to connect on port 443 for the .well-known
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 443)
+
+ self._handle_well_known_connection(
+ client_factory,
+ expected_sni=b"testserv",
+ response_headers={b"Cache-Control": b"max-age=1000"},
+ content=b'{ "m.server": "' + (b"a" * WELL_KNOWN_MAX_SIZE) + b'" }',
+ )
+
+ # The result is sucessful, but disabled delegation.
+ r = self.successResultOf(fetch_d)
+ self.assertIsNone(r.delegated_server)
+
def test_srv_fallbacks(self):
"""Test that other SRV results are tried if the first one fails.
"""
diff --git a/tests/http/test_additional_resource.py b/tests/http/test_additional_resource.py
index 62d36c2906..453391a5a5 100644
--- a/tests/http/test_additional_resource.py
+++ b/tests/http/test_additional_resource.py
@@ -17,6 +17,7 @@
from synapse.http.additional_resource import AdditionalResource
from synapse.http.server import respond_with_json
+from tests.server import FakeSite, make_request
from tests.unittest import HomeserverTestCase
@@ -43,20 +44,18 @@ class AdditionalResourceTests(HomeserverTestCase):
def test_async(self):
handler = _AsyncTestCustomEndpoint({}, None).handle_request
- self.resource = AdditionalResource(self.hs, handler)
+ resource = AdditionalResource(self.hs, handler)
- request, channel = self.make_request("GET", "/")
- self.render(request)
+ channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
- self.assertEqual(request.code, 200)
+ self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body, {"some_key": "some_value_async"})
def test_sync(self):
handler = _SyncTestCustomEndpoint({}, None).handle_request
- self.resource = AdditionalResource(self.hs, handler)
+ resource = AdditionalResource(self.hs, handler)
- request, channel = self.make_request("GET", "/")
- self.render(request)
+ channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
- self.assertEqual(request.code, 200)
+ self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body, {"some_key": "some_value_sync"})
diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
index 22abf76515..9a56e1c14a 100644
--- a/tests/http/test_proxyagent.py
+++ b/tests/http/test_proxyagent.py
@@ -15,12 +15,14 @@
import logging
import treq
+from netaddr import IPSet
from twisted.internet import interfaces # noqa: F401
from twisted.internet.protocol import Factory
from twisted.protocols.tls import TLSMemoryBIOFactory
from twisted.web.http import HTTPChannel
+from synapse.http.client import BlacklistingReactorWrapper
from synapse.http.proxyagent import ProxyAgent
from tests.http import TestServerTLSConnectionFactory, get_test_https_policy
@@ -292,6 +294,134 @@ class MatrixFederationAgentTests(TestCase):
body = self.successResultOf(treq.content(resp))
self.assertEqual(body, b"result")
+ def test_http_request_via_proxy_with_blacklist(self):
+ # The blacklist includes the configured proxy IP.
+ agent = ProxyAgent(
+ BlacklistingReactorWrapper(
+ self.reactor, ip_whitelist=None, ip_blacklist=IPSet(["1.0.0.0/8"])
+ ),
+ self.reactor,
+ http_proxy=b"proxy.com:8888",
+ )
+
+ self.reactor.lookups["proxy.com"] = "1.2.3.5"
+ d = agent.request(b"GET", b"http://test.com")
+
+ # there should be a pending TCP connection
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.5")
+ self.assertEqual(port, 8888)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory, _get_test_protocol_factory()
+ )
+
+ # the FakeTransport is async, so we need to pump the reactor
+ self.reactor.advance(0)
+
+ # now there should be a pending request
+ self.assertEqual(len(http_server.requests), 1)
+
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"http://test.com")
+ self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+ request.write(b"result")
+ request.finish()
+
+ self.reactor.advance(0)
+
+ resp = self.successResultOf(d)
+ body = self.successResultOf(treq.content(resp))
+ self.assertEqual(body, b"result")
+
+ def test_https_request_via_proxy_with_blacklist(self):
+ # The blacklist includes the configured proxy IP.
+ agent = ProxyAgent(
+ BlacklistingReactorWrapper(
+ self.reactor, ip_whitelist=None, ip_blacklist=IPSet(["1.0.0.0/8"])
+ ),
+ self.reactor,
+ contextFactory=get_test_https_policy(),
+ https_proxy=b"proxy.com",
+ )
+
+ self.reactor.lookups["proxy.com"] = "1.2.3.5"
+ d = agent.request(b"GET", b"https://test.com/abc")
+
+ # there should be a pending TCP connection
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.5")
+ self.assertEqual(port, 1080)
+
+ # make a test HTTP server, and wire up the client
+ proxy_server = self._make_connection(
+ client_factory, _get_test_protocol_factory()
+ )
+
+ # fish the transports back out so that we can do the old switcheroo
+ s2c_transport = proxy_server.transport
+ client_protocol = s2c_transport.other
+ c2s_transport = client_protocol.transport
+
+ # the FakeTransport is async, so we need to pump the reactor
+ self.reactor.advance(0)
+
+ # now there should be a pending CONNECT request
+ self.assertEqual(len(proxy_server.requests), 1)
+
+ request = proxy_server.requests[0]
+ self.assertEqual(request.method, b"CONNECT")
+ self.assertEqual(request.path, b"test.com:443")
+
+ # tell the proxy server not to close the connection
+ proxy_server.persistent = True
+
+ # this just stops the http Request trying to do a chunked response
+ # request.setHeader(b"Content-Length", b"0")
+ request.finish()
+
+ # now we can replace the proxy channel with a new, SSL-wrapped HTTP channel
+ ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory())
+ ssl_protocol = ssl_factory.buildProtocol(None)
+ http_server = ssl_protocol.wrappedProtocol
+
+ ssl_protocol.makeConnection(
+ FakeTransport(client_protocol, self.reactor, ssl_protocol)
+ )
+ c2s_transport.other = ssl_protocol
+
+ self.reactor.advance(0)
+
+ server_name = ssl_protocol._tlsConnection.get_servername()
+ expected_sni = b"test.com"
+ self.assertEqual(
+ server_name,
+ expected_sni,
+ "Expected SNI %s but got %s" % (expected_sni, server_name),
+ )
+
+ # now there should be a pending request
+ self.assertEqual(len(http_server.requests), 1)
+
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"/abc")
+ self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+ request.write(b"result")
+ request.finish()
+
+ self.reactor.advance(0)
+
+ resp = self.successResultOf(d)
+ body = self.successResultOf(treq.content(resp))
+ self.assertEqual(body, b"result")
+
def _wrap_server_factory_for_tls(factory, sanlist=None):
"""Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
diff --git a/tests/logging/__init__.py b/tests/logging/__init__.py
index e69de29bb2..a58d51441c 100644
--- a/tests/logging/__init__.py
+++ b/tests/logging/__init__.py
@@ -0,0 +1,34 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+
+class LoggerCleanupMixin:
+ def get_logger(self, handler):
+ """
+ Attach a handler to a logger and add clean-ups to remove revert this.
+ """
+ # Create a logger and add the handler to it.
+ logger = logging.getLogger(__name__)
+ logger.addHandler(handler)
+
+ # Ensure the logger actually logs something.
+ logger.setLevel(logging.INFO)
+
+ # Ensure the logger gets cleaned-up appropriately.
+ self.addCleanup(logger.removeHandler, handler)
+ self.addCleanup(logger.setLevel, logging.NOTSET)
+
+ return logger
diff --git a/tests/logging/test_remote_handler.py b/tests/logging/test_remote_handler.py
new file mode 100644
index 0000000000..4bc27a1d7d
--- /dev/null
+++ b/tests/logging/test_remote_handler.py
@@ -0,0 +1,169 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from twisted.test.proto_helpers import AccumulatingProtocol
+
+from synapse.logging import RemoteHandler
+
+from tests.logging import LoggerCleanupMixin
+from tests.server import FakeTransport, get_clock
+from tests.unittest import TestCase
+
+
+def connect_logging_client(reactor, client_id):
+ # This is essentially tests.server.connect_client, but disabling autoflush on
+ # the client transport. This is necessary to avoid an infinite loop due to
+ # sending of data via the logging transport causing additional logs to be
+ # written.
+ factory = reactor.tcpClients.pop(client_id)[2]
+ client = factory.buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, reactor))
+ client.makeConnection(FakeTransport(server, reactor, autoflush=False))
+
+ return client, server
+
+
+class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
+ def setUp(self):
+ self.reactor, _ = get_clock()
+
+ def test_log_output(self):
+ """
+ The remote handler delivers logs over TCP.
+ """
+ handler = RemoteHandler("127.0.0.1", 9000, _reactor=self.reactor)
+ logger = self.get_logger(handler)
+
+ logger.info("Hello there, %s!", "wally")
+
+ # Trigger the connection
+ client, server = connect_logging_client(self.reactor, 0)
+
+ # Trigger data being sent
+ client.transport.flush()
+
+ # One log message, with a single trailing newline
+ logs = server.data.decode("utf8").splitlines()
+ self.assertEqual(len(logs), 1)
+ self.assertEqual(server.data.count(b"\n"), 1)
+
+ # Ensure the data passed through properly.
+ self.assertEqual(logs[0], "Hello there, wally!")
+
+ def test_log_backpressure_debug(self):
+ """
+ When backpressure is hit, DEBUG logs will be shed.
+ """
+ handler = RemoteHandler(
+ "127.0.0.1", 9000, maximum_buffer=10, _reactor=self.reactor
+ )
+ logger = self.get_logger(handler)
+
+ # Send some debug messages
+ for i in range(0, 3):
+ logger.debug("debug %s" % (i,))
+
+ # Send a bunch of useful messages
+ for i in range(0, 7):
+ logger.info("info %s" % (i,))
+
+ # The last debug message pushes it past the maximum buffer
+ logger.debug("too much debug")
+
+ # Allow the reconnection
+ client, server = connect_logging_client(self.reactor, 0)
+ client.transport.flush()
+
+ # Only the 7 infos made it through, the debugs were elided
+ logs = server.data.splitlines()
+ self.assertEqual(len(logs), 7)
+ self.assertNotIn(b"debug", server.data)
+
+ def test_log_backpressure_info(self):
+ """
+ When backpressure is hit, DEBUG and INFO logs will be shed.
+ """
+ handler = RemoteHandler(
+ "127.0.0.1", 9000, maximum_buffer=10, _reactor=self.reactor
+ )
+ logger = self.get_logger(handler)
+
+ # Send some debug messages
+ for i in range(0, 3):
+ logger.debug("debug %s" % (i,))
+
+ # Send a bunch of useful messages
+ for i in range(0, 10):
+ logger.warning("warn %s" % (i,))
+
+ # Send a bunch of info messages
+ for i in range(0, 3):
+ logger.info("info %s" % (i,))
+
+ # The last debug message pushes it past the maximum buffer
+ logger.debug("too much debug")
+
+ # Allow the reconnection
+ client, server = connect_logging_client(self.reactor, 0)
+ client.transport.flush()
+
+ # The 10 warnings made it through, the debugs and infos were elided
+ logs = server.data.splitlines()
+ self.assertEqual(len(logs), 10)
+ self.assertNotIn(b"debug", server.data)
+ self.assertNotIn(b"info", server.data)
+
+ def test_log_backpressure_cut_middle(self):
+ """
+ When backpressure is hit, and no more DEBUG and INFOs cannot be culled,
+ it will cut the middle messages out.
+ """
+ handler = RemoteHandler(
+ "127.0.0.1", 9000, maximum_buffer=10, _reactor=self.reactor
+ )
+ logger = self.get_logger(handler)
+
+ # Send a bunch of useful messages
+ for i in range(0, 20):
+ logger.warning("warn %s" % (i,))
+
+ # Allow the reconnection
+ client, server = connect_logging_client(self.reactor, 0)
+ client.transport.flush()
+
+ # The first five and last five warnings made it through, the debugs and
+ # infos were elided
+ logs = server.data.decode("utf8").splitlines()
+ self.assertEqual(
+ ["warn %s" % (i,) for i in range(5)]
+ + ["warn %s" % (i,) for i in range(15, 20)],
+ logs,
+ )
+
+ def test_cancel_connection(self):
+ """
+ Gracefully handle the connection being cancelled.
+ """
+ handler = RemoteHandler(
+ "127.0.0.1", 9000, maximum_buffer=10, _reactor=self.reactor
+ )
+ logger = self.get_logger(handler)
+
+ # Send a message.
+ logger.info("Hello there, %s!", "wally")
+
+ # Do not accept the connection and shutdown. This causes the pending
+ # connection to be cancelled (and should not raise any exceptions).
+ handler.close()
diff --git a/tests/logging/test_structured.py b/tests/logging/test_structured.py
deleted file mode 100644
index d36f5f426c..0000000000
--- a/tests/logging/test_structured.py
+++ /dev/null
@@ -1,214 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2019 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-import os
-import os.path
-import shutil
-import sys
-import textwrap
-
-from twisted.logger import Logger, eventAsText, eventsFromJSONLogFile
-
-from synapse.config.logger import setup_logging
-from synapse.logging._structured import setup_structured_logging
-from synapse.logging.context import LoggingContext
-
-from tests.unittest import DEBUG, HomeserverTestCase
-
-
-class FakeBeginner:
- def beginLoggingTo(self, observers, **kwargs):
- self.observers = observers
-
-
-class StructuredLoggingTestBase:
- """
- Test base that registers a cleanup handler to reset the stdlib log handler
- to 'unset'.
- """
-
- def prepare(self, reactor, clock, hs):
- def _cleanup():
- logging.getLogger("synapse").setLevel(logging.NOTSET)
-
- self.addCleanup(_cleanup)
-
-
-class StructuredLoggingTestCase(StructuredLoggingTestBase, HomeserverTestCase):
- """
- Tests for Synapse's structured logging support.
- """
-
- def test_output_to_json_round_trip(self):
- """
- Synapse logs can be outputted to JSON and then read back again.
- """
- temp_dir = self.mktemp()
- os.mkdir(temp_dir)
- self.addCleanup(shutil.rmtree, temp_dir)
-
- json_log_file = os.path.abspath(os.path.join(temp_dir, "out.json"))
-
- log_config = {
- "drains": {"jsonfile": {"type": "file_json", "location": json_log_file}}
- }
-
- # Begin the logger with our config
- beginner = FakeBeginner()
- setup_structured_logging(
- self.hs, self.hs.config, log_config, logBeginner=beginner
- )
-
- # Make a logger and send an event
- logger = Logger(
- namespace="tests.logging.test_structured", observer=beginner.observers[0]
- )
- logger.info("Hello there, {name}!", name="wally")
-
- # Read the log file and check it has the event we sent
- with open(json_log_file, "r") as f:
- logged_events = list(eventsFromJSONLogFile(f))
- self.assertEqual(len(logged_events), 1)
-
- # The event pulled from the file should render fine
- self.assertEqual(
- eventAsText(logged_events[0], includeTimestamp=False),
- "[tests.logging.test_structured#info] Hello there, wally!",
- )
-
- def test_output_to_text(self):
- """
- Synapse logs can be outputted to text.
- """
- temp_dir = self.mktemp()
- os.mkdir(temp_dir)
- self.addCleanup(shutil.rmtree, temp_dir)
-
- log_file = os.path.abspath(os.path.join(temp_dir, "out.log"))
-
- log_config = {"drains": {"file": {"type": "file", "location": log_file}}}
-
- # Begin the logger with our config
- beginner = FakeBeginner()
- setup_structured_logging(
- self.hs, self.hs.config, log_config, logBeginner=beginner
- )
-
- # Make a logger and send an event
- logger = Logger(
- namespace="tests.logging.test_structured", observer=beginner.observers[0]
- )
- logger.info("Hello there, {name}!", name="wally")
-
- # Read the log file and check it has the event we sent
- with open(log_file, "r") as f:
- logged_events = f.read().strip().split("\n")
- self.assertEqual(len(logged_events), 1)
-
- # The event pulled from the file should render fine
- self.assertTrue(
- logged_events[0].endswith(
- " - tests.logging.test_structured - INFO - None - Hello there, wally!"
- )
- )
-
- def test_collects_logcontext(self):
- """
- Test that log outputs have the attached logging context.
- """
- log_config = {"drains": {}}
-
- # Begin the logger with our config
- beginner = FakeBeginner()
- publisher = setup_structured_logging(
- self.hs, self.hs.config, log_config, logBeginner=beginner
- )
-
- logs = []
-
- publisher.addObserver(logs.append)
-
- # Make a logger and send an event
- logger = Logger(
- namespace="tests.logging.test_structured", observer=beginner.observers[0]
- )
-
- with LoggingContext("testcontext", request="somereq"):
- logger.info("Hello there, {name}!", name="steve")
-
- self.assertEqual(len(logs), 1)
- self.assertEqual(logs[0]["request"], "somereq")
-
-
-class StructuredLoggingConfigurationFileTestCase(
- StructuredLoggingTestBase, HomeserverTestCase
-):
- def make_homeserver(self, reactor, clock):
-
- tempdir = self.mktemp()
- os.mkdir(tempdir)
- log_config_file = os.path.abspath(os.path.join(tempdir, "log.config.yaml"))
- self.homeserver_log = os.path.abspath(os.path.join(tempdir, "homeserver.log"))
-
- config = self.default_config()
- config["log_config"] = log_config_file
-
- with open(log_config_file, "w") as f:
- f.write(
- textwrap.dedent(
- """\
- structured: true
-
- drains:
- file:
- type: file_json
- location: %s
- """
- % (self.homeserver_log,)
- )
- )
-
- self.addCleanup(self._sys_cleanup)
-
- return self.setup_test_homeserver(config=config)
-
- def _sys_cleanup(self):
- sys.stdout = sys.__stdout__
- sys.stderr = sys.__stderr__
-
- # Do not remove! We need the logging system to be set other than WARNING.
- @DEBUG
- def test_log_output(self):
- """
- When a structured logging config is given, Synapse will use it.
- """
- beginner = FakeBeginner()
- publisher = setup_logging(self.hs, self.hs.config, logBeginner=beginner)
-
- # Make a logger and send an event
- logger = Logger(namespace="tests.logging.test_structured", observer=publisher)
-
- with LoggingContext("testcontext", request="somereq"):
- logger.info("Hello there, {name}!", name="steve")
-
- with open(self.homeserver_log, "r") as f:
- logged_events = [
- eventAsText(x, includeTimestamp=False) for x in eventsFromJSONLogFile(f)
- ]
-
- logs = "\n".join(logged_events)
- self.assertTrue("***** STARTING SERVER *****" in logs)
- self.assertTrue("Hello there, steve!" in logs)
diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py
index 4cf81f7128..48a74e2eee 100644
--- a/tests/logging/test_terse_json.py
+++ b/tests/logging/test_terse_json.py
@@ -14,221 +14,124 @@
# limitations under the License.
import json
-from collections import Counter
+import logging
+from io import StringIO
-from twisted.logger import Logger
+from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter
+from synapse.logging.context import LoggingContext, LoggingContextFilter
-from synapse.logging._structured import setup_structured_logging
+from tests.logging import LoggerCleanupMixin
+from tests.unittest import TestCase
-from tests.server import connect_client
-from tests.unittest import HomeserverTestCase
-from .test_structured import FakeBeginner, StructuredLoggingTestBase
+class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
+ def setUp(self):
+ self.output = StringIO()
+ def get_log_line(self):
+ # One log message, with a single trailing newline.
+ data = self.output.getvalue()
+ logs = data.splitlines()
+ self.assertEqual(len(logs), 1)
+ self.assertEqual(data.count("\n"), 1)
+ return json.loads(logs[0])
-class TerseJSONTCPTestCase(StructuredLoggingTestBase, HomeserverTestCase):
- def test_log_output(self):
+ def test_terse_json_output(self):
"""
- The Terse JSON outputter delivers simplified structured logs over TCP.
+ The Terse JSON formatter converts log messages to JSON.
"""
- log_config = {
- "drains": {
- "tersejson": {
- "type": "network_json_terse",
- "host": "127.0.0.1",
- "port": 8000,
- }
- }
- }
-
- # Begin the logger with our config
- beginner = FakeBeginner()
- setup_structured_logging(
- self.hs, self.hs.config, log_config, logBeginner=beginner
- )
-
- logger = Logger(
- namespace="tests.logging.test_terse_json", observer=beginner.observers[0]
- )
- logger.info("Hello there, {name}!", name="wally")
-
- # Trigger the connection
- self.pump()
+ handler = logging.StreamHandler(self.output)
+ handler.setFormatter(TerseJsonFormatter())
+ logger = self.get_logger(handler)
- _, server = connect_client(self.reactor, 0)
+ logger.info("Hello there, %s!", "wally")
- # Trigger data being sent
- self.pump()
-
- # One log message, with a single trailing newline
- logs = server.data.decode("utf8").splitlines()
- self.assertEqual(len(logs), 1)
- self.assertEqual(server.data.count(b"\n"), 1)
-
- log = json.loads(logs[0])
+ log = self.get_log_line()
# The terse logger should give us these keys.
expected_log_keys = [
"log",
"time",
"level",
- "log_namespace",
- "request",
- "scope",
- "server_name",
- "name",
+ "namespace",
]
- self.assertEqual(set(log.keys()), set(expected_log_keys))
-
- # It contains the data we expect.
- self.assertEqual(log["name"], "wally")
+ self.assertCountEqual(log.keys(), expected_log_keys)
+ self.assertEqual(log["log"], "Hello there, wally!")
- def test_log_backpressure_debug(self):
+ def test_extra_data(self):
"""
- When backpressure is hit, DEBUG logs will be shed.
+ Additional information can be included in the structured logging.
"""
- log_config = {
- "loggers": {"synapse": {"level": "DEBUG"}},
- "drains": {
- "tersejson": {
- "type": "network_json_terse",
- "host": "127.0.0.1",
- "port": 8000,
- "maximum_buffer": 10,
- }
- },
- }
-
- # Begin the logger with our config
- beginner = FakeBeginner()
- setup_structured_logging(
- self.hs,
- self.hs.config,
- log_config,
- logBeginner=beginner,
- redirect_stdlib_logging=False,
- )
+ handler = logging.StreamHandler(self.output)
+ handler.setFormatter(TerseJsonFormatter())
+ logger = self.get_logger(handler)
- logger = Logger(
- namespace="synapse.logging.test_terse_json", observer=beginner.observers[0]
+ logger.info(
+ "Hello there, %s!", "wally", extra={"foo": "bar", "int": 3, "bool": True}
)
- # Send some debug messages
- for i in range(0, 3):
- logger.debug("debug %s" % (i,))
-
- # Send a bunch of useful messages
- for i in range(0, 7):
- logger.info("test message %s" % (i,))
-
- # The last debug message pushes it past the maximum buffer
- logger.debug("too much debug")
+ log = self.get_log_line()
- # Allow the reconnection
- _, server = connect_client(self.reactor, 0)
- self.pump()
+ # The terse logger should give us these keys.
+ expected_log_keys = [
+ "log",
+ "time",
+ "level",
+ "namespace",
+ # The additional keys given via extra.
+ "foo",
+ "int",
+ "bool",
+ ]
+ self.assertCountEqual(log.keys(), expected_log_keys)
- # Only the 7 infos made it through, the debugs were elided
- logs = server.data.splitlines()
- self.assertEqual(len(logs), 7)
+ # Check the values of the extra fields.
+ self.assertEqual(log["foo"], "bar")
+ self.assertEqual(log["int"], 3)
+ self.assertIs(log["bool"], True)
- def test_log_backpressure_info(self):
+ def test_json_output(self):
"""
- When backpressure is hit, DEBUG and INFO logs will be shed.
+ The Terse JSON formatter converts log messages to JSON.
"""
- log_config = {
- "loggers": {"synapse": {"level": "DEBUG"}},
- "drains": {
- "tersejson": {
- "type": "network_json_terse",
- "host": "127.0.0.1",
- "port": 8000,
- "maximum_buffer": 10,
- }
- },
- }
-
- # Begin the logger with our config
- beginner = FakeBeginner()
- setup_structured_logging(
- self.hs,
- self.hs.config,
- log_config,
- logBeginner=beginner,
- redirect_stdlib_logging=False,
- )
+ handler = logging.StreamHandler(self.output)
+ handler.setFormatter(JsonFormatter())
+ logger = self.get_logger(handler)
- logger = Logger(
- namespace="synapse.logging.test_terse_json", observer=beginner.observers[0]
- )
-
- # Send some debug messages
- for i in range(0, 3):
- logger.debug("debug %s" % (i,))
-
- # Send a bunch of useful messages
- for i in range(0, 10):
- logger.warn("test warn %s" % (i,))
-
- # Send a bunch of info messages
- for i in range(0, 3):
- logger.info("test message %s" % (i,))
-
- # The last debug message pushes it past the maximum buffer
- logger.debug("too much debug")
+ logger.info("Hello there, %s!", "wally")
- # Allow the reconnection
- client, server = connect_client(self.reactor, 0)
- self.pump()
+ log = self.get_log_line()
- # The 10 warnings made it through, the debugs and infos were elided
- logs = list(map(json.loads, server.data.decode("utf8").splitlines()))
- self.assertEqual(len(logs), 10)
-
- self.assertEqual(Counter([x["level"] for x in logs]), {"WARN": 10})
+ # The terse logger should give us these keys.
+ expected_log_keys = [
+ "log",
+ "level",
+ "namespace",
+ ]
+ self.assertCountEqual(log.keys(), expected_log_keys)
+ self.assertEqual(log["log"], "Hello there, wally!")
- def test_log_backpressure_cut_middle(self):
+ def test_with_context(self):
"""
- When backpressure is hit, and no more DEBUG and INFOs cannot be culled,
- it will cut the middle messages out.
+ The logging context should be added to the JSON response.
"""
- log_config = {
- "loggers": {"synapse": {"level": "DEBUG"}},
- "drains": {
- "tersejson": {
- "type": "network_json_terse",
- "host": "127.0.0.1",
- "port": 8000,
- "maximum_buffer": 10,
- }
- },
- }
-
- # Begin the logger with our config
- beginner = FakeBeginner()
- setup_structured_logging(
- self.hs,
- self.hs.config,
- log_config,
- logBeginner=beginner,
- redirect_stdlib_logging=False,
- )
+ handler = logging.StreamHandler(self.output)
+ handler.setFormatter(JsonFormatter())
+ handler.addFilter(LoggingContextFilter())
+ logger = self.get_logger(handler)
- logger = Logger(
- namespace="synapse.logging.test_terse_json", observer=beginner.observers[0]
- )
+ with LoggingContext(request="test"):
+ logger.info("Hello there, %s!", "wally")
- # Send a bunch of useful messages
- for i in range(0, 20):
- logger.warn("test warn", num=i)
+ log = self.get_log_line()
- # Allow the reconnection
- client, server = connect_client(self.reactor, 0)
- self.pump()
-
- # The first five and last five warnings made it through, the debugs and
- # infos were elided
- logs = list(map(json.loads, server.data.decode("utf8").splitlines()))
- self.assertEqual(len(logs), 10)
- self.assertEqual(Counter([x["level"] for x in logs]), {"WARN": 10})
- self.assertEqual([0, 1, 2, 3, 4, 15, 16, 17, 18, 19], [x["num"] for x in logs])
+ # The terse logger should give us these keys.
+ expected_log_keys = [
+ "log",
+ "level",
+ "namespace",
+ "request",
+ ]
+ self.assertCountEqual(log.keys(), expected_log_keys)
+ self.assertEqual(log["log"], "Hello there, wally!")
+ self.assertEqual(log["request"], "test")
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 04de0b9dbe..27206ca3db 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -12,16 +12,27 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from mock import Mock
-from synapse.module_api import ModuleApi
+from synapse.events import EventBase
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.types import create_requester
from tests.unittest import HomeserverTestCase
class ModuleApiTestCase(HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
def prepare(self, reactor, clock, homeserver):
self.store = homeserver.get_datastore()
- self.module_api = ModuleApi(homeserver, homeserver.get_auth_handler())
+ self.module_api = homeserver.get_module_api()
+ self.event_creation_handler = homeserver.get_event_creation_handler()
def test_can_register_user(self):
"""Tests that an external module can register a user"""
@@ -52,3 +63,142 @@ class ModuleApiTestCase(HomeserverTestCase):
# Check that the displayname was assigned
displayname = self.get_success(self.store.get_profile_displayname("bob"))
self.assertEqual(displayname, "Bobberino")
+
+ def test_sending_events_into_room(self):
+ """Tests that a module can send events into a room"""
+ # Mock out create_and_send_nonmember_event to check whether events are being sent
+ self.event_creation_handler.create_and_send_nonmember_event = Mock(
+ spec=[],
+ side_effect=self.event_creation_handler.create_and_send_nonmember_event,
+ )
+
+ # Create a user and room to play with
+ user_id = self.register_user("summer", "monkey")
+ tok = self.login("summer", "monkey")
+ room_id = self.helper.create_room_as(user_id, tok=tok)
+
+ # Create and send a non-state event
+ content = {"body": "I am a puppet", "msgtype": "m.text"}
+ event_dict = {
+ "room_id": room_id,
+ "type": "m.room.message",
+ "content": content,
+ "sender": user_id,
+ }
+ event = self.get_success(
+ self.module_api.create_and_send_event_into_room(event_dict)
+ ) # type: EventBase
+ self.assertEqual(event.sender, user_id)
+ self.assertEqual(event.type, "m.room.message")
+ self.assertEqual(event.room_id, room_id)
+ self.assertFalse(hasattr(event, "state_key"))
+ self.assertDictEqual(event.content, content)
+
+ expected_requester = create_requester(
+ user_id, authenticated_entity=self.hs.hostname
+ )
+
+ # Check that the event was sent
+ self.event_creation_handler.create_and_send_nonmember_event.assert_called_with(
+ expected_requester, event_dict, ratelimit=False, ignore_shadow_ban=True,
+ )
+
+ # Create and send a state event
+ content = {
+ "events_default": 0,
+ "users": {user_id: 100},
+ "state_default": 50,
+ "users_default": 0,
+ "events": {"test.event.type": 25},
+ }
+ event_dict = {
+ "room_id": room_id,
+ "type": "m.room.power_levels",
+ "content": content,
+ "sender": user_id,
+ "state_key": "",
+ }
+ event = self.get_success(
+ self.module_api.create_and_send_event_into_room(event_dict)
+ ) # type: EventBase
+ self.assertEqual(event.sender, user_id)
+ self.assertEqual(event.type, "m.room.power_levels")
+ self.assertEqual(event.room_id, room_id)
+ self.assertEqual(event.state_key, "")
+ self.assertDictEqual(event.content, content)
+
+ # Check that the event was sent
+ self.event_creation_handler.create_and_send_nonmember_event.assert_called_with(
+ expected_requester,
+ {
+ "type": "m.room.power_levels",
+ "content": content,
+ "room_id": room_id,
+ "sender": user_id,
+ "state_key": "",
+ },
+ ratelimit=False,
+ ignore_shadow_ban=True,
+ )
+
+ # Check that we can't send membership events
+ content = {
+ "membership": "leave",
+ }
+ event_dict = {
+ "room_id": room_id,
+ "type": "m.room.member",
+ "content": content,
+ "sender": user_id,
+ "state_key": user_id,
+ }
+ self.get_failure(
+ self.module_api.create_and_send_event_into_room(event_dict), Exception
+ )
+
+ def test_public_rooms(self):
+ """Tests that a room can be added and removed from the public rooms list,
+ as well as have its public rooms directory state queried.
+ """
+ # Create a user and room to play with
+ user_id = self.register_user("kermit", "monkey")
+ tok = self.login("kermit", "monkey")
+ room_id = self.helper.create_room_as(user_id, tok=tok)
+
+ # The room should not currently be in the public rooms directory
+ is_in_public_rooms = self.get_success(
+ self.module_api.public_room_list_manager.room_is_in_public_room_list(
+ room_id
+ )
+ )
+ self.assertFalse(is_in_public_rooms)
+
+ # Let's try adding it to the public rooms directory
+ self.get_success(
+ self.module_api.public_room_list_manager.add_room_to_public_room_list(
+ room_id
+ )
+ )
+
+ # And checking whether it's in there...
+ is_in_public_rooms = self.get_success(
+ self.module_api.public_room_list_manager.room_is_in_public_room_list(
+ room_id
+ )
+ )
+ self.assertTrue(is_in_public_rooms)
+
+ # Let's remove it again
+ self.get_success(
+ self.module_api.public_room_list_manager.remove_room_from_public_room_list(
+ room_id
+ )
+ )
+
+ # Should be gone
+ is_in_public_rooms = self.get_success(
+ self.module_api.public_room_list_manager.room_is_in_public_room_list(
+ room_id
+ )
+ )
+ self.assertFalse(is_in_public_rooms)
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index 3224568640..961bf09de9 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -100,7 +100,7 @@ class EmailPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(self.access_token)
)
- token_id = user_tuple["token_id"]
+ token_id = user_tuple.token_id
self.pusher = self.get_success(
self.hs.get_pusherpool().add_pusher(
@@ -131,6 +131,35 @@ class EmailPusherTests(HomeserverTestCase):
# We should get emailed about that message
self._check_for_mail()
+ def test_invite_sends_email(self):
+ # Create a room and invite the user to it
+ room = self.helper.create_room_as(self.others[0].id, tok=self.others[0].token)
+ self.helper.invite(
+ room=room,
+ src=self.others[0].id,
+ tok=self.others[0].token,
+ targ=self.user_id,
+ )
+
+ # We should get emailed about the invite
+ self._check_for_mail()
+
+ def test_invite_to_empty_room_sends_email(self):
+ # Create a room and invite the user to it
+ room = self.helper.create_room_as(self.others[0].id, tok=self.others[0].token)
+ self.helper.invite(
+ room=room,
+ src=self.others[0].id,
+ tok=self.others[0].token,
+ targ=self.user_id,
+ )
+
+ # Then have the original user leave
+ self.helper.leave(room, self.others[0].id, tok=self.others[0].token)
+
+ # We should get emailed about the invite
+ self._check_for_mail()
+
def test_multiple_members_email(self):
# We want to test multiple notifications, so we pause processing of push
# while we send messages.
@@ -158,8 +187,21 @@ class EmailPusherTests(HomeserverTestCase):
# We should get emailed about those messages
self._check_for_mail()
+ def test_encrypted_message(self):
+ room = self.helper.create_room_as(self.user_id, tok=self.access_token)
+ self.helper.invite(
+ room=room, src=self.user_id, tok=self.access_token, targ=self.others[0].id
+ )
+ self.helper.join(room=room, user=self.others[0].id, tok=self.others[0].token)
+
+ # The other user sends some messages
+ self.helper.send_event(room, "m.room.encrypted", {}, tok=self.others[0].token)
+
+ # We should get emailed about that message
+ self._check_for_mail()
+
def _check_for_mail(self):
- "Check that the user receives an email notification"
+ """Check that the user receives an email notification"""
# Get the stream ordering before it gets sent
pushers = self.get_success(
@@ -167,7 +209,7 @@ class EmailPusherTests(HomeserverTestCase):
)
pushers = list(pushers)
self.assertEqual(len(pushers), 1)
- last_stream_ordering = pushers[0]["last_stream_ordering"]
+ last_stream_ordering = pushers[0].last_stream_ordering
# Advance time a bit, so the pusher will register something has happened
self.pump(10)
@@ -178,7 +220,7 @@ class EmailPusherTests(HomeserverTestCase):
)
pushers = list(pushers)
self.assertEqual(len(pushers), 1)
- self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"])
+ self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering)
# One email was attempted to be sent
self.assertEqual(len(self.email_attempts), 1)
@@ -196,4 +238,4 @@ class EmailPusherTests(HomeserverTestCase):
)
pushers = list(pushers)
self.assertEqual(len(pushers), 1)
- self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering)
+ self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index b567868b02..60f0820cff 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -12,16 +12,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
from mock import Mock
from twisted.internet.defer import Deferred
import synapse.rest.admin
from synapse.logging.context import make_deferred_yieldable
+from synapse.push import PusherConfigException
from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v2_alpha import receipts
-from tests.unittest import HomeserverTestCase
+from tests.unittest import HomeserverTestCase, override_config
class HTTPPusherTests(HomeserverTestCase):
@@ -29,10 +30,16 @@ class HTTPPusherTests(HomeserverTestCase):
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
login.register_servlets,
+ receipts.register_servlets,
]
user_id = True
hijack_auth = False
+ def default_config(self):
+ config = super().default_config()
+ config["start_pushers"] = True
+ return config
+
def make_homeserver(self, reactor, clock):
self.push_attempts = []
@@ -45,13 +52,49 @@ class HTTPPusherTests(HomeserverTestCase):
m.post_json_get_json = post_json_get_json
- config = self.default_config()
- config["start_pushers"] = True
-
- hs = self.setup_test_homeserver(config=config, proxied_http_client=m)
+ hs = self.setup_test_homeserver(proxied_blacklisted_http_client=m)
return hs
+ def test_invalid_configuration(self):
+ """Invalid push configurations should be rejected."""
+ # Register the user who gets notified
+ user_id = self.register_user("user", "pass")
+ access_token = self.login("user", "pass")
+
+ # Register the pusher
+ user_tuple = self.get_success(
+ self.hs.get_datastore().get_user_by_access_token(access_token)
+ )
+ token_id = user_tuple.token_id
+
+ def test_data(data):
+ self.get_failure(
+ self.hs.get_pusherpool().add_pusher(
+ user_id=user_id,
+ access_token=token_id,
+ kind="http",
+ app_id="m.http",
+ app_display_name="HTTP Push Notifications",
+ device_display_name="pushy push",
+ pushkey="a@example.com",
+ lang=None,
+ data=data,
+ ),
+ PusherConfigException,
+ )
+
+ # Data must be provided with a URL.
+ test_data(None)
+ test_data({})
+ test_data({"url": 1})
+ # A bare domain name isn't accepted.
+ test_data({"url": "example.com"})
+ # A URL without a path isn't accepted.
+ test_data({"url": "http://example.com"})
+ # A url with an incorrect path isn't accepted.
+ test_data({"url": "http://example.com/foo"})
+
def test_sends_http(self):
"""
The HTTP pusher will send pushes for each message to a HTTP endpoint
@@ -69,7 +112,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token)
)
- token_id = user_tuple["token_id"]
+ token_id = user_tuple.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
@@ -81,7 +124,7 @@ class HTTPPusherTests(HomeserverTestCase):
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
- data={"url": "example.com"},
+ data={"url": "http://example.com/_matrix/push/v1/notify"},
)
)
@@ -101,7 +144,7 @@ class HTTPPusherTests(HomeserverTestCase):
)
pushers = list(pushers)
self.assertEqual(len(pushers), 1)
- last_stream_ordering = pushers[0]["last_stream_ordering"]
+ last_stream_ordering = pushers[0].last_stream_ordering
# Advance time a bit, so the pusher will register something has happened
self.pump()
@@ -112,11 +155,13 @@ class HTTPPusherTests(HomeserverTestCase):
)
pushers = list(pushers)
self.assertEqual(len(pushers), 1)
- self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"])
+ self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering)
# One push was attempted to be sent -- it'll be the first message
self.assertEqual(len(self.push_attempts), 1)
- self.assertEqual(self.push_attempts[0][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+ )
self.assertEqual(
self.push_attempts[0][2]["notification"]["content"]["body"], "Hi!"
)
@@ -131,12 +176,14 @@ class HTTPPusherTests(HomeserverTestCase):
)
pushers = list(pushers)
self.assertEqual(len(pushers), 1)
- self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering)
- last_stream_ordering = pushers[0]["last_stream_ordering"]
+ self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
+ last_stream_ordering = pushers[0].last_stream_ordering
# Now it'll try and send the second push message, which will be the second one
self.assertEqual(len(self.push_attempts), 2)
- self.assertEqual(self.push_attempts[1][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
+ )
self.assertEqual(
self.push_attempts[1][2]["notification"]["content"]["body"], "There!"
)
@@ -151,7 +198,7 @@ class HTTPPusherTests(HomeserverTestCase):
)
pushers = list(pushers)
self.assertEqual(len(pushers), 1)
- self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering)
+ self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
def test_sends_high_priority_for_encrypted(self):
"""
@@ -181,7 +228,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token)
)
- token_id = user_tuple["token_id"]
+ token_id = user_tuple.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
@@ -193,7 +240,7 @@ class HTTPPusherTests(HomeserverTestCase):
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
- data={"url": "example.com"},
+ data={"url": "http://example.com/_matrix/push/v1/notify"},
)
)
@@ -229,7 +276,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Check our push made it with high priority
self.assertEqual(len(self.push_attempts), 1)
- self.assertEqual(self.push_attempts[0][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+ )
self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high")
# Add yet another person — we want to make this room not a 1:1
@@ -267,7 +316,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Advance time a bit, so the pusher will register something has happened
self.pump()
self.assertEqual(len(self.push_attempts), 2)
- self.assertEqual(self.push_attempts[1][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
+ )
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high")
def test_sends_high_priority_for_one_to_one_only(self):
@@ -297,7 +348,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token)
)
- token_id = user_tuple["token_id"]
+ token_id = user_tuple.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
@@ -309,7 +360,7 @@ class HTTPPusherTests(HomeserverTestCase):
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
- data={"url": "example.com"},
+ data={"url": "http://example.com/_matrix/push/v1/notify"},
)
)
@@ -325,7 +376,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Check our push made it with high priority — this is a one-to-one room
self.assertEqual(len(self.push_attempts), 1)
- self.assertEqual(self.push_attempts[0][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+ )
self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high")
# Yet another user joins
@@ -344,7 +397,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Advance time a bit, so the pusher will register something has happened
self.pump()
self.assertEqual(len(self.push_attempts), 2)
- self.assertEqual(self.push_attempts[1][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
+ )
# check that this is low-priority
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
@@ -379,7 +434,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token)
)
- token_id = user_tuple["token_id"]
+ token_id = user_tuple.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
@@ -391,7 +446,7 @@ class HTTPPusherTests(HomeserverTestCase):
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
- data={"url": "example.com"},
+ data={"url": "http://example.com/_matrix/push/v1/notify"},
)
)
@@ -407,7 +462,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Check our push made it with high priority
self.assertEqual(len(self.push_attempts), 1)
- self.assertEqual(self.push_attempts[0][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+ )
self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high")
# Send another event, this time with no mention
@@ -416,7 +473,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Advance time a bit, so the pusher will register something has happened
self.pump()
self.assertEqual(len(self.push_attempts), 2)
- self.assertEqual(self.push_attempts[1][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
+ )
# check that this is low-priority
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
@@ -452,7 +511,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token)
)
- token_id = user_tuple["token_id"]
+ token_id = user_tuple.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
@@ -464,7 +523,7 @@ class HTTPPusherTests(HomeserverTestCase):
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
- data={"url": "example.com"},
+ data={"url": "http://example.com/_matrix/push/v1/notify"},
)
)
@@ -484,7 +543,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Check our push made it with high priority
self.assertEqual(len(self.push_attempts), 1)
- self.assertEqual(self.push_attempts[0][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+ )
self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high")
# Send another event, this time as someone without the power of @room
@@ -495,7 +556,169 @@ class HTTPPusherTests(HomeserverTestCase):
# Advance time a bit, so the pusher will register something has happened
self.pump()
self.assertEqual(len(self.push_attempts), 2)
- self.assertEqual(self.push_attempts[1][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
+ )
# check that this is low-priority
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
+
+ def test_push_unread_count_group_by_room(self):
+ """
+ The HTTP pusher will group unread count by number of unread rooms.
+ """
+ # Carry out common push count tests and setup
+ self._test_push_unread_count()
+
+ # Carry out our option-value specific test
+ #
+ # This push should still only contain an unread count of 1 (for 1 unread room)
+ self.assertEqual(
+ self.push_attempts[5][2]["notification"]["counts"]["unread"], 1
+ )
+
+ @override_config({"push": {"group_unread_count_by_room": False}})
+ def test_push_unread_count_message_count(self):
+ """
+ The HTTP pusher will send the total unread message count.
+ """
+ # Carry out common push count tests and setup
+ self._test_push_unread_count()
+
+ # Carry out our option-value specific test
+ #
+ # We're counting every unread message, so there should now be 4 since the
+ # last read receipt
+ self.assertEqual(
+ self.push_attempts[5][2]["notification"]["counts"]["unread"], 4
+ )
+
+ def _test_push_unread_count(self):
+ """
+ Tests that the correct unread count appears in sent push notifications
+
+ Note that:
+ * Sending messages will cause push notifications to go out to relevant users
+ * Sending a read receipt will cause a "badge update" notification to go out to
+ the user that sent the receipt
+ """
+ # Register the user who gets notified
+ user_id = self.register_user("user", "pass")
+ access_token = self.login("user", "pass")
+
+ # Register the user who sends the message
+ other_user_id = self.register_user("other_user", "pass")
+ other_access_token = self.login("other_user", "pass")
+
+ # Create a room (as other_user)
+ room_id = self.helper.create_room_as(other_user_id, tok=other_access_token)
+
+ # The user to get notified joins
+ self.helper.join(room=room_id, user=user_id, tok=access_token)
+
+ # Register the pusher
+ user_tuple = self.get_success(
+ self.hs.get_datastore().get_user_by_access_token(access_token)
+ )
+ token_id = user_tuple.token_id
+
+ self.get_success(
+ self.hs.get_pusherpool().add_pusher(
+ user_id=user_id,
+ access_token=token_id,
+ kind="http",
+ app_id="m.http",
+ app_display_name="HTTP Push Notifications",
+ device_display_name="pushy push",
+ pushkey="a@example.com",
+ lang=None,
+ data={"url": "http://example.com/_matrix/push/v1/notify"},
+ )
+ )
+
+ # Send a message
+ response = self.helper.send(
+ room_id, body="Hello there!", tok=other_access_token
+ )
+ # To get an unread count, the user who is getting notified has to have a read
+ # position in the room. We'll set the read position to this event in a moment
+ first_message_event_id = response["event_id"]
+
+ # Advance time a bit (so the pusher will register something has happened) and
+ # make the push succeed
+ self.push_attempts[0][0].callback({})
+ self.pump()
+
+ # Check our push made it
+ self.assertEqual(len(self.push_attempts), 1)
+ self.assertEqual(
+ self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+ )
+
+ # Check that the unread count for the room is 0
+ #
+ # The unread count is zero as the user has no read receipt in the room yet
+ self.assertEqual(
+ self.push_attempts[0][2]["notification"]["counts"]["unread"], 0
+ )
+
+ # Now set the user's read receipt position to the first event
+ #
+ # This will actually trigger a new notification to be sent out so that
+ # even if the user does not receive another message, their unread
+ # count goes down
+ channel = self.make_request(
+ "POST",
+ "/rooms/%s/receipt/m.read/%s" % (room_id, first_message_event_id),
+ {},
+ access_token=access_token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Advance time and make the push succeed
+ self.push_attempts[1][0].callback({})
+ self.pump()
+
+ # Unread count is still zero as we've read the only message in the room
+ self.assertEqual(len(self.push_attempts), 2)
+ self.assertEqual(
+ self.push_attempts[1][2]["notification"]["counts"]["unread"], 0
+ )
+
+ # Send another message
+ self.helper.send(
+ room_id, body="How's the weather today?", tok=other_access_token
+ )
+
+ # Advance time and make the push succeed
+ self.push_attempts[2][0].callback({})
+ self.pump()
+
+ # This push should contain an unread count of 1 as there's now been one
+ # message since our last read receipt
+ self.assertEqual(len(self.push_attempts), 3)
+ self.assertEqual(
+ self.push_attempts[2][2]["notification"]["counts"]["unread"], 1
+ )
+
+ # Since we're grouping by room, sending more messages shouldn't increase the
+ # unread count, as they're all being sent in the same room
+ self.helper.send(room_id, body="Hello?", tok=other_access_token)
+
+ # Advance time and make the push succeed
+ self.pump()
+ self.push_attempts[3][0].callback({})
+
+ self.helper.send(room_id, body="Hello??", tok=other_access_token)
+
+ # Advance time and make the push succeed
+ self.pump()
+ self.push_attempts[4][0].callback({})
+
+ self.helper.send(room_id, body="HELLO???", tok=other_access_token)
+
+ # Advance time and make the push succeed
+ self.pump()
+ self.push_attempts[5][0].callback({})
+
+ self.assertEqual(len(self.push_attempts), 6)
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index ae60874ec3..3379189785 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -12,23 +12,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
-from typing import Any, Callable, List, Optional, Tuple
+from typing import Any, Callable, Dict, List, Optional, Tuple
import attr
from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
+from twisted.internet.protocol import Protocol
from twisted.internet.task import LoopingCall
from twisted.web.http import HTTPChannel
+from twisted.web.resource import Resource
from synapse.app.generic_worker import (
GenericWorkerReplicationHandler,
GenericWorkerServer,
)
from synapse.http.server import JsonResource
-from synapse.http.site import SynapseRequest
-from synapse.replication.http import ReplicationRestResource, streams
+from synapse.http.site import SynapseRequest, SynapseSite
+from synapse.replication.http import ReplicationRestResource
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@@ -36,7 +37,12 @@ from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
-from tests.server import FakeTransport, render
+from tests.server import FakeTransport
+
+try:
+ import hiredis
+except ImportError:
+ hiredis = None
logger = logging.getLogger(__name__)
@@ -44,9 +50,10 @@ logger = logging.getLogger(__name__)
class BaseStreamTestCase(unittest.HomeserverTestCase):
"""Base class for tests of the replication streams"""
- servlets = [
- streams.register_servlets,
- ]
+ # hiredis is an optional dependency so we don't want to require it for running
+ # the tests.
+ if not hiredis:
+ skip = "Requires hiredis"
def prepare(self, reactor, clock, hs):
# build a replication server
@@ -57,8 +64,8 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
# Make a new HomeServer object for the worker
self.reactor.lookups["testserv"] = "1.2.3.4"
self.worker_hs = self.setup_test_homeserver(
- http_client=None,
- homeserverToUse=GenericWorkerServer,
+ federation_http_client=None,
+ homeserver_to_use=GenericWorkerServer,
config=self._get_worker_hs_config(),
reactor=self.reactor,
)
@@ -68,7 +75,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.worker_hs.get_datastore().db_pool = hs.get_datastore().db_pool
self.test_handler = self._build_replication_data_handler()
- self.worker_hs.replication_data_handler = self.test_handler
+ self.worker_hs._replication_data_handler = self.test_handler
repl_handler = ReplicationCommandHandler(self.worker_hs)
self.client = ClientReplicationStreamProtocol(
@@ -78,6 +85,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self._client_transport = None
self._server_transport = None
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ d = super().create_resource_dict()
+ d["/_synapse/replication"] = ReplicationRestResource(self.hs)
+ return d
+
def _get_worker_hs_config(self) -> dict:
config = self.default_config()
config["worker_app"] = "synapse.app.generic_worker"
@@ -197,23 +209,41 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
self.server_factory = ReplicationStreamProtocolFactory(self.hs)
self.streamer = self.hs.get_replication_streamer()
+ # Fake in memory Redis server that servers can connect to.
+ self._redis_server = FakeRedisPubSubServer()
+
store = self.hs.get_datastore()
self.database_pool = store.db_pool
self.reactor.lookups["testserv"] = "1.2.3.4"
+ self.reactor.lookups["localhost"] = "127.0.0.1"
- self._worker_hs_to_resource = {}
+ # A map from a HS instance to the associated HTTP Site to use for
+ # handling inbound HTTP requests to that instance.
+ self._hs_to_site = {self.hs: self.site}
+
+ if self.hs.config.redis.redis_enabled:
+ # Handle attempts to connect to fake redis server.
+ self.reactor.add_tcp_client_callback(
+ "localhost", 6379, self.connect_any_redis_attempts,
+ )
+
+ self.hs.get_tcp_replication().start_replication(self.hs)
# When we see a connection attempt to the master replication listener we
# automatically set up the connection. This is so that tests don't
# manually have to go and explicitly set it up each time (plus sometimes
# it is impossible to write the handling explicitly in the tests).
+ #
+ # Register the master replication listener:
self.reactor.add_tcp_client_callback(
- "1.2.3.4", 8765, self._handle_http_replication_attempt
+ "1.2.3.4",
+ 8765,
+ lambda: self._handle_http_replication_attempt(self.hs, 8765),
)
- def create_test_json_resource(self):
- """Overrides `HomeserverTestCase.create_test_json_resource`.
+ def create_test_resource(self):
+ """Overrides `HomeserverTestCase.create_test_resource`.
"""
# We override this so that it automatically registers all the HTTP
# replication servlets, without having to explicitly do that in all
@@ -236,7 +266,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
worker_app: Type of worker, e.g. `synapse.app.federation_sender`.
extra_config: Any extra config to use for this instances.
**kwargs: Options that get passed to `self.setup_test_homeserver`,
- useful to e.g. pass some mocks for things like `http_client`
+ useful to e.g. pass some mocks for things like `federation_http_client`
Returns:
The new worker HomeServer instance.
@@ -247,34 +277,69 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
config.update(extra_config)
worker_hs = self.setup_test_homeserver(
- homeserverToUse=GenericWorkerServer,
+ homeserver_to_use=GenericWorkerServer,
config=config,
reactor=self.reactor,
- **kwargs
+ **kwargs,
)
+ # If the instance is in the `instance_map` config then workers may try
+ # and send HTTP requests to it, so we register it with
+ # `_handle_http_replication_attempt` like we do with the master HS.
+ instance_name = worker_hs.get_instance_name()
+ instance_loc = worker_hs.config.worker.instance_map.get(instance_name)
+ if instance_loc:
+ # Ensure the host is one that has a fake DNS entry.
+ if instance_loc.host not in self.reactor.lookups:
+ raise Exception(
+ "Host does not have an IP for instance_map[%r].host = %r"
+ % (instance_name, instance_loc.host,)
+ )
+
+ self.reactor.add_tcp_client_callback(
+ self.reactor.lookups[instance_loc.host],
+ instance_loc.port,
+ lambda: self._handle_http_replication_attempt(
+ worker_hs, instance_loc.port
+ ),
+ )
+
store = worker_hs.get_datastore()
store.db_pool._db_pool = self.database_pool._db_pool
- repl_handler = ReplicationCommandHandler(worker_hs)
- client = ClientReplicationStreamProtocol(
- worker_hs, "client", "test", self.clock, repl_handler,
- )
- server = self.server_factory.buildProtocol(None)
+ # Set up TCP replication between master and the new worker if we don't
+ # have Redis support enabled.
+ if not worker_hs.config.redis_enabled:
+ repl_handler = ReplicationCommandHandler(worker_hs)
+ client = ClientReplicationStreamProtocol(
+ worker_hs, "client", "test", self.clock, repl_handler,
+ )
+ server = self.server_factory.buildProtocol(None)
- client_transport = FakeTransport(server, self.reactor)
- client.makeConnection(client_transport)
+ client_transport = FakeTransport(server, self.reactor)
+ client.makeConnection(client_transport)
- server_transport = FakeTransport(client, self.reactor)
- server.makeConnection(server_transport)
+ server_transport = FakeTransport(client, self.reactor)
+ server.makeConnection(server_transport)
# Set up a resource for the worker
- resource = ReplicationRestResource(self.hs)
+ resource = ReplicationRestResource(worker_hs)
for servlet in self.servlets:
servlet(worker_hs, resource)
- self._worker_hs_to_resource[worker_hs] = resource
+ self._hs_to_site[worker_hs] = SynapseSite(
+ logger_name="synapse.access.http.fake",
+ site_tag="{}-{}".format(
+ worker_hs.config.server.server_name, worker_hs.get_instance_name()
+ ),
+ config=worker_hs.config.server.listeners[0],
+ resource=resource,
+ server_version_string="1",
+ )
+
+ if worker_hs.config.redis.redis_enabled:
+ worker_hs.get_tcp_replication().start_replication(worker_hs)
return worker_hs
@@ -284,9 +349,6 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
config["worker_replication_http_port"] = "8765"
return config
- def render_on_worker(self, worker_hs: HomeServer, request: SynapseRequest):
- render(request, self._worker_hs_to_resource[worker_hs], self.reactor)
-
def replicate(self):
"""Tell the master side of replication that something has happened, and then
wait for the replication to occur.
@@ -294,9 +356,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
self.streamer.on_notifier_poke()
self.pump()
- def _handle_http_replication_attempt(self):
- """Handles a connection attempt to the master replication HTTP
- listener.
+ def _handle_http_replication_attempt(self, hs, repl_port):
+ """Handles a connection attempt to the given HS replication HTTP
+ listener on the given port.
"""
# We should have at least one outbound connection attempt, where the
@@ -305,7 +367,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
self.assertGreaterEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients.pop()
self.assertEqual(host, "1.2.3.4")
- self.assertEqual(port, 8765)
+ self.assertEqual(port, repl_port)
# Set up client side protocol
client_protocol = client_factory.buildProtocol(None)
@@ -315,7 +377,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# Set up the server side protocol
channel = _PushHTTPChannel(self.reactor)
channel.requestFactory = request_factory
- channel.site = self.site
+ channel.site = self._hs_to_site[hs]
# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
@@ -333,6 +395,32 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# inside `connecTCP` before the connection has been passed back to the
# code that requested the TCP connection.
+ def connect_any_redis_attempts(self):
+ """If redis is enabled we need to deal with workers connecting to a
+ redis server. We don't want to use a real Redis server so we use a
+ fake one.
+ """
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+ self.assertEqual(host, "localhost")
+ self.assertEqual(port, 6379)
+
+ client_protocol = client_factory.buildProtocol(None)
+ server_protocol = self._redis_server.buildProtocol(None)
+
+ client_to_server_transport = FakeTransport(
+ server_protocol, self.reactor, client_protocol
+ )
+ client_protocol.makeConnection(client_to_server_transport)
+
+ server_to_client_transport = FakeTransport(
+ client_protocol, self.reactor, server_protocol
+ )
+ server_protocol.makeConnection(server_to_client_transport)
+
+ return client_to_server_transport, server_to_client_transport
+
class TestReplicationDataHandler(GenericWorkerReplicationHandler):
"""Drop-in for ReplicationDataHandler which just collects RDATA rows"""
@@ -467,3 +555,105 @@ class _PullToPushProducer:
pass
self.stopProducing()
+
+
+class FakeRedisPubSubServer:
+ """A fake Redis server for pub/sub.
+ """
+
+ def __init__(self):
+ self._subscribers = set()
+
+ def add_subscriber(self, conn):
+ """A connection has called SUBSCRIBE
+ """
+ self._subscribers.add(conn)
+
+ def remove_subscriber(self, conn):
+ """A connection has called UNSUBSCRIBE
+ """
+ self._subscribers.discard(conn)
+
+ def publish(self, conn, channel, msg) -> int:
+ """A connection want to publish a message to subscribers.
+ """
+ for sub in self._subscribers:
+ sub.send(["message", channel, msg])
+
+ return len(self._subscribers)
+
+ def buildProtocol(self, addr):
+ return FakeRedisPubSubProtocol(self)
+
+
+class FakeRedisPubSubProtocol(Protocol):
+ """A connection from a client talking to the fake Redis server.
+ """
+
+ def __init__(self, server: FakeRedisPubSubServer):
+ self._server = server
+ self._reader = hiredis.Reader()
+
+ def dataReceived(self, data):
+ self._reader.feed(data)
+
+ # We might get multiple messages in one packet.
+ while True:
+ msg = self._reader.gets()
+
+ if msg is False:
+ # No more messages.
+ return
+
+ if not isinstance(msg, list):
+ # Inbound commands should always be a list
+ raise Exception("Expected redis list")
+
+ self.handle_command(msg[0], *msg[1:])
+
+ def handle_command(self, command, *args):
+ """Received a Redis command from the client.
+ """
+
+ # We currently only support pub/sub.
+ if command == b"PUBLISH":
+ channel, message = args
+ num_subscribers = self._server.publish(self, channel, message)
+ self.send(num_subscribers)
+ elif command == b"SUBSCRIBE":
+ (channel,) = args
+ self._server.add_subscriber(self)
+ self.send(["subscribe", channel, 1])
+ else:
+ raise Exception("Unknown command")
+
+ def send(self, msg):
+ """Send a message back to the client.
+ """
+ raw = self.encode(msg).encode("utf-8")
+
+ self.transport.write(raw)
+ self.transport.flush()
+
+ def encode(self, obj):
+ """Encode an object to its Redis format.
+
+ Supports: strings/bytes, integers and list/tuples.
+ """
+
+ if isinstance(obj, bytes):
+ # We assume bytes are just unicode strings.
+ obj = obj.decode("utf-8")
+
+ if isinstance(obj, str):
+ return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj)
+ if isinstance(obj, int):
+ return ":{val}\r\n".format(val=obj)
+ if isinstance(obj, (list, tuple)):
+ items = "".join(self.encode(a) for a in obj)
+ return "*{len}\r\n{items}".format(len=len(obj), items=items)
+
+ raise Exception("Unrecognized type for encoding redis: %r: %r", type(obj), obj)
+
+ def connectionLost(self, reason):
+ self._server.remove_subscriber(self)
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index c9998e88e6..bad0df08cf 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -449,7 +449,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
sender=sender,
type="test_event",
content={"body": body},
- **kwargs
+ **kwargs,
)
)
diff --git a/tests/replication/test_auth.py b/tests/replication/test_auth.py
new file mode 100644
index 0000000000..f35a5235e1
--- /dev/null
+++ b/tests/replication/test_auth.py
@@ -0,0 +1,117 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from synapse.rest.client.v2_alpha import register
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.server import FakeChannel, make_request
+from tests.unittest import override_config
+
+logger = logging.getLogger(__name__)
+
+
+class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
+ """Test the authentication of HTTP calls between workers."""
+
+ servlets = [register.register_servlets]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+ # This isn't a real configuration option but is used to provide the main
+ # homeserver and worker homeserver different options.
+ main_replication_secret = config.pop("main_replication_secret", None)
+ if main_replication_secret:
+ config["worker_replication_secret"] = main_replication_secret
+ return self.setup_test_homeserver(config=config)
+
+ def _get_worker_hs_config(self) -> dict:
+ config = self.default_config()
+ config["worker_app"] = "synapse.app.client_reader"
+ config["worker_replication_host"] = "testserv"
+ config["worker_replication_http_port"] = "8765"
+
+ return config
+
+ def _test_register(self) -> FakeChannel:
+ """Run the actual test:
+
+ 1. Create a worker homeserver.
+ 2. Start registration by providing a user/password.
+ 3. Complete registration by providing dummy auth (this hits the main synapse).
+ 4. Return the final request.
+
+ """
+ worker_hs = self.make_worker_hs("synapse.app.client_reader")
+ site = self._hs_to_site[worker_hs]
+
+ channel_1 = make_request(
+ self.reactor,
+ site,
+ "POST",
+ "register",
+ {"username": "user", "type": "m.login.password", "password": "bar"},
+ )
+ self.assertEqual(channel_1.code, 401)
+
+ # Grab the session
+ session = channel_1.json_body["session"]
+
+ # also complete the dummy auth
+ return make_request(
+ self.reactor,
+ site,
+ "POST",
+ "register",
+ {"auth": {"session": session, "type": "m.login.dummy"}},
+ )
+
+ def test_no_auth(self):
+ """With no authentication the request should finish.
+ """
+ channel = self._test_register()
+ self.assertEqual(channel.code, 200)
+
+ # We're given a registered user.
+ self.assertEqual(channel.json_body["user_id"], "@user:test")
+
+ @override_config({"main_replication_secret": "my-secret"})
+ def test_missing_auth(self):
+ """If the main process expects a secret that is not provided, an error results.
+ """
+ channel = self._test_register()
+ self.assertEqual(channel.code, 500)
+
+ @override_config(
+ {
+ "main_replication_secret": "my-secret",
+ "worker_replication_secret": "wrong-secret",
+ }
+ )
+ def test_unauthorized(self):
+ """If the main process receives the wrong secret, an error results.
+ """
+ channel = self._test_register()
+ self.assertEqual(channel.code, 500)
+
+ @override_config({"worker_replication_secret": "my-secret"})
+ def test_authorized(self):
+ """The request should finish when the worker provides the authentication header.
+ """
+ channel = self._test_register()
+ self.assertEqual(channel.code, 200)
+
+ # We're given a registered user.
+ self.assertEqual(channel.json_body["user_id"], "@user:test")
diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py
index 86c03fd89c..4608b65a0c 100644
--- a/tests/replication/test_client_reader_shard.py
+++ b/tests/replication/test_client_reader_shard.py
@@ -14,27 +14,19 @@
# limitations under the License.
import logging
-from synapse.api.constants import LoginType
-from synapse.http.site import SynapseRequest
from synapse.rest.client.v2_alpha import register
from tests.replication._base import BaseMultiWorkerStreamTestCase
-from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker
-from tests.server import FakeChannel
+from tests.server import make_request
logger = logging.getLogger(__name__)
class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
- """Base class for tests of the replication streams"""
+ """Test using one or more client readers for registration."""
servlets = [register.register_servlets]
- def prepare(self, reactor, clock, hs):
- self.recaptcha_checker = DummyRecaptchaChecker(hs)
- auth_handler = hs.get_auth_handler()
- auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker
-
def _get_worker_hs_config(self) -> dict:
config = self.default_config()
config["worker_app"] = "synapse.app.client_reader"
@@ -46,24 +38,29 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
"""Test that registration works when using a single client reader worker.
"""
worker_hs = self.make_worker_hs("synapse.app.client_reader")
+ site = self._hs_to_site[worker_hs]
- request_1, channel_1 = self.make_request(
+ channel_1 = make_request(
+ self.reactor,
+ site,
"POST",
"register",
{"username": "user", "type": "m.login.password", "password": "bar"},
- ) # type: SynapseRequest, FakeChannel
- self.render_on_worker(worker_hs, request_1)
- self.assertEqual(request_1.code, 401)
+ )
+ self.assertEqual(channel_1.code, 401)
# Grab the session
session = channel_1.json_body["session"]
# also complete the dummy auth
- request_2, channel_2 = self.make_request(
- "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
- ) # type: SynapseRequest, FakeChannel
- self.render_on_worker(worker_hs, request_2)
- self.assertEqual(request_2.code, 200)
+ channel_2 = make_request(
+ self.reactor,
+ site,
+ "POST",
+ "register",
+ {"auth": {"session": session, "type": "m.login.dummy"}},
+ )
+ self.assertEqual(channel_2.code, 200)
# We're given a registered user.
self.assertEqual(channel_2.json_body["user_id"], "@user:test")
@@ -74,23 +71,29 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
worker_hs_1 = self.make_worker_hs("synapse.app.client_reader")
worker_hs_2 = self.make_worker_hs("synapse.app.client_reader")
- request_1, channel_1 = self.make_request(
+ site_1 = self._hs_to_site[worker_hs_1]
+ channel_1 = make_request(
+ self.reactor,
+ site_1,
"POST",
"register",
{"username": "user", "type": "m.login.password", "password": "bar"},
- ) # type: SynapseRequest, FakeChannel
- self.render_on_worker(worker_hs_1, request_1)
- self.assertEqual(request_1.code, 401)
+ )
+ self.assertEqual(channel_1.code, 401)
# Grab the session
session = channel_1.json_body["session"]
# also complete the dummy auth
- request_2, channel_2 = self.make_request(
- "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
- ) # type: SynapseRequest, FakeChannel
- self.render_on_worker(worker_hs_2, request_2)
- self.assertEqual(request_2.code, 200)
+ site_2 = self._hs_to_site[worker_hs_2]
+ channel_2 = make_request(
+ self.reactor,
+ site_2,
+ "POST",
+ "register",
+ {"auth": {"session": session, "type": "m.login.dummy"}},
+ )
+ self.assertEqual(channel_2.code, 200)
# We're given a registered user.
self.assertEqual(channel_2.json_body["user_id"], "@user:test")
diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py
index 23be1167a3..1853667558 100644
--- a/tests/replication/test_federation_ack.py
+++ b/tests/replication/test_federation_ack.py
@@ -31,7 +31,7 @@ class FederationAckTestCase(HomeserverTestCase):
return config
def make_homeserver(self, reactor, clock):
- hs = self.setup_test_homeserver(homeserverToUse=GenericWorkerServer)
+ hs = self.setup_test_homeserver(homeserver_to_use=GenericWorkerServer)
return hs
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 1d7edee5ba..fffdb742c8 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -50,7 +50,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
self.make_worker_hs(
"synapse.app.federation_sender",
{"send_federation": True},
- http_client=mock_client,
+ federation_http_client=mock_client,
)
user = self.register_user("user", "pass")
@@ -81,7 +81,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
"worker_name": "sender1",
"federation_sender_instances": ["sender1", "sender2"],
},
- http_client=mock_client1,
+ federation_http_client=mock_client1,
)
mock_client2 = Mock(spec=["put_json"])
@@ -93,7 +93,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
"worker_name": "sender2",
"federation_sender_instances": ["sender1", "sender2"],
},
- http_client=mock_client2,
+ federation_http_client=mock_client2,
)
user = self.register_user("user2", "pass")
@@ -144,7 +144,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
"worker_name": "sender1",
"federation_sender_instances": ["sender1", "sender2"],
},
- http_client=mock_client1,
+ federation_http_client=mock_client1,
)
mock_client2 = Mock(spec=["put_json"])
@@ -156,7 +156,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
"worker_name": "sender2",
"federation_sender_instances": ["sender1", "sender2"],
},
- http_client=mock_client2,
+ federation_http_client=mock_client2,
)
user = self.register_user("user3", "pass")
@@ -207,7 +207,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
def create_room_with_remote_server(self, user, token, remote_server="other_server"):
room = self.helper.create_room_as(user, tok=token)
store = self.hs.get_datastore()
- federation = self.hs.get_handlers().federation_handler
+ federation = self.hs.get_federation_handler()
prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room))
room_version = self.get_success(store.get_room_version(room))
@@ -226,7 +226,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
}
builder = factory.for_room_version(room_version, event_dict)
- join_event = self.get_success(builder.build(prev_event_ids))
+ join_event = self.get_success(builder.build(prev_event_ids, None))
self.get_success(federation.on_send_join_request(remote_server, join_event))
self.replicate()
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
new file mode 100644
index 0000000000..d1feca961f
--- /dev/null
+++ b/tests/replication/test_multi_media_repo.py
@@ -0,0 +1,279 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import os
+from binascii import unhexlify
+from typing import Tuple
+
+from twisted.internet.protocol import Factory
+from twisted.protocols.tls import TLSMemoryBIOFactory
+from twisted.web.http import HTTPChannel
+from twisted.web.server import Request
+
+from synapse.rest import admin
+from synapse.rest.client.v1 import login
+from synapse.server import HomeServer
+
+from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.server import FakeChannel, FakeSite, FakeTransport, make_request
+
+logger = logging.getLogger(__name__)
+
+test_server_connection_factory = None
+
+
+class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
+ """Checks running multiple media repos work correctly.
+ """
+
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.user_id = self.register_user("user", "pass")
+ self.access_token = self.login("user", "pass")
+
+ self.reactor.lookups["example.com"] = "1.2.3.4"
+
+ def default_config(self):
+ conf = super().default_config()
+ conf["federation_custom_ca_list"] = [get_test_ca_cert_file()]
+ return conf
+
+ def _get_media_req(
+ self, hs: HomeServer, target: str, media_id: str
+ ) -> Tuple[FakeChannel, Request]:
+ """Request some remote media from the given HS by calling the download
+ API.
+
+ This then triggers an outbound request from the HS to the target.
+
+ Returns:
+ The channel for the *client* request and the *outbound* request for
+ the media which the caller should respond to.
+ """
+ resource = hs.get_media_repository_resource().children[b"download"]
+ channel = make_request(
+ self.reactor,
+ FakeSite(resource),
+ "GET",
+ "/{}/{}".format(target, media_id),
+ shorthand=False,
+ access_token=self.access_token,
+ await_result=False,
+ )
+ self.pump()
+
+ clients = self.reactor.tcpClients
+ self.assertGreaterEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop()
+
+ # build the test server
+ server_tls_protocol = _build_test_server(get_connection_factory())
+
+ # now, tell the client protocol factory to build the client protocol (it will be a
+ # _WrappingProtocol, around a TLSMemoryBIOProtocol, around an
+ # HTTP11ClientProtocol) and wire the output of said protocol up to the server via
+ # a FakeTransport.
+ #
+ # Normally this would be done by the TCP socket code in Twisted, but we are
+ # stubbing that out here.
+ client_protocol = client_factory.buildProtocol(None)
+ client_protocol.makeConnection(
+ FakeTransport(server_tls_protocol, self.reactor, client_protocol)
+ )
+
+ # tell the server tls protocol to send its stuff back to the client, too
+ server_tls_protocol.makeConnection(
+ FakeTransport(client_protocol, self.reactor, server_tls_protocol)
+ )
+
+ # fish the test server back out of the server-side TLS protocol.
+ http_server = server_tls_protocol.wrappedProtocol
+
+ # give the reactor a pump to get the TLS juices flowing.
+ self.reactor.pump((0.1,))
+
+ self.assertEqual(len(http_server.requests), 1)
+ request = http_server.requests[0]
+
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(
+ request.path,
+ "/_matrix/media/r0/download/{}/{}".format(target, media_id).encode("utf-8"),
+ )
+ self.assertEqual(
+ request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")]
+ )
+
+ return channel, request
+
+ def test_basic(self):
+ """Test basic fetching of remote media from a single worker.
+ """
+ hs1 = self.make_worker_hs("synapse.app.generic_worker")
+
+ channel, request = self._get_media_req(hs1, "example.com:443", "ABC123")
+
+ request.setResponseCode(200)
+ request.responseHeaders.setRawHeaders(b"Content-Type", [b"text/plain"])
+ request.write(b"Hello!")
+ request.finish()
+
+ self.pump(0.1)
+
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.result["body"], b"Hello!")
+
+ def test_download_simple_file_race(self):
+ """Test that fetching remote media from two different processes at the
+ same time works.
+ """
+ hs1 = self.make_worker_hs("synapse.app.generic_worker")
+ hs2 = self.make_worker_hs("synapse.app.generic_worker")
+
+ start_count = self._count_remote_media()
+
+ # Make two requests without responding to the outbound media requests.
+ channel1, request1 = self._get_media_req(hs1, "example.com:443", "ABC123")
+ channel2, request2 = self._get_media_req(hs2, "example.com:443", "ABC123")
+
+ # Respond to the first outbound media request and check that the client
+ # request is successful
+ request1.setResponseCode(200)
+ request1.responseHeaders.setRawHeaders(b"Content-Type", [b"text/plain"])
+ request1.write(b"Hello!")
+ request1.finish()
+
+ self.pump(0.1)
+
+ self.assertEqual(channel1.code, 200, channel1.result["body"])
+ self.assertEqual(channel1.result["body"], b"Hello!")
+
+ # Now respond to the second with the same content.
+ request2.setResponseCode(200)
+ request2.responseHeaders.setRawHeaders(b"Content-Type", [b"text/plain"])
+ request2.write(b"Hello!")
+ request2.finish()
+
+ self.pump(0.1)
+
+ self.assertEqual(channel2.code, 200, channel2.result["body"])
+ self.assertEqual(channel2.result["body"], b"Hello!")
+
+ # We expect only one new file to have been persisted.
+ self.assertEqual(start_count + 1, self._count_remote_media())
+
+ def test_download_image_race(self):
+ """Test that fetching remote *images* from two different processes at
+ the same time works.
+
+ This checks that races generating thumbnails are handled correctly.
+ """
+ hs1 = self.make_worker_hs("synapse.app.generic_worker")
+ hs2 = self.make_worker_hs("synapse.app.generic_worker")
+
+ start_count = self._count_remote_thumbnails()
+
+ channel1, request1 = self._get_media_req(hs1, "example.com:443", "PIC1")
+ channel2, request2 = self._get_media_req(hs2, "example.com:443", "PIC1")
+
+ png_data = unhexlify(
+ b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+ b"0000001f15c4890000000a49444154789c63000100000500010d"
+ b"0a2db40000000049454e44ae426082"
+ )
+
+ request1.setResponseCode(200)
+ request1.responseHeaders.setRawHeaders(b"Content-Type", [b"image/png"])
+ request1.write(png_data)
+ request1.finish()
+
+ self.pump(0.1)
+
+ self.assertEqual(channel1.code, 200, channel1.result["body"])
+ self.assertEqual(channel1.result["body"], png_data)
+
+ request2.setResponseCode(200)
+ request2.responseHeaders.setRawHeaders(b"Content-Type", [b"image/png"])
+ request2.write(png_data)
+ request2.finish()
+
+ self.pump(0.1)
+
+ self.assertEqual(channel2.code, 200, channel2.result["body"])
+ self.assertEqual(channel2.result["body"], png_data)
+
+ # We expect only three new thumbnails to have been persisted.
+ self.assertEqual(start_count + 3, self._count_remote_thumbnails())
+
+ def _count_remote_media(self) -> int:
+ """Count the number of files in our remote media directory.
+ """
+ path = os.path.join(
+ self.hs.get_media_repository().primary_base_path, "remote_content"
+ )
+ return sum(len(files) for _, _, files in os.walk(path))
+
+ def _count_remote_thumbnails(self) -> int:
+ """Count the number of files in our remote thumbnails directory.
+ """
+ path = os.path.join(
+ self.hs.get_media_repository().primary_base_path, "remote_thumbnail"
+ )
+ return sum(len(files) for _, _, files in os.walk(path))
+
+
+def get_connection_factory():
+ # this needs to happen once, but not until we are ready to run the first test
+ global test_server_connection_factory
+ if test_server_connection_factory is None:
+ test_server_connection_factory = TestServerTLSConnectionFactory(
+ sanlist=[b"DNS:example.com"]
+ )
+ return test_server_connection_factory
+
+
+def _build_test_server(connection_creator):
+ """Construct a test server
+
+ This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol
+
+ Args:
+ connection_creator (IOpenSSLServerConnectionCreator): thing to build
+ SSL connections
+ sanlist (list[bytes]): list of the SAN entries for the cert returned
+ by the server
+
+ Returns:
+ TLSMemoryBIOProtocol
+ """
+ server_factory = Factory.forProtocol(HTTPChannel)
+ # Request.finish expects the factory to have a 'log' method.
+ server_factory.log = _log_request
+
+ server_tls_factory = TLSMemoryBIOFactory(
+ connection_creator, isClient=False, wrappedFactory=server_factory
+ )
+
+ return server_tls_factory.buildProtocol(None)
+
+
+def _log_request(request):
+ """Implements Factory.log, which is expected by Request.finish"""
+ logger.info("Completed request %s", request)
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index 2bdc6edbb1..800ad94a04 100644
--- a/tests/replication/test_pusher_shard.py
+++ b/tests/replication/test_pusher_shard.py
@@ -55,7 +55,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
user_dict = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token)
)
- token_id = user_dict["token_id"]
+ token_id = user_dict.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
@@ -67,7 +67,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
- data={"url": "https://push.example.com/push"},
+ data={"url": "https://push.example.com/_matrix/push/v1/notify"},
)
)
@@ -98,7 +98,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
self.make_worker_hs(
"synapse.app.pusher",
{"start_pushers": True},
- proxied_http_client=http_client_mock,
+ proxied_blacklisted_http_client=http_client_mock,
)
event_id = self._create_pusher_and_send_msg("user")
@@ -109,7 +109,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
http_client_mock.post_json_get_json.assert_called_once()
self.assertEqual(
http_client_mock.post_json_get_json.call_args[0][0],
- "https://push.example.com/push",
+ "https://push.example.com/_matrix/push/v1/notify",
)
self.assertEqual(
event_id,
@@ -133,7 +133,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
"worker_name": "pusher1",
"pusher_instances": ["pusher1", "pusher2"],
},
- proxied_http_client=http_client_mock1,
+ proxied_blacklisted_http_client=http_client_mock1,
)
http_client_mock2 = Mock(spec_set=["post_json_get_json"])
@@ -148,7 +148,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
"worker_name": "pusher2",
"pusher_instances": ["pusher1", "pusher2"],
},
- proxied_http_client=http_client_mock2,
+ proxied_blacklisted_http_client=http_client_mock2,
)
# We choose a user name that we know should go to pusher1.
@@ -161,7 +161,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
http_client_mock2.post_json_get_json.assert_not_called()
self.assertEqual(
http_client_mock1.post_json_get_json.call_args[0][0],
- "https://push.example.com/push",
+ "https://push.example.com/_matrix/push/v1/notify",
)
self.assertEqual(
event_id,
@@ -183,7 +183,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
http_client_mock2.post_json_get_json.assert_called_once()
self.assertEqual(
http_client_mock2.post_json_get_json.call_args[0][0],
- "https://push.example.com/push",
+ "https://push.example.com/_matrix/push/v1/notify",
)
self.assertEqual(
event_id,
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
new file mode 100644
index 0000000000..8d494ebc03
--- /dev/null
+++ b/tests/replication/test_sharded_event_persister.py
@@ -0,0 +1,333 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from mock import patch
+
+from synapse.api.room_versions import RoomVersion
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v2_alpha import sync
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.server import make_request
+from tests.utils import USE_POSTGRES_FOR_TESTS
+
+logger = logging.getLogger(__name__)
+
+
+class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
+ """Checks event persisting sharding works
+ """
+
+ # Event persister sharding requires postgres (due to needing
+ # `MutliWriterIdGenerator`).
+ if not USE_POSTGRES_FOR_TESTS:
+ skip = "Requires Postgres"
+
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ # Register a user who sends a message that we'll get notified about
+ self.other_user_id = self.register_user("otheruser", "pass")
+ self.other_access_token = self.login("otheruser", "pass")
+
+ self.room_creator = self.hs.get_room_creation_handler()
+ self.store = hs.get_datastore()
+
+ def default_config(self):
+ conf = super().default_config()
+ conf["redis"] = {"enabled": "true"}
+ conf["stream_writers"] = {"events": ["worker1", "worker2"]}
+ conf["instance_map"] = {
+ "worker1": {"host": "testserv", "port": 1001},
+ "worker2": {"host": "testserv", "port": 1002},
+ }
+ return conf
+
+ def _create_room(self, room_id: str, user_id: str, tok: str):
+ """Create a room with given room_id
+ """
+
+ # We control the room ID generation by patching out the
+ # `_generate_room_id` method
+ async def generate_room(
+ creator_id: str, is_public: bool, room_version: RoomVersion
+ ):
+ await self.store.store_room(
+ room_id=room_id,
+ room_creator_user_id=creator_id,
+ is_public=is_public,
+ room_version=room_version,
+ )
+ return room_id
+
+ with patch(
+ "synapse.handlers.room.RoomCreationHandler._generate_room_id"
+ ) as mock:
+ mock.side_effect = generate_room
+ self.helper.create_room_as(user_id, tok=tok)
+
+ def test_basic(self):
+ """Simple test to ensure that multiple rooms can be created and joined,
+ and that different rooms get handled by different instances.
+ """
+
+ self.make_worker_hs(
+ "synapse.app.generic_worker", {"worker_name": "worker1"},
+ )
+
+ self.make_worker_hs(
+ "synapse.app.generic_worker", {"worker_name": "worker2"},
+ )
+
+ persisted_on_1 = False
+ persisted_on_2 = False
+
+ store = self.hs.get_datastore()
+
+ user_id = self.register_user("user", "pass")
+ access_token = self.login("user", "pass")
+
+ # Keep making new rooms until we see rooms being persisted on both
+ # workers.
+ for _ in range(10):
+ # Create a room
+ room = self.helper.create_room_as(user_id, tok=access_token)
+
+ # The other user joins
+ self.helper.join(
+ room=room, user=self.other_user_id, tok=self.other_access_token
+ )
+
+ # The other user sends some messages
+ rseponse = self.helper.send(room, body="Hi!", tok=self.other_access_token)
+ event_id = rseponse["event_id"]
+
+ # The event position includes which instance persisted the event.
+ pos = self.get_success(store.get_position_for_event(event_id))
+
+ persisted_on_1 |= pos.instance_name == "worker1"
+ persisted_on_2 |= pos.instance_name == "worker2"
+
+ if persisted_on_1 and persisted_on_2:
+ break
+
+ self.assertTrue(persisted_on_1)
+ self.assertTrue(persisted_on_2)
+
+ def test_vector_clock_token(self):
+ """Tests that using a stream token with a vector clock component works
+ correctly with basic /sync and /messages usage.
+ """
+
+ self.make_worker_hs(
+ "synapse.app.generic_worker", {"worker_name": "worker1"},
+ )
+
+ worker_hs2 = self.make_worker_hs(
+ "synapse.app.generic_worker", {"worker_name": "worker2"},
+ )
+
+ sync_hs = self.make_worker_hs(
+ "synapse.app.generic_worker", {"worker_name": "sync"},
+ )
+ sync_hs_site = self._hs_to_site[sync_hs]
+
+ # Specially selected room IDs that get persisted on different workers.
+ room_id1 = "!foo:test"
+ room_id2 = "!baz:test"
+
+ self.assertEqual(
+ self.hs.config.worker.events_shard_config.get_instance(room_id1), "worker1"
+ )
+ self.assertEqual(
+ self.hs.config.worker.events_shard_config.get_instance(room_id2), "worker2"
+ )
+
+ user_id = self.register_user("user", "pass")
+ access_token = self.login("user", "pass")
+
+ store = self.hs.get_datastore()
+
+ # Create two room on the different workers.
+ self._create_room(room_id1, user_id, access_token)
+ self._create_room(room_id2, user_id, access_token)
+
+ # The other user joins
+ self.helper.join(
+ room=room_id1, user=self.other_user_id, tok=self.other_access_token
+ )
+ self.helper.join(
+ room=room_id2, user=self.other_user_id, tok=self.other_access_token
+ )
+
+ # Do an initial sync so that we're up to date.
+ channel = make_request(
+ self.reactor, sync_hs_site, "GET", "/sync", access_token=access_token
+ )
+ next_batch = channel.json_body["next_batch"]
+
+ # We now gut wrench into the events stream MultiWriterIdGenerator on
+ # worker2 to mimic it getting stuck persisting an event. This ensures
+ # that when we send an event on worker1 we end up in a state where
+ # worker2 events stream position lags that on worker1, resulting in a
+ # RoomStreamToken with a non-empty instance map component.
+ #
+ # Worker2's event stream position will not advance until we call
+ # __aexit__ again.
+ actx = worker_hs2.get_datastore()._stream_id_gen.get_next()
+ self.get_success(actx.__aenter__())
+
+ response = self.helper.send(room_id1, body="Hi!", tok=self.other_access_token)
+ first_event_in_room1 = response["event_id"]
+
+ # Assert that the current stream token has an instance map component, as
+ # we are trying to test vector clock tokens.
+ room_stream_token = store.get_room_max_token()
+ self.assertNotEqual(len(room_stream_token.instance_map), 0)
+
+ # Check that syncing still gets the new event, despite the gap in the
+ # stream IDs.
+ channel = make_request(
+ self.reactor,
+ sync_hs_site,
+ "GET",
+ "/sync?since={}".format(next_batch),
+ access_token=access_token,
+ )
+
+ # We should only see the new event and nothing else
+ self.assertIn(room_id1, channel.json_body["rooms"]["join"])
+ self.assertNotIn(room_id2, channel.json_body["rooms"]["join"])
+
+ events = channel.json_body["rooms"]["join"][room_id1]["timeline"]["events"]
+ self.assertListEqual(
+ [first_event_in_room1], [event["event_id"] for event in events]
+ )
+
+ # Get the next batch and makes sure its a vector clock style token.
+ vector_clock_token = channel.json_body["next_batch"]
+ self.assertTrue(vector_clock_token.startswith("m"))
+
+ # Now that we've got a vector clock token we finish the fake persisting
+ # an event we started above.
+ self.get_success(actx.__aexit__(None, None, None))
+
+ # Now try and send an event to the other rooom so that we can test that
+ # the vector clock style token works as a `since` token.
+ response = self.helper.send(room_id2, body="Hi!", tok=self.other_access_token)
+ first_event_in_room2 = response["event_id"]
+
+ channel = make_request(
+ self.reactor,
+ sync_hs_site,
+ "GET",
+ "/sync?since={}".format(vector_clock_token),
+ access_token=access_token,
+ )
+
+ self.assertNotIn(room_id1, channel.json_body["rooms"]["join"])
+ self.assertIn(room_id2, channel.json_body["rooms"]["join"])
+
+ events = channel.json_body["rooms"]["join"][room_id2]["timeline"]["events"]
+ self.assertListEqual(
+ [first_event_in_room2], [event["event_id"] for event in events]
+ )
+
+ next_batch = channel.json_body["next_batch"]
+
+ # We also want to test that the vector clock style token works with
+ # pagination. We do this by sending a couple of new events into the room
+ # and syncing again to get a prev_batch token for each room, then
+ # paginating from there back to the vector clock token.
+ self.helper.send(room_id1, body="Hi again!", tok=self.other_access_token)
+ self.helper.send(room_id2, body="Hi again!", tok=self.other_access_token)
+
+ channel = make_request(
+ self.reactor,
+ sync_hs_site,
+ "GET",
+ "/sync?since={}".format(next_batch),
+ access_token=access_token,
+ )
+
+ prev_batch1 = channel.json_body["rooms"]["join"][room_id1]["timeline"][
+ "prev_batch"
+ ]
+ prev_batch2 = channel.json_body["rooms"]["join"][room_id2]["timeline"][
+ "prev_batch"
+ ]
+
+ # Paginating back in the first room should not produce any results, as
+ # no events have happened in it. This tests that we are correctly
+ # filtering results based on the vector clock portion.
+ channel = make_request(
+ self.reactor,
+ sync_hs_site,
+ "GET",
+ "/rooms/{}/messages?from={}&to={}&dir=b".format(
+ room_id1, prev_batch1, vector_clock_token
+ ),
+ access_token=access_token,
+ )
+ self.assertListEqual([], channel.json_body["chunk"])
+
+ # Paginating back on the second room should produce the first event
+ # again. This tests that pagination isn't completely broken.
+ channel = make_request(
+ self.reactor,
+ sync_hs_site,
+ "GET",
+ "/rooms/{}/messages?from={}&to={}&dir=b".format(
+ room_id2, prev_batch2, vector_clock_token
+ ),
+ access_token=access_token,
+ )
+ self.assertEqual(len(channel.json_body["chunk"]), 1)
+ self.assertEqual(
+ channel.json_body["chunk"][0]["event_id"], first_event_in_room2
+ )
+
+ # Paginating forwards should give the same results
+ channel = make_request(
+ self.reactor,
+ sync_hs_site,
+ "GET",
+ "/rooms/{}/messages?from={}&to={}&dir=f".format(
+ room_id1, vector_clock_token, prev_batch1
+ ),
+ access_token=access_token,
+ )
+ self.assertListEqual([], channel.json_body["chunk"])
+
+ channel = make_request(
+ self.reactor,
+ sync_hs_site,
+ "GET",
+ "/rooms/{}/messages?from={}&to={}&dir=f".format(
+ room_id2, vector_clock_token, prev_batch2,
+ ),
+ access_token=access_token,
+ )
+ self.assertEqual(len(channel.json_body["chunk"]), 1)
+ self.assertEqual(
+ channel.json_body["chunk"][0]["event_id"], first_event_in_room2
+ )
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 0f1144fe1e..0504cd187e 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -30,19 +30,19 @@ from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import groups
from tests import unittest
+from tests.server import FakeSite, make_request
class VersionTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/server_version"
- def create_test_json_resource(self):
+ def create_test_resource(self):
resource = JsonResource(self.hs)
VersionServlet(self.hs).register(resource)
return resource
def test_version_string(self):
- request, channel = self.make_request("GET", self.url, shorthand=False)
- self.render(request)
+ channel = self.make_request("GET", self.url, shorthand=False)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(
@@ -68,14 +68,13 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
def test_delete_group(self):
# Create a new group
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/create_group".encode("ascii"),
access_token=self.admin_user_tok,
content={"localpart": "test"},
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
group_id = channel.json_body["group_id"]
@@ -85,17 +84,15 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
# Invite/join another user
url = "/groups/%s/admin/users/invite/%s" % (group_id, self.other_user)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", url.encode("ascii"), access_token=self.admin_user_tok, content={}
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
url = "/groups/%s/self/accept_invite" % (group_id,)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", url.encode("ascii"), access_token=self.other_user_token, content={}
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Check other user knows they're in the group
@@ -103,15 +100,14 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
self.assertIn(group_id, self._get_groups_user_is_in(self.other_user_token))
# Now delete the group
- url = "/admin/delete_group/" + group_id
- request, channel = self.make_request(
+ url = "/_synapse/admin/v1/delete_group/" + group_id
+ channel = self.make_request(
"POST",
url.encode("ascii"),
access_token=self.admin_user_tok,
content={"localpart": "test"},
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Check group returns 404
@@ -127,11 +123,10 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
"""
url = "/groups/%s/profile" % (group_id,)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok
)
- self.render(request)
self.assertEqual(
expect_code, int(channel.result["code"]), msg=channel.result["body"]
)
@@ -139,11 +134,10 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
def _get_groups_user_is_in(self, access_token):
"""Returns the list of groups the user is in (given their access token)
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/joined_groups".encode("ascii"), access_token=access_token
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
return channel.json_body["groups"]
@@ -216,17 +210,20 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
}
config["media_storage_providers"] = [provider_config]
- hs = self.setup_test_homeserver(config=config, http_client=client)
+ hs = self.setup_test_homeserver(config=config, federation_http_client=client)
return hs
def _ensure_quarantined(self, admin_user_tok, server_and_media_id):
"""Ensure a piece of media is quarantined when trying to access it."""
- request, channel = self.make_request(
- "GET", server_and_media_id, shorthand=False, access_token=admin_user_tok,
+ channel = make_request(
+ self.reactor,
+ FakeSite(self.download_resource),
+ "GET",
+ server_and_media_id,
+ shorthand=False,
+ access_token=admin_user_tok,
)
- request.render(self.download_resource)
- self.pump(1.0)
# Should be quarantined
self.assertEqual(
@@ -244,10 +241,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Attempt quarantine media APIs as non-admin
url = "/_synapse/admin/v1/media/quarantine/example.org/abcde12345"
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", url.encode("ascii"), access_token=non_admin_user_tok,
)
- self.render(request)
# Expect a forbidden error
self.assertEqual(
@@ -258,10 +254,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# And the roomID/userID endpoint
url = "/_synapse/admin/v1/room/!room%3Aexample.com/media/quarantine"
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", url.encode("ascii"), access_token=non_admin_user_tok,
)
- self.render(request)
# Expect a forbidden error
self.assertEqual(
@@ -287,14 +282,14 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
server_name, media_id = server_name_and_media_id.split("/")
# Attempt to access the media
- request, channel = self.make_request(
+ channel = make_request(
+ self.reactor,
+ FakeSite(self.download_resource),
"GET",
server_name_and_media_id,
shorthand=False,
access_token=non_admin_user_tok,
)
- request.render(self.download_resource)
- self.pump(1.0)
# Should be successful
self.assertEqual(200, int(channel.code), msg=channel.result["body"])
@@ -304,8 +299,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
urllib.parse.quote(server_name),
urllib.parse.quote(media_id),
)
- request, channel = self.make_request("POST", url, access_token=admin_user_tok,)
- self.render(request)
+ channel = self.make_request("POST", url, access_token=admin_user_tok,)
self.pump(1.0)
self.assertEqual(200, int(channel.code), msg=channel.result["body"])
@@ -357,8 +351,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/room/%s/media/quarantine" % urllib.parse.quote(
room_id
)
- request, channel = self.make_request("POST", url, access_token=admin_user_tok,)
- self.render(request)
+ channel = self.make_request("POST", url, access_token=admin_user_tok,)
self.pump(1.0)
self.assertEqual(200, int(channel.code), msg=channel.result["body"])
self.assertEqual(
@@ -402,10 +395,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote(
non_admin_user
)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", url.encode("ascii"), access_token=admin_user_tok,
)
- self.render(request)
self.pump(1.0)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(
@@ -445,10 +437,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote(
non_admin_user
)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", url.encode("ascii"), access_token=admin_user_tok,
)
- self.render(request)
self.pump(1.0)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(
@@ -462,14 +453,14 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
# Attempt to access each piece of media
- request, channel = self.make_request(
+ channel = make_request(
+ self.reactor,
+ FakeSite(self.download_resource),
"GET",
server_and_media_id_2,
shorthand=False,
access_token=non_admin_user_tok,
)
- request.render(self.download_resource)
- self.pump(1.0)
# Shouldn't be quarantined
self.assertEqual(
diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
index 92c9058887..248c4442c3 100644
--- a/tests/rest/admin/test_device.py
+++ b/tests/rest/admin/test_device.py
@@ -50,20 +50,17 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
"""
Try to get a device of an user without authentication.
"""
- request, channel = self.make_request("GET", self.url, b"{}")
- self.render(request)
+ channel = self.make_request("GET", self.url, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
- request, channel = self.make_request("PUT", self.url, b"{}")
- self.render(request)
+ channel = self.make_request("PUT", self.url, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
- request, channel = self.make_request("DELETE", self.url, b"{}")
- self.render(request)
+ channel = self.make_request("DELETE", self.url, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -72,26 +69,23 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
"""
If the user is not a server admin, an error is returned.
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url, access_token=self.other_user_token,
)
- self.render(request)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", self.url, access_token=self.other_user_token,
)
- self.render(request)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- request, channel = self.make_request(
+ channel = self.make_request(
"DELETE", self.url, access_token=self.other_user_token,
)
- self.render(request)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -105,26 +99,17 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
% self.other_user_device_id
)
- request, channel = self.make_request(
- "GET", url, access_token=self.admin_user_tok,
- )
- self.render(request)
+ channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
- request, channel = self.make_request(
- "PUT", url, access_token=self.admin_user_tok,
- )
- self.render(request)
+ channel = self.make_request("PUT", url, access_token=self.admin_user_tok,)
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
- request, channel = self.make_request(
- "DELETE", url, access_token=self.admin_user_tok,
- )
- self.render(request)
+ channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,)
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -138,26 +123,17 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
% self.other_user_device_id
)
- request, channel = self.make_request(
- "GET", url, access_token=self.admin_user_tok,
- )
- self.render(request)
+ channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
- request, channel = self.make_request(
- "PUT", url, access_token=self.admin_user_tok,
- )
- self.render(request)
+ channel = self.make_request("PUT", url, access_token=self.admin_user_tok,)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
- request, channel = self.make_request(
- "DELETE", url, access_token=self.admin_user_tok,
- )
- self.render(request)
+ channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
@@ -170,25 +146,16 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
self.other_user
)
- request, channel = self.make_request(
- "GET", url, access_token=self.admin_user_tok,
- )
- self.render(request)
+ channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
- request, channel = self.make_request(
- "PUT", url, access_token=self.admin_user_tok,
- )
- self.render(request)
+ channel = self.make_request("PUT", url, access_token=self.admin_user_tok,)
self.assertEqual(200, channel.code, msg=channel.json_body)
- request, channel = self.make_request(
- "DELETE", url, access_token=self.admin_user_tok,
- )
- self.render(request)
+ channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,)
# Delete unknown device returns status 200
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -212,22 +179,18 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
}
body = json.dumps(update)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
self.url,
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
- self.render(request)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.TOO_LARGE, channel.json_body["errcode"])
# Ensure the display name was not updated.
- request, channel = self.make_request(
- "GET", self.url, access_token=self.admin_user_tok,
- )
- self.render(request)
+ channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("new display", channel.json_body["display_name"])
@@ -244,18 +207,12 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
)
)
- request, channel = self.make_request(
- "PUT", self.url, access_token=self.admin_user_tok,
- )
- self.render(request)
+ channel = self.make_request("PUT", self.url, access_token=self.admin_user_tok,)
self.assertEqual(200, channel.code, msg=channel.json_body)
# Ensure the display name was not updated.
- request, channel = self.make_request(
- "GET", self.url, access_token=self.admin_user_tok,
- )
- self.render(request)
+ channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("new display", channel.json_body["display_name"])
@@ -266,21 +223,17 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
"""
# Set new display_name
body = json.dumps({"display_name": "new displayname"})
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
self.url,
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
- self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
# Check new display_name
- request, channel = self.make_request(
- "GET", self.url, access_token=self.admin_user_tok,
- )
- self.render(request)
+ channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("new displayname", channel.json_body["display_name"])
@@ -289,10 +242,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
"""
Tests that a normal lookup for a device is successfully
"""
- request, channel = self.make_request(
- "GET", self.url, access_token=self.admin_user_tok,
- )
- self.render(request)
+ channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["user_id"])
@@ -313,10 +263,9 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(1, number_devices)
# Delete device
- request, channel = self.make_request(
+ channel = self.make_request(
"DELETE", self.url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -346,8 +295,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
"""
Try to list devices of an user without authentication.
"""
- request, channel = self.make_request("GET", self.url, b"{}")
- self.render(request)
+ channel = self.make_request("GET", self.url, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -358,10 +306,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
"""
other_user_token = self.login("user", "pass")
- request, channel = self.make_request(
- "GET", self.url, access_token=other_user_token,
- )
- self.render(request)
+ channel = self.make_request("GET", self.url, access_token=other_user_token,)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -371,10 +316,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
Tests that a lookup for a user that does not exist returns a 404
"""
url = "/_synapse/admin/v2/users/@unknown_person:test/devices"
- request, channel = self.make_request(
- "GET", url, access_token=self.admin_user_tok,
- )
- self.render(request)
+ channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -385,14 +327,24 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
"""
url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices"
- request, channel = self.make_request(
- "GET", url, access_token=self.admin_user_tok,
- )
- self.render(request)
+ channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
+ def test_user_has_no_devices(self):
+ """
+ Tests that a normal lookup for devices is successfully
+ if user has no devices
+ """
+
+ # Get devices
+ channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(0, channel.json_body["total"])
+ self.assertEqual(0, len(channel.json_body["devices"]))
+
def test_get_devices(self):
"""
Tests that a normal lookup for devices is successfully
@@ -403,12 +355,10 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
self.login("user", "pass")
# Get devices
- request, channel = self.make_request(
- "GET", self.url, access_token=self.admin_user_tok,
- )
- self.render(request)
+ channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(number_devices, channel.json_body["total"])
self.assertEqual(number_devices, len(channel.json_body["devices"]))
self.assertEqual(self.other_user, channel.json_body["devices"][0]["user_id"])
# Check that all fields are available
@@ -443,8 +393,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
"""
Try to delete devices of an user without authentication.
"""
- request, channel = self.make_request("POST", self.url, b"{}")
- self.render(request)
+ channel = self.make_request("POST", self.url, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -455,10 +404,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
"""
other_user_token = self.login("user", "pass")
- request, channel = self.make_request(
- "POST", self.url, access_token=other_user_token,
- )
- self.render(request)
+ channel = self.make_request("POST", self.url, access_token=other_user_token,)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -468,10 +414,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
Tests that a lookup for a user that does not exist returns a 404
"""
url = "/_synapse/admin/v2/users/@unknown_person:test/delete_devices"
- request, channel = self.make_request(
- "POST", url, access_token=self.admin_user_tok,
- )
- self.render(request)
+ channel = self.make_request("POST", url, access_token=self.admin_user_tok,)
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -482,10 +425,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
"""
url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/delete_devices"
- request, channel = self.make_request(
- "POST", url, access_token=self.admin_user_tok,
- )
- self.render(request)
+ channel = self.make_request("POST", url, access_token=self.admin_user_tok,)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
@@ -495,13 +435,12 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
Tests that a remove of a device that does not exist returns 200.
"""
body = json.dumps({"devices": ["unknown_device1", "unknown_device2"]})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url,
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
- self.render(request)
# Delete unknown devices returns status 200
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -527,13 +466,12 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
# Delete devices
body = json.dumps({"devices": device_ids})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url,
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
- self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py
index bf79086f78..aa389df12f 100644
--- a/tests/rest/admin/test_event_reports.py
+++ b/tests/rest/admin/test_event_reports.py
@@ -70,15 +70,21 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
self.url = "/_synapse/admin/v1/event_reports"
+ def test_no_auth(self):
+ """
+ Try to get an event report without authentication.
+ """
+ channel = self.make_request("GET", self.url, b"{}")
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
def test_requester_is_no_admin(self):
"""
If the user is not a server admin, an error 403 is returned.
"""
- request, channel = self.make_request(
- "GET", self.url, access_token=self.other_user_tok,
- )
- self.render(request)
+ channel = self.make_request("GET", self.url, access_token=self.other_user_tok,)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -88,10 +94,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
Testing list of reported events
"""
- request, channel = self.make_request(
- "GET", self.url, access_token=self.admin_user_tok,
- )
- self.render(request)
+ channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 20)
@@ -104,10 +107,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
Testing list of reported events with limit
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?limit=5", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 20)
@@ -120,10 +122,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
Testing list of reported events with a defined starting point (from)
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?from=5", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 20)
@@ -136,10 +137,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
Testing list of reported events with a defined starting point and limit
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 20)
@@ -152,12 +152,11 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
Testing list of reported events with a filter of room
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
self.url + "?room_id=%s" % self.room_id1,
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 10)
@@ -173,12 +172,11 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
Testing list of reported events with a filter of user
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
self.url + "?user_id=%s" % self.other_user,
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 10)
@@ -194,12 +192,11 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
Testing list of reported events with a filter of user and room
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
self.url + "?user_id=%s&room_id=%s" % (self.other_user, self.room_id1),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 5)
@@ -217,10 +214,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
"""
# fetch the most recent first, largest timestamp
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?dir=b", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 20)
@@ -234,10 +230,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
report += 1
# fetch the oldest first, smallest timestamp
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?dir=f", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 20)
@@ -255,10 +250,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
Testing that a invalid search order returns a 400
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?dir=bar", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
@@ -266,13 +260,12 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
def test_limit_is_negative(self):
"""
- Testing that a negative list parameter returns a 400
+ Testing that a negative limit parameter returns a 400
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
@@ -282,10 +275,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
Testing that a negative from parameter returns a 400
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?from=-5", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
@@ -297,10 +289,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
# `next_token` does not appear
# Number of results is the number of entries
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?limit=20", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 20)
@@ -309,10 +300,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
# `next_token` does not appear
# Number of max results is larger than the number of entries
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?limit=21", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 20)
@@ -321,10 +311,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
# `next_token` does appear
# Number of max results is smaller than the number of entries
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?limit=19", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 20)
@@ -334,10 +323,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
# Check
# Set `from` to value of `next_token` for request remaining entries
# `next_token` does not appear
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url + "?from=19", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 20)
@@ -350,17 +338,16 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
resp = self.helper.send(room_id, tok=user_tok)
event_id = resp["event_id"]
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"rooms/%s/report/%s" % (room_id, event_id),
json.dumps({"score": -100, "reason": "this makes me sad"}),
access_token=user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
def _check_fields(self, content):
- """Checks that all attributes are present in a event report
+ """Checks that all attributes are present in an event report
"""
for c in content:
self.assertIn("id", c)
@@ -368,15 +355,163 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
self.assertIn("room_id", c)
self.assertIn("event_id", c)
self.assertIn("user_id", c)
- self.assertIn("reason", c)
- self.assertIn("content", c)
self.assertIn("sender", c)
- self.assertIn("room_alias", c)
- self.assertIn("event_json", c)
- self.assertIn("score", c["content"])
- self.assertIn("reason", c["content"])
- self.assertIn("auth_events", c["event_json"])
- self.assertIn("type", c["event_json"])
- self.assertIn("room_id", c["event_json"])
- self.assertIn("sender", c["event_json"])
- self.assertIn("content", c["event_json"])
+ self.assertIn("canonical_alias", c)
+ self.assertIn("name", c)
+ self.assertIn("score", c)
+ self.assertIn("reason", c)
+
+
+class EventReportDetailTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ report_event.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_tok = self.login("user", "pass")
+
+ self.room_id1 = self.helper.create_room_as(
+ self.other_user, tok=self.other_user_tok, is_public=True
+ )
+ self.helper.join(self.room_id1, user=self.admin_user, tok=self.admin_user_tok)
+
+ self._create_event_and_report(
+ room_id=self.room_id1, user_tok=self.other_user_tok,
+ )
+
+ # first created event report gets `id`=2
+ self.url = "/_synapse/admin/v1/event_reports/2"
+
+ def test_no_auth(self):
+ """
+ Try to get event report without authentication.
+ """
+ channel = self.make_request("GET", self.url, b"{}")
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error 403 is returned.
+ """
+
+ channel = self.make_request("GET", self.url, access_token=self.other_user_tok,)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_default_success(self):
+ """
+ Testing get a reported event
+ """
+
+ channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self._check_fields(channel.json_body)
+
+ def test_invalid_report_id(self):
+ """
+ Testing that an invalid `report_id` returns a 400.
+ """
+
+ # `report_id` is negative
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/event_reports/-123",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+ self.assertEqual(
+ "The report_id parameter must be a string representing a positive integer.",
+ channel.json_body["error"],
+ )
+
+ # `report_id` is a non-numerical string
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/event_reports/abcdef",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+ self.assertEqual(
+ "The report_id parameter must be a string representing a positive integer.",
+ channel.json_body["error"],
+ )
+
+ # `report_id` is undefined
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/event_reports/",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+ self.assertEqual(
+ "The report_id parameter must be a string representing a positive integer.",
+ channel.json_body["error"],
+ )
+
+ def test_report_id_not_found(self):
+ """
+ Testing that a not existing `report_id` returns a 404.
+ """
+
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/event_reports/123",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+ self.assertEqual("Event report not found", channel.json_body["error"])
+
+ def _create_event_and_report(self, room_id, user_tok):
+ """Create and report events
+ """
+ resp = self.helper.send(room_id, tok=user_tok)
+ event_id = resp["event_id"]
+
+ channel = self.make_request(
+ "POST",
+ "rooms/%s/report/%s" % (room_id, event_id),
+ json.dumps({"score": -100, "reason": "this makes me sad"}),
+ access_token=user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ def _check_fields(self, content):
+ """Checks that all attributes are present in a event report
+ """
+ self.assertIn("id", content)
+ self.assertIn("received_ts", content)
+ self.assertIn("room_id", content)
+ self.assertIn("event_id", content)
+ self.assertIn("user_id", content)
+ self.assertIn("sender", content)
+ self.assertIn("canonical_alias", content)
+ self.assertIn("name", content)
+ self.assertIn("event_json", content)
+ self.assertIn("score", content)
+ self.assertIn("reason", content)
+ self.assertIn("auth_events", content["event_json"])
+ self.assertIn("type", content["event_json"])
+ self.assertIn("room_id", content["event_json"])
+ self.assertIn("sender", content["event_json"])
+ self.assertIn("content", content["event_json"])
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
new file mode 100644
index 0000000000..c2b998cdae
--- /dev/null
+++ b/tests/rest/admin/test_media.py
@@ -0,0 +1,535 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Dirk Klimpel
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+from binascii import unhexlify
+
+import synapse.rest.admin
+from synapse.api.errors import Codes
+from synapse.rest.client.v1 import login, profile, room
+from synapse.rest.media.v1.filepath import MediaFilePaths
+
+from tests import unittest
+from tests.server import FakeSite, make_request
+
+
+class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ synapse.rest.admin.register_servlets_for_media_repo,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.handler = hs.get_device_handler()
+ self.media_repo = hs.get_media_repository_resource()
+ self.server_name = hs.hostname
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.filepaths = MediaFilePaths(hs.config.media_store_path)
+
+ def test_no_auth(self):
+ """
+ Try to delete media without authentication.
+ """
+ url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345")
+
+ channel = self.make_request("DELETE", url, b"{}")
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error is returned.
+ """
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_token = self.login("user", "pass")
+
+ url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345")
+
+ channel = self.make_request("DELETE", url, access_token=self.other_user_token,)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_media_does_not_exist(self):
+ """
+ Tests that a lookup for a media that does not exist returns a 404
+ """
+ url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345")
+
+ channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,)
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_media_is_not_local(self):
+ """
+ Tests that a lookup for a media that is not a local returns a 400
+ """
+ url = "/_synapse/admin/v1/media/%s/%s" % ("unknown_domain", "12345")
+
+ channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,)
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Can only delete local media", channel.json_body["error"])
+
+ def test_delete_media(self):
+ """
+ Tests that delete a media is successfully
+ """
+
+ download_resource = self.media_repo.children[b"download"]
+ upload_resource = self.media_repo.children[b"upload"]
+ image_data = unhexlify(
+ b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+ b"0000001f15c4890000000a49444154789c63000100000500010d"
+ b"0a2db40000000049454e44ae426082"
+ )
+
+ # Upload some media into the room
+ response = self.helper.upload_media(
+ upload_resource, image_data, tok=self.admin_user_tok, expect_code=200
+ )
+ # Extract media ID from the response
+ server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
+ server_name, media_id = server_and_media_id.split("/")
+
+ self.assertEqual(server_name, self.server_name)
+
+ # Attempt to access media
+ channel = make_request(
+ self.reactor,
+ FakeSite(download_resource),
+ "GET",
+ server_and_media_id,
+ shorthand=False,
+ access_token=self.admin_user_tok,
+ )
+
+ # Should be successful
+ self.assertEqual(
+ 200,
+ channel.code,
+ msg=(
+ "Expected to receive a 200 on accessing media: %s" % server_and_media_id
+ ),
+ )
+
+ # Test if the file exists
+ local_path = self.filepaths.local_media_filepath(media_id)
+ self.assertTrue(os.path.exists(local_path))
+
+ url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, media_id)
+
+ # Delete media
+ channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(1, channel.json_body["total"])
+ self.assertEqual(
+ media_id, channel.json_body["deleted_media"][0],
+ )
+
+ # Attempt to access media
+ channel = make_request(
+ self.reactor,
+ FakeSite(download_resource),
+ "GET",
+ server_and_media_id,
+ shorthand=False,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(
+ 404,
+ channel.code,
+ msg=(
+ "Expected to receive a 404 on accessing deleted media: %s"
+ % server_and_media_id
+ ),
+ )
+
+ # Test if the file is deleted
+ self.assertFalse(os.path.exists(local_path))
+
+
+class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ synapse.rest.admin.register_servlets_for_media_repo,
+ login.register_servlets,
+ profile.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.handler = hs.get_device_handler()
+ self.media_repo = hs.get_media_repository_resource()
+ self.server_name = hs.hostname
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.filepaths = MediaFilePaths(hs.config.media_store_path)
+ self.url = "/_synapse/admin/v1/media/%s/delete" % self.server_name
+
+ def test_no_auth(self):
+ """
+ Try to delete media without authentication.
+ """
+
+ channel = self.make_request("POST", self.url, b"{}")
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error is returned.
+ """
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_token = self.login("user", "pass")
+
+ channel = self.make_request(
+ "POST", self.url, access_token=self.other_user_token,
+ )
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_media_is_not_local(self):
+ """
+ Tests that a lookup for media that is not local returns a 400
+ """
+ url = "/_synapse/admin/v1/media/%s/delete" % "unknown_domain"
+
+ channel = self.make_request(
+ "POST", url + "?before_ts=1234", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Can only delete local media", channel.json_body["error"])
+
+ def test_missing_parameter(self):
+ """
+ If the parameter `before_ts` is missing, an error is returned.
+ """
+ channel = self.make_request("POST", self.url, access_token=self.admin_user_tok,)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
+ self.assertEqual(
+ "Missing integer query parameter b'before_ts'", channel.json_body["error"]
+ )
+
+ def test_invalid_parameter(self):
+ """
+ If parameters are invalid, an error is returned.
+ """
+ channel = self.make_request(
+ "POST", self.url + "?before_ts=-1234", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+ self.assertEqual(
+ "Query parameter before_ts must be a string representing a positive integer.",
+ channel.json_body["error"],
+ )
+
+ channel = self.make_request(
+ "POST",
+ self.url + "?before_ts=1234&size_gt=-1234",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+ self.assertEqual(
+ "Query parameter size_gt must be a string representing a positive integer.",
+ channel.json_body["error"],
+ )
+
+ channel = self.make_request(
+ "POST",
+ self.url + "?before_ts=1234&keep_profiles=not_bool",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+ self.assertEqual(
+ "Boolean query parameter b'keep_profiles' must be one of ['true', 'false']",
+ channel.json_body["error"],
+ )
+
+ def test_delete_media_never_accessed(self):
+ """
+ Tests that media deleted if it is older than `before_ts` and never accessed
+ `last_access_ts` is `NULL` and `created_ts` < `before_ts`
+ """
+
+ # upload and do not access
+ server_and_media_id = self._create_media()
+ self.pump(1.0)
+
+ # test that the file exists
+ media_id = server_and_media_id.split("/")[1]
+ local_path = self.filepaths.local_media_filepath(media_id)
+ self.assertTrue(os.path.exists(local_path))
+
+ # timestamp after upload/create
+ now_ms = self.clock.time_msec()
+ channel = self.make_request(
+ "POST",
+ self.url + "?before_ts=" + str(now_ms),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(1, channel.json_body["total"])
+ self.assertEqual(
+ media_id, channel.json_body["deleted_media"][0],
+ )
+
+ self._access_media(server_and_media_id, False)
+
+ def test_keep_media_by_date(self):
+ """
+ Tests that media is not deleted if it is newer than `before_ts`
+ """
+
+ # timestamp before upload
+ now_ms = self.clock.time_msec()
+ server_and_media_id = self._create_media()
+
+ self._access_media(server_and_media_id)
+
+ channel = self.make_request(
+ "POST",
+ self.url + "?before_ts=" + str(now_ms),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(0, channel.json_body["total"])
+
+ self._access_media(server_and_media_id)
+
+ # timestamp after upload
+ now_ms = self.clock.time_msec()
+ channel = self.make_request(
+ "POST",
+ self.url + "?before_ts=" + str(now_ms),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(1, channel.json_body["total"])
+ self.assertEqual(
+ server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0],
+ )
+
+ self._access_media(server_and_media_id, False)
+
+ def test_keep_media_by_size(self):
+ """
+ Tests that media is not deleted if its size is smaller than or equal
+ to `size_gt`
+ """
+ server_and_media_id = self._create_media()
+
+ self._access_media(server_and_media_id)
+
+ now_ms = self.clock.time_msec()
+ channel = self.make_request(
+ "POST",
+ self.url + "?before_ts=" + str(now_ms) + "&size_gt=67",
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(0, channel.json_body["total"])
+
+ self._access_media(server_and_media_id)
+
+ now_ms = self.clock.time_msec()
+ channel = self.make_request(
+ "POST",
+ self.url + "?before_ts=" + str(now_ms) + "&size_gt=66",
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(1, channel.json_body["total"])
+ self.assertEqual(
+ server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0],
+ )
+
+ self._access_media(server_and_media_id, False)
+
+ def test_keep_media_by_user_avatar(self):
+ """
+ Tests that we do not delete media if is used as a user avatar
+ Tests parameter `keep_profiles`
+ """
+ server_and_media_id = self._create_media()
+
+ self._access_media(server_and_media_id)
+
+ # set media as avatar
+ channel = self.make_request(
+ "PUT",
+ "/profile/%s/avatar_url" % (self.admin_user,),
+ content=json.dumps({"avatar_url": "mxc://%s" % (server_and_media_id,)}),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ now_ms = self.clock.time_msec()
+ channel = self.make_request(
+ "POST",
+ self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true",
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(0, channel.json_body["total"])
+
+ self._access_media(server_and_media_id)
+
+ now_ms = self.clock.time_msec()
+ channel = self.make_request(
+ "POST",
+ self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false",
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(1, channel.json_body["total"])
+ self.assertEqual(
+ server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0],
+ )
+
+ self._access_media(server_and_media_id, False)
+
+ def test_keep_media_by_room_avatar(self):
+ """
+ Tests that we do not delete media if it is used as a room avatar
+ Tests parameter `keep_profiles`
+ """
+ server_and_media_id = self._create_media()
+
+ self._access_media(server_and_media_id)
+
+ # set media as room avatar
+ room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ channel = self.make_request(
+ "PUT",
+ "/rooms/%s/state/m.room.avatar" % (room_id,),
+ content=json.dumps({"url": "mxc://%s" % (server_and_media_id,)}),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ now_ms = self.clock.time_msec()
+ channel = self.make_request(
+ "POST",
+ self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true",
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(0, channel.json_body["total"])
+
+ self._access_media(server_and_media_id)
+
+ now_ms = self.clock.time_msec()
+ channel = self.make_request(
+ "POST",
+ self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false",
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(1, channel.json_body["total"])
+ self.assertEqual(
+ server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0],
+ )
+
+ self._access_media(server_and_media_id, False)
+
+ def _create_media(self):
+ """
+ Create a media and return media_id and server_and_media_id
+ """
+ upload_resource = self.media_repo.children[b"upload"]
+ # file size is 67 Byte
+ image_data = unhexlify(
+ b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+ b"0000001f15c4890000000a49444154789c63000100000500010d"
+ b"0a2db40000000049454e44ae426082"
+ )
+
+ # Upload some media into the room
+ response = self.helper.upload_media(
+ upload_resource, image_data, tok=self.admin_user_tok, expect_code=200
+ )
+ # Extract media ID from the response
+ server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
+ server_name = server_and_media_id.split("/")[0]
+
+ # Check that new media is a local and not remote
+ self.assertEqual(server_name, self.server_name)
+
+ return server_and_media_id
+
+ def _access_media(self, server_and_media_id, expect_success=True):
+ """
+ Try to access a media and check the result
+ """
+ download_resource = self.media_repo.children[b"download"]
+
+ media_id = server_and_media_id.split("/")[1]
+ local_path = self.filepaths.local_media_filepath(media_id)
+
+ channel = make_request(
+ self.reactor,
+ FakeSite(download_resource),
+ "GET",
+ server_and_media_id,
+ shorthand=False,
+ access_token=self.admin_user_tok,
+ )
+
+ if expect_success:
+ self.assertEqual(
+ 200,
+ channel.code,
+ msg=(
+ "Expected to receive a 200 on accessing media: %s"
+ % server_and_media_id
+ ),
+ )
+ # Test that the file exists
+ self.assertTrue(os.path.exists(local_path))
+ else:
+ self.assertEqual(
+ 404,
+ channel.code,
+ msg=(
+ "Expected to receive a 404 on accessing deleted media: %s"
+ % (server_and_media_id)
+ ),
+ )
+ # Test that the file is deleted
+ self.assertFalse(os.path.exists(local_path))
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 6dfc709dc5..fa620f97f3 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -20,6 +20,7 @@ from typing import List, Optional
from mock import Mock
import synapse.rest.admin
+from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import Codes
from synapse.rest.client.v1 import directory, events, login, room
@@ -78,14 +79,13 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
)
# Test that the admin can still send shutdown
- url = "admin/shutdown_room/" + room_id
- request, channel = self.make_request(
+ url = "/_synapse/admin/v1/shutdown_room/" + room_id
+ channel = self.make_request(
"POST",
url.encode("ascii"),
json.dumps({"new_room_user_id": self.admin_user}),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -104,24 +104,22 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
# Enable world readable
url = "rooms/%s/state/m.room.history_visibility" % (room_id,)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
url.encode("ascii"),
json.dumps({"history_visibility": "world_readable"}),
access_token=self.other_user_token,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Test that the admin can still send shutdown
- url = "admin/shutdown_room/" + room_id
- request, channel = self.make_request(
+ url = "/_synapse/admin/v1/shutdown_room/" + room_id
+ channel = self.make_request(
"POST",
url.encode("ascii"),
json.dumps({"new_room_user_id": self.admin_user}),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -133,19 +131,17 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
"""
url = "rooms/%s/initialSync" % (room_id,)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok
)
- self.render(request)
self.assertEqual(
expect_code, int(channel.result["code"]), msg=channel.result["body"]
)
url = "events?timeout=0&room_id=" + room_id
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok
)
- self.render(request)
self.assertEqual(
expect_code, int(channel.result["code"]), msg=channel.result["body"]
)
@@ -189,10 +185,9 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
If the user is not a server admin, an error 403 is returned.
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", self.url, json.dumps({}), access_token=self.other_user_tok,
)
- self.render(request)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -203,10 +198,9 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
"""
url = "/_synapse/admin/v1/rooms/!unknown:test/delete"
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", url, json.dumps({}), access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -217,10 +211,9 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
"""
url = "/_synapse/admin/v1/rooms/invalidroom/delete"
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", url, json.dumps({}), access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(
@@ -233,13 +226,12 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
"""
body = json.dumps({"new_room_user_id": "@unknown:test"})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url,
content=body.encode(encoding="utf_8"),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertIn("new_room_id", channel.json_body)
@@ -253,13 +245,12 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
"""
body = json.dumps({"new_room_user_id": "@not:exist.bla"})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url,
content=body.encode(encoding="utf_8"),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(
@@ -272,13 +263,12 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
"""
body = json.dumps({"block": "NotBool"})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url,
content=body.encode(encoding="utf_8"),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
@@ -289,13 +279,12 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
"""
body = json.dumps({"purge": "NotBool"})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url,
content=body.encode(encoding="utf_8"),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
@@ -316,13 +305,12 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
body = json.dumps({"block": True, "purge": True})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url.encode("ascii"),
content=body.encode(encoding="utf_8"),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(None, channel.json_body["new_room_id"])
@@ -350,13 +338,12 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
body = json.dumps({"block": False, "purge": True})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url.encode("ascii"),
content=body.encode(encoding="utf_8"),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(None, channel.json_body["new_room_id"])
@@ -385,13 +372,12 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
body = json.dumps({"block": False, "purge": False})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url.encode("ascii"),
content=body.encode(encoding="utf_8"),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(None, channel.json_body["new_room_id"])
@@ -433,13 +419,12 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
# Test that the admin can still send shutdown
url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
url.encode("ascii"),
json.dumps({"new_room_user_id": self.admin_user}),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
@@ -464,13 +449,12 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
# Enable world readable
url = "rooms/%s/state/m.room.history_visibility" % (self.room_id,)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
url.encode("ascii"),
json.dumps({"history_visibility": "world_readable"}),
access_token=self.other_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Test that room is not purged
@@ -482,13 +466,12 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
# Test that the admin can still send shutdown
url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
url.encode("ascii"),
json.dumps({"new_room_user_id": self.admin_user}),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
@@ -531,40 +514,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
def _is_purged(self, room_id):
"""Test that the following tables have been purged of all rows related to the room.
"""
- for table in (
- "current_state_events",
- "event_backward_extremities",
- "event_forward_extremities",
- "event_json",
- "event_push_actions",
- "event_search",
- "events",
- "group_rooms",
- "public_room_list_stream",
- "receipts_graph",
- "receipts_linearized",
- "room_aliases",
- "room_depth",
- "room_memberships",
- "room_stats_state",
- "room_stats_current",
- "room_stats_historical",
- "room_stats_earliest_token",
- "rooms",
- "stream_ordering_to_exterm",
- "users_in_public_rooms",
- "users_who_share_private_rooms",
- "appservice_room_list",
- "e2e_room_keys",
- "event_push_summary",
- "pusher_throttle",
- "group_summary_rooms",
- "local_invites",
- "room_account_data",
- "room_tags",
- # "state_groups", # Current impl leaves orphaned state groups around.
- "state_groups_state",
- ):
+ for table in PURGE_TABLES:
count = self.get_success(
self.store.db_pool.simple_select_one_onecol(
table=table,
@@ -581,19 +531,17 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
"""
url = "rooms/%s/initialSync" % (room_id,)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok
)
- self.render(request)
self.assertEqual(
expect_code, int(channel.result["code"]), msg=channel.result["body"]
)
url = "events?timeout=0&room_id=" + room_id
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok
)
- self.render(request)
self.assertEqual(
expect_code, int(channel.result["code"]), msg=channel.result["body"]
)
@@ -622,50 +570,17 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase):
self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok)
url = "/_synapse/admin/v1/purge_room"
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
url.encode("ascii"),
{"room_id": room_id},
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Test that the following tables have been purged of all rows related to the room.
- for table in (
- "current_state_events",
- "event_backward_extremities",
- "event_forward_extremities",
- "event_json",
- "event_push_actions",
- "event_search",
- "events",
- "group_rooms",
- "public_room_list_stream",
- "receipts_graph",
- "receipts_linearized",
- "room_aliases",
- "room_depth",
- "room_memberships",
- "room_stats_state",
- "room_stats_current",
- "room_stats_historical",
- "room_stats_earliest_token",
- "rooms",
- "stream_ordering_to_exterm",
- "users_in_public_rooms",
- "users_who_share_private_rooms",
- "appservice_room_list",
- "e2e_room_keys",
- "event_push_summary",
- "pusher_throttle",
- "group_summary_rooms",
- "room_account_data",
- "room_tags",
- # "state_groups", # Current impl leaves orphaned state groups around.
- "state_groups_state",
- ):
+ for table in PURGE_TABLES:
count = self.get_success(
self.store.db_pool.simple_select_one_onecol(
table=table,
@@ -709,10 +624,9 @@ class RoomTestCase(unittest.HomeserverTestCase):
# Request the list of rooms
url = "/_synapse/admin/v1/rooms"
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok,
)
- self.render(request)
# Check request completed successfully
self.assertEqual(200, int(channel.code), msg=channel.json_body)
@@ -791,10 +705,9 @@ class RoomTestCase(unittest.HomeserverTestCase):
limit,
"name",
)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(
200, int(channel.result["code"]), msg=channel.result["body"]
)
@@ -832,10 +745,9 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(room_ids, returned_room_ids)
url = "/_synapse/admin/v1/rooms?from=%d&limit=%d" % (start, limit)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
def test_correct_room_attributes(self):
@@ -853,13 +765,12 @@ class RoomTestCase(unittest.HomeserverTestCase):
# Create a new alias to this room
url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
url.encode("ascii"),
{"room_id": room_id},
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Set this new alias as the canonical alias for this room
@@ -884,10 +795,9 @@ class RoomTestCase(unittest.HomeserverTestCase):
# Request the list of rooms
url = "/_synapse/admin/v1/rooms"
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Check that rooms were returned
@@ -926,13 +836,12 @@ class RoomTestCase(unittest.HomeserverTestCase):
url = "/_matrix/client/r0/directory/room/%s" % (
urllib.parse.quote(test_alias),
)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
url.encode("ascii"),
{"room_id": room_id},
access_token=admin_user_tok,
)
- self.render(request)
self.assertEqual(
200, int(channel.result["code"]), msg=channel.result["body"]
)
@@ -967,10 +876,9 @@ class RoomTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/rooms?order_by=%s" % (order_type,)
if reverse:
url += "&dir=b"
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
# Check that rooms were returned
@@ -1104,10 +1012,9 @@ class RoomTestCase(unittest.HomeserverTestCase):
expected_http_code: The expected http code for the request
"""
url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
if expected_http_code != 200:
@@ -1144,6 +1051,13 @@ class RoomTestCase(unittest.HomeserverTestCase):
_search_test(room_id_2, "else")
_search_test(room_id_2, "se")
+ # Test case insensitive
+ _search_test(room_id_1, "SOMETHING")
+ _search_test(room_id_1, "THING")
+
+ _search_test(room_id_2, "ELSE")
+ _search_test(room_id_2, "SE")
+
_search_test(None, "foo")
_search_test(None, "bar")
_search_test(None, "", expected_http_code=400)
@@ -1166,10 +1080,9 @@ class RoomTestCase(unittest.HomeserverTestCase):
)
url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("room_id", channel.json_body)
@@ -1179,6 +1092,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertIn("canonical_alias", channel.json_body)
self.assertIn("joined_members", channel.json_body)
self.assertIn("joined_local_members", channel.json_body)
+ self.assertIn("joined_local_devices", channel.json_body)
self.assertIn("version", channel.json_body)
self.assertIn("creator", channel.json_body)
self.assertIn("encryption", channel.json_body)
@@ -1191,6 +1105,39 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(room_id_1, channel.json_body["room_id"])
+ def test_single_room_devices(self):
+ """Test that `joined_local_devices` can be requested correctly"""
+ room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
+ channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(1, channel.json_body["joined_local_devices"])
+
+ # Have another user join the room
+ user_1 = self.register_user("foo", "pass")
+ user_tok_1 = self.login("foo", "pass")
+ self.helper.join(room_id_1, user_1, tok=user_tok_1)
+
+ url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
+ channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(2, channel.json_body["joined_local_devices"])
+
+ # leave room
+ self.helper.leave(room_id_1, self.admin_user, tok=self.admin_user_tok)
+ self.helper.leave(room_id_1, user_1, tok=user_tok_1)
+ url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
+ channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(0, channel.json_body["joined_local_devices"])
+
def test_room_members(self):
"""Test that room members can be requested correctly"""
# Create two test rooms
@@ -1214,10 +1161,9 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.helper.join(room_id_2, user_3, tok=user_tok_3)
url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_1,)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertCountEqual(
@@ -1226,10 +1172,9 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["total"], 3)
url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_2,)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertCountEqual(
@@ -1267,13 +1212,12 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"""
body = json.dumps({"user_id": self.second_user_id})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url,
content=body.encode(encoding="utf_8"),
access_token=self.second_tok,
)
- self.render(request)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -1284,13 +1228,12 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"""
body = json.dumps({"unknown_parameter": "@unknown:test"})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url,
content=body.encode(encoding="utf_8"),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
@@ -1301,13 +1244,12 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"""
body = json.dumps({"user_id": "@unknown:test"})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url,
content=body.encode(encoding="utf_8"),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -1318,13 +1260,12 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"""
body = json.dumps({"user_id": "@not:exist.bla"})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url,
content=body.encode(encoding="utf_8"),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(
@@ -1339,13 +1280,12 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
body = json.dumps({"user_id": self.second_user_id})
url = "/_synapse/admin/v1/join/!unknown:test"
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
url,
content=body.encode(encoding="utf_8"),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("No known servers", channel.json_body["error"])
@@ -1357,13 +1297,12 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
body = json.dumps({"user_id": self.second_user_id})
url = "/_synapse/admin/v1/join/invalidroom"
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
url,
content=body.encode(encoding="utf_8"),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(
@@ -1377,23 +1316,21 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"""
body = json.dumps({"user_id": self.second_user_id})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
self.url,
content=body.encode(encoding="utf_8"),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(self.public_room_id, channel.json_body["room_id"])
# Validate if user is a member of the room
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
)
- self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0])
@@ -1408,13 +1345,12 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/join/{}".format(private_room_id)
body = json.dumps({"user_id": self.second_user_id})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
url,
content=body.encode(encoding="utf_8"),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -1439,10 +1375,9 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
# Validate if server admin is a member of the room
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
@@ -1451,22 +1386,20 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/join/{}".format(private_room_id)
body = json.dumps({"user_id": self.second_user_id})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
url,
content=body.encode(encoding="utf_8"),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(private_room_id, channel.json_body["room_id"])
# Validate if user is a member of the room
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
)
- self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
@@ -1481,22 +1414,192 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/join/{}".format(private_room_id)
body = json.dumps({"user_id": self.second_user_id})
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
url,
content=body.encode(encoding="utf_8"),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(private_room_id, channel.json_body["room_id"])
# Validate if user is a member of the room
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
)
- self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
+
+
+class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, homeserver):
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.creator = self.register_user("creator", "test")
+ self.creator_tok = self.login("creator", "test")
+
+ self.second_user_id = self.register_user("second", "test")
+ self.second_tok = self.login("second", "test")
+
+ self.public_room_id = self.helper.create_room_as(
+ self.creator, tok=self.creator_tok, is_public=True
+ )
+ self.url = "/_synapse/admin/v1/rooms/{}/make_room_admin".format(
+ self.public_room_id
+ )
+
+ def test_public_room(self):
+ """Test that getting admin in a public room works.
+ """
+ room_id = self.helper.create_room_as(
+ self.creator, tok=self.creator_tok, is_public=True
+ )
+
+ channel = self.make_request(
+ "POST",
+ "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id),
+ content={},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Now we test that we can join the room and ban a user.
+ self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok)
+ self.helper.change_membership(
+ room_id,
+ self.admin_user,
+ "@test:test",
+ Membership.BAN,
+ tok=self.admin_user_tok,
+ )
+
+ def test_private_room(self):
+ """Test that getting admin in a private room works and we get invited.
+ """
+ room_id = self.helper.create_room_as(
+ self.creator, tok=self.creator_tok, is_public=False,
+ )
+
+ channel = self.make_request(
+ "POST",
+ "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id),
+ content={},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Now we test that we can join the room (we should have received an
+ # invite) and can ban a user.
+ self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok)
+ self.helper.change_membership(
+ room_id,
+ self.admin_user,
+ "@test:test",
+ Membership.BAN,
+ tok=self.admin_user_tok,
+ )
+
+ def test_other_user(self):
+ """Test that giving admin in a public room works to a non-admin user works.
+ """
+ room_id = self.helper.create_room_as(
+ self.creator, tok=self.creator_tok, is_public=True
+ )
+
+ channel = self.make_request(
+ "POST",
+ "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id),
+ content={"user_id": self.second_user_id},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Now we test that we can join the room and ban a user.
+ self.helper.join(room_id, self.second_user_id, tok=self.second_tok)
+ self.helper.change_membership(
+ room_id,
+ self.second_user_id,
+ "@test:test",
+ Membership.BAN,
+ tok=self.second_tok,
+ )
+
+ def test_not_enough_power(self):
+ """Test that we get a sensible error if there are no local room admins.
+ """
+ room_id = self.helper.create_room_as(
+ self.creator, tok=self.creator_tok, is_public=True
+ )
+
+ # The creator drops admin rights in the room.
+ pl = self.helper.get_state(
+ room_id, EventTypes.PowerLevels, tok=self.creator_tok
+ )
+ pl["users"][self.creator] = 0
+ self.helper.send_state(
+ room_id, EventTypes.PowerLevels, body=pl, tok=self.creator_tok
+ )
+
+ channel = self.make_request(
+ "POST",
+ "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id),
+ content={},
+ access_token=self.admin_user_tok,
+ )
+
+ # We expect this to fail with a 400 as there are no room admins.
+ #
+ # (Note we assert the error message to ensure that it's not denied for
+ # some other reason)
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ channel.json_body["error"],
+ "No local admin user in room with power to update power levels.",
+ )
+
+
+PURGE_TABLES = [
+ "current_state_events",
+ "event_backward_extremities",
+ "event_forward_extremities",
+ "event_json",
+ "event_push_actions",
+ "event_search",
+ "events",
+ "group_rooms",
+ "public_room_list_stream",
+ "receipts_graph",
+ "receipts_linearized",
+ "room_aliases",
+ "room_depth",
+ "room_memberships",
+ "room_stats_state",
+ "room_stats_current",
+ "room_stats_historical",
+ "room_stats_earliest_token",
+ "rooms",
+ "stream_ordering_to_exterm",
+ "users_in_public_rooms",
+ "users_who_share_private_rooms",
+ "appservice_room_list",
+ "e2e_room_keys",
+ "event_push_summary",
+ "pusher_throttle",
+ "group_summary_rooms",
+ "room_account_data",
+ "room_tags",
+ # "state_groups", # Current impl leaves orphaned state groups around.
+ "state_groups_state",
+]
diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py
new file mode 100644
index 0000000000..73f8a8ec99
--- /dev/null
+++ b/tests/rest/admin/test_statistics.py
@@ -0,0 +1,452 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Dirk Klimpel
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from binascii import unhexlify
+from typing import Any, Dict, List, Optional
+
+import synapse.rest.admin
+from synapse.api.errors import Codes
+from synapse.rest.client.v1 import login
+
+from tests import unittest
+
+
+class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.media_repo = hs.get_media_repository_resource()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_tok = self.login("user", "pass")
+
+ self.url = "/_synapse/admin/v1/statistics/users/media"
+
+ def test_no_auth(self):
+ """
+ Try to list users without authentication.
+ """
+ channel = self.make_request("GET", self.url, b"{}")
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error 403 is returned.
+ """
+ channel = self.make_request(
+ "GET", self.url, json.dumps({}), access_token=self.other_user_tok,
+ )
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_invalid_parameter(self):
+ """
+ If parameters are invalid, an error is returned.
+ """
+ # unkown order_by
+ channel = self.make_request(
+ "GET", self.url + "?order_by=bar", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ # negative from
+ channel = self.make_request(
+ "GET", self.url + "?from=-5", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ # negative limit
+ channel = self.make_request(
+ "GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ # negative from_ts
+ channel = self.make_request(
+ "GET", self.url + "?from_ts=-1234", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ # negative until_ts
+ channel = self.make_request(
+ "GET", self.url + "?until_ts=-1234", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ # until_ts smaller from_ts
+ channel = self.make_request(
+ "GET",
+ self.url + "?from_ts=10&until_ts=5",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ # empty search term
+ channel = self.make_request(
+ "GET", self.url + "?search_term=", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ # invalid search order
+ channel = self.make_request(
+ "GET", self.url + "?dir=bar", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ def test_limit(self):
+ """
+ Testing list of media with limit
+ """
+ self._create_users_with_media(10, 2)
+
+ channel = self.make_request(
+ "GET", self.url + "?limit=5", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 10)
+ self.assertEqual(len(channel.json_body["users"]), 5)
+ self.assertEqual(channel.json_body["next_token"], 5)
+ self._check_fields(channel.json_body["users"])
+
+ def test_from(self):
+ """
+ Testing list of media with a defined starting point (from)
+ """
+ self._create_users_with_media(20, 2)
+
+ channel = self.make_request(
+ "GET", self.url + "?from=5", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["users"]), 15)
+ self.assertNotIn("next_token", channel.json_body)
+ self._check_fields(channel.json_body["users"])
+
+ def test_limit_and_from(self):
+ """
+ Testing list of media with a defined starting point and limit
+ """
+ self._create_users_with_media(20, 2)
+
+ channel = self.make_request(
+ "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(channel.json_body["next_token"], 15)
+ self.assertEqual(len(channel.json_body["users"]), 10)
+ self._check_fields(channel.json_body["users"])
+
+ def test_next_token(self):
+ """
+ Testing that `next_token` appears at the right place
+ """
+
+ number_users = 20
+ self._create_users_with_media(number_users, 3)
+
+ # `next_token` does not appear
+ # Number of results is the number of entries
+ channel = self.make_request(
+ "GET", self.url + "?limit=20", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_users)
+ self.assertEqual(len(channel.json_body["users"]), number_users)
+ self.assertNotIn("next_token", channel.json_body)
+
+ # `next_token` does not appear
+ # Number of max results is larger than the number of entries
+ channel = self.make_request(
+ "GET", self.url + "?limit=21", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_users)
+ self.assertEqual(len(channel.json_body["users"]), number_users)
+ self.assertNotIn("next_token", channel.json_body)
+
+ # `next_token` does appear
+ # Number of max results is smaller than the number of entries
+ channel = self.make_request(
+ "GET", self.url + "?limit=19", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_users)
+ self.assertEqual(len(channel.json_body["users"]), 19)
+ self.assertEqual(channel.json_body["next_token"], 19)
+
+ # Set `from` to value of `next_token` for request remaining entries
+ # Check `next_token` does not appear
+ channel = self.make_request(
+ "GET", self.url + "?from=19", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_users)
+ self.assertEqual(len(channel.json_body["users"]), 1)
+ self.assertNotIn("next_token", channel.json_body)
+
+ def test_no_media(self):
+ """
+ Tests that a normal lookup for statistics is successfully
+ if users have no media created
+ """
+
+ channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(0, channel.json_body["total"])
+ self.assertEqual(0, len(channel.json_body["users"]))
+
+ def test_order_by(self):
+ """
+ Testing order list with parameter `order_by`
+ """
+
+ # create users
+ self.register_user("user_a", "pass", displayname="UserZ")
+ userA_tok = self.login("user_a", "pass")
+ self._create_media(userA_tok, 1)
+
+ self.register_user("user_b", "pass", displayname="UserY")
+ userB_tok = self.login("user_b", "pass")
+ self._create_media(userB_tok, 3)
+
+ self.register_user("user_c", "pass", displayname="UserX")
+ userC_tok = self.login("user_c", "pass")
+ self._create_media(userC_tok, 2)
+
+ # order by user_id
+ self._order_test("user_id", ["@user_a:test", "@user_b:test", "@user_c:test"])
+ self._order_test(
+ "user_id", ["@user_a:test", "@user_b:test", "@user_c:test"], "f",
+ )
+ self._order_test(
+ "user_id", ["@user_c:test", "@user_b:test", "@user_a:test"], "b",
+ )
+
+ # order by displayname
+ self._order_test(
+ "displayname", ["@user_c:test", "@user_b:test", "@user_a:test"]
+ )
+ self._order_test(
+ "displayname", ["@user_c:test", "@user_b:test", "@user_a:test"], "f",
+ )
+ self._order_test(
+ "displayname", ["@user_a:test", "@user_b:test", "@user_c:test"], "b",
+ )
+
+ # order by media_length
+ self._order_test(
+ "media_length", ["@user_a:test", "@user_c:test", "@user_b:test"],
+ )
+ self._order_test(
+ "media_length", ["@user_a:test", "@user_c:test", "@user_b:test"], "f",
+ )
+ self._order_test(
+ "media_length", ["@user_b:test", "@user_c:test", "@user_a:test"], "b",
+ )
+
+ # order by media_count
+ self._order_test(
+ "media_count", ["@user_a:test", "@user_c:test", "@user_b:test"],
+ )
+ self._order_test(
+ "media_count", ["@user_a:test", "@user_c:test", "@user_b:test"], "f",
+ )
+ self._order_test(
+ "media_count", ["@user_b:test", "@user_c:test", "@user_a:test"], "b",
+ )
+
+ def test_from_until_ts(self):
+ """
+ Testing filter by time with parameters `from_ts` and `until_ts`
+ """
+ # create media earlier than `ts1` to ensure that `from_ts` is working
+ self._create_media(self.other_user_tok, 3)
+ self.pump(1)
+ ts1 = self.clock.time_msec()
+
+ # list all media when filter is not set
+ channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["users"][0]["media_count"], 3)
+
+ # filter media starting at `ts1` after creating first media
+ # result is 0
+ channel = self.make_request(
+ "GET", self.url + "?from_ts=%s" % (ts1,), access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 0)
+
+ self._create_media(self.other_user_tok, 3)
+ self.pump(1)
+ ts2 = self.clock.time_msec()
+ # create media after `ts2` to ensure that `until_ts` is working
+ self._create_media(self.other_user_tok, 3)
+
+ # filter media between `ts1` and `ts2`
+ channel = self.make_request(
+ "GET",
+ self.url + "?from_ts=%s&until_ts=%s" % (ts1, ts2),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["users"][0]["media_count"], 3)
+
+ # filter media until `ts2` and earlier
+ channel = self.make_request(
+ "GET", self.url + "?until_ts=%s" % (ts2,), access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["users"][0]["media_count"], 6)
+
+ def test_search_term(self):
+ self._create_users_with_media(20, 1)
+
+ # check without filter get all users
+ channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+
+ # filter user 1 and 10-19 by `user_id`
+ channel = self.make_request(
+ "GET",
+ self.url + "?search_term=foo_user_1",
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 11)
+
+ # filter on this user in `displayname`
+ channel = self.make_request(
+ "GET",
+ self.url + "?search_term=bar_user_10",
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["users"][0]["displayname"], "bar_user_10")
+ self.assertEqual(channel.json_body["total"], 1)
+
+ # filter and get empty result
+ channel = self.make_request(
+ "GET", self.url + "?search_term=foobar", access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 0)
+
+ def _create_users_with_media(self, number_users: int, media_per_user: int):
+ """
+ Create a number of users with a number of media
+ Args:
+ number_users: Number of users to be created
+ media_per_user: Number of media to be created for each user
+ """
+ for i in range(number_users):
+ self.register_user("foo_user_%s" % i, "pass", displayname="bar_user_%s" % i)
+ user_tok = self.login("foo_user_%s" % i, "pass")
+ self._create_media(user_tok, media_per_user)
+
+ def _create_media(self, user_token: str, number_media: int):
+ """
+ Create a number of media for a specific user
+ Args:
+ user_token: Access token of the user
+ number_media: Number of media to be created for the user
+ """
+ upload_resource = self.media_repo.children[b"upload"]
+ for i in range(number_media):
+ # file size is 67 Byte
+ image_data = unhexlify(
+ b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+ b"0000001f15c4890000000a49444154789c63000100000500010d"
+ b"0a2db40000000049454e44ae426082"
+ )
+
+ # Upload some media into the room
+ self.helper.upload_media(
+ upload_resource, image_data, tok=user_token, expect_code=200
+ )
+
+ def _check_fields(self, content: List[Dict[str, Any]]):
+ """Checks that all attributes are present in content
+ Args:
+ content: List that is checked for content
+ """
+ for c in content:
+ self.assertIn("user_id", c)
+ self.assertIn("displayname", c)
+ self.assertIn("media_count", c)
+ self.assertIn("media_length", c)
+
+ def _order_test(
+ self, order_type: str, expected_user_list: List[str], dir: Optional[str] = None
+ ):
+ """Request the list of users in a certain order. Assert that order is what
+ we expect
+ Args:
+ order_type: The type of ordering to give the server
+ expected_user_list: The list of user_ids in the order we expect to get
+ back from the server
+ dir: The direction of ordering to give the server
+ """
+
+ url = self.url + "?order_by=%s" % (order_type,)
+ if dir is not None and dir in ("b", "f"):
+ url += "&dir=%s" % (dir,)
+ channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(channel.json_body["total"], len(expected_user_list))
+
+ returned_order = [row["user_id"] for row in channel.json_body["users"]]
+ self.assertListEqual(expected_user_list, returned_order)
+ self._check_fields(channel.json_body["users"])
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 98d0623734..9b2e4765f6 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -17,14 +17,16 @@ import hashlib
import hmac
import json
import urllib.parse
+from binascii import unhexlify
+from typing import Optional
from mock import Mock
import synapse.rest.admin
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
-from synapse.rest.client.v1 import login, room
-from synapse.rest.client.v2_alpha import sync
+from synapse.rest.client.v1 import login, logout, profile, room
+from synapse.rest.client.v2_alpha import devices, sync
from tests import unittest
from tests.test_utils import make_awaitable
@@ -33,11 +35,14 @@ from tests.unittest import override_config
class UserRegisterTestCase(unittest.HomeserverTestCase):
- servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource]
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ profile.register_servlets,
+ ]
def make_homeserver(self, reactor, clock):
- self.url = "/_matrix/client/r0/admin/register"
+ self.url = "/_synapse/admin/v1/register"
self.registration_handler = Mock()
self.identity_handler = Mock()
@@ -66,8 +71,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
"""
self.hs.config.registration_shared_secret = None
- request, channel = self.make_request("POST", self.url, b"{}")
- self.render(request)
+ channel = self.make_request("POST", self.url, b"{}")
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(
@@ -84,8 +88,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
self.hs.get_secrets = Mock(return_value=secrets)
- request, channel = self.make_request("GET", self.url)
- self.render(request)
+ channel = self.make_request("GET", self.url)
self.assertEqual(channel.json_body, {"nonce": "abcd"})
@@ -94,16 +97,14 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
Calling GET on the endpoint will return a randomised nonce, which will
only last for SALT_TIMEOUT (60s).
"""
- request, channel = self.make_request("GET", self.url)
- self.render(request)
+ channel = self.make_request("GET", self.url)
nonce = channel.json_body["nonce"]
# 59 seconds
self.reactor.advance(59)
body = json.dumps({"nonce": nonce})
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("username must be specified", channel.json_body["error"])
@@ -111,8 +112,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# 61 seconds
self.reactor.advance(2)
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("unrecognised nonce", channel.json_body["error"])
@@ -121,8 +121,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
"""
Only the provided nonce can be used, as it's checked in the MAC.
"""
- request, channel = self.make_request("GET", self.url)
- self.render(request)
+ channel = self.make_request("GET", self.url)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -138,8 +137,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
"mac": want_mac,
}
)
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("HMAC incorrect", channel.json_body["error"])
@@ -149,8 +147,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
When the correct nonce is provided, and the right key is provided, the
user is registered.
"""
- request, channel = self.make_request("GET", self.url)
- self.render(request)
+ channel = self.make_request("GET", self.url)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -169,8 +166,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
"mac": want_mac,
}
)
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["user_id"])
@@ -179,8 +175,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
"""
A valid unrecognised nonce.
"""
- request, channel = self.make_request("GET", self.url)
- self.render(request)
+ channel = self.make_request("GET", self.url)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -196,15 +191,13 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
"mac": want_mac,
}
)
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["user_id"])
# Now, try and reuse it
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("unrecognised nonce", channel.json_body["error"])
@@ -217,8 +210,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
"""
def nonce():
- request, channel = self.make_request("GET", self.url)
- self.render(request)
+ channel = self.make_request("GET", self.url)
return channel.json_body["nonce"]
#
@@ -227,8 +219,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# Must be present
body = json.dumps({})
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("nonce must be specified", channel.json_body["error"])
@@ -239,32 +230,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# Must be present
body = json.dumps({"nonce": nonce()})
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("username must be specified", channel.json_body["error"])
# Must be a string
body = json.dumps({"nonce": nonce(), "username": 1234})
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("Invalid username", channel.json_body["error"])
# Must not have null bytes
body = json.dumps({"nonce": nonce(), "username": "abcd\u0000"})
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("Invalid username", channel.json_body["error"])
# Must not have null bytes
body = json.dumps({"nonce": nonce(), "username": "a" * 1000})
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("Invalid username", channel.json_body["error"])
@@ -275,32 +262,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# Must be present
body = json.dumps({"nonce": nonce(), "username": "a"})
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("password must be specified", channel.json_body["error"])
# Must be a string
body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234})
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("Invalid password", channel.json_body["error"])
# Must not have null bytes
body = json.dumps({"nonce": nonce(), "username": "a", "password": "abcd\u0000"})
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("Invalid password", channel.json_body["error"])
# Super long
body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000})
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("Invalid password", channel.json_body["error"])
@@ -318,12 +301,113 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
"user_type": "invalid",
}
)
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("Invalid user type", channel.json_body["error"])
+ def test_displayname(self):
+ """
+ Test that displayname of new user is set
+ """
+
+ # set no displayname
+ channel = self.make_request("GET", self.url)
+ nonce = channel.json_body["nonce"]
+
+ want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
+ want_mac.update(nonce.encode("ascii") + b"\x00bob1\x00abc123\x00notadmin")
+ want_mac = want_mac.hexdigest()
+
+ body = json.dumps(
+ {"nonce": nonce, "username": "bob1", "password": "abc123", "mac": want_mac}
+ )
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@bob1:test", channel.json_body["user_id"])
+
+ channel = self.make_request("GET", "/profile/@bob1:test/displayname")
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("bob1", channel.json_body["displayname"])
+
+ # displayname is None
+ channel = self.make_request("GET", self.url)
+ nonce = channel.json_body["nonce"]
+
+ want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
+ want_mac.update(nonce.encode("ascii") + b"\x00bob2\x00abc123\x00notadmin")
+ want_mac = want_mac.hexdigest()
+
+ body = json.dumps(
+ {
+ "nonce": nonce,
+ "username": "bob2",
+ "displayname": None,
+ "password": "abc123",
+ "mac": want_mac,
+ }
+ )
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@bob2:test", channel.json_body["user_id"])
+
+ channel = self.make_request("GET", "/profile/@bob2:test/displayname")
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("bob2", channel.json_body["displayname"])
+
+ # displayname is empty
+ channel = self.make_request("GET", self.url)
+ nonce = channel.json_body["nonce"]
+
+ want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
+ want_mac.update(nonce.encode("ascii") + b"\x00bob3\x00abc123\x00notadmin")
+ want_mac = want_mac.hexdigest()
+
+ body = json.dumps(
+ {
+ "nonce": nonce,
+ "username": "bob3",
+ "displayname": "",
+ "password": "abc123",
+ "mac": want_mac,
+ }
+ )
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@bob3:test", channel.json_body["user_id"])
+
+ channel = self.make_request("GET", "/profile/@bob3:test/displayname")
+ self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+
+ # set displayname
+ channel = self.make_request("GET", self.url)
+ nonce = channel.json_body["nonce"]
+
+ want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
+ want_mac.update(nonce.encode("ascii") + b"\x00bob4\x00abc123\x00notadmin")
+ want_mac = want_mac.hexdigest()
+
+ body = json.dumps(
+ {
+ "nonce": nonce,
+ "username": "bob4",
+ "displayname": "Bob's Name",
+ "password": "abc123",
+ "mac": want_mac,
+ }
+ )
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@bob4:test", channel.json_body["user_id"])
+
+ channel = self.make_request("GET", "/profile/@bob4:test/displayname")
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("Bob's Name", channel.json_body["displayname"])
+
@override_config(
{"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0}
)
@@ -346,8 +430,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
)
# Register new user with admin API
- request, channel = self.make_request("GET", self.url)
- self.render(request)
+ channel = self.make_request("GET", self.url)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -366,8 +449,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
"mac": want_mac,
}
)
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
+ channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["user_id"])
@@ -385,35 +467,125 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
- self.register_user("user1", "pass1", admin=False)
- self.register_user("user2", "pass2", admin=False)
+ self.user1 = self.register_user(
+ "user1", "pass1", admin=False, displayname="Name 1"
+ )
+ self.user2 = self.register_user(
+ "user2", "pass2", admin=False, displayname="Name 2"
+ )
def test_no_auth(self):
"""
Try to list users without authentication.
"""
- request, channel = self.make_request("GET", self.url, b"{}")
- self.render(request)
+ channel = self.make_request("GET", self.url, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual("M_MISSING_TOKEN", channel.json_body["errcode"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error is returned.
+ """
+ other_user_token = self.login("user1", "pass1")
+
+ channel = self.make_request("GET", self.url, access_token=other_user_token)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_all_users(self):
"""
List all users, including deactivated users.
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
self.url + "?deactivated=true",
b"{}",
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(3, len(channel.json_body["users"]))
self.assertEqual(3, channel.json_body["total"])
+ # Check that all fields are available
+ for u in channel.json_body["users"]:
+ self.assertIn("name", u)
+ self.assertIn("is_guest", u)
+ self.assertIn("admin", u)
+ self.assertIn("user_type", u)
+ self.assertIn("deactivated", u)
+ self.assertIn("displayname", u)
+ self.assertIn("avatar_url", u)
+
+ def test_search_term(self):
+ """Test that searching for a users works correctly"""
+
+ def _search_test(
+ expected_user_id: Optional[str],
+ search_term: str,
+ search_field: Optional[str] = "name",
+ expected_http_code: Optional[int] = 200,
+ ):
+ """Search for a user and check that the returned user's id is a match
+
+ Args:
+ expected_user_id: The user_id expected to be returned by the API. Set
+ to None to expect zero results for the search
+ search_term: The term to search for user names with
+ search_field: Field which is to request: `name` or `user_id`
+ expected_http_code: The expected http code for the request
+ """
+ url = self.url + "?%s=%s" % (search_field, search_term,)
+ channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
+
+ if expected_http_code != 200:
+ return
+
+ # Check that users were returned
+ self.assertTrue("users" in channel.json_body)
+ users = channel.json_body["users"]
+
+ # Check that the expected number of users were returned
+ expected_user_count = 1 if expected_user_id else 0
+ self.assertEqual(len(users), expected_user_count)
+ self.assertEqual(channel.json_body["total"], expected_user_count)
+
+ if expected_user_id:
+ # Check that the first returned user id is correct
+ u = users[0]
+ self.assertEqual(expected_user_id, u["name"])
+
+ # Perform search tests
+ _search_test(self.user1, "er1")
+ _search_test(self.user1, "me 1")
+
+ _search_test(self.user2, "er2")
+ _search_test(self.user2, "me 2")
+
+ _search_test(self.user1, "er1", "user_id")
+ _search_test(self.user2, "er2", "user_id")
+
+ # Test case insensitive
+ _search_test(self.user1, "ER1")
+ _search_test(self.user1, "NAME 1")
+
+ _search_test(self.user2, "ER2")
+ _search_test(self.user2, "NAME 2")
+
+ _search_test(self.user1, "ER1", "user_id")
+ _search_test(self.user2, "ER2", "user_id")
+
+ _search_test(None, "foo")
+ _search_test(None, "bar")
+
+ _search_test(None, "foo", "user_id")
+ _search_test(None, "bar", "user_id")
+
class UserRestTestCase(unittest.HomeserverTestCase):
@@ -429,7 +601,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
- self.other_user = self.register_user("user", "pass")
+ self.other_user = self.register_user("user", "pass", displayname="User")
self.other_user_token = self.login("user", "pass")
self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote(
self.other_user
@@ -441,18 +613,14 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"""
url = "/_synapse/admin/v2/users/@bob:test"
- request, channel = self.make_request(
- "GET", url, access_token=self.other_user_token,
- )
- self.render(request)
+ channel = self.make_request("GET", url, access_token=self.other_user_token,)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("You are not a server admin", channel.json_body["error"])
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", url, access_token=self.other_user_token, content=b"{}",
)
- self.render(request)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("You are not a server admin", channel.json_body["error"])
@@ -462,12 +630,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
Tests that a lookup for a user that does not exist returns a 404
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_synapse/admin/v2/users/@unknown_person:test",
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual("M_NOT_FOUND", channel.json_body["errcode"])
@@ -485,17 +652,16 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"admin": True,
"displayname": "Bob's name",
"threepids": [{"medium": "email", "address": "bob@bob.bob"}],
- "avatar_url": None,
+ "avatar_url": "mxc://fibble/wibble",
}
)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
url,
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
- self.render(request)
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"])
@@ -503,12 +669,10 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
self.assertEqual(True, channel.json_body["admin"])
+ self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
# Get user
- request, channel = self.make_request(
- "GET", url, access_token=self.admin_user_tok,
- )
- self.render(request)
+ channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"])
@@ -518,6 +682,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(True, channel.json_body["admin"])
self.assertEqual(False, channel.json_body["is_guest"])
self.assertEqual(False, channel.json_body["deactivated"])
+ self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
def test_create_user(self):
"""
@@ -532,16 +697,16 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"admin": False,
"displayname": "Bob's name",
"threepids": [{"medium": "email", "address": "bob@bob.bob"}],
+ "avatar_url": "mxc://fibble/wibble",
}
)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
url,
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
- self.render(request)
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"])
@@ -549,12 +714,10 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
self.assertEqual(False, channel.json_body["admin"])
+ self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
# Get user
- request, channel = self.make_request(
- "GET", url, access_token=self.admin_user_tok,
- )
- self.render(request)
+ channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"])
@@ -564,6 +727,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(False, channel.json_body["admin"])
self.assertEqual(False, channel.json_body["is_guest"])
self.assertEqual(False, channel.json_body["deactivated"])
+ self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
@override_config(
{"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0}
@@ -579,10 +743,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Sync to set admin user to active
# before limit of monthly active users is reached
- request, channel = self.make_request(
- "GET", "/sync", access_token=self.admin_user_tok
- )
- self.render(request)
+ channel = self.make_request("GET", "/sync", access_token=self.admin_user_tok)
if channel.code != 200:
raise HttpResponseException(
@@ -605,13 +766,12 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Create user
body = json.dumps({"password": "abc123", "admin": False})
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
url,
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
- self.render(request)
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"])
@@ -645,13 +805,12 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Create user
body = json.dumps({"password": "abc123", "admin": False})
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
url,
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
- self.render(request)
# Admin user is not blocked by mau anymore
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
@@ -683,13 +842,12 @@ class UserRestTestCase(unittest.HomeserverTestCase):
}
)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
url,
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
- self.render(request)
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"])
@@ -701,7 +859,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
)
pushers = list(pushers)
self.assertEqual(len(pushers), 1)
- self.assertEqual("@bob:test", pushers[0]["user_name"])
+ self.assertEqual("@bob:test", pushers[0].user_name)
@override_config(
{
@@ -728,13 +886,12 @@ class UserRestTestCase(unittest.HomeserverTestCase):
}
)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
url,
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
- self.render(request)
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"])
@@ -755,13 +912,12 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Change password
body = json.dumps({"password": "hahaha"})
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -773,23 +929,21 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Modify user
body = json.dumps({"displayname": "foobar"})
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual("foobar", channel.json_body["displayname"])
# Get user
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url_other_user, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
@@ -805,13 +959,12 @@ class UserRestTestCase(unittest.HomeserverTestCase):
{"threepids": [{"medium": "email", "address": "bob3@bob.bob"}]}
)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
@@ -819,10 +972,9 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"])
# Get user
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url_other_user, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
@@ -837,13 +989,12 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Deactivate user
body = json.dumps({"deactivated": True})
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
@@ -851,28 +1002,74 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# the user is deactivated, the threepid will be deleted
# Get user
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url_other_user, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"])
+ @override_config({"user_directory": {"enabled": True, "search_all_users": True}})
+ def test_change_name_deactivate_user_user_directory(self):
+ """
+ Test change profile information of a deactivated user and
+ check that it does not appear in user directory
+ """
+
+ # is in user directory
+ profile = self.get_success(self.store.get_user_in_directory(self.other_user))
+ self.assertTrue(profile["display_name"] == "User")
+
+ # Deactivate user
+ body = json.dumps({"deactivated": True})
+
+ channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(True, channel.json_body["deactivated"])
+
+ # is not in user directory
+ profile = self.get_success(self.store.get_user_in_directory(self.other_user))
+ self.assertTrue(profile is None)
+
+ # Set new displayname user
+ body = json.dumps({"displayname": "Foobar"})
+
+ channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(True, channel.json_body["deactivated"])
+ self.assertEqual("Foobar", channel.json_body["displayname"])
+
+ # is not in user directory
+ profile = self.get_success(self.store.get_user_in_directory(self.other_user))
+ self.assertTrue(profile is None)
+
def test_reactivate_user(self):
"""
Test reactivating another user.
"""
# Deactivate the user.
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
content=json.dumps({"deactivated": True}).encode(encoding="utf_8"),
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self._is_erased("@user:test", False)
d = self.store.mark_user_erased("@user:test")
@@ -880,17 +1077,16 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self._is_erased("@user:test", True)
# Attempt to reactivate the user (without a password).
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
content=json.dumps({"deactivated": False}).encode(encoding="utf_8"),
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
# Reactivate the user.
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
@@ -898,14 +1094,12 @@ class UserRestTestCase(unittest.HomeserverTestCase):
encoding="utf_8"
),
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Get user
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url_other_user, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
@@ -920,23 +1114,21 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Set a user as an admin
body = json.dumps({"admin": True})
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["admin"])
# Get user
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url_other_user, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
@@ -952,23 +1144,19 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Create user
body = json.dumps({"password": "abc123"})
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
url,
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
- self.render(request)
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("bob", channel.json_body["displayname"])
# Get user
- request, channel = self.make_request(
- "GET", url, access_token=self.admin_user_tok,
- )
- self.render(request)
+ channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"])
@@ -978,21 +1166,17 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Change password (and use a str for deactivate instead of a bool)
body = json.dumps({"password": "abc123", "deactivated": "false"}) # oops!
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
url,
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
# Check user is not deactivated
- request, channel = self.make_request(
- "GET", url, access_token=self.admin_user_tok,
- )
- self.render(request)
+ channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"])
@@ -1016,7 +1200,6 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
- sync.register_servlets,
room.register_servlets,
]
@@ -1035,8 +1218,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
"""
Try to list rooms of an user without authentication.
"""
- request, channel = self.make_request("GET", self.url, b"{}")
- self.render(request)
+ channel = self.make_request("GET", self.url, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -1047,10 +1229,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
"""
other_user_token = self.login("user", "pass")
- request, channel = self.make_request(
- "GET", self.url, access_token=other_user_token,
- )
- self.render(request)
+ channel = self.make_request("GET", self.url, access_token=other_user_token,)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -1060,10 +1239,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
Tests that a lookup for a user that does not exist returns a 404
"""
url = "/_synapse/admin/v1/users/@unknown_person:test/joined_rooms"
- request, channel = self.make_request(
- "GET", url, access_token=self.admin_user_tok,
- )
- self.render(request)
+ channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -1074,14 +1250,23 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
"""
url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/joined_rooms"
- request, channel = self.make_request(
- "GET", url, access_token=self.admin_user_tok,
- )
- self.render(request)
+ channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
+ def test_no_memberships(self):
+ """
+ Tests that a normal lookup for rooms is successfully
+ if user has no memberships
+ """
+ # Get rooms
+ channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(0, channel.json_body["total"])
+ self.assertEqual(0, len(channel.json_body["joined_rooms"]))
+
def test_get_rooms(self):
"""
Tests that a normal lookup for rooms is successfully
@@ -1093,11 +1278,675 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
self.helper.create_room_as(self.other_user, tok=other_user_tok)
# Get rooms
- request, channel = self.make_request(
- "GET", self.url, access_token=self.admin_user_tok,
- )
- self.render(request)
+ channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(number_rooms, channel.json_body["total"])
self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"]))
+
+
+class PushersRestTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.url = "/_synapse/admin/v1/users/%s/pushers" % urllib.parse.quote(
+ self.other_user
+ )
+
+ def test_no_auth(self):
+ """
+ Try to list pushers of an user without authentication.
+ """
+ channel = self.make_request("GET", self.url, b"{}")
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error is returned.
+ """
+ other_user_token = self.login("user", "pass")
+
+ channel = self.make_request("GET", self.url, access_token=other_user_token,)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_user_does_not_exist(self):
+ """
+ Tests that a lookup for a user that does not exist returns a 404
+ """
+ url = "/_synapse/admin/v1/users/@unknown_person:test/pushers"
+ channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_user_is_not_local(self):
+ """
+ Tests that a lookup for a user that is not a local returns a 400
+ """
+ url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/pushers"
+
+ channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Can only lookup local users", channel.json_body["error"])
+
+ def test_get_pushers(self):
+ """
+ Tests that a normal lookup for pushers is successfully
+ """
+
+ # Get pushers
+ channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(0, channel.json_body["total"])
+
+ # Register the pusher
+ other_user_token = self.login("user", "pass")
+ user_tuple = self.get_success(
+ self.store.get_user_by_access_token(other_user_token)
+ )
+ token_id = user_tuple.token_id
+
+ self.get_success(
+ self.hs.get_pusherpool().add_pusher(
+ user_id=self.other_user,
+ access_token=token_id,
+ kind="http",
+ app_id="m.http",
+ app_display_name="HTTP Push Notifications",
+ device_display_name="pushy push",
+ pushkey="a@example.com",
+ lang=None,
+ data={"url": "https://example.com/_matrix/push/v1/notify"},
+ )
+ )
+
+ # Get pushers
+ channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(1, channel.json_body["total"])
+
+ for p in channel.json_body["pushers"]:
+ self.assertIn("pushkey", p)
+ self.assertIn("kind", p)
+ self.assertIn("app_id", p)
+ self.assertIn("app_display_name", p)
+ self.assertIn("device_display_name", p)
+ self.assertIn("profile_tag", p)
+ self.assertIn("lang", p)
+ self.assertIn("url", p["data"])
+
+
+class UserMediaRestTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.media_repo = hs.get_media_repository_resource()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.url = "/_synapse/admin/v1/users/%s/media" % urllib.parse.quote(
+ self.other_user
+ )
+
+ def test_no_auth(self):
+ """
+ Try to list media of an user without authentication.
+ """
+ channel = self.make_request("GET", self.url, b"{}")
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error is returned.
+ """
+ other_user_token = self.login("user", "pass")
+
+ channel = self.make_request("GET", self.url, access_token=other_user_token,)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_user_does_not_exist(self):
+ """
+ Tests that a lookup for a user that does not exist returns a 404
+ """
+ url = "/_synapse/admin/v1/users/@unknown_person:test/media"
+ channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_user_is_not_local(self):
+ """
+ Tests that a lookup for a user that is not a local returns a 400
+ """
+ url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/media"
+
+ channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Can only lookup local users", channel.json_body["error"])
+
+ def test_limit(self):
+ """
+ Testing list of media with limit
+ """
+
+ number_media = 20
+ other_user_tok = self.login("user", "pass")
+ self._create_media(other_user_tok, number_media)
+
+ channel = self.make_request(
+ "GET", self.url + "?limit=5", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_media)
+ self.assertEqual(len(channel.json_body["media"]), 5)
+ self.assertEqual(channel.json_body["next_token"], 5)
+ self._check_fields(channel.json_body["media"])
+
+ def test_from(self):
+ """
+ Testing list of media with a defined starting point (from)
+ """
+
+ number_media = 20
+ other_user_tok = self.login("user", "pass")
+ self._create_media(other_user_tok, number_media)
+
+ channel = self.make_request(
+ "GET", self.url + "?from=5", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_media)
+ self.assertEqual(len(channel.json_body["media"]), 15)
+ self.assertNotIn("next_token", channel.json_body)
+ self._check_fields(channel.json_body["media"])
+
+ def test_limit_and_from(self):
+ """
+ Testing list of media with a defined starting point and limit
+ """
+
+ number_media = 20
+ other_user_tok = self.login("user", "pass")
+ self._create_media(other_user_tok, number_media)
+
+ channel = self.make_request(
+ "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_media)
+ self.assertEqual(channel.json_body["next_token"], 15)
+ self.assertEqual(len(channel.json_body["media"]), 10)
+ self._check_fields(channel.json_body["media"])
+
+ def test_limit_is_negative(self):
+ """
+ Testing that a negative limit parameter returns a 400
+ """
+
+ channel = self.make_request(
+ "GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ def test_from_is_negative(self):
+ """
+ Testing that a negative from parameter returns a 400
+ """
+
+ channel = self.make_request(
+ "GET", self.url + "?from=-5", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ def test_next_token(self):
+ """
+ Testing that `next_token` appears at the right place
+ """
+
+ number_media = 20
+ other_user_tok = self.login("user", "pass")
+ self._create_media(other_user_tok, number_media)
+
+ # `next_token` does not appear
+ # Number of results is the number of entries
+ channel = self.make_request(
+ "GET", self.url + "?limit=20", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_media)
+ self.assertEqual(len(channel.json_body["media"]), number_media)
+ self.assertNotIn("next_token", channel.json_body)
+
+ # `next_token` does not appear
+ # Number of max results is larger than the number of entries
+ channel = self.make_request(
+ "GET", self.url + "?limit=21", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_media)
+ self.assertEqual(len(channel.json_body["media"]), number_media)
+ self.assertNotIn("next_token", channel.json_body)
+
+ # `next_token` does appear
+ # Number of max results is smaller than the number of entries
+ channel = self.make_request(
+ "GET", self.url + "?limit=19", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_media)
+ self.assertEqual(len(channel.json_body["media"]), 19)
+ self.assertEqual(channel.json_body["next_token"], 19)
+
+ # Check
+ # Set `from` to value of `next_token` for request remaining entries
+ # `next_token` does not appear
+ channel = self.make_request(
+ "GET", self.url + "?from=19", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_media)
+ self.assertEqual(len(channel.json_body["media"]), 1)
+ self.assertNotIn("next_token", channel.json_body)
+
+ def test_user_has_no_media(self):
+ """
+ Tests that a normal lookup for media is successfully
+ if user has no media created
+ """
+
+ channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(0, channel.json_body["total"])
+ self.assertEqual(0, len(channel.json_body["media"]))
+
+ def test_get_media(self):
+ """
+ Tests that a normal lookup for media is successfully
+ """
+
+ number_media = 5
+ other_user_tok = self.login("user", "pass")
+ self._create_media(other_user_tok, number_media)
+
+ channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(number_media, channel.json_body["total"])
+ self.assertEqual(number_media, len(channel.json_body["media"]))
+ self.assertNotIn("next_token", channel.json_body)
+ self._check_fields(channel.json_body["media"])
+
+ def _create_media(self, user_token, number_media):
+ """
+ Create a number of media for a specific user
+ """
+ upload_resource = self.media_repo.children[b"upload"]
+ for i in range(number_media):
+ # file size is 67 Byte
+ image_data = unhexlify(
+ b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+ b"0000001f15c4890000000a49444154789c63000100000500010d"
+ b"0a2db40000000049454e44ae426082"
+ )
+
+ # Upload some media into the room
+ self.helper.upload_media(
+ upload_resource, image_data, tok=user_token, expect_code=200
+ )
+
+ def _check_fields(self, content):
+ """Checks that all attributes are present in content
+ """
+ for m in content:
+ self.assertIn("media_id", m)
+ self.assertIn("media_type", m)
+ self.assertIn("media_length", m)
+ self.assertIn("upload_name", m)
+ self.assertIn("created_ts", m)
+ self.assertIn("last_access_ts", m)
+ self.assertIn("quarantined_by", m)
+ self.assertIn("safe_from_quarantine", m)
+
+
+class UserTokenRestTestCase(unittest.HomeserverTestCase):
+ """Test for /_synapse/admin/v1/users/<user>/login
+ """
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ room.register_servlets,
+ devices.register_servlets,
+ logout.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_tok = self.login("user", "pass")
+ self.url = "/_synapse/admin/v1/users/%s/login" % urllib.parse.quote(
+ self.other_user
+ )
+
+ def _get_token(self) -> str:
+ channel = self.make_request(
+ "POST", self.url, b"{}", access_token=self.admin_user_tok
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ return channel.json_body["access_token"]
+
+ def test_no_auth(self):
+ """Try to login as a user without authentication.
+ """
+ channel = self.make_request("POST", self.url, b"{}")
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_not_admin(self):
+ """Try to login as a user as a non-admin user.
+ """
+ channel = self.make_request(
+ "POST", self.url, b"{}", access_token=self.other_user_tok
+ )
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+
+ def test_send_event(self):
+ """Test that sending event as a user works.
+ """
+ # Create a room.
+ room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_tok)
+
+ # Login in as the user
+ puppet_token = self._get_token()
+
+ # Test that sending works, and generates the event as the right user.
+ resp = self.helper.send_event(room_id, "com.example.test", tok=puppet_token)
+ event_id = resp["event_id"]
+ event = self.get_success(self.store.get_event(event_id))
+ self.assertEqual(event.sender, self.other_user)
+
+ def test_devices(self):
+ """Tests that logging in as a user doesn't create a new device for them.
+ """
+ # Login in as the user
+ self._get_token()
+
+ # Check that we don't see a new device in our devices list
+ channel = self.make_request(
+ "GET", "devices", b"{}", access_token=self.other_user_tok
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # We should only see the one device (from the login in `prepare`)
+ self.assertEqual(len(channel.json_body["devices"]), 1)
+
+ def test_logout(self):
+ """Test that calling `/logout` with the token works.
+ """
+ # Login in as the user
+ puppet_token = self._get_token()
+
+ # Test that we can successfully make a request
+ channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Logout with the puppet token
+ channel = self.make_request("POST", "logout", b"{}", access_token=puppet_token)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # The puppet token should no longer work
+ channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+
+ # .. but the real user's tokens should still work
+ channel = self.make_request(
+ "GET", "devices", b"{}", access_token=self.other_user_tok
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ def test_user_logout_all(self):
+ """Tests that the target user calling `/logout/all` does *not* expire
+ the token.
+ """
+ # Login in as the user
+ puppet_token = self._get_token()
+
+ # Test that we can successfully make a request
+ channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Logout all with the real user token
+ channel = self.make_request(
+ "POST", "logout/all", b"{}", access_token=self.other_user_tok
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # The puppet token should still work
+ channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # .. but the real user's tokens shouldn't
+ channel = self.make_request(
+ "GET", "devices", b"{}", access_token=self.other_user_tok
+ )
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+
+ def test_admin_logout_all(self):
+ """Tests that the admin user calling `/logout/all` does expire the
+ token.
+ """
+ # Login in as the user
+ puppet_token = self._get_token()
+
+ # Test that we can successfully make a request
+ channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Logout all with the admin user token
+ channel = self.make_request(
+ "POST", "logout/all", b"{}", access_token=self.admin_user_tok
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # The puppet token should no longer work
+ channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+
+ # .. but the real user's tokens should still work
+ channel = self.make_request(
+ "GET", "devices", b"{}", access_token=self.other_user_tok
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ @unittest.override_config(
+ {
+ "public_baseurl": "https://example.org/",
+ "user_consent": {
+ "version": "1.0",
+ "policy_name": "My Cool Privacy Policy",
+ "template_dir": "/",
+ "require_at_registration": True,
+ "block_events_error": "You should accept the policy",
+ },
+ "form_secret": "123secret",
+ }
+ )
+ def test_consent(self):
+ """Test that sending a message is not subject to the privacy policies.
+ """
+ # Have the admin user accept the terms.
+ self.get_success(self.store.user_set_consent_version(self.admin_user, "1.0"))
+
+ # First, cheekily accept the terms and create a room
+ self.get_success(self.store.user_set_consent_version(self.other_user, "1.0"))
+ room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_tok)
+ self.helper.send_event(room_id, "com.example.test", tok=self.other_user_tok)
+
+ # Now unaccept it and check that we can't send an event
+ self.get_success(self.store.user_set_consent_version(self.other_user, "0.0"))
+ self.helper.send_event(
+ room_id, "com.example.test", tok=self.other_user_tok, expect_code=403
+ )
+
+ # Login in as the user
+ puppet_token = self._get_token()
+
+ # Sending an event on their behalf should work fine
+ self.helper.send_event(room_id, "com.example.test", tok=puppet_token)
+
+ @override_config(
+ {"limit_usage_by_mau": True, "max_mau_value": 1, "mau_trial_days": 0}
+ )
+ def test_mau_limit(self):
+ # Create a room as the admin user. This will bump the monthly active users to 1.
+ room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ # Trying to join as the other user should fail due to reaching MAU limit.
+ self.helper.join(
+ room_id, user=self.other_user, tok=self.other_user_tok, expect_code=403
+ )
+
+ # Logging in as the other user and joining a room should work, even
+ # though the MAU limit would stop the user doing so.
+ puppet_token = self._get_token()
+ self.helper.join(room_id, user=self.other_user, tok=puppet_token)
+
+
+class WhoisRestTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.url1 = "/_synapse/admin/v1/whois/%s" % urllib.parse.quote(self.other_user)
+ self.url2 = "/_matrix/client/r0/admin/whois/%s" % urllib.parse.quote(
+ self.other_user
+ )
+
+ def test_no_auth(self):
+ """
+ Try to get information of an user without authentication.
+ """
+ channel = self.make_request("GET", self.url1, b"{}")
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ channel = self.make_request("GET", self.url2, b"{}")
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_not_admin(self):
+ """
+ If the user is not a server admin, an error is returned.
+ """
+ self.register_user("user2", "pass")
+ other_user2_token = self.login("user2", "pass")
+
+ channel = self.make_request("GET", self.url1, access_token=other_user2_token,)
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ channel = self.make_request("GET", self.url2, access_token=other_user2_token,)
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_user_is_not_local(self):
+ """
+ Tests that a lookup for a user that is not a local returns a 400
+ """
+ url1 = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain"
+ url2 = "/_matrix/client/r0/admin/whois/@unknown_person:unknown_domain"
+
+ channel = self.make_request("GET", url1, access_token=self.admin_user_tok,)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Can only whois a local user", channel.json_body["error"])
+
+ channel = self.make_request("GET", url2, access_token=self.admin_user_tok,)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Can only whois a local user", channel.json_body["error"])
+
+ def test_get_whois_admin(self):
+ """
+ The lookup should succeed for an admin.
+ """
+ channel = self.make_request("GET", self.url1, access_token=self.admin_user_tok,)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(self.other_user, channel.json_body["user_id"])
+ self.assertIn("devices", channel.json_body)
+
+ channel = self.make_request("GET", self.url2, access_token=self.admin_user_tok,)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(self.other_user, channel.json_body["user_id"])
+ self.assertIn("devices", channel.json_body)
+
+ def test_get_whois_user(self):
+ """
+ The lookup should succeed for a normal user looking up their own information.
+ """
+ other_user_token = self.login("user", "pass")
+
+ channel = self.make_request("GET", self.url1, access_token=other_user_token,)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(self.other_user, channel.json_body["user_id"])
+ self.assertIn("devices", channel.json_body)
+
+ channel = self.make_request("GET", self.url2, access_token=other_user_token,)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(self.other_user, channel.json_body["user_id"])
+ self.assertIn("devices", channel.json_body)
diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py
index 6803b372ac..c74693e9b2 100644
--- a/tests/rest/client/test_consent.py
+++ b/tests/rest/client/test_consent.py
@@ -21,7 +21,7 @@ from synapse.rest.client.v1 import login, room
from synapse.rest.consent import consent_resource
from tests import unittest
-from tests.server import render
+from tests.server import FakeSite, make_request
class ConsentResourceTestCase(unittest.HomeserverTestCase):
@@ -61,8 +61,9 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
def test_render_public_consent(self):
"""You can observe the terms form without specifying a user"""
resource = consent_resource.ConsentResource(self.hs)
- request, channel = self.make_request("GET", "/consent?v=1", shorthand=False)
- render(request, resource, self.reactor)
+ channel = make_request(
+ self.reactor, FakeSite(resource), "GET", "/consent?v=1", shorthand=False
+ )
self.assertEqual(channel.code, 200)
def test_accept_consent(self):
@@ -81,10 +82,14 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
uri_builder.build_user_consent_uri(user_id).replace("_matrix/", "")
+ "&u=user"
)
- request, channel = self.make_request(
- "GET", consent_uri, access_token=access_token, shorthand=False
+ channel = make_request(
+ self.reactor,
+ FakeSite(resource),
+ "GET",
+ consent_uri,
+ access_token=access_token,
+ shorthand=False,
)
- render(request, resource, self.reactor)
self.assertEqual(channel.code, 200)
# Get the version from the body, and whether we've consented
@@ -92,21 +97,26 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
self.assertEqual(consented, "False")
# POST to the consent page, saying we've agreed
- request, channel = self.make_request(
+ channel = make_request(
+ self.reactor,
+ FakeSite(resource),
"POST",
consent_uri + "&v=" + version,
access_token=access_token,
shorthand=False,
)
- render(request, resource, self.reactor)
self.assertEqual(channel.code, 200)
# Fetch the consent page, to get the consent version -- it should have
# changed
- request, channel = self.make_request(
- "GET", consent_uri, access_token=access_token, shorthand=False
+ channel = make_request(
+ self.reactor,
+ FakeSite(resource),
+ "GET",
+ consent_uri,
+ access_token=access_token,
+ shorthand=False,
)
- render(request, resource, self.reactor)
self.assertEqual(channel.code, 200)
# Get the version from the body, and check that it's the version we
diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py
index 5e9c07ebf3..56937dcd2e 100644
--- a/tests/rest/client/test_ephemeral_message.py
+++ b/tests/rest/client/test_ephemeral_message.py
@@ -93,8 +93,7 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase):
def get_event(self, room_id, event_id, expected_code=200):
url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
- request, channel = self.make_request("GET", url)
- self.render(request)
+ channel = self.make_request("GET", url)
self.assertEqual(channel.code, expected_code, channel.result)
diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py
index c973521907..c0a9fc6925 100644
--- a/tests/rest/client/test_identity.py
+++ b/tests/rest/client/test_identity.py
@@ -43,10 +43,7 @@ class IdentityTestCase(unittest.HomeserverTestCase):
self.register_user("kermit", "monkey")
tok = self.login("kermit", "monkey")
- request, channel = self.make_request(
- b"POST", "/createRoom", b"{}", access_token=tok
- )
- self.render(request)
+ channel = self.make_request(b"POST", "/createRoom", b"{}", access_token=tok)
self.assertEquals(channel.result["code"], b"200", channel.result)
room_id = channel.json_body["room_id"]
@@ -57,8 +54,7 @@ class IdentityTestCase(unittest.HomeserverTestCase):
}
request_data = json.dumps(params)
request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii")
- request, channel = self.make_request(
+ channel = self.make_request(
b"POST", request_url, request_data, access_token=tok
)
- self.render(request)
self.assertEquals(channel.result["code"], b"403", channel.result)
diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py
index d2bcf256fa..f0707646bb 100644
--- a/tests/rest/client/test_redactions.py
+++ b/tests/rest/client/test_redactions.py
@@ -69,18 +69,12 @@ class RedactionsTestCase(HomeserverTestCase):
"""
path = "/_matrix/client/r0/rooms/%s/redact/%s" % (room_id, event_id)
- request, channel = self.make_request(
- "POST", path, content={}, access_token=access_token
- )
- self.render(request)
+ channel = self.make_request("POST", path, content={}, access_token=access_token)
self.assertEqual(int(channel.result["code"]), expect_code)
return channel.json_body
def _sync_room_timeline(self, access_token, room_id):
- request, channel = self.make_request(
- "GET", "sync", access_token=self.mod_access_token
- )
- self.render(request)
+ channel = self.make_request("GET", "sync", access_token=self.mod_access_token)
self.assertEqual(channel.result["code"], b"200")
room_sync = channel.json_body["rooms"]["join"][room_id]
return room_sync["timeline"]["events"]
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index 7d3773ff78..31dc832fd5 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -325,8 +325,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
def get_event(self, room_id, event_id, expected_code=200):
url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
- request, channel = self.make_request("GET", url, access_token=self.token)
- self.render(request)
+ channel = self.make_request("GET", url, access_token=self.token)
self.assertEqual(channel.code, expected_code, channel.result)
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index dfe4bf7762..e689c3fbea 100644
--- a/tests/rest/client/test_shadow_banned.py
+++ b/tests/rest/client/test_shadow_banned.py
@@ -78,7 +78,7 @@ class RoomTestCase(_ShadowBannedBase):
def test_invite_3pid(self):
"""Ensure that a 3PID invite does not attempt to contact the identity server."""
- identity_handler = self.hs.get_handlers().identity_handler
+ identity_handler = self.hs.get_identity_handler()
identity_handler.lookup_3pid = Mock(
side_effect=AssertionError("This should not get called")
)
@@ -89,13 +89,12 @@ class RoomTestCase(_ShadowBannedBase):
)
# Inviting the user completes successfully.
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/rooms/%s/invite" % (room_id,),
{"id_server": "test", "medium": "email", "address": "test@test.test"},
access_token=self.banned_access_token,
)
- self.render(request)
self.assertEquals(200, channel.code, channel.result)
# This should have raised an error earlier, but double check this wasn't called.
@@ -104,13 +103,12 @@ class RoomTestCase(_ShadowBannedBase):
def test_create_room(self):
"""Invitations during a room creation should be discarded, but the room still gets created."""
# The room creation is successful.
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/createRoom",
{"visibility": "public", "invite": [self.other_user_id]},
access_token=self.banned_access_token,
)
- self.render(request)
self.assertEquals(200, channel.code, channel.result)
room_id = channel.json_body["room_id"]
@@ -160,13 +158,12 @@ class RoomTestCase(_ShadowBannedBase):
self.banned_user_id, tok=self.banned_access_token
)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/rooms/%s/upgrade" % (room_id,),
{"new_version": "6"},
access_token=self.banned_access_token,
)
- self.render(request)
self.assertEquals(200, channel.code, channel.result)
# A new room_id should be returned.
self.assertIn("replacement_room", channel.json_body)
@@ -186,13 +183,12 @@ class RoomTestCase(_ShadowBannedBase):
self.banned_user_id, tok=self.banned_access_token
)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/rooms/%s/typing/%s" % (room_id, self.banned_user_id),
{"typing": True, "timeout": 30000},
access_token=self.banned_access_token,
)
- self.render(request)
self.assertEquals(200, channel.code)
# There should be no typing events.
@@ -202,13 +198,12 @@ class RoomTestCase(_ShadowBannedBase):
# The other user can join and send typing events.
self.helper.join(room_id, self.other_user_id, tok=self.other_access_token)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/rooms/%s/typing/%s" % (room_id, self.other_user_id),
{"typing": True, "timeout": 30000},
access_token=self.other_access_token,
)
- self.render(request)
self.assertEquals(200, channel.code)
# These appear in the room.
@@ -249,21 +244,19 @@ class ProfileTestCase(_ShadowBannedBase):
)
# The update should succeed.
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/_matrix/client/r0/profile/%s/displayname" % (self.banned_user_id,),
{"displayname": new_display_name},
access_token=self.banned_access_token,
)
- self.render(request)
self.assertEquals(200, channel.code, channel.result)
self.assertEqual(channel.json_body, {})
# The user's display name should be updated.
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/profile/%s/displayname" % (self.banned_user_id,)
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.json_body["displayname"], new_display_name)
@@ -289,14 +282,13 @@ class ProfileTestCase(_ShadowBannedBase):
)
# The update should succeed.
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/_matrix/client/r0/rooms/%s/state/m.room.member/%s"
% (room_id, self.banned_user_id),
{"membership": "join", "displayname": new_display_name},
access_token=self.banned_access_token,
)
- self.render(request)
self.assertEquals(200, channel.code, channel.result)
self.assertIn("event_id", channel.json_body)
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
new file mode 100644
index 0000000000..227fffab58
--- /dev/null
+++ b/tests/rest/client/test_third_party_rules.py
@@ -0,0 +1,185 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import threading
+from typing import Dict
+
+from mock import Mock
+
+from synapse.events import EventBase
+from synapse.module_api import ModuleApi
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.types import Requester, StateMap
+
+from tests import unittest
+
+thread_local = threading.local()
+
+
+class ThirdPartyRulesTestModule:
+ def __init__(self, config: Dict, module_api: ModuleApi):
+ # keep a record of the "current" rules module, so that the test can patch
+ # it if desired.
+ thread_local.rules_module = self
+ self.module_api = module_api
+
+ async def on_create_room(
+ self, requester: Requester, config: dict, is_requester_admin: bool
+ ):
+ return True
+
+ async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]):
+ return True
+
+ @staticmethod
+ def parse_config(config):
+ return config
+
+
+def current_rules_module() -> ThirdPartyRulesTestModule:
+ return thread_local.rules_module
+
+
+class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def default_config(self):
+ config = super().default_config()
+ config["third_party_event_rules"] = {
+ "module": __name__ + ".ThirdPartyRulesTestModule",
+ "config": {},
+ }
+ return config
+
+ def prepare(self, reactor, clock, homeserver):
+ # Create a user and room to play with during the tests
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+
+ self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+
+ def test_third_party_rules(self):
+ """Tests that a forbidden event is forbidden from being sent, but an allowed one
+ can be sent.
+ """
+ # patch the rules module with a Mock which will return False for some event
+ # types
+ async def check(ev, state):
+ return ev.type != "foo.bar.forbidden"
+
+ callback = Mock(spec=[], side_effect=check)
+ current_rules_module().check_event_allowed = callback
+
+ channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/send/foo.bar.allowed/1" % self.room_id,
+ {},
+ access_token=self.tok,
+ )
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ callback.assert_called_once()
+
+ # there should be various state events in the state arg: do some basic checks
+ state_arg = callback.call_args[0][1]
+ for k in (("m.room.create", ""), ("m.room.member", self.user_id)):
+ self.assertIn(k, state_arg)
+ ev = state_arg[k]
+ self.assertEqual(ev.type, k[0])
+ self.assertEqual(ev.state_key, k[1])
+
+ channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/2" % self.room_id,
+ {},
+ access_token=self.tok,
+ )
+ self.assertEquals(channel.result["code"], b"403", channel.result)
+
+ def test_cannot_modify_event(self):
+ """cannot accidentally modify an event before it is persisted"""
+
+ # first patch the event checker so that it will try to modify the event
+ async def check(ev: EventBase, state):
+ ev.content = {"x": "y"}
+ return True
+
+ current_rules_module().check_event_allowed = check
+
+ # now send the event
+ channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/send/modifyme/1" % self.room_id,
+ {"x": "x"},
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.result["code"], b"500", channel.result)
+
+ def test_modify_event(self):
+ """The module can return a modified version of the event"""
+ # first patch the event checker so that it will modify the event
+ async def check(ev: EventBase, state):
+ d = ev.get_dict()
+ d["content"] = {"x": "y"}
+ return d
+
+ current_rules_module().check_event_allowed = check
+
+ # now send the event
+ channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/send/modifyme/1" % self.room_id,
+ {"x": "x"},
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ event_id = channel.json_body["event_id"]
+
+ # ... and check that it got modified
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ ev = channel.json_body
+ self.assertEqual(ev["content"]["x"], "y")
+
+ def test_send_event(self):
+ """Tests that the module can send an event into a room via the module api"""
+ content = {
+ "msgtype": "m.text",
+ "body": "Hello!",
+ }
+ event_dict = {
+ "room_id": self.room_id,
+ "type": "m.room.message",
+ "content": content,
+ "sender": self.user_id,
+ }
+ event = self.get_success(
+ current_rules_module().module_api.create_and_send_event_into_room(
+ event_dict
+ )
+ ) # type: EventBase
+
+ self.assertEquals(event.sender, self.user_id)
+ self.assertEquals(event.room_id, self.room_id)
+ self.assertEquals(event.type, "m.room.message")
+ self.assertEquals(event.content, content)
diff --git a/tests/rest/client/third_party_rules.py b/tests/rest/client/third_party_rules.py
deleted file mode 100644
index 8c24add530..0000000000
--- a/tests/rest/client/third_party_rules.py
+++ /dev/null
@@ -1,79 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2019 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the 'License');
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an 'AS IS' BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.rest import admin
-from synapse.rest.client.v1 import login, room
-
-from tests import unittest
-
-
-class ThirdPartyRulesTestModule:
- def __init__(self, config):
- pass
-
- def check_event_allowed(self, event, context):
- if event.type == "foo.bar.forbidden":
- return False
- else:
- return True
-
- @staticmethod
- def parse_config(config):
- return config
-
-
-class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
- servlets = [
- admin.register_servlets,
- login.register_servlets,
- room.register_servlets,
- ]
-
- def make_homeserver(self, reactor, clock):
- config = self.default_config()
- config["third_party_event_rules"] = {
- "module": "tests.rest.client.third_party_rules.ThirdPartyRulesTestModule",
- "config": {},
- }
-
- self.hs = self.setup_test_homeserver(config=config)
- return self.hs
-
- def test_third_party_rules(self):
- """Tests that a forbidden event is forbidden from being sent, but an allowed one
- can be sent.
- """
- user_id = self.register_user("kermit", "monkey")
- tok = self.login("kermit", "monkey")
-
- room_id = self.helper.create_room_as(user_id, tok=tok)
-
- request, channel = self.make_request(
- "PUT",
- "/_matrix/client/r0/rooms/%s/send/foo.bar.allowed/1" % room_id,
- {},
- access_token=tok,
- )
- self.render(request)
- self.assertEquals(channel.result["code"], b"200", channel.result)
-
- request, channel = self.make_request(
- "PUT",
- "/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/1" % room_id,
- {},
- access_token=tok,
- )
- self.render(request)
- self.assertEquals(channel.result["code"], b"403", channel.result)
diff --git a/tests/rest/client/v1/test_directory.py b/tests/rest/client/v1/test_directory.py
index 633b7dbda0..edd1d184f8 100644
--- a/tests/rest/client/v1/test_directory.py
+++ b/tests/rest/client/v1/test_directory.py
@@ -21,6 +21,7 @@ from synapse.types import RoomAlias
from synapse.util.stringutils import random_string
from tests import unittest
+from tests.unittest import override_config
class DirectoryTestCase(unittest.HomeserverTestCase):
@@ -67,10 +68,18 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.ensure_user_joined_room()
self.set_alias_via_directory(400, alias_length=256)
- def test_state_event_in_room(self):
+ @override_config({"default_room_version": 5})
+ def test_state_event_user_in_v5_room(self):
+ """Test that a regular user can add alias events before room v6"""
self.ensure_user_joined_room()
self.set_alias_via_state_event(200)
+ @override_config({"default_room_version": 6})
+ def test_state_event_v6_room(self):
+ """Test that a regular user can *not* add alias events from room v6"""
+ self.ensure_user_joined_room()
+ self.set_alias_via_state_event(403)
+
def test_directory_in_room(self):
self.ensure_user_joined_room()
self.set_alias_via_directory(200)
@@ -82,10 +91,9 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
# that we can make sure that the check is done on the whole alias.
data = {"room_alias_name": random_string(256 - len(self.hs.hostname))}
request_data = json.dumps(data)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", url, request_data, access_token=self.user_tok
)
- self.render(request)
self.assertEqual(channel.code, 400, channel.result)
def test_room_creation(self):
@@ -96,10 +104,9 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
# as cautious as possible here.
data = {"room_alias_name": random_string(5)}
request_data = json.dumps(data)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", url, request_data, access_token=self.user_tok
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
def set_alias_via_state_event(self, expected_code, alias_length=5):
@@ -111,10 +118,9 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
data = {"aliases": [self.random_alias(alias_length)]}
request_data = json.dumps(data)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", url, request_data, access_token=self.user_tok
)
- self.render(request)
self.assertEqual(channel.code, expected_code, channel.result)
def set_alias_via_directory(self, expected_code, alias_length=5):
@@ -122,10 +128,9 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
data = {"room_id": self.room_id}
request_data = json.dumps(data)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", url, request_data, access_token=self.user_tok
)
- self.render(request)
self.assertEqual(channel.code, expected_code, channel.result)
def random_alias(self, length):
diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py
index f75520877f..0a5ca317ea 100644
--- a/tests/rest/client/v1/test_events.py
+++ b/tests/rest/client/v1/test_events.py
@@ -42,7 +42,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
hs = self.setup_test_homeserver(config=config)
- hs.get_handlers().federation_handler = Mock()
+ hs.get_federation_handler = Mock()
return hs
@@ -63,17 +63,15 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
# implementation is now part of the r0 implementation, the newer
# behaviour is used instead to be consistent with the r0 spec.
# see issue #2602
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/events?access_token=%s" % ("invalid" + self.token,)
)
- self.render(request)
self.assertEquals(channel.code, 401, msg=channel.result)
# valid token, expect content
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/events?access_token=%s&timeout=0" % (self.token,)
)
- self.render(request)
self.assertEquals(channel.code, 200, msg=channel.result)
self.assertTrue("chunk" in channel.json_body)
self.assertTrue("start" in channel.json_body)
@@ -89,10 +87,9 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
)
# valid token, expect content
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/events?access_token=%s&timeout=0" % (self.token,)
)
- self.render(request)
self.assertEquals(channel.code, 200, msg=channel.result)
# We may get a presence event for ourselves down
@@ -152,8 +149,7 @@ class GetEventsTestCase(unittest.HomeserverTestCase):
resp = self.helper.send(self.room_id, tok=self.token)
event_id = resp["event_id"]
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/events/" + event_id, access_token=self.token,
)
- self.render(request)
self.assertEquals(channel.code, 200, msg=channel.result)
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 5d987a30c7..18932d7518 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -63,8 +63,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
"password": "monkey",
}
- request, channel = self.make_request(b"POST", LOGIN_URL, params)
- self.render(request)
+ channel = self.make_request(b"POST", LOGIN_URL, params)
if i == 5:
self.assertEquals(channel.result["code"], b"429", channel.result)
@@ -83,8 +82,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
"password": "monkey",
}
- request, channel = self.make_request(b"POST", LOGIN_URL, params)
- self.render(request)
+ channel = self.make_request(b"POST", LOGIN_URL, params)
self.assertEquals(channel.result["code"], b"200", channel.result)
@@ -110,8 +108,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "monkey",
}
- request, channel = self.make_request(b"POST", LOGIN_URL, params)
- self.render(request)
+ channel = self.make_request(b"POST", LOGIN_URL, params)
if i == 5:
self.assertEquals(channel.result["code"], b"429", channel.result)
@@ -130,8 +127,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "monkey",
}
- request, channel = self.make_request(b"POST", LOGIN_URL, params)
- self.render(request)
+ channel = self.make_request(b"POST", LOGIN_URL, params)
self.assertEquals(channel.result["code"], b"200", channel.result)
@@ -157,8 +153,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "notamonkey",
}
- request, channel = self.make_request(b"POST", LOGIN_URL, params)
- self.render(request)
+ channel = self.make_request(b"POST", LOGIN_URL, params)
if i == 5:
self.assertEquals(channel.result["code"], b"429", channel.result)
@@ -177,8 +172,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "notamonkey",
}
- request, channel = self.make_request(b"POST", LOGIN_URL, params)
- self.render(request)
+ channel = self.make_request(b"POST", LOGIN_URL, params)
self.assertEquals(channel.result["code"], b"403", channel.result)
@@ -187,8 +181,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self.register_user("kermit", "monkey")
# we shouldn't be able to make requests without an access token
- request, channel = self.make_request(b"GET", TEST_URL)
- self.render(request)
+ channel = self.make_request(b"GET", TEST_URL)
self.assertEquals(channel.result["code"], b"401", channel.result)
self.assertEquals(channel.json_body["errcode"], "M_MISSING_TOKEN")
@@ -198,28 +191,21 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "monkey",
}
- request, channel = self.make_request(b"POST", LOGIN_URL, params)
- self.render(request)
+ channel = self.make_request(b"POST", LOGIN_URL, params)
self.assertEquals(channel.code, 200, channel.result)
access_token = channel.json_body["access_token"]
device_id = channel.json_body["device_id"]
# we should now be able to make requests with the access token
- request, channel = self.make_request(
- b"GET", TEST_URL, access_token=access_token
- )
- self.render(request)
+ channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEquals(channel.code, 200, channel.result)
# time passes
self.reactor.advance(24 * 3600)
# ... and we should be soft-logouted
- request, channel = self.make_request(
- b"GET", TEST_URL, access_token=access_token
- )
- self.render(request)
+ channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEquals(channel.code, 401, channel.result)
self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEquals(channel.json_body["soft_logout"], True)
@@ -233,10 +219,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# more requests with the expired token should still return a soft-logout
self.reactor.advance(3600)
- request, channel = self.make_request(
- b"GET", TEST_URL, access_token=access_token
- )
- self.render(request)
+ channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEquals(channel.code, 401, channel.result)
self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEquals(channel.json_body["soft_logout"], True)
@@ -244,20 +227,16 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# ... but if we delete that device, it will be a proper logout
self._delete_device(access_token_2, "kermit", "monkey", device_id)
- request, channel = self.make_request(
- b"GET", TEST_URL, access_token=access_token
- )
- self.render(request)
+ channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEquals(channel.code, 401, channel.result)
self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEquals(channel.json_body["soft_logout"], False)
def _delete_device(self, access_token, user_id, password, device_id):
"""Perform the UI-Auth to delete a device"""
- request, channel = self.make_request(
+ channel = self.make_request(
b"DELETE", "devices/" + device_id, access_token=access_token
)
- self.render(request)
self.assertEquals(channel.code, 401, channel.result)
# check it's a UI-Auth fail
self.assertEqual(
@@ -275,13 +254,12 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"session": channel.json_body["session"],
}
- request, channel = self.make_request(
+ channel = self.make_request(
b"DELETE",
"devices/" + device_id,
access_token=access_token,
content={"auth": auth},
)
- self.render(request)
self.assertEquals(channel.code, 200, channel.result)
@override_config({"session_lifetime": "24h"})
@@ -292,29 +270,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
access_token = self.login("kermit", "monkey")
# we should now be able to make requests with the access token
- request, channel = self.make_request(
- b"GET", TEST_URL, access_token=access_token
- )
- self.render(request)
+ channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEquals(channel.code, 200, channel.result)
# time passes
self.reactor.advance(24 * 3600)
# ... and we should be soft-logouted
- request, channel = self.make_request(
- b"GET", TEST_URL, access_token=access_token
- )
- self.render(request)
+ channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEquals(channel.code, 401, channel.result)
self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEquals(channel.json_body["soft_logout"], True)
# Now try to hard logout this session
- request, channel = self.make_request(
- b"POST", "/logout", access_token=access_token
- )
- self.render(request)
+ channel = self.make_request(b"POST", "/logout", access_token=access_token)
self.assertEquals(channel.result["code"], b"200", channel.result)
@override_config({"session_lifetime": "24h"})
@@ -325,29 +294,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
access_token = self.login("kermit", "monkey")
# we should now be able to make requests with the access token
- request, channel = self.make_request(
- b"GET", TEST_URL, access_token=access_token
- )
- self.render(request)
+ channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEquals(channel.code, 200, channel.result)
# time passes
self.reactor.advance(24 * 3600)
# ... and we should be soft-logouted
- request, channel = self.make_request(
- b"GET", TEST_URL, access_token=access_token
- )
- self.render(request)
+ channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEquals(channel.code, 401, channel.result)
self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEquals(channel.json_body["soft_logout"], True)
# Now try to hard log out all of the user's sessions
- request, channel = self.make_request(
- b"POST", "/logout/all", access_token=access_token
- )
- self.render(request)
+ channel = self.make_request(b"POST", "/logout/all", access_token=access_token)
self.assertEquals(channel.result["code"], b"200", channel.result)
@@ -422,8 +382,7 @@ class CASTestCase(unittest.HomeserverTestCase):
cas_ticket_url = urllib.parse.urlunparse(url_parts)
# Get Synapse to call the fake CAS and serve the template.
- request, channel = self.make_request("GET", cas_ticket_url)
- self.render(request)
+ channel = self.make_request("GET", cas_ticket_url)
# Test that the response is HTML.
self.assertEqual(channel.code, 200)
@@ -467,8 +426,7 @@ class CASTestCase(unittest.HomeserverTestCase):
)
# Get Synapse to call the fake CAS and serve the template.
- request, channel = self.make_request("GET", cas_ticket_url)
- self.render(request)
+ channel = self.make_request("GET", cas_ticket_url)
self.assertEqual(channel.code, 302)
location_headers = channel.headers.getRawHeaders("Location")
@@ -494,8 +452,7 @@ class CASTestCase(unittest.HomeserverTestCase):
)
# Get Synapse to call the fake CAS and serve the template.
- request, channel = self.make_request("GET", cas_ticket_url)
- self.render(request)
+ channel = self.make_request("GET", cas_ticket_url)
# Because the user is deactivated they are served an error template.
self.assertEqual(channel.code, 403)
@@ -518,15 +475,18 @@ class JWTTestCase(unittest.HomeserverTestCase):
self.hs.config.jwt_algorithm = self.jwt_algorithm
return self.hs
- def jwt_encode(self, token, secret=jwt_secret):
- return jwt.encode(token, secret, self.jwt_algorithm).decode("ascii")
+ def jwt_encode(self, token: str, secret: str = jwt_secret) -> str:
+ # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
+ result = jwt.encode(token, secret, self.jwt_algorithm)
+ if isinstance(result, bytes):
+ return result.decode("ascii")
+ return result
def jwt_login(self, *args):
params = json.dumps(
{"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
)
- request, channel = self.make_request(b"POST", LOGIN_URL, params)
- self.render(request)
+ channel = self.make_request(b"POST", LOGIN_URL, params)
return channel
def test_login_jwt_valid_registered(self):
@@ -658,8 +618,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_no_token(self):
params = json.dumps({"type": "org.matrix.login.jwt"})
- request, channel = self.make_request(b"POST", LOGIN_URL, params)
- self.render(request)
+ channel = self.make_request(b"POST", LOGIN_URL, params)
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(channel.json_body["error"], "Token field for JWT is missing")
@@ -725,15 +684,18 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
self.hs.config.jwt_algorithm = "RS256"
return self.hs
- def jwt_encode(self, token, secret=jwt_privatekey):
- return jwt.encode(token, secret, "RS256").decode("ascii")
+ def jwt_encode(self, token: str, secret: str = jwt_privatekey) -> str:
+ # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
+ result = jwt.encode(token, secret, "RS256")
+ if isinstance(result, bytes):
+ return result.decode("ascii")
+ return result
def jwt_login(self, *args):
params = json.dumps(
{"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
)
- request, channel = self.make_request(b"POST", LOGIN_URL, params)
- self.render(request)
+ channel = self.make_request(b"POST", LOGIN_URL, params)
return channel
def test_login_jwt_valid(self):
@@ -761,12 +723,11 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
]
def register_as_user(self, username):
- request, channel = self.make_request(
+ self.make_request(
b"POST",
"/_matrix/client/r0/register?access_token=%s" % (self.service.token,),
{"username": username},
)
- self.render(request)
def make_homeserver(self, reactor, clock):
self.hs = self.setup_test_homeserver()
@@ -811,11 +772,10 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
"type": login.LoginRestServlet.APPSERVICE_TYPE,
"identifier": {"type": "m.id.user", "user": AS_USER},
}
- request, channel = self.make_request(
+ channel = self.make_request(
b"POST", LOGIN_URL, params, access_token=self.service.token
)
- self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
def test_login_appservice_user_bot(self):
@@ -827,11 +787,10 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
"type": login.LoginRestServlet.APPSERVICE_TYPE,
"identifier": {"type": "m.id.user", "user": self.service.sender},
}
- request, channel = self.make_request(
+ channel = self.make_request(
b"POST", LOGIN_URL, params, access_token=self.service.token
)
- self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
def test_login_appservice_wrong_user(self):
@@ -843,11 +802,10 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
"type": login.LoginRestServlet.APPSERVICE_TYPE,
"identifier": {"type": "m.id.user", "user": "fibble_wibble"},
}
- request, channel = self.make_request(
+ channel = self.make_request(
b"POST", LOGIN_URL, params, access_token=self.service.token
)
- self.render(request)
self.assertEquals(channel.result["code"], b"403", channel.result)
def test_login_appservice_wrong_as(self):
@@ -859,11 +817,10 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
"type": login.LoginRestServlet.APPSERVICE_TYPE,
"identifier": {"type": "m.id.user", "user": AS_USER},
}
- request, channel = self.make_request(
+ channel = self.make_request(
b"POST", LOGIN_URL, params, access_token=self.another_service.token
)
- self.render(request)
self.assertEquals(channel.result["code"], b"403", channel.result)
def test_login_appservice_no_token(self):
@@ -876,7 +833,6 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
"type": login.LoginRestServlet.APPSERVICE_TYPE,
"identifier": {"type": "m.id.user", "user": AS_USER},
}
- request, channel = self.make_request(b"POST", LOGIN_URL, params)
+ channel = self.make_request(b"POST", LOGIN_URL, params)
- self.render(request)
self.assertEquals(channel.result["code"], b"401", channel.result)
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
index 3c66255dac..94a5154834 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/v1/test_presence.py
@@ -33,13 +33,16 @@ class PresenceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
+ presence_handler = Mock()
+ presence_handler.set_state.return_value = defer.succeed(None)
+
hs = self.setup_test_homeserver(
- "red", http_client=None, federation_client=Mock()
+ "red",
+ federation_http_client=None,
+ federation_client=Mock(),
+ presence_handler=presence_handler,
)
- hs.presence_handler = Mock()
- hs.presence_handler.set_state.return_value = defer.succeed(None)
-
return hs
def test_put_presence(self):
@@ -50,13 +53,12 @@ class PresenceTestCase(unittest.HomeserverTestCase):
self.hs.config.use_presence = True
body = {"presence": "here", "status_msg": "beep boop"}
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", "/presence/%s/status" % (self.user_id,), body
)
- self.render(request)
self.assertEqual(channel.code, 200)
- self.assertEqual(self.hs.presence_handler.set_state.call_count, 1)
+ self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 1)
def test_put_presence_disabled(self):
"""
@@ -66,10 +68,9 @@ class PresenceTestCase(unittest.HomeserverTestCase):
self.hs.config.use_presence = False
body = {"presence": "here", "status_msg": "beep boop"}
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", "/presence/%s/status" % (self.user_id,), body
)
- self.render(request)
self.assertEqual(channel.code, 200)
- self.assertEqual(self.hs.presence_handler.set_state.call_count, 0)
+ self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 0)
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index ace0a3c08d..e59fa70baa 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -63,7 +63,7 @@ class MockHandlerProfileTestCase(unittest.TestCase):
hs = yield setup_test_homeserver(
self.addCleanup,
"test",
- http_client=None,
+ federation_http_client=None,
resource_for_client=self.mock_resource,
federation=Mock(),
federation_client=Mock(),
@@ -189,13 +189,12 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.owner_tok = self.login("owner", "pass")
def test_set_displayname(self):
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/profile/%s/displayname" % (self.owner,),
content=json.dumps({"displayname": "test"}),
access_token=self.owner_tok,
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
res = self.get_displayname()
@@ -203,23 +202,19 @@ class ProfileTestCase(unittest.HomeserverTestCase):
def test_set_displayname_too_long(self):
"""Attempts to set a stupid displayname should get a 400"""
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/profile/%s/displayname" % (self.owner,),
content=json.dumps({"displayname": "test" * 100}),
access_token=self.owner_tok,
)
- self.render(request)
self.assertEqual(channel.code, 400, channel.result)
res = self.get_displayname()
self.assertEqual(res, "owner")
def get_displayname(self):
- request, channel = self.make_request(
- "GET", "/profile/%s/displayname" % (self.owner,)
- )
- self.render(request)
+ channel = self.make_request("GET", "/profile/%s/displayname" % (self.owner,))
self.assertEqual(channel.code, 200, channel.result)
return channel.json_body["displayname"]
@@ -281,10 +276,9 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
)
def request_profile(self, expected_code, url_suffix="", access_token=None):
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.profile_url + url_suffix, access_token=access_token
)
- self.render(request)
self.assertEqual(channel.code, expected_code, channel.result)
def ensure_requester_left_room(self):
@@ -324,24 +318,21 @@ class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase):
"""Tests that a user can lookup their own profile without having to be in a room
if 'require_auth_for_profile_requests' is set to true in the server's config.
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/profile/" + self.requester, access_token=self.requester_tok
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/profile/" + self.requester + "/displayname",
access_token=self.requester_tok,
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/profile/" + self.requester + "/avatar_url",
access_token=self.requester_tok,
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
diff --git a/tests/rest/client/v1/test_push_rule_attrs.py b/tests/rest/client/v1/test_push_rule_attrs.py
index 081052f6a6..2bc512d75e 100644
--- a/tests/rest/client/v1/test_push_rule_attrs.py
+++ b/tests/rest/client/v1/test_push_rule_attrs.py
@@ -45,17 +45,15 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
}
# PUT a new rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", "/pushrules/global/override/best.friend", body, access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 200)
# GET enabled for that new rule
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/enabled", access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["enabled"], True)
@@ -76,49 +74,43 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
}
# PUT a new rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", "/pushrules/global/override/best.friend", body, access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 200)
# disable the rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/pushrules/global/override/best.friend/enabled",
{"enabled": False},
access_token=token,
)
- self.render(request)
self.assertEqual(channel.code, 200)
# check rule disabled
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/enabled", access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["enabled"], False)
# DELETE the rule
- request, channel = self.make_request(
+ channel = self.make_request(
"DELETE", "/pushrules/global/override/best.friend", access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 200)
# PUT a new rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", "/pushrules/global/override/best.friend", body, access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 200)
# GET enabled for that new rule
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/enabled", access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["enabled"], True)
@@ -138,45 +130,40 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
}
# PUT a new rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", "/pushrules/global/override/best.friend", body, access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 200)
# disable the rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/pushrules/global/override/best.friend/enabled",
{"enabled": False},
access_token=token,
)
- self.render(request)
self.assertEqual(channel.code, 200)
# check rule disabled
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/enabled", access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["enabled"], False)
# re-enable the rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/pushrules/global/override/best.friend/enabled",
{"enabled": True},
access_token=token,
)
- self.render(request)
self.assertEqual(channel.code, 200)
# check rule enabled
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/enabled", access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["enabled"], True)
@@ -195,39 +182,34 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
}
# check 404 for never-heard-of rule
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/enabled", access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 404)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
# PUT a new rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", "/pushrules/global/override/best.friend", body, access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 200)
# GET enabled for that new rule
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/enabled", access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 200)
# DELETE the rule
- request, channel = self.make_request(
+ channel = self.make_request(
"DELETE", "/pushrules/global/override/best.friend", access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 200)
# check 404 for deleted rule
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/enabled", access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 404)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
@@ -239,10 +221,9 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
token = self.login("user", "pass")
# check 404 for never-heard-of rule
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/pushrules/global/override/.m.muahahaha/enabled", access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 404)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
@@ -254,13 +235,12 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
token = self.login("user", "pass")
# enable & check 404 for never-heard-of rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/pushrules/global/override/best.friend/enabled",
{"enabled": True},
access_token=token,
)
- self.render(request)
self.assertEqual(channel.code, 404)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
@@ -272,13 +252,12 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
token = self.login("user", "pass")
# enable & check 404 for never-heard-of rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/pushrules/global/override/.m.muahahah/enabled",
{"enabled": True},
access_token=token,
)
- self.render(request)
self.assertEqual(channel.code, 404)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
@@ -297,17 +276,15 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
}
# PUT a new rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", "/pushrules/global/override/best.friend", body, access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 200)
# GET actions for that new rule
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/actions", access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 200)
self.assertEqual(
channel.json_body["actions"], ["notify", {"set_tweak": "highlight"}]
@@ -328,27 +305,24 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
}
# PUT a new rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", "/pushrules/global/override/best.friend", body, access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 200)
# change the rule actions
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/pushrules/global/override/best.friend/actions",
{"actions": ["dont_notify"]},
access_token=token,
)
- self.render(request)
self.assertEqual(channel.code, 200)
# GET actions for that new rule
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/actions", access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["actions"], ["dont_notify"])
@@ -367,32 +341,28 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
}
# check 404 for never-heard-of rule
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/enabled", access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 404)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
# PUT a new rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", "/pushrules/global/override/best.friend", body, access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 200)
# DELETE the rule
- request, channel = self.make_request(
+ channel = self.make_request(
"DELETE", "/pushrules/global/override/best.friend", access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 200)
# check 404 for deleted rule
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/enabled", access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 404)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
@@ -404,10 +374,9 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
token = self.login("user", "pass")
# check 404 for never-heard-of rule
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/pushrules/global/override/.m.muahahaha/actions", access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 404)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
@@ -419,13 +388,12 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
token = self.login("user", "pass")
# enable & check 404 for never-heard-of rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/pushrules/global/override/best.friend/actions",
{"actions": ["dont_notify"]},
access_token=token,
)
- self.render(request)
self.assertEqual(channel.code, 404)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
@@ -437,12 +405,11 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
token = self.login("user", "pass")
# enable & check 404 for never-heard-of rule
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/pushrules/global/override/.m.muahahah/actions",
{"actions": ["dont_notify"]},
access_token=token,
)
- self.render(request)
self.assertEqual(channel.code, 404)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 0d809d25d5..6105eac47c 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -26,12 +26,14 @@ from mock import Mock
import synapse.rest.admin
from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.handlers.pagination import PurgeStatus
+from synapse.rest import admin
from synapse.rest.client.v1 import directory, login, profile, room
from synapse.rest.client.v2_alpha import account
from synapse.types import JsonDict, RoomAlias, UserID
from synapse.util.stringutils import random_string
from tests import unittest
+from tests.test_utils import make_awaitable
PATH_PREFIX = b"/_matrix/client/api/v1"
@@ -44,10 +46,13 @@ class RoomBase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.hs = self.setup_test_homeserver(
- "red", http_client=None, federation_client=Mock(),
+ "red", federation_http_client=None, federation_client=Mock(),
)
- self.hs.get_federation_handler = Mock(return_value=Mock())
+ self.hs.get_federation_handler = Mock()
+ self.hs.get_federation_handler.return_value.maybe_backfill = Mock(
+ return_value=make_awaitable(None)
+ )
async def _insert_client_ip(*args, **kwargs):
return None
@@ -79,19 +84,17 @@ class RoomPermissionsTestCase(RoomBase):
self.created_rmid_msg_path = (
"rooms/%s/send/m.room.message/a1" % (self.created_rmid)
).encode("ascii")
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", self.created_rmid_msg_path, b'{"msgtype":"m.text","body":"test msg"}'
)
- self.render(request)
self.assertEquals(200, channel.code, channel.result)
# set topic for public room
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode("ascii"),
b'{"topic":"Public Room Topic"}',
)
- self.render(request)
self.assertEquals(200, channel.code, channel.result)
# auth as user_id now
@@ -109,37 +112,32 @@ class RoomPermissionsTestCase(RoomBase):
)
# send message in uncreated room, expect 403
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,),
msg_content,
)
- self.render(request)
self.assertEquals(403, channel.code, msg=channel.result["body"])
# send message in created room not joined (no state), expect 403
- request, channel = self.make_request("PUT", send_msg_path(), msg_content)
- self.render(request)
+ channel = self.make_request("PUT", send_msg_path(), msg_content)
self.assertEquals(403, channel.code, msg=channel.result["body"])
# send message in created room and invited, expect 403
self.helper.invite(
room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id
)
- request, channel = self.make_request("PUT", send_msg_path(), msg_content)
- self.render(request)
+ channel = self.make_request("PUT", send_msg_path(), msg_content)
self.assertEquals(403, channel.code, msg=channel.result["body"])
# send message in created room and joined, expect 200
self.helper.join(room=self.created_rmid, user=self.user_id)
- request, channel = self.make_request("PUT", send_msg_path(), msg_content)
- self.render(request)
+ channel = self.make_request("PUT", send_msg_path(), msg_content)
self.assertEquals(200, channel.code, msg=channel.result["body"])
# send message in created room and left, expect 403
self.helper.leave(room=self.created_rmid, user=self.user_id)
- request, channel = self.make_request("PUT", send_msg_path(), msg_content)
- self.render(request)
+ channel = self.make_request("PUT", send_msg_path(), msg_content)
self.assertEquals(403, channel.code, msg=channel.result["body"])
def test_topic_perms(self):
@@ -147,36 +145,30 @@ class RoomPermissionsTestCase(RoomBase):
topic_path = "/rooms/%s/state/m.room.topic" % self.created_rmid
# set/get topic in uncreated room, expect 403
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content
)
- self.render(request)
self.assertEquals(403, channel.code, msg=channel.result["body"])
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid
)
- self.render(request)
self.assertEquals(403, channel.code, msg=channel.result["body"])
# set/get topic in created PRIVATE room not joined, expect 403
- request, channel = self.make_request("PUT", topic_path, topic_content)
- self.render(request)
+ channel = self.make_request("PUT", topic_path, topic_content)
self.assertEquals(403, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("GET", topic_path)
- self.render(request)
+ channel = self.make_request("GET", topic_path)
self.assertEquals(403, channel.code, msg=channel.result["body"])
# set topic in created PRIVATE room and invited, expect 403
self.helper.invite(
room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id
)
- request, channel = self.make_request("PUT", topic_path, topic_content)
- self.render(request)
+ channel = self.make_request("PUT", topic_path, topic_content)
self.assertEquals(403, channel.code, msg=channel.result["body"])
# get topic in created PRIVATE room and invited, expect 403
- request, channel = self.make_request("GET", topic_path)
- self.render(request)
+ channel = self.make_request("GET", topic_path)
self.assertEquals(403, channel.code, msg=channel.result["body"])
# set/get topic in created PRIVATE room and joined, expect 200
@@ -184,46 +176,39 @@ class RoomPermissionsTestCase(RoomBase):
# Only room ops can set topic by default
self.helper.auth_user_id = self.rmcreator_id
- request, channel = self.make_request("PUT", topic_path, topic_content)
- self.render(request)
+ channel = self.make_request("PUT", topic_path, topic_content)
self.assertEquals(200, channel.code, msg=channel.result["body"])
self.helper.auth_user_id = self.user_id
- request, channel = self.make_request("GET", topic_path)
- self.render(request)
+ channel = self.make_request("GET", topic_path)
self.assertEquals(200, channel.code, msg=channel.result["body"])
self.assert_dict(json.loads(topic_content.decode("utf8")), channel.json_body)
# set/get topic in created PRIVATE room and left, expect 403
self.helper.leave(room=self.created_rmid, user=self.user_id)
- request, channel = self.make_request("PUT", topic_path, topic_content)
- self.render(request)
+ channel = self.make_request("PUT", topic_path, topic_content)
self.assertEquals(403, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("GET", topic_path)
- self.render(request)
+ channel = self.make_request("GET", topic_path)
self.assertEquals(200, channel.code, msg=channel.result["body"])
# get topic in PUBLIC room, not joined, expect 403
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/rooms/%s/state/m.room.topic" % self.created_public_rmid
)
- self.render(request)
self.assertEquals(403, channel.code, msg=channel.result["body"])
# set topic in PUBLIC room, not joined, expect 403
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/rooms/%s/state/m.room.topic" % self.created_public_rmid,
topic_content,
)
- self.render(request)
self.assertEquals(403, channel.code, msg=channel.result["body"])
def _test_get_membership(self, room=None, members=[], expect_code=None):
for member in members:
path = "/rooms/%s/state/m.room.member/%s" % (room, member)
- request, channel = self.make_request("GET", path)
- self.render(request)
+ channel = self.make_request("GET", path)
self.assertEquals(expect_code, channel.code)
def test_membership_basic_room_perms(self):
@@ -395,19 +380,16 @@ class RoomsMemberListTestCase(RoomBase):
def test_get_member_list(self):
room_id = self.helper.create_room_as(self.user_id)
- request, channel = self.make_request("GET", "/rooms/%s/members" % room_id)
- self.render(request)
+ channel = self.make_request("GET", "/rooms/%s/members" % room_id)
self.assertEquals(200, channel.code, msg=channel.result["body"])
def test_get_member_list_no_room(self):
- request, channel = self.make_request("GET", "/rooms/roomdoesnotexist/members")
- self.render(request)
+ channel = self.make_request("GET", "/rooms/roomdoesnotexist/members")
self.assertEquals(403, channel.code, msg=channel.result["body"])
def test_get_member_list_no_permission(self):
room_id = self.helper.create_room_as("@some_other_guy:red")
- request, channel = self.make_request("GET", "/rooms/%s/members" % room_id)
- self.render(request)
+ channel = self.make_request("GET", "/rooms/%s/members" % room_id)
self.assertEquals(403, channel.code, msg=channel.result["body"])
def test_get_member_list_mixed_memberships(self):
@@ -416,20 +398,17 @@ class RoomsMemberListTestCase(RoomBase):
room_path = "/rooms/%s/members" % room_id
self.helper.invite(room=room_id, src=room_creator, targ=self.user_id)
# can't see list if you're just invited.
- request, channel = self.make_request("GET", room_path)
- self.render(request)
+ channel = self.make_request("GET", room_path)
self.assertEquals(403, channel.code, msg=channel.result["body"])
self.helper.join(room=room_id, user=self.user_id)
# can see list now joined
- request, channel = self.make_request("GET", room_path)
- self.render(request)
+ channel = self.make_request("GET", room_path)
self.assertEquals(200, channel.code, msg=channel.result["body"])
self.helper.leave(room=room_id, user=self.user_id)
# can see old list once left
- request, channel = self.make_request("GET", room_path)
- self.render(request)
+ channel = self.make_request("GET", room_path)
self.assertEquals(200, channel.code, msg=channel.result["body"])
@@ -440,56 +419,45 @@ class RoomsCreateTestCase(RoomBase):
def test_post_room_no_keys(self):
# POST with no config keys, expect new room id
- request, channel = self.make_request("POST", "/createRoom", "{}")
+ channel = self.make_request("POST", "/createRoom", "{}")
- self.render(request)
self.assertEquals(200, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
def test_post_room_visibility_key(self):
# POST with visibility config key, expect new room id
- request, channel = self.make_request(
- "POST", "/createRoom", b'{"visibility":"private"}'
- )
- self.render(request)
+ channel = self.make_request("POST", "/createRoom", b'{"visibility":"private"}')
self.assertEquals(200, channel.code)
self.assertTrue("room_id" in channel.json_body)
def test_post_room_custom_key(self):
# POST with custom config keys, expect new room id
- request, channel = self.make_request(
- "POST", "/createRoom", b'{"custom":"stuff"}'
- )
- self.render(request)
+ channel = self.make_request("POST", "/createRoom", b'{"custom":"stuff"}')
self.assertEquals(200, channel.code)
self.assertTrue("room_id" in channel.json_body)
def test_post_room_known_and_unknown_keys(self):
# POST with custom + known config keys, expect new room id
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", "/createRoom", b'{"visibility":"private","custom":"things"}'
)
- self.render(request)
self.assertEquals(200, channel.code)
self.assertTrue("room_id" in channel.json_body)
def test_post_room_invalid_content(self):
# POST with invalid content / paths, expect 400
- request, channel = self.make_request("POST", "/createRoom", b'{"visibili')
- self.render(request)
+ channel = self.make_request("POST", "/createRoom", b'{"visibili')
self.assertEquals(400, channel.code)
- request, channel = self.make_request("POST", "/createRoom", b'["hello"]')
- self.render(request)
+ channel = self.make_request("POST", "/createRoom", b'["hello"]')
self.assertEquals(400, channel.code)
def test_post_room_invitees_invalid_mxid(self):
# POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088
# Note the trailing space in the MXID here!
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", "/createRoom", b'{"invite":["@alice:example.com "]}'
)
- self.render(request)
self.assertEquals(400, channel.code)
@@ -505,66 +473,54 @@ class RoomTopicTestCase(RoomBase):
def test_invalid_puts(self):
# missing keys or invalid json
- request, channel = self.make_request("PUT", self.path, "{}")
- self.render(request)
+ channel = self.make_request("PUT", self.path, "{}")
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("PUT", self.path, '{"_name":"bo"}')
- self.render(request)
+ channel = self.make_request("PUT", self.path, '{"_name":"bo"}')
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("PUT", self.path, '{"nao')
- self.render(request)
+ channel = self.make_request("PUT", self.path, '{"nao')
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", self.path, '[{"_name":"bo"},{"_name":"jill"}]'
)
- self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("PUT", self.path, "text only")
- self.render(request)
+ channel = self.make_request("PUT", self.path, "text only")
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("PUT", self.path, "")
- self.render(request)
+ channel = self.make_request("PUT", self.path, "")
self.assertEquals(400, channel.code, msg=channel.result["body"])
# valid key, wrong type
content = '{"topic":["Topic name"]}'
- request, channel = self.make_request("PUT", self.path, content)
- self.render(request)
+ channel = self.make_request("PUT", self.path, content)
self.assertEquals(400, channel.code, msg=channel.result["body"])
def test_rooms_topic(self):
# nothing should be there
- request, channel = self.make_request("GET", self.path)
- self.render(request)
+ channel = self.make_request("GET", self.path)
self.assertEquals(404, channel.code, msg=channel.result["body"])
# valid put
content = '{"topic":"Topic name"}'
- request, channel = self.make_request("PUT", self.path, content)
- self.render(request)
+ channel = self.make_request("PUT", self.path, content)
self.assertEquals(200, channel.code, msg=channel.result["body"])
# valid get
- request, channel = self.make_request("GET", self.path)
- self.render(request)
+ channel = self.make_request("GET", self.path)
self.assertEquals(200, channel.code, msg=channel.result["body"])
self.assert_dict(json.loads(content), channel.json_body)
def test_rooms_topic_with_extra_keys(self):
# valid put with extra keys
content = '{"topic":"Seasons","subtopic":"Summer"}'
- request, channel = self.make_request("PUT", self.path, content)
- self.render(request)
+ channel = self.make_request("PUT", self.path, content)
self.assertEquals(200, channel.code, msg=channel.result["body"])
# valid get
- request, channel = self.make_request("GET", self.path)
- self.render(request)
+ channel = self.make_request("GET", self.path)
self.assertEquals(200, channel.code, msg=channel.result["body"])
self.assert_dict(json.loads(content), channel.json_body)
@@ -580,30 +536,22 @@ class RoomMemberStateTestCase(RoomBase):
def test_invalid_puts(self):
path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id)
# missing keys or invalid json
- request, channel = self.make_request("PUT", path, "{}")
- self.render(request)
+ channel = self.make_request("PUT", path, "{}")
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("PUT", path, '{"_name":"bo"}')
- self.render(request)
+ channel = self.make_request("PUT", path, '{"_name":"bo"}')
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("PUT", path, '{"nao')
- self.render(request)
+ channel = self.make_request("PUT", path, '{"nao')
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request(
- "PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]'
- )
- self.render(request)
+ channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]')
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("PUT", path, "text only")
- self.render(request)
+ channel = self.make_request("PUT", path, "text only")
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("PUT", path, "")
- self.render(request)
+ channel = self.make_request("PUT", path, "")
self.assertEquals(400, channel.code, msg=channel.result["body"])
# valid keys, wrong types
@@ -612,8 +560,7 @@ class RoomMemberStateTestCase(RoomBase):
Membership.JOIN,
Membership.LEAVE,
)
- request, channel = self.make_request("PUT", path, content.encode("ascii"))
- self.render(request)
+ channel = self.make_request("PUT", path, content.encode("ascii"))
self.assertEquals(400, channel.code, msg=channel.result["body"])
def test_rooms_members_self(self):
@@ -624,12 +571,10 @@ class RoomMemberStateTestCase(RoomBase):
# valid join message (NOOP since we made the room)
content = '{"membership":"%s"}' % Membership.JOIN
- request, channel = self.make_request("PUT", path, content.encode("ascii"))
- self.render(request)
+ channel = self.make_request("PUT", path, content.encode("ascii"))
self.assertEquals(200, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("GET", path, None)
- self.render(request)
+ channel = self.make_request("GET", path, None)
self.assertEquals(200, channel.code, msg=channel.result["body"])
expected_response = {"membership": Membership.JOIN}
@@ -644,12 +589,10 @@ class RoomMemberStateTestCase(RoomBase):
# valid invite message
content = '{"membership":"%s"}' % Membership.INVITE
- request, channel = self.make_request("PUT", path, content)
- self.render(request)
+ channel = self.make_request("PUT", path, content)
self.assertEquals(200, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("GET", path, None)
- self.render(request)
+ channel = self.make_request("GET", path, None)
self.assertEquals(200, channel.code, msg=channel.result["body"])
self.assertEquals(json.loads(content), channel.json_body)
@@ -665,12 +608,10 @@ class RoomMemberStateTestCase(RoomBase):
Membership.INVITE,
"Join us!",
)
- request, channel = self.make_request("PUT", path, content)
- self.render(request)
+ channel = self.make_request("PUT", path, content)
self.assertEquals(200, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("GET", path, None)
- self.render(request)
+ channel = self.make_request("GET", path, None)
self.assertEquals(200, channel.code, msg=channel.result["body"])
self.assertEquals(json.loads(content), channel.json_body)
@@ -679,6 +620,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
user_id = "@sid1:red"
servlets = [
+ admin.register_servlets,
profile.register_servlets,
room.register_servlets,
]
@@ -720,8 +662,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
# Update the display name for the user.
path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id
- request, channel = self.make_request("PUT", path, {"displayname": "John Doe"})
- self.render(request)
+ channel = self.make_request("PUT", path, {"displayname": "John Doe"})
self.assertEquals(channel.code, 200, channel.json_body)
# Check that all the rooms have been sent a profile update into.
@@ -731,8 +672,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
self.user_id,
)
- request, channel = self.make_request("GET", path)
- self.render(request)
+ channel = self.make_request("GET", path)
self.assertEquals(channel.code, 200)
self.assertIn("displayname", channel.json_body)
@@ -756,10 +696,23 @@ class RoomJoinRatelimitTestCase(RoomBase):
# Make sure we send more requests than the rate-limiting config would allow
# if all of these requests ended up joining the user to a room.
for i in range(4):
- request, channel = self.make_request("POST", path % room_id, {})
- self.render(request)
+ channel = self.make_request("POST", path % room_id, {})
self.assertEquals(channel.code, 200)
+ @unittest.override_config(
+ {
+ "rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}},
+ "auto_join_rooms": ["#room:red", "#room2:red", "#room3:red", "#room4:red"],
+ "autocreate_auto_join_rooms": True,
+ },
+ )
+ def test_autojoin_rooms(self):
+ user_id = self.register_user("testuser", "password")
+
+ # Check that the new user successfully joined the four rooms
+ rooms = self.get_success(self.hs.get_datastore().get_rooms_for_user(user_id))
+ self.assertEqual(len(rooms), 4)
+
class RoomMessagesTestCase(RoomBase):
""" Tests /rooms/$room_id/messages/$user_id/$msg_id REST events. """
@@ -772,51 +725,40 @@ class RoomMessagesTestCase(RoomBase):
def test_invalid_puts(self):
path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
# missing keys or invalid json
- request, channel = self.make_request("PUT", path, b"{}")
- self.render(request)
+ channel = self.make_request("PUT", path, b"{}")
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("PUT", path, b'{"_name":"bo"}')
- self.render(request)
+ channel = self.make_request("PUT", path, b'{"_name":"bo"}')
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("PUT", path, b'{"nao')
- self.render(request)
+ channel = self.make_request("PUT", path, b'{"nao')
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request(
- "PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]'
- )
- self.render(request)
+ channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]')
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("PUT", path, b"text only")
- self.render(request)
+ channel = self.make_request("PUT", path, b"text only")
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("PUT", path, b"")
- self.render(request)
+ channel = self.make_request("PUT", path, b"")
self.assertEquals(400, channel.code, msg=channel.result["body"])
def test_rooms_messages_sent(self):
path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
content = b'{"body":"test","msgtype":{"type":"a"}}'
- request, channel = self.make_request("PUT", path, content)
- self.render(request)
+ channel = self.make_request("PUT", path, content)
self.assertEquals(400, channel.code, msg=channel.result["body"])
# custom message types
content = b'{"body":"test","msgtype":"test.custom.text"}'
- request, channel = self.make_request("PUT", path, content)
- self.render(request)
+ channel = self.make_request("PUT", path, content)
self.assertEquals(200, channel.code, msg=channel.result["body"])
# m.text message type
path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id))
content = b'{"body":"test2","msgtype":"m.text"}'
- request, channel = self.make_request("PUT", path, content)
- self.render(request)
+ channel = self.make_request("PUT", path, content)
self.assertEquals(200, channel.code, msg=channel.result["body"])
@@ -830,10 +772,7 @@ class RoomInitialSyncTestCase(RoomBase):
self.room_id = self.helper.create_room_as(self.user_id)
def test_initial_sync(self):
- request, channel = self.make_request(
- "GET", "/rooms/%s/initialSync" % self.room_id
- )
- self.render(request)
+ channel = self.make_request("GET", "/rooms/%s/initialSync" % self.room_id)
self.assertEquals(200, channel.code)
self.assertEquals(self.room_id, channel.json_body["room_id"])
@@ -874,10 +813,9 @@ class RoomMessageListTestCase(RoomBase):
def test_topo_token_is_accepted(self):
token = "t1-0_0_0_0_0_0_0_0_0"
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
)
- self.render(request)
self.assertEquals(200, channel.code)
self.assertTrue("start" in channel.json_body)
self.assertEquals(token, channel.json_body["start"])
@@ -886,10 +824,9 @@ class RoomMessageListTestCase(RoomBase):
def test_stream_token_is_accepted_for_fwd_pagianation(self):
token = "s0_0_0_0_0_0_0_0_0"
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
)
- self.render(request)
self.assertEquals(200, channel.code)
self.assertTrue("start" in channel.json_body)
self.assertEquals(token, channel.json_body["start"])
@@ -920,7 +857,7 @@ class RoomMessageListTestCase(RoomBase):
self.helper.send(self.room_id, "message 3")
# Check that we get the first and second message when querying /messages.
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
% (
@@ -929,7 +866,6 @@ class RoomMessageListTestCase(RoomBase):
json.dumps({"types": [EventTypes.Message]}),
),
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.json_body)
chunk = channel.json_body["chunk"]
@@ -949,7 +885,7 @@ class RoomMessageListTestCase(RoomBase):
# Check that we only get the second message through /message now that the first
# has been purged.
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
% (
@@ -958,7 +894,6 @@ class RoomMessageListTestCase(RoomBase):
json.dumps({"types": [EventTypes.Message]}),
),
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.json_body)
chunk = channel.json_body["chunk"]
@@ -967,7 +902,7 @@ class RoomMessageListTestCase(RoomBase):
# Check that we get no event, but also no error, when querying /messages with
# the token that was pointing at the first event, because we don't have it
# anymore.
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
% (
@@ -976,7 +911,6 @@ class RoomMessageListTestCase(RoomBase):
json.dumps({"types": [EventTypes.Message]}),
),
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.json_body)
chunk = channel.json_body["chunk"]
@@ -1027,7 +961,7 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
self.helper.send(self.room, body="Hi!", tok=self.other_access_token)
self.helper.send(self.room, body="There!", tok=self.other_access_token)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/search?access_token=%s" % (self.access_token,),
{
@@ -1036,7 +970,6 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
}
},
)
- self.render(request)
# Check we get the results we expect -- one search result, of the sent
# messages
@@ -1057,7 +990,7 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
self.helper.send(self.room, body="Hi!", tok=self.other_access_token)
self.helper.send(self.room, body="There!", tok=self.other_access_token)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/search?access_token=%s" % (self.access_token,),
{
@@ -1070,7 +1003,6 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
}
},
)
- self.render(request)
# Check we get the results we expect -- one search result, of the sent
# messages
@@ -1106,16 +1038,14 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
return self.hs
def test_restricted_no_auth(self):
- request, channel = self.make_request("GET", self.url)
- self.render(request)
+ channel = self.make_request("GET", self.url)
self.assertEqual(channel.code, 401, channel.result)
def test_restricted_auth(self):
self.register_user("user", "pass")
tok = self.login("user", "pass")
- request, channel = self.make_request("GET", self.url, access_token=tok)
- self.render(request)
+ channel = self.make_request("GET", self.url, access_token=tok)
self.assertEqual(channel.code, 200, channel.result)
@@ -1143,13 +1073,12 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
self.displayname = "test user"
data = {"displayname": self.displayname}
request_data = json.dumps(data)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/_matrix/client/r0/profile/%s/displayname" % (self.user_id,),
request_data,
access_token=self.tok,
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
@@ -1157,23 +1086,21 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
def test_per_room_profile_forbidden(self):
data = {"membership": "join", "displayname": "other test user"}
request_data = json.dumps(data)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/_matrix/client/r0/rooms/%s/state/m.room.member/%s"
% (self.room_id, self.user_id),
request_data,
access_token=self.tok,
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
event_id = channel.json_body["event_id"]
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
access_token=self.tok,
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
res_displayname = channel.json_body["content"]["displayname"]
@@ -1202,13 +1129,12 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
def test_join_reason(self):
reason = "hello"
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/rooms/{}/join".format(self.room_id),
content={"reason": reason},
access_token=self.second_tok,
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
self._check_for_reason(reason)
@@ -1217,13 +1143,12 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
reason = "hello"
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/rooms/{}/leave".format(self.room_id),
content={"reason": reason},
access_token=self.second_tok,
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
self._check_for_reason(reason)
@@ -1232,13 +1157,12 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
reason = "hello"
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/rooms/{}/kick".format(self.room_id),
content={"reason": reason, "user_id": self.second_user_id},
access_token=self.second_tok,
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
self._check_for_reason(reason)
@@ -1247,39 +1171,36 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
reason = "hello"
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/rooms/{}/ban".format(self.room_id),
content={"reason": reason, "user_id": self.second_user_id},
access_token=self.creator_tok,
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
self._check_for_reason(reason)
def test_unban_reason(self):
reason = "hello"
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/rooms/{}/unban".format(self.room_id),
content={"reason": reason, "user_id": self.second_user_id},
access_token=self.creator_tok,
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
self._check_for_reason(reason)
def test_invite_reason(self):
reason = "hello"
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/rooms/{}/invite".format(self.room_id),
content={"reason": reason, "user_id": self.second_user_id},
access_token=self.creator_tok,
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
self._check_for_reason(reason)
@@ -1293,26 +1214,24 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
)
reason = "hello"
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/rooms/{}/leave".format(self.room_id),
content={"reason": reason},
access_token=self.second_tok,
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
self._check_for_reason(reason)
def _check_for_reason(self, reason):
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_matrix/client/r0/rooms/{}/state/m.room.member/{}".format(
self.room_id, self.second_user_id
),
access_token=self.creator_tok,
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
event_content = channel.json_body
@@ -1355,13 +1274,12 @@ class LabelsTestCase(unittest.HomeserverTestCase):
"""Test that we can filter by a label on a /context request."""
event_id = self._send_labelled_messages_in_room()
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/rooms/%s/context/%s?filter=%s"
% (self.room_id, event_id, json.dumps(self.FILTER_LABELS)),
access_token=self.tok,
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
events_before = channel.json_body["events_before"]
@@ -1386,13 +1304,12 @@ class LabelsTestCase(unittest.HomeserverTestCase):
"""Test that we can filter by the absence of a label on a /context request."""
event_id = self._send_labelled_messages_in_room()
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/rooms/%s/context/%s?filter=%s"
% (self.room_id, event_id, json.dumps(self.FILTER_NOT_LABELS)),
access_token=self.tok,
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
events_before = channel.json_body["events_before"]
@@ -1422,13 +1339,12 @@ class LabelsTestCase(unittest.HomeserverTestCase):
"""
event_id = self._send_labelled_messages_in_room()
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/rooms/%s/context/%s?filter=%s"
% (self.room_id, event_id, json.dumps(self.FILTER_LABELS_NOT_LABELS)),
access_token=self.tok,
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
events_before = channel.json_body["events_before"]
@@ -1451,12 +1367,11 @@ class LabelsTestCase(unittest.HomeserverTestCase):
self._send_labelled_messages_in_room()
token = "s0_0_0_0_0_0_0_0_0"
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/rooms/%s/messages?access_token=%s&from=%s&filter=%s"
% (self.room_id, self.tok, token, json.dumps(self.FILTER_LABELS)),
)
- self.render(request)
events = channel.json_body["chunk"]
@@ -1469,12 +1384,11 @@ class LabelsTestCase(unittest.HomeserverTestCase):
self._send_labelled_messages_in_room()
token = "s0_0_0_0_0_0_0_0_0"
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/rooms/%s/messages?access_token=%s&from=%s&filter=%s"
% (self.room_id, self.tok, token, json.dumps(self.FILTER_NOT_LABELS)),
)
- self.render(request)
events = channel.json_body["chunk"]
@@ -1493,7 +1407,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
self._send_labelled_messages_in_room()
token = "s0_0_0_0_0_0_0_0_0"
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/rooms/%s/messages?access_token=%s&from=%s&filter=%s"
% (
@@ -1503,7 +1417,6 @@ class LabelsTestCase(unittest.HomeserverTestCase):
json.dumps(self.FILTER_LABELS_NOT_LABELS),
),
)
- self.render(request)
events = channel.json_body["chunk"]
@@ -1525,10 +1438,9 @@ class LabelsTestCase(unittest.HomeserverTestCase):
self._send_labelled_messages_in_room()
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", "/search?access_token=%s" % self.tok, request_data
)
- self.render(request)
results = channel.json_body["search_categories"]["room_events"]["results"]
@@ -1561,10 +1473,9 @@ class LabelsTestCase(unittest.HomeserverTestCase):
self._send_labelled_messages_in_room()
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", "/search?access_token=%s" % self.tok, request_data
)
- self.render(request)
results = channel.json_body["search_categories"]["room_events"]["results"]
@@ -1609,10 +1520,9 @@ class LabelsTestCase(unittest.HomeserverTestCase):
self._send_labelled_messages_in_room()
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", "/search?access_token=%s" % self.tok, request_data
)
- self.render(request)
results = channel.json_body["search_categories"]["room_events"]["results"]
@@ -1731,13 +1641,12 @@ class ContextTestCase(unittest.HomeserverTestCase):
# Check that we can still see the messages before the erasure request.
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
'/rooms/%s/context/%s?filter={"types":["m.room.message"]}'
% (self.room_id, event_id),
access_token=self.tok,
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
events_before = channel.json_body["events_before"]
@@ -1796,13 +1705,12 @@ class ContextTestCase(unittest.HomeserverTestCase):
# Check that a user that joined the room after the erasure request can't see
# the messages anymore.
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
'/rooms/%s/context/%s?filter={"types":["m.room.message"]}'
% (self.room_id, event_id),
access_token=invited_tok,
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
events_before = channel.json_body["events_before"]
@@ -1887,13 +1795,12 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
def _get_aliases(self, access_token: str, expected_code: int = 200) -> JsonDict:
"""Calls the endpoint under test. returns the json response object."""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_matrix/client/unstable/org.matrix.msc2432/rooms/%s/aliases"
% (self.room_id,),
access_token=access_token,
)
- self.render(request)
self.assertEqual(channel.code, expected_code, channel.result)
res = channel.json_body
self.assertIsInstance(res, dict)
@@ -1909,10 +1816,9 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
data = {"room_id": self.room_id}
request_data = json.dumps(data)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", url, request_data, access_token=self.room_owner_tok
)
- self.render(request)
self.assertEqual(channel.code, expected_code, channel.result)
@@ -1940,20 +1846,18 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
data = {"room_id": self.room_id}
request_data = json.dumps(data)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", url, request_data, access_token=self.room_owner_tok
)
- self.render(request)
self.assertEqual(channel.code, expected_code, channel.result)
def _get_canonical_alias(self, expected_code: int = 200) -> JsonDict:
"""Calls the endpoint under test. returns the json response object."""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"rooms/%s/state/m.room.canonical_alias" % (self.room_id,),
access_token=self.room_owner_tok,
)
- self.render(request)
self.assertEqual(channel.code, expected_code, channel.result)
res = channel.json_body
self.assertIsInstance(res, dict)
@@ -1961,13 +1865,12 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
def _set_canonical_alias(self, content: str, expected_code: int = 200) -> JsonDict:
"""Calls the endpoint under test. returns the json response object."""
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"rooms/%s/state/m.room.canonical_alias" % (self.room_id,),
json.dumps(content),
access_token=self.room_owner_tok,
)
- self.render(request)
self.assertEqual(channel.code, expected_code, channel.result)
res = channel.json_body
self.assertIsInstance(res, dict)
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 94d2bf2eb1..38c51525a3 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -39,12 +39,12 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
- "red", http_client=None, federation_client=Mock(),
+ "red", federation_http_client=None, federation_client=Mock(),
)
self.event_source = hs.get_event_sources().sources["typing"]
- hs.get_handlers().federation_handler = Mock()
+ hs.get_federation_handler = Mock()
async def get_user_by_access_token(token=None, allow_guest=False):
return {
@@ -94,12 +94,11 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
self.helper.join(self.room_id, user="@jim:red")
def test_set_typing(self):
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/rooms/%s/typing/%s" % (self.room_id, self.user_id),
b'{"typing": true, "timeout": 30000}',
)
- self.render(request)
self.assertEquals(200, channel.code)
self.assertEquals(self.event_source.get_current_key(), 1)
@@ -118,21 +117,19 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
)
def test_set_not_typing(self):
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/rooms/%s/typing/%s" % (self.room_id, self.user_id),
b'{"typing": false}',
)
- self.render(request)
self.assertEquals(200, channel.code)
def test_typing_timeout(self):
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/rooms/%s/typing/%s" % (self.room_id, self.user_id),
b'{"typing": true, "timeout": 30000}',
)
- self.render(request)
self.assertEquals(200, channel.code)
self.assertEquals(self.event_source.get_current_key(), 1)
@@ -141,12 +138,11 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
self.assertEquals(self.event_source.get_current_key(), 2)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/rooms/%s/typing/%s" % (self.room_id, self.user_id),
b'{"typing": true, "timeout": 30000}',
)
- self.render(request)
self.assertEquals(200, channel.code)
self.assertEquals(self.event_source.get_current_key(), 3)
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index afaf9f7b85..dbc27893b5 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -2,7 +2,7 @@
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
# Copyright 2018-2019 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,16 +17,23 @@
# limitations under the License.
import json
+import re
import time
+import urllib.parse
from typing import Any, Dict, Optional
+from mock import patch
+
import attr
from twisted.web.resource import Resource
+from twisted.web.server import Site
from synapse.api.constants import Membership
+from synapse.types import JsonDict
-from tests.server import make_request, render
+from tests.server import FakeSite, make_request
+from tests.test_utils import FakeResponse
@attr.s
@@ -36,25 +43,51 @@ class RestHelper:
"""
hs = attr.ib()
- resource = attr.ib()
+ site = attr.ib(type=Site)
auth_user_id = attr.ib()
def create_room_as(
- self, room_creator=None, is_public=True, tok=None, expect_code=200,
- ):
+ self,
+ room_creator: str = None,
+ is_public: bool = True,
+ room_version: str = None,
+ tok: str = None,
+ expect_code: int = 200,
+ ) -> str:
+ """
+ Create a room.
+
+ Args:
+ room_creator: The user ID to create the room with.
+ is_public: If True, the `visibility` parameter will be set to the
+ default (public). Otherwise, the `visibility` parameter will be set
+ to "private".
+ room_version: The room version to create the room as. Defaults to Synapse's
+ default room version.
+ tok: The access token to use in the request.
+ expect_code: The expected HTTP response code.
+
+ Returns:
+ The ID of the newly created room.
+ """
temp_id = self.auth_user_id
self.auth_user_id = room_creator
path = "/_matrix/client/r0/createRoom"
content = {}
if not is_public:
content["visibility"] = "private"
+ if room_version:
+ content["room_version"] = room_version
if tok:
path = path + "?access_token=%s" % tok
- request, channel = make_request(
- self.hs.get_reactor(), "POST", path, json.dumps(content).encode("utf8")
+ channel = make_request(
+ self.hs.get_reactor(),
+ self.site,
+ "POST",
+ path,
+ json.dumps(content).encode("utf8"),
)
- render(request, self.resource, self.hs.get_reactor())
assert channel.result["code"] == b"%d" % expect_code, channel.result
self.auth_user_id = temp_id
@@ -124,12 +157,14 @@ class RestHelper:
data = {"membership": membership}
data.update(extra_data)
- request, channel = make_request(
- self.hs.get_reactor(), "PUT", path, json.dumps(data).encode("utf8")
+ channel = make_request(
+ self.hs.get_reactor(),
+ self.site,
+ "PUT",
+ path,
+ json.dumps(data).encode("utf8"),
)
- render(request, self.resource, self.hs.get_reactor())
-
assert int(channel.result["code"]) == expect_code, (
"Expected: %d, got: %d, resp: %r"
% (expect_code, int(channel.result["code"]), channel.result["body"])
@@ -157,10 +192,13 @@ class RestHelper:
if tok:
path = path + "?access_token=%s" % tok
- request, channel = make_request(
- self.hs.get_reactor(), "PUT", path, json.dumps(content).encode("utf8")
+ channel = make_request(
+ self.hs.get_reactor(),
+ self.site,
+ "PUT",
+ path,
+ json.dumps(content).encode("utf8"),
)
- render(request, self.resource, self.hs.get_reactor())
assert int(channel.result["code"]) == expect_code, (
"Expected: %d, got: %d, resp: %r"
@@ -210,9 +248,7 @@ class RestHelper:
if body is not None:
content = json.dumps(body).encode("utf8")
- request, channel = make_request(self.hs.get_reactor(), method, path, content)
-
- render(request, self.resource, self.hs.get_reactor())
+ channel = make_request(self.hs.get_reactor(), self.site, method, path, content)
assert int(channel.result["code"]) == expect_code, (
"Expected: %d, got: %d, resp: %r"
@@ -295,14 +331,15 @@ class RestHelper:
"""
image_length = len(image_data)
path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
- request, channel = make_request(
- self.hs.get_reactor(), "POST", path, content=image_data, access_token=tok
- )
- request.requestHeaders.addRawHeader(
- b"Content-Length", str(image_length).encode("UTF-8")
+ channel = make_request(
+ self.hs.get_reactor(),
+ FakeSite(resource),
+ "POST",
+ path,
+ content=image_data,
+ access_token=tok,
+ custom_headers=[(b"Content-Length", str(image_length))],
)
- request.render(resource)
- self.hs.get_reactor().pump([100])
assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
expect_code,
@@ -311,3 +348,111 @@ class RestHelper:
)
return channel.json_body
+
+ def login_via_oidc(self, remote_user_id: str) -> JsonDict:
+ """Log in (as a new user) via OIDC
+
+ Returns the result of the final token login.
+
+ Requires that "oidc_config" in the homeserver config be set appropriately
+ (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
+ "public_base_url".
+
+ Also requires the login servlet and the OIDC callback resource to be mounted at
+ the normal places.
+ """
+ client_redirect_url = "https://x"
+
+ # first hit the redirect url (which will issue a cookie and state)
+ channel = make_request(
+ self.hs.get_reactor(),
+ self.site,
+ "GET",
+ "/login/sso/redirect?redirectUrl=" + client_redirect_url,
+ )
+ # that will redirect to the OIDC IdP, but we skip that and go straight
+ # back to synapse's OIDC callback resource. However, we do need the "state"
+ # param that synapse passes to the IdP via query params, and the cookie that
+ # synapse passes to the client.
+ assert channel.code == 302
+ oauth_uri = channel.headers.getRawHeaders("Location")[0]
+ params = urllib.parse.parse_qs(urllib.parse.urlparse(oauth_uri).query)
+ redirect_uri = "%s?%s" % (
+ urllib.parse.urlparse(params["redirect_uri"][0]).path,
+ urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}),
+ )
+ cookies = {}
+ for h in channel.headers.getRawHeaders("Set-Cookie"):
+ parts = h.split(";")
+ k, v = parts[0].split("=", maxsplit=1)
+ cookies[k] = v
+
+ # before we hit the callback uri, stub out some methods in the http client so
+ # that we don't have to handle full HTTPS requests.
+
+ # (expected url, json response) pairs, in the order we expect them.
+ expected_requests = [
+ # first we get a hit to the token endpoint, which we tell to return
+ # a dummy OIDC access token
+ ("https://issuer.test/token", {"access_token": "TEST"}),
+ # and then one to the user_info endpoint, which returns our remote user id.
+ ("https://issuer.test/userinfo", {"sub": remote_user_id}),
+ ]
+
+ async def mock_req(method: str, uri: str, data=None, headers=None):
+ (expected_uri, resp_obj) = expected_requests.pop(0)
+ assert uri == expected_uri
+ resp = FakeResponse(
+ code=200, phrase=b"OK", body=json.dumps(resp_obj).encode("utf-8"),
+ )
+ return resp
+
+ with patch.object(self.hs.get_proxied_http_client(), "request", mock_req):
+ # now hit the callback URI with the right params and a made-up code
+ channel = make_request(
+ self.hs.get_reactor(),
+ self.site,
+ "GET",
+ redirect_uri,
+ custom_headers=[
+ ("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items()
+ ],
+ )
+
+ # expect a confirmation page
+ assert channel.code == 200
+
+ # fish the matrix login token out of the body of the confirmation page
+ m = re.search(
+ 'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,),
+ channel.result["body"].decode("utf-8"),
+ )
+ assert m
+ login_token = m.group(1)
+
+ # finally, submit the matrix login token to the login API, which gives us our
+ # matrix access token and device id.
+ channel = make_request(
+ self.hs.get_reactor(),
+ self.site,
+ "POST",
+ "/login",
+ content={"type": "m.login.token", "token": login_token},
+ )
+ assert channel.code == 200
+ return channel.json_body
+
+
+# an 'oidc_config' suitable for login_via_oidc.
+TEST_OIDC_CONFIG = {
+ "enabled": True,
+ "discover": False,
+ "issuer": "https://issuer.test",
+ "client_id": "test-client-id",
+ "client_secret": "test-client-secret",
+ "scopes": ["profile"],
+ "authorization_endpoint": "https://z",
+ "token_endpoint": "https://issuer.test/token",
+ "userinfo_endpoint": "https://issuer.test/userinfo",
+ "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},
+}
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index ae2cd67f35..cb87b80e33 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -19,7 +19,6 @@ import os
import re
from email.parser import Parser
from typing import Optional
-from urllib.parse import urlencode
import pkg_resources
@@ -31,6 +30,7 @@ from synapse.rest.client.v2_alpha import account, register
from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
from tests import unittest
+from tests.server import FakeSite, make_request
from tests.unittest import override_config
@@ -240,12 +240,11 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
self.assertIsNotNone(session_id)
def _request_token(self, email, client_secret):
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
b"account/password/email/requestToken",
{"client_secret": client_secret, "email": email, "send_attempt": 1},
)
- self.render(request)
self.assertEquals(200, channel.code, channel.result)
return channel.json_body["sid"]
@@ -255,31 +254,29 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
path = link.replace("https://example.com", "")
# Load the password reset confirmation page
- request, channel = self.make_request("GET", path, shorthand=False)
- request.render(self.submit_token_resource)
- self.pump()
+ channel = make_request(
+ self.reactor,
+ FakeSite(self.submit_token_resource),
+ "GET",
+ path,
+ shorthand=False,
+ )
+
self.assertEquals(200, channel.code, channel.result)
# Now POST to the same endpoint, mimicking the same behaviour as clicking the
# password reset confirm button
- # Send arguments as url-encoded form data, matching the template's behaviour
- form_args = []
- for key, value_list in request.args.items():
- for value in value_list:
- arg = (key, value)
- form_args.append(arg)
-
# Confirm the password reset
- request, channel = self.make_request(
+ channel = make_request(
+ self.reactor,
+ FakeSite(self.submit_token_resource),
"POST",
path,
- content=urlencode(form_args).encode("utf8"),
+ content=b"",
shorthand=False,
content_is_form=True,
)
- request.render(self.submit_token_resource)
- self.pump()
self.assertEquals(200, channel.code, channel.result)
def _get_link_from_email(self):
@@ -305,7 +302,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
def _reset_password(
self, new_password, session_id, client_secret, expected_code=200
):
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
b"account/password",
{
@@ -319,7 +316,6 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
},
},
)
- self.render(request)
self.assertEquals(expected_code, channel.code, channel.result)
@@ -348,11 +344,9 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
self.assertTrue(self.get_success(store.get_user_deactivated_status(user_id)))
# Check that this access token has been invalidated.
- request, channel = self.make_request("GET", "account/whoami")
- self.render(request)
- self.assertEqual(request.code, 401)
+ channel = self.make_request("GET", "account/whoami")
+ self.assertEqual(channel.code, 401)
- @unittest.INFO
def test_pending_invites(self):
"""Tests that deactivating a user rejects every pending invite for them."""
store = self.hs.get_datastore()
@@ -405,11 +399,10 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
"erase": False,
}
)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", "account/deactivate", request_data, access_token=tok
)
- self.render(request)
- self.assertEqual(request.code, 200)
+ self.assertEqual(channel.code, 200)
class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
@@ -529,7 +522,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
self._validate_token(link)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
b"/_matrix/client/unstable/account/3pid/add",
{
@@ -543,15 +536,13 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
},
access_token=self.user_id_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
# Get user
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url_3pid, access_token=self.user_id_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
@@ -570,20 +561,18 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
)
)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
b"account/3pid/delete",
{"medium": "email", "address": self.email},
access_token=self.user_id_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Get user
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url_3pid, access_token=self.user_id_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
@@ -604,22 +593,20 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
)
)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
b"account/3pid/delete",
{"medium": "email", "address": self.email},
access_token=self.user_id_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
# Get user
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url_3pid, access_token=self.user_id_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@@ -634,7 +621,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
self.assertEquals(len(self.email_attempts), 1)
# Attempt to add email without clicking the link
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
b"/_matrix/client/unstable/account/3pid/add",
{
@@ -648,15 +635,13 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
},
access_token=self.user_id_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
# Get user
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url_3pid, access_token=self.user_id_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
@@ -669,7 +654,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
session_id = "weasle"
# Attempt to add email without even requesting an email
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
b"/_matrix/client/unstable/account/3pid/add",
{
@@ -683,15 +668,13 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
},
access_token=self.user_id_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
# Get user
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url_3pid, access_token=self.user_id_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
@@ -793,10 +776,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
if next_link:
body["next_link"] = next_link
- request, channel = self.make_request(
- "POST", b"account/3pid/email/requestToken", body,
- )
- self.render(request)
+ channel = self.make_request("POST", b"account/3pid/email/requestToken", body,)
self.assertEquals(expect_code, channel.code, channel.result)
return channel.json_body.get("sid")
@@ -804,12 +784,11 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
def _request_token_invalid_email(
self, email, expected_errcode, expected_error, client_secret="foobar",
):
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
b"account/3pid/email/requestToken",
{"client_secret": client_secret, "email": email, "send_attempt": 1},
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(expected_errcode, channel.json_body["errcode"])
self.assertEqual(expected_error, channel.json_body["error"])
@@ -818,8 +797,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
# Remove the host
path = link.replace("https://example.com", "")
- request, channel = self.make_request("GET", path, shorthand=False)
- self.render(request)
+ channel = self.make_request("GET", path, shorthand=False)
self.assertEquals(200, channel.code, channel.result)
def _get_link_from_email(self):
@@ -853,7 +831,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
self._validate_token(link)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
b"/_matrix/client/unstable/account/3pid/add",
{
@@ -868,14 +846,12 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Get user
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url_3pid, access_token=self.user_id_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index 293ccfba2b..ac66a4e0b7 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -12,19 +12,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import List, Union
+
+from typing import Union
from twisted.internet.defer import succeed
import synapse.rest.admin
from synapse.api.constants import LoginType
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
-from synapse.http.site import SynapseRequest
from synapse.rest.client.v1 import login
from synapse.rest.client.v2_alpha import auth, devices, register
-from synapse.types import JsonDict
+from synapse.rest.oidc import OIDCResource
+from synapse.types import JsonDict, UserID
from tests import unittest
+from tests.rest.client.v1.utils import TEST_OIDC_CONFIG
from tests.server import FakeChannel
@@ -38,11 +40,6 @@ class DummyRecaptchaChecker(UserInteractiveAuthChecker):
return succeed(True)
-class DummyPasswordChecker(UserInteractiveAuthChecker):
- def check_auth(self, authdict, clientip):
- return succeed(authdict["identifier"]["user"])
-
-
class FallbackAuthTests(unittest.HomeserverTestCase):
servlets = [
@@ -69,12 +66,9 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
def register(self, expected_response: int, body: JsonDict) -> FakeChannel:
"""Make a register request."""
- request, channel = self.make_request(
- "POST", "register", body
- ) # type: SynapseRequest, FakeChannel
- self.render(request)
+ channel = self.make_request("POST", "register", body)
- self.assertEqual(request.code, expected_response)
+ self.assertEqual(channel.code, expected_response)
return channel
def recaptcha(
@@ -84,27 +78,24 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
if post_session is None:
post_session = session
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "auth/m.login.recaptcha/fallback/web?session=" + session
- ) # type: SynapseRequest, FakeChannel
- self.render(request)
- self.assertEqual(request.code, 200)
+ )
+ self.assertEqual(channel.code, 200)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"auth/m.login.recaptcha/fallback/web?session="
+ post_session
+ "&g-recaptcha-response=a",
)
- self.render(request)
- self.assertEqual(request.code, expected_post_response)
+ self.assertEqual(channel.code, expected_post_response)
# The recaptcha handler is called with the response given
attempts = self.recaptcha_checker.recaptcha_attempts
self.assertEqual(len(attempts), 1)
self.assertEqual(attempts[0][0]["response"], "a")
- @unittest.INFO
def test_fallback_captcha(self):
"""Ensure that fallback auth via a captcha works."""
# Returns a 401 as per the spec
@@ -165,36 +156,44 @@ class UIAuthTests(unittest.HomeserverTestCase):
register.register_servlets,
]
- def prepare(self, reactor, clock, hs):
- auth_handler = hs.get_auth_handler()
- auth_handler.checkers[LoginType.PASSWORD] = DummyPasswordChecker(hs)
+ def default_config(self):
+ config = super().default_config()
- self.user_pass = "pass"
- self.user = self.register_user("test", self.user_pass)
- self.user_tok = self.login("test", self.user_pass)
+ # we enable OIDC as a way of testing SSO flows
+ oidc_config = {}
+ oidc_config.update(TEST_OIDC_CONFIG)
+ oidc_config["allow_existing_users"] = True
- def get_device_ids(self) -> List[str]:
- # Get the list of devices so one can be deleted.
- request, channel = self.make_request(
- "GET", "devices", access_token=self.user_tok,
- ) # type: SynapseRequest, FakeChannel
- self.render(request)
+ config["oidc_config"] = oidc_config
+ config["public_baseurl"] = "https://synapse.test"
+ return config
- # Get the ID of the device.
- self.assertEqual(request.code, 200)
- return [d["device_id"] for d in channel.json_body["devices"]]
+ def create_resource_dict(self):
+ resource_dict = super().create_resource_dict()
+ # mount the OIDC resource at /_synapse/oidc
+ resource_dict["/_synapse/oidc"] = OIDCResource(self.hs)
+ return resource_dict
+
+ def prepare(self, reactor, clock, hs):
+ self.user_pass = "pass"
+ self.user = self.register_user("test", self.user_pass)
+ self.device_id = "dev1"
+ self.user_tok = self.login("test", self.user_pass, self.device_id)
def delete_device(
- self, device: str, expected_response: int, body: Union[bytes, JsonDict] = b""
+ self,
+ access_token: str,
+ device: str,
+ expected_response: int,
+ body: Union[bytes, JsonDict] = b"",
) -> FakeChannel:
"""Delete an individual device."""
- request, channel = self.make_request(
- "DELETE", "devices/" + device, body, access_token=self.user_tok
- ) # type: SynapseRequest, FakeChannel
- self.render(request)
+ channel = self.make_request(
+ "DELETE", "devices/" + device, body, access_token=access_token,
+ )
# Ensure the response is sane.
- self.assertEqual(request.code, expected_response)
+ self.assertEqual(channel.code, expected_response)
return channel
@@ -202,13 +201,12 @@ class UIAuthTests(unittest.HomeserverTestCase):
"""Delete 1 or more devices."""
# Note that this uses the delete_devices endpoint so that we can modify
# the payload half-way through some tests.
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", "delete_devices", body, access_token=self.user_tok,
- ) # type: SynapseRequest, FakeChannel
- self.render(request)
+ )
# Ensure the response is sane.
- self.assertEqual(request.code, expected_response)
+ self.assertEqual(channel.code, expected_response)
return channel
@@ -216,11 +214,9 @@ class UIAuthTests(unittest.HomeserverTestCase):
"""
Test user interactive authentication outside of registration.
"""
- device_id = self.get_device_ids()[0]
-
# Attempt to delete this device.
# Returns a 401 as per the spec
- channel = self.delete_device(device_id, 401)
+ channel = self.delete_device(self.user_tok, self.device_id, 401)
# Grab the session
session = channel.json_body["session"]
@@ -229,7 +225,8 @@ class UIAuthTests(unittest.HomeserverTestCase):
# Make another request providing the UI auth flow.
self.delete_device(
- device_id,
+ self.user_tok,
+ self.device_id,
200,
{
"auth": {
@@ -241,6 +238,31 @@ class UIAuthTests(unittest.HomeserverTestCase):
},
)
+ def test_grandfathered_identifier(self):
+ """Check behaviour without "identifier" dict
+
+ Synapse used to require clients to submit a "user" field for m.login.password
+ UIA - check that still works.
+ """
+
+ channel = self.delete_device(self.user_tok, self.device_id, 401)
+ session = channel.json_body["session"]
+
+ # Make another request providing the UI auth flow.
+ self.delete_device(
+ self.user_tok,
+ self.device_id,
+ 200,
+ {
+ "auth": {
+ "type": "m.login.password",
+ "user": self.user,
+ "password": self.user_pass,
+ "session": session,
+ },
+ },
+ )
+
def test_can_change_body(self):
"""
The client dict can be modified during the user interactive authentication session.
@@ -252,14 +274,11 @@ class UIAuthTests(unittest.HomeserverTestCase):
session ID should be rejected.
"""
# Create a second login.
- self.login("test", self.user_pass)
-
- device_ids = self.get_device_ids()
- self.assertEqual(len(device_ids), 2)
+ self.login("test", self.user_pass, "dev2")
# Attempt to delete the first device.
# Returns a 401 as per the spec
- channel = self.delete_devices(401, {"devices": [device_ids[0]]})
+ channel = self.delete_devices(401, {"devices": [self.device_id]})
# Grab the session
session = channel.json_body["session"]
@@ -271,7 +290,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
self.delete_devices(
200,
{
- "devices": [device_ids[1]],
+ "devices": ["dev2"],
"auth": {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": self.user},
@@ -286,14 +305,11 @@ class UIAuthTests(unittest.HomeserverTestCase):
The initial requested URI cannot be modified during the user interactive authentication session.
"""
# Create a second login.
- self.login("test", self.user_pass)
-
- device_ids = self.get_device_ids()
- self.assertEqual(len(device_ids), 2)
+ self.login("test", self.user_pass, "dev2")
# Attempt to delete the first device.
# Returns a 401 as per the spec
- channel = self.delete_device(device_ids[0], 401)
+ channel = self.delete_device(self.user_tok, self.device_id, 401)
# Grab the session
session = channel.json_body["session"]
@@ -302,8 +318,11 @@ class UIAuthTests(unittest.HomeserverTestCase):
# Make another request providing the UI auth flow, but try to delete the
# second device. This results in an error.
+ #
+ # This makes use of the fact that the device ID is embedded into the URL.
self.delete_device(
- device_ids[1],
+ self.user_tok,
+ "dev2",
403,
{
"auth": {
@@ -314,3 +333,83 @@ class UIAuthTests(unittest.HomeserverTestCase):
},
},
)
+
+ @unittest.override_config({"ui_auth": {"session_timeout": 5 * 1000}})
+ def test_can_reuse_session(self):
+ """
+ The session can be reused if configured.
+
+ Compare to test_cannot_change_uri.
+ """
+ # Create a second and third login.
+ self.login("test", self.user_pass, "dev2")
+ self.login("test", self.user_pass, "dev3")
+
+ # Attempt to delete a device. This works since the user just logged in.
+ self.delete_device(self.user_tok, "dev2", 200)
+
+ # Move the clock forward past the validation timeout.
+ self.reactor.advance(6)
+
+ # Deleting another devices throws the user into UI auth.
+ channel = self.delete_device(self.user_tok, "dev3", 401)
+
+ # Grab the session
+ session = channel.json_body["session"]
+ # Ensure that flows are what is expected.
+ self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
+
+ # Make another request providing the UI auth flow.
+ self.delete_device(
+ self.user_tok,
+ "dev3",
+ 200,
+ {
+ "auth": {
+ "type": "m.login.password",
+ "identifier": {"type": "m.id.user", "user": self.user},
+ "password": self.user_pass,
+ "session": session,
+ },
+ },
+ )
+
+ # Make another request, but try to delete the first device. This works
+ # due to re-using the previous session.
+ #
+ # Note that *no auth* information is provided, not even a session iD!
+ self.delete_device(self.user_tok, self.device_id, 200)
+
+ def test_does_not_offer_password_for_sso_user(self):
+ login_resp = self.helper.login_via_oidc("username")
+ user_tok = login_resp["access_token"]
+ device_id = login_resp["device_id"]
+
+ # now call the device deletion API: we should get the option to auth with SSO
+ # and not password.
+ channel = self.delete_device(user_tok, device_id, 401)
+
+ flows = channel.json_body["flows"]
+ self.assertEqual(flows, [{"stages": ["m.login.sso"]}])
+
+ def test_does_not_offer_sso_for_password_user(self):
+ # now call the device deletion API: we should get the option to auth with SSO
+ # and not password.
+ channel = self.delete_device(self.user_tok, self.device_id, 401)
+
+ flows = channel.json_body["flows"]
+ self.assertEqual(flows, [{"stages": ["m.login.password"]}])
+
+ def test_offers_both_flows_for_upgraded_user(self):
+ """A user that had a password and then logged in with SSO should get both flows
+ """
+ login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
+ self.assertEqual(login_resp["user_id"], self.user)
+
+ channel = self.delete_device(self.user_tok, self.device_id, 401)
+
+ flows = channel.json_body["flows"]
+ # we have no particular expectations of ordering here
+ self.assertIn({"stages": ["m.login.password"]}, flows)
+ self.assertIn({"stages": ["m.login.sso"]}, flows)
+ self.assertEqual(len(flows), 2)
diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py
index b9e01c9418..e808339fb3 100644
--- a/tests/rest/client/v2_alpha/test_capabilities.py
+++ b/tests/rest/client/v2_alpha/test_capabilities.py
@@ -36,8 +36,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
return hs
def test_check_auth_required(self):
- request, channel = self.make_request("GET", self.url)
- self.render(request)
+ channel = self.make_request("GET", self.url)
self.assertEqual(channel.code, 401)
@@ -45,8 +44,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
self.register_user("user", "pass")
access_token = self.login("user", "pass")
- request, channel = self.make_request("GET", self.url, access_token=access_token)
- self.render(request)
+ channel = self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]
self.assertEqual(channel.code, 200)
@@ -64,8 +62,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
user = self.register_user(localpart, password)
access_token = self.login(user, password)
- request, channel = self.make_request("GET", self.url, access_token=access_token)
- self.render(request)
+ channel = self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]
self.assertEqual(channel.code, 200)
@@ -73,8 +70,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
# Test case where password is handled outside of Synapse
self.assertTrue(capabilities["m.change_password"]["enabled"])
self.get_success(self.store.user_set_password_hash(user, None))
- request, channel = self.make_request("GET", self.url, access_token=access_token)
- self.render(request)
+ channel = self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]
self.assertEqual(channel.code, 200)
diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py
index de00350580..f761c44936 100644
--- a/tests/rest/client/v2_alpha/test_filter.py
+++ b/tests/rest/client/v2_alpha/test_filter.py
@@ -36,12 +36,11 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.store = hs.get_datastore()
def test_add_filter(self):
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/user/%s/filter" % (self.user_id),
self.EXAMPLE_FILTER_JSON,
)
- self.render(request)
self.assertEqual(channel.result["code"], b"200")
self.assertEqual(channel.json_body, {"filter_id": "0"})
@@ -50,12 +49,11 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.assertEquals(filter.result, self.EXAMPLE_FILTER)
def test_add_filter_for_other_user(self):
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"),
self.EXAMPLE_FILTER_JSON,
)
- self.render(request)
self.assertEqual(channel.result["code"], b"403")
self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN)
@@ -63,12 +61,11 @@ class FilterTestCase(unittest.HomeserverTestCase):
def test_add_filter_non_local_user(self):
_is_mine = self.hs.is_mine
self.hs.is_mine = lambda target_user: False
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/user/%s/filter" % (self.user_id),
self.EXAMPLE_FILTER_JSON,
)
- self.render(request)
self.hs.is_mine = _is_mine
self.assertEqual(channel.result["code"], b"403")
@@ -82,19 +79,17 @@ class FilterTestCase(unittest.HomeserverTestCase):
)
self.reactor.advance(1)
filter_id = filter_id.result
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id)
)
- self.render(request)
self.assertEqual(channel.result["code"], b"200")
self.assertEquals(channel.json_body, self.EXAMPLE_FILTER)
def test_get_filter_non_existant(self):
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id)
)
- self.render(request)
self.assertEqual(channel.result["code"], b"404")
self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND)
@@ -102,18 +97,16 @@ class FilterTestCase(unittest.HomeserverTestCase):
# Currently invalid params do not have an appropriate errcode
# in errors.py
def test_get_filter_invalid_id(self):
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id)
)
- self.render(request)
self.assertEqual(channel.result["code"], b"400")
# No ID also returns an invalid_id error
def test_get_filter_no_id(self):
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id)
)
- self.render(request)
self.assertEqual(channel.result["code"], b"400")
diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/v2_alpha/test_password_policy.py
index c57072f50c..fba34def30 100644
--- a/tests/rest/client/v2_alpha/test_password_policy.py
+++ b/tests/rest/client/v2_alpha/test_password_policy.py
@@ -70,10 +70,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
def test_get_policy(self):
"""Tests if the /password_policy endpoint returns the configured policy."""
- request, channel = self.make_request(
- "GET", "/_matrix/client/r0/password_policy"
- )
- self.render(request)
+ channel = self.make_request("GET", "/_matrix/client/r0/password_policy")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(
@@ -90,8 +87,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
def test_password_too_short(self):
request_data = json.dumps({"username": "kermit", "password": "shorty"})
- request, channel = self.make_request("POST", self.register_url, request_data)
- self.render(request)
+ channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(
@@ -100,8 +96,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
def test_password_no_digit(self):
request_data = json.dumps({"username": "kermit", "password": "longerpassword"})
- request, channel = self.make_request("POST", self.register_url, request_data)
- self.render(request)
+ channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(
@@ -110,8 +105,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
def test_password_no_symbol(self):
request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"})
- request, channel = self.make_request("POST", self.register_url, request_data)
- self.render(request)
+ channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(
@@ -120,8 +114,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
def test_password_no_uppercase(self):
request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"})
- request, channel = self.make_request("POST", self.register_url, request_data)
- self.render(request)
+ channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(
@@ -130,8 +123,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
def test_password_no_lowercase(self):
request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"})
- request, channel = self.make_request("POST", self.register_url, request_data)
- self.render(request)
+ channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(
@@ -140,8 +132,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
def test_password_compliant(self):
request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"})
- request, channel = self.make_request("POST", self.register_url, request_data)
- self.render(request)
+ channel = self.make_request("POST", self.register_url, request_data)
# Getting a 401 here means the password has passed validation and the server has
# responded with a list of registration flows.
@@ -167,13 +158,12 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
},
}
)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/account/password",
request_data,
access_token=tok,
)
- self.render(request)
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT)
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 2fc3a60fc5..27db4f551e 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -55,15 +55,15 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.hs.config.server_name,
id="1234",
namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
+ sender="@as:test",
)
self.hs.get_datastore().services_cache.append(appservice)
request_data = json.dumps({"username": "as_user_kermit"})
- request, channel = self.make_request(
+ channel = self.make_request(
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
)
- self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
det_data = {"user_id": user_id, "home_server": self.hs.hostname}
@@ -72,25 +72,22 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
def test_POST_appservice_registration_invalid(self):
self.appservice = None # no application service exists
request_data = json.dumps({"username": "kermit"})
- request, channel = self.make_request(
+ channel = self.make_request(
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
)
- self.render(request)
self.assertEquals(channel.result["code"], b"401", channel.result)
def test_POST_bad_password(self):
request_data = json.dumps({"username": "kermit", "password": 666})
- request, channel = self.make_request(b"POST", self.url, request_data)
- self.render(request)
+ channel = self.make_request(b"POST", self.url, request_data)
self.assertEquals(channel.result["code"], b"400", channel.result)
self.assertEquals(channel.json_body["error"], "Invalid password")
def test_POST_bad_username(self):
request_data = json.dumps({"username": 777, "password": "monkey"})
- request, channel = self.make_request(b"POST", self.url, request_data)
- self.render(request)
+ channel = self.make_request(b"POST", self.url, request_data)
self.assertEquals(channel.result["code"], b"400", channel.result)
self.assertEquals(channel.json_body["error"], "Invalid username")
@@ -105,8 +102,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"auth": {"type": LoginType.DUMMY},
}
request_data = json.dumps(params)
- request, channel = self.make_request(b"POST", self.url, request_data)
- self.render(request)
+ channel = self.make_request(b"POST", self.url, request_data)
det_data = {
"user_id": user_id,
@@ -121,18 +117,17 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
request_data = json.dumps({"username": "kermit", "password": "monkey"})
self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
- request, channel = self.make_request(b"POST", self.url, request_data)
- self.render(request)
+ channel = self.make_request(b"POST", self.url, request_data)
self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(channel.json_body["error"], "Registration has been disabled")
+ self.assertEquals(channel.json_body["errcode"], "M_FORBIDDEN")
def test_POST_guest_registration(self):
self.hs.config.macaroon_secret_key = "test"
self.hs.config.allow_guest_access = True
- request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
- self.render(request)
+ channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"}
self.assertEquals(channel.result["code"], b"200", channel.result)
@@ -141,8 +136,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
def test_POST_disabled_guest_registration(self):
self.hs.config.allow_guest_access = False
- request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
- self.render(request)
+ channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(channel.json_body["error"], "Guest access is disabled")
@@ -151,8 +145,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
def test_POST_ratelimiting_guest(self):
for i in range(0, 6):
url = self.url + b"?kind=guest"
- request, channel = self.make_request(b"POST", url, b"{}")
- self.render(request)
+ channel = self.make_request(b"POST", url, b"{}")
if i == 5:
self.assertEquals(channel.result["code"], b"429", channel.result)
@@ -162,8 +155,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
- request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
- self.render(request)
+ channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
self.assertEquals(channel.result["code"], b"200", channel.result)
@@ -177,8 +169,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"auth": {"type": LoginType.DUMMY},
}
request_data = json.dumps(params)
- request, channel = self.make_request(b"POST", self.url, request_data)
- self.render(request)
+ channel = self.make_request(b"POST", self.url, request_data)
if i == 5:
self.assertEquals(channel.result["code"], b"429", channel.result)
@@ -188,14 +179,12 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
- request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
- self.render(request)
+ channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
self.assertEquals(channel.result["code"], b"200", channel.result)
def test_advertised_flows(self):
- request, channel = self.make_request(b"POST", self.url, b"{}")
- self.render(request)
+ channel = self.make_request(b"POST", self.url, b"{}")
self.assertEquals(channel.result["code"], b"401", channel.result)
flows = channel.json_body["flows"]
@@ -218,8 +207,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
}
)
def test_advertised_flows_captcha_and_terms_and_3pids(self):
- request, channel = self.make_request(b"POST", self.url, b"{}")
- self.render(request)
+ channel = self.make_request(b"POST", self.url, b"{}")
self.assertEquals(channel.result["code"], b"401", channel.result)
flows = channel.json_body["flows"]
@@ -251,8 +239,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
}
)
def test_advertised_flows_no_msisdn_email_required(self):
- request, channel = self.make_request(b"POST", self.url, b"{}")
- self.render(request)
+ channel = self.make_request(b"POST", self.url, b"{}")
self.assertEquals(channel.result["code"], b"401", channel.result)
flows = channel.json_body["flows"]
@@ -292,12 +279,11 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
)
)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
b"register/email/requestToken",
{"client_secret": "foobar", "email": email, "send_attempt": 1},
)
- self.render(request)
self.assertEquals(200, channel.code, channel.result)
self.assertIsNotNone(channel.json_body.get("sid"))
@@ -332,15 +318,13 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
# The specific endpoint doesn't matter, all we need is an authenticated
# endpoint.
- request, channel = self.make_request(b"GET", "/sync", access_token=tok)
- self.render(request)
+ channel = self.make_request(b"GET", "/sync", access_token=tok)
self.assertEquals(channel.result["code"], b"200", channel.result)
self.reactor.advance(datetime.timedelta(weeks=1).total_seconds())
- request, channel = self.make_request(b"GET", "/sync", access_token=tok)
- self.render(request)
+ channel = self.make_request(b"GET", "/sync", access_token=tok)
self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(
@@ -359,19 +343,15 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
self.register_user("admin", "adminpassword", admin=True)
admin_tok = self.login("admin", "adminpassword")
- url = "/_matrix/client/unstable/admin/account_validity/validity"
+ url = "/_synapse/admin/v1/account_validity/validity"
params = {"user_id": user_id}
request_data = json.dumps(params)
- request, channel = self.make_request(
- b"POST", url, request_data, access_token=admin_tok
- )
- self.render(request)
+ channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
self.assertEquals(channel.result["code"], b"200", channel.result)
# The specific endpoint doesn't matter, all we need is an authenticated
# endpoint.
- request, channel = self.make_request(b"GET", "/sync", access_token=tok)
- self.render(request)
+ channel = self.make_request(b"GET", "/sync", access_token=tok)
self.assertEquals(channel.result["code"], b"200", channel.result)
def test_manual_expire(self):
@@ -381,23 +361,19 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
self.register_user("admin", "adminpassword", admin=True)
admin_tok = self.login("admin", "adminpassword")
- url = "/_matrix/client/unstable/admin/account_validity/validity"
+ url = "/_synapse/admin/v1/account_validity/validity"
params = {
"user_id": user_id,
"expiration_ts": 0,
"enable_renewal_emails": False,
}
request_data = json.dumps(params)
- request, channel = self.make_request(
- b"POST", url, request_data, access_token=admin_tok
- )
- self.render(request)
+ channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
self.assertEquals(channel.result["code"], b"200", channel.result)
# The specific endpoint doesn't matter, all we need is an authenticated
# endpoint.
- request, channel = self.make_request(b"GET", "/sync", access_token=tok)
- self.render(request)
+ channel = self.make_request(b"GET", "/sync", access_token=tok)
self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(
channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
@@ -410,30 +386,25 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
self.register_user("admin", "adminpassword", admin=True)
admin_tok = self.login("admin", "adminpassword")
- url = "/_matrix/client/unstable/admin/account_validity/validity"
+ url = "/_synapse/admin/v1/account_validity/validity"
params = {
"user_id": user_id,
"expiration_ts": 0,
"enable_renewal_emails": False,
}
request_data = json.dumps(params)
- request, channel = self.make_request(
- b"POST", url, request_data, access_token=admin_tok
- )
- self.render(request)
+ channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
self.assertEquals(channel.result["code"], b"200", channel.result)
# Try to log the user out
- request, channel = self.make_request(b"POST", "/logout", access_token=tok)
- self.render(request)
+ channel = self.make_request(b"POST", "/logout", access_token=tok)
self.assertEquals(channel.result["code"], b"200", channel.result)
# Log the user in again (allowed for expired accounts)
tok = self.login("kermit", "monkey")
# Try to log out all of the user's sessions
- request, channel = self.make_request(b"POST", "/logout/all", access_token=tok)
- self.render(request)
+ channel = self.make_request(b"POST", "/logout/all", access_token=tok)
self.assertEquals(channel.result["code"], b"200", channel.result)
@@ -507,8 +478,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
# retrieve the token from the DB.
renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id))
url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token
- request, channel = self.make_request(b"GET", url)
- self.render(request)
+ channel = self.make_request(b"GET", url)
self.assertEquals(channel.result["code"], b"200", channel.result)
# Check that we're getting HTML back.
@@ -528,16 +498,14 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
# our access token should be denied from now, otherwise they should
# succeed.
self.reactor.advance(datetime.timedelta(days=3).total_seconds())
- request, channel = self.make_request(b"GET", "/sync", access_token=tok)
- self.render(request)
+ channel = self.make_request(b"GET", "/sync", access_token=tok)
self.assertEquals(channel.result["code"], b"200", channel.result)
def test_renewal_invalid_token(self):
# Hit the renewal endpoint with an invalid token and check that it behaves as
# expected, i.e. that it responds with 404 Not Found and the correct HTML.
url = "/_matrix/client/unstable/account_validity/renew?token=123"
- request, channel = self.make_request(b"GET", url)
- self.render(request)
+ channel = self.make_request(b"GET", url)
self.assertEquals(channel.result["code"], b"404", channel.result)
# Check that we're getting HTML back.
@@ -558,12 +526,11 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
self.email_attempts = []
(user_id, tok) = self.create_user()
- request, channel = self.make_request(
+ channel = self.make_request(
b"POST",
"/_matrix/client/unstable/account_validity/send_mail",
access_token=tok,
)
- self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
self.assertEqual(len(self.email_attempts), 1)
@@ -583,11 +550,10 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
"erase": False,
}
)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", "account/deactivate", request_data, access_token=tok
)
- self.render(request)
- self.assertEqual(request.code, 200)
+ self.assertEqual(channel.code, 200)
self.reactor.advance(datetime.timedelta(days=8).total_seconds())
@@ -598,7 +564,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
tok = self.login("kermit", "monkey")
# We need to manually add an email address otherwise the handler will do
# nothing.
- now = self.hs.clock.time_msec()
+ now = self.hs.get_clock().time_msec()
self.get_success(
self.store.user_add_threepid(
user_id=user_id,
@@ -616,7 +582,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
# We need to manually add an email address otherwise the handler will do
# nothing.
- now = self.hs.clock.time_msec()
+ now = self.hs.get_clock().time_msec()
self.get_success(
self.store.user_add_threepid(
user_id=user_id,
@@ -635,12 +601,11 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
self.email_attempts = []
# Test that we're still able to manually trigger a mail to be sent.
- request, channel = self.make_request(
+ channel = self.make_request(
b"POST",
"/_matrix/client/unstable/account_validity/send_mail",
access_token=tok,
)
- self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
self.assertEqual(len(self.email_attempts), 1)
@@ -676,7 +641,7 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
self.hs.config.account_validity.startup_job_max_delta = self.max_delta
- now_ms = self.hs.clock.time_msec()
+ now_ms = self.hs.get_clock().time_msec()
self.get_success(self.store._set_expiration_date_when_missing())
res = self.get_success(self.store.get_expiration_ts_for_user(user_id))
diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py
index 99c9f4e928..bd574077e7 100644
--- a/tests/rest/client/v2_alpha/test_relations.py
+++ b/tests/rest/client/v2_alpha/test_relations.py
@@ -60,12 +60,11 @@ class RelationsTestCase(unittest.HomeserverTestCase):
event_id = channel.json_body["event_id"]
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/rooms/%s/event/%s" % (self.room, event_id),
access_token=self.user_token,
)
- self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
self.assert_dict(
@@ -108,13 +107,12 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code, channel.json_body)
annotation_id = channel.json_body["event_id"]
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/relations/%s?limit=1"
% (self.room, self.parent_id),
access_token=self.user_token,
)
- self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
# We expect to get back a single pagination result, which is the full
@@ -154,13 +152,12 @@ class RelationsTestCase(unittest.HomeserverTestCase):
if prev_token:
from_token = "&from=" + prev_token
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/relations/%s?limit=1%s"
% (self.room, self.parent_id, from_token),
access_token=self.user_token,
)
- self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
@@ -213,13 +210,12 @@ class RelationsTestCase(unittest.HomeserverTestCase):
if prev_token:
from_token = "&from=" + prev_token
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1%s"
% (self.room, self.parent_id, from_token),
access_token=self.user_token,
)
- self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
@@ -283,7 +279,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
if prev_token:
from_token = "&from=" + prev_token
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s"
"/aggregations/%s/%s/m.reaction/%s?limit=1%s"
@@ -296,7 +292,6 @@ class RelationsTestCase(unittest.HomeserverTestCase):
),
access_token=self.user_token,
)
- self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
@@ -330,13 +325,12 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
self.assertEquals(200, channel.code, channel.json_body)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/aggregations/%s"
% (self.room, self.parent_id),
access_token=self.user_token,
)
- self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
self.assertEquals(
@@ -363,22 +357,20 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code, channel.json_body)
# Now lets redact one of the 'a' reactions
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/rooms/%s/redact/%s" % (self.room, to_redact_event_id),
access_token=self.user_token,
content={},
)
- self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/aggregations/%s"
% (self.room, self.parent_id),
access_token=self.user_token,
)
- self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
self.assertEquals(
@@ -390,13 +382,12 @@ class RelationsTestCase(unittest.HomeserverTestCase):
"""Test that aggregations must be annotations.
"""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/aggregations/%s/%s?limit=1"
% (self.room, self.parent_id, RelationTypes.REPLACE),
access_token=self.user_token,
)
- self.render(request)
self.assertEquals(400, channel.code, channel.json_body)
def test_aggregation_get_event(self):
@@ -423,12 +414,11 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code, channel.json_body)
reply_2 = channel.json_body["event_id"]
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/rooms/%s/event/%s" % (self.room, self.parent_id),
access_token=self.user_token,
)
- self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
self.assertEquals(
@@ -460,12 +450,11 @@ class RelationsTestCase(unittest.HomeserverTestCase):
edit_event_id = channel.json_body["event_id"]
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/rooms/%s/event/%s" % (self.room, self.parent_id),
access_token=self.user_token,
)
- self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
self.assertEquals(channel.json_body["content"], new_body)
@@ -518,12 +507,11 @@ class RelationsTestCase(unittest.HomeserverTestCase):
)
self.assertEquals(200, channel.code, channel.json_body)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/rooms/%s/event/%s" % (self.room, self.parent_id),
access_token=self.user_token,
)
- self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
self.assertEquals(channel.json_body["content"], new_body)
@@ -561,37 +549,34 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code, channel.json_body)
# Check the relation is returned
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message"
% (self.room, original_event_id),
access_token=self.user_token,
)
- self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
self.assertIn("chunk", channel.json_body)
self.assertEquals(len(channel.json_body["chunk"]), 1)
# Redact the original event
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/rooms/%s/redact/%s/%s"
% (self.room, original_event_id, "test_relations_redaction_redacts_edits"),
access_token=self.user_token,
content="{}",
)
- self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
# Try to check for remaining m.replace relations
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message"
% (self.room, original_event_id),
access_token=self.user_token,
)
- self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
# Check that no relations are returned
@@ -613,7 +598,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code, channel.json_body)
# Redact the original
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
"/rooms/%s/redact/%s/%s"
% (
@@ -624,17 +609,15 @@ class RelationsTestCase(unittest.HomeserverTestCase):
access_token=self.user_token,
content="{}",
)
- self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
# Check that aggregations returns zero
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/aggregations/%s/m.annotation/m.reaction"
% (self.room, original_event_id),
access_token=self.user_token,
)
- self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
self.assertIn("chunk", channel.json_body)
@@ -673,14 +656,13 @@ class RelationsTestCase(unittest.HomeserverTestCase):
original_id = parent_id if parent_id else self.parent_id
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s"
% (self.room, original_id, relation_type, event_type, query),
json.dumps(content).encode("utf-8"),
access_token=access_token,
)
- self.render(request)
return channel
def _create_user(self, localpart):
diff --git a/tests/rest/client/v2_alpha/test_shared_rooms.py b/tests/rest/client/v2_alpha/test_shared_rooms.py
index 5ae72fd008..116ace1812 100644
--- a/tests/rest/client/v2_alpha/test_shared_rooms.py
+++ b/tests/rest/client/v2_alpha/test_shared_rooms.py
@@ -17,6 +17,7 @@ from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import shared_rooms
from tests import unittest
+from tests.server import FakeChannel
class UserSharedRoomsTest(unittest.HomeserverTestCase):
@@ -40,15 +41,13 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
self.store = hs.get_datastore()
self.handler = hs.get_user_directory_handler()
- def _get_shared_rooms(self, token, other_user):
- request, channel = self.make_request(
+ def _get_shared_rooms(self, token, other_user) -> FakeChannel:
+ return self.make_request(
"GET",
"/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/%s"
% other_user,
access_token=token,
)
- self.render(request)
- return request, channel
def test_shared_room_list_public(self):
"""
@@ -64,7 +63,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
self.helper.join(room, user=u2, tok=u2_token)
- request, channel = self._get_shared_rooms(u1_token, u2)
+ channel = self._get_shared_rooms(u1_token, u2)
self.assertEquals(200, channel.code, channel.result)
self.assertEquals(len(channel.json_body["joined"]), 1)
self.assertEquals(channel.json_body["joined"][0], room)
@@ -83,7 +82,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
self.helper.join(room, user=u2, tok=u2_token)
- request, channel = self._get_shared_rooms(u1_token, u2)
+ channel = self._get_shared_rooms(u1_token, u2)
self.assertEquals(200, channel.code, channel.result)
self.assertEquals(len(channel.json_body["joined"]), 1)
self.assertEquals(channel.json_body["joined"][0], room)
@@ -105,7 +104,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
self.helper.join(room_public, user=u2, tok=u2_token)
self.helper.join(room_private, user=u1, tok=u1_token)
- request, channel = self._get_shared_rooms(u1_token, u2)
+ channel = self._get_shared_rooms(u1_token, u2)
self.assertEquals(200, channel.code, channel.result)
self.assertEquals(len(channel.json_body["joined"]), 2)
self.assertTrue(room_public in channel.json_body["joined"])
@@ -126,13 +125,13 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
self.helper.join(room, user=u2, tok=u2_token)
# Assert user directory is not empty
- request, channel = self._get_shared_rooms(u1_token, u2)
+ channel = self._get_shared_rooms(u1_token, u2)
self.assertEquals(200, channel.code, channel.result)
self.assertEquals(len(channel.json_body["joined"]), 1)
self.assertEquals(channel.json_body["joined"][0], room)
self.helper.leave(room, user=u1, tok=u1_token)
- request, channel = self._get_shared_rooms(u2_token, u1)
+ channel = self._get_shared_rooms(u2_token, u1)
self.assertEquals(200, channel.code, channel.result)
self.assertEquals(len(channel.json_body["joined"]), 0)
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index a31e44c97e..512e36c236 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -35,8 +35,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
]
def test_sync_argless(self):
- request, channel = self.make_request("GET", "/sync")
- self.render(request)
+ channel = self.make_request("GET", "/sync")
self.assertEqual(channel.code, 200)
self.assertTrue(
@@ -56,8 +55,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
"""
self.hs.config.use_presence = False
- request, channel = self.make_request("GET", "/sync")
- self.render(request)
+ channel = self.make_request("GET", "/sync")
self.assertEqual(channel.code, 200)
self.assertTrue(
@@ -196,10 +194,9 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
tok=tok,
)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/sync?filter=%s" % sync_filter, access_token=tok
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
return channel.json_body["rooms"]["join"][room_id]["timeline"]["events"]
@@ -248,44 +245,35 @@ class SyncTypingTests(unittest.HomeserverTestCase):
self.helper.send(room, body="There!", tok=other_access_token)
# Start typing.
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
typing_url % (room, other_user_id, other_access_token),
b'{"typing": true, "timeout": 30000}',
)
- self.render(request)
self.assertEquals(200, channel.code)
- request, channel = self.make_request(
- "GET", "/sync?access_token=%s" % (access_token,)
- )
- self.render(request)
+ channel = self.make_request("GET", "/sync?access_token=%s" % (access_token,))
self.assertEquals(200, channel.code)
next_batch = channel.json_body["next_batch"]
# Stop typing.
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
typing_url % (room, other_user_id, other_access_token),
b'{"typing": false}',
)
- self.render(request)
self.assertEquals(200, channel.code)
# Start typing.
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
typing_url % (room, other_user_id, other_access_token),
b'{"typing": true, "timeout": 30000}',
)
- self.render(request)
self.assertEquals(200, channel.code)
# Should return immediately
- request, channel = self.make_request(
- "GET", sync_url % (access_token, next_batch)
- )
- self.render(request)
+ channel = self.make_request("GET", sync_url % (access_token, next_batch))
self.assertEquals(200, channel.code)
next_batch = channel.json_body["next_batch"]
@@ -297,10 +285,7 @@ class SyncTypingTests(unittest.HomeserverTestCase):
# invalidate the stream token.
self.helper.send(room, body="There!", tok=other_access_token)
- request, channel = self.make_request(
- "GET", sync_url % (access_token, next_batch)
- )
- self.render(request)
+ channel = self.make_request("GET", sync_url % (access_token, next_batch))
self.assertEquals(200, channel.code)
next_batch = channel.json_body["next_batch"]
@@ -308,10 +293,7 @@ class SyncTypingTests(unittest.HomeserverTestCase):
# ahead, and therefore it's saying the typing (that we've actually
# already seen) is new, since it's got a token above our new, now-reset
# stream token.
- request, channel = self.make_request(
- "GET", sync_url % (access_token, next_batch)
- )
- self.render(request)
+ channel = self.make_request("GET", sync_url % (access_token, next_batch))
self.assertEquals(200, channel.code)
next_batch = channel.json_body["next_batch"]
@@ -320,10 +302,8 @@ class SyncTypingTests(unittest.HomeserverTestCase):
typing._reset()
# Now it SHOULD fail as it never completes!
- request, channel = self.make_request(
- "GET", sync_url % (access_token, next_batch)
- )
- self.assertRaises(TimedOutException, self.render, request)
+ with self.assertRaises(TimedOutException):
+ self.make_request("GET", sync_url % (access_token, next_batch))
class UnreadMessagesTestCase(unittest.HomeserverTestCase):
@@ -395,13 +375,12 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
# Send a read receipt to tell the server we've read the latest event.
body = json.dumps({"m.read": res["event_id"]}).encode("utf8")
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/rooms/%s/read_markers" % self.room_id,
body,
access_token=self.tok,
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.json_body)
# Check that the unread counter is back to 0.
@@ -463,10 +442,9 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
def _check_unread_count(self, expected_count: True):
"""Syncs and compares the unread count with the expected value."""
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", self.url % self.next_batch, access_token=self.tok,
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.json_body)
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index 6850c666be..5e90d656f7 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -32,16 +32,16 @@ from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.stringutils import random_string
from tests import unittest
-from tests.server import FakeChannel, wait_until_result
+from tests.server import FakeChannel
from tests.utils import default_config
class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.http_client = Mock()
- return self.setup_test_homeserver(http_client=self.http_client)
+ return self.setup_test_homeserver(federation_http_client=self.http_client)
- def create_test_json_resource(self):
+ def create_test_resource(self):
return create_resource_tree(
{"/_matrix/key/v2": KeyApiV2Resource(self.hs)}, root_resource=NoResource()
)
@@ -94,7 +94,7 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase):
% (server_name.encode("utf-8"), key_id.encode("utf-8")),
b"1.1",
)
- wait_until_result(self.reactor, req)
+ channel.await_result()
self.assertEqual(channel.code, 200)
resp = channel.json_body
return resp
@@ -172,7 +172,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
}
]
self.hs2 = self.setup_test_homeserver(
- http_client=self.http_client2, config=config
+ federation_http_client=self.http_client2, config=config
)
# wire up outbound POST /key/v2/query requests from hs2 so that they
@@ -190,7 +190,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
req.requestReceived(
b"POST", path.encode("utf-8"), b"1.1",
)
- wait_until_result(self.reactor, req)
+ channel.await_result()
self.assertEqual(channel.code, 200)
resp = channel.json_body
return resp
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 5f897d49cf..ae2b32b131 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -36,6 +36,7 @@ from synapse.rest.media.v1.media_storage import MediaStorage
from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
from tests import unittest
+from tests.server import FakeSite, make_request
class MediaStorageTests(unittest.HomeserverTestCase):
@@ -213,7 +214,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
}
config["media_storage_providers"] = [provider_config]
- hs = self.setup_test_homeserver(config=config, http_client=client)
+ hs = self.setup_test_homeserver(config=config, federation_http_client=client)
return hs
@@ -227,8 +228,14 @@ class MediaRepoTests(unittest.HomeserverTestCase):
def _req(self, content_disposition):
- request, channel = self.make_request("GET", self.media_id, shorthand=False)
- request.render(self.download_resource)
+ channel = make_request(
+ self.reactor,
+ FakeSite(self.download_resource),
+ "GET",
+ self.media_id,
+ shorthand=False,
+ await_result=False,
+ )
self.pump()
# We've made one fetch, to example.com, using the media URL, and asking
@@ -317,10 +324,14 @@ class MediaRepoTests(unittest.HomeserverTestCase):
def _test_thumbnail(self, method, expected_body, expected_found):
params = "?width=32&height=32&method=" + method
- request, channel = self.make_request(
- "GET", self.media_id + params, shorthand=False
+ channel = make_request(
+ self.reactor,
+ FakeSite(self.thumbnail_resource),
+ "GET",
+ self.media_id + params,
+ shorthand=False,
+ await_result=False,
)
- request.render(self.thumbnail_resource)
self.pump()
headers = {
@@ -348,7 +359,19 @@ class MediaRepoTests(unittest.HomeserverTestCase):
channel.json_body,
{
"errcode": "M_NOT_FOUND",
- "error": "Not found [b'example.com', b'12345?width=32&height=32&method=%s']"
- % method,
+ "error": "Not found [b'example.com', b'12345']",
},
)
+
+ def test_x_robots_tag_header(self):
+ """
+ Tests that the `X-Robots-Tag` header is present, which informs web crawlers
+ to not index, archive, or follow links in media.
+ """
+ channel = self._req(b"inline; filename=out" + self.test_image.extension)
+
+ headers = channel.headers
+ self.assertEqual(
+ headers.getRawHeaders(b"X-Robots-Tag"),
+ [b"noindex, nofollow, noarchive, noimageindex"],
+ )
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index c00a7b9114..83d728b4a4 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -18,41 +18,15 @@ import re
from mock import patch
-import attr
-
from twisted.internet._resolver import HostResolution
from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.error import DNSLookupError
-from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol
-from twisted.web._newclient import ResponseDone
from tests import unittest
from tests.server import FakeTransport
-@attr.s
-class FakeResponse:
- version = attr.ib()
- code = attr.ib()
- phrase = attr.ib()
- headers = attr.ib()
- body = attr.ib()
- absoluteURI = attr.ib()
-
- @property
- def request(self):
- @attr.s
- class FakeTransport:
- absoluteURI = self.absoluteURI
-
- return FakeTransport()
-
- def deliverBody(self, protocol):
- protocol.dataReceived(self.body)
- protocol.connectionLost(Failure(ResponseDone()))
-
-
class URLPreviewTests(unittest.HomeserverTestCase):
hijack_auth = True
@@ -133,13 +107,18 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.reactor.nameResolver = Resolver()
+ def create_test_resource(self):
+ return self.hs.get_media_repository_resource()
+
def test_cache_returns_correct_type(self):
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
- request, channel = self.make_request(
- "GET", "url_preview?url=http://matrix.org", shorthand=False
+ channel = self.make_request(
+ "GET",
+ "preview_url?url=http://matrix.org",
+ shorthand=False,
+ await_result=False,
)
- request.render(self.preview_url)
self.pump()
client = self.reactor.tcpClients[0][2].buildProtocol(None)
@@ -159,11 +138,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
)
# Check the cache returns the correct response
- request, channel = self.make_request(
- "GET", "url_preview?url=http://matrix.org", shorthand=False
+ channel = self.make_request(
+ "GET", "preview_url?url=http://matrix.org", shorthand=False
)
- request.render(self.preview_url)
- self.pump()
# Check the cache response has the same content
self.assertEqual(channel.code, 200)
@@ -177,11 +154,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertNotIn("http://matrix.org", self.preview_url._cache)
# Check the database cache returns the correct response
- request, channel = self.make_request(
- "GET", "url_preview?url=http://matrix.org", shorthand=False
+ channel = self.make_request(
+ "GET", "preview_url?url=http://matrix.org", shorthand=False
)
- request.render(self.preview_url)
- self.pump()
# Check the cache response has the same content
self.assertEqual(channel.code, 200)
@@ -200,10 +175,12 @@ class URLPreviewTests(unittest.HomeserverTestCase):
b"</head></html>"
)
- request, channel = self.make_request(
- "GET", "url_preview?url=http://matrix.org", shorthand=False
+ channel = self.make_request(
+ "GET",
+ "preview_url?url=http://matrix.org",
+ shorthand=False,
+ await_result=False,
)
- request.render(self.preview_url)
self.pump()
client = self.reactor.tcpClients[0][2].buildProtocol(None)
@@ -233,10 +210,12 @@ class URLPreviewTests(unittest.HomeserverTestCase):
b"</head></html>"
)
- request, channel = self.make_request(
- "GET", "url_preview?url=http://matrix.org", shorthand=False
+ channel = self.make_request(
+ "GET",
+ "preview_url?url=http://matrix.org",
+ shorthand=False,
+ await_result=False,
)
- request.render(self.preview_url)
self.pump()
client = self.reactor.tcpClients[0][2].buildProtocol(None)
@@ -266,10 +245,12 @@ class URLPreviewTests(unittest.HomeserverTestCase):
b"</head></html>"
)
- request, channel = self.make_request(
- "GET", "url_preview?url=http://matrix.org", shorthand=False
+ channel = self.make_request(
+ "GET",
+ "preview_url?url=http://matrix.org",
+ shorthand=False,
+ await_result=False,
)
- request.render(self.preview_url)
self.pump()
client = self.reactor.tcpClients[0][2].buildProtocol(None)
@@ -297,10 +278,12 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"""
self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")]
- request, channel = self.make_request(
- "GET", "url_preview?url=http://example.com", shorthand=False
+ channel = self.make_request(
+ "GET",
+ "preview_url?url=http://example.com",
+ shorthand=False,
+ await_result=False,
)
- request.render(self.preview_url)
self.pump()
client = self.reactor.tcpClients[0][2].buildProtocol(None)
@@ -325,11 +308,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"""
self.lookups["example.com"] = [(IPv4Address, "192.168.1.1")]
- request, channel = self.make_request(
- "GET", "url_preview?url=http://example.com", shorthand=False
+ channel = self.make_request(
+ "GET", "preview_url?url=http://example.com", shorthand=False
)
- request.render(self.preview_url)
- self.pump()
# No requests made.
self.assertEqual(len(self.reactor.tcpClients), 0)
@@ -348,11 +329,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"""
self.lookups["example.com"] = [(IPv4Address, "1.1.1.2")]
- request, channel = self.make_request(
- "GET", "url_preview?url=http://example.com", shorthand=False
+ channel = self.make_request(
+ "GET", "preview_url?url=http://example.com", shorthand=False
)
- request.render(self.preview_url)
- self.pump()
self.assertEqual(channel.code, 502)
self.assertEqual(
@@ -367,11 +346,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"""
Blacklisted IP addresses, accessed directly, are not spidered.
"""
- request, channel = self.make_request(
- "GET", "url_preview?url=http://192.168.1.1", shorthand=False
+ channel = self.make_request(
+ "GET", "preview_url?url=http://192.168.1.1", shorthand=False
)
- request.render(self.preview_url)
- self.pump()
# No requests made.
self.assertEqual(len(self.reactor.tcpClients), 0)
@@ -388,11 +365,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"""
Blacklisted IP ranges, accessed directly, are not spidered.
"""
- request, channel = self.make_request(
- "GET", "url_preview?url=http://1.1.1.2", shorthand=False
+ channel = self.make_request(
+ "GET", "preview_url?url=http://1.1.1.2", shorthand=False
)
- request.render(self.preview_url)
- self.pump()
self.assertEqual(channel.code, 403)
self.assertEqual(
@@ -410,10 +385,12 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"""
self.lookups["example.com"] = [(IPv4Address, "1.1.1.1")]
- request, channel = self.make_request(
- "GET", "url_preview?url=http://example.com", shorthand=False
+ channel = self.make_request(
+ "GET",
+ "preview_url?url=http://example.com",
+ shorthand=False,
+ await_result=False,
)
- request.render(self.preview_url)
self.pump()
client = self.reactor.tcpClients[0][2].buildProtocol(None)
@@ -445,11 +422,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
(IPv4Address, "10.1.2.3"),
]
- request, channel = self.make_request(
- "GET", "url_preview?url=http://example.com", shorthand=False
+ channel = self.make_request(
+ "GET", "preview_url?url=http://example.com", shorthand=False
)
- request.render(self.preview_url)
- self.pump()
self.assertEqual(channel.code, 502)
self.assertEqual(
channel.json_body,
@@ -467,11 +442,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
(IPv6Address, "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff")
]
- request, channel = self.make_request(
- "GET", "url_preview?url=http://example.com", shorthand=False
+ channel = self.make_request(
+ "GET", "preview_url?url=http://example.com", shorthand=False
)
- request.render(self.preview_url)
- self.pump()
# No requests made.
self.assertEqual(len(self.reactor.tcpClients), 0)
@@ -490,11 +463,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"""
self.lookups["example.com"] = [(IPv6Address, "2001:800::1")]
- request, channel = self.make_request(
- "GET", "url_preview?url=http://example.com", shorthand=False
+ channel = self.make_request(
+ "GET", "preview_url?url=http://example.com", shorthand=False
)
- request.render(self.preview_url)
- self.pump()
self.assertEqual(channel.code, 502)
self.assertEqual(
@@ -509,11 +480,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"""
OPTIONS returns the OPTIONS.
"""
- request, channel = self.make_request(
- "OPTIONS", "url_preview?url=http://example.com", shorthand=False
+ channel = self.make_request(
+ "OPTIONS", "preview_url?url=http://example.com", shorthand=False
)
- request.render(self.preview_url)
- self.pump()
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body, {})
@@ -524,10 +493,12 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")]
# Build and make a request to the server
- request, channel = self.make_request(
- "GET", "url_preview?url=http://example.com", shorthand=False
+ channel = self.make_request(
+ "GET",
+ "preview_url?url=http://example.com",
+ shorthand=False,
+ await_result=False,
)
- request.render(self.preview_url)
self.pump()
# Extract Synapse's tcp client
@@ -596,12 +567,12 @@ class URLPreviewTests(unittest.HomeserverTestCase):
b"</head></html>"
)
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
- "url_preview?url=http://twitter.com/matrixdotorg/status/12345",
+ "preview_url?url=http://twitter.com/matrixdotorg/status/12345",
shorthand=False,
+ await_result=False,
)
- request.render(self.preview_url)
self.pump()
client = self.reactor.tcpClients[0][2].buildProtocol(None)
@@ -661,12 +632,12 @@ class URLPreviewTests(unittest.HomeserverTestCase):
}
end_content = json.dumps(result).encode("utf-8")
- request, channel = self.make_request(
+ channel = self.make_request(
"GET",
- "url_preview?url=http://twitter.com/matrixdotorg/status/12345",
+ "preview_url?url=http://twitter.com/matrixdotorg/status/12345",
shorthand=False,
+ await_result=False,
)
- request.render(self.preview_url)
self.pump()
client = self.reactor.tcpClients[0][2].buildProtocol(None)
diff --git a/tests/rest/test_health.py b/tests/rest/test_health.py
index 2d021f6565..32acd93dc1 100644
--- a/tests/rest/test_health.py
+++ b/tests/rest/test_health.py
@@ -20,15 +20,12 @@ from tests import unittest
class HealthCheckTests(unittest.HomeserverTestCase):
- def setUp(self):
- super().setUp()
-
+ def create_test_resource(self):
# replace the JsonResource with a HealthResource.
- self.resource = HealthResource()
+ return HealthResource()
def test_health(self):
- request, channel = self.make_request("GET", "/health", shorthand=False)
- self.render(request)
+ channel = self.make_request("GET", "/health", shorthand=False)
- self.assertEqual(request.code, 200)
+ self.assertEqual(channel.code, 200)
self.assertEqual(channel.result["body"], b"OK")
diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py
index dcd65c2a50..14de0921be 100644
--- a/tests/rest/test_well_known.py
+++ b/tests/rest/test_well_known.py
@@ -20,22 +20,19 @@ from tests import unittest
class WellKnownTests(unittest.HomeserverTestCase):
- def setUp(self):
- super().setUp()
-
+ def create_test_resource(self):
# replace the JsonResource with a WellKnownResource
- self.resource = WellKnownResource(self.hs)
+ return WellKnownResource(self.hs)
def test_well_known(self):
self.hs.config.public_baseurl = "https://tesths"
self.hs.config.default_identity_server = "https://testis"
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/.well-known/matrix/client", shorthand=False
)
- self.render(request)
- self.assertEqual(request.code, 200)
+ self.assertEqual(channel.code, 200)
self.assertEqual(
channel.json_body,
{
@@ -47,9 +44,8 @@ class WellKnownTests(unittest.HomeserverTestCase):
def test_well_known_no_public_baseurl(self):
self.hs.config.public_baseurl = None
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/.well-known/matrix/client", shorthand=False
)
- self.render(request)
- self.assertEqual(request.code, 404)
+ self.assertEqual(channel.code, 404)
diff --git a/tests/server.py b/tests/server.py
index b404ad4e2a..7d1ad362c4 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -1,8 +1,11 @@
import json
import logging
+from collections import deque
from io import SEEK_END, BytesIO
+from typing import Callable, Iterable, Optional, Tuple, Union
import attr
+from typing_extensions import Deque
from zope.interface import implementer
from twisted.internet import address, threads, udp
@@ -16,8 +19,8 @@ from twisted.internet.interfaces import (
)
from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
-from twisted.web.http import unquote
from twisted.web.http_headers import Headers
+from twisted.web.resource import IResource
from twisted.web.server import Site
from synapse.http.site import SynapseRequest
@@ -43,7 +46,7 @@ class FakeChannel:
site = attr.ib(type=Site)
_reactor = attr.ib()
- result = attr.ib(default=attr.Factory(dict))
+ result = attr.ib(type=dict, default=attr.Factory(dict))
_producer = None
@property
@@ -114,6 +117,25 @@ class FakeChannel:
def transport(self):
return self
+ def await_result(self, timeout: int = 100) -> None:
+ """
+ Wait until the request is finished.
+ """
+ self._reactor.run()
+ x = 0
+
+ while not self.result.get("done"):
+ # If there's a producer, tell it to resume producing so we get content
+ if self._producer:
+ self._producer.resumeProducing()
+
+ x += 1
+
+ if x > timeout:
+ raise TimedOutException("Timed out waiting for request to finish.")
+
+ self._reactor.advance(0.1)
+
class FakeSite:
"""
@@ -125,9 +147,21 @@ class FakeSite:
site_tag = "test"
access_logger = logging.getLogger("synapse.access.http.fake")
+ def __init__(self, resource: IResource):
+ """
+
+ Args:
+ resource: the resource to be used for rendering all requests
+ """
+ self._resource = resource
+
+ def getResourceFor(self, request):
+ return self._resource
+
def make_request(
reactor,
+ site: Site,
method,
path,
content=b"",
@@ -136,12 +170,19 @@ def make_request(
shorthand=True,
federation_auth_origin=None,
content_is_form=False,
-):
+ await_result: bool = True,
+ custom_headers: Optional[
+ Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
+ ] = None,
+) -> FakeChannel:
"""
- Make a web request using the given method and path, feed it the
- content, and return the Request and the Channel underneath.
+ Make a web request using the given method, path and content, and render it
+
+ Returns the fake Channel object which records the response to the request.
Args:
+ site: The twisted Site to use to render the request
+
method (bytes/unicode): The HTTP request method ("verb").
path (bytes/unicode): The HTTP path, suitably URL encoded (e.g.
escaped UTF-8 & spaces and such).
@@ -154,8 +195,14 @@ def make_request(
content_is_form: Whether the content is URL encoded form data. Adds the
'Content-Type': 'application/x-www-form-urlencoded' header.
+ custom_headers: (name, value) pairs to add as request headers
+
+ await_result: whether to wait for the request to complete rendering. If true,
+ will pump the reactor until the the renderer tells the channel the request
+ is finished.
+
Returns:
- Tuple[synapse.http.site.SynapseRequest, channel]
+ channel
"""
if not isinstance(method, bytes):
method = method.encode("ascii")
@@ -169,24 +216,24 @@ def make_request(
and not path.startswith(b"/_matrix")
and not path.startswith(b"/_synapse")
):
+ if path.startswith(b"/"):
+ path = path[1:]
path = b"/_matrix/client/r0/" + path
- path = path.replace(b"//", b"/")
if not path.startswith(b"/"):
path = b"/" + path
+ if isinstance(content, dict):
+ content = json.dumps(content).encode("utf8")
if isinstance(content, str):
content = content.encode("utf8")
- site = FakeSite()
channel = FakeChannel(site, reactor)
req = request(channel)
- req.process = lambda: b""
req.content = BytesIO(content)
# Twisted expects to be at the end of the content when parsing the request.
req.content.seek(SEEK_END)
- req.postpath = list(map(unquote, path[1:].split(b"/")))
if access_token:
req.requestHeaders.addRawHeader(
@@ -208,35 +255,17 @@ def make_request(
# Assume the body is JSON
req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
- req.requestReceived(method, path, b"1.1")
+ if custom_headers:
+ for k, v in custom_headers:
+ req.requestHeaders.addRawHeader(k, v)
- return req, channel
-
-
-def wait_until_result(clock, request, timeout=100):
- """
- Wait until the request is finished.
- """
- clock.run()
- x = 0
-
- while not request.finished:
-
- # If there's a producer, tell it to resume producing so we get content
- if request._channel._producer:
- request._channel._producer.resumeProducing()
-
- x += 1
-
- if x > timeout:
- raise TimedOutException("Timed out waiting for request to finish.")
-
- clock.advance(0.1)
+ req.parseCookies()
+ req.requestReceived(method, path, b"1.1")
+ if await_result:
+ channel.await_result()
-def render(request, resource, clock):
- request.render(resource)
- wait_until_result(clock, request)
+ return channel
@implementer(IReactorPluggableNameResolver)
@@ -251,6 +280,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
self._tcp_callbacks = {}
self._udp = []
lookups = self.lookups = {}
+ self._thread_callbacks = deque() # type: Deque[Callable[[], None]]()
@implementer(IResolverSimple)
class FakeResolver:
@@ -272,10 +302,10 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
"""
Make the callback fire in the next reactor iteration.
"""
- d = Deferred()
- d.addCallback(lambda x: callback(*args, **kwargs))
- self.callLater(0, d.callback, True)
- return d
+ cb = lambda: callback(*args, **kwargs)
+ # it's not safe to call callLater() here, so we append the callback to a
+ # separate queue.
+ self._thread_callbacks.append(cb)
def getThreadPool(self):
return self.threadpool
@@ -303,6 +333,30 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
return conn
+ def advance(self, amount):
+ # first advance our reactor's time, and run any "callLater" callbacks that
+ # makes ready
+ super().advance(amount)
+
+ # now run any "callFromThread" callbacks
+ while True:
+ try:
+ callback = self._thread_callbacks.popleft()
+ except IndexError:
+ break
+ callback()
+
+ # check for more "callLater" callbacks added by the thread callback
+ # This isn't required in a regular reactor, but it ends up meaning that
+ # our database queries can complete in a single call to `advance` [1] which
+ # simplifies tests.
+ #
+ # [1]: we replace the threadpool backing the db connection pool with a
+ # mock ThreadPool which doesn't really use threads; but we still use
+ # reactor.callFromThread to feed results back from the db functions to the
+ # main thread.
+ super().advance(0)
+
class ThreadPool:
"""
@@ -339,8 +393,6 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
"""
server = _sth(cleanup_func, *args, **kwargs)
- database = server.config.database.get_single_database()
-
# Make the thread pool synchronous.
clock = server.get_clock()
@@ -354,7 +406,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
pool._runWithConnection,
func,
*args,
- **kwargs
+ **kwargs,
)
def runInteraction(interaction, *args, **kwargs):
@@ -364,7 +416,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
pool._runInteraction,
interaction,
*args,
- **kwargs
+ **kwargs,
)
pool.runWithConnection = runWithConnection
@@ -372,6 +424,10 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
pool.threadpool = ThreadPool(clock._reactor)
pool.running = True
+ # We've just changed the Databases to run DB transactions on the same
+ # thread, so we need to disable the dedicated thread behaviour.
+ server.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False
+
return server
@@ -541,12 +597,10 @@ def connect_client(reactor: IReactorTCP, client_id: int) -> AccumulatingProtocol
reactor
factory: The connecting factory to build.
"""
- factory = reactor.tcpClients[client_id][2]
+ factory = reactor.tcpClients.pop(client_id)[2]
client = factory.buildProtocol(None)
server = AccumulatingProtocol()
server.makeConnection(FakeTransport(client, reactor))
client.makeConnection(FakeTransport(server, reactor))
- reactor.tcpClients.pop(client_id)
-
return client, server
diff --git a/tests/server_notices/test_consent.py b/tests/server_notices/test_consent.py
index 872039c8f1..4dd5a36178 100644
--- a/tests/server_notices/test_consent.py
+++ b/tests/server_notices/test_consent.py
@@ -70,29 +70,26 @@ class ConsentNoticesTests(unittest.HomeserverTestCase):
the notice URL + an authentication code.
"""
# Initial sync, to get the user consent room invite
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/_matrix/client/r0/sync", access_token=self.access_token
)
- self.render(request)
self.assertEqual(channel.code, 200)
# Get the Room ID to join
room_id = list(channel.json_body["rooms"]["invite"].keys())[0]
# Join the room
- request, channel = self.make_request(
+ channel = self.make_request(
"POST",
"/_matrix/client/r0/rooms/" + room_id + "/join",
access_token=self.access_token,
)
- self.render(request)
self.assertEqual(channel.code, 200)
# Sync again, to get the message in the room
- request, channel = self.make_request(
+ channel = self.make_request(
"GET", "/_matrix/client/r0/sync", access_token=self.access_token
)
- self.render(request)
self.assertEqual(channel.code, 200)
# Get the message
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 6382b19dc3..fea54464af 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -305,8 +305,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
self.register_user("user", "password")
tok = self.login("user", "password")
- request, channel = self.make_request("GET", "/sync?timeout=0", access_token=tok)
- self.render(request)
+ channel = self.make_request("GET", "/sync?timeout=0", access_token=tok)
invites = channel.json_body["rooms"]["invite"]
self.assertEqual(len(invites), 0, invites)
@@ -319,8 +318,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
# Sync again to retrieve the events in the room, so we can check whether this
# room has a notice in it.
- request, channel = self.make_request("GET", "/sync?timeout=0", access_token=tok)
- self.render(request)
+ channel = self.make_request("GET", "/sync?timeout=0", access_token=tok)
# Scan the events in the room to search for a message from the server notices
# user.
@@ -355,10 +353,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
tok = self.login(localpart, "password")
# Sync with the user's token to mark the user as active.
- request, channel = self.make_request(
- "GET", "/sync?timeout=0", access_token=tok,
- )
- self.render(request)
+ channel = self.make_request("GET", "/sync?timeout=0", access_token=tok,)
# Also retrieves the list of invites for this user. We don't care about that
# one except if we're processing the last user, which should have received an
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index ad9bbef9d2..77c72834f2 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -24,7 +24,11 @@ from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.api.room_versions import RoomVersions
from synapse.event_auth import auth_types_for_event
from synapse.events import make_event_from_dict
-from synapse.state.v2 import lexicographical_topological_sort, resolve_events_with_store
+from synapse.state.v2 import (
+ _get_auth_chain_difference,
+ lexicographical_topological_sort,
+ resolve_events_with_store,
+)
from synapse.types import EventID
from tests import unittest
@@ -84,7 +88,7 @@ class FakeEvent:
event_dict = {
"auth_events": [(a, {}) for a in auth_events],
"prev_events": [(p, {}) for p in prev_events],
- "event_id": self.node_id,
+ "event_id": self.event_id,
"sender": self.sender,
"type": self.type,
"content": self.content,
@@ -377,6 +381,61 @@ class StateTestCase(unittest.TestCase):
self.do_check(events, edges, expected_state_ids)
+ def test_mainline_sort(self):
+ """Tests that the mainline ordering works correctly.
+ """
+
+ events = [
+ FakeEvent(
+ id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
+ ),
+ FakeEvent(
+ id="PA1",
+ sender=ALICE,
+ type=EventTypes.PowerLevels,
+ state_key="",
+ content={"users": {ALICE: 100, BOB: 50}},
+ ),
+ FakeEvent(
+ id="T2", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
+ ),
+ FakeEvent(
+ id="PA2",
+ sender=ALICE,
+ type=EventTypes.PowerLevels,
+ state_key="",
+ content={
+ "users": {ALICE: 100, BOB: 50},
+ "events": {EventTypes.PowerLevels: 100},
+ },
+ ),
+ FakeEvent(
+ id="PB",
+ sender=BOB,
+ type=EventTypes.PowerLevels,
+ state_key="",
+ content={"users": {ALICE: 100, BOB: 50}},
+ ),
+ FakeEvent(
+ id="T3", sender=BOB, type=EventTypes.Topic, state_key="", content={}
+ ),
+ FakeEvent(
+ id="T4", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
+ ),
+ ]
+
+ edges = [
+ ["END", "T3", "PA2", "T2", "PA1", "T1", "START"],
+ ["END", "T4", "PB", "PA1"],
+ ]
+
+ # We expect T3 to be picked as the other topics are pointing at older
+ # power levels. Note that without mainline ordering we'd pick T4 due to
+ # it being sent *after* T3.
+ expected_state_ids = ["T3", "PA2"]
+
+ self.do_check(events, edges, expected_state_ids)
+
def do_check(self, events, edges, expected_state_ids):
"""Take a list of events and edges and calculate the state of the
graph at END, and asserts it matches `expected_state_ids`
@@ -587,6 +646,134 @@ class SimpleParamStateTestCase(unittest.TestCase):
self.assert_dict(self.expected_combined_state, state)
+class AuthChainDifferenceTestCase(unittest.TestCase):
+ """We test that `_get_auth_chain_difference` correctly handles unpersisted
+ events.
+ """
+
+ def test_simple(self):
+ # Test getting the auth difference for a simple chain with a single
+ # unpersisted event:
+ #
+ # Unpersisted | Persisted
+ # |
+ # C -|-> B -> A
+
+ a = FakeEvent(
+ id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([], [])
+
+ b = FakeEvent(
+ id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([a.event_id], [])
+
+ c = FakeEvent(
+ id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([b.event_id], [])
+
+ persisted_events = {a.event_id: a, b.event_id: b}
+ unpersited_events = {c.event_id: c}
+
+ state_sets = [{"a": a.event_id, "b": b.event_id}, {"c": c.event_id}]
+
+ store = TestStateResolutionStore(persisted_events)
+
+ diff_d = _get_auth_chain_difference(
+ ROOM_ID, state_sets, unpersited_events, store
+ )
+ difference = self.successResultOf(defer.ensureDeferred(diff_d))
+
+ self.assertEqual(difference, {c.event_id})
+
+ def test_multiple_unpersisted_chain(self):
+ # Test getting the auth difference for a simple chain with multiple
+ # unpersisted events:
+ #
+ # Unpersisted | Persisted
+ # |
+ # D -> C -|-> B -> A
+
+ a = FakeEvent(
+ id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([], [])
+
+ b = FakeEvent(
+ id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([a.event_id], [])
+
+ c = FakeEvent(
+ id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([b.event_id], [])
+
+ d = FakeEvent(
+ id="D", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([c.event_id], [])
+
+ persisted_events = {a.event_id: a, b.event_id: b}
+ unpersited_events = {c.event_id: c, d.event_id: d}
+
+ state_sets = [
+ {"a": a.event_id, "b": b.event_id},
+ {"c": c.event_id, "d": d.event_id},
+ ]
+
+ store = TestStateResolutionStore(persisted_events)
+
+ diff_d = _get_auth_chain_difference(
+ ROOM_ID, state_sets, unpersited_events, store
+ )
+ difference = self.successResultOf(defer.ensureDeferred(diff_d))
+
+ self.assertEqual(difference, {d.event_id, c.event_id})
+
+ def test_unpersisted_events_different_sets(self):
+ # Test getting the auth difference for with multiple unpersisted events
+ # in different branches:
+ #
+ # Unpersisted | Persisted
+ # |
+ # D --> C -|-> B -> A
+ # E ----^ -|---^
+ # |
+
+ a = FakeEvent(
+ id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([], [])
+
+ b = FakeEvent(
+ id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([a.event_id], [])
+
+ c = FakeEvent(
+ id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([b.event_id], [])
+
+ d = FakeEvent(
+ id="D", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([c.event_id], [])
+
+ e = FakeEvent(
+ id="E", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([c.event_id, b.event_id], [])
+
+ persisted_events = {a.event_id: a, b.event_id: b}
+ unpersited_events = {c.event_id: c, d.event_id: d, e.event_id: e}
+
+ state_sets = [
+ {"a": a.event_id, "b": b.event_id, "e": e.event_id},
+ {"c": c.event_id, "d": d.event_id},
+ ]
+
+ store = TestStateResolutionStore(persisted_events)
+
+ diff_d = _get_auth_chain_difference(
+ ROOM_ID, state_sets, unpersited_events, store
+ )
+ difference = self.successResultOf(defer.ensureDeferred(diff_d))
+
+ self.assertEqual(difference, {d.event_id, e.event_id})
+
+
def pairwise(iterable):
"s -> (s0,s1), (s1,s2), (s2, s3), ..."
a, b = itertools.tee(iterable)
@@ -647,7 +834,7 @@ class TestStateResolutionStore:
return list(result)
- def get_auth_chain_difference(self, auth_sets):
+ def get_auth_chain_difference(self, room_id, auth_sets):
chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
common = set(chains[0]).intersection(*chains[1:])
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index f5afed017c..1ac4ebc61d 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -15,308 +15,9 @@
# limitations under the License.
-from mock import Mock
-
-from twisted.internet import defer
-
-from synapse.util.async_helpers import ObservableDeferred
-from synapse.util.caches.descriptors import Cache, cached
-
from tests import unittest
-class CacheTestCase(unittest.HomeserverTestCase):
- def prepare(self, reactor, clock, homeserver):
- self.cache = Cache("test")
-
- def test_empty(self):
- failed = False
- try:
- self.cache.get("foo")
- except KeyError:
- failed = True
-
- self.assertTrue(failed)
-
- def test_hit(self):
- self.cache.prefill("foo", 123)
-
- self.assertEquals(self.cache.get("foo"), 123)
-
- def test_invalidate(self):
- self.cache.prefill(("foo",), 123)
- self.cache.invalidate(("foo",))
-
- failed = False
- try:
- self.cache.get(("foo",))
- except KeyError:
- failed = True
-
- self.assertTrue(failed)
-
- def test_eviction(self):
- cache = Cache("test", max_entries=2)
-
- cache.prefill(1, "one")
- cache.prefill(2, "two")
- cache.prefill(3, "three") # 1 will be evicted
-
- failed = False
- try:
- cache.get(1)
- except KeyError:
- failed = True
-
- self.assertTrue(failed)
-
- cache.get(2)
- cache.get(3)
-
- def test_eviction_lru(self):
- cache = Cache("test", max_entries=2)
-
- cache.prefill(1, "one")
- cache.prefill(2, "two")
-
- # Now access 1 again, thus causing 2 to be least-recently used
- cache.get(1)
-
- cache.prefill(3, "three")
-
- failed = False
- try:
- cache.get(2)
- except KeyError:
- failed = True
-
- self.assertTrue(failed)
-
- cache.get(1)
- cache.get(3)
-
-
-class CacheDecoratorTestCase(unittest.HomeserverTestCase):
- @defer.inlineCallbacks
- def test_passthrough(self):
- class A:
- @cached()
- def func(self, key):
- return key
-
- a = A()
-
- self.assertEquals((yield a.func("foo")), "foo")
- self.assertEquals((yield a.func("bar")), "bar")
-
- @defer.inlineCallbacks
- def test_hit(self):
- callcount = [0]
-
- class A:
- @cached()
- def func(self, key):
- callcount[0] += 1
- return key
-
- a = A()
- yield a.func("foo")
-
- self.assertEquals(callcount[0], 1)
-
- self.assertEquals((yield a.func("foo")), "foo")
- self.assertEquals(callcount[0], 1)
-
- @defer.inlineCallbacks
- def test_invalidate(self):
- callcount = [0]
-
- class A:
- @cached()
- def func(self, key):
- callcount[0] += 1
- return key
-
- a = A()
- yield a.func("foo")
-
- self.assertEquals(callcount[0], 1)
-
- a.func.invalidate(("foo",))
-
- yield a.func("foo")
-
- self.assertEquals(callcount[0], 2)
-
- def test_invalidate_missing(self):
- class A:
- @cached()
- def func(self, key):
- return key
-
- A().func.invalidate(("what",))
-
- @defer.inlineCallbacks
- def test_max_entries(self):
- callcount = [0]
-
- class A:
- @cached(max_entries=10)
- def func(self, key):
- callcount[0] += 1
- return key
-
- a = A()
-
- for k in range(0, 12):
- yield a.func(k)
-
- self.assertEquals(callcount[0], 12)
-
- # There must have been at least 2 evictions, meaning if we calculate
- # all 12 values again, we must get called at least 2 more times
- for k in range(0, 12):
- yield a.func(k)
-
- self.assertTrue(
- callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0])
- )
-
- def test_prefill(self):
- callcount = [0]
-
- d = defer.succeed(123)
-
- class A:
- @cached()
- def func(self, key):
- callcount[0] += 1
- return d
-
- a = A()
-
- a.func.prefill(("foo",), ObservableDeferred(d))
-
- self.assertEquals(a.func("foo").result, d.result)
- self.assertEquals(callcount[0], 0)
-
- @defer.inlineCallbacks
- def test_invalidate_context(self):
- callcount = [0]
- callcount2 = [0]
-
- class A:
- @cached()
- def func(self, key):
- callcount[0] += 1
- return key
-
- @cached(cache_context=True)
- def func2(self, key, cache_context):
- callcount2[0] += 1
- return self.func(key, on_invalidate=cache_context.invalidate)
-
- a = A()
- yield a.func2("foo")
-
- self.assertEquals(callcount[0], 1)
- self.assertEquals(callcount2[0], 1)
-
- a.func.invalidate(("foo",))
- yield a.func("foo")
-
- self.assertEquals(callcount[0], 2)
- self.assertEquals(callcount2[0], 1)
-
- yield a.func2("foo")
-
- self.assertEquals(callcount[0], 2)
- self.assertEquals(callcount2[0], 2)
-
- @defer.inlineCallbacks
- def test_eviction_context(self):
- callcount = [0]
- callcount2 = [0]
-
- class A:
- @cached(max_entries=2)
- def func(self, key):
- callcount[0] += 1
- return key
-
- @cached(cache_context=True)
- def func2(self, key, cache_context):
- callcount2[0] += 1
- return self.func(key, on_invalidate=cache_context.invalidate)
-
- a = A()
- yield a.func2("foo")
- yield a.func2("foo2")
-
- self.assertEquals(callcount[0], 2)
- self.assertEquals(callcount2[0], 2)
-
- yield a.func2("foo")
- self.assertEquals(callcount[0], 2)
- self.assertEquals(callcount2[0], 2)
-
- yield a.func("foo3")
-
- self.assertEquals(callcount[0], 3)
- self.assertEquals(callcount2[0], 2)
-
- yield a.func2("foo")
-
- self.assertEquals(callcount[0], 4)
- self.assertEquals(callcount2[0], 3)
-
- @defer.inlineCallbacks
- def test_double_get(self):
- callcount = [0]
- callcount2 = [0]
-
- class A:
- @cached()
- def func(self, key):
- callcount[0] += 1
- return key
-
- @cached(cache_context=True)
- def func2(self, key, cache_context):
- callcount2[0] += 1
- return self.func(key, on_invalidate=cache_context.invalidate)
-
- a = A()
- a.func2.cache.cache = Mock(wraps=a.func2.cache.cache)
-
- yield a.func2("foo")
-
- self.assertEquals(callcount[0], 1)
- self.assertEquals(callcount2[0], 1)
-
- a.func2.invalidate(("foo",))
- self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
-
- yield a.func2("foo")
- a.func2.invalidate(("foo",))
- self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
-
- self.assertEquals(callcount[0], 1)
- self.assertEquals(callcount2[0], 2)
-
- a.func.invalidate(("foo",))
- self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
- yield a.func("foo")
-
- self.assertEquals(callcount[0], 2)
- self.assertEquals(callcount2[0], 2)
-
- yield a.func2("foo")
-
- self.assertEquals(callcount[0], 2)
- self.assertEquals(callcount2[0], 3)
-
-
class UpsertManyTests(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.storage = hs.get_datastore()
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 46f94914ff..1ce29af5fd 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -58,7 +58,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
# must be done after inserts
database = hs.get_datastores().databases[0]
self.store = ApplicationServiceStore(
- database, make_conn(database._database_config, database.engine), hs
+ database, make_conn(database._database_config, database.engine, "test"), hs
)
def tearDown(self):
@@ -132,7 +132,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
db_config = hs.config.get_single_database()
self.store = TestTransactionStore(
- database, make_conn(db_config, self.engine), hs
+ database, make_conn(db_config, self.engine, "test"), hs
)
def _add_service(self, url, as_token, id):
@@ -244,7 +244,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
service = Mock(id=self.as_list[0]["id"])
events = [Mock(event_id="e1"), Mock(event_id="e2")]
txn = yield defer.ensureDeferred(
- self.store.create_appservice_txn(service, events)
+ self.store.create_appservice_txn(service, events, [])
)
self.assertEquals(txn.id, 1)
self.assertEquals(txn.events, events)
@@ -258,7 +258,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self._insert_txn(service.id, 9644, events)
yield self._insert_txn(service.id, 9645, events)
txn = yield defer.ensureDeferred(
- self.store.create_appservice_txn(service, events)
+ self.store.create_appservice_txn(service, events, [])
)
self.assertEquals(txn.id, 9646)
self.assertEquals(txn.events, events)
@@ -270,7 +270,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
events = [Mock(event_id="e1"), Mock(event_id="e2")]
yield self._set_last_txn(service.id, 9643)
txn = yield defer.ensureDeferred(
- self.store.create_appservice_txn(service, events)
+ self.store.create_appservice_txn(service, events, [])
)
self.assertEquals(txn.id, 9644)
self.assertEquals(txn.events, events)
@@ -293,7 +293,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self._insert_txn(self.as_list[3]["id"], 9643, events)
txn = yield defer.ensureDeferred(
- self.store.create_appservice_txn(service, events)
+ self.store.create_appservice_txn(service, events, [])
)
self.assertEquals(txn.id, 9644)
self.assertEquals(txn.events, events)
@@ -410,6 +410,62 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
)
+class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver()
+ return hs
+
+ def prepare(self, hs, reactor, clock):
+ self.service = Mock(id="foo")
+ self.store = self.hs.get_datastore()
+ self.get_success(self.store.set_appservice_state(self.service, "up"))
+
+ def test_get_type_stream_id_for_appservice_no_value(self):
+ value = self.get_success(
+ self.store.get_type_stream_id_for_appservice(self.service, "read_receipt")
+ )
+ self.assertEquals(value, 0)
+
+ value = self.get_success(
+ self.store.get_type_stream_id_for_appservice(self.service, "presence")
+ )
+ self.assertEquals(value, 0)
+
+ def test_get_type_stream_id_for_appservice_invalid_type(self):
+ self.get_failure(
+ self.store.get_type_stream_id_for_appservice(self.service, "foobar"),
+ ValueError,
+ )
+
+ def test_set_type_stream_id_for_appservice(self):
+ read_receipt_value = 1024
+ self.get_success(
+ self.store.set_type_stream_id_for_appservice(
+ self.service, "read_receipt", read_receipt_value
+ )
+ )
+ result = self.get_success(
+ self.store.get_type_stream_id_for_appservice(self.service, "read_receipt")
+ )
+ self.assertEqual(result, read_receipt_value)
+
+ self.get_success(
+ self.store.set_type_stream_id_for_appservice(
+ self.service, "presence", read_receipt_value
+ )
+ )
+ result = self.get_success(
+ self.store.get_type_stream_id_for_appservice(self.service, "presence")
+ )
+ self.assertEqual(result, read_receipt_value)
+
+ def test_set_type_stream_id_for_appservice_invalid_type(self):
+ self.get_failure(
+ self.store.set_type_stream_id_for_appservice(self.service, "foobar", 1024),
+ ValueError,
+ )
+
+
# required for ApplicationServiceTransactionStoreTestCase tests
class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore):
def __init__(self, database: DatabasePool, db_conn, hs):
@@ -448,7 +504,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
database = hs.get_datastores().databases[0]
ApplicationServiceStore(
- database, make_conn(database._database_config, database.engine), hs
+ database, make_conn(database._database_config, database.engine, "test"), hs
)
@defer.inlineCallbacks
@@ -467,7 +523,9 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
with self.assertRaises(ConfigError) as cm:
database = hs.get_datastores().databases[0]
ApplicationServiceStore(
- database, make_conn(database._database_config, database.engine), hs
+ database,
+ make_conn(database._database_config, database.engine, "test"),
+ hs,
)
e = cm.exception
@@ -491,7 +549,9 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
with self.assertRaises(ConfigError) as cm:
database = hs.get_datastores().databases[0]
ApplicationServiceStore(
- database, make_conn(database._database_config, database.engine), hs
+ database,
+ make_conn(database._database_config, database.engine, "test"),
+ hs,
)
e = cm.exception
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 080761d1d2..c13a57dad1 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -22,7 +22,7 @@ import synapse.rest.admin
from synapse.api.constants import EventTypes
from synapse.rest.client.v1 import login, room
from synapse.storage import prepare_database
-from synapse.types import Requester, UserID
+from synapse.types import UserID, create_requester
from tests.unittest import HomeserverTestCase
@@ -38,7 +38,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
# Create a test user and room
self.user = UserID("alice", "test")
- self.requester = Requester(self.user, None, False, False, None, None)
+ self.requester = create_requester(self.user)
info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
self.room_id = info["room_id"]
@@ -260,7 +260,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
# Create a test user and room
self.user = UserID.from_string(self.register_user("user1", "password"))
self.token1 = self.login("user1", "password")
- self.requester = Requester(self.user, None, False, False, None, None)
+ self.requester = create_requester(self.user)
info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
self.room_id = info["room_id"]
self.event_creator = homeserver.get_event_creation_handler()
@@ -309,36 +309,6 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
)
self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids))
- @patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=0)
- def test_send_dummy_event_without_consent(self):
- self._create_extremity_rich_graph()
- self._enable_consent_checking()
-
- # Pump the reactor repeatedly so that the background updates have a
- # chance to run. Attempt to add dummy event with user that has not consented
- # Check that dummy event send fails.
- self.pump(10 * 60)
- latest_event_ids = self.get_success(
- self.store.get_latest_event_ids_in_room(self.room_id)
- )
- self.assertTrue(len(latest_event_ids) == self.EXTREMITIES_COUNT)
-
- # Create new user, and add consent
- user2 = self.register_user("user2", "password")
- token2 = self.login("user2", "password")
- self.get_success(
- self.store.user_set_consent_version(user2, self.CONSENT_VERSION)
- )
- self.helper.join(self.room_id, user2, tok=token2)
-
- # Background updates should now cause a dummy event to be added to the graph
- self.pump(10 * 60)
-
- latest_event_ids = self.get_success(
- self.store.get_latest_event_ids_in_room(self.room_id)
- )
- self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids))
-
@patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=250)
def test_expiry_logic(self):
"""Simple test to ensure that _expire_rooms_to_exclude_from_dummy_event_insertion()
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 755c70db31..a69117c5a9 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -21,6 +21,7 @@ from synapse.http.site import XForwardedForRequest
from synapse.rest.client.v1 import login
from tests import unittest
+from tests.server import make_request
from tests.test_utils import make_awaitable
from tests.unittest import override_config
@@ -408,18 +409,18 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase):
# Advance to a known time
self.reactor.advance(123456 - self.reactor.seconds())
- request, channel = self.make_request(
+ headers1 = {b"User-Agent": b"Mozzila pizza"}
+ headers1.update(headers)
+
+ make_request(
+ self.reactor,
+ self.site,
"GET",
- "/_matrix/client/r0/admin/users/" + self.user_id,
+ "/_synapse/admin/v1/users/" + self.user_id,
access_token=access_token,
- **make_request_args
+ custom_headers=headers1.items(),
+ **make_request_args,
)
- request.requestHeaders.addRawHeader(b"User-Agent", b"Mozzila pizza")
-
- # Add the optional headers
- for h, v in headers.items():
- request.requestHeaders.addRawHeader(h, v)
- self.render(request)
# Advance so the save loop occurs
self.reactor.advance(100)
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index ecb00f4e02..dabc1c5f09 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -80,6 +80,32 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
)
@defer.inlineCallbacks
+ def test_count_devices_by_users(self):
+ yield defer.ensureDeferred(
+ self.store.store_device("user_id", "device1", "display_name 1")
+ )
+ yield defer.ensureDeferred(
+ self.store.store_device("user_id", "device2", "display_name 2")
+ )
+ yield defer.ensureDeferred(
+ self.store.store_device("user_id2", "device3", "display_name 3")
+ )
+
+ res = yield defer.ensureDeferred(self.store.count_devices_by_users())
+ self.assertEqual(0, res)
+
+ res = yield defer.ensureDeferred(self.store.count_devices_by_users(["unknown"]))
+ self.assertEqual(0, res)
+
+ res = yield defer.ensureDeferred(self.store.count_devices_by_users(["user_id"]))
+ self.assertEqual(2, res)
+
+ res = yield defer.ensureDeferred(
+ self.store.count_devices_by_users(["user_id", "user_id2"])
+ )
+ self.assertEqual(3, res)
+
+ @defer.inlineCallbacks
def test_get_device_updates_by_remote(self):
device_ids = ["device_id1", "device_id2"]
diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py
index 35dafbb904..3d7760d5d9 100644
--- a/tests/storage/test_e2e_room_keys.py
+++ b/tests/storage/test_e2e_room_keys.py
@@ -26,7 +26,7 @@ room_key = {
class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
- hs = self.setup_test_homeserver("server", http_client=None)
+ hs = self.setup_test_homeserver("server", federation_http_client=None)
self.store = hs.get_datastore()
return hs
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index d4c3b867e3..482506d731 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -202,34 +202,41 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# Now actually test that various combinations give the right result:
difference = self.get_success(
- self.store.get_auth_chain_difference([{"a"}, {"b"}])
+ self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}])
)
self.assertSetEqual(difference, {"a", "b"})
difference = self.get_success(
- self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}])
+ self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}])
)
self.assertSetEqual(difference, {"a", "b", "c", "e", "f"})
difference = self.get_success(
- self.store.get_auth_chain_difference([{"a", "c"}, {"b"}])
+ self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b"}])
)
self.assertSetEqual(difference, {"a", "b", "c"})
difference = self.get_success(
- self.store.get_auth_chain_difference([{"a"}, {"b"}, {"d"}])
+ self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b", "c"}])
+ )
+ self.assertSetEqual(difference, {"a", "b"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"d"}])
)
self.assertSetEqual(difference, {"a", "b", "d", "e"})
difference = self.get_success(
- self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}, {"d"}])
+ self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}, {"d"}])
)
self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"})
difference = self.get_success(
- self.store.get_auth_chain_difference([{"a"}, {"b"}, {"e"}])
+ self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"e"}])
)
self.assertSetEqual(difference, {"a", "b"})
- difference = self.get_success(self.store.get_auth_chain_difference([{"a"}]))
+ difference = self.get_success(
+ self.store.get_auth_chain_difference(room_id, [{"a"}])
+ )
self.assertSetEqual(difference, set())
diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py
index 3957471f3f..7691f2d790 100644
--- a/tests/storage/test_event_metrics.py
+++ b/tests/storage/test_event_metrics.py
@@ -14,7 +14,7 @@
# limitations under the License.
from synapse.metrics import REGISTRY, generate_latest
-from synapse.types import Requester, UserID
+from synapse.types import UserID, create_requester
from tests.unittest import HomeserverTestCase
@@ -27,7 +27,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase):
room_creator = self.hs.get_room_creation_handler()
user = UserID("alice", "test")
- requester = Requester(user, None, False, False, None, None)
+ requester = create_requester(user)
# Real events, forward extremities
events = [(3, 2), (6, 2), (4, 6)]
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
new file mode 100644
index 0000000000..71210ce606
--- /dev/null
+++ b/tests/storage/test_events.py
@@ -0,0 +1,334 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.api.room_versions import RoomVersions
+from synapse.federation.federation_base import event_from_pdu_json
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+
+from tests.unittest import HomeserverTestCase
+
+
+class ExtremPruneTestCase(HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, homeserver):
+ self.state = self.hs.get_state_handler()
+ self.persistence = self.hs.get_storage().persistence
+ self.store = self.hs.get_datastore()
+
+ self.register_user("user", "pass")
+ self.token = self.login("user", "pass")
+
+ self.room_id = self.helper.create_room_as(
+ "user", room_version=RoomVersions.V6.identifier, tok=self.token
+ )
+
+ body = self.helper.send(self.room_id, body="Test", tok=self.token)
+ local_message_event_id = body["event_id"]
+
+ # Fudge a remote event and persist it. This will be the extremity before
+ # the gap.
+ self.remote_event_1 = event_from_pdu_json(
+ {
+ "type": EventTypes.Message,
+ "state_key": "@user:other",
+ "content": {},
+ "room_id": self.room_id,
+ "sender": "@user:other",
+ "depth": 5,
+ "prev_events": [local_message_event_id],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ RoomVersions.V6,
+ )
+
+ self.persist_event(self.remote_event_1)
+
+ # Check that the current extremities is the remote event.
+ self.assert_extremities([self.remote_event_1.event_id])
+
+ def persist_event(self, event, state=None):
+ """Persist the event, with optional state
+ """
+ context = self.get_success(
+ self.state.compute_event_context(event, old_state=state)
+ )
+ self.get_success(self.persistence.persist_event(event, context))
+
+ def assert_extremities(self, expected_extremities):
+ """Assert the current extremities for the room
+ """
+ extremities = self.get_success(
+ self.store.get_prev_events_for_room(self.room_id)
+ )
+ self.assertCountEqual(extremities, expected_extremities)
+
+ def test_prune_gap(self):
+ """Test that we drop extremities after a gap when we see an event from
+ the same domain.
+ """
+
+ # Fudge a second event which points to an event we don't have. This is a
+ # state event so that the state changes (otherwise we won't prune the
+ # extremity as they'll have the same state group).
+ remote_event_2 = event_from_pdu_json(
+ {
+ "type": EventTypes.Member,
+ "state_key": "@user:other",
+ "content": {"membership": Membership.JOIN},
+ "room_id": self.room_id,
+ "sender": "@user:other",
+ "depth": 50,
+ "prev_events": ["$some_unknown_message"],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ RoomVersions.V6,
+ )
+
+ state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+
+ self.persist_event(remote_event_2, state=state_before_gap.values())
+
+ # Check the new extremity is just the new remote event.
+ self.assert_extremities([remote_event_2.event_id])
+
+ def test_do_not_prune_gap_if_state_different(self):
+ """Test that we don't prune extremities after a gap if the resolved
+ state is different.
+ """
+
+ # Fudge a second event which points to an event we don't have.
+ remote_event_2 = event_from_pdu_json(
+ {
+ "type": EventTypes.Message,
+ "state_key": "@user:other",
+ "content": {},
+ "room_id": self.room_id,
+ "sender": "@user:other",
+ "depth": 10,
+ "prev_events": ["$some_unknown_message"],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ RoomVersions.V6,
+ )
+
+ # Now we persist it with state with a dropped history visibility
+ # setting. The state resolution across the old and new event will then
+ # include it, and so the resolved state won't match the new state.
+ state_before_gap = dict(
+ self.get_success(self.state.get_current_state(self.room_id))
+ )
+ state_before_gap.pop(("m.room.history_visibility", ""))
+
+ context = self.get_success(
+ self.state.compute_event_context(
+ remote_event_2, old_state=state_before_gap.values()
+ )
+ )
+
+ self.get_success(self.persistence.persist_event(remote_event_2, context))
+
+ # Check that we haven't dropped the old extremity.
+ self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
+
+ def test_prune_gap_if_old(self):
+ """Test that we drop extremities after a gap when the previous extremity
+ is "old"
+ """
+
+ # Advance the clock for many days to make the old extremity "old". We
+ # also set the depth to "lots".
+ self.reactor.advance(7 * 24 * 60 * 60)
+
+ # Fudge a second event which points to an event we don't have. This is a
+ # state event so that the state changes (otherwise we won't prune the
+ # extremity as they'll have the same state group).
+ remote_event_2 = event_from_pdu_json(
+ {
+ "type": EventTypes.Member,
+ "state_key": "@user:other2",
+ "content": {"membership": Membership.JOIN},
+ "room_id": self.room_id,
+ "sender": "@user:other2",
+ "depth": 10000,
+ "prev_events": ["$some_unknown_message"],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ RoomVersions.V6,
+ )
+
+ state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+
+ self.persist_event(remote_event_2, state=state_before_gap.values())
+
+ # Check the new extremity is just the new remote event.
+ self.assert_extremities([remote_event_2.event_id])
+
+ def test_do_not_prune_gap_if_other_server(self):
+ """Test that we do not drop extremities after a gap when we see an event
+ from a different domain.
+ """
+
+ # Fudge a second event which points to an event we don't have. This is a
+ # state event so that the state changes (otherwise we won't prune the
+ # extremity as they'll have the same state group).
+ remote_event_2 = event_from_pdu_json(
+ {
+ "type": EventTypes.Member,
+ "state_key": "@user:other2",
+ "content": {"membership": Membership.JOIN},
+ "room_id": self.room_id,
+ "sender": "@user:other2",
+ "depth": 10,
+ "prev_events": ["$some_unknown_message"],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ RoomVersions.V6,
+ )
+
+ state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+
+ self.persist_event(remote_event_2, state=state_before_gap.values())
+
+ # Check the new extremity is just the new remote event.
+ self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
+
+ def test_prune_gap_if_dummy_remote(self):
+ """Test that we drop extremities after a gap when the previous extremity
+ is a local dummy event and only points to remote events.
+ """
+
+ body = self.helper.send_event(
+ self.room_id, type=EventTypes.Dummy, content={}, tok=self.token
+ )
+ local_message_event_id = body["event_id"]
+ self.assert_extremities([local_message_event_id])
+
+ # Advance the clock for many days to make the old extremity "old". We
+ # also set the depth to "lots".
+ self.reactor.advance(7 * 24 * 60 * 60)
+
+ # Fudge a second event which points to an event we don't have. This is a
+ # state event so that the state changes (otherwise we won't prune the
+ # extremity as they'll have the same state group).
+ remote_event_2 = event_from_pdu_json(
+ {
+ "type": EventTypes.Member,
+ "state_key": "@user:other2",
+ "content": {"membership": Membership.JOIN},
+ "room_id": self.room_id,
+ "sender": "@user:other2",
+ "depth": 10000,
+ "prev_events": ["$some_unknown_message"],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ RoomVersions.V6,
+ )
+
+ state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+
+ self.persist_event(remote_event_2, state=state_before_gap.values())
+
+ # Check the new extremity is just the new remote event.
+ self.assert_extremities([remote_event_2.event_id])
+
+ def test_prune_gap_if_dummy_local(self):
+ """Test that we don't drop extremities after a gap when the previous
+ extremity is a local dummy event and points to local events.
+ """
+
+ body = self.helper.send(self.room_id, body="Test", tok=self.token)
+
+ body = self.helper.send_event(
+ self.room_id, type=EventTypes.Dummy, content={}, tok=self.token
+ )
+ local_message_event_id = body["event_id"]
+ self.assert_extremities([local_message_event_id])
+
+ # Advance the clock for many days to make the old extremity "old". We
+ # also set the depth to "lots".
+ self.reactor.advance(7 * 24 * 60 * 60)
+
+ # Fudge a second event which points to an event we don't have. This is a
+ # state event so that the state changes (otherwise we won't prune the
+ # extremity as they'll have the same state group).
+ remote_event_2 = event_from_pdu_json(
+ {
+ "type": EventTypes.Member,
+ "state_key": "@user:other2",
+ "content": {"membership": Membership.JOIN},
+ "room_id": self.room_id,
+ "sender": "@user:other2",
+ "depth": 10000,
+ "prev_events": ["$some_unknown_message"],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ RoomVersions.V6,
+ )
+
+ state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+
+ self.persist_event(remote_event_2, state=state_before_gap.values())
+
+ # Check the new extremity is just the new remote event.
+ self.assert_extremities([remote_event_2.event_id, local_message_event_id])
+
+ def test_do_not_prune_gap_if_not_dummy(self):
+ """Test that we do not drop extremities after a gap when the previous extremity
+ is not a dummy event.
+ """
+
+ body = self.helper.send(self.room_id, body="test", tok=self.token)
+ local_message_event_id = body["event_id"]
+ self.assert_extremities([local_message_event_id])
+
+ # Fudge a second event which points to an event we don't have. This is a
+ # state event so that the state changes (otherwise we won't prune the
+ # extremity as they'll have the same state group).
+ remote_event_2 = event_from_pdu_json(
+ {
+ "type": EventTypes.Member,
+ "state_key": "@user:other2",
+ "content": {"membership": Membership.JOIN},
+ "room_id": self.room_id,
+ "sender": "@user:other2",
+ "depth": 10000,
+ "prev_events": ["$some_unknown_message"],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ RoomVersions.V6,
+ )
+
+ state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+
+ self.persist_event(remote_event_2, state=state_before_gap.values())
+
+ # Check the new extremity is just the new remote event.
+ self.assert_extremities([local_message_event_id, remote_event_2.event_id])
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 392b08832b..cc0612cf65 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -199,10 +199,17 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
first_id_gen = self._create_id_generator("first", writers=["first", "second"])
second_id_gen = self._create_id_generator("second", writers=["first", "second"])
- self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
- self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
+ # The first ID gen will notice that it can advance its token to 7 as it
+ # has no in progress writes...
+ self.assertEqual(first_id_gen.get_positions(), {"first": 7, "second": 7})
+ self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
+ # ... but the second ID gen doesn't know that.
+ self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
+ self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3)
+ self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
+
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
@@ -211,7 +218,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(stream_id, 8)
self.assertEqual(
- first_id_gen.get_positions(), {"first": 3, "second": 7}
+ first_id_gen.get_positions(), {"first": 7, "second": 7}
)
self.get_success(_get_next_async())
@@ -279,7 +286,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self._insert_row_with_id("first", 3)
self._insert_row_with_id("second", 5)
- id_gen = self._create_id_generator("first", writers=["first", "second"])
+ id_gen = self._create_id_generator("worker", writers=["first", "second"])
self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
@@ -319,14 +326,14 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
id_gen = self._create_id_generator("first", writers=["first", "second"])
- self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
+ self.assertEqual(id_gen.get_positions(), {"first": 5, "second": 5})
- self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 5)
async def _get_next_async():
async with id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 6)
- self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 5)
self.get_success(_get_next_async())
@@ -388,7 +395,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self._insert_row_with_id("second", 5)
# Initial config has two writers
- id_gen = self._create_id_generator("first", writers=["first", "second"])
+ id_gen = self._create_id_generator("worker", writers=["first", "second"])
self.assertEqual(id_gen.get_persisted_upto_position(), 3)
self.assertEqual(id_gen.get_current_token_for_writer("first"), 3)
self.assertEqual(id_gen.get_current_token_for_writer("second"), 5)
@@ -568,7 +575,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.get_success(_get_next_async2())
- self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2})
+ self.assertEqual(id_gen_1.get_positions(), {"first": -2, "second": -2})
self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)
self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)
diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py
index 7e7f1286d9..e9e3bca3bf 100644
--- a/tests/storage/test_main.py
+++ b/tests/storage/test_main.py
@@ -48,3 +48,10 @@ class DataStoreTestCase(unittest.TestCase):
self.assertEquals(1, total)
self.assertEquals(self.displayname, users.pop()["displayname"])
+
+ users, total = yield defer.ensureDeferred(
+ self.store.get_users_paginate(0, 10, name="BC", guests=False)
+ )
+
+ self.assertEquals(1, total)
+ self.assertEquals(self.displayname, users.pop()["displayname"])
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index cc1f3c53c5..a06ad2c03e 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -27,7 +27,7 @@ class PurgeTests(HomeserverTestCase):
servlets = [room.register_servlets]
def make_homeserver(self, reactor, clock):
- hs = self.setup_test_homeserver("server", http_client=None)
+ hs = self.setup_test_homeserver("server", federation_http_client=None)
return hs
def prepare(self, reactor, clock, hs):
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 1ea35d60c1..a6303bf0ee 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -14,9 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from mock import Mock
-
from canonicaljson import json
from twisted.internet import defer
@@ -30,12 +27,10 @@ from tests.utils import create_room
class RedactionTestCase(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
- config = self.default_config()
+ def default_config(self):
+ config = super().default_config()
config["redaction_retention_period"] = "30d"
- return self.setup_test_homeserver(
- resource_for_federation=Mock(), http_client=None, config=config
- )
+ return config
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
@@ -236,9 +231,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self._event_id = event_id
@defer.inlineCallbacks
- def build(self, prev_event_ids):
+ def build(self, prev_event_ids, auth_event_ids):
built_event = yield defer.ensureDeferred(
- self._base_builder.build(prev_event_ids)
+ self._base_builder.build(prev_event_ids, auth_event_ids)
)
built_event._event_id = self._event_id
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 6b582771fe..c8c7a90e5d 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -69,11 +69,9 @@ class RegistrationStoreTestCase(unittest.TestCase):
self.store.get_user_by_access_token(self.tokens[1])
)
- self.assertDictContainsSubset(
- {"name": self.user_id, "device_id": self.device_id}, result
- )
-
- self.assertTrue("token_id" in result)
+ self.assertEqual(result.user_id, self.user_id)
+ self.assertEqual(result.device_id, self.device_id)
+ self.assertIsNotNone(result.token_id)
@defer.inlineCallbacks
def test_user_delete_access_tokens(self):
@@ -105,7 +103,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
user = yield defer.ensureDeferred(
self.store.get_user_by_access_token(self.tokens[0])
)
- self.assertEqual(self.user_id, user["name"])
+ self.assertEqual(self.user_id, user.user_id)
# now delete the rest
yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id))
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 12ccc1f53e..d2aed66f6d 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -14,12 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from unittest.mock import Mock
-
from synapse.api.constants import Membership
from synapse.rest.admin import register_servlets_for_client_rest_resource
from synapse.rest.client.v1 import login, room
-from synapse.types import Requester, UserID
+from synapse.types import UserID, create_requester
from tests import unittest
from tests.test_utils import event_injection
@@ -34,12 +32,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
- def make_homeserver(self, reactor, clock):
- hs = self.setup_test_homeserver(
- resource_for_federation=Mock(), http_client=None
- )
- return hs
-
def prepare(self, reactor, clock, hs: TestHomeServer):
# We can't test the RoomMemberStore on its own without the other event
@@ -187,7 +179,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
# Now let's create a room, which will insert a membership
user = UserID("alice", "test")
- requester = Requester(user, None, False, False, None, None)
+ requester = create_requester(user)
self.get_success(self.room_creator.create_room(requester, {}))
# Register the background update to run again.
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 738e912468..a6f63f4aaf 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -21,6 +21,8 @@ from tests.utils import setup_test_homeserver
ALICE = "@alice:a"
BOB = "@bob:b"
BOBBY = "@bobby:a"
+# The localpart isn't 'Bela' on purpose so we can test looking up display names.
+BELA = "@somenickname:a"
class UserDirectoryStoreTestCase(unittest.TestCase):
@@ -41,6 +43,9 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
self.store.update_profile_in_user_dir(BOBBY, "bobby", None)
)
yield defer.ensureDeferred(
+ self.store.update_profile_in_user_dir(BELA, "Bela", None)
+ )
+ yield defer.ensureDeferred(
self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB))
)
@@ -72,3 +77,21 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
)
finally:
self.hs.config.user_directory_search_all_users = False
+
+ @defer.inlineCallbacks
+ def test_search_user_dir_stop_words(self):
+ """Tests that a user can look up another user by searching for the start if its
+ display name even if that name happens to be a common English word that would
+ usually be ignored in full text searches.
+ """
+ self.hs.config.user_directory_search_all_users = True
+ try:
+ r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "be", 10))
+ self.assertFalse(r["limited"])
+ self.assertEqual(1, len(r["results"]))
+ self.assertDictEqual(
+ r["results"][0],
+ {"user_id": BELA, "display_name": "Bela", "avatar_url": None},
+ )
+ finally:
+ self.hs.config.user_directory_search_all_users = False
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 27a7fc9ed7..fc9aab32d0 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -20,7 +20,7 @@ from twisted.internet.defer import succeed
from synapse.api.errors import FederationError
from synapse.events import make_event_from_dict
from synapse.logging.context import LoggingContext
-from synapse.types import Requester, UserID
+from synapse.types import UserID, create_requester
from synapse.util import Clock
from synapse.util.retryutils import NotRetryingDestination
@@ -37,13 +37,13 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
self.hs_clock = Clock(self.reactor)
self.homeserver = setup_test_homeserver(
self.addCleanup,
- http_client=self.http_client,
+ federation_http_client=self.http_client,
clock=self.hs_clock,
reactor=self.reactor,
)
user_id = UserID("us", "test")
- our_user = Requester(user_id, None, False, False, None, None)
+ our_user = create_requester(user_id)
room_creator = self.homeserver.get_room_creation_handler()
self.room_id = self.get_success(
room_creator.create_room(
@@ -75,7 +75,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
}
)
- self.handler = self.homeserver.get_handlers().federation_handler
+ self.handler = self.homeserver.get_federation_handler()
self.handler.do_auth = lambda origin, event, context, auth_events: succeed(
context
)
@@ -134,7 +134,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
}
)
- with LoggingContext(request="lying_event"):
+ with LoggingContext():
failure = self.get_failure(
self.handler.on_receive_pdu(
"test.serv", lying_event, sent_to_us_directly=True
diff --git a/tests/test_mau.py b/tests/test_mau.py
index 654a6fa42d..51660b51d5 100644
--- a/tests/test_mau.py
+++ b/tests/test_mau.py
@@ -19,6 +19,7 @@ import json
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, HttpResponseException, SynapseError
+from synapse.appservice import ApplicationService
from synapse.rest.client.v2_alpha import register, sync
from tests import unittest
@@ -75,6 +76,45 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.assertEqual(e.code, 403)
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+ def test_as_ignores_mau(self):
+ """Test that application services can still create users when the MAU
+ limit has been reached. This only works when application service
+ user ip tracking is disabled.
+ """
+
+ # Create and sync so that the MAU counts get updated
+ token1 = self.create_user("kermit1")
+ self.do_sync_for_user(token1)
+ token2 = self.create_user("kermit2")
+ self.do_sync_for_user(token2)
+
+ # check we're testing what we think we are: there should be two active users
+ self.assertEqual(self.get_success(self.store.get_monthly_active_count()), 2)
+
+ # We've created and activated two users, we shouldn't be able to
+ # register new users
+ with self.assertRaises(SynapseError) as cm:
+ self.create_user("kermit3")
+
+ e = cm.exception
+ self.assertEqual(e.code, 403)
+ self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+
+ # Cheekily add an application service that we use to register a new user
+ # with.
+ as_token = "foobartoken"
+ self.store.services_cache.append(
+ ApplicationService(
+ token=as_token,
+ hostname=self.hs.hostname,
+ id="SomeASID",
+ sender="@as_sender:test",
+ namespaces={"users": [{"regex": "@as_*", "exclusive": True}]},
+ )
+ )
+
+ self.create_user("as_kermit4", token=as_token)
+
def test_allowed_after_a_month_mau(self):
# Create and sync so that the MAU counts get updated
token1 = self.create_user("kermit1")
@@ -192,7 +232,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.reactor.advance(100)
self.assertEqual(2, self.successResultOf(count))
- def create_user(self, localpart):
+ def create_user(self, localpart, token=None):
request_data = json.dumps(
{
"username": localpart,
@@ -201,8 +241,9 @@ class TestMauLimit(unittest.HomeserverTestCase):
}
)
- request, channel = self.make_request("POST", "/register", request_data)
- self.render(request)
+ channel = self.make_request(
+ "POST", "/register", request_data, access_token=token,
+ )
if channel.code != 200:
raise HttpResponseException(
@@ -214,8 +255,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
return access_token
def do_sync_for_user(self, token):
- request, channel = self.make_request("GET", "/sync", access_token=token)
- self.render(request)
+ channel = self.make_request("GET", "/sync", access_token=token)
if channel.code != 200:
raise HttpResponseException(
diff --git a/tests/test_metrics.py b/tests/test_metrics.py
index f5f63d8ed6..759e4cd048 100644
--- a/tests/test_metrics.py
+++ b/tests/test_metrics.py
@@ -15,7 +15,7 @@
# limitations under the License.
from synapse.metrics import REGISTRY, InFlightGauge, generate_latest
-from synapse.util.caches.descriptors import Cache
+from synapse.util.caches.deferred_cache import DeferredCache
from tests import unittest
@@ -138,7 +138,7 @@ class CacheMetricsTests(unittest.HomeserverTestCase):
Caches produce metrics reflecting their state when scraped.
"""
CACHE_NAME = "cache_metrics_test_fgjkbdfg"
- cache = Cache(CACHE_NAME, max_entries=777)
+ cache = DeferredCache(CACHE_NAME, max_entries=777)
items = {
x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii")
diff --git a/tests/test_phone_home.py b/tests/test_phone_home.py
index 7657bddea5..e7aed092c2 100644
--- a/tests/test_phone_home.py
+++ b/tests/test_phone_home.py
@@ -17,7 +17,7 @@ import resource
import mock
-from synapse.app.homeserver import phone_stats_home
+from synapse.app.phone_stats_home import phone_stats_home
from tests.unittest import HomeserverTestCase
diff --git a/tests/test_preview.py b/tests/test_preview.py
index 7f67ee9e1f..a883d707df 100644
--- a/tests/test_preview.py
+++ b/tests/test_preview.py
@@ -56,7 +56,7 @@ class PreviewTestCase(unittest.TestCase):
desc = summarize_paragraphs(example_paras, min_size=200, max_size=500)
- self.assertEquals(
+ self.assertEqual(
desc,
"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
@@ -69,7 +69,7 @@ class PreviewTestCase(unittest.TestCase):
desc = summarize_paragraphs(example_paras[1:], min_size=200, max_size=500)
- self.assertEquals(
+ self.assertEqual(
desc,
"Tromsø lies in Northern Norway. The municipality has a population of"
" (2015) 72,066, but with an annual influx of students it has over 75,000"
@@ -96,7 +96,7 @@ class PreviewTestCase(unittest.TestCase):
desc = summarize_paragraphs(example_paras, min_size=200, max_size=500)
- self.assertEquals(
+ self.assertEqual(
desc,
"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
@@ -122,7 +122,7 @@ class PreviewTestCase(unittest.TestCase):
]
desc = summarize_paragraphs(example_paras, min_size=200, max_size=500)
- self.assertEquals(
+ self.assertEqual(
desc,
"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
@@ -149,7 +149,7 @@ class PreviewUrlTestCase(unittest.TestCase):
og = decode_and_calc_og(html, "http://example.com/test.html")
- self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."})
+ self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
def test_comment(self):
html = """
@@ -164,7 +164,7 @@ class PreviewUrlTestCase(unittest.TestCase):
og = decode_and_calc_og(html, "http://example.com/test.html")
- self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."})
+ self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
def test_comment2(self):
html = """
@@ -182,7 +182,7 @@ class PreviewUrlTestCase(unittest.TestCase):
og = decode_and_calc_og(html, "http://example.com/test.html")
- self.assertEquals(
+ self.assertEqual(
og,
{
"og:title": "Foo",
@@ -203,7 +203,7 @@ class PreviewUrlTestCase(unittest.TestCase):
og = decode_and_calc_og(html, "http://example.com/test.html")
- self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."})
+ self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
def test_missing_title(self):
html = """
@@ -216,7 +216,7 @@ class PreviewUrlTestCase(unittest.TestCase):
og = decode_and_calc_og(html, "http://example.com/test.html")
- self.assertEquals(og, {"og:title": None, "og:description": "Some text."})
+ self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
def test_h1_as_title(self):
html = """
@@ -230,7 +230,7 @@ class PreviewUrlTestCase(unittest.TestCase):
og = decode_and_calc_og(html, "http://example.com/test.html")
- self.assertEquals(og, {"og:title": "Title", "og:description": "Some text."})
+ self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."})
def test_missing_title_and_broken_h1(self):
html = """
@@ -244,4 +244,9 @@ class PreviewUrlTestCase(unittest.TestCase):
og = decode_and_calc_og(html, "http://example.com/test.html")
- self.assertEquals(og, {"og:title": None, "og:description": "Some text."})
+ self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
+
+ def test_empty(self):
+ html = ""
+ og = decode_and_calc_og(html, "http://example.com/test.html")
+ self.assertEqual(og, {})
diff --git a/tests/test_server.py b/tests/test_server.py
index 655c918a15..815da18e65 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -26,9 +26,9 @@ from synapse.util import Clock
from tests import unittest
from tests.server import (
+ FakeSite,
ThreadedMemoryReactorClock,
make_request,
- render,
setup_test_homeserver,
)
@@ -38,7 +38,10 @@ class JsonResourceTests(unittest.TestCase):
self.reactor = ThreadedMemoryReactorClock()
self.hs_clock = Clock(self.reactor)
self.homeserver = setup_test_homeserver(
- self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.reactor
+ self.addCleanup,
+ federation_http_client=None,
+ clock=self.hs_clock,
+ reactor=self.reactor,
)
def test_handler_for_request(self):
@@ -61,12 +64,10 @@ class JsonResourceTests(unittest.TestCase):
"test_servlet",
)
- request, channel = make_request(
- self.reactor, b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83"
+ make_request(
+ self.reactor, FakeSite(res), b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83"
)
- render(request, res, self.reactor)
- self.assertEqual(request.args, {b"a": ["\N{SNOWMAN}".encode("utf8")]})
self.assertEqual(got_kwargs, {"room_id": "\N{SNOWMAN}"})
def test_callback_direct_exception(self):
@@ -83,8 +84,7 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
)
- request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo")
- render(request, res, self.reactor)
+ channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo")
self.assertEqual(channel.result["code"], b"500")
@@ -108,8 +108,7 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
)
- request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo")
- render(request, res, self.reactor)
+ channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo")
self.assertEqual(channel.result["code"], b"500")
@@ -127,8 +126,7 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
)
- request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo")
- render(request, res, self.reactor)
+ channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo")
self.assertEqual(channel.result["code"], b"403")
self.assertEqual(channel.json_body["error"], "Forbidden!!one!")
@@ -150,8 +148,7 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
)
- request, channel = make_request(self.reactor, b"GET", b"/_matrix/foobar")
- render(request, res, self.reactor)
+ channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foobar")
self.assertEqual(channel.result["code"], b"400")
self.assertEqual(channel.json_body["error"], "Unrecognized request")
@@ -173,8 +170,7 @@ class JsonResourceTests(unittest.TestCase):
)
# The path was registered as GET, but this is a HEAD request.
- request, channel = make_request(self.reactor, b"HEAD", b"/_matrix/foo")
- render(request, res, self.reactor)
+ channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/_matrix/foo")
self.assertEqual(channel.result["code"], b"200")
self.assertNotIn("body", channel.result)
@@ -196,9 +192,6 @@ class OptionsResourceTests(unittest.TestCase):
def _make_request(self, method, path):
"""Create a request from the method/path and return a channel with the response."""
- request, channel = make_request(self.reactor, method, path, shorthand=False)
- request.prepath = [] # This doesn't get set properly by make_request.
-
# Create a site and query for the resource.
site = SynapseSite(
"test",
@@ -207,11 +200,9 @@ class OptionsResourceTests(unittest.TestCase):
self.resource,
"1.0",
)
- request.site = site
- resource = site.getResourceFor(request)
- # Finally, render the resource and return the channel.
- render(request, resource, self.reactor)
+ # render the request and return the channel
+ channel = make_request(self.reactor, site, method, path, shorthand=False)
return channel
def test_unknown_options_request(self):
@@ -284,8 +275,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback
- request, channel = make_request(self.reactor, b"GET", b"/path")
- render(request, res, self.reactor)
+ channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
self.assertEqual(channel.result["code"], b"200")
body = channel.result["body"]
@@ -303,8 +293,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback
- request, channel = make_request(self.reactor, b"GET", b"/path")
- render(request, res, self.reactor)
+ channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
self.assertEqual(channel.result["code"], b"301")
headers = channel.result["headers"]
@@ -325,8 +314,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback
- request, channel = make_request(self.reactor, b"GET", b"/path")
- render(request, res, self.reactor)
+ channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
self.assertEqual(channel.result["code"], b"304")
headers = channel.result["headers"]
@@ -345,8 +333,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback
- request, channel = make_request(self.reactor, b"HEAD", b"/path")
- render(request, res, self.reactor)
+ channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/path")
self.assertEqual(channel.result["code"], b"200")
self.assertNotIn("body", channel.result)
diff --git a/tests/test_state.py b/tests/test_state.py
index 80b0ccbc40..6227a3ba95 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -169,6 +169,7 @@ class StateTestCase(unittest.TestCase):
"get_state_handler",
"get_clock",
"get_state_resolution_handler",
+ "hostname",
]
)
hs.config = default_config("tesths", True)
diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py
index b89798336c..a743cdc3a9 100644
--- a/tests/test_terms_auth.py
+++ b/tests/test_terms_auth.py
@@ -53,8 +53,7 @@ class TermsTestCase(unittest.HomeserverTestCase):
def test_ui_auth(self):
# Do a UI auth request
request_data = json.dumps({"username": "kermit", "password": "monkey"})
- request, channel = self.make_request(b"POST", self.url, request_data)
- self.render(request)
+ channel = self.make_request(b"POST", self.url, request_data)
self.assertEquals(channel.result["code"], b"401", channel.result)
@@ -97,8 +96,7 @@ class TermsTestCase(unittest.HomeserverTestCase):
self.registration_handler.check_username = Mock(return_value=True)
- request, channel = self.make_request(b"POST", self.url, request_data)
- self.render(request)
+ channel = self.make_request(b"POST", self.url, request_data)
# We don't bother checking that the response is correct - we'll leave that to
# other tests. We just want to make sure we're on the right path.
@@ -115,8 +113,7 @@ class TermsTestCase(unittest.HomeserverTestCase):
},
}
)
- request, channel = self.make_request(b"POST", self.url, request_data)
- self.render(request)
+ channel = self.make_request(b"POST", self.url, request_data)
# We're interested in getting a response that looks like a successful
# registration, not so much that the details are exactly what we want.
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index a298cc0fd3..43898d8142 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -17,8 +17,17 @@
"""
Utilities for running the unit tests
"""
+import sys
+import warnings
from asyncio import Future
-from typing import Any, Awaitable, TypeVar
+from typing import Any, Awaitable, Callable, TypeVar
+
+from mock import Mock
+
+import attr
+
+from twisted.python.failure import Failure
+from twisted.web.client import ResponseDone
TV = TypeVar("TV")
@@ -48,3 +57,65 @@ def make_awaitable(result: Any) -> Awaitable[Any]:
future = Future() # type: ignore
future.set_result(result)
return future
+
+
+def setup_awaitable_errors() -> Callable[[], None]:
+ """
+ Convert warnings from a non-awaited coroutines into errors.
+ """
+ warnings.simplefilter("error", RuntimeWarning)
+
+ # unraisablehook was added in Python 3.8.
+ if not hasattr(sys, "unraisablehook"):
+ return lambda: None
+
+ # State shared between unraisablehook and check_for_unraisable_exceptions.
+ unraisable_exceptions = []
+ orig_unraisablehook = sys.unraisablehook # type: ignore
+
+ def unraisablehook(unraisable):
+ unraisable_exceptions.append(unraisable.exc_value)
+
+ def cleanup():
+ """
+ A method to be used as a clean-up that fails a test-case if there are any new unraisable exceptions.
+ """
+ sys.unraisablehook = orig_unraisablehook # type: ignore
+ if unraisable_exceptions:
+ raise unraisable_exceptions.pop()
+
+ sys.unraisablehook = unraisablehook # type: ignore
+
+ return cleanup
+
+
+def simple_async_mock(return_value=None, raises=None) -> Mock:
+ # AsyncMock is not available in python3.5, this mimics part of its behaviour
+ async def cb(*args, **kwargs):
+ if raises:
+ raise raises
+ return return_value
+
+ return Mock(side_effect=cb)
+
+
+@attr.s
+class FakeResponse:
+ """A fake twisted.web.IResponse object
+
+ there is a similar class at treq.test.test_response, but it lacks a `phrase`
+ attribute, and didn't support deliverBody until recently.
+ """
+
+ # HTTP response code
+ code = attr.ib(type=int)
+
+ # HTTP response phrase (eg b'OK' for a 200)
+ phrase = attr.ib(type=bytes)
+
+ # body of the response
+ body = attr.ib(type=bytes)
+
+ def deliverBody(self, protocol):
+ protocol.dataReceived(self.body)
+ protocol.connectionLost(Failure(ResponseDone()))
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index e93aa84405..c3c4a93e1f 100644
--- a/tests/test_utils/event_injection.py
+++ b/tests/test_utils/event_injection.py
@@ -50,7 +50,7 @@ async def inject_member_event(
sender=sender,
state_key=target,
content=content,
- **kwargs
+ **kwargs,
)
diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py
index fdfb840b62..52ae5c5713 100644
--- a/tests/test_utils/logging_setup.py
+++ b/tests/test_utils/logging_setup.py
@@ -48,7 +48,7 @@ def setup_logging():
handler = ToTwistedHandler()
formatter = logging.Formatter(log_format)
handler.setFormatter(formatter)
- handler.addFilter(LoggingContextFilter(request=""))
+ handler.addFilter(LoggingContextFilter())
root_logger.addHandler(handler)
log_level = os.environ.get("SYNAPSE_TEST_LOG_LEVEL", "ERROR")
diff --git a/tests/unittest.py b/tests/unittest.py
index e654c0442d..af7f752c5a 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -20,7 +20,7 @@ import hmac
import inspect
import logging
import time
-from typing import Optional, Tuple, Type, TypeVar, Union
+from typing import Dict, Iterable, Optional, Tuple, Type, TypeVar, Union
from mock import Mock, patch
@@ -30,6 +30,7 @@ from twisted.internet.defer import Deferred, ensureDeferred, succeed
from twisted.python.failure import Failure
from twisted.python.threadpool import ThreadPool
from twisted.trial import unittest
+from twisted.web.resource import Resource
from synapse.api.constants import EventTypes, Membership
from synapse.config.homeserver import HomeServerConfig
@@ -44,17 +45,12 @@ from synapse.logging.context import (
set_current_context,
)
from synapse.server import HomeServer
-from synapse.types import Requester, UserID, create_requester
+from synapse.types import UserID, create_requester
+from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.ratelimitutils import FederationRateLimiter
-from tests.server import (
- FakeChannel,
- get_clock,
- make_request,
- render,
- setup_test_homeserver,
-)
-from tests.test_utils import event_injection
+from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver
+from tests.test_utils import event_injection, setup_awaitable_errors
from tests.test_utils.logging_setup import setup_logging
from tests.utils import default_config, setupdb
@@ -119,6 +115,10 @@ class TestCase(unittest.TestCase):
logging.getLogger().setLevel(level)
+ # Trial messes with the warnings configuration, thus this has to be
+ # done in the context of an individual TestCase.
+ self.addCleanup(setup_awaitable_errors())
+
return orig()
@around(self)
@@ -235,13 +235,11 @@ class HomeserverTestCase(TestCase):
if not isinstance(self.hs, HomeServer):
raise Exception("A homeserver wasn't returned, but %r" % (self.hs,))
- # Register the resources
- self.resource = self.create_test_json_resource()
-
- # create a site to wrap the resource.
+ # create the root resource, and a site to wrap it.
+ self.resource = self.create_test_resource()
self.site = SynapseSite(
logger_name="synapse.access.http.fake",
- site_tag="test",
+ site_tag=self.hs.config.server.server_name,
config=self.hs.config.server.listeners[0],
resource=self.resource,
server_version_string="1",
@@ -249,22 +247,29 @@ class HomeserverTestCase(TestCase):
from tests.rest.client.v1.utils import RestHelper
- self.helper = RestHelper(self.hs, self.resource, getattr(self, "user_id", None))
+ self.helper = RestHelper(self.hs, self.site, getattr(self, "user_id", None))
if hasattr(self, "user_id"):
if self.hijack_auth:
+ # We need a valid token ID to satisfy foreign key constraints.
+ token_id = self.get_success(
+ self.hs.get_datastore().add_access_token_to_user(
+ self.helper.auth_user_id, "some_fake_token", None, None,
+ )
+ )
+
async def get_user_by_access_token(token=None, allow_guest=False):
return {
"user": UserID.from_string(self.helper.auth_user_id),
- "token_id": 1,
+ "token_id": token_id,
"is_guest": False,
}
async def get_user_by_req(request, allow_guest=False, rights="access"):
return create_requester(
UserID.from_string(self.helper.auth_user_id),
- 1,
+ token_id,
False,
False,
None,
@@ -312,22 +317,32 @@ class HomeserverTestCase(TestCase):
hs = self.setup_test_homeserver()
return hs
- def create_test_json_resource(self):
+ def create_test_resource(self) -> Resource:
"""
- Create a test JsonResource, with the relevant servlets registerd to it
+ Create a the root resource for the test server.
- The default implementation calls each function in `servlets` to do the
- registration.
-
- Returns:
- JsonResource:
+ The default calls `self.create_resource_dict` and builds the resultant dict
+ into a tree.
"""
- resource = JsonResource(self.hs)
+ root_resource = Resource()
+ create_resource_tree(self.create_resource_dict(), root_resource)
+ return root_resource
- for servlet in self.servlets:
- servlet(self.hs, resource)
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ """Create a resource tree for the test server
+
+ A resource tree is a mapping from path to twisted.web.resource.
- return resource
+ The default implementation creates a JsonResource and calls each function in
+ `servlets` to register servlets against it.
+ """
+ servlet_resource = JsonResource(self.hs)
+ for servlet in self.servlets:
+ servlet(self.hs, servlet_resource)
+ return {
+ "/_matrix/client": servlet_resource,
+ "/_synapse/admin": servlet_resource,
+ }
def default_config(self):
"""
@@ -367,7 +382,11 @@ class HomeserverTestCase(TestCase):
shorthand: bool = True,
federation_auth_origin: str = None,
content_is_form: bool = False,
- ) -> Tuple[T, FakeChannel]:
+ await_result: bool = True,
+ custom_headers: Optional[
+ Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
+ ] = None,
+ ) -> FakeChannel:
"""
Create a SynapseRequest at the path using the method and containing the
given content.
@@ -385,14 +404,18 @@ class HomeserverTestCase(TestCase):
content_is_form: Whether the content is URL encoded form data. Adds the
'Content-Type': 'application/x-www-form-urlencoded' header.
+ await_result: whether to wait for the request to complete rendering. If
+ true (the default), will pump the test reactor until the the renderer
+ tells the channel the request is finished.
+
+ custom_headers: (name, value) pairs to add as request headers
+
Returns:
- Tuple[synapse.http.site.SynapseRequest, channel]
+ The FakeChannel object which stores the result of the request.
"""
- if isinstance(content, dict):
- content = json.dumps(content).encode("utf8")
-
return make_request(
self.reactor,
+ self.site,
method,
path,
content,
@@ -401,18 +424,10 @@ class HomeserverTestCase(TestCase):
shorthand,
federation_auth_origin,
content_is_form,
+ await_result,
+ custom_headers,
)
- def render(self, request):
- """
- Render a request against the resources registered by the test class's
- servlets.
-
- Args:
- request (synapse.http.site.SynapseRequest): The request to render.
- """
- render(request, self.resource, self.reactor)
-
def setup_test_homeserver(self, *args, **kwargs):
"""
Set up the test homeserver, meant to be called by the overridable
@@ -505,24 +520,29 @@ class HomeserverTestCase(TestCase):
return result
- def register_user(self, username, password, admin=False):
+ def register_user(
+ self,
+ username: str,
+ password: str,
+ admin: Optional[bool] = False,
+ displayname: Optional[str] = None,
+ ) -> str:
"""
Register a user. Requires the Admin API be registered.
Args:
- username (bytes/unicode): The user part of the new user.
- password (bytes/unicode): The password of the new user.
- admin (bool): Whether the user should be created as an admin
- or not.
+ username: The user part of the new user.
+ password: The password of the new user.
+ admin: Whether the user should be created as an admin or not.
+ displayname: The displayname of the new user.
Returns:
- The MXID of the new user (unicode).
+ The MXID of the new user.
"""
self.hs.config.registration_shared_secret = "shared"
# Create the user
- request, channel = self.make_request("GET", "/_matrix/client/r0/admin/register")
- self.render(request)
+ channel = self.make_request("GET", "/_synapse/admin/v1/register")
self.assertEqual(channel.code, 200, msg=channel.result)
nonce = channel.json_body["nonce"]
@@ -540,16 +560,16 @@ class HomeserverTestCase(TestCase):
{
"nonce": nonce,
"username": username,
+ "displayname": displayname,
"password": password,
"admin": admin,
"mac": want_mac,
"inhibit_login": True,
}
)
- request, channel = self.make_request(
- "POST", "/_matrix/client/r0/admin/register", body.encode("utf8")
+ channel = self.make_request(
+ "POST", "/_synapse/admin/v1/register", body.encode("utf8")
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.json_body)
user_id = channel.json_body["user_id"]
@@ -565,10 +585,9 @@ class HomeserverTestCase(TestCase):
if device_id:
body["device_id"] = device_id
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
access_token = channel.json_body["access_token"]
@@ -590,7 +609,7 @@ class HomeserverTestCase(TestCase):
"""
event_creator = self.hs.get_event_creation_handler()
secrets = self.hs.get_secrets()
- requester = Requester(user, None, False, False, None, None)
+ requester = create_requester(user)
event, context = self.get_success(
event_creator.create_event(
@@ -608,7 +627,9 @@ class HomeserverTestCase(TestCase):
if soft_failed:
event.internal_metadata.soft_failed = True
- self.get_success(event_creator.send_nonmember_event(requester, event, context))
+ self.get_success(
+ event_creator.handle_new_client_event(requester, event, context)
+ )
return event.event_id
@@ -632,10 +653,9 @@ class HomeserverTestCase(TestCase):
"""
body = {"type": "m.login.password", "user": username, "password": password}
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
)
- self.render(request)
self.assertEqual(channel.code, 403, channel.result)
def inject_room_member(self, room: str, user: str, membership: Membership) -> None:
@@ -659,13 +679,29 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
A federating homeserver that authenticates incoming requests as `other.example.com`.
"""
- def prepare(self, reactor, clock, homeserver):
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ d = super().create_resource_dict()
+ d["/_matrix/federation"] = TestTransportLayerServer(self.hs)
+ return d
+
+
+class TestTransportLayerServer(JsonResource):
+ """A test implementation of TransportLayerServer
+
+ authenticates incoming requests as `other.example.com`.
+ """
+
+ def __init__(self, hs):
+ super().__init__(hs)
+
class Authenticator:
def authenticate_request(self, request, content):
return succeed("other.example.com")
+ authenticator = Authenticator()
+
ratelimiter = FederationRateLimiter(
- clock,
+ hs.get_clock(),
FederationRateLimitConfig(
window_size=1,
sleep_limit=1,
@@ -674,11 +710,8 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
concurrent_requests=1000,
),
)
- federation_server.register_servlets(
- homeserver, self.resource, Authenticator(), ratelimiter
- )
- return super().prepare(reactor, clock, homeserver)
+ federation_server.register_servlets(hs, self, authenticator, ratelimiter)
def override_config(extra_config):
diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py
new file mode 100644
index 0000000000..dadfabd46d
--- /dev/null
+++ b/tests/util/caches/test_deferred_cache.py
@@ -0,0 +1,251 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import partial
+
+from twisted.internet import defer
+
+from synapse.util.caches.deferred_cache import DeferredCache
+
+from tests.unittest import TestCase
+
+
+class DeferredCacheTestCase(TestCase):
+ def test_empty(self):
+ cache = DeferredCache("test")
+ failed = False
+ try:
+ cache.get("foo")
+ except KeyError:
+ failed = True
+
+ self.assertTrue(failed)
+
+ def test_hit(self):
+ cache = DeferredCache("test")
+ cache.prefill("foo", 123)
+
+ self.assertEquals(self.successResultOf(cache.get("foo")), 123)
+
+ def test_hit_deferred(self):
+ cache = DeferredCache("test")
+ origin_d = defer.Deferred()
+ set_d = cache.set("k1", origin_d)
+
+ # get should return an incomplete deferred
+ get_d = cache.get("k1")
+ self.assertFalse(get_d.called)
+
+ # add a callback that will make sure that the set_d gets called before the get_d
+ def check1(r):
+ self.assertTrue(set_d.called)
+ return r
+
+ # TODO: Actually ObservableDeferred *doesn't* run its tests in order on py3.8.
+ # maybe we should fix that?
+ # get_d.addCallback(check1)
+
+ # now fire off all the deferreds
+ origin_d.callback(99)
+ self.assertEqual(self.successResultOf(origin_d), 99)
+ self.assertEqual(self.successResultOf(set_d), 99)
+ self.assertEqual(self.successResultOf(get_d), 99)
+
+ def test_callbacks(self):
+ """Invalidation callbacks are called at the right time"""
+ cache = DeferredCache("test")
+ callbacks = set()
+
+ # start with an entry, with a callback
+ cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill"))
+
+ # now replace that entry with a pending result
+ origin_d = defer.Deferred()
+ set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set"))
+
+ # ... and also make a get request
+ get_d = cache.get("k1", callback=lambda: callbacks.add("get"))
+
+ # we don't expect the invalidation callback for the original value to have
+ # been called yet, even though get() will now return a different result.
+ # I'm not sure if that is by design or not.
+ self.assertEqual(callbacks, set())
+
+ # now fire off all the deferreds
+ origin_d.callback(20)
+ self.assertEqual(self.successResultOf(set_d), 20)
+ self.assertEqual(self.successResultOf(get_d), 20)
+
+ # now the original invalidation callback should have been called, but none of
+ # the others
+ self.assertEqual(callbacks, {"prefill"})
+ callbacks.clear()
+
+ # another update should invalidate both the previous results
+ cache.prefill("k1", 30)
+ self.assertEqual(callbacks, {"set", "get"})
+
+ def test_set_fail(self):
+ cache = DeferredCache("test")
+ callbacks = set()
+
+ # start with an entry, with a callback
+ cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill"))
+
+ # now replace that entry with a pending result
+ origin_d = defer.Deferred()
+ set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set"))
+
+ # ... and also make a get request
+ get_d = cache.get("k1", callback=lambda: callbacks.add("get"))
+
+ # none of the callbacks should have been called yet
+ self.assertEqual(callbacks, set())
+
+ # oh noes! fails!
+ e = Exception("oops")
+ origin_d.errback(e)
+ self.assertIs(self.failureResultOf(set_d, Exception).value, e)
+ self.assertIs(self.failureResultOf(get_d, Exception).value, e)
+
+ # the callbacks for the failed requests should have been called.
+ # I'm not sure if this is deliberate or not.
+ self.assertEqual(callbacks, {"get", "set"})
+ callbacks.clear()
+
+ # the old value should still be returned now?
+ get_d2 = cache.get("k1", callback=lambda: callbacks.add("get2"))
+ self.assertEqual(self.successResultOf(get_d2), 10)
+
+ # replacing the value now should run the callbacks for those requests
+ # which got the original result
+ cache.prefill("k1", 30)
+ self.assertEqual(callbacks, {"prefill", "get2"})
+
+ def test_get_immediate(self):
+ cache = DeferredCache("test")
+ d1 = defer.Deferred()
+ cache.set("key1", d1)
+
+ # get_immediate should return default
+ v = cache.get_immediate("key1", 1)
+ self.assertEqual(v, 1)
+
+ # now complete the set
+ d1.callback(2)
+
+ # get_immediate should return result
+ v = cache.get_immediate("key1", 1)
+ self.assertEqual(v, 2)
+
+ def test_invalidate(self):
+ cache = DeferredCache("test")
+ cache.prefill(("foo",), 123)
+ cache.invalidate(("foo",))
+
+ failed = False
+ try:
+ cache.get(("foo",))
+ except KeyError:
+ failed = True
+
+ self.assertTrue(failed)
+
+ def test_invalidate_all(self):
+ cache = DeferredCache("testcache")
+
+ callback_record = [False, False]
+
+ def record_callback(idx):
+ callback_record[idx] = True
+
+ # add a couple of pending entries
+ d1 = defer.Deferred()
+ cache.set("key1", d1, partial(record_callback, 0))
+
+ d2 = defer.Deferred()
+ cache.set("key2", d2, partial(record_callback, 1))
+
+ # lookup should return pending deferreds
+ self.assertFalse(cache.get("key1").called)
+ self.assertFalse(cache.get("key2").called)
+
+ # let one of the lookups complete
+ d2.callback("result2")
+
+ # now the cache will return a completed deferred
+ self.assertEqual(self.successResultOf(cache.get("key2")), "result2")
+
+ # now do the invalidation
+ cache.invalidate_all()
+
+ # lookup should fail
+ with self.assertRaises(KeyError):
+ cache.get("key1")
+ with self.assertRaises(KeyError):
+ cache.get("key2")
+
+ # both callbacks should have been callbacked
+ self.assertTrue(callback_record[0], "Invalidation callback for key1 not called")
+ self.assertTrue(callback_record[1], "Invalidation callback for key2 not called")
+
+ # letting the other lookup complete should do nothing
+ d1.callback("result1")
+ with self.assertRaises(KeyError):
+ cache.get("key1", None)
+
+ def test_eviction(self):
+ cache = DeferredCache(
+ "test", max_entries=2, apply_cache_factor_from_config=False
+ )
+
+ cache.prefill(1, "one")
+ cache.prefill(2, "two")
+ cache.prefill(3, "three") # 1 will be evicted
+
+ failed = False
+ try:
+ cache.get(1)
+ except KeyError:
+ failed = True
+
+ self.assertTrue(failed)
+
+ cache.get(2)
+ cache.get(3)
+
+ def test_eviction_lru(self):
+ cache = DeferredCache(
+ "test", max_entries=2, apply_cache_factor_from_config=False
+ )
+
+ cache.prefill(1, "one")
+ cache.prefill(2, "two")
+
+ # Now access 1 again, thus causing 2 to be least-recently used
+ cache.get(1)
+
+ cache.prefill(3, "three")
+
+ failed = False
+ try:
+ cache.get(2)
+ except KeyError:
+ failed = True
+
+ self.assertTrue(failed)
+
+ cache.get(1)
+ cache.get(3)
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 677e925477..cf1e3203a4 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from functools import partial
+from typing import Set
import mock
@@ -29,60 +29,50 @@ from synapse.logging.context import (
make_deferred_yieldable,
)
from synapse.util.caches import descriptors
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import cached, lru_cache
from tests import unittest
+from tests.test_utils import get_awaitable_result
logger = logging.getLogger(__name__)
-def run_on_reactor():
- d = defer.Deferred()
- reactor.callLater(0, d.callback, 0)
- return make_deferred_yieldable(d)
-
-
-class CacheTestCase(unittest.TestCase):
- def test_invalidate_all(self):
- cache = descriptors.Cache("testcache")
-
- callback_record = [False, False]
-
- def record_callback(idx):
- callback_record[idx] = True
-
- # add a couple of pending entries
- d1 = defer.Deferred()
- cache.set("key1", d1, partial(record_callback, 0))
-
- d2 = defer.Deferred()
- cache.set("key2", d2, partial(record_callback, 1))
-
- # lookup should return observable deferreds
- self.assertFalse(cache.get("key1").has_called())
- self.assertFalse(cache.get("key2").has_called())
+class LruCacheDecoratorTestCase(unittest.TestCase):
+ def test_base(self):
+ class Cls:
+ def __init__(self):
+ self.mock = mock.Mock()
- # let one of the lookups complete
- d2.callback("result2")
+ @lru_cache()
+ def fn(self, arg1, arg2):
+ return self.mock(arg1, arg2)
- # for now at least, the cache will return real results rather than an
- # observabledeferred
- self.assertEqual(cache.get("key2"), "result2")
+ obj = Cls()
+ obj.mock.return_value = "fish"
+ r = obj.fn(1, 2)
+ self.assertEqual(r, "fish")
+ obj.mock.assert_called_once_with(1, 2)
+ obj.mock.reset_mock()
- # now do the invalidation
- cache.invalidate_all()
+ # a call with different params should call the mock again
+ obj.mock.return_value = "chips"
+ r = obj.fn(1, 3)
+ self.assertEqual(r, "chips")
+ obj.mock.assert_called_once_with(1, 3)
+ obj.mock.reset_mock()
- # lookup should return none
- self.assertIsNone(cache.get("key1", None))
- self.assertIsNone(cache.get("key2", None))
+ # the two values should now be cached
+ r = obj.fn(1, 2)
+ self.assertEqual(r, "fish")
+ r = obj.fn(1, 3)
+ self.assertEqual(r, "chips")
+ obj.mock.assert_not_called()
- # both callbacks should have been callbacked
- self.assertTrue(callback_record[0], "Invalidation callback for key1 not called")
- self.assertTrue(callback_record[1], "Invalidation callback for key2 not called")
- # letting the other lookup complete should do nothing
- d1.callback("result1")
- self.assertIsNone(cache.get("key1", None))
+def run_on_reactor():
+ d = defer.Deferred()
+ reactor.callLater(0, d.callback, 0)
+ return make_deferred_yieldable(d)
class DescriptorTestCase(unittest.TestCase):
@@ -174,6 +164,57 @@ class DescriptorTestCase(unittest.TestCase):
d = obj.fn(1)
self.failureResultOf(d, SynapseError)
+ def test_cache_with_async_exception(self):
+ """The wrapped function returns a failure
+ """
+
+ class Cls:
+ result = None
+ call_count = 0
+
+ @cached()
+ def fn(self, arg1):
+ self.call_count += 1
+ return self.result
+
+ obj = Cls()
+ callbacks = set() # type: Set[str]
+
+ # set off an asynchronous request
+ obj.result = origin_d = defer.Deferred()
+
+ d1 = obj.fn(1, on_invalidate=lambda: callbacks.add("d1"))
+ self.assertFalse(d1.called)
+
+ # a second request should also return a deferred, but should not call the
+ # function itself.
+ d2 = obj.fn(1, on_invalidate=lambda: callbacks.add("d2"))
+ self.assertFalse(d2.called)
+ self.assertEqual(obj.call_count, 1)
+
+ # no callbacks yet
+ self.assertEqual(callbacks, set())
+
+ # the original request fails
+ e = Exception("bzz")
+ origin_d.errback(e)
+
+ # ... which should cause the lookups to fail similarly
+ self.assertIs(self.failureResultOf(d1, Exception).value, e)
+ self.assertIs(self.failureResultOf(d2, Exception).value, e)
+
+ # ... and the callbacks to have been, uh, called.
+ self.assertEqual(callbacks, {"d1", "d2"})
+
+ # ... leaving the cache empty
+ self.assertEqual(len(obj.fn.cache.cache), 0)
+
+ # and a second call should work as normal
+ obj.result = defer.succeed(100)
+ d3 = obj.fn(1)
+ self.assertEqual(self.successResultOf(d3), 100)
+ self.assertEqual(obj.call_count, 2)
+
def test_cache_logcontexts(self):
"""Check that logcontexts are set and restored correctly when
using the cache."""
@@ -354,6 +395,260 @@ class DescriptorTestCase(unittest.TestCase):
d = obj.fn(1)
self.failureResultOf(d, SynapseError)
+ def test_invalidate_cascade(self):
+ """Invalidations should cascade up through cache contexts"""
+
+ class Cls:
+ @cached(cache_context=True)
+ async def func1(self, key, cache_context):
+ return await self.func2(key, on_invalidate=cache_context.invalidate)
+
+ @cached(cache_context=True)
+ async def func2(self, key, cache_context):
+ return self.func3(key, on_invalidate=cache_context.invalidate)
+
+ @lru_cache(cache_context=True)
+ def func3(self, key, cache_context):
+ self.invalidate = cache_context.invalidate
+ return 42
+
+ obj = Cls()
+
+ top_invalidate = mock.Mock()
+ r = get_awaitable_result(obj.func1("k1", on_invalidate=top_invalidate))
+ self.assertEqual(r, 42)
+ obj.invalidate()
+ top_invalidate.assert_called_once()
+
+
+class CacheDecoratorTestCase(unittest.HomeserverTestCase):
+ """More tests for @cached
+
+ The following is a set of tests that got lost in a different file for a while.
+
+ There are probably duplicates of the tests in DescriptorTestCase. Ideally the
+ duplicates would be removed and the two sets of classes combined.
+ """
+
+ @defer.inlineCallbacks
+ def test_passthrough(self):
+ class A:
+ @cached()
+ def func(self, key):
+ return key
+
+ a = A()
+
+ self.assertEquals((yield a.func("foo")), "foo")
+ self.assertEquals((yield a.func("bar")), "bar")
+
+ @defer.inlineCallbacks
+ def test_hit(self):
+ callcount = [0]
+
+ class A:
+ @cached()
+ def func(self, key):
+ callcount[0] += 1
+ return key
+
+ a = A()
+ yield a.func("foo")
+
+ self.assertEquals(callcount[0], 1)
+
+ self.assertEquals((yield a.func("foo")), "foo")
+ self.assertEquals(callcount[0], 1)
+
+ @defer.inlineCallbacks
+ def test_invalidate(self):
+ callcount = [0]
+
+ class A:
+ @cached()
+ def func(self, key):
+ callcount[0] += 1
+ return key
+
+ a = A()
+ yield a.func("foo")
+
+ self.assertEquals(callcount[0], 1)
+
+ a.func.invalidate(("foo",))
+
+ yield a.func("foo")
+
+ self.assertEquals(callcount[0], 2)
+
+ def test_invalidate_missing(self):
+ class A:
+ @cached()
+ def func(self, key):
+ return key
+
+ A().func.invalidate(("what",))
+
+ @defer.inlineCallbacks
+ def test_max_entries(self):
+ callcount = [0]
+
+ class A:
+ @cached(max_entries=10)
+ def func(self, key):
+ callcount[0] += 1
+ return key
+
+ a = A()
+
+ for k in range(0, 12):
+ yield a.func(k)
+
+ self.assertEquals(callcount[0], 12)
+
+ # There must have been at least 2 evictions, meaning if we calculate
+ # all 12 values again, we must get called at least 2 more times
+ for k in range(0, 12):
+ yield a.func(k)
+
+ self.assertTrue(
+ callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0])
+ )
+
+ def test_prefill(self):
+ callcount = [0]
+
+ d = defer.succeed(123)
+
+ class A:
+ @cached()
+ def func(self, key):
+ callcount[0] += 1
+ return d
+
+ a = A()
+
+ a.func.prefill(("foo",), 456)
+
+ self.assertEquals(a.func("foo").result, 456)
+ self.assertEquals(callcount[0], 0)
+
+ @defer.inlineCallbacks
+ def test_invalidate_context(self):
+ callcount = [0]
+ callcount2 = [0]
+
+ class A:
+ @cached()
+ def func(self, key):
+ callcount[0] += 1
+ return key
+
+ @cached(cache_context=True)
+ def func2(self, key, cache_context):
+ callcount2[0] += 1
+ return self.func(key, on_invalidate=cache_context.invalidate)
+
+ a = A()
+ yield a.func2("foo")
+
+ self.assertEquals(callcount[0], 1)
+ self.assertEquals(callcount2[0], 1)
+
+ a.func.invalidate(("foo",))
+ yield a.func("foo")
+
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 1)
+
+ yield a.func2("foo")
+
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 2)
+
+ @defer.inlineCallbacks
+ def test_eviction_context(self):
+ callcount = [0]
+ callcount2 = [0]
+
+ class A:
+ @cached(max_entries=2)
+ def func(self, key):
+ callcount[0] += 1
+ return key
+
+ @cached(cache_context=True)
+ def func2(self, key, cache_context):
+ callcount2[0] += 1
+ return self.func(key, on_invalidate=cache_context.invalidate)
+
+ a = A()
+ yield a.func2("foo")
+ yield a.func2("foo2")
+
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 2)
+
+ yield a.func2("foo")
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 2)
+
+ yield a.func("foo3")
+
+ self.assertEquals(callcount[0], 3)
+ self.assertEquals(callcount2[0], 2)
+
+ yield a.func2("foo")
+
+ self.assertEquals(callcount[0], 4)
+ self.assertEquals(callcount2[0], 3)
+
+ @defer.inlineCallbacks
+ def test_double_get(self):
+ callcount = [0]
+ callcount2 = [0]
+
+ class A:
+ @cached()
+ def func(self, key):
+ callcount[0] += 1
+ return key
+
+ @cached(cache_context=True)
+ def func2(self, key, cache_context):
+ callcount2[0] += 1
+ return self.func(key, on_invalidate=cache_context.invalidate)
+
+ a = A()
+ a.func2.cache.cache = mock.Mock(wraps=a.func2.cache.cache)
+
+ yield a.func2("foo")
+
+ self.assertEquals(callcount[0], 1)
+ self.assertEquals(callcount2[0], 1)
+
+ a.func2.invalidate(("foo",))
+ self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
+
+ yield a.func2("foo")
+ a.func2.invalidate(("foo",))
+ self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
+
+ self.assertEquals(callcount[0], 1)
+ self.assertEquals(callcount2[0], 2)
+
+ a.func.invalidate(("foo",))
+ self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
+ yield a.func("foo")
+
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 2)
+
+ yield a.func2("foo")
+
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 3)
+
class CachedListDescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks
diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index 0adb2174af..a739a6aaaf 100644
--- a/tests/util/test_lrucache.py
+++ b/tests/util/test_lrucache.py
@@ -19,7 +19,8 @@ from mock import Mock
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache
-from .. import unittest
+from tests import unittest
+from tests.unittest import override_config
class LruCacheTestCase(unittest.HomeserverTestCase):
@@ -59,7 +60,7 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
self.assertEquals(cache.pop("key"), None)
def test_del_multi(self):
- cache = LruCache(4, 2, cache_type=TreeCache)
+ cache = LruCache(4, keylen=2, cache_type=TreeCache)
cache[("animal", "cat")] = "mew"
cache[("animal", "dog")] = "woof"
cache[("vehicles", "car")] = "vroom"
@@ -83,6 +84,11 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
cache.clear()
self.assertEquals(len(cache), 0)
+ @override_config({"caches": {"per_cache_factors": {"mycache": 10}}})
+ def test_special_size(self):
+ cache = LruCache(10, "mycache")
+ self.assertEqual(cache.max_size, 100)
+
class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
def test_get(self):
@@ -160,7 +166,7 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
m2 = Mock()
m3 = Mock()
m4 = Mock()
- cache = LruCache(4, 2, cache_type=TreeCache)
+ cache = LruCache(4, keylen=2, cache_type=TreeCache)
cache.set(("a", "1"), "value", callbacks=[m1])
cache.set(("a", "2"), "value", callbacks=[m2])
diff --git a/tests/utils.py b/tests/utils.py
index 4673872f88..977eeaf6ee 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -20,12 +20,12 @@ import os
import time
import uuid
import warnings
-from inspect import getcallargs
+from typing import Type
from urllib import parse as urlparse
from mock import Mock, patch
-from twisted.internet import defer, reactor
+from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.api.errors import CodeMessageException, cs_error
@@ -33,14 +33,13 @@ from synapse.api.room_versions import RoomVersions
from synapse.config.database import DatabaseConnectionConfig
from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import DEFAULT_ROOM_VERSION
-from synapse.federation.transport import server as federation_server
from synapse.http.server import HttpServer
from synapse.logging.context import current_context, set_current_context
from synapse.server import HomeServer
from synapse.storage import DataStore
+from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.engines import PostgresEngine, create_engine
from synapse.storage.prepare_database import prepare_database
-from synapse.util.ratelimitutils import FederationRateLimiter
# set this to True to run the tests against postgres instead of sqlite.
#
@@ -88,6 +87,7 @@ def setupdb():
host=POSTGRES_HOST,
password=POSTGRES_PASSWORD,
)
+ db_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests")
prepare_database(db_conn, db_engine, None)
db_conn.close()
@@ -190,11 +190,10 @@ class TestHomeServer(HomeServer):
def setup_test_homeserver(
cleanup_func,
name="test",
- datastore=None,
config=None,
reactor=None,
- homeserverToUse=TestHomeServer,
- **kargs
+ homeserver_to_use: Type[HomeServer] = TestHomeServer,
+ **kwargs
):
"""
Setup a homeserver suitable for running tests against. Keyword arguments
@@ -217,8 +216,8 @@ def setup_test_homeserver(
config.ldap_enabled = False
- if "clock" not in kargs:
- kargs["clock"] = MockClock()
+ if "clock" not in kwargs:
+ kwargs["clock"] = MockClock()
if USE_POSTGRES_FOR_TESTS:
test_db = "synapse_test_%s" % uuid.uuid4().hex
@@ -247,7 +246,7 @@ def setup_test_homeserver(
# Create the database before we actually try and connect to it, based off
# the template database we generate in setupdb()
- if datastore is None and isinstance(db_engine, PostgresEngine):
+ if isinstance(db_engine, PostgresEngine):
db_conn = db_engine.module.connect(
database=POSTGRES_BASE_DB,
user=POSTGRES_USER,
@@ -263,79 +262,68 @@ def setup_test_homeserver(
cur.close()
db_conn.close()
- if datastore is None:
- hs = homeserverToUse(
- name,
- config=config,
- version_string="Synapse/tests",
- tls_server_context_factory=Mock(),
- tls_client_options_factory=Mock(),
- reactor=reactor,
- **kargs
- )
+ hs = homeserver_to_use(
+ name, config=config, version_string="Synapse/tests", reactor=reactor,
+ )
- hs.setup()
- if homeserverToUse.__name__ == "TestHomeServer":
- hs.setup_master()
-
- if isinstance(db_engine, PostgresEngine):
- database = hs.get_datastores().databases[0]
-
- # We need to do cleanup on PostgreSQL
- def cleanup():
- import psycopg2
-
- # Close all the db pools
- database._db_pool.close()
-
- dropped = False
-
- # Drop the test database
- db_conn = db_engine.module.connect(
- database=POSTGRES_BASE_DB,
- user=POSTGRES_USER,
- host=POSTGRES_HOST,
- password=POSTGRES_PASSWORD,
- )
- db_conn.autocommit = True
- cur = db_conn.cursor()
-
- # Try a few times to drop the DB. Some things may hold on to the
- # database for a few more seconds due to flakiness, preventing
- # us from dropping it when the test is over. If we can't drop
- # it, warn and move on.
- for x in range(5):
- try:
- cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
- db_conn.commit()
- dropped = True
- except psycopg2.OperationalError as e:
- warnings.warn(
- "Couldn't drop old db: " + str(e), category=UserWarning
- )
- time.sleep(0.5)
-
- cur.close()
- db_conn.close()
-
- if not dropped:
- warnings.warn("Failed to drop old DB.", category=UserWarning)
-
- if not LEAVE_DB:
- # Register the cleanup hook
- cleanup_func(cleanup)
+ # Install @cache_in_self attributes
+ for key, val in kwargs.items():
+ setattr(hs, "_" + key, val)
- else:
- hs = homeserverToUse(
- name,
- datastore=datastore,
- config=config,
- version_string="Synapse/tests",
- tls_server_context_factory=Mock(),
- tls_client_options_factory=Mock(),
- reactor=reactor,
- **kargs
- )
+ # Mock TLS
+ hs.tls_server_context_factory = Mock()
+ hs.tls_client_options_factory = Mock()
+
+ hs.setup()
+ if homeserver_to_use == TestHomeServer:
+ hs.setup_background_tasks()
+
+ if isinstance(db_engine, PostgresEngine):
+ database = hs.get_datastores().databases[0]
+
+ # We need to do cleanup on PostgreSQL
+ def cleanup():
+ import psycopg2
+
+ # Close all the db pools
+ database._db_pool.close()
+
+ dropped = False
+
+ # Drop the test database
+ db_conn = db_engine.module.connect(
+ database=POSTGRES_BASE_DB,
+ user=POSTGRES_USER,
+ host=POSTGRES_HOST,
+ password=POSTGRES_PASSWORD,
+ )
+ db_conn.autocommit = True
+ cur = db_conn.cursor()
+
+ # Try a few times to drop the DB. Some things may hold on to the
+ # database for a few more seconds due to flakiness, preventing
+ # us from dropping it when the test is over. If we can't drop
+ # it, warn and move on.
+ for x in range(5):
+ try:
+ cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
+ db_conn.commit()
+ dropped = True
+ except psycopg2.OperationalError as e:
+ warnings.warn(
+ "Couldn't drop old db: " + str(e), category=UserWarning
+ )
+ time.sleep(0.5)
+
+ cur.close()
+ db_conn.close()
+
+ if not dropped:
+ warnings.warn("Failed to drop old DB.", category=UserWarning)
+
+ if not LEAVE_DB:
+ # Register the cleanup hook
+ cleanup_func(cleanup)
# bcrypt is far too slow to be doing in unit tests
# Need to let the HS build an auth handler and then mess with it
@@ -351,32 +339,9 @@ def setup_test_homeserver(
hs.get_auth_handler().validate_hash = validate_hash
- fed = kargs.get("resource_for_federation", None)
- if fed:
- register_federation_servlets(hs, fed)
-
return hs
-def register_federation_servlets(hs, resource):
- federation_server.register_servlets(
- hs,
- resource=resource,
- authenticator=federation_server.Authenticator(hs),
- ratelimiter=FederationRateLimiter(
- hs.get_clock(), config=hs.config.rc_federation
- ),
- )
-
-
-def get_mock_call_args(pattern_func, mock_func):
- """ Return the arguments the mock function was called with interpreted
- by the pattern functions argument list.
- """
- invoked_args, invoked_kargs = mock_func.call_args
- return getcallargs(pattern_func, *invoked_args, **invoked_kargs)
-
-
def mock_getRawHeaders(headers=None):
headers = headers if headers is not None else {}
@@ -562,86 +527,6 @@ class MockClock:
return d
-def _format_call(args, kwargs):
- return ", ".join(
- ["%r" % (a) for a in args] + ["%s=%r" % (k, v) for k, v in kwargs.items()]
- )
-
-
-class DeferredMockCallable:
- """A callable instance that stores a set of pending call expectations and
- return values for them. It allows a unit test to assert that the given set
- of function calls are eventually made, by awaiting on them to be called.
- """
-
- def __init__(self):
- self.expectations = []
- self.calls = []
-
- def __call__(self, *args, **kwargs):
- self.calls.append((args, kwargs))
-
- if not self.expectations:
- raise ValueError(
- "%r has no pending calls to handle call(%s)"
- % (self, _format_call(args, kwargs))
- )
-
- for (call, result, d) in self.expectations:
- if args == call[1] and kwargs == call[2]:
- d.callback(None)
- return result
-
- failure = AssertionError(
- "Was not expecting call(%s)" % (_format_call(args, kwargs))
- )
-
- for _, _, d in self.expectations:
- try:
- d.errback(failure)
- except Exception:
- pass
-
- raise failure
-
- def expect_call_and_return(self, call, result):
- self.expectations.append((call, result, defer.Deferred()))
-
- @defer.inlineCallbacks
- def await_calls(self, timeout=1000):
- deferred = defer.DeferredList(
- [d for _, _, d in self.expectations], fireOnOneErrback=True
- )
-
- timer = reactor.callLater(
- timeout / 1000,
- deferred.errback,
- AssertionError(
- "%d pending calls left: %s"
- % (
- len([e for e in self.expectations if not e[2].called]),
- [e for e in self.expectations if not e[2].called],
- )
- ),
- )
-
- yield deferred
-
- timer.cancel()
-
- self.calls = []
-
- def assert_had_no_calls(self):
- if self.calls:
- calls = self.calls
- self.calls = []
-
- raise AssertionError(
- "Expected not to received any calls, got:\n"
- + "\n".join(["call(%s)" % _format_call(c[0], c[1]) for c in calls])
- )
-
-
async def create_room(hs, room_id: str, creator_id: str):
"""Creates and persist a creation event for the given room
"""
|