diff --git a/tests/federation/transport/test_client.py b/tests/federation/transport/test_client.py
index b84c74fc0e..c90635e0a0 100644
--- a/tests/federation/transport/test_client.py
+++ b/tests/federation/transport/test_client.py
@@ -13,12 +13,14 @@
# limitations under the License.
import json
+from typing import List, Optional
from unittest.mock import Mock
import ijson.common
from synapse.api.room_versions import RoomVersions
from synapse.federation.transport.client import SendJoinParser
+from synapse.types import JsonDict
from synapse.util import ExceptionBundle
from tests.unittest import TestCase
@@ -71,33 +73,68 @@ class SendJoinParserTestCase(TestCase):
def test_partial_state(self) -> None:
"""Check that the partial_state flag is correctly parsed"""
- parser = SendJoinParser(RoomVersions.V1, False)
- response = {
- "org.matrix.msc3706.partial_state": True,
- }
- serialised_response = json.dumps(response).encode()
+ def parse(response: JsonDict) -> bool:
+ parser = SendJoinParser(RoomVersions.V1, False)
+ serialised_response = json.dumps(response).encode()
- # Send data to the parser
- parser.write(serialised_response)
+ # Send data to the parser
+ parser.write(serialised_response)
- # Retrieve and check the parsed SendJoinResponse
- parsed_response = parser.finish()
- self.assertTrue(parsed_response.partial_state)
+ # Retrieve and check the parsed SendJoinResponse
+ parsed_response = parser.finish()
+ return parsed_response.partial_state
- def test_servers_in_room(self) -> None:
- """Check that the servers_in_room field is correctly parsed"""
- parser = SendJoinParser(RoomVersions.V1, False)
- response = {"org.matrix.msc3706.servers_in_room": ["hs1", "hs2"]}
+ self.assertTrue(parse({"members_omitted": True}))
+ self.assertTrue(parse({"org.matrix.msc3706.partial_state": True}))
- serialised_response = json.dumps(response).encode()
+ self.assertFalse(parse({"members_omitted": False}))
+ self.assertFalse(parse({"org.matrix.msc3706.partial_state": False}))
- # Send data to the parser
- parser.write(serialised_response)
+ # If there's a conflict, the stable field wins.
+ self.assertTrue(
+ parse({"members_omitted": True, "org.matrix.msc3706.partial_state": False})
+ )
+ self.assertFalse(
+ parse({"members_omitted": False, "org.matrix.msc3706.partial_state": True})
+ )
- # Retrieve and check the parsed SendJoinResponse
- parsed_response = parser.finish()
- self.assertEqual(parsed_response.servers_in_room, ["hs1", "hs2"])
+ def test_servers_in_room(self) -> None:
+ """Check that the servers_in_room field is correctly parsed"""
+
+ def parse(response: JsonDict) -> Optional[List[str]]:
+ parser = SendJoinParser(RoomVersions.V1, False)
+ serialised_response = json.dumps(response).encode()
+
+ # Send data to the parser
+ parser.write(serialised_response)
+
+ # Retrieve and check the parsed SendJoinResponse
+ parsed_response = parser.finish()
+ return parsed_response.servers_in_room
+
+ self.assertEqual(
+ parse({"org.matrix.msc3706.servers_in_room": ["hs1", "hs2"]}),
+ ["hs1", "hs2"],
+ )
+ self.assertEqual(parse({"servers_in_room": ["example.com"]}), ["example.com"])
+
+ # If both are provided, the stable identifier should win
+ self.assertEqual(
+ parse(
+ {
+ "org.matrix.msc3706.servers_in_room": ["old"],
+ "servers_in_room": ["new"],
+ }
+ ),
+ ["new"],
+ )
+
+ # And lastly, we should be able to tell if neither field was present.
+ self.assertEqual(
+ parse({}),
+ None,
+ )
def test_errors_closing_coroutines(self) -> None:
"""Check we close all coroutines, even if closing the first raises an Exception.
|