diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index a2dfa1ed05..4b53b6d40b 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -274,6 +274,39 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEquals(failure.value.code, 400)
self.assertEquals(failure.value.errcode, Codes.EXCLUSIVE)
+ def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self):
+ self.store.get_user_by_access_token = simple_async_mock(
+ TokenLookupResult(
+ user_id="@baldrick:matrix.org",
+ device_id="device",
+ token_owner="@admin:matrix.org",
+ )
+ )
+ self.store.insert_client_ip = simple_async_mock(None)
+ request = Mock(args={})
+ request.getClientIP.return_value = "127.0.0.1"
+ request.args[b"access_token"] = [self.test_token]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ self.get_success(self.auth.get_user_by_req(request))
+ self.store.insert_client_ip.assert_called_once()
+
+ def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self):
+ self.auth._track_puppeted_user_ips = True
+ self.store.get_user_by_access_token = simple_async_mock(
+ TokenLookupResult(
+ user_id="@baldrick:matrix.org",
+ device_id="device",
+ token_owner="@admin:matrix.org",
+ )
+ )
+ self.store.insert_client_ip = simple_async_mock(None)
+ request = Mock(args={})
+ request.getClientIP.return_value = "127.0.0.1"
+ request.args[b"access_token"] = [self.test_token]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ self.get_success(self.auth.get_user_by_req(request))
+ self.assertEquals(self.store.insert_client_ip.call_count, 2)
+
def test_get_user_from_macaroon(self):
self.store.get_user_by_access_token = simple_async_mock(
TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device")
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index ddcf3ee348..734ed84d78 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -13,8 +13,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Iterable
from unittest import mock
+from parameterized import parameterized
from signedjson import key as key, sign as sign
from twisted.internet import defer
@@ -23,6 +25,7 @@ from synapse.api.constants import RoomEncryptionAlgorithms
from synapse.api.errors import Codes, SynapseError
from tests import unittest
+from tests.test_utils import make_awaitable
class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
@@ -765,6 +768,8 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
remote_user_id = "@test:other"
local_user_id = "@test:test"
+ # Pretend we're sharing a room with the user we're querying. If not,
+ # `_query_devices_for_destination` will return early.
self.store.get_rooms_for_user = mock.Mock(
return_value=defer.succeed({"some_room_id"})
)
@@ -831,3 +836,94 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
}
},
)
+
+ @parameterized.expand(
+ [
+ # The remote homeserver's response indicates that this user has 0/1/2 devices.
+ ([],),
+ (["device_1"],),
+ (["device_1", "device_2"],),
+ ]
+ )
+ def test_query_all_devices_caches_result(self, device_ids: Iterable[str]):
+ """Test that requests for all of a remote user's devices are cached.
+
+ We do this by asserting that only one call over federation was made, and that
+ the two queries to the local homeserver produce the same response.
+ """
+ local_user_id = "@test:test"
+ remote_user_id = "@test:other"
+ request_body = {"device_keys": {remote_user_id: []}}
+
+ response_devices = [
+ {
+ "device_id": device_id,
+ "keys": {
+ "algorithms": ["dummy"],
+ "device_id": device_id,
+ "keys": {f"dummy:{device_id}": "dummy"},
+ "signatures": {device_id: {f"dummy:{device_id}": "dummy"}},
+ "unsigned": {},
+ "user_id": "@test:other",
+ },
+ }
+ for device_id in device_ids
+ ]
+
+ response_body = {
+ "devices": response_devices,
+ "user_id": remote_user_id,
+ "stream_id": 12345, # an integer, according to the spec
+ }
+
+ e2e_handler = self.hs.get_e2e_keys_handler()
+
+ # Pretend we're sharing a room with the user we're querying. If not,
+ # `_query_devices_for_destination` will return early.
+ mock_get_rooms = mock.patch.object(
+ self.store,
+ "get_rooms_for_user",
+ new_callable=mock.MagicMock,
+ return_value=make_awaitable(["some_room_id"]),
+ )
+ mock_request = mock.patch.object(
+ self.hs.get_federation_client(),
+ "query_user_devices",
+ new_callable=mock.MagicMock,
+ return_value=make_awaitable(response_body),
+ )
+
+ with mock_get_rooms, mock_request as mocked_federation_request:
+ # Make the first query and sanity check it succeeds.
+ response_1 = self.get_success(
+ e2e_handler.query_devices(
+ request_body,
+ timeout=10,
+ from_user_id=local_user_id,
+ from_device_id="some_device_id",
+ )
+ )
+ self.assertEqual(response_1["failures"], {})
+
+ # We should have made a federation request to do so.
+ mocked_federation_request.assert_called_once()
+
+ # Reset the mock so we can prove we don't make a second federation request.
+ mocked_federation_request.reset_mock()
+
+ # Repeat the query.
+ response_2 = self.get_success(
+ e2e_handler.query_devices(
+ request_body,
+ timeout=10,
+ from_user_id=local_user_id,
+ from_device_id="some_device_id",
+ )
+ )
+ self.assertEqual(response_2["failures"], {})
+
+ # We should not have made a second federation request.
+ mocked_federation_request.assert_not_called()
+
+ # The two requests to the local homeserver should be identical.
+ self.assertEqual(response_1, response_2)
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 08e9730d4d..2add72b28a 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -22,7 +22,7 @@ from twisted.internet import defer
import synapse
from synapse.handlers.auth import load_legacy_password_auth_providers
from synapse.module_api import ModuleApi
-from synapse.rest.client import devices, login
+from synapse.rest.client import devices, login, logout
from synapse.types import JsonDict
from tests import unittest
@@ -155,6 +155,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
synapse.rest.admin.register_servlets,
login.register_servlets,
devices.register_servlets,
+ logout.register_servlets,
]
def setUp(self):
@@ -719,6 +720,31 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 400, channel.result)
+ def test_on_logged_out(self):
+ """Tests that the on_logged_out callback is called when the user logs out."""
+ self.register_user("rin", "password")
+ tok = self.login("rin", "password")
+
+ self.called = False
+
+ async def on_logged_out(user_id, device_id, access_token):
+ self.called = True
+
+ on_logged_out = Mock(side_effect=on_logged_out)
+ self.hs.get_password_auth_provider().on_logged_out_callbacks.append(
+ on_logged_out
+ )
+
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/logout",
+ {},
+ access_token=tok,
+ )
+ self.assertEqual(channel.code, 200)
+ on_logged_out.assert_called_once()
+ self.assertTrue(self.called)
+
def _get_login_flows(self) -> JsonDict:
channel = self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, 200, channel.result)
diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py
index e5a6a6c747..51b22d2998 100644
--- a/tests/handlers/test_room_summary.py
+++ b/tests/handlers/test_room_summary.py
@@ -28,6 +28,7 @@ from synapse.api.constants import (
from synapse.api.errors import AuthError, NotFoundError, SynapseError
from synapse.api.room_versions import RoomVersions
from synapse.events import make_event_from_dict
+from synapse.federation.transport.client import TransportLayerClient
from synapse.handlers.room_summary import _child_events_comparison_key, _RoomEntry
from synapse.rest import admin
from synapse.rest.client import login, room
@@ -134,10 +135,18 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
self._add_child(self.space, self.room, self.token)
def _add_child(
- self, space_id: str, room_id: str, token: str, order: Optional[str] = None
+ self,
+ space_id: str,
+ room_id: str,
+ token: str,
+ order: Optional[str] = None,
+ via: Optional[List[str]] = None,
) -> None:
"""Add a child room to a space."""
- content: JsonDict = {"via": [self.hs.hostname]}
+ if via is None:
+ via = [self.hs.hostname]
+
+ content: JsonDict = {"via": via}
if order is not None:
content["order"] = order
self.helper.send_state(
@@ -253,6 +262,38 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
)
self._assert_hierarchy(result, expected)
+ def test_large_space(self):
+ """Test a space with a large number of rooms."""
+ rooms = [self.room]
+ # Make at least 51 rooms that are part of the space.
+ for _ in range(55):
+ room = self.helper.create_room_as(self.user, tok=self.token)
+ self._add_child(self.space, room, self.token)
+ rooms.append(room)
+
+ result = self.get_success(self.handler.get_space_summary(self.user, self.space))
+ # The spaces result should have the space and the first 50 rooms in it,
+ # along with the links from space -> room for those 50 rooms.
+ expected = [(self.space, rooms[:50])] + [(room, []) for room in rooms[:49]]
+ self._assert_rooms(result, expected)
+
+ # The result should have the space and the rooms in it, along with the links
+ # from space -> room.
+ expected = [(self.space, rooms)] + [(room, []) for room in rooms]
+
+ # Make two requests to fully paginate the results.
+ result = self.get_success(
+ self.handler.get_room_hierarchy(create_requester(self.user), self.space)
+ )
+ result2 = self.get_success(
+ self.handler.get_room_hierarchy(
+ create_requester(self.user), self.space, from_token=result["next_batch"]
+ )
+ )
+ # Combine the results.
+ result["rooms"] += result2["rooms"]
+ self._assert_hierarchy(result, expected)
+
def test_visibility(self):
"""A user not in a space cannot inspect it."""
user2 = self.register_user("user2", "pass")
@@ -1004,6 +1045,85 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
)
self._assert_hierarchy(result, expected)
+ def test_fed_caching(self):
+ """
+ Federation `/hierarchy` responses should be cached.
+ """
+ fed_hostname = self.hs.hostname + "2"
+ fed_subspace = "#space:" + fed_hostname
+ fed_room = "#room:" + fed_hostname
+
+ # Add a room to the space which is on another server.
+ self._add_child(self.space, fed_subspace, self.token, via=[fed_hostname])
+
+ federation_requests = 0
+
+ async def get_room_hierarchy(
+ _self: TransportLayerClient,
+ destination: str,
+ room_id: str,
+ suggested_only: bool,
+ ) -> JsonDict:
+ nonlocal federation_requests
+ federation_requests += 1
+
+ return {
+ "room": {
+ "room_id": fed_subspace,
+ "world_readable": True,
+ "room_type": RoomTypes.SPACE,
+ "children_state": [
+ {
+ "type": EventTypes.SpaceChild,
+ "room_id": fed_subspace,
+ "state_key": fed_room,
+ "content": {"via": [fed_hostname]},
+ },
+ ],
+ },
+ "children": [
+ {
+ "room_id": fed_room,
+ "world_readable": True,
+ },
+ ],
+ "inaccessible_children": [],
+ }
+
+ expected = [
+ (self.space, [self.room, fed_subspace]),
+ (self.room, ()),
+ (fed_subspace, [fed_room]),
+ (fed_room, ()),
+ ]
+
+ with mock.patch(
+ "synapse.federation.transport.client.TransportLayerClient.get_room_hierarchy",
+ new=get_room_hierarchy,
+ ):
+ result = self.get_success(
+ self.handler.get_room_hierarchy(create_requester(self.user), self.space)
+ )
+ self.assertEqual(federation_requests, 1)
+ self._assert_hierarchy(result, expected)
+
+ # The previous federation response should be reused.
+ result = self.get_success(
+ self.handler.get_room_hierarchy(create_requester(self.user), self.space)
+ )
+ self.assertEqual(federation_requests, 1)
+ self._assert_hierarchy(result, expected)
+
+ # Expire the response cache
+ self.reactor.advance(5 * 60 + 1)
+
+ # A new federation request should be made.
+ result = self.get_success(
+ self.handler.get_room_hierarchy(create_requester(self.user), self.space)
+ )
+ self.assertEqual(federation_requests, 2)
+ self._assert_hierarchy(result, expected)
+
class RoomSummaryTestCase(unittest.HomeserverTestCase):
servlets = [
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 638186f173..07a760e91a 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -11,15 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
from typing import Optional
-from unittest.mock import Mock
+from unittest.mock import MagicMock, Mock, patch
from synapse.api.constants import EventTypes, JoinRules
from synapse.api.errors import Codes, ResourceLimitError
from synapse.api.filtering import Filtering
from synapse.api.room_versions import RoomVersions
-from synapse.handlers.sync import SyncConfig
+from synapse.handlers.sync import SyncConfig, SyncResult
from synapse.rest import admin
from synapse.rest.client import knock, login, room
from synapse.server import HomeServer
@@ -27,6 +26,7 @@ from synapse.types import UserID, create_requester
import tests.unittest
import tests.utils
+from tests.test_utils import make_awaitable
class SyncTestCase(tests.unittest.HomeserverTestCase):
@@ -186,6 +186,97 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.assertNotIn(invite_room, [r.room_id for r in result.invited])
self.assertNotIn(knock_room, [r.room_id for r in result.knocked])
+ def test_ban_wins_race_with_join(self):
+ """Rooms shouldn't appear under "joined" if a join loses a race to a ban.
+
+ A complicated edge case. Imagine the following scenario:
+
+ * you attempt to join a room
+ * racing with that is a ban which comes in over federation, which ends up with
+ an earlier stream_ordering than the join.
+ * you get a sync response with a sync token which is _after_ the ban, but before
+ the join
+ * now your join lands; it is a valid event because its `prev_event`s predate the
+ ban, but will not make it into current_state_events (because bans win over
+ joins in state res, essentially).
+ * When we do a sync from the incremental sync, the only event in the timeline
+ is your join ... and yet you aren't joined.
+
+ The ban coming in over federation isn't crucial for this behaviour; the key
+ requirements are:
+ 1. the homeserver generates a join event with prev_events that precede the ban
+ (so that it passes the "are you banned" test)
+ 2. the join event has a stream_ordering after that of the ban.
+
+ We use monkeypatching to artificially trigger condition (1).
+ """
+ # A local user Alice creates a room.
+ owner = self.register_user("alice", "password")
+ owner_tok = self.login(owner, "password")
+ room_id = self.helper.create_room_as(owner, is_public=True, tok=owner_tok)
+
+ # Do a sync as Alice to get the latest event in the room.
+ alice_sync_result: SyncResult = self.get_success(
+ self.sync_handler.wait_for_sync_for_user(
+ create_requester(owner), generate_sync_config(owner)
+ )
+ )
+ self.assertEqual(len(alice_sync_result.joined), 1)
+ self.assertEqual(alice_sync_result.joined[0].room_id, room_id)
+ last_room_creation_event_id = (
+ alice_sync_result.joined[0].timeline.events[-1].event_id
+ )
+
+ # Eve, a ne'er-do-well, registers.
+ eve = self.register_user("eve", "password")
+ eve_token = self.login(eve, "password")
+
+ # Alice preemptively bans Eve.
+ self.helper.ban(room_id, owner, eve, tok=owner_tok)
+
+ # Eve syncs.
+ eve_requester = create_requester(eve)
+ eve_sync_config = generate_sync_config(eve)
+ eve_sync_after_ban: SyncResult = self.get_success(
+ self.sync_handler.wait_for_sync_for_user(eve_requester, eve_sync_config)
+ )
+
+ # Sanity check this sync result. We shouldn't be joined to the room.
+ self.assertEqual(eve_sync_after_ban.joined, [])
+
+ # Eve tries to join the room. We monkey patch the internal logic which selects
+ # the prev_events used when creating the join event, such that the ban does not
+ # precede the join.
+ mocked_get_prev_events = patch.object(
+ self.hs.get_datastore(),
+ "get_prev_events_for_room",
+ new_callable=MagicMock,
+ return_value=make_awaitable([last_room_creation_event_id]),
+ )
+ with mocked_get_prev_events:
+ self.helper.join(room_id, eve, tok=eve_token)
+
+ # Eve makes a second, incremental sync.
+ eve_incremental_sync_after_join: SyncResult = self.get_success(
+ self.sync_handler.wait_for_sync_for_user(
+ eve_requester,
+ eve_sync_config,
+ since_token=eve_sync_after_ban.next_batch,
+ )
+ )
+ # Eve should not see herself as joined to the room.
+ self.assertEqual(eve_incremental_sync_after_join.joined, [])
+
+ # If we did a third initial sync, we should _still_ see eve is not joined to the room.
+ eve_initial_sync_after_join: SyncResult = self.get_success(
+ self.sync_handler.wait_for_sync_for_user(
+ eve_requester,
+ eve_sync_config,
+ since_token=None,
+ )
+ )
+ self.assertEqual(eve_initial_sync_after_join.joined, [])
+
_request_key = 0
diff --git a/tests/http/test_webclient.py b/tests/http/test_webclient.py
new file mode 100644
index 0000000000..ee5cf299f6
--- /dev/null
+++ b/tests/http/test_webclient.py
@@ -0,0 +1,108 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from http import HTTPStatus
+from typing import Dict
+
+from twisted.web.resource import Resource
+
+from synapse.app.homeserver import SynapseHomeServer
+from synapse.config.server import HttpListenerConfig, HttpResourceConfig, ListenerConfig
+from synapse.http.site import SynapseSite
+
+from tests.server import make_request
+from tests.unittest import HomeserverTestCase, create_resource_tree, override_config
+
+
+class WebClientTests(HomeserverTestCase):
+ @override_config(
+ {
+ "web_client_location": "https://example.org",
+ }
+ )
+ def test_webclient_resolves_with_client_resource(self):
+ """
+ Tests that both client and webclient resources can be accessed simultaneously.
+
+ This is a regression test created in response to https://github.com/matrix-org/synapse/issues/11763.
+ """
+ for resource_name_order_list in [
+ ["webclient", "client"],
+ ["client", "webclient"],
+ ]:
+ # Create a dictionary from path regex -> resource
+ resource_dict: Dict[str, Resource] = {}
+
+ for resource_name in resource_name_order_list:
+ resource_dict.update(
+ SynapseHomeServer._configure_named_resource(self.hs, resource_name)
+ )
+
+ # Create a root resource which ties the above resources together into one
+ root_resource = Resource()
+ create_resource_tree(resource_dict, root_resource)
+
+ # Create a site configured with this resource to make HTTP requests against
+ listener_config = ListenerConfig(
+ port=8008,
+ bind_addresses=["127.0.0.1"],
+ type="http",
+ http_options=HttpListenerConfig(
+ resources=[HttpResourceConfig(names=resource_name_order_list)]
+ ),
+ )
+ test_site = SynapseSite(
+ logger_name="synapse.access.http.fake",
+ site_tag=self.hs.config.server.server_name,
+ config=listener_config,
+ resource=root_resource,
+ server_version_string="1",
+ max_request_body_size=1234,
+ reactor=self.reactor,
+ )
+
+ # Attempt to make requests to endpoints on both the webclient and client resources
+ # on test_site.
+ self._request_client_and_webclient_resources(test_site)
+
+ def _request_client_and_webclient_resources(self, test_site: SynapseSite) -> None:
+ """Make a request to an endpoint on both the webclient and client-server resources
+ of the given SynapseSite.
+
+ Args:
+ test_site: The SynapseSite object to make requests against.
+ """
+
+ # Ensure that the *webclient* resource is behaving as expected (we get redirected to
+ # the configured web_client_location)
+ channel = make_request(
+ self.reactor,
+ site=test_site,
+ method="GET",
+ path="/_matrix/client",
+ )
+ # Check that we are being redirected to the webclient location URI.
+ self.assertEqual(channel.code, HTTPStatus.FOUND)
+ self.assertEqual(
+ channel.headers.getRawHeaders("Location"), ["https://example.org"]
+ )
+
+ # Ensure that a request to the *client* resource works.
+ channel = make_request(
+ self.reactor,
+ site=test_site,
+ method="GET",
+ path="/_matrix/client/v3/login",
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK)
+ self.assertIn("flows", channel.json_body)
diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py
index 742f194257..b70350b6f1 100644
--- a/tests/rest/admin/test_federation.py
+++ b/tests/rest/admin/test_federation.py
@@ -314,15 +314,12 @@ class FederationTestCase(unittest.HomeserverTestCase):
retry_interval,
last_successful_stream_ordering,
) in dest:
- self.get_success(
- self.store.set_destination_retry_timings(
- destination, failure_ts, retry_last_ts, retry_interval
- )
- )
- self.get_success(
- self.store.set_destination_last_successful_stream_ordering(
- destination, last_successful_stream_ordering
- )
+ self._create_destination(
+ destination,
+ failure_ts,
+ retry_last_ts,
+ retry_interval,
+ last_successful_stream_ordering,
)
# order by default (destination)
@@ -413,11 +410,9 @@ class FederationTestCase(unittest.HomeserverTestCase):
_search_test(None, "foo")
_search_test(None, "bar")
- def test_get_single_destination(self) -> None:
- """
- Get one specific destinations.
- """
- self._create_destinations(5)
+ def test_get_single_destination_with_retry_timings(self) -> None:
+ """Get one specific destination which has retry timings."""
+ self._create_destinations(1)
channel = self.make_request(
"GET",
@@ -432,6 +427,53 @@ class FederationTestCase(unittest.HomeserverTestCase):
# convert channel.json_body into a List
self._check_fields([channel.json_body])
+ def test_get_single_destination_no_retry_timings(self) -> None:
+ """Get one specific destination which has no retry timings."""
+ self._create_destination("sub0.example.com")
+
+ channel = self.make_request(
+ "GET",
+ self.url + "/sub0.example.com",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual("sub0.example.com", channel.json_body["destination"])
+ self.assertEqual(0, channel.json_body["retry_last_ts"])
+ self.assertEqual(0, channel.json_body["retry_interval"])
+ self.assertIsNone(channel.json_body["failure_ts"])
+ self.assertIsNone(channel.json_body["last_successful_stream_ordering"])
+
+ def _create_destination(
+ self,
+ destination: str,
+ failure_ts: Optional[int] = None,
+ retry_last_ts: int = 0,
+ retry_interval: int = 0,
+ last_successful_stream_ordering: Optional[int] = None,
+ ) -> None:
+ """Create one specific destination
+
+ Args:
+ destination: the destination we have successfully sent to
+ failure_ts: when the server started failing (ms since epoch)
+ retry_last_ts: time of last retry attempt in unix epoch ms
+ retry_interval: how long until next retry in ms
+ last_successful_stream_ordering: the stream_ordering of the most
+ recent successfully-sent PDU
+ """
+ self.get_success(
+ self.store.set_destination_retry_timings(
+ destination, failure_ts, retry_last_ts, retry_interval
+ )
+ )
+ if last_successful_stream_ordering is not None:
+ self.get_success(
+ self.store.set_destination_last_successful_stream_ordering(
+ destination, last_successful_stream_ordering
+ )
+ )
+
def _create_destinations(self, number_destinations: int) -> None:
"""Create a number of destinations
@@ -440,10 +482,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
"""
for i in range(0, number_destinations):
dest = f"sub{i}.example.com"
- self.get_success(self.store.set_destination_retry_timings(dest, 50, 50, 50))
- self.get_success(
- self.store.set_destination_last_successful_stream_ordering(dest, 100)
- )
+ self._create_destination(dest, 50, 50, 50, 100)
def _check_fields(self, content: List[JsonDict]) -> None:
"""Checks that the expected destination attributes are present in content
diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py
index 81f3ac7f04..8513b1d2df 100644
--- a/tests/rest/admin/test_registration_tokens.py
+++ b/tests/rest/admin/test_registration_tokens.py
@@ -223,20 +223,13 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
# Create all possible single character tokens
tokens = []
for c in string.ascii_letters + string.digits + "._~-":
- tokens.append(
- {
- "token": c,
- "uses_allowed": None,
- "pending": 0,
- "completed": 0,
- "expiry_time": None,
- }
- )
+ tokens.append((c, None, 0, 0, None))
self.get_success(
self.store.db_pool.simple_insert_many(
"registration_tokens",
- tokens,
- "create_all_registration_tokens",
+ keys=("token", "uses_allowed", "pending", "completed", "expiry_time"),
+ values=tokens,
+ desc="create_all_registration_tokens",
)
)
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index d2c8781cd4..3495a0366a 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -1089,6 +1089,8 @@ class RoomTestCase(unittest.HomeserverTestCase):
)
room_ids.append(room_id)
+ room_ids.sort()
+
# Request the list of rooms
url = "/_synapse/admin/v1/rooms"
channel = self.make_request(
@@ -1360,6 +1362,12 @@ class RoomTestCase(unittest.HomeserverTestCase):
room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
room_id_3 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ # Also create a list sorted by IDs for properties that are equal (and thus sorted by room_id)
+ sorted_by_room_id_asc = [room_id_1, room_id_2, room_id_3]
+ sorted_by_room_id_asc.sort()
+ sorted_by_room_id_desc = sorted_by_room_id_asc.copy()
+ sorted_by_room_id_desc.reverse()
+
# Set room names in alphabetical order. room 1 -> A, 2 -> B, 3 -> C
self.helper.send_state(
room_id_1,
@@ -1405,41 +1413,42 @@ class RoomTestCase(unittest.HomeserverTestCase):
_order_test("canonical_alias", [room_id_1, room_id_2, room_id_3])
_order_test("canonical_alias", [room_id_3, room_id_2, room_id_1], reverse=True)
+ # Note: joined_member counts are sorted in descending order when dir=f
_order_test("joined_members", [room_id_3, room_id_2, room_id_1])
_order_test("joined_members", [room_id_1, room_id_2, room_id_3], reverse=True)
+ # Note: joined_local_member counts are sorted in descending order when dir=f
_order_test("joined_local_members", [room_id_3, room_id_2, room_id_1])
_order_test(
"joined_local_members", [room_id_1, room_id_2, room_id_3], reverse=True
)
- _order_test("version", [room_id_1, room_id_2, room_id_3])
- _order_test("version", [room_id_1, room_id_2, room_id_3], reverse=True)
+ # Note: versions are sorted in descending order when dir=f
+ _order_test("version", sorted_by_room_id_asc, reverse=True)
+ _order_test("version", sorted_by_room_id_desc)
- _order_test("creator", [room_id_1, room_id_2, room_id_3])
- _order_test("creator", [room_id_1, room_id_2, room_id_3], reverse=True)
+ _order_test("creator", sorted_by_room_id_asc)
+ _order_test("creator", sorted_by_room_id_desc, reverse=True)
- _order_test("encryption", [room_id_1, room_id_2, room_id_3])
- _order_test("encryption", [room_id_1, room_id_2, room_id_3], reverse=True)
+ _order_test("encryption", sorted_by_room_id_asc)
+ _order_test("encryption", sorted_by_room_id_desc, reverse=True)
- _order_test("federatable", [room_id_1, room_id_2, room_id_3])
- _order_test("federatable", [room_id_1, room_id_2, room_id_3], reverse=True)
+ _order_test("federatable", sorted_by_room_id_asc)
+ _order_test("federatable", sorted_by_room_id_desc, reverse=True)
- _order_test("public", [room_id_1, room_id_2, room_id_3])
- # Different sort order of SQlite and PostreSQL
- # _order_test("public", [room_id_3, room_id_2, room_id_1], reverse=True)
+ _order_test("public", sorted_by_room_id_asc)
+ _order_test("public", sorted_by_room_id_desc, reverse=True)
- _order_test("join_rules", [room_id_1, room_id_2, room_id_3])
- _order_test("join_rules", [room_id_1, room_id_2, room_id_3], reverse=True)
+ _order_test("join_rules", sorted_by_room_id_asc)
+ _order_test("join_rules", sorted_by_room_id_desc, reverse=True)
- _order_test("guest_access", [room_id_1, room_id_2, room_id_3])
- _order_test("guest_access", [room_id_1, room_id_2, room_id_3], reverse=True)
+ _order_test("guest_access", sorted_by_room_id_asc)
+ _order_test("guest_access", sorted_by_room_id_desc, reverse=True)
- _order_test("history_visibility", [room_id_1, room_id_2, room_id_3])
- _order_test(
- "history_visibility", [room_id_1, room_id_2, room_id_3], reverse=True
- )
+ _order_test("history_visibility", sorted_by_room_id_asc)
+ _order_test("history_visibility", sorted_by_room_id_desc, reverse=True)
+ # Note: state_event counts are sorted in descending order when dir=f
_order_test("state_events", [room_id_3, room_id_2, room_id_1])
_order_test("state_events", [room_id_1, room_id_2, room_id_3], reverse=True)
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index e0b9fe8e91..9711405735 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -1181,6 +1181,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.other_user, device_id=None, valid_until_ms=None
)
)
+
self.url_prefix = "/_synapse/admin/v2/users/%s"
self.url_other_user = self.url_prefix % self.other_user
@@ -1188,7 +1189,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"""
If the user is not a server admin, an error is returned.
"""
- url = "/_synapse/admin/v2/users/@bob:test"
+ url = self.url_prefix % "@bob:test"
channel = self.make_request(
"GET",
@@ -1216,7 +1217,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- "/_synapse/admin/v2/users/@unknown_person:test",
+ self.url_prefix % "@unknown_person:test",
access_token=self.admin_user_tok,
)
@@ -1337,7 +1338,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"""
Check that a new admin user is created successfully.
"""
- url = "/_synapse/admin/v2/users/@bob:test"
+ url = self.url_prefix % "@bob:test"
# Create user (server admin)
body = {
@@ -1386,7 +1387,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"""
Check that a new regular user is created successfully.
"""
- url = "/_synapse/admin/v2/users/@bob:test"
+ url = self.url_prefix % "@bob:test"
# Create user
body = {
@@ -1478,7 +1479,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
)
# Register new user with admin API
- url = "/_synapse/admin/v2/users/@bob:test"
+ url = self.url_prefix % "@bob:test"
# Create user
channel = self.make_request(
@@ -1515,7 +1516,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
)
# Register new user with admin API
- url = "/_synapse/admin/v2/users/@bob:test"
+ url = self.url_prefix % "@bob:test"
# Create user
channel = self.make_request(
@@ -1545,7 +1546,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
Check that a new regular user is created successfully and
got an email pusher.
"""
- url = "/_synapse/admin/v2/users/@bob:test"
+ url = self.url_prefix % "@bob:test"
# Create user
body = {
@@ -1588,7 +1589,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
Check that a new regular user is created successfully and
got not an email pusher.
"""
- url = "/_synapse/admin/v2/users/@bob:test"
+ url = self.url_prefix % "@bob:test"
# Create user
body = {
@@ -2085,10 +2086,13 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["deactivated"])
- self.assertIsNone(channel.json_body["password_hash"])
self.assertEqual(0, len(channel.json_body["threepids"]))
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
self.assertEqual("User", channel.json_body["displayname"])
+
+ # This key was removed intentionally. Ensure it is not accidentally re-included.
+ self.assertNotIn("password_hash", channel.json_body)
+
# the user is deactivated, the threepid will be deleted
# Get user
@@ -2101,11 +2105,13 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["deactivated"])
- self.assertIsNone(channel.json_body["password_hash"])
self.assertEqual(0, len(channel.json_body["threepids"]))
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
self.assertEqual("User", channel.json_body["displayname"])
+ # This key was removed intentionally. Ensure it is not accidentally re-included.
+ self.assertNotIn("password_hash", channel.json_body)
+
@override_config({"user_directory": {"enabled": True, "search_all_users": True}})
def test_change_name_deactivate_user_user_directory(self):
"""
@@ -2177,9 +2183,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertFalse(channel.json_body["deactivated"])
- self.assertIsNotNone(channel.json_body["password_hash"])
self._is_erased("@user:test", False)
+ # This key was removed intentionally. Ensure it is not accidentally re-included.
+ self.assertNotIn("password_hash", channel.json_body)
+
@override_config({"password_config": {"localdb_enabled": False}})
def test_reactivate_user_localdb_disabled(self):
"""
@@ -2209,9 +2217,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertFalse(channel.json_body["deactivated"])
- self.assertIsNone(channel.json_body["password_hash"])
self._is_erased("@user:test", False)
+ # This key was removed intentionally. Ensure it is not accidentally re-included.
+ self.assertNotIn("password_hash", channel.json_body)
+
@override_config({"password_config": {"enabled": False}})
def test_reactivate_user_password_disabled(self):
"""
@@ -2241,9 +2251,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertFalse(channel.json_body["deactivated"])
- self.assertIsNone(channel.json_body["password_hash"])
self._is_erased("@user:test", False)
+ # This key was removed intentionally. Ensure it is not accidentally re-included.
+ self.assertNotIn("password_hash", channel.json_body)
+
def test_set_user_as_admin(self):
"""
Test setting the admin flag on a user.
@@ -2328,7 +2340,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
Ensure an account can't accidentally be deactivated by using a str value
for the deactivated body parameter
"""
- url = "/_synapse/admin/v2/users/@bob:test"
+ url = self.url_prefix % "@bob:test"
# Create user
channel = self.make_request(
@@ -2392,18 +2404,20 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Deactivate the user.
channel = self.make_request(
"PUT",
- "/_synapse/admin/v2/users/%s" % urllib.parse.quote(user_id),
+ self.url_prefix % urllib.parse.quote(user_id),
access_token=self.admin_user_tok,
content={"deactivated": True},
)
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertTrue(channel.json_body["deactivated"])
- self.assertIsNone(channel.json_body["password_hash"])
self._is_erased(user_id, False)
d = self.store.mark_user_erased(user_id)
self.assertIsNone(self.get_success(d))
self._is_erased(user_id, True)
+ # This key was removed intentionally. Ensure it is not accidentally re-included.
+ self.assertNotIn("password_hash", channel.json_body)
+
def _check_fields(self, content: JsonDict):
"""Checks that the expected user attributes are present in content
@@ -2416,13 +2430,15 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertIn("admin", content)
self.assertIn("deactivated", content)
self.assertIn("shadow_banned", content)
- self.assertIn("password_hash", content)
self.assertIn("creation_ts", content)
self.assertIn("appservice_id", content)
self.assertIn("consent_server_notice_sent", content)
self.assertIn("consent_version", content)
self.assertIn("external_ids", content)
+ # This key was removed intentionally. Ensure it is not accidentally re-included.
+ self.assertNotIn("password_hash", content)
+
class UserMembershipRestTestCase(unittest.HomeserverTestCase):
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index c026d526ef..c9b220e73d 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -21,6 +21,7 @@ from unittest.mock import patch
from synapse.api.constants import EventTypes, RelationTypes
from synapse.rest import admin
from synapse.rest.client import login, register, relations, room, sync
+from synapse.types import JsonDict
from tests import unittest
from tests.server import FakeChannel
@@ -93,11 +94,6 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel.json_body,
)
- def test_deny_membership(self):
- """Test that we deny relations on membership events"""
- channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member)
- self.assertEquals(400, channel.code, channel.json_body)
-
def test_deny_invalid_event(self):
"""Test that we deny relations on non-existant events"""
channel = self._send_relation(
@@ -459,7 +455,14 @@ class RelationsTestCase(unittest.HomeserverTestCase):
@unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
def test_bundled_aggregations(self):
- """Test that annotations, references, and threads get correctly bundled."""
+ """
+ Test that annotations, references, and threads get correctly bundled.
+
+ Note that this doesn't test against /relations since only thread relations
+ get bundled via that API. See test_aggregation_get_event_for_thread.
+
+ See test_edit for a similar test for edits.
+ """
# Setup by sending a variety of relations.
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
self.assertEquals(200, channel.code, channel.json_body)
@@ -487,12 +490,13 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code, channel.json_body)
thread_2 = channel.json_body["event_id"]
- def assert_bundle(actual):
+ def assert_bundle(event_json: JsonDict) -> None:
"""Assert the expected values of the bundled aggregations."""
+ relations_dict = event_json["unsigned"].get("m.relations")
# Ensure the fields are as expected.
self.assertCountEqual(
- actual.keys(),
+ relations_dict.keys(),
(
RelationTypes.ANNOTATION,
RelationTypes.REFERENCE,
@@ -508,17 +512,20 @@ class RelationsTestCase(unittest.HomeserverTestCase):
{"type": "m.reaction", "key": "b", "count": 1},
]
},
- actual[RelationTypes.ANNOTATION],
+ relations_dict[RelationTypes.ANNOTATION],
)
self.assertEquals(
{"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]},
- actual[RelationTypes.REFERENCE],
+ relations_dict[RelationTypes.REFERENCE],
)
self.assertEquals(
2,
- actual[RelationTypes.THREAD].get("count"),
+ relations_dict[RelationTypes.THREAD].get("count"),
+ )
+ self.assertTrue(
+ relations_dict[RelationTypes.THREAD].get("current_user_participated")
)
# The latest thread event has some fields that don't matter.
self.assert_dict(
@@ -535,20 +542,9 @@ class RelationsTestCase(unittest.HomeserverTestCase):
"type": "m.room.test",
"user_id": self.user_id,
},
- actual[RelationTypes.THREAD].get("latest_event"),
+ relations_dict[RelationTypes.THREAD].get("latest_event"),
)
- def _find_and_assert_event(events):
- """
- Find the parent event in a chunk of events and assert that it has the proper bundled aggregations.
- """
- for event in events:
- if event["event_id"] == self.parent_id:
- break
- else:
- raise AssertionError(f"Event {self.parent_id} not found in chunk")
- assert_bundle(event["unsigned"].get("m.relations"))
-
# Request the event directly.
channel = self.make_request(
"GET",
@@ -556,7 +552,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
- assert_bundle(channel.json_body["unsigned"].get("m.relations"))
+ assert_bundle(channel.json_body)
# Request the room messages.
channel = self.make_request(
@@ -565,7 +561,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
- _find_and_assert_event(channel.json_body["chunk"])
+ assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"]))
# Request the room context.
channel = self.make_request(
@@ -574,17 +570,14 @@ class RelationsTestCase(unittest.HomeserverTestCase):
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
- assert_bundle(channel.json_body["event"]["unsigned"].get("m.relations"))
+ assert_bundle(channel.json_body["event"])
# Request sync.
- # channel = self.make_request("GET", "/sync", access_token=self.user_token)
- # self.assertEquals(200, channel.code, channel.json_body)
- # room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
- # self.assertTrue(room_timeline["limited"])
- # _find_and_assert_event(room_timeline["events"])
-
- # Note that /relations is tested separately in test_aggregation_get_event_for_thread
- # since it needs different data configured.
+ channel = self.make_request("GET", "/sync", access_token=self.user_token)
+ self.assertEquals(200, channel.code, channel.json_body)
+ room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
+ self.assertTrue(room_timeline["limited"])
+ self._find_event_in_chunk(room_timeline["events"])
def test_aggregation_get_event_for_annotation(self):
"""Test that annotations do not get bundled aggregations included
@@ -779,25 +772,58 @@ class RelationsTestCase(unittest.HomeserverTestCase):
edit_event_id = channel.json_body["event_id"]
+ def assert_bundle(event_json: JsonDict) -> None:
+ """Assert the expected values of the bundled aggregations."""
+ relations_dict = event_json["unsigned"].get("m.relations")
+ self.assertIn(RelationTypes.REPLACE, relations_dict)
+
+ m_replace_dict = relations_dict[RelationTypes.REPLACE]
+ for key in ["event_id", "sender", "origin_server_ts"]:
+ self.assertIn(key, m_replace_dict)
+
+ self.assert_dict(
+ {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
+ )
+
channel = self.make_request(
"GET",
- "/rooms/%s/event/%s" % (self.room, self.parent_id),
+ f"/rooms/{self.room}/event/{self.parent_id}",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
-
self.assertEquals(channel.json_body["content"], new_body)
+ assert_bundle(channel.json_body)
- relations_dict = channel.json_body["unsigned"].get("m.relations")
- self.assertIn(RelationTypes.REPLACE, relations_dict)
+ # Request the room messages.
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{self.room}/messages?dir=b",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"]))
- m_replace_dict = relations_dict[RelationTypes.REPLACE]
- for key in ["event_id", "sender", "origin_server_ts"]:
- self.assertIn(key, m_replace_dict)
+ # Request the room context.
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{self.room}/context/{self.parent_id}",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ assert_bundle(channel.json_body["event"])
- self.assert_dict(
- {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
+ # Request sync, but limit the timeline so it becomes limited (and includes
+ # bundled aggregations).
+ filter = urllib.parse.quote_plus(
+ '{"room": {"timeline": {"limit": 2}}}'.encode()
)
+ channel = self.make_request(
+ "GET", f"/sync?filter={filter}", access_token=self.user_token
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
+ self.assertTrue(room_timeline["limited"])
+ assert_bundle(self._find_event_in_chunk(room_timeline["events"]))
def test_multi_edit(self):
"""Test that multiple edits, including attempts by people who
@@ -1104,6 +1130,16 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code, channel.json_body)
self.assertEquals(channel.json_body["chunk"], [])
+ def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict:
+ """
+ Find the parent event in a chunk of events and assert that it has the proper bundled aggregations.
+ """
+ for event in events:
+ if event["event_id"] == self.parent_id:
+ return event
+
+ raise AssertionError(f"Event {self.parent_id} not found in chunk")
+
def _send_relation(
self,
relation_type: str,
@@ -1119,7 +1155,8 @@ class RelationsTestCase(unittest.HomeserverTestCase):
relation_type: One of `RelationTypes`
event_type: The type of the event to create
key: The aggregation key used for m.annotation relation type.
- content: The content of the created event.
+ content: The content of the created event. Will be modified to configure
+ the m.relates_to key based on the other provided parameters.
access_token: The access token used to send the relation, defaults
to `self.user_token`
parent_id: The event_id this relation relates to. If None, then self.parent_id
@@ -1130,17 +1167,21 @@ class RelationsTestCase(unittest.HomeserverTestCase):
if not access_token:
access_token = self.user_token
- query = ""
- if key:
- query = "?key=" + urllib.parse.quote_plus(key.encode("utf-8"))
-
original_id = parent_id if parent_id else self.parent_id
+ if content is None:
+ content = {}
+ content["m.relates_to"] = {
+ "event_id": original_id,
+ "rel_type": relation_type,
+ }
+ if key is not None:
+ content["m.relates_to"]["key"] = key
+
channel = self.make_request(
"POST",
- "/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s"
- % (self.room, original_id, relation_type, event_type, query),
- content or {},
+ f"/_matrix/client/v3/rooms/{self.room}/send/{event_type}",
+ content,
access_token=access_token,
)
return channel
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index b58452195a..fe5b536d97 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -228,7 +228,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
self.assertIsNotNone(event)
time_now = self.clock.time_msec()
- serialized = self.get_success(self.serializer.serialize_event(event, time_now))
+ serialized = self.serializer.serialize_event(event, time_now)
return serialized
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 1af5e5cee5..8424383580 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -196,6 +196,16 @@ class RestHelper:
expect_code=expect_code,
)
+ def ban(self, room: str, src: str, targ: str, **kwargs: object):
+ """A convenience helper: `change_membership` with `membership` preset to "ban"."""
+ self.change_membership(
+ room=room,
+ src=src,
+ targ=targ,
+ membership=Membership.BAN,
+ **kwargs,
+ )
+
def change_membership(
self,
room: str,
diff --git a/tests/server.py b/tests/server.py
index ca2b7a5b97..a0cd14ea45 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -14,6 +14,8 @@
import hashlib
import json
import logging
+import os
+import os.path
import time
import uuid
import warnings
@@ -71,6 +73,7 @@ from tests.utils import (
POSTGRES_HOST,
POSTGRES_PASSWORD,
POSTGRES_USER,
+ SQLITE_PERSIST_DB,
USE_POSTGRES_FOR_TESTS,
MockClock,
default_config,
@@ -739,9 +742,23 @@ def setup_test_homeserver(
},
}
else:
+ if SQLITE_PERSIST_DB:
+ # The current working directory is in _trial_temp, so this gets created within that directory.
+ test_db_location = os.path.abspath("test.db")
+ logger.debug("Will persist db to %s", test_db_location)
+ # Ensure each test gets a clean database.
+ try:
+ os.remove(test_db_location)
+ except FileNotFoundError:
+ pass
+ else:
+ logger.debug("Removed existing DB at %s", test_db_location)
+ else:
+ test_db_location = ":memory:"
+
database_config = {
"name": "sqlite3",
- "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
+ "args": {"database": test_db_location, "cp_min": 1, "cp_max": 1},
}
if "db_txn_limit" in kwargs:
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index bf78084869..2bc89512f8 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -531,17 +531,25 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
self.get_success(
self.store.db_pool.simple_insert_many(
table="federation_inbound_events_staging",
+ keys=(
+ "origin",
+ "room_id",
+ "received_ts",
+ "event_id",
+ "event_json",
+ "internal_metadata",
+ ),
values=[
- {
- "origin": "some_origin",
- "room_id": room_id,
- "received_ts": 0,
- "event_id": f"$fake_event_id_{i + 1}",
- "event_json": json_encoder.encode(
+ (
+ "some_origin",
+ room_id,
+ 0,
+ f"$fake_event_id_{i + 1}",
+ json_encoder.encode(
{"prev_events": [prev_event_format(f"$fake_event_id_{i}")]}
),
- "internal_metadata": "{}",
- }
+ "{}",
+ )
for i in range(500)
],
desc="test_prune_inbound_federation_queue",
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 3eef1c4c05..2b9804aba0 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -17,7 +17,9 @@ from unittest.mock import Mock
from twisted.internet.defer import succeed
from synapse.api.errors import FederationError
+from synapse.api.room_versions import RoomVersions
from synapse.events import make_event_from_dict
+from synapse.federation.federation_base import event_from_pdu_json
from synapse.logging.context import LoggingContext
from synapse.types import UserID, create_requester
from synapse.util import Clock
@@ -276,3 +278,73 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
"ed25519:" + remote_self_signing_key in self_signing_key["keys"].keys(),
)
self.assertTrue(remote_self_signing_key in self_signing_key["keys"].values())
+
+
+class StripUnsignedFromEventsTestCase(unittest.TestCase):
+ def test_strip_unauthorized_unsigned_values(self):
+ event1 = {
+ "sender": "@baduser:test.serv",
+ "state_key": "@baduser:test.serv",
+ "event_id": "$event1:test.serv",
+ "depth": 1000,
+ "origin_server_ts": 1,
+ "type": "m.room.member",
+ "origin": "test.servx",
+ "content": {"membership": "join"},
+ "auth_events": [],
+ "unsigned": {"malicious garbage": "hackz", "more warez": "more hackz"},
+ }
+ filtered_event = event_from_pdu_json(event1, RoomVersions.V1)
+ # Make sure unauthorized fields are stripped from unsigned
+ self.assertNotIn("more warez", filtered_event.unsigned)
+
+ def test_strip_event_maintains_allowed_fields(self):
+ event2 = {
+ "sender": "@baduser:test.serv",
+ "state_key": "@baduser:test.serv",
+ "event_id": "$event2:test.serv",
+ "depth": 1000,
+ "origin_server_ts": 1,
+ "type": "m.room.member",
+ "origin": "test.servx",
+ "auth_events": [],
+ "content": {"membership": "join"},
+ "unsigned": {
+ "malicious garbage": "hackz",
+ "more warez": "more hackz",
+ "age": 14,
+ "invite_room_state": [],
+ },
+ }
+
+ filtered_event2 = event_from_pdu_json(event2, RoomVersions.V1)
+ self.assertIn("age", filtered_event2.unsigned)
+ self.assertEqual(14, filtered_event2.unsigned["age"])
+ self.assertNotIn("more warez", filtered_event2.unsigned)
+ # Invite_room_state is allowed in events of type m.room.member
+ self.assertIn("invite_room_state", filtered_event2.unsigned)
+ self.assertEqual([], filtered_event2.unsigned["invite_room_state"])
+
+ def test_strip_event_removes_fields_based_on_event_type(self):
+ event3 = {
+ "sender": "@baduser:test.serv",
+ "state_key": "@baduser:test.serv",
+ "event_id": "$event3:test.serv",
+ "depth": 1000,
+ "origin_server_ts": 1,
+ "type": "m.room.power_levels",
+ "origin": "test.servx",
+ "content": {},
+ "auth_events": [],
+ "unsigned": {
+ "malicious garbage": "hackz",
+ "more warez": "more hackz",
+ "age": 14,
+ "invite_room_state": [],
+ },
+ }
+ filtered_event3 = event_from_pdu_json(event3, RoomVersions.V1)
+ self.assertIn("age", filtered_event3.unsigned)
+ # Invite_room_state field is only permitted in event type m.room.member
+ self.assertNotIn("invite_room_state", filtered_event3.unsigned)
+ self.assertNotIn("more warez", filtered_event3.unsigned)
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index 15ac2bfeba..f05a373aa0 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -19,7 +19,7 @@ import sys
import warnings
from asyncio import Future
from binascii import unhexlify
-from typing import Any, Awaitable, Callable, TypeVar
+from typing import Awaitable, Callable, TypeVar
from unittest.mock import Mock
import attr
@@ -46,7 +46,7 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV:
raise Exception("awaitable has not yet completed")
-def make_awaitable(result: Any) -> Awaitable[Any]:
+def make_awaitable(result: TV) -> Awaitable[TV]:
"""
Makes an awaitable, suitable for mocking an `async` function.
This uses Futures as they can be awaited multiple times so can be returned
diff --git a/tests/utils.py b/tests/utils.py
index 6d013e8518..c06fc320f3 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -42,6 +42,10 @@ POSTGRES_HOST = os.environ.get("SYNAPSE_POSTGRES_HOST", None)
POSTGRES_PASSWORD = os.environ.get("SYNAPSE_POSTGRES_PASSWORD", None)
POSTGRES_BASE_DB = "_synapse_unit_tests_base_%s" % (os.getpid(),)
+# When debugging a specific test, it's occasionally useful to write the
+# DB to disk and query it with the sqlite CLI.
+SQLITE_PERSIST_DB = os.environ.get("SYNAPSE_TEST_PERSIST_SQLITE_DB") is not None
+
# the dbname we will connect to in order to create the base database.
POSTGRES_DBNAME_FOR_INITIAL_CREATE = "postgres"
|