diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 745c295d3b..cbecc1c20f 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -197,6 +197,18 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# self.assertFalse(d.called)
self.get_success(d)
+ def test_verify_for_server_locally(self):
+ """Ensure that locally signed JSON can be verified without fetching keys
+ over federation
+ """
+ kr = keyring.Keyring(self.hs)
+ json1 = {}
+ signedjson.sign.sign_json(json1, self.hs.hostname, self.hs.signing_key)
+
+ # Test that verify_json_for_server succeeds on a object signed by ourselves
+ d = kr.verify_json_for_server(self.hs.hostname, json1, 0)
+ self.get_success(d)
+
def test_verify_json_for_server_with_null_valid_until_ms(self):
"""Tests that we correctly handle key requests for keys we've stored
with a null `ts_valid_until_ms`
diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py
index 59de1142b1..abf2a0fe0d 100644
--- a/tests/handlers/test_admin.py
+++ b/tests/handlers/test_admin.py
@@ -17,8 +17,9 @@ from unittest.mock import Mock
import synapse.rest.admin
import synapse.storage
-from synapse.api.constants import EventTypes
-from synapse.rest.client import login, room
+from synapse.api.constants import EventTypes, JoinRules
+from synapse.api.room_versions import RoomVersions
+from synapse.rest.client import knock, login, room
from tests import unittest
@@ -28,6 +29,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
room.register_servlets,
+ knock.register_servlets,
]
def prepare(self, reactor, clock, hs):
@@ -201,3 +203,32 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.assertEqual(args[0], room_id)
self.assertEqual(args[1].content["membership"], "invite")
self.assertTrue(args[2]) # Assert there is at least one bit of state
+
+ def test_knock(self):
+ """Tests that knock get handled correctly."""
+ # create a knockable v7 room
+ room_id = self.helper.create_room_as(
+ self.user1, room_version=RoomVersions.V7.identifier, tok=self.token1
+ )
+ self.helper.send_state(
+ room_id,
+ EventTypes.JoinRules,
+ {"join_rule": JoinRules.KNOCK},
+ tok=self.token1,
+ )
+
+ self.helper.send(room_id, body="Hello!", tok=self.token1)
+ self.helper.knock(room_id, self.user2, tok=self.token2)
+
+ writer = Mock()
+
+ self.get_success(self.admin_handler.export_user_data(self.user2, writer))
+
+ writer.write_events.assert_not_called()
+ writer.write_state.assert_not_called()
+ writer.write_knock.assert_called_once()
+
+ args = writer.write_knock.call_args[0]
+ self.assertEqual(args[0], room_id)
+ self.assertEqual(args[1].content["membership"], "knock")
+ self.assertTrue(args[2]) # Assert there is at least one bit of state
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 43998020b2..1f6a924452 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -40,6 +40,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
hs.get_application_service_scheduler.return_value = self.mock_scheduler
hs.get_clock.return_value = MockClock()
self.handler = ApplicationServicesHandler(hs)
+ self.event_source = hs.get_event_sources()
def test_notify_interested_services(self):
interested_service = self._mkservice(is_interested=True)
@@ -252,6 +253,56 @@ class AppServiceHandlerTestCase(unittest.TestCase):
},
)
+ def test_notify_interested_services_ephemeral(self):
+ """
+ Test sending ephemeral events to the appservice handler are scheduled
+ to be pushed out to interested appservices, and that the stream ID is
+ updated accordingly.
+ """
+ interested_service = self._mkservice(is_interested=True)
+ services = [interested_service]
+
+ self.mock_store.get_app_services.return_value = services
+ self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable(
+ 579
+ )
+
+ event = Mock(event_id="event_1")
+ self.event_source.sources.receipt.get_new_events_as.return_value = (
+ make_awaitable(([event], None))
+ )
+
+ self.handler.notify_interested_services_ephemeral("receipt_key", 580)
+ self.mock_scheduler.submit_ephemeral_events_for_as.assert_called_once_with(
+ interested_service, [event]
+ )
+ self.mock_store.set_type_stream_id_for_appservice.assert_called_once_with(
+ interested_service,
+ "read_receipt",
+ 580,
+ )
+
+ def test_notify_interested_services_ephemeral_out_of_order(self):
+ """
+ Test sending out of order ephemeral events to the appservice handler
+ are ignored.
+ """
+ interested_service = self._mkservice(is_interested=True)
+ services = [interested_service]
+
+ self.mock_store.get_app_services.return_value = services
+ self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable(
+ 580
+ )
+
+ event = Mock(event_id="event_1")
+ self.event_source.sources.receipt.get_new_events_as.return_value = (
+ make_awaitable(([event], None))
+ )
+
+ self.handler.notify_interested_services_ephemeral("receipt_key", 579)
+ self.mock_scheduler.submit_ephemeral_events_for_as.assert_not_called()
+
def _mkservice(self, is_interested, protocols=None):
service = Mock()
service.is_interested.return_value = make_awaitable(is_interested)
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 3ac48e5e95..43031e07ea 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -160,6 +160,37 @@ class DeviceTestCase(unittest.HomeserverTestCase):
# we'd like to check the access token was invalidated, but that's a
# bit of a PITA.
+ def test_delete_device_and_device_inbox(self):
+ self._record_users()
+
+ # add an device_inbox
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "device_inbox",
+ {
+ "user_id": user1,
+ "device_id": "abc",
+ "stream_id": 1,
+ "message_json": "{}",
+ },
+ )
+ )
+
+ # delete the device
+ self.get_success(self.handler.delete_device(user1, "abc"))
+
+ # check that the device_inbox was deleted
+ res = self.get_success(
+ self.store.db_pool.simple_select_one(
+ table="device_inbox",
+ keyvalues={"user_id": user1, "device_id": "abc"},
+ retcols=("user_id", "device_id"),
+ allow_none=True,
+ desc="get_device_id_from_device_inbox",
+ )
+ )
+ self.assertIsNone(res)
+
def test_update_device(self):
self._record_users()
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 39e7b1ab25..0c3b86fda9 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -17,6 +17,8 @@ from unittest import mock
from signedjson import key as key, sign as sign
+from twisted.internet import defer
+
from synapse.api.constants import RoomEncryptionAlgorithms
from synapse.api.errors import Codes, SynapseError
@@ -630,3 +632,152 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
],
other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey],
)
+
+ def test_query_devices_remote_no_sync(self):
+ """Tests that querying keys for a remote user that we don't share a room
+ with returns the cross signing keys correctly.
+ """
+
+ remote_user_id = "@test:other"
+ local_user_id = "@test:test"
+
+ remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
+ remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
+
+ self.hs.get_federation_client().query_client_keys = mock.Mock(
+ return_value=defer.succeed(
+ {
+ "device_keys": {remote_user_id: {}},
+ "master_keys": {
+ remote_user_id: {
+ "user_id": remote_user_id,
+ "usage": ["master"],
+ "keys": {"ed25519:" + remote_master_key: remote_master_key},
+ },
+ },
+ "self_signing_keys": {
+ remote_user_id: {
+ "user_id": remote_user_id,
+ "usage": ["self_signing"],
+ "keys": {
+ "ed25519:"
+ + remote_self_signing_key: remote_self_signing_key
+ },
+ }
+ },
+ }
+ )
+ )
+
+ e2e_handler = self.hs.get_e2e_keys_handler()
+
+ query_result = self.get_success(
+ e2e_handler.query_devices(
+ {
+ "device_keys": {remote_user_id: []},
+ },
+ timeout=10,
+ from_user_id=local_user_id,
+ from_device_id="some_device_id",
+ )
+ )
+
+ self.assertEqual(query_result["failures"], {})
+ self.assertEqual(
+ query_result["master_keys"],
+ {
+ remote_user_id: {
+ "user_id": remote_user_id,
+ "usage": ["master"],
+ "keys": {"ed25519:" + remote_master_key: remote_master_key},
+ },
+ },
+ )
+ self.assertEqual(
+ query_result["self_signing_keys"],
+ {
+ remote_user_id: {
+ "user_id": remote_user_id,
+ "usage": ["self_signing"],
+ "keys": {
+ "ed25519:" + remote_self_signing_key: remote_self_signing_key
+ },
+ }
+ },
+ )
+
+ def test_query_devices_remote_sync(self):
+ """Tests that querying keys for a remote user that we share a room with,
+ but haven't yet fetched the keys for, returns the cross signing keys
+ correctly.
+ """
+
+ remote_user_id = "@test:other"
+ local_user_id = "@test:test"
+
+ self.store.get_rooms_for_user = mock.Mock(
+ return_value=defer.succeed({"some_room_id"})
+ )
+
+ remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
+ remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
+
+ self.hs.get_federation_client().query_user_devices = mock.Mock(
+ return_value=defer.succeed(
+ {
+ "user_id": remote_user_id,
+ "stream_id": 1,
+ "devices": [],
+ "master_key": {
+ "user_id": remote_user_id,
+ "usage": ["master"],
+ "keys": {"ed25519:" + remote_master_key: remote_master_key},
+ },
+ "self_signing_key": {
+ "user_id": remote_user_id,
+ "usage": ["self_signing"],
+ "keys": {
+ "ed25519:"
+ + remote_self_signing_key: remote_self_signing_key
+ },
+ },
+ }
+ )
+ )
+
+ e2e_handler = self.hs.get_e2e_keys_handler()
+
+ query_result = self.get_success(
+ e2e_handler.query_devices(
+ {
+ "device_keys": {remote_user_id: []},
+ },
+ timeout=10,
+ from_user_id=local_user_id,
+ from_device_id="some_device_id",
+ )
+ )
+
+ self.assertEqual(query_result["failures"], {})
+ self.assertEqual(
+ query_result["master_keys"],
+ {
+ remote_user_id: {
+ "user_id": remote_user_id,
+ "usage": ["master"],
+ "keys": {"ed25519:" + remote_master_key: remote_master_key},
+ }
+ },
+ )
+ self.assertEqual(
+ query_result["self_signing_keys"],
+ {
+ remote_user_id: {
+ "user_id": remote_user_id,
+ "usage": ["self_signing"],
+ "keys": {
+ "ed25519:" + remote_self_signing_key: remote_self_signing_key
+ },
+ }
+ },
+ )
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index e915dd5c7c..d16cd141a7 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -15,12 +15,12 @@ from unittest.mock import Mock
from twisted.internet import defer
-from synapse.api.constants import EduTypes
+from synapse.api.constants import EduTypes, EventTypes
from synapse.events import EventBase
from synapse.federation.units import Transaction
from synapse.handlers.presence import UserPresenceState
from synapse.rest import admin
-from synapse.rest.client import login, presence, room
+from synapse.rest.client import login, presence, profile, room
from synapse.types import create_requester
from tests.events.test_presence_router import send_presence_update, sync_presence
@@ -37,6 +37,7 @@ class ModuleApiTestCase(HomeserverTestCase):
login.register_servlets,
room.register_servlets,
presence.register_servlets,
+ profile.register_servlets,
]
def prepare(self, reactor, clock, homeserver):
@@ -115,7 +116,6 @@ class ModuleApiTestCase(HomeserverTestCase):
# Insert a second ip, agent at a later date. We should be able to retrieve it.
last_seen_2 = last_seen_1 + 10000
- print("%s => %s" % (last_seen_1, last_seen_2))
self.get_success(
self.store.insert_client_ip(
user_id, "access_token", "ip_2", "user_agent_2", "device_2", last_seen_2
@@ -385,6 +385,152 @@ class ModuleApiTestCase(HomeserverTestCase):
self.assertTrue(found_update)
+ def test_update_membership(self):
+ """Tests that the module API can update the membership of a user in a room."""
+ peter = self.register_user("peter", "hackme")
+ lesley = self.register_user("lesley", "hackme")
+ tok = self.login("peter", "hackme")
+ lesley_tok = self.login("lesley", "hackme")
+
+ # Make peter create a public room.
+ room_id = self.helper.create_room_as(
+ room_creator=peter, is_public=True, tok=tok
+ )
+
+ # Set a profile for lesley.
+ channel = self.make_request(
+ method="PUT",
+ path="/_matrix/client/r0/profile/%s/displayname" % lesley,
+ content={"displayname": "Lesley May"},
+ access_token=lesley_tok,
+ )
+
+ self.assertEqual(channel.code, 200, channel.result)
+
+ channel = self.make_request(
+ method="PUT",
+ path="/_matrix/client/r0/profile/%s/avatar_url" % lesley,
+ content={"avatar_url": "some_url"},
+ access_token=lesley_tok,
+ )
+
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # Make Peter invite Lesley to the room.
+ self.get_success(
+ defer.ensureDeferred(
+ self.module_api.update_room_membership(peter, lesley, room_id, "invite")
+ )
+ )
+
+ res = self.helper.get_state(
+ room_id=room_id,
+ event_type="m.room.member",
+ state_key=lesley,
+ tok=tok,
+ )
+
+ # Check the membership is correct.
+ self.assertEqual(res["membership"], "invite")
+
+ # Also check that the profile was correctly filled out, and that it's not
+ # Peter's.
+ self.assertEqual(res["displayname"], "Lesley May")
+ self.assertEqual(res["avatar_url"], "some_url")
+
+ # Make lesley join it.
+ self.get_success(
+ defer.ensureDeferred(
+ self.module_api.update_room_membership(lesley, lesley, room_id, "join")
+ )
+ )
+
+ # Check that the membership of lesley in the room is "join".
+ res = self.helper.get_state(
+ room_id=room_id,
+ event_type="m.room.member",
+ state_key=lesley,
+ tok=tok,
+ )
+
+ self.assertEqual(res["membership"], "join")
+
+ # Also check that the profile was correctly filled out.
+ self.assertEqual(res["displayname"], "Lesley May")
+ self.assertEqual(res["avatar_url"], "some_url")
+
+ # Make peter kick lesley from the room.
+ self.get_success(
+ defer.ensureDeferred(
+ self.module_api.update_room_membership(peter, lesley, room_id, "leave")
+ )
+ )
+
+ # Check that the membership of lesley in the room is "leave".
+ res = self.helper.get_state(
+ room_id=room_id,
+ event_type="m.room.member",
+ state_key=lesley,
+ tok=tok,
+ )
+
+ self.assertEqual(res["membership"], "leave")
+
+ # Try to send a membership update from a non-local user and check that it fails.
+ d = defer.ensureDeferred(
+ self.module_api.update_room_membership(
+ "@nicolas:otherserver.com",
+ lesley,
+ room_id,
+ "invite",
+ )
+ )
+
+ self.get_failure(d, RuntimeError)
+
+ # Check that inviting a user that doesn't have a profile falls back to using a
+ # default (localpart + no avatar) profile.
+ simone = "@simone:" + self.hs.config.server.server_name
+ self.get_success(
+ defer.ensureDeferred(
+ self.module_api.update_room_membership(peter, simone, room_id, "invite")
+ )
+ )
+
+ res = self.helper.get_state(
+ room_id=room_id,
+ event_type="m.room.member",
+ state_key=simone,
+ tok=tok,
+ )
+
+ self.assertEqual(res["membership"], "invite")
+ self.assertEqual(res["displayname"], "simone")
+ self.assertIsNone(res["avatar_url"])
+
+ def test_get_room_state(self):
+ """Tests that a module can retrieve the state of a room through the module API."""
+ user_id = self.register_user("peter", "hackme")
+ tok = self.login("peter", "hackme")
+
+ # Create a room and send some custom state in it.
+ room_id = self.helper.create_room_as(tok=tok)
+ self.helper.send_state(room_id, "org.matrix.test", {}, tok=tok)
+
+ # Check that the module API can successfully fetch state for the room.
+ state = self.get_success(
+ defer.ensureDeferred(self.module_api.get_room_state(room_id))
+ )
+
+ # Check that a few standard events are in the returned state.
+ self.assertIn((EventTypes.Create, ""), state)
+ self.assertIn((EventTypes.Member, user_id), state)
+
+ # Check that our custom state event is in the returned state.
+ self.assertEqual(state[("org.matrix.test", "")].sender, user_id)
+ self.assertEqual(state[("org.matrix.test", "")].state_key, "")
+ self.assertEqual(state[("org.matrix.test", "")].content, {})
+
class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
"""For testing ModuleApi functionality in a multi-worker setup"""
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index fa8018e5a7..90f800e564 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -65,7 +65,7 @@ class EmailPusherTests(HomeserverTestCase):
"notif_from": "test@example.com",
"riot_base_url": None,
}
- config["public_baseurl"] = "aaa"
+ config["public_baseurl"] = "http://aaa"
config["start_pushers"] = True
hs = self.setup_test_homeserver(config=config)
diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py
new file mode 100644
index 0000000000..78c48db552
--- /dev/null
+++ b/tests/rest/admin/test_background_updates.py
@@ -0,0 +1,218 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import synapse.rest.admin
+from synapse.rest.client import login
+from synapse.server import HomeServer
+
+from tests import unittest
+
+
+class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs: HomeServer):
+ self.store = hs.get_datastore()
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ def _register_bg_update(self):
+ "Adds a bg update but doesn't start it"
+
+ async def _fake_update(progress, batch_size) -> int:
+ await self.clock.sleep(0.2)
+ return batch_size
+
+ self.store.db_pool.updates.register_background_update_handler(
+ "test_update",
+ _fake_update,
+ )
+
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ table="background_updates",
+ values={
+ "update_name": "test_update",
+ "progress_json": "{}",
+ },
+ )
+ )
+
+ def test_status_empty(self):
+ """Test the status API works."""
+
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/background_updates/status",
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Background updates should be enabled, but none should be running.
+ self.assertDictEqual(
+ channel.json_body, {"current_updates": {}, "enabled": True}
+ )
+
+ def test_status_bg_update(self):
+ """Test the status API works with a background update."""
+
+ # Create a new background update
+
+ self._register_bg_update()
+
+ self.store.db_pool.updates.start_doing_background_updates()
+ self.reactor.pump([1.0, 1.0])
+
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/background_updates/status",
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Background updates should be enabled, and one should be running.
+ self.assertDictEqual(
+ channel.json_body,
+ {
+ "current_updates": {
+ "master": {
+ "name": "test_update",
+ "average_items_per_ms": 0.1,
+ "total_duration_ms": 1000.0,
+ "total_item_count": 100,
+ }
+ },
+ "enabled": True,
+ },
+ )
+
+ def test_enabled(self):
+ """Test the enabled API works."""
+
+ # Create a new background update
+
+ self._register_bg_update()
+ self.store.db_pool.updates.start_doing_background_updates()
+
+ # Test that GET works and returns enabled is True.
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/background_updates/enabled",
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertDictEqual(channel.json_body, {"enabled": True})
+
+ # Disable the BG updates
+ channel = self.make_request(
+ "POST",
+ "/_synapse/admin/v1/background_updates/enabled",
+ content={"enabled": False},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertDictEqual(channel.json_body, {"enabled": False})
+
+ # Advance a bit and get the current status, note this will finish the in
+ # flight background update so we call it the status API twice and check
+ # there was no change.
+ self.reactor.pump([1.0, 1.0])
+
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/background_updates/status",
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertDictEqual(
+ channel.json_body,
+ {
+ "current_updates": {
+ "master": {
+ "name": "test_update",
+ "average_items_per_ms": 0.1,
+ "total_duration_ms": 1000.0,
+ "total_item_count": 100,
+ }
+ },
+ "enabled": False,
+ },
+ )
+
+ # Run the reactor for a bit so the BG updates would have a chance to run
+ # if they were to.
+ self.reactor.pump([1.0, 1.0])
+
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/background_updates/status",
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # There should be no change from the previous /status response.
+ self.assertDictEqual(
+ channel.json_body,
+ {
+ "current_updates": {
+ "master": {
+ "name": "test_update",
+ "average_items_per_ms": 0.1,
+ "total_duration_ms": 1000.0,
+ "total_item_count": 100,
+ }
+ },
+ "enabled": False,
+ },
+ )
+
+ # Re-enable the background updates.
+
+ channel = self.make_request(
+ "POST",
+ "/_synapse/admin/v1/background_updates/enabled",
+ content={"enabled": True},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ self.assertDictEqual(channel.json_body, {"enabled": True})
+
+ self.reactor.pump([1.0, 1.0])
+
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/background_updates/status",
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Background updates should be enabled and making progress.
+ self.assertDictEqual(
+ channel.json_body,
+ {
+ "current_updates": {
+ "master": {
+ "name": "test_update",
+ "average_items_per_ms": 0.1,
+ "total_duration_ms": 2000.0,
+ "total_item_count": 200,
+ }
+ },
+ "enabled": True,
+ },
+ )
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 0fa55e03b4..46116644ce 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -17,8 +17,6 @@ import urllib.parse
from typing import List, Optional
from unittest.mock import Mock
-from parameterized import parameterized_class
-
import synapse.rest.admin
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import Codes
@@ -29,13 +27,6 @@ from tests import unittest
"""Tests admin REST events for /rooms paths."""
-@parameterized_class(
- ("method", "url_template"),
- [
- ("POST", "/_synapse/admin/v1/rooms/%s/delete"),
- ("DELETE", "/_synapse/admin/v1/rooms/%s"),
- ],
-)
class DeleteRoomTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
@@ -67,7 +58,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
self.room_id = self.helper.create_room_as(
self.other_user, tok=self.other_user_tok
)
- self.url = self.url_template % self.room_id
+ self.url = "/_synapse/admin/v1/rooms/%s" % self.room_id
def test_requester_is_no_admin(self):
"""
@@ -75,7 +66,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request(
- self.method,
+ "DELETE",
self.url,
json.dumps({}),
access_token=self.other_user_tok,
@@ -88,10 +79,10 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
"""
Check that unknown rooms/server return error 404.
"""
- url = self.url_template % "!unknown:test"
+ url = "/_synapse/admin/v1/rooms/%s" % "!unknown:test"
channel = self.make_request(
- self.method,
+ "DELETE",
url,
json.dumps({}),
access_token=self.admin_user_tok,
@@ -104,10 +95,10 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
"""
Check that invalid room names, return an error 400.
"""
- url = self.url_template % "invalidroom"
+ url = "/_synapse/admin/v1/rooms/%s" % "invalidroom"
channel = self.make_request(
- self.method,
+ "DELETE",
url,
json.dumps({}),
access_token=self.admin_user_tok,
@@ -126,7 +117,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
body = json.dumps({"new_room_user_id": "@unknown:test"})
channel = self.make_request(
- self.method,
+ "DELETE",
self.url,
content=body.encode(encoding="utf_8"),
access_token=self.admin_user_tok,
@@ -145,7 +136,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
body = json.dumps({"new_room_user_id": "@not:exist.bla"})
channel = self.make_request(
- self.method,
+ "DELETE",
self.url,
content=body.encode(encoding="utf_8"),
access_token=self.admin_user_tok,
@@ -164,7 +155,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
body = json.dumps({"block": "NotBool"})
channel = self.make_request(
- self.method,
+ "DELETE",
self.url,
content=body.encode(encoding="utf_8"),
access_token=self.admin_user_tok,
@@ -180,7 +171,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
body = json.dumps({"purge": "NotBool"})
channel = self.make_request(
- self.method,
+ "DELETE",
self.url,
content=body.encode(encoding="utf_8"),
access_token=self.admin_user_tok,
@@ -206,7 +197,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
body = json.dumps({"block": True, "purge": True})
channel = self.make_request(
- self.method,
+ "DELETE",
self.url.encode("ascii"),
content=body.encode(encoding="utf_8"),
access_token=self.admin_user_tok,
@@ -239,7 +230,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
body = json.dumps({"block": False, "purge": True})
channel = self.make_request(
- self.method,
+ "DELETE",
self.url.encode("ascii"),
content=body.encode(encoding="utf_8"),
access_token=self.admin_user_tok,
@@ -270,10 +261,10 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
# Assert one user in room
self._is_member(room_id=self.room_id, user_id=self.other_user)
- body = json.dumps({"block": False, "purge": False})
+ body = json.dumps({"block": True, "purge": False})
channel = self.make_request(
- self.method,
+ "DELETE",
self.url.encode("ascii"),
content=body.encode(encoding="utf_8"),
access_token=self.admin_user_tok,
@@ -287,7 +278,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
with self.assertRaises(AssertionError):
self._is_purged(self.room_id)
- self._is_blocked(self.room_id, expect=False)
+ self._is_blocked(self.room_id, expect=True)
self._has_no_members(self.room_id)
def test_shutdown_room_consent(self):
@@ -319,7 +310,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
# Test that the admin can still send shutdown
channel = self.make_request(
- self.method,
+ "DELETE",
self.url,
json.dumps({"new_room_user_id": self.admin_user}),
access_token=self.admin_user_tok,
@@ -365,7 +356,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
# Test that the admin can still send shutdown
channel = self.make_request(
- self.method,
+ "DELETE",
self.url,
json.dumps({"new_room_user_id": self.admin_user}),
access_token=self.admin_user_tok,
@@ -689,36 +680,6 @@ class RoomTestCase(unittest.HomeserverTestCase):
reversing the order, etc.
"""
- def _set_canonical_alias(room_id: str, test_alias: str, admin_user_tok: str):
- # Create a new alias to this room
- url = "/_matrix/client/r0/directory/room/%s" % (
- urllib.parse.quote(test_alias),
- )
- channel = self.make_request(
- "PUT",
- url.encode("ascii"),
- {"room_id": room_id},
- access_token=admin_user_tok,
- )
- self.assertEqual(
- 200, int(channel.result["code"]), msg=channel.result["body"]
- )
-
- # Set this new alias as the canonical alias for this room
- self.helper.send_state(
- room_id,
- "m.room.aliases",
- {"aliases": [test_alias]},
- tok=admin_user_tok,
- state_key="test",
- )
- self.helper.send_state(
- room_id,
- "m.room.canonical_alias",
- {"alias": test_alias},
- tok=admin_user_tok,
- )
-
def _order_test(
order_type: str,
expected_room_list: List[str],
@@ -790,9 +751,9 @@ class RoomTestCase(unittest.HomeserverTestCase):
)
# Set room canonical room aliases
- _set_canonical_alias(room_id_1, "#A_alias:test", self.admin_user_tok)
- _set_canonical_alias(room_id_2, "#B_alias:test", self.admin_user_tok)
- _set_canonical_alias(room_id_3, "#C_alias:test", self.admin_user_tok)
+ self._set_canonical_alias(room_id_1, "#A_alias:test", self.admin_user_tok)
+ self._set_canonical_alias(room_id_2, "#B_alias:test", self.admin_user_tok)
+ self._set_canonical_alias(room_id_3, "#C_alias:test", self.admin_user_tok)
# Set room member size in the reverse order. room 1 -> 1 member, 2 -> 2, 3 -> 3
user_1 = self.register_user("bob1", "pass")
@@ -859,7 +820,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
room_name_1 = "something"
- room_name_2 = "else"
+ room_name_2 = "LoremIpsum"
# Set the name for each room
self.helper.send_state(
@@ -875,6 +836,8 @@ class RoomTestCase(unittest.HomeserverTestCase):
tok=self.admin_user_tok,
)
+ self._set_canonical_alias(room_id_1, "#Room_Alias1:test", self.admin_user_tok)
+
def _search_test(
expected_room_id: Optional[str],
search_term: str,
@@ -923,24 +886,36 @@ class RoomTestCase(unittest.HomeserverTestCase):
r = rooms[0]
self.assertEqual(expected_room_id, r["room_id"])
- # Perform search tests
+ # Test searching by room name
_search_test(room_id_1, "something")
_search_test(room_id_1, "thing")
- _search_test(room_id_2, "else")
- _search_test(room_id_2, "se")
+ _search_test(room_id_2, "LoremIpsum")
+ _search_test(room_id_2, "lorem")
# Test case insensitive
_search_test(room_id_1, "SOMETHING")
_search_test(room_id_1, "THING")
- _search_test(room_id_2, "ELSE")
- _search_test(room_id_2, "SE")
+ _search_test(room_id_2, "LOREMIPSUM")
+ _search_test(room_id_2, "LOREM")
_search_test(None, "foo")
_search_test(None, "bar")
_search_test(None, "", expected_http_code=400)
+ # Test that the whole room id returns the room
+ _search_test(room_id_1, room_id_1)
+ # Test that the search by room_id is case sensitive
+ _search_test(None, room_id_1.lower())
+ # Test search part of local part of room id do not match
+ _search_test(None, room_id_1[1:10])
+
+ # Test that whole room alias return no result, because of domain
+ _search_test(None, "#Room_Alias1:test")
+ # Test search local part of alias
+ _search_test(room_id_1, "alias1")
+
def test_search_term_non_ascii(self):
"""Test that searching for a room with non-ASCII characters works correctly"""
@@ -1123,6 +1098,32 @@ class RoomTestCase(unittest.HomeserverTestCase):
# the create_room already does the right thing, so no need to verify that we got
# the state events it created.
+ def _set_canonical_alias(self, room_id: str, test_alias: str, admin_user_tok: str):
+ # Create a new alias to this room
+ url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),)
+ channel = self.make_request(
+ "PUT",
+ url.encode("ascii"),
+ {"room_id": room_id},
+ access_token=admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Set this new alias as the canonical alias for this room
+ self.helper.send_state(
+ room_id,
+ "m.room.aliases",
+ {"aliases": [test_alias]},
+ tok=admin_user_tok,
+ state_key="test",
+ )
+ self.helper.send_state(
+ room_id,
+ "m.room.canonical_alias",
+ {"alias": test_alias},
+ tok=admin_user_tok,
+ )
+
class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py
index 84d092ca82..fcdc565814 100644
--- a/tests/rest/client/test_consent.py
+++ b/tests/rest/client/test_consent.py
@@ -35,7 +35,6 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
config = self.default_config()
- config["public_baseurl"] = "aaaa"
config["form_secret"] = "123abc"
# Make some temporary templates...
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index 66dcfc9f88..6e7c0f11df 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -891,7 +891,6 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
"smtp_pass": None,
"notif_from": "test@example.com",
}
- config["public_baseurl"] = "aaa"
self.hs = self.setup_test_homeserver(config=config)
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index 95be369d4b..c427686376 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -14,6 +14,8 @@
# limitations under the License.
import json
+from parameterized import parameterized
+
import synapse.rest.admin
from synapse.api.constants import (
EventContentFields,
@@ -417,7 +419,30 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
# Test that the first user can't see the other user's hidden read receipt
self.assertEqual(self._get_read_receipt(), None)
- def test_read_receipt_with_empty_body(self):
+ @parameterized.expand(
+ [
+ # Old Element version, expected to send an empty body
+ (
+ "agent1",
+ "Element/1.2.2 (Linux; U; Android 9; MatrixAndroidSDK_X 0.0.1)",
+ 200,
+ ),
+ # Old SchildiChat version, expected to send an empty body
+ ("agent2", "SchildiChat/1.2.1 (Android 10)", 200),
+ # Expected 400: Denies empty body starting at version 1.3+
+ ("agent3", "Element/1.3.6 (Android 10)", 400),
+ ("agent4", "SchildiChat/1.3.6 (Android 11)", 400),
+ # Contains "Riot": Receipts with empty bodies expected
+ ("agent5", "Element (Riot.im) (Android 9)", 200),
+ # Expected 400: Does not contain "Android"
+ ("agent6", "Element/1.2.1", 400),
+ # Expected 400: Different format, missing "/" after Element; existing build that should allow empty bodies, but minimal ongoing usage
+ ("agent7", "Element dbg/1.1.8-dev (Android)", 400),
+ ]
+ )
+ def test_read_receipt_with_empty_body(
+ self, name, user_agent: str, expected_status_code: int
+ ):
# Send a message as the first user
res = self.helper.send(self.room_id, body="hello", tok=self.tok)
@@ -426,8 +451,9 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
"POST",
"/rooms/%s/receipt/m.read/%s" % (self.room_id, res["event_id"]),
access_token=self.tok2,
+ custom_headers=[("User-Agent", user_agent)],
)
- self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.code, expected_status_code)
def _get_read_receipt(self):
"""Syncs and returns the read receipt."""
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 531f09c48b..4e71b6ec12 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -15,7 +15,7 @@ import threading
from typing import TYPE_CHECKING, Dict, Optional, Tuple
from unittest.mock import Mock
-from synapse.api.constants import EventTypes
+from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
@@ -25,6 +25,7 @@ from synapse.types import JsonDict, Requester, StateMap
from synapse.util.frozenutils import unfreeze
from tests import unittest
+from tests.test_utils import make_awaitable
if TYPE_CHECKING:
from synapse.module_api import ModuleApi
@@ -74,7 +75,7 @@ class LegacyChangeEvents(LegacyThirdPartyRulesTestModule):
return d
-class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
+class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
servlets = [
admin.register_servlets,
login.register_servlets,
@@ -86,11 +87,29 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
load_legacy_third_party_event_rules(hs)
+ # We're not going to be properly signing events as our remote homeserver is fake,
+ # therefore disable event signature checks.
+ # Note that these checks are not relevant to this test case.
+
+ # Have this homeserver auto-approve all event signature checking.
+ async def approve_all_signature_checking(_, pdu):
+ return pdu
+
+ hs.get_federation_server()._check_sigs_and_hash = approve_all_signature_checking
+
+ # Have this homeserver skip event auth checks. This is necessary due to
+ # event auth checks ensuring that events were signed by the sender's homeserver.
+ async def _check_event_auth(origin, event, context, *args, **kwargs):
+ return context
+
+ hs.get_federation_event_handler()._check_event_auth = _check_event_auth
+
return hs
def prepare(self, reactor, clock, homeserver):
- # Create a user and room to play with during the tests
+ # Create some users and a room to play with during the tests
self.user_id = self.register_user("kermit", "monkey")
+ self.invitee = self.register_user("invitee", "hackme")
self.tok = self.login("kermit", "monkey")
# Some tests might prevent room creation on purpose.
@@ -197,19 +216,9 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
{"x": "x"},
access_token=self.tok,
)
- # check_event_allowed has some error handling, so it shouldn't 500 just because a
- # module did something bad.
- self.assertEqual(channel.code, 200, channel.result)
- event_id = channel.json_body["event_id"]
-
- channel = self.make_request(
- "GET",
- "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
- access_token=self.tok,
- )
- self.assertEqual(channel.code, 200, channel.result)
- ev = channel.json_body
- self.assertEqual(ev["content"]["x"], "x")
+ # Because check_event_allowed raises an exception, it leads to a
+ # 500 Internal Server Error
+ self.assertEqual(channel.code, 500, channel.result)
def test_modify_event(self):
"""The module can return a modified version of the event"""
@@ -424,6 +433,74 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["i"], i)
+ def test_on_new_event(self):
+ """Test that the on_new_event callback is called on new events"""
+ on_new_event = Mock(make_awaitable(None))
+ self.hs.get_third_party_event_rules()._on_new_event_callbacks.append(
+ on_new_event
+ )
+
+ # Send a message event to the room and check that the callback is called.
+ self.helper.send(room_id=self.room_id, tok=self.tok)
+ self.assertEqual(on_new_event.call_count, 1)
+
+ # Check that the callback is also called on membership updates.
+ self.helper.invite(
+ room=self.room_id,
+ src=self.user_id,
+ targ=self.invitee,
+ tok=self.tok,
+ )
+
+ self.assertEqual(on_new_event.call_count, 2)
+
+ args, _ = on_new_event.call_args
+
+ self.assertEqual(args[0].membership, Membership.INVITE)
+ self.assertEqual(args[0].state_key, self.invitee)
+
+ # Check that the invitee's membership is correct in the state that's passed down
+ # to the callback.
+ self.assertEqual(
+ args[1][(EventTypes.Member, self.invitee)].membership,
+ Membership.INVITE,
+ )
+
+ # Send an event over federation and check that the callback is also called.
+ self._send_event_over_federation()
+ self.assertEqual(on_new_event.call_count, 3)
+
+ def _send_event_over_federation(self) -> None:
+ """Send a dummy event over federation and check that the request succeeds."""
+ body = {
+ "origin": self.hs.config.server.server_name,
+ "origin_server_ts": self.clock.time_msec(),
+ "pdus": [
+ {
+ "sender": self.user_id,
+ "type": EventTypes.Message,
+ "state_key": "",
+ "content": {"body": "hello world", "msgtype": "m.text"},
+ "room_id": self.room_id,
+ "depth": 0,
+ "origin_server_ts": self.clock.time_msec(),
+ "prev_events": [],
+ "auth_events": [],
+ "signatures": {},
+ "unsigned": {},
+ }
+ ],
+ }
+
+ channel = self.make_request(
+ method="PUT",
+ path="/_matrix/federation/v1/send/1",
+ content=body,
+ federation_auth_origin=self.hs.config.server.server_name.encode("utf8"),
+ )
+
+ self.assertEqual(channel.code, 200, channel.result)
+
def _update_power_levels(self, event_default: int = 0):
"""Updates the room's power levels.
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 71fa87ce92..ec0979850b 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -120,6 +120,35 @@ class RestHelper:
expect_code=expect_code,
)
+ def knock(self, room=None, user=None, reason=None, expect_code=200, tok=None):
+ temp_id = self.auth_user_id
+ self.auth_user_id = user
+ path = "/knock/%s" % room
+ if tok:
+ path = path + "?access_token=%s" % tok
+
+ data = {}
+ if reason:
+ data["reason"] = reason
+
+ channel = make_request(
+ self.hs.get_reactor(),
+ self.site,
+ "POST",
+ path,
+ json.dumps(data).encode("utf8"),
+ )
+
+ assert (
+ int(channel.result["code"]) == expect_code
+ ), "Expected: %d, got: %d, resp: %r" % (
+ expect_code,
+ int(channel.result["code"]),
+ channel.result["body"],
+ )
+
+ self.auth_user_id = temp_id
+
def leave(self, room=None, user=None, expect_code=200, tok=None):
self.change_membership(
room=room,
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 4ae00755c9..4cf1ed5ddf 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -248,7 +248,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
self.media_id = "example.com/12345"
- def _req(self, content_disposition):
+ def _req(self, content_disposition, include_content_type=True):
channel = make_request(
self.reactor,
@@ -271,8 +271,11 @@ class MediaRepoTests(unittest.HomeserverTestCase):
headers = {
b"Content-Length": [b"%d" % (len(self.test_image.data))],
- b"Content-Type": [self.test_image.content_type],
}
+
+ if include_content_type:
+ headers[b"Content-Type"] = [self.test_image.content_type]
+
if content_disposition:
headers[b"Content-Disposition"] = [content_disposition]
@@ -285,6 +288,17 @@ class MediaRepoTests(unittest.HomeserverTestCase):
return channel
+ def test_handle_missing_content_type(self):
+ channel = self._req(
+ b"inline; filename=out" + self.test_image.extension,
+ include_content_type=False,
+ )
+ headers = channel.headers
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ headers.getRawHeaders(b"Content-Type"), [b"application/octet-stream"]
+ )
+
def test_disposition_filename_ascii(self):
"""
If the filename is filename=<ascii> then Synapse will decode it as an
diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py
index b2c0279ba0..118aa93a32 100644
--- a/tests/rest/test_well_known.py
+++ b/tests/rest/test_well_known.py
@@ -11,17 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.web.resource import Resource
-
-from synapse.rest.well_known import WellKnownResource
+from synapse.rest.well_known import well_known_resource
from tests import unittest
class WellKnownTests(unittest.HomeserverTestCase):
def create_test_resource(self):
- # replace the JsonResource with a WellKnownResource
- return WellKnownResource(self.hs)
+ # replace the JsonResource with a Resource wrapping the WellKnownResource
+ res = Resource()
+ res.putChild(b".well-known", well_known_resource(self.hs))
+ return res
@unittest.override_config(
{
@@ -29,7 +31,7 @@ class WellKnownTests(unittest.HomeserverTestCase):
"default_identity_server": "https://testis",
}
)
- def test_well_known(self):
+ def test_client_well_known(self):
channel = self.make_request(
"GET", "/.well-known/matrix/client", shorthand=False
)
@@ -48,9 +50,27 @@ class WellKnownTests(unittest.HomeserverTestCase):
"public_baseurl": None,
}
)
- def test_well_known_no_public_baseurl(self):
+ def test_client_well_known_no_public_baseurl(self):
channel = self.make_request(
"GET", "/.well-known/matrix/client", shorthand=False
)
self.assertEqual(channel.code, 404)
+
+ @unittest.override_config({"serve_server_wellknown": True})
+ def test_server_well_known(self):
+ channel = self.make_request(
+ "GET", "/.well-known/matrix/server", shorthand=False
+ )
+
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body,
+ {"m.server": "test:443"},
+ )
+
+ def test_server_well_known_disabled(self):
+ channel = self.make_request(
+ "GET", "/.well-known/matrix/server", shorthand=False
+ )
+ self.assertEqual(channel.code, 404)
diff --git a/tests/storage/databases/main/test_deviceinbox.py b/tests/storage/databases/main/test_deviceinbox.py
new file mode 100644
index 0000000000..4b67bd15b7
--- /dev/null
+++ b/tests/storage/databases/main/test_deviceinbox.py
@@ -0,0 +1,164 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.rest import admin
+from synapse.rest.client import devices
+
+from tests.unittest import HomeserverTestCase
+
+
+class DeviceInboxBackgroundUpdateStoreTestCase(HomeserverTestCase):
+
+ servlets = [
+ admin.register_servlets,
+ devices.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.user_id = self.register_user("foo", "pass")
+
+ def test_background_remove_deleted_devices_from_device_inbox(self):
+ """Test that the background task to delete old device_inboxes works properly."""
+
+ # create a valid device
+ self.get_success(
+ self.store.store_device(self.user_id, "cur_device", "display_name")
+ )
+
+ # Add device_inbox to devices
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "device_inbox",
+ {
+ "user_id": self.user_id,
+ "device_id": "cur_device",
+ "stream_id": 1,
+ "message_json": "{}",
+ },
+ )
+ )
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "device_inbox",
+ {
+ "user_id": self.user_id,
+ "device_id": "old_device",
+ "stream_id": 2,
+ "message_json": "{}",
+ },
+ )
+ )
+
+ # Insert and run the background update.
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ {
+ "update_name": "remove_deleted_devices_from_device_inbox",
+ "progress_json": "{}",
+ },
+ )
+ )
+
+ # ... and tell the DataStore that it hasn't finished all updates yet
+ self.store.db_pool.updates._all_done = False
+
+ self.wait_for_background_updates()
+
+ # Make sure the background task deleted old device_inbox
+ res = self.get_success(
+ self.store.db_pool.simple_select_onecol(
+ table="device_inbox",
+ keyvalues={},
+ retcol="device_id",
+ desc="get_device_id_from_device_inbox",
+ )
+ )
+ self.assertEqual(1, len(res))
+ self.assertEqual(res[0], "cur_device")
+
+ def test_background_remove_hidden_devices_from_device_inbox(self):
+ """Test that the background task to delete hidden devices
+ from device_inboxes works properly."""
+
+ # create a valid device
+ self.get_success(
+ self.store.store_device(self.user_id, "cur_device", "display_name")
+ )
+
+ # create a hidden device
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "devices",
+ values={
+ "user_id": self.user_id,
+ "device_id": "hidden_device",
+ "display_name": "hidden_display_name",
+ "hidden": True,
+ },
+ )
+ )
+
+ # Add device_inbox to devices
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "device_inbox",
+ {
+ "user_id": self.user_id,
+ "device_id": "cur_device",
+ "stream_id": 1,
+ "message_json": "{}",
+ },
+ )
+ )
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "device_inbox",
+ {
+ "user_id": self.user_id,
+ "device_id": "hidden_device",
+ "stream_id": 2,
+ "message_json": "{}",
+ },
+ )
+ )
+
+ # Insert and run the background update.
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ {
+ "update_name": "remove_hidden_devices_from_device_inbox",
+ "progress_json": "{}",
+ },
+ )
+ )
+
+ # ... and tell the DataStore that it hasn't finished all updates yet
+ self.store.db_pool.updates._all_done = False
+
+ self.wait_for_background_updates()
+
+ # Make sure the background task deleted hidden devices from device_inbox
+ res = self.get_success(
+ self.store.db_pool.simple_select_onecol(
+ table="device_inbox",
+ keyvalues={},
+ retcol="device_id",
+ desc="get_device_id_from_device_inbox",
+ )
+ )
+ self.assertEqual(1, len(res))
+ self.assertEqual(res[0], "cur_device")
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 0e4013ebea..c8ac67e35b 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -20,6 +20,7 @@ from parameterized import parameterized
import synapse.rest.admin
from synapse.http.site import XForwardedForRequest
from synapse.rest.client import login
+from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
from synapse.types import UserID
from tests import unittest
@@ -171,6 +172,27 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
if after_persisting:
# Trigger the storage loop
self.reactor.advance(10)
+ else:
+ # Check that the new IP and user agent has not been stored yet
+ db_result = self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="devices",
+ keyvalues={},
+ retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
+ ),
+ )
+ self.assertEqual(
+ db_result,
+ [
+ {
+ "user_id": user_id,
+ "device_id": device_id,
+ "ip": None,
+ "user_agent": None,
+ "last_seen": None,
+ },
+ ],
+ )
result = self.get_success(
self.store.get_last_client_ip_by_device(user_id, device_id)
@@ -189,6 +211,104 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
},
)
+ def test_get_last_client_ip_by_device_combined_data(self):
+ """Test that `get_last_client_ip_by_device` combines persisted and unpersisted
+ data together correctly
+ """
+ self.reactor.advance(12345678)
+
+ user_id = "@user:id"
+ device_id_1 = "MY_DEVICE_1"
+ device_id_2 = "MY_DEVICE_2"
+
+ # Insert user IPs
+ self.get_success(
+ self.store.store_device(
+ user_id,
+ device_id_1,
+ "display name",
+ )
+ )
+ self.get_success(
+ self.store.store_device(
+ user_id,
+ device_id_2,
+ "display name",
+ )
+ )
+ self.get_success(
+ self.store.insert_client_ip(
+ user_id, "access_token_1", "ip_1", "user_agent_1", device_id_1
+ )
+ )
+ self.get_success(
+ self.store.insert_client_ip(
+ user_id, "access_token_2", "ip_2", "user_agent_2", device_id_2
+ )
+ )
+
+ # Trigger the storage loop and wait for the rate limiting period to be over
+ self.reactor.advance(10 + LAST_SEEN_GRANULARITY / 1000)
+
+ # Update the user agent for the second device, without running the storage loop
+ self.get_success(
+ self.store.insert_client_ip(
+ user_id, "access_token_2", "ip_2", "user_agent_3", device_id_2
+ )
+ )
+
+ # Check that the new IP and user agent has not been stored yet
+ db_result = self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="devices",
+ keyvalues={},
+ retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
+ ),
+ )
+ self.assertCountEqual(
+ db_result,
+ [
+ {
+ "user_id": user_id,
+ "device_id": device_id_1,
+ "ip": "ip_1",
+ "user_agent": "user_agent_1",
+ "last_seen": 12345678000,
+ },
+ {
+ "user_id": user_id,
+ "device_id": device_id_2,
+ "ip": "ip_2",
+ "user_agent": "user_agent_2",
+ "last_seen": 12345678000,
+ },
+ ],
+ )
+
+ # Check that data from the database and memory are combined together correctly
+ result = self.get_success(
+ self.store.get_last_client_ip_by_device(user_id, None)
+ )
+ self.assertEqual(
+ result,
+ {
+ (user_id, device_id_1): {
+ "user_id": user_id,
+ "device_id": device_id_1,
+ "ip": "ip_1",
+ "user_agent": "user_agent_1",
+ "last_seen": 12345678000,
+ },
+ (user_id, device_id_2): {
+ "user_id": user_id,
+ "device_id": device_id_2,
+ "ip": "ip_2",
+ "user_agent": "user_agent_3",
+ "last_seen": 12345688000 + LAST_SEEN_GRANULARITY,
+ },
+ },
+ )
+
@parameterized.expand([(False,), (True,)])
def test_get_user_ip_and_agents(self, after_persisting: bool):
"""Test `get_user_ip_and_agents` for persisted and unpersisted data"""
@@ -207,6 +327,16 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
if after_persisting:
# Trigger the storage loop
self.reactor.advance(10)
+ else:
+ # Check that the new IP and user agent has not been stored yet
+ db_result = self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="user_ips",
+ keyvalues={},
+ retcols=("access_token", "ip", "user_agent", "last_seen"),
+ ),
+ )
+ self.assertEqual(db_result, [])
self.assertEqual(
self.get_success(self.store.get_user_ip_and_agents(user)),
@@ -220,6 +350,82 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
],
)
+ def test_get_user_ip_and_agents_combined_data(self):
+ """Test that `get_user_ip_and_agents` combines persisted and unpersisted data
+ together correctly
+ """
+ self.reactor.advance(12345678)
+
+ user_id = "@user:id"
+ user = UserID.from_string(user_id)
+
+ # Insert user IPs
+ self.get_success(
+ self.store.insert_client_ip(
+ user_id, "access_token", "ip_1", "user_agent_1", "MY_DEVICE_1"
+ )
+ )
+ self.get_success(
+ self.store.insert_client_ip(
+ user_id, "access_token", "ip_2", "user_agent_2", "MY_DEVICE_2"
+ )
+ )
+
+ # Trigger the storage loop and wait for the rate limiting period to be over
+ self.reactor.advance(10 + LAST_SEEN_GRANULARITY / 1000)
+
+ # Update the user agent for the second device, without running the storage loop
+ self.get_success(
+ self.store.insert_client_ip(
+ user_id, "access_token", "ip_2", "user_agent_3", "MY_DEVICE_2"
+ )
+ )
+
+ # Check that the new IP and user agent has not been stored yet
+ db_result = self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="user_ips",
+ keyvalues={},
+ retcols=("access_token", "ip", "user_agent", "last_seen"),
+ ),
+ )
+ self.assertEqual(
+ db_result,
+ [
+ {
+ "access_token": "access_token",
+ "ip": "ip_1",
+ "user_agent": "user_agent_1",
+ "last_seen": 12345678000,
+ },
+ {
+ "access_token": "access_token",
+ "ip": "ip_2",
+ "user_agent": "user_agent_2",
+ "last_seen": 12345678000,
+ },
+ ],
+ )
+
+ # Check that data from the database and memory are combined together correctly
+ self.assertCountEqual(
+ self.get_success(self.store.get_user_ip_and_agents(user)),
+ [
+ {
+ "access_token": "access_token",
+ "ip": "ip_1",
+ "user_agent": "user_agent_1",
+ "last_seen": 12345678000,
+ },
+ {
+ "access_token": "access_token",
+ "ip": "ip_2",
+ "user_agent": "user_agent_3",
+ "last_seen": 12345688000 + LAST_SEEN_GRANULARITY,
+ },
+ ],
+ )
+
@override_config({"limit_usage_by_mau": False, "max_mau_value": 50})
def test_disabled_monthly_active_user(self):
user_id = "@user:server"
diff --git a/tests/storage/test_rollback_worker.py b/tests/storage/test_rollback_worker.py
new file mode 100644
index 0000000000..a6be9a1bb1
--- /dev/null
+++ b/tests/storage/test_rollback_worker.py
@@ -0,0 +1,69 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.app.generic_worker import GenericWorkerServer
+from synapse.storage.database import LoggingDatabaseConnection
+from synapse.storage.prepare_database import PrepareDatabaseException, prepare_database
+from synapse.storage.schema import SCHEMA_VERSION
+
+from tests.unittest import HomeserverTestCase
+
+
+class WorkerSchemaTests(HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver(
+ federation_http_client=None, homeserver_to_use=GenericWorkerServer
+ )
+ return hs
+
+ def default_config(self):
+ conf = super().default_config()
+
+ # Mark this as a worker app.
+ conf["worker_app"] = "yes"
+
+ return conf
+
+ def test_rolling_back(self):
+ """Test that workers can start if the DB is a newer schema version"""
+
+ db_pool = self.hs.get_datastore().db_pool
+ db_conn = LoggingDatabaseConnection(
+ db_pool._db_pool.connect(),
+ db_pool.engine,
+ "tests",
+ )
+
+ cur = db_conn.cursor()
+ cur.execute("UPDATE schema_version SET version = ?", (SCHEMA_VERSION + 1,))
+
+ db_conn.commit()
+
+ prepare_database(db_conn, db_pool.engine, self.hs.config)
+
+ def test_not_upgraded(self):
+ """Test that workers don't start if the DB has an older schema version"""
+ db_pool = self.hs.get_datastore().db_pool
+ db_conn = LoggingDatabaseConnection(
+ db_pool._db_pool.connect(),
+ db_pool.engine,
+ "tests",
+ )
+
+ cur = db_conn.cursor()
+ cur.execute("UPDATE schema_version SET version = ?", (SCHEMA_VERSION - 1,))
+
+ db_conn.commit()
+
+ with self.assertRaises(PrepareDatabaseException):
+ prepare_database(db_conn, db_pool.engine, self.hs.config)
diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py
index 54a88a8325..c613ce3f10 100644
--- a/tests/util/caches/test_deferred_cache.py
+++ b/tests/util/caches/test_deferred_cache.py
@@ -47,9 +47,7 @@ class DeferredCacheTestCase(TestCase):
self.assertTrue(set_d.called)
return r
- # TODO: Actually ObservableDeferred *doesn't* run its tests in order on py3.8.
- # maybe we should fix that?
- # get_d.addCallback(check1)
+ get_d.addCallback(check1)
# now fire off all the deferreds
origin_d.callback(99)
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 39947a166b..ced3efd93f 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -17,6 +17,7 @@ from typing import Set
from unittest import mock
from twisted.internet import defer, reactor
+from twisted.internet.defer import Deferred
from synapse.api.errors import SynapseError
from synapse.logging.context import (
@@ -703,6 +704,48 @@ class CachedListDescriptorTestCase(unittest.TestCase):
obj.mock.assert_called_once_with((40,), 2)
self.assertEqual(r, {10: "fish", 40: "gravy"})
+ def test_concurrent_lookups(self):
+ """All concurrent lookups should get the same result"""
+
+ class Cls:
+ def __init__(self):
+ self.mock = mock.Mock()
+
+ @descriptors.cached()
+ def fn(self, arg1):
+ pass
+
+ @descriptors.cachedList("fn", "args1")
+ def list_fn(self, args1) -> "Deferred[dict]":
+ return self.mock(args1)
+
+ obj = Cls()
+ deferred_result = Deferred()
+ obj.mock.return_value = deferred_result
+
+ # start off several concurrent lookups of the same key
+ d1 = obj.list_fn([10])
+ d2 = obj.list_fn([10])
+ d3 = obj.list_fn([10])
+
+ # the mock should have been called exactly once
+ obj.mock.assert_called_once_with((10,))
+ obj.mock.reset_mock()
+
+ # ... and none of the calls should yet be complete
+ self.assertFalse(d1.called)
+ self.assertFalse(d2.called)
+ self.assertFalse(d3.called)
+
+ # complete the lookup. @cachedList functions need to complete with a map
+ # of input->result
+ deferred_result.callback({10: "peas"})
+
+ # ... which should give the right result to all the callers
+ self.assertEqual(self.successResultOf(d1), {10: "peas"})
+ self.assertEqual(self.successResultOf(d2), {10: "peas"})
+ self.assertEqual(self.successResultOf(d3), {10: "peas"})
+
@defer.inlineCallbacks
def test_invalidate(self):
"""Make sure that invalidation callbacks are called."""
diff --git a/tests/util/test_async_utils.py b/tests/util/test_async_helpers.py
index 069f875962..ab89cab812 100644
--- a/tests/util/test_async_utils.py
+++ b/tests/util/test_async_helpers.py
@@ -21,11 +21,78 @@ from synapse.logging.context import (
PreserveLoggingContext,
current_context,
)
-from synapse.util.async_helpers import timeout_deferred
+from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
from tests.unittest import TestCase
+class ObservableDeferredTest(TestCase):
+ def test_succeed(self):
+ origin_d = Deferred()
+ observable = ObservableDeferred(origin_d)
+
+ observer1 = observable.observe()
+ observer2 = observable.observe()
+
+ self.assertFalse(observer1.called)
+ self.assertFalse(observer2.called)
+
+ # check the first observer is called first
+ def check_called_first(res):
+ self.assertFalse(observer2.called)
+ return res
+
+ observer1.addBoth(check_called_first)
+
+ # store the results
+ results = [None, None]
+
+ def check_val(res, idx):
+ results[idx] = res
+ return res
+
+ observer1.addCallback(check_val, 0)
+ observer2.addCallback(check_val, 1)
+
+ origin_d.callback(123)
+ self.assertEqual(results[0], 123, "observer 1 callback result")
+ self.assertEqual(results[1], 123, "observer 2 callback result")
+
+ def test_failure(self):
+ origin_d = Deferred()
+ observable = ObservableDeferred(origin_d, consumeErrors=True)
+
+ observer1 = observable.observe()
+ observer2 = observable.observe()
+
+ self.assertFalse(observer1.called)
+ self.assertFalse(observer2.called)
+
+ # check the first observer is called first
+ def check_called_first(res):
+ self.assertFalse(observer2.called)
+ return res
+
+ observer1.addBoth(check_called_first)
+
+ # store the results
+ results = [None, None]
+
+ def check_val(res, idx):
+ results[idx] = res
+ return None
+
+ observer1.addErrback(check_val, 0)
+ observer2.addErrback(check_val, 1)
+
+ try:
+ raise Exception("gah!")
+ except Exception as e:
+ origin_d.errback(e)
+ self.assertEqual(str(results[0].value), "gah!", "observer 1 errback result")
+ self.assertEqual(str(results[1].value), "gah!", "observer 2 errback result")
+
+
class TimeoutDeferredTest(TestCase):
def setUp(self):
self.clock = Clock()
|