diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py
new file mode 100644
index 0000000000..121aa88cfa
--- /dev/null
+++ b/tests/federation/transport/test_knocking.py
@@ -0,0 +1,302 @@
+# Copyright 2020 Matrix.org Federation C.I.C
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from collections import OrderedDict
+from typing import Dict, List
+
+from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.api.room_versions import RoomVersions
+from synapse.events import builder
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.server import HomeServer
+from synapse.types import RoomAlias
+
+from tests.test_utils import event_injection
+from tests.unittest import FederatingHomeserverTestCase, TestCase, override_config
+
+# An identifier to use while MSC2304 is not in a stable release of the spec
+KNOCK_UNSTABLE_IDENTIFIER = "xyz.amorgan.knock"
+
+
+class KnockingStrippedStateEventHelperMixin(TestCase):
+ def send_example_state_events_to_room(
+ self,
+ hs: "HomeServer",
+ room_id: str,
+ sender: str,
+ ) -> OrderedDict:
+ """Adds some state to a room. State events are those that should be sent to a knocking
+ user after they knock on the room, as well as some state that *shouldn't* be sent
+ to the knocking user.
+
+ Args:
+ hs: The homeserver of the sender.
+ room_id: The ID of the room to send state into.
+ sender: The ID of the user to send state as. Must be in the room.
+
+ Returns:
+ The OrderedDict of event types and content that a user is expected to see
+ after knocking on a room.
+ """
+ # To set a canonical alias, we'll need to point an alias at the room first.
+ canonical_alias = "#fancy_alias:test"
+ self.get_success(
+ self.store.create_room_alias_association(
+ RoomAlias.from_string(canonical_alias), room_id, ["test"]
+ )
+ )
+
+ # Send some state that we *don't* expect to be given to knocking users
+ self.get_success(
+ event_injection.inject_event(
+ hs,
+ room_version=RoomVersions.MSC2403.identifier,
+ room_id=room_id,
+ sender=sender,
+ type="com.example.secret",
+ state_key="",
+ content={"secret": "password"},
+ )
+ )
+
+ # We use an OrderedDict here to ensure that the knock membership appears last.
+ # Note that order only matters when sending stripped state to clients, not federated
+ # homeservers.
+ room_state = OrderedDict(
+ [
+ # We need to set the room's join rules to allow knocking
+ (
+ EventTypes.JoinRules,
+ {"content": {"join_rule": JoinRules.KNOCK}, "state_key": ""},
+ ),
+ # Below are state events that are to be stripped and sent to clients
+ (
+ EventTypes.Name,
+ {"content": {"name": "A cool room"}, "state_key": ""},
+ ),
+ (
+ EventTypes.RoomAvatar,
+ {
+ "content": {
+ "info": {
+ "h": 398,
+ "mimetype": "image/jpeg",
+ "size": 31037,
+ "w": 394,
+ },
+ "url": "mxc://example.org/JWEIFJgwEIhweiWJE",
+ },
+ "state_key": "",
+ },
+ ),
+ (
+ EventTypes.RoomEncryption,
+ {"content": {"algorithm": "m.megolm.v1.aes-sha2"}, "state_key": ""},
+ ),
+ (
+ EventTypes.CanonicalAlias,
+ {
+ "content": {"alias": canonical_alias, "alt_aliases": []},
+ "state_key": "",
+ },
+ ),
+ ]
+ )
+
+ for event_type, event_dict in room_state.items():
+ event_content = event_dict["content"]
+ state_key = event_dict["state_key"]
+
+ self.get_success(
+ event_injection.inject_event(
+ hs,
+ room_version=RoomVersions.MSC2403.identifier,
+ room_id=room_id,
+ sender=sender,
+ type=event_type,
+ state_key=state_key,
+ content=event_content,
+ )
+ )
+
+ # Finally, we expect to see the m.room.create event of the room as part of the
+ # stripped state. We don't need to inject this event though.
+ room_state[EventTypes.Create] = {
+ "content": {
+ "creator": sender,
+ "room_version": RoomVersions.MSC2403.identifier,
+ },
+ "state_key": "",
+ }
+
+ return room_state
+
+ def check_knock_room_state_against_room_state(
+ self,
+ knock_room_state: List[Dict],
+ expected_room_state: Dict,
+ ) -> None:
+ """Test a list of stripped room state events received over federation against a
+ dict of expected state events.
+
+ Args:
+ knock_room_state: The list of room state that was received over federation.
+ expected_room_state: A dict containing the room state we expect to see in
+ `knock_room_state`.
+ """
+ for event in knock_room_state:
+ event_type = event["type"]
+
+ # Check that this event type is one of those that we expected.
+ # Note: This will also check that no excess state was included
+ self.assertIn(event_type, expected_room_state)
+
+ # Check the state content matches
+ self.assertEquals(
+ expected_room_state[event_type]["content"], event["content"]
+ )
+
+ # Check the state key is correct
+ self.assertEqual(
+ expected_room_state[event_type]["state_key"], event["state_key"]
+ )
+
+ # Ensure the event has been stripped
+ self.assertNotIn("signatures", event)
+
+ # Pop once we've found and processed a state event
+ expected_room_state.pop(event_type)
+
+ # Check that all expected state events were accounted for
+ self.assertEqual(len(expected_room_state), 0)
+
+
+class FederationKnockingTestCase(
+ FederatingHomeserverTestCase, KnockingStrippedStateEventHelperMixin
+):
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, homeserver):
+ self.store = homeserver.get_datastore()
+
+ # We're not going to be properly signing events as our remote homeserver is fake,
+ # therefore disable event signature checks.
+ # Note that these checks are not relevant to this test case.
+
+ # Have this homeserver auto-approve all event signature checking.
+ async def approve_all_signature_checking(_, pdu):
+ return pdu
+
+ homeserver.get_federation_server()._check_sigs_and_hash = (
+ approve_all_signature_checking
+ )
+
+ # Have this homeserver skip event auth checks. This is necessary due to
+ # event auth checks ensuring that events were signed by the sender's homeserver.
+ async def _check_event_auth(
+ origin, event, context, state, auth_events, backfilled
+ ):
+ return context
+
+ homeserver.get_federation_handler()._check_event_auth = _check_event_auth
+
+ return super().prepare(reactor, clock, homeserver)
+
+ @override_config({"experimental_features": {"msc2403_enabled": True}})
+ def test_room_state_returned_when_knocking(self):
+ """
+ Tests that specific, stripped state events from a room are returned after
+ a remote homeserver successfully knocks on a local room.
+ """
+ user_id = self.register_user("u1", "you the one")
+ user_token = self.login("u1", "you the one")
+
+ fake_knocking_user_id = "@user:other.example.com"
+
+ # Create a room with a room version that includes knocking
+ room_id = self.helper.create_room_as(
+ "u1",
+ is_public=False,
+ room_version=RoomVersions.MSC2403.identifier,
+ tok=user_token,
+ )
+
+ # Update the join rules and add additional state to the room to check for later
+ expected_room_state = self.send_example_state_events_to_room(
+ self.hs, room_id, user_id
+ )
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/federation/unstable/%s/make_knock/%s/%s?ver=%s"
+ % (
+ KNOCK_UNSTABLE_IDENTIFIER,
+ room_id,
+ fake_knocking_user_id,
+ # Inform the remote that we support the room version of the room we're
+ # knocking on
+ RoomVersions.MSC2403.identifier,
+ ),
+ )
+ self.assertEquals(200, channel.code, channel.result)
+
+ # Note: We don't expect the knock membership event to be sent over federation as
+ # part of the stripped room state, as the knocking homeserver already has that
+ # event. It is only done for clients during /sync
+
+ # Extract the generated knock event json
+ knock_event = channel.json_body["event"]
+
+ # Check that the event has things we expect in it
+ self.assertEquals(knock_event["room_id"], room_id)
+ self.assertEquals(knock_event["sender"], fake_knocking_user_id)
+ self.assertEquals(knock_event["state_key"], fake_knocking_user_id)
+ self.assertEquals(knock_event["type"], EventTypes.Member)
+ self.assertEquals(knock_event["content"]["membership"], Membership.KNOCK)
+
+ # Turn the event json dict into a proper event.
+ # We won't sign it properly, but that's OK as we stub out event auth in `prepare`
+ signed_knock_event = builder.create_local_event_from_event_dict(
+ self.clock,
+ self.hs.hostname,
+ self.hs.signing_key,
+ room_version=RoomVersions.MSC2403,
+ event_dict=knock_event,
+ )
+
+ # Convert our proper event back to json dict format
+ signed_knock_event_json = signed_knock_event.get_pdu_json(
+ self.clock.time_msec()
+ )
+
+ # Send the signed knock event into the room
+ channel = self.make_request(
+ "PUT",
+ "/_matrix/federation/unstable/%s/send_knock/%s/%s"
+ % (KNOCK_UNSTABLE_IDENTIFIER, room_id, signed_knock_event.event_id),
+ signed_knock_event_json,
+ )
+ self.assertEquals(200, channel.code, channel.result)
+
+ # Check that we got the stripped room state in return
+ room_state_events = channel.json_body["knock_state_events"]
+
+ # Validate the stripped room state events
+ self.check_knock_room_state_against_room_state(
+ room_state_events, expected_room_state
+ )
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index dbcbdf159a..be5737e420 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -17,10 +17,14 @@ import json
import synapse.rest.admin
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.rest.client.v1 import login, room
-from synapse.rest.client.v2_alpha import read_marker, sync
+from synapse.rest.client.v2_alpha import knock, read_marker, sync
from tests import unittest
+from tests.federation.transport.test_knocking import (
+ KnockingStrippedStateEventHelperMixin,
+)
from tests.server import TimedOutException
+from tests.unittest import override_config
class FilterTestCase(unittest.HomeserverTestCase):
@@ -305,6 +309,93 @@ class SyncTypingTests(unittest.HomeserverTestCase):
self.make_request("GET", sync_url % (access_token, next_batch))
+class SyncKnockTestCase(
+ unittest.HomeserverTestCase, KnockingStrippedStateEventHelperMixin
+):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ sync.register_servlets,
+ knock.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.url = "/sync?since=%s"
+ self.next_batch = "s0"
+
+ # Register the first user (used to create the room to knock on).
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+
+ # Create the room we'll knock on.
+ self.room_id = self.helper.create_room_as(
+ self.user_id,
+ is_public=False,
+ room_version="xyz.amorgan.knock",
+ tok=self.tok,
+ )
+
+ # Register the second user (used to knock on the room).
+ self.knocker = self.register_user("knocker", "monkey")
+ self.knocker_tok = self.login("knocker", "monkey")
+
+ # Perform an initial sync for the knocking user.
+ channel = self.make_request(
+ "GET",
+ self.url % self.next_batch,
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Store the next batch for the next request.
+ self.next_batch = channel.json_body["next_batch"]
+
+ # Set up some room state to test with.
+ self.expected_room_state = self.send_example_state_events_to_room(
+ hs, self.room_id, self.user_id
+ )
+
+ @override_config({"experimental_features": {"msc2403_enabled": True}})
+ def test_knock_room_state(self):
+ """Tests that /sync returns state from a room after knocking on it."""
+ # Knock on a room
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/unstable/xyz.amorgan.knock/%s" % (self.room_id,),
+ b"{}",
+ self.knocker_tok,
+ )
+ self.assertEquals(200, channel.code, channel.result)
+
+ # We expect to see the knock event in the stripped room state later
+ self.expected_room_state[EventTypes.Member] = {
+ "content": {"membership": "xyz.amorgan.knock", "displayname": "knocker"},
+ "state_key": "@knocker:test",
+ }
+
+ # Check that /sync includes stripped state from the room
+ channel = self.make_request(
+ "GET",
+ self.url % self.next_batch,
+ access_token=self.knocker_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Extract the stripped room state events from /sync
+ knock_entry = channel.json_body["rooms"]["xyz.amorgan.knock"]
+ room_state_events = knock_entry[self.room_id]["knock_state"]["events"]
+
+ # Validate that the knock membership event came last
+ self.assertEqual(room_state_events[-1]["type"], EventTypes.Member)
+
+ # Validate the stripped room state events
+ self.check_knock_room_state_against_room_state(
+ room_state_events, self.expected_room_state
+ )
+
+
class UnreadMessagesTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
@@ -447,7 +538,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
)
self._check_unread_count(5)
- def _check_unread_count(self, expected_count: True):
+ def _check_unread_count(self, expected_count: int):
"""Syncs and compares the unread count with the expected value."""
channel = self.make_request(
|