diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py
index ba2a2bfd64..9bd6275e92 100644
--- a/tests/appservice/test_appservice.py
+++ b/tests/appservice/test_appservice.py
@@ -19,6 +19,7 @@ from twisted.internet import defer
from synapse.appservice import ApplicationService, Namespace
from tests import unittest
+from tests.test_utils import simple_async_mock
def _regex(regex: str, exclusive: bool = True) -> Namespace:
@@ -39,13 +40,19 @@ class ApplicationServiceTestCase(unittest.TestCase):
)
self.store = Mock()
+ self.store.get_aliases_for_room = simple_async_mock([])
+ self.store.get_users_in_room = simple_async_mock([])
@defer.inlineCallbacks
def test_regex_user_id_prefix_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
self.event.sender = "@irc_foobar:matrix.org"
self.assertTrue(
- (yield defer.ensureDeferred(self.service.is_interested(self.event)))
+ (
+ yield defer.ensureDeferred(
+ self.service.is_interested(self.event, self.store)
+ )
+ )
)
@defer.inlineCallbacks
@@ -53,7 +60,11 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
self.event.sender = "@someone_else:matrix.org"
self.assertFalse(
- (yield defer.ensureDeferred(self.service.is_interested(self.event)))
+ (
+ yield defer.ensureDeferred(
+ self.service.is_interested(self.event, self.store)
+ )
+ )
)
@defer.inlineCallbacks
@@ -63,7 +74,11 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.event.type = "m.room.member"
self.event.state_key = "@irc_foobar:matrix.org"
self.assertTrue(
- (yield defer.ensureDeferred(self.service.is_interested(self.event)))
+ (
+ yield defer.ensureDeferred(
+ self.service.is_interested(self.event, self.store)
+ )
+ )
)
@defer.inlineCallbacks
@@ -73,7 +88,11 @@ class ApplicationServiceTestCase(unittest.TestCase):
)
self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org"
self.assertTrue(
- (yield defer.ensureDeferred(self.service.is_interested(self.event)))
+ (
+ yield defer.ensureDeferred(
+ self.service.is_interested(self.event, self.store)
+ )
+ )
)
@defer.inlineCallbacks
@@ -83,7 +102,11 @@ class ApplicationServiceTestCase(unittest.TestCase):
)
self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org"
self.assertFalse(
- (yield defer.ensureDeferred(self.service.is_interested(self.event)))
+ (
+ yield defer.ensureDeferred(
+ self.service.is_interested(self.event, self.store)
+ )
+ )
)
@defer.inlineCallbacks
@@ -91,10 +114,10 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org")
)
- self.store.get_aliases_for_room.return_value = defer.succeed(
+ self.store.get_aliases_for_room = simple_async_mock(
["#irc_foobar:matrix.org", "#athing:matrix.org"]
)
- self.store.get_users_in_room.return_value = defer.succeed([])
+ self.store.get_users_in_room = simple_async_mock([])
self.assertTrue(
(
yield defer.ensureDeferred(
@@ -144,10 +167,10 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org")
)
- self.store.get_aliases_for_room.return_value = defer.succeed(
+ self.store.get_aliases_for_room = simple_async_mock(
["#xmpp_foobar:matrix.org", "#athing:matrix.org"]
)
- self.store.get_users_in_room.return_value = defer.succeed([])
+ self.store.get_users_in_room = simple_async_mock([])
self.assertFalse(
(
yield defer.ensureDeferred(
@@ -163,10 +186,8 @@ class ApplicationServiceTestCase(unittest.TestCase):
)
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
self.event.sender = "@irc_foobar:matrix.org"
- self.store.get_aliases_for_room.return_value = defer.succeed(
- ["#irc_barfoo:matrix.org"]
- )
- self.store.get_users_in_room.return_value = defer.succeed([])
+ self.store.get_aliases_for_room = simple_async_mock(["#irc_barfoo:matrix.org"])
+ self.store.get_users_in_room = simple_async_mock([])
self.assertTrue(
(
yield defer.ensureDeferred(
@@ -184,17 +205,21 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.event.content = {"membership": "invite"}
self.event.state_key = self.service.sender
self.assertTrue(
- (yield defer.ensureDeferred(self.service.is_interested(self.event)))
+ (
+ yield defer.ensureDeferred(
+ self.service.is_interested(self.event, self.store)
+ )
+ )
)
@defer.inlineCallbacks
def test_member_list_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
# Note that @irc_fo:here is the AS user.
- self.store.get_users_in_room.return_value = defer.succeed(
+ self.store.get_users_in_room = simple_async_mock(
["@alice:here", "@irc_fo:here", "@bob:here"]
)
- self.store.get_aliases_for_room.return_value = defer.succeed([])
+ self.store.get_aliases_for_room = simple_async_mock([])
self.event.sender = "@xmpp_foobar:matrix.org"
self.assertTrue(
diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index 55f0899bae..8fb6687f89 100644
--- a/tests/appservice/test_scheduler.py
+++ b/tests/appservice/test_scheduler.py
@@ -11,23 +11,29 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import TYPE_CHECKING
from unittest.mock import Mock
from twisted.internet import defer
from synapse.appservice import ApplicationServiceState
from synapse.appservice.scheduler import (
+ ApplicationServiceScheduler,
_Recoverer,
- _ServiceQueuer,
_TransactionController,
)
from synapse.logging.context import make_deferred_yieldable
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests import unittest
from tests.test_utils import simple_async_mock
from ..utils import MockClock
+if TYPE_CHECKING:
+ from twisted.internet.testing import MemoryReactor
+
class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
def setUp(self):
@@ -58,7 +64,10 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
self.store.create_appservice_txn.assert_called_once_with(
- service=service, events=events, ephemeral=[] # txn made and saved
+ service=service,
+ events=events,
+ ephemeral=[],
+ to_device_messages=[], # txn made and saved
)
self.assertEquals(0, len(self.txnctrl.recoverers)) # no recoverer made
txn.complete.assert_called_once_with(self.store) # txn completed
@@ -79,7 +88,10 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
self.store.create_appservice_txn.assert_called_once_with(
- service=service, events=events, ephemeral=[] # txn made and saved
+ service=service,
+ events=events,
+ ephemeral=[],
+ to_device_messages=[], # txn made and saved
)
self.assertEquals(0, txn.send.call_count) # txn not sent though
self.assertEquals(0, txn.complete.call_count) # or completed
@@ -102,7 +114,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
self.store.create_appservice_txn.assert_called_once_with(
- service=service, events=events, ephemeral=[]
+ service=service, events=events, ephemeral=[], to_device_messages=[]
)
self.assertEquals(1, self.recoverer_fn.call_count) # recoverer made
self.assertEquals(1, self.recoverer.recover.call_count) # and invoked
@@ -189,38 +201,41 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
self.callback.assert_called_once_with(self.recoverer)
-class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
- def setUp(self):
+class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
+ def prepare(self, reactor: "MemoryReactor", clock: Clock, hs: HomeServer):
+ self.scheduler = ApplicationServiceScheduler(hs)
self.txn_ctrl = Mock()
self.txn_ctrl.send = simple_async_mock()
- self.queuer = _ServiceQueuer(self.txn_ctrl, MockClock())
+
+ # Replace instantiated _TransactionController instances with our Mock
+ self.scheduler.txn_ctrl = self.txn_ctrl
+ self.scheduler.queuer.txn_ctrl = self.txn_ctrl
def test_send_single_event_no_queue(self):
# Expect the event to be sent immediately.
service = Mock(id=4)
event = Mock()
- self.queuer.enqueue_event(service, event)
- self.txn_ctrl.send.assert_called_once_with(service, [event], [])
+ self.scheduler.enqueue_for_appservice(service, events=[event])
+ self.txn_ctrl.send.assert_called_once_with(service, [event], [], [])
def test_send_single_event_with_queue(self):
d = defer.Deferred()
- self.txn_ctrl.send = Mock(
- side_effect=lambda x, y, z: make_deferred_yieldable(d)
- )
+ self.txn_ctrl.send = Mock(return_value=make_deferred_yieldable(d))
service = Mock(id=4)
event = Mock(event_id="first")
event2 = Mock(event_id="second")
event3 = Mock(event_id="third")
# Send an event and don't resolve it just yet.
- self.queuer.enqueue_event(service, event)
+ self.scheduler.enqueue_for_appservice(service, events=[event])
# Send more events: expect send() to NOT be called multiple times.
- self.queuer.enqueue_event(service, event2)
- self.queuer.enqueue_event(service, event3)
- self.txn_ctrl.send.assert_called_with(service, [event], [])
+ # (call enqueue_for_appservice multiple times deliberately)
+ self.scheduler.enqueue_for_appservice(service, events=[event2])
+ self.scheduler.enqueue_for_appservice(service, events=[event3])
+ self.txn_ctrl.send.assert_called_with(service, [event], [], [])
self.assertEquals(1, self.txn_ctrl.send.call_count)
# Resolve the send event: expect the queued events to be sent
d.callback(service)
- self.txn_ctrl.send.assert_called_with(service, [event2, event3], [])
+ self.txn_ctrl.send.assert_called_with(service, [event2, event3], [], [])
self.assertEquals(2, self.txn_ctrl.send.call_count)
def test_multiple_service_queues(self):
@@ -238,23 +253,23 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
send_return_list = [srv_1_defer, srv_2_defer]
- def do_send(x, y, z):
+ def do_send(*args, **kwargs):
return make_deferred_yieldable(send_return_list.pop(0))
self.txn_ctrl.send = Mock(side_effect=do_send)
# send events for different ASes and make sure they are sent
- self.queuer.enqueue_event(srv1, srv_1_event)
- self.queuer.enqueue_event(srv1, srv_1_event2)
- self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [])
- self.queuer.enqueue_event(srv2, srv_2_event)
- self.queuer.enqueue_event(srv2, srv_2_event2)
- self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [])
+ self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event])
+ self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event2])
+ self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [], [])
+ self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event])
+ self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event2])
+ self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [], [])
# make sure callbacks for a service only send queued events for THAT
# service
srv_2_defer.callback(srv2)
- self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [])
+ self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [], [])
self.assertEquals(3, self.txn_ctrl.send.call_count)
def test_send_large_txns(self):
@@ -262,7 +277,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
srv_2_defer = defer.Deferred()
send_return_list = [srv_1_defer, srv_2_defer]
- def do_send(x, y, z):
+ def do_send(*args, **kwargs):
return make_deferred_yieldable(send_return_list.pop(0))
self.txn_ctrl.send = Mock(side_effect=do_send)
@@ -270,67 +285,65 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
service = Mock(id=4, name="service")
event_list = [Mock(name="event%i" % (i + 1)) for i in range(200)]
for event in event_list:
- self.queuer.enqueue_event(service, event)
+ self.scheduler.enqueue_for_appservice(service, [event], [])
# Expect the first event to be sent immediately.
- self.txn_ctrl.send.assert_called_with(service, [event_list[0]], [])
+ self.txn_ctrl.send.assert_called_with(service, [event_list[0]], [], [])
srv_1_defer.callback(service)
# Then send the next 100 events
- self.txn_ctrl.send.assert_called_with(service, event_list[1:101], [])
+ self.txn_ctrl.send.assert_called_with(service, event_list[1:101], [], [])
srv_2_defer.callback(service)
# Then the final 99 events
- self.txn_ctrl.send.assert_called_with(service, event_list[101:], [])
+ self.txn_ctrl.send.assert_called_with(service, event_list[101:], [], [])
self.assertEquals(3, self.txn_ctrl.send.call_count)
def test_send_single_ephemeral_no_queue(self):
# Expect the event to be sent immediately.
service = Mock(id=4, name="service")
event_list = [Mock(name="event")]
- self.queuer.enqueue_ephemeral(service, event_list)
- self.txn_ctrl.send.assert_called_once_with(service, [], event_list)
+ self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
+ self.txn_ctrl.send.assert_called_once_with(service, [], event_list, [])
def test_send_multiple_ephemeral_no_queue(self):
# Expect the event to be sent immediately.
service = Mock(id=4, name="service")
event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")]
- self.queuer.enqueue_ephemeral(service, event_list)
- self.txn_ctrl.send.assert_called_once_with(service, [], event_list)
+ self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
+ self.txn_ctrl.send.assert_called_once_with(service, [], event_list, [])
def test_send_single_ephemeral_with_queue(self):
d = defer.Deferred()
- self.txn_ctrl.send = Mock(
- side_effect=lambda x, y, z: make_deferred_yieldable(d)
- )
+ self.txn_ctrl.send = Mock(return_value=make_deferred_yieldable(d))
service = Mock(id=4)
event_list_1 = [Mock(event_id="event1"), Mock(event_id="event2")]
event_list_2 = [Mock(event_id="event3"), Mock(event_id="event4")]
event_list_3 = [Mock(event_id="event5"), Mock(event_id="event6")]
# Send an event and don't resolve it just yet.
- self.queuer.enqueue_ephemeral(service, event_list_1)
+ self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_1)
# Send more events: expect send() to NOT be called multiple times.
- self.queuer.enqueue_ephemeral(service, event_list_2)
- self.queuer.enqueue_ephemeral(service, event_list_3)
- self.txn_ctrl.send.assert_called_with(service, [], event_list_1)
+ self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_2)
+ self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_3)
+ self.txn_ctrl.send.assert_called_with(service, [], event_list_1, [])
self.assertEquals(1, self.txn_ctrl.send.call_count)
# Resolve txn_ctrl.send
d.callback(service)
# Expect the queued events to be sent
- self.txn_ctrl.send.assert_called_with(service, [], event_list_2 + event_list_3)
+ self.txn_ctrl.send.assert_called_with(
+ service, [], event_list_2 + event_list_3, []
+ )
self.assertEquals(2, self.txn_ctrl.send.call_count)
def test_send_large_txns_ephemeral(self):
d = defer.Deferred()
- self.txn_ctrl.send = Mock(
- side_effect=lambda x, y, z: make_deferred_yieldable(d)
- )
+ self.txn_ctrl.send = Mock(return_value=make_deferred_yieldable(d))
# Expect the event to be sent immediately.
service = Mock(id=4, name="service")
first_chunk = [Mock(name="event%i" % (i + 1)) for i in range(100)]
second_chunk = [Mock(name="event%i" % (i + 101)) for i in range(50)]
event_list = first_chunk + second_chunk
- self.queuer.enqueue_ephemeral(service, event_list)
- self.txn_ctrl.send.assert_called_once_with(service, [], first_chunk)
+ self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
+ self.txn_ctrl.send.assert_called_once_with(service, [], first_chunk, [])
d.callback(service)
- self.txn_ctrl.send.assert_called_with(service, [], second_chunk)
+ self.txn_ctrl.send.assert_called_with(service, [], second_chunk, [])
self.assertEquals(2, self.txn_ctrl.send.call_count)
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 7b486aba4a..e40ef95874 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -47,7 +47,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
)
# Get the room complexity
- channel = self.make_request(
+ channel = self.make_signed_federation_request(
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
)
self.assertEquals(200, channel.code)
@@ -59,7 +59,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
store.get_current_state_event_counts = lambda x: make_awaitable(500 * 1.23)
# Get the room complexity again -- make sure it's our artificial value
- channel = self.make_request(
+ channel = self.make_signed_federation_request(
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
)
self.assertEquals(200, channel.code)
diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index 03e1e11f49..d084919ef7 100644
--- a/tests/federation/test_federation_server.py
+++ b/tests/federation/test_federation_server.py
@@ -16,12 +16,21 @@ import logging
from parameterized import parameterized
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.config.server import DEFAULT_ROOM_VERSION
+from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events import make_event_from_dict
from synapse.federation.federation_server import server_matches_acl_event
from synapse.rest import admin
from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
from tests import unittest
+from tests.unittest import override_config
class FederationServerTests(unittest.FederatingHomeserverTestCase):
@@ -113,7 +122,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase):
room_1 = self.helper.create_room_as(u1, tok=u1_token)
self.inject_room_member(room_1, "@user:other.example.com", "join")
- channel = self.make_request(
+ channel = self.make_signed_federation_request(
"GET", "/_matrix/federation/v1/state/%s" % (room_1,)
)
self.assertEquals(200, channel.code, channel.result)
@@ -145,13 +154,152 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase):
room_1 = self.helper.create_room_as(u1, tok=u1_token)
- channel = self.make_request(
+ channel = self.make_signed_federation_request(
"GET", "/_matrix/federation/v1/state/%s" % (room_1,)
)
self.assertEquals(403, channel.code, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
+ super().prepare(reactor, clock, hs)
+
+ # create the room
+ creator_user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+ self._room_id = self.helper.create_room_as(
+ room_creator=creator_user_id, tok=tok
+ )
+
+ # a second member on the orgin HS
+ second_member_user_id = self.register_user("fozzie", "bear")
+ tok2 = self.login("fozzie", "bear")
+ self.helper.join(self._room_id, second_member_user_id, tok=tok2)
+
+ def _make_join(self, user_id) -> JsonDict:
+ channel = self.make_signed_federation_request(
+ "GET",
+ f"/_matrix/federation/v1/make_join/{self._room_id}/{user_id}"
+ f"?ver={DEFAULT_ROOM_VERSION}",
+ )
+ self.assertEquals(channel.code, 200, channel.json_body)
+ return channel.json_body
+
+ def test_send_join(self):
+ """happy-path test of send_join"""
+ joining_user = "@misspiggy:" + self.OTHER_SERVER_NAME
+ join_result = self._make_join(joining_user)
+
+ join_event_dict = join_result["event"]
+ add_hashes_and_signatures(
+ KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION],
+ join_event_dict,
+ signature_name=self.OTHER_SERVER_NAME,
+ signing_key=self.OTHER_SERVER_SIGNATURE_KEY,
+ )
+ channel = self.make_signed_federation_request(
+ "PUT",
+ f"/_matrix/federation/v2/send_join/{self._room_id}/x",
+ content=join_event_dict,
+ )
+ self.assertEquals(channel.code, 200, channel.json_body)
+
+ # we should get complete room state back
+ returned_state = [
+ (ev["type"], ev["state_key"]) for ev in channel.json_body["state"]
+ ]
+ self.assertCountEqual(
+ returned_state,
+ [
+ ("m.room.create", ""),
+ ("m.room.power_levels", ""),
+ ("m.room.join_rules", ""),
+ ("m.room.history_visibility", ""),
+ ("m.room.member", "@kermit:test"),
+ ("m.room.member", "@fozzie:test"),
+ # nb: *not* the joining user
+ ],
+ )
+
+ # also check the auth chain
+ returned_auth_chain_events = [
+ (ev["type"], ev["state_key"]) for ev in channel.json_body["auth_chain"]
+ ]
+ self.assertCountEqual(
+ returned_auth_chain_events,
+ [
+ ("m.room.create", ""),
+ ("m.room.member", "@kermit:test"),
+ ("m.room.power_levels", ""),
+ ("m.room.join_rules", ""),
+ ],
+ )
+
+ # the room should show that the new user is a member
+ r = self.get_success(
+ self.hs.get_state_handler().get_current_state(self._room_id)
+ )
+ self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
+
+ @override_config({"experimental_features": {"msc3706_enabled": True}})
+ def test_send_join_partial_state(self):
+ """When MSC3706 support is enabled, /send_join should return partial state"""
+ joining_user = "@misspiggy:" + self.OTHER_SERVER_NAME
+ join_result = self._make_join(joining_user)
+
+ join_event_dict = join_result["event"]
+ add_hashes_and_signatures(
+ KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION],
+ join_event_dict,
+ signature_name=self.OTHER_SERVER_NAME,
+ signing_key=self.OTHER_SERVER_SIGNATURE_KEY,
+ )
+ channel = self.make_signed_federation_request(
+ "PUT",
+ f"/_matrix/federation/v2/send_join/{self._room_id}/x?org.matrix.msc3706.partial_state=true",
+ content=join_event_dict,
+ )
+ self.assertEquals(channel.code, 200, channel.json_body)
+
+ # expect a reduced room state
+ returned_state = [
+ (ev["type"], ev["state_key"]) for ev in channel.json_body["state"]
+ ]
+ self.assertCountEqual(
+ returned_state,
+ [
+ ("m.room.create", ""),
+ ("m.room.power_levels", ""),
+ ("m.room.join_rules", ""),
+ ("m.room.history_visibility", ""),
+ ],
+ )
+
+ # the auth chain should not include anything already in "state"
+ returned_auth_chain_events = [
+ (ev["type"], ev["state_key"]) for ev in channel.json_body["auth_chain"]
+ ]
+ self.assertCountEqual(
+ returned_auth_chain_events,
+ [
+ ("m.room.member", "@kermit:test"),
+ ],
+ )
+
+ # the room should show that the new user is a member
+ r = self.get_success(
+ self.hs.get_state_handler().get_current_state(self._room_id)
+ )
+ self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
+
+
def _create_acl_event(content):
return make_event_from_dict(
{
diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py
index bfa156eebb..686f42ab48 100644
--- a/tests/federation/transport/test_knocking.py
+++ b/tests/federation/transport/test_knocking.py
@@ -245,7 +245,7 @@ class FederationKnockingTestCase(
self.hs, room_id, user_id
)
- channel = self.make_request(
+ channel = self.make_signed_federation_request(
"GET",
"/_matrix/federation/v1/make_knock/%s/%s?ver=%s"
% (
@@ -288,7 +288,7 @@ class FederationKnockingTestCase(
)
# Send the signed knock event into the room
- channel = self.make_request(
+ channel = self.make_signed_federation_request(
"PUT",
"/_matrix/federation/v1/send_knock/%s/%s"
% (room_id, signed_knock_event.event_id),
diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py
index 84fa72b9ff..eb62addda8 100644
--- a/tests/federation/transport/test_server.py
+++ b/tests/federation/transport/test_server.py
@@ -22,10 +22,9 @@ class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase):
"""Test that unauthenticated requests to the public rooms directory 403 when
allow_public_rooms_over_federation is False.
"""
- channel = self.make_request(
+ channel = self.make_signed_federation_request(
"GET",
"/_matrix/federation/v1/publicRooms",
- federation_auth_origin=b"example.com",
)
self.assertEquals(403, channel.code)
@@ -34,9 +33,8 @@ class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase):
"""Test that unauthenticated requests to the public rooms directory 200 when
allow_public_rooms_over_federation is True.
"""
- channel = self.make_request(
+ channel = self.make_signed_federation_request(
"GET",
"/_matrix/federation/v1/publicRooms",
- federation_auth_origin=b"example.com",
)
self.assertEquals(200, channel.code)
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index d6f14e2dba..fe57ff2671 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -1,4 +1,4 @@
-# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2015-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,18 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Dict, Iterable, List, Optional
from unittest.mock import Mock
from twisted.internet import defer
+import synapse.rest.admin
+import synapse.storage
+from synapse.appservice import ApplicationService
from synapse.handlers.appservice import ApplicationServicesHandler
+from synapse.rest.client import login, receipts, room, sendtodevice
from synapse.types import RoomStreamToken
+from synapse.util.stringutils import random_string
-from tests.test_utils import make_awaitable
+from tests import unittest
+from tests.test_utils import make_awaitable, simple_async_mock
from tests.utils import MockClock
-from .. import unittest
-
class AppServiceHandlerTestCase(unittest.TestCase):
"""Tests the ApplicationServicesHandler."""
@@ -36,6 +41,9 @@ class AppServiceHandlerTestCase(unittest.TestCase):
hs.get_datastore.return_value = self.mock_store
self.mock_store.get_received_ts.return_value = make_awaitable(0)
self.mock_store.set_appservice_last_pos.return_value = make_awaitable(None)
+ self.mock_store.set_appservice_stream_type_pos.return_value = make_awaitable(
+ None
+ )
hs.get_application_service_api.return_value = self.mock_as_api
hs.get_application_service_scheduler.return_value = self.mock_scheduler
hs.get_clock.return_value = MockClock()
@@ -63,8 +71,8 @@ class AppServiceHandlerTestCase(unittest.TestCase):
]
self.handler.notify_interested_services(RoomStreamToken(None, 1))
- self.mock_scheduler.submit_event_for_as.assert_called_once_with(
- interested_service, event
+ self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
+ interested_service, events=[event]
)
def test_query_user_exists_unknown_user(self):
@@ -261,7 +269,6 @@ class AppServiceHandlerTestCase(unittest.TestCase):
"""
interested_service = self._mkservice(is_interested=True)
services = [interested_service]
-
self.mock_store.get_app_services.return_value = services
self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable(
579
@@ -275,10 +282,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.handler.notify_interested_services_ephemeral(
"receipt_key", 580, ["@fakerecipient:example.com"]
)
- self.mock_scheduler.submit_ephemeral_events_for_as.assert_called_once_with(
- interested_service, [event]
+ self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
+ interested_service, ephemeral=[event]
)
- self.mock_store.set_type_stream_id_for_appservice.assert_called_once_with(
+ self.mock_store.set_appservice_stream_type_pos.assert_called_once_with(
interested_service,
"read_receipt",
580,
@@ -305,7 +312,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.handler.notify_interested_services_ephemeral(
"receipt_key", 580, ["@fakerecipient:example.com"]
)
- self.mock_scheduler.submit_ephemeral_events_for_as.assert_not_called()
+ # This method will be called, but with an empty list of events
+ self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
+ interested_service, ephemeral=[]
+ )
def _mkservice(self, is_interested, protocols=None):
service = Mock()
@@ -321,3 +331,252 @@ class AppServiceHandlerTestCase(unittest.TestCase):
service.token = "mock_service_token"
service.url = "mock_service_url"
return service
+
+
+class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
+ """
+ Tests that the ApplicationServicesHandler sends events to application
+ services correctly.
+ """
+
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ room.register_servlets,
+ sendtodevice.register_servlets,
+ receipts.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ # Mock the ApplicationServiceScheduler's _TransactionController's send method so that
+ # we can track any outgoing ephemeral events
+ self.send_mock = simple_async_mock()
+ hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock
+
+ # Mock out application services, and allow defining our own in tests
+ self._services: List[ApplicationService] = []
+ self.hs.get_datastore().get_app_services = Mock(return_value=self._services)
+
+ # A user on the homeserver.
+ self.local_user_device_id = "local_device"
+ self.local_user = self.register_user("local_user", "password")
+ self.local_user_token = self.login(
+ "local_user", "password", self.local_user_device_id
+ )
+
+ # A user on the homeserver which lies within an appservice's exclusive user namespace.
+ self.exclusive_as_user_device_id = "exclusive_as_device"
+ self.exclusive_as_user = self.register_user("exclusive_as_user", "password")
+ self.exclusive_as_user_token = self.login(
+ "exclusive_as_user", "password", self.exclusive_as_user_device_id
+ )
+
+ @unittest.override_config(
+ {"experimental_features": {"msc2409_to_device_messages_enabled": True}}
+ )
+ def test_application_services_receive_local_to_device(self):
+ """
+ Test that when a user sends a to-device message to another user
+ that is an application service's user namespace, the
+ application service will receive it.
+ """
+ interested_appservice = self._register_application_service(
+ namespaces={
+ ApplicationService.NS_USERS: [
+ {
+ "regex": "@exclusive_as_user:.+",
+ "exclusive": True,
+ }
+ ],
+ },
+ )
+
+ # Have local_user send a to-device message to exclusive_as_user
+ message_content = {"some_key": "some really interesting value"}
+ chan = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/sendToDevice/m.room_key_request/3",
+ content={
+ "messages": {
+ self.exclusive_as_user: {
+ self.exclusive_as_user_device_id: message_content
+ }
+ }
+ },
+ access_token=self.local_user_token,
+ )
+ self.assertEqual(chan.code, 200, chan.result)
+
+ # Have exclusive_as_user send a to-device message to local_user
+ chan = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/sendToDevice/m.room_key_request/4",
+ content={
+ "messages": {
+ self.local_user: {self.local_user_device_id: message_content}
+ }
+ },
+ access_token=self.exclusive_as_user_token,
+ )
+ self.assertEqual(chan.code, 200, chan.result)
+
+ # Check if our application service - that is interested in exclusive_as_user - received
+ # the to-device message as part of an AS transaction.
+ # Only the local_user -> exclusive_as_user to-device message should have been forwarded to the AS.
+ #
+ # The uninterested application service should not have been notified at all.
+ self.send_mock.assert_called_once()
+ service, _events, _ephemeral, to_device_messages = self.send_mock.call_args[0]
+
+ # Assert that this was the same to-device message that local_user sent
+ self.assertEqual(service, interested_appservice)
+ self.assertEqual(to_device_messages[0]["type"], "m.room_key_request")
+ self.assertEqual(to_device_messages[0]["sender"], self.local_user)
+
+ # Additional fields 'to_user_id' and 'to_device_id' specifically for
+ # to-device messages via the AS API
+ self.assertEqual(to_device_messages[0]["to_user_id"], self.exclusive_as_user)
+ self.assertEqual(
+ to_device_messages[0]["to_device_id"], self.exclusive_as_user_device_id
+ )
+ self.assertEqual(to_device_messages[0]["content"], message_content)
+
+ @unittest.override_config(
+ {"experimental_features": {"msc2409_to_device_messages_enabled": True}}
+ )
+ def test_application_services_receive_bursts_of_to_device(self):
+ """
+ Test that when a user sends >100 to-device messages at once, any
+ interested AS's will receive them in separate transactions.
+
+ Also tests that uninterested application services do not receive messages.
+ """
+ # Register two application services with exclusive interest in a user
+ interested_appservices = []
+ for _ in range(2):
+ appservice = self._register_application_service(
+ namespaces={
+ ApplicationService.NS_USERS: [
+ {
+ "regex": "@exclusive_as_user:.+",
+ "exclusive": True,
+ }
+ ],
+ },
+ )
+ interested_appservices.append(appservice)
+
+ # ...and an application service which does not have any user interest.
+ self._register_application_service()
+
+ to_device_message_content = {
+ "some key": "some interesting value",
+ }
+
+ # We need to send a large burst of to-device messages. We also would like to
+ # include them all in the same application service transaction so that we can
+ # test large transactions.
+ #
+ # To do this, we can send a single to-device message to many user devices at
+ # once.
+ #
+ # We insert number_of_messages - 1 messages into the database directly. We'll then
+ # send a final to-device message to the real device, which will also kick off
+ # an AS transaction (as just inserting messages into the DB won't).
+ number_of_messages = 150
+ fake_device_ids = [f"device_{num}" for num in range(number_of_messages - 1)]
+ messages = {
+ self.exclusive_as_user: {
+ device_id: to_device_message_content for device_id in fake_device_ids
+ }
+ }
+
+ # Create a fake device per message. We can't send to-device messages to
+ # a device that doesn't exist.
+ self.get_success(
+ self.hs.get_datastore().db_pool.simple_insert_many(
+ desc="test_application_services_receive_burst_of_to_device",
+ table="devices",
+ keys=("user_id", "device_id"),
+ values=[
+ (
+ self.exclusive_as_user,
+ device_id,
+ )
+ for device_id in fake_device_ids
+ ],
+ )
+ )
+
+ # Seed the device_inbox table with our fake messages
+ self.get_success(
+ self.hs.get_datastore().add_messages_to_device_inbox(messages, {})
+ )
+
+ # Now have local_user send a final to-device message to exclusive_as_user. All unsent
+ # to-device messages should be sent to any application services
+ # interested in exclusive_as_user.
+ chan = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/sendToDevice/m.room_key_request/4",
+ content={
+ "messages": {
+ self.exclusive_as_user: {
+ self.exclusive_as_user_device_id: to_device_message_content
+ }
+ }
+ },
+ access_token=self.local_user_token,
+ )
+ self.assertEqual(chan.code, 200, chan.result)
+
+ self.send_mock.assert_called()
+
+ # Count the total number of to-device messages that were sent out per-service.
+ # Ensure that we only sent to-device messages to interested services, and that
+ # each interested service received the full count of to-device messages.
+ service_id_to_message_count: Dict[str, int] = {}
+
+ for call in self.send_mock.call_args_list:
+ service, _events, _ephemeral, to_device_messages = call[0]
+
+ # Check that this was made to an interested service
+ self.assertIn(service, interested_appservices)
+
+ # Add to the count of messages for this application service
+ service_id_to_message_count.setdefault(service.id, 0)
+ service_id_to_message_count[service.id] += len(to_device_messages)
+
+ # Assert that each interested service received the full count of messages
+ for count in service_id_to_message_count.values():
+ self.assertEqual(count, number_of_messages)
+
+ def _register_application_service(
+ self,
+ namespaces: Optional[Dict[str, Iterable[Dict]]] = None,
+ ) -> ApplicationService:
+ """
+ Register a new application service, with the given namespaces of interest.
+
+ Args:
+ namespaces: A dictionary containing any user, room or alias namespaces that
+ the application service is interested in.
+
+ Returns:
+ The registered application service.
+ """
+ # Create an application service
+ appservice = ApplicationService(
+ token=random_string(10),
+ hostname="example.com",
+ id=random_string(10),
+ sender="@as:example.com",
+ rate_limited=False,
+ namespaces=namespaces,
+ supports_ephemeral=True,
+ )
+
+ # Register the application service
+ self._services.append(appservice)
+
+ return appservice
diff --git a/tests/handlers/test_deactivate_account.py b/tests/handlers/test_deactivate_account.py
index 3da597768c..01096a1581 100644
--- a/tests/handlers/test_deactivate_account.py
+++ b/tests/handlers/test_deactivate_account.py
@@ -217,3 +217,109 @@ class DeactivateAccountTestCase(HomeserverTestCase):
self.assertEqual(
self.get_success(self._store.ignored_by("@sheltie:test")), set()
)
+
+ def _rerun_retroactive_account_data_deletion_update(self) -> None:
+ # Reset the 'all done' flag
+ self._store.db_pool.updates._all_done = False
+
+ self.get_success(
+ self._store.db_pool.simple_insert(
+ "background_updates",
+ {
+ "update_name": "delete_account_data_for_deactivated_users",
+ "progress_json": "{}",
+ },
+ )
+ )
+
+ self.wait_for_background_updates()
+
+ def test_account_data_deleted_retroactively_by_background_update_if_deactivated(
+ self,
+ ) -> None:
+ """
+ Tests that a user, who deactivated their account before account data was
+ deleted automatically upon deactivation, has their account data retroactively
+ scrubbed by the background update.
+ """
+
+ # Request the deactivation of our account
+ self._deactivate_my_account()
+
+ # Add some account data
+ # (we do this after the deactivation so that the act of deactivating doesn't
+ # clear it out. This emulates a user that was deactivated before this was cleared
+ # upon deactivation.)
+ self.get_success(
+ self._store.add_account_data_for_user(
+ self.user,
+ AccountDataTypes.DIRECT,
+ {"@someone:remote": ["!somewhere:remote"]},
+ )
+ )
+
+ # Check that the account data is there.
+ self.assertIsNotNone(
+ self.get_success(
+ self._store.get_global_account_data_by_type_for_user(
+ self.user,
+ AccountDataTypes.DIRECT,
+ )
+ ),
+ )
+
+ # Re-run the retroactive deletion update
+ self._rerun_retroactive_account_data_deletion_update()
+
+ # Check that the account data was cleared.
+ self.assertIsNone(
+ self.get_success(
+ self._store.get_global_account_data_by_type_for_user(
+ self.user,
+ AccountDataTypes.DIRECT,
+ )
+ ),
+ )
+
+ def test_account_data_preserved_by_background_update_if_not_deactivated(
+ self,
+ ) -> None:
+ """
+ Tests that the background update does not scrub account data for users that have
+ not been deactivated.
+ """
+
+ # Add some account data
+ # (we do this after the deactivation so that the act of deactivating doesn't
+ # clear it out. This emulates a user that was deactivated before this was cleared
+ # upon deactivation.)
+ self.get_success(
+ self._store.add_account_data_for_user(
+ self.user,
+ AccountDataTypes.DIRECT,
+ {"@someone:remote": ["!somewhere:remote"]},
+ )
+ )
+
+ # Check that the account data is there.
+ self.assertIsNotNone(
+ self.get_success(
+ self._store.get_global_account_data_by_type_for_user(
+ self.user,
+ AccountDataTypes.DIRECT,
+ )
+ ),
+ )
+
+ # Re-run the retroactive deletion update
+ self._rerun_retroactive_account_data_deletion_update()
+
+ # Check that the account data was NOT cleared.
+ self.assertIsNotNone(
+ self.get_success(
+ self._store.get_global_account_data_by_type_for_user(
+ self.user,
+ AccountDataTypes.DIRECT,
+ )
+ ),
+ )
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index cfe3de5266..a552d8182e 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -155,7 +155,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.http_client = Mock(spec=["get_json"])
self.http_client.get_json.side_effect = get_json
- self.http_client.user_agent = "Synapse Test"
+ self.http_client.user_agent = b"Synapse Test"
hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
@@ -438,12 +438,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
state = "state"
nonce = "nonce"
client_redirect_url = "http://client/redirect"
- user_agent = "Browser"
ip_address = "10.0.0.1"
session = self._generate_oidc_session_token(state, nonce, client_redirect_url)
- request = _build_callback_request(
- code, state, session, user_agent=user_agent, ip_address=ip_address
- )
+ request = _build_callback_request(code, state, session, ip_address=ip_address)
self.get_success(self.handler.handle_oidc_callback(request))
@@ -1274,7 +1271,6 @@ def _build_callback_request(
code: str,
state: str,
session: str,
- user_agent: str = "Browser",
ip_address: str = "10.0.0.1",
):
"""Builds a fake SynapseRequest to mock the browser callback
@@ -1289,7 +1285,6 @@ def _build_callback_request(
query param. Should be the same as was embedded in the session in
_build_oidc_session.
session: the "session" which would have been passed around in the cookie.
- user_agent: the user-agent to present
ip_address: the IP address to pretend the request came from
"""
request = Mock(
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 94809cb8be..4740dd0a65 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -21,13 +21,15 @@ from twisted.internet import defer
import synapse
from synapse.api.constants import LoginType
+from synapse.api.errors import Codes
from synapse.handlers.auth import load_legacy_password_auth_providers
from synapse.module_api import ModuleApi
-from synapse.rest.client import devices, login, logout, register
+from synapse.rest.client import account, devices, login, logout, register
from synapse.types import JsonDict, UserID
from tests import unittest
from tests.server import FakeChannel
+from tests.test_utils import make_awaitable
from tests.unittest import override_config
# (possibly experimental) login flows we expect to appear in the list after the normal
@@ -158,6 +160,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
devices.register_servlets,
logout.register_servlets,
register.register_servlets,
+ account.register_servlets,
]
def setUp(self):
@@ -803,6 +806,77 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# Check that the callback has been called.
m.assert_called_once()
+ # Set some email configuration so the test doesn't fail because of its absence.
+ @override_config({"email": {"notif_from": "noreply@test"}})
+ def test_3pid_allowed(self):
+ """Tests that an is_3pid_allowed_callbacks forbidding a 3PID makes Synapse refuse
+ to bind the new 3PID, and that one allowing a 3PID makes Synapse accept to bind
+ the 3PID. Also checks that the module is passed a boolean indicating whether the
+ user to bind this 3PID to is currently registering.
+ """
+ self._test_3pid_allowed("rin", False)
+ self._test_3pid_allowed("kitay", True)
+
+ def _test_3pid_allowed(self, username: str, registration: bool):
+ """Tests that the "is_3pid_allowed" module callback is called correctly, using
+ either /register or /account URLs depending on the arguments.
+
+ Args:
+ username: The username to use for the test.
+ registration: Whether to test with registration URLs.
+ """
+ self.hs.get_identity_handler().send_threepid_validation = Mock(
+ return_value=make_awaitable(0),
+ )
+
+ m = Mock(return_value=make_awaitable(False))
+ self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m]
+
+ self.register_user(username, "password")
+ tok = self.login(username, "password")
+
+ if registration:
+ url = "/register/email/requestToken"
+ else:
+ url = "/account/3pid/email/requestToken"
+
+ channel = self.make_request(
+ "POST",
+ url,
+ {
+ "client_secret": "foo",
+ "email": "foo@test.com",
+ "send_attempt": 0,
+ },
+ access_token=tok,
+ )
+ self.assertEqual(channel.code, 403, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"],
+ Codes.THREEPID_DENIED,
+ channel.json_body,
+ )
+
+ m.assert_called_once_with("email", "foo@test.com", registration)
+
+ m = Mock(return_value=make_awaitable(True))
+ self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m]
+
+ channel = self.make_request(
+ "POST",
+ url,
+ {
+ "client_secret": "foo",
+ "email": "bar@test.com",
+ "send_attempt": 0,
+ },
+ access_token=tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertIn("sid", channel.json_body)
+
+ m.assert_called_once_with("email", "bar@test.com", registration)
+
def _setup_get_username_for_registration(self) -> Mock:
"""Registers a get_username_for_registration callback that appends "-foo" to the
username the client is trying to register.
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 70c621b825..482c90ef68 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -169,7 +169,9 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# Register an AS user.
user = self.register_user("user", "pass")
token = self.login(user, "pass")
- as_user = self.register_appservice_user("as_user_potato", self.appservice.token)
+ as_user, _ = self.register_appservice_user(
+ "as_user_potato", self.appservice.token
+ )
# Join the AS user to rooms owned by the normal user.
public, private = self._create_rooms_and_inject_memberships(
@@ -388,7 +390,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
def test_handle_local_profile_change_with_appservice_user(self) -> None:
# create user
- as_user_id = self.register_appservice_user(
+ as_user_id, _ = self.register_appservice_user(
"as_user_alice", self.appservice.token
)
diff --git a/tests/http/test_webclient.py b/tests/http/test_webclient.py
deleted file mode 100644
index ee5cf299f6..0000000000
--- a/tests/http/test_webclient.py
+++ /dev/null
@@ -1,108 +0,0 @@
-# Copyright 2022 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from http import HTTPStatus
-from typing import Dict
-
-from twisted.web.resource import Resource
-
-from synapse.app.homeserver import SynapseHomeServer
-from synapse.config.server import HttpListenerConfig, HttpResourceConfig, ListenerConfig
-from synapse.http.site import SynapseSite
-
-from tests.server import make_request
-from tests.unittest import HomeserverTestCase, create_resource_tree, override_config
-
-
-class WebClientTests(HomeserverTestCase):
- @override_config(
- {
- "web_client_location": "https://example.org",
- }
- )
- def test_webclient_resolves_with_client_resource(self):
- """
- Tests that both client and webclient resources can be accessed simultaneously.
-
- This is a regression test created in response to https://github.com/matrix-org/synapse/issues/11763.
- """
- for resource_name_order_list in [
- ["webclient", "client"],
- ["client", "webclient"],
- ]:
- # Create a dictionary from path regex -> resource
- resource_dict: Dict[str, Resource] = {}
-
- for resource_name in resource_name_order_list:
- resource_dict.update(
- SynapseHomeServer._configure_named_resource(self.hs, resource_name)
- )
-
- # Create a root resource which ties the above resources together into one
- root_resource = Resource()
- create_resource_tree(resource_dict, root_resource)
-
- # Create a site configured with this resource to make HTTP requests against
- listener_config = ListenerConfig(
- port=8008,
- bind_addresses=["127.0.0.1"],
- type="http",
- http_options=HttpListenerConfig(
- resources=[HttpResourceConfig(names=resource_name_order_list)]
- ),
- )
- test_site = SynapseSite(
- logger_name="synapse.access.http.fake",
- site_tag=self.hs.config.server.server_name,
- config=listener_config,
- resource=root_resource,
- server_version_string="1",
- max_request_body_size=1234,
- reactor=self.reactor,
- )
-
- # Attempt to make requests to endpoints on both the webclient and client resources
- # on test_site.
- self._request_client_and_webclient_resources(test_site)
-
- def _request_client_and_webclient_resources(self, test_site: SynapseSite) -> None:
- """Make a request to an endpoint on both the webclient and client-server resources
- of the given SynapseSite.
-
- Args:
- test_site: The SynapseSite object to make requests against.
- """
-
- # Ensure that the *webclient* resource is behaving as expected (we get redirected to
- # the configured web_client_location)
- channel = make_request(
- self.reactor,
- site=test_site,
- method="GET",
- path="/_matrix/client",
- )
- # Check that we are being redirected to the webclient location URI.
- self.assertEqual(channel.code, HTTPStatus.FOUND)
- self.assertEqual(
- channel.headers.getRawHeaders("Location"), ["https://example.org"]
- )
-
- # Ensure that a request to the *client* resource works.
- channel = make_request(
- self.reactor,
- site=test_site,
- method="GET",
- path="/_matrix/client/v3/login",
- )
- self.assertEqual(channel.code, HTTPStatus.OK)
- self.assertIn("flows", channel.json_body)
diff --git a/tests/logging/test_opentracing.py b/tests/logging/test_opentracing.py
new file mode 100644
index 0000000000..e430941d27
--- /dev/null
+++ b/tests/logging/test_opentracing.py
@@ -0,0 +1,184 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactorClock
+
+from synapse.logging.context import (
+ LoggingContext,
+ make_deferred_yieldable,
+ run_in_background,
+)
+from synapse.logging.opentracing import (
+ start_active_span,
+ start_active_span_follows_from,
+)
+from synapse.util import Clock
+
+try:
+ from synapse.logging.scopecontextmanager import LogContextScopeManager
+except ImportError:
+ LogContextScopeManager = None # type: ignore
+
+try:
+ import jaeger_client
+except ImportError:
+ jaeger_client = None # type: ignore
+
+from tests.unittest import TestCase
+
+
+class LogContextScopeManagerTestCase(TestCase):
+ if LogContextScopeManager is None:
+ skip = "Requires opentracing" # type: ignore[unreachable]
+ if jaeger_client is None:
+ skip = "Requires jaeger_client" # type: ignore[unreachable]
+
+ def setUp(self) -> None:
+ # since this is a unit test, we don't really want to mess around with the
+ # global variables that power opentracing. We create our own tracer instance
+ # and test with it.
+
+ scope_manager = LogContextScopeManager({})
+ config = jaeger_client.config.Config(
+ config={}, service_name="test", scope_manager=scope_manager
+ )
+
+ self._reporter = jaeger_client.reporter.InMemoryReporter()
+
+ self._tracer = config.create_tracer(
+ sampler=jaeger_client.ConstSampler(True),
+ reporter=self._reporter,
+ )
+
+ def test_start_active_span(self) -> None:
+ # the scope manager assumes a logging context of some sort.
+ with LoggingContext("root context"):
+ self.assertIsNone(self._tracer.active_span)
+
+ # start_active_span should start and activate a span.
+ scope = start_active_span("span", tracer=self._tracer)
+ span = scope.span
+ self.assertEqual(self._tracer.active_span, span)
+ self.assertIsNotNone(span.start_time)
+
+ # entering the context doesn't actually do a whole lot.
+ with scope as ctx:
+ self.assertIs(ctx, scope)
+ self.assertEqual(self._tracer.active_span, span)
+
+ # ... but leaving it unsets the active span, and finishes the span.
+ self.assertIsNone(self._tracer.active_span)
+ self.assertIsNotNone(span.end_time)
+
+ # the span should have been reported
+ self.assertEqual(self._reporter.get_spans(), [span])
+
+ def test_nested_spans(self) -> None:
+ """Starting two spans off inside each other should work"""
+
+ with LoggingContext("root context"):
+ with start_active_span("root span", tracer=self._tracer) as root_scope:
+ self.assertEqual(self._tracer.active_span, root_scope.span)
+
+ scope1 = start_active_span(
+ "child1",
+ tracer=self._tracer,
+ )
+ self.assertEqual(
+ self._tracer.active_span, scope1.span, "child1 was not activated"
+ )
+ self.assertEqual(
+ scope1.span.context.parent_id, root_scope.span.context.span_id
+ )
+
+ scope2 = start_active_span_follows_from(
+ "child2",
+ contexts=(scope1,),
+ tracer=self._tracer,
+ )
+ self.assertEqual(self._tracer.active_span, scope2.span)
+ self.assertEqual(
+ scope2.span.context.parent_id, scope1.span.context.span_id
+ )
+
+ with scope1, scope2:
+ pass
+
+ # the root scope should be restored
+ self.assertEqual(self._tracer.active_span, root_scope.span)
+ self.assertIsNotNone(scope2.span.end_time)
+ self.assertIsNotNone(scope1.span.end_time)
+
+ self.assertIsNone(self._tracer.active_span)
+
+ # the spans should be reported in order of their finishing.
+ self.assertEqual(
+ self._reporter.get_spans(), [scope2.span, scope1.span, root_scope.span]
+ )
+
+ def test_overlapping_spans(self) -> None:
+ """Overlapping spans which are not neatly nested should work"""
+ reactor = MemoryReactorClock()
+ clock = Clock(reactor)
+
+ scopes = []
+
+ async def task(i: int):
+ scope = start_active_span(
+ f"task{i}",
+ tracer=self._tracer,
+ )
+ scopes.append(scope)
+
+ self.assertEqual(self._tracer.active_span, scope.span)
+ await clock.sleep(4)
+ self.assertEqual(self._tracer.active_span, scope.span)
+ scope.close()
+
+ async def root():
+ with start_active_span("root span", tracer=self._tracer) as root_scope:
+ self.assertEqual(self._tracer.active_span, root_scope.span)
+ scopes.append(root_scope)
+
+ d1 = run_in_background(task, 1)
+ await clock.sleep(2)
+ d2 = run_in_background(task, 2)
+
+ # because we did run_in_background, the active span should still be the
+ # root.
+ self.assertEqual(self._tracer.active_span, root_scope.span)
+
+ await make_deferred_yieldable(
+ defer.gatherResults([d1, d2], consumeErrors=True)
+ )
+
+ self.assertEqual(self._tracer.active_span, root_scope.span)
+
+ with LoggingContext("root context"):
+ # start the test off
+ d1 = defer.ensureDeferred(root())
+
+ # let the tasks complete
+ reactor.pump((2,) * 8)
+
+ self.successResultOf(d1)
+ self.assertIsNone(self._tracer.active_span)
+
+ # the spans should be reported in order of their finishing: task 1, task 2,
+ # root.
+ self.assertEqual(
+ self._reporter.get_spans(),
+ [scopes[1].span, scopes[2].span, scopes[0].span],
+ )
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index cb02eddf07..9fc50f8852 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -14,6 +14,7 @@
import logging
from typing import Any, Dict, List, Optional, Tuple
+from twisted.internet.address import IPv4Address
from twisted.internet.protocol import Protocol
from twisted.web.resource import Resource
@@ -53,7 +54,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
server_factory = ReplicationStreamProtocolFactory(hs)
self.streamer = hs.get_replication_streamer()
self.server: ServerReplicationStreamProtocol = server_factory.buildProtocol(
- None
+ IPv4Address("TCP", "127.0.0.1", 0)
)
# Make a new HomeServer object for the worker
@@ -345,7 +346,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
self.clock,
repl_handler,
)
- server = self.server_factory.buildProtocol(None)
+ server = self.server_factory.buildProtocol(
+ IPv4Address("TCP", "127.0.0.1", 0)
+ )
client_transport = FakeTransport(server, self.reactor)
client.makeConnection(client_transport)
diff --git a/tests/replication/tcp/test_remote_server_up.py b/tests/replication/tcp/test_remote_server_up.py
index 262c35cef3..545f11acd1 100644
--- a/tests/replication/tcp/test_remote_server_up.py
+++ b/tests/replication/tcp/test_remote_server_up.py
@@ -14,6 +14,7 @@
from typing import Tuple
+from twisted.internet.address import IPv4Address
from twisted.internet.interfaces import IProtocol
from twisted.test.proto_helpers import StringTransport
@@ -29,7 +30,7 @@ class RemoteServerUpTestCase(HomeserverTestCase):
def _make_client(self) -> Tuple[IProtocol, StringTransport]:
"""Create a new direct TCP replication connection"""
- proto = self.factory.buildProtocol(("127.0.0.1", 0))
+ proto = self.factory.buildProtocol(IPv4Address("TCP", "127.0.0.1", 0))
transport = StringTransport()
proto.makeConnection(transport)
diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py
index 249808b031..989e801768 100644
--- a/tests/rest/client/test_capabilities.py
+++ b/tests/rest/client/test_capabilities.py
@@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from http import HTTPStatus
+
import synapse.rest.admin
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.rest.client import capabilities, login
@@ -28,7 +30,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
]
def make_homeserver(self, reactor, clock):
- self.url = b"/_matrix/client/r0/capabilities"
+ self.url = b"/capabilities"
hs = self.setup_test_homeserver()
self.config = hs.config
self.auth_handler = hs.get_auth_handler()
@@ -96,39 +98,20 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertFalse(capabilities["m.change_password"]["enabled"])
- def test_get_change_users_attributes_capabilities_when_msc3283_disabled(self):
- """Test that per default msc3283 is disabled server returns `m.change_password`."""
+ def test_get_change_users_attributes_capabilities(self):
+ """Test that server returns capabilities by default."""
access_token = self.login(self.localpart, self.password)
channel = self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]
- self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.code, HTTPStatus.OK)
self.assertTrue(capabilities["m.change_password"]["enabled"])
- self.assertNotIn("org.matrix.msc3283.set_displayname", capabilities)
- self.assertNotIn("org.matrix.msc3283.set_avatar_url", capabilities)
- self.assertNotIn("org.matrix.msc3283.3pid_changes", capabilities)
-
- @override_config({"experimental_features": {"msc3283_enabled": True}})
- def test_get_change_users_attributes_capabilities_when_msc3283_enabled(self):
- """Test if msc3283 is enabled server returns capabilities."""
- access_token = self.login(self.localpart, self.password)
-
- channel = self.make_request("GET", self.url, access_token=access_token)
- capabilities = channel.json_body["capabilities"]
+ self.assertTrue(capabilities["m.set_displayname"]["enabled"])
+ self.assertTrue(capabilities["m.set_avatar_url"]["enabled"])
+ self.assertTrue(capabilities["m.3pid_changes"]["enabled"])
- self.assertEqual(channel.code, 200)
- self.assertTrue(capabilities["m.change_password"]["enabled"])
- self.assertTrue(capabilities["org.matrix.msc3283.set_displayname"]["enabled"])
- self.assertTrue(capabilities["org.matrix.msc3283.set_avatar_url"]["enabled"])
- self.assertTrue(capabilities["org.matrix.msc3283.3pid_changes"]["enabled"])
-
- @override_config(
- {
- "enable_set_displayname": False,
- "experimental_features": {"msc3283_enabled": True},
- }
- )
+ @override_config({"enable_set_displayname": False})
def test_get_set_displayname_capabilities_displayname_disabled(self):
"""Test if set displayname is disabled that the server responds it."""
access_token = self.login(self.localpart, self.password)
@@ -136,15 +119,10 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
channel = self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]
- self.assertEqual(channel.code, 200)
- self.assertFalse(capabilities["org.matrix.msc3283.set_displayname"]["enabled"])
-
- @override_config(
- {
- "enable_set_avatar_url": False,
- "experimental_features": {"msc3283_enabled": True},
- }
- )
+ self.assertEqual(channel.code, HTTPStatus.OK)
+ self.assertFalse(capabilities["m.set_displayname"]["enabled"])
+
+ @override_config({"enable_set_avatar_url": False})
def test_get_set_avatar_url_capabilities_avatar_url_disabled(self):
"""Test if set avatar_url is disabled that the server responds it."""
access_token = self.login(self.localpart, self.password)
@@ -152,24 +130,19 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
channel = self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]
- self.assertEqual(channel.code, 200)
- self.assertFalse(capabilities["org.matrix.msc3283.set_avatar_url"]["enabled"])
-
- @override_config(
- {
- "enable_3pid_changes": False,
- "experimental_features": {"msc3283_enabled": True},
- }
- )
- def test_change_3pid_capabilities_3pid_disabled(self):
+ self.assertEqual(channel.code, HTTPStatus.OK)
+ self.assertFalse(capabilities["m.set_avatar_url"]["enabled"])
+
+ @override_config({"enable_3pid_changes": False})
+ def test_get_change_3pid_capabilities_3pid_disabled(self):
"""Test if change 3pid is disabled that the server responds it."""
access_token = self.login(self.localpart, self.password)
channel = self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]
- self.assertEqual(channel.code, 200)
- self.assertFalse(capabilities["org.matrix.msc3283.3pid_changes"]["enabled"])
+ self.assertEqual(channel.code, HTTPStatus.OK)
+ self.assertFalse(capabilities["m.3pid_changes"]["enabled"])
@override_config({"experimental_features": {"msc3244_enabled": False}})
def test_get_does_not_include_msc3244_fields_when_disabled(self):
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index 407dd32a73..0f1c47dcbb 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -1154,7 +1154,7 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
servlets = [register.register_servlets]
- url = "/_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity"
+ url = "/_matrix/client/v1/register/m.login.registration_token/validity"
def default_config(self):
config = super().default_config()
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 96ae7790bb..de80aca037 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -21,7 +21,8 @@ from unittest.mock import patch
from synapse.api.constants import EventTypes, RelationTypes
from synapse.rest import admin
from synapse.rest.client import login, register, relations, room, sync
-from synapse.types import JsonDict
+from synapse.storage.relations import RelationPaginationToken
+from synapse.types import JsonDict, StreamToken
from tests import unittest
from tests.server import FakeChannel
@@ -168,24 +169,28 @@ class RelationsTestCase(unittest.HomeserverTestCase):
"""Tests that calling pagination API correctly the latest relations."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
self.assertEquals(200, channel.code, channel.json_body)
+ first_annotation_id = channel.json_body["event_id"]
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
self.assertEquals(200, channel.code, channel.json_body)
- annotation_id = channel.json_body["event_id"]
+ second_annotation_id = channel.json_body["event_id"]
channel = self.make_request(
"GET",
- "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1"
- % (self.room, self.parent_id),
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
- # We expect to get back a single pagination result, which is the full
- # relation event we sent above.
+ # We expect to get back a single pagination result, which is the latest
+ # full relation event we sent above.
self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body)
self.assert_dict(
- {"event_id": annotation_id, "sender": self.user_id, "type": "m.reaction"},
+ {
+ "event_id": second_annotation_id,
+ "sender": self.user_id,
+ "type": "m.reaction",
+ },
channel.json_body["chunk"][0],
)
@@ -200,6 +205,36 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel.json_body.get("next_batch"), str, channel.json_body
)
+ # Request the relations again, but with a different direction.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations"
+ f"/{self.parent_id}?limit=1&org.matrix.msc3715.dir=f",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ # We expect to get back a single pagination result, which is the earliest
+ # full relation event we sent above.
+ self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body)
+ self.assert_dict(
+ {
+ "event_id": first_annotation_id,
+ "sender": self.user_id,
+ "type": "m.reaction",
+ },
+ channel.json_body["chunk"][0],
+ )
+
+ def _stream_token_to_relation_token(self, token: str) -> str:
+ """Convert a StreamToken into a legacy token (RelationPaginationToken)."""
+ room_key = self.get_success(StreamToken.from_string(self.store, token)).room_key
+ return self.get_success(
+ RelationPaginationToken(
+ topological=room_key.topological, stream=room_key.stream
+ ).to_string(self.store)
+ )
+
def test_repeated_paginate_relations(self):
"""Test that if we paginate using a limit and tokens then we get the
expected events.
@@ -213,7 +248,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code, channel.json_body)
expected_event_ids.append(channel.json_body["event_id"])
- prev_token: Optional[str] = None
+ prev_token = ""
found_event_ids: List[str] = []
for _ in range(20):
from_token = ""
@@ -222,8 +257,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1%s"
- % (self.room, self.parent_id, from_token),
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
@@ -241,6 +275,93 @@ class RelationsTestCase(unittest.HomeserverTestCase):
found_event_ids.reverse()
self.assertEquals(found_event_ids, expected_event_ids)
+ # Reset and try again, but convert the tokens to the legacy format.
+ prev_token = ""
+ found_event_ids = []
+ for _ in range(20):
+ from_token = ""
+ if prev_token:
+ from_token = "&from=" + self._stream_token_to_relation_token(prev_token)
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
+ next_batch = channel.json_body.get("next_batch")
+
+ self.assertNotEquals(prev_token, next_batch)
+ prev_token = next_batch
+
+ if not prev_token:
+ break
+
+ # We paginated backwards, so reverse
+ found_event_ids.reverse()
+ self.assertEquals(found_event_ids, expected_event_ids)
+
+ def test_pagination_from_sync_and_messages(self):
+ """Pagination tokens from /sync and /messages can be used to paginate /relations."""
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A")
+ self.assertEquals(200, channel.code, channel.json_body)
+ annotation_id = channel.json_body["event_id"]
+ # Send an event after the relation events.
+ self.helper.send(self.room, body="Latest event", tok=self.user_token)
+
+ # Request /sync, limiting it such that only the latest event is returned
+ # (and not the relation).
+ filter = urllib.parse.quote_plus(
+ '{"room": {"timeline": {"limit": 1}}}'.encode()
+ )
+ channel = self.make_request(
+ "GET", f"/sync?filter={filter}", access_token=self.user_token
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
+ sync_prev_batch = room_timeline["prev_batch"]
+ self.assertIsNotNone(sync_prev_batch)
+ # Ensure the relation event is not in the batch returned from /sync.
+ self.assertNotIn(
+ annotation_id, [ev["event_id"] for ev in room_timeline["events"]]
+ )
+
+ # Request /messages, limiting it such that only the latest event is
+ # returned (and not the relation).
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{self.room}/messages?dir=b&limit=1",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ messages_end = channel.json_body["end"]
+ self.assertIsNotNone(messages_end)
+ # Ensure the relation event is not in the chunk returned from /messages.
+ self.assertNotIn(
+ annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
+ )
+
+ # Request /relations with the pagination tokens received from both the
+ # /sync and /messages responses above, in turn.
+ #
+ # This is a tiny bit silly since the client wouldn't know the parent ID
+ # from the requests above; consider the parent ID to be known from a
+ # previous /sync.
+ for from_token in (sync_prev_batch, messages_end):
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?from={from_token}",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ # The relation should be in the returned chunk.
+ self.assertIn(
+ annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
+ )
+
def test_aggregation_pagination_groups(self):
"""Test that we can paginate annotation groups correctly."""
@@ -337,7 +458,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
self.assertEquals(200, channel.code, channel.json_body)
- prev_token: Optional[str] = None
+ prev_token = ""
found_event_ids: List[str] = []
encoded_key = urllib.parse.quote_plus("👍".encode())
for _ in range(20):
@@ -347,15 +468,42 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- "/_matrix/client/unstable/rooms/%s"
- "/aggregations/%s/%s/m.reaction/%s?limit=1%s"
- % (
- self.room,
- self.parent_id,
- RelationTypes.ANNOTATION,
- encoded_key,
- from_token,
- ),
+ f"/_matrix/client/unstable/rooms/{self.room}"
+ f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}"
+ f"/m.reaction/{encoded_key}?limit=1{from_token}",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
+
+ found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
+
+ next_batch = channel.json_body.get("next_batch")
+
+ self.assertNotEquals(prev_token, next_batch)
+ prev_token = next_batch
+
+ if not prev_token:
+ break
+
+ # We paginated backwards, so reverse
+ found_event_ids.reverse()
+ self.assertEquals(found_event_ids, expected_event_ids)
+
+ # Reset and try again, but convert the tokens to the legacy format.
+ prev_token = ""
+ found_event_ids = []
+ for _ in range(20):
+ from_token = ""
+ if prev_token:
+ from_token = "&from=" + self._stream_token_to_relation_token(prev_token)
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}"
+ f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}"
+ f"/m.reaction/{encoded_key}?limit=1{from_token}",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
@@ -453,7 +601,9 @@ class RelationsTestCase(unittest.HomeserverTestCase):
)
self.assertEquals(400, channel.code, channel.json_body)
- @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
+ @unittest.override_config(
+ {"experimental_features": {"msc3440_enabled": True, "msc3666_enabled": True}}
+ )
def test_bundled_aggregations(self):
"""
Test that annotations, references, and threads get correctly bundled.
@@ -579,6 +729,23 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertTrue(room_timeline["limited"])
assert_bundle(self._find_event_in_chunk(room_timeline["events"]))
+ # Request search.
+ channel = self.make_request(
+ "POST",
+ "/search",
+ # Search term matches the parent message.
+ content={"search_categories": {"room_events": {"search_term": "Hi"}}},
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ chunk = [
+ result["result"]
+ for result in channel.json_body["search_categories"]["room_events"][
+ "results"
+ ]
+ ]
+ assert_bundle(self._find_event_in_chunk(chunk))
+
def test_aggregation_get_event_for_annotation(self):
"""Test that annotations do not get bundled aggregations included
when directly requested.
@@ -759,6 +926,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code, channel.json_body)
self.assertNotIn("m.relations", channel.json_body["unsigned"])
+ @unittest.override_config({"experimental_features": {"msc3666_enabled": True}})
def test_edit(self):
"""Test that a simple edit works."""
@@ -825,6 +993,23 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertTrue(room_timeline["limited"])
assert_bundle(self._find_event_in_chunk(room_timeline["events"]))
+ # Request search.
+ channel = self.make_request(
+ "POST",
+ "/search",
+ # Search term matches the parent message.
+ content={"search_categories": {"room_events": {"search_term": "Hi"}}},
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ chunk = [
+ result["result"]
+ for result in channel.json_body["search_categories"]["room_events"][
+ "results"
+ ]
+ ]
+ assert_bundle(self._find_event_in_chunk(chunk))
+
def test_multi_edit(self):
"""Test that multiple edits, including attempts by people who
shouldn't be allowed, are correctly handled.
diff --git a/tests/rest/client/test_room_batch.py b/tests/rest/client/test_room_batch.py
index 721454c187..e9f8704035 100644
--- a/tests/rest/client/test_room_batch.py
+++ b/tests/rest/client/test_room_batch.py
@@ -89,7 +89,7 @@ class RoomBatchTestCase(unittest.HomeserverTestCase):
self.clock = clock
self.storage = hs.get_storage()
- self.virtual_user_id = self.register_appservice_user(
+ self.virtual_user_id, _ = self.register_appservice_user(
"as_user_potato", self.appservice.token
)
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 10a4a4dc5e..b7f086927b 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -18,7 +18,7 @@
"""Tests REST events for /rooms paths."""
import json
-from typing import Dict, Iterable, List, Optional
+from typing import Iterable, List
from unittest.mock import Mock, call
from urllib import parse as urlparse
@@ -35,7 +35,7 @@ from synapse.api.errors import Codes, HttpResponseException
from synapse.handlers.pagination import PurgeStatus
from synapse.rest import admin
from synapse.rest.client import account, directory, login, profile, room, sync
-from synapse.types import JsonDict, Requester, RoomAlias, UserID, create_requester
+from synapse.types import JsonDict, RoomAlias, UserID, create_requester
from synapse.util.stringutils import random_string
from tests import unittest
@@ -674,121 +674,6 @@ class RoomsCreateTestCase(RoomBase):
channel = self.make_request("POST", "/createRoom", content)
self.assertEqual(200, channel.code)
- def test_spamchecker_invites(self):
- """Tests the user_may_create_room_with_invites spam checker callback."""
-
- # Mock do_3pid_invite, so we don't fail from failing to send a 3PID invite to an
- # IS.
- async def do_3pid_invite(
- room_id: str,
- inviter: UserID,
- medium: str,
- address: str,
- id_server: str,
- requester: Requester,
- txn_id: Optional[str],
- id_access_token: Optional[str] = None,
- ) -> int:
- return 0
-
- do_3pid_invite_mock = Mock(side_effect=do_3pid_invite)
- self.hs.get_room_member_handler().do_3pid_invite = do_3pid_invite_mock
-
- # Add a mock callback for user_may_create_room_with_invites. Make it allow any
- # room creation request for now.
- return_value = True
-
- async def user_may_create_room_with_invites(
- user: str,
- invites: List[str],
- threepid_invites: List[Dict[str, str]],
- ) -> bool:
- return return_value
-
- callback_mock = Mock(side_effect=user_may_create_room_with_invites)
- self.hs.get_spam_checker()._user_may_create_room_with_invites_callbacks.append(
- callback_mock,
- )
-
- # The MXIDs we'll try to invite.
- invited_mxids = [
- "@alice1:red",
- "@alice2:red",
- "@alice3:red",
- "@alice4:red",
- ]
-
- # The 3PIDs we'll try to invite.
- invited_3pids = [
- {
- "id_server": "example.com",
- "id_access_token": "sometoken",
- "medium": "email",
- "address": "alice1@example.com",
- },
- {
- "id_server": "example.com",
- "id_access_token": "sometoken",
- "medium": "email",
- "address": "alice2@example.com",
- },
- {
- "id_server": "example.com",
- "id_access_token": "sometoken",
- "medium": "email",
- "address": "alice3@example.com",
- },
- ]
-
- # Create a room and invite the Matrix users, and check that it succeeded.
- channel = self.make_request(
- "POST",
- "/createRoom",
- json.dumps({"invite": invited_mxids}).encode("utf8"),
- )
- self.assertEqual(200, channel.code)
-
- # Check that the callback was called with the right arguments.
- expected_call_args = ((self.user_id, invited_mxids, []),)
- self.assertEquals(
- callback_mock.call_args,
- expected_call_args,
- callback_mock.call_args,
- )
-
- # Create a room and invite the 3PIDs, and check that it succeeded.
- channel = self.make_request(
- "POST",
- "/createRoom",
- json.dumps({"invite_3pid": invited_3pids}).encode("utf8"),
- )
- self.assertEqual(200, channel.code)
-
- # Check that do_3pid_invite was called the right amount of time
- self.assertEquals(do_3pid_invite_mock.call_count, len(invited_3pids))
-
- # Check that the callback was called with the right arguments.
- expected_call_args = ((self.user_id, [], invited_3pids),)
- self.assertEquals(
- callback_mock.call_args,
- expected_call_args,
- callback_mock.call_args,
- )
-
- # Now deny any room creation.
- return_value = False
-
- # Create a room and invite the 3PIDs, and check that it failed.
- channel = self.make_request(
- "POST",
- "/createRoom",
- json.dumps({"invite_3pid": invited_3pids}).encode("utf8"),
- )
- self.assertEqual(403, channel.code)
-
- # Check that do_3pid_invite wasn't called this time.
- self.assertEquals(do_3pid_invite_mock.call_count, len(invited_3pids))
-
def test_spam_checker_may_join_room(self):
"""Tests that the user_may_join_room spam checker callback is correctly bypassed
when creating a new room.
diff --git a/tests/rest/client/test_sendtodevice.py b/tests/rest/client/test_sendtodevice.py
index 6db7062a8e..e2ed14457f 100644
--- a/tests/rest/client/test_sendtodevice.py
+++ b/tests/rest/client/test_sendtodevice.py
@@ -198,3 +198,43 @@ class SendToDeviceTestCase(HomeserverTestCase):
"content": {"idx": 3},
},
)
+
+ def test_limited_sync(self):
+ """If a limited sync for to-devices happens the next /sync should respond immediately."""
+
+ self.register_user("u1", "pass")
+ user1_tok = self.login("u1", "pass", "d1")
+
+ user2 = self.register_user("u2", "pass")
+ user2_tok = self.login("u2", "pass", "d2")
+
+ # Do an initial sync
+ channel = self.make_request("GET", "/sync", access_token=user2_tok)
+ self.assertEqual(channel.code, 200, channel.result)
+ sync_token = channel.json_body["next_batch"]
+
+ # Send 150 to-device messages. We limit to 100 in `/sync`
+ for i in range(150):
+ test_msg = {"foo": "bar"}
+ chan = self.make_request(
+ "PUT",
+ f"/_matrix/client/r0/sendToDevice/m.test/1234-{i}",
+ content={"messages": {user2: {"d2": test_msg}}},
+ access_token=user1_tok,
+ )
+ self.assertEqual(chan.code, 200, chan.result)
+
+ channel = self.make_request(
+ "GET", f"/sync?since={sync_token}&timeout=300000", access_token=user2_tok
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+ messages = channel.json_body.get("to_device", {}).get("events", [])
+ self.assertEqual(len(messages), 100)
+ sync_token = channel.json_body["next_batch"]
+
+ channel = self.make_request(
+ "GET", f"/sync?since={sync_token}&timeout=300000", access_token=user2_tok
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+ messages = channel.json_body.get("to_device", {}).get("events", [])
+ self.assertEqual(len(messages), 50)
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index c427686376..cd4af2b1f3 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -23,7 +23,7 @@ from synapse.api.constants import (
ReadReceiptEventFields,
RelationTypes,
)
-from synapse.rest.client import knock, login, read_marker, receipts, room, sync
+from synapse.rest.client import devices, knock, login, read_marker, receipts, room, sync
from tests import unittest
from tests.federation.transport.test_knocking import (
@@ -710,3 +710,58 @@ class SyncCacheTestCase(unittest.HomeserverTestCase):
channel.await_result(timeout_ms=9900)
channel.await_result(timeout_ms=200)
self.assertEqual(channel.code, 200, channel.json_body)
+
+
+class DeviceListSyncTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ devices.register_servlets,
+ ]
+
+ def test_user_with_no_rooms_receives_self_device_list_updates(self):
+ """Tests that a user with no rooms still receives their own device list updates"""
+ device_id = "TESTDEVICE"
+
+ # Register a user and login, creating a device
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey", device_id=device_id)
+
+ # Request an initial sync
+ channel = self.make_request("GET", "/sync", access_token=self.tok)
+ self.assertEqual(channel.code, 200, channel.json_body)
+ next_batch = channel.json_body["next_batch"]
+
+ # Now, make an incremental sync request.
+ # It won't return until something has happened
+ incremental_sync_channel = self.make_request(
+ "GET",
+ f"/sync?since={next_batch}&timeout=30000",
+ access_token=self.tok,
+ await_result=False,
+ )
+
+ # Change our device's display name
+ channel = self.make_request(
+ "PUT",
+ f"devices/{device_id}",
+ {
+ "display_name": "freeze ray",
+ },
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # The sync should now have returned
+ incremental_sync_channel.await_result(timeout_ms=20000)
+ self.assertEqual(incremental_sync_channel.code, 200, channel.json_body)
+
+ # We should have received notification that the (user's) device has changed
+ device_list_changes = incremental_sync_channel.json_body.get(
+ "device_lists", {}
+ ).get("changed", [])
+
+ self.assertIn(
+ self.user_id, device_list_changes, incremental_sync_channel.json_body
+ )
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 4e71b6ec12..ac6b86ff6b 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -107,6 +107,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
return hs
def prepare(self, reactor, clock, homeserver):
+ super().prepare(reactor, clock, homeserver)
# Create some users and a room to play with during the tests
self.user_id = self.register_user("kermit", "monkey")
self.invitee = self.register_user("invitee", "hackme")
@@ -473,8 +474,6 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
def _send_event_over_federation(self) -> None:
"""Send a dummy event over federation and check that the request succeeds."""
body = {
- "origin": self.hs.config.server.server_name,
- "origin_server_ts": self.clock.time_msec(),
"pdus": [
{
"sender": self.user_id,
@@ -492,11 +491,10 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
],
}
- channel = self.make_request(
+ channel = self.make_signed_federation_request(
method="PUT",
path="/_matrix/federation/v1/send/1",
content=body,
- federation_auth_origin=self.hs.config.server.server_name.encode("utf8"),
)
self.assertEqual(channel.code, 200, channel.result)
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 8424383580..1c0cb0cf4f 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -31,6 +31,7 @@ from typing import (
overload,
)
from unittest.mock import patch
+from urllib.parse import urlencode
import attr
from typing_extensions import Literal
@@ -147,12 +148,20 @@ class RestHelper:
expect_code=expect_code,
)
- def join(self, room=None, user=None, expect_code=200, tok=None):
+ def join(
+ self,
+ room: str,
+ user: Optional[str] = None,
+ expect_code: int = 200,
+ tok: Optional[str] = None,
+ appservice_user_id: Optional[str] = None,
+ ) -> None:
self.change_membership(
room=room,
src=user,
targ=user,
tok=tok,
+ appservice_user_id=appservice_user_id,
membership=Membership.JOIN,
expect_code=expect_code,
)
@@ -209,11 +218,12 @@ class RestHelper:
def change_membership(
self,
room: str,
- src: str,
- targ: str,
+ src: Optional[str],
+ targ: Optional[str],
membership: str,
extra_data: Optional[dict] = None,
tok: Optional[str] = None,
+ appservice_user_id: Optional[str] = None,
expect_code: int = 200,
expect_errcode: Optional[str] = None,
) -> None:
@@ -227,15 +237,26 @@ class RestHelper:
membership: The type of membership event
extra_data: Extra information to include in the content of the event
tok: The user access token to use
+ appservice_user_id: The `user_id` URL parameter to pass.
+ This allows driving an application service user
+ using an application service access token in `tok`.
expect_code: The expected HTTP response code
expect_errcode: The expected Matrix error code
"""
temp_id = self.auth_user_id
self.auth_user_id = src
- path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (room, targ)
+ path = f"/_matrix/client/r0/rooms/{room}/state/m.room.member/{targ}"
+ url_params: Dict[str, str] = {}
+
if tok:
- path = path + "?access_token=%s" % tok
+ url_params["access_token"] = tok
+
+ if appservice_user_id:
+ url_params["user_id"] = appservice_user_id
+
+ if url_params:
+ path += "?" + urlencode(url_params)
data = {"membership": membership}
data.update(extra_data or {})
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index 53f6186213..da2c533260 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -243,6 +243,78 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
+ def test_video_rejected(self):
+ self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
+
+ end_content = b"anything"
+
+ channel = self.make_request(
+ "GET",
+ "preview_url?url=http://matrix.org",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b"Content-Type: video/mp4\r\n\r\n"
+ )
+ % (len(end_content))
+ + end_content
+ )
+
+ self.pump()
+ self.assertEqual(channel.code, 502)
+ self.assertEqual(
+ channel.json_body,
+ {
+ "errcode": "M_UNKNOWN",
+ "error": "Requested file's content type not allowed for this operation: video/mp4",
+ },
+ )
+
+ def test_audio_rejected(self):
+ self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
+
+ end_content = b"anything"
+
+ channel = self.make_request(
+ "GET",
+ "preview_url?url=http://matrix.org",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b"Content-Type: audio/aac\r\n\r\n"
+ )
+ % (len(end_content))
+ + end_content
+ )
+
+ self.pump()
+ self.assertEqual(channel.code, 502)
+ self.assertEqual(
+ channel.json_body,
+ {
+ "errcode": "M_UNKNOWN",
+ "error": "Requested file's content type not allowed for this operation: audio/aac",
+ },
+ )
+
def test_non_ascii_preview_content_type(self):
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 329490caad..ddcb7f5549 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -266,7 +266,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
service = Mock(id=self.as_list[0]["id"])
events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
txn = self.get_success(
- defer.ensureDeferred(self.store.create_appservice_txn(service, events, []))
+ defer.ensureDeferred(
+ self.store.create_appservice_txn(service, events, [], [])
+ )
)
self.assertEquals(txn.id, 1)
self.assertEquals(txn.events, events)
@@ -280,7 +282,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
self.get_success(self._set_last_txn(service.id, 9643)) # AS is falling behind
self.get_success(self._insert_txn(service.id, 9644, events))
self.get_success(self._insert_txn(service.id, 9645, events))
- txn = self.get_success(self.store.create_appservice_txn(service, events, []))
+ txn = self.get_success(
+ self.store.create_appservice_txn(service, events, [], [])
+ )
self.assertEquals(txn.id, 9646)
self.assertEquals(txn.events, events)
self.assertEquals(txn.service, service)
@@ -291,7 +295,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
service = Mock(id=self.as_list[0]["id"])
events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
self.get_success(self._set_last_txn(service.id, 9643))
- txn = self.get_success(self.store.create_appservice_txn(service, events, []))
+ txn = self.get_success(
+ self.store.create_appservice_txn(service, events, [], [])
+ )
self.assertEquals(txn.id, 9644)
self.assertEquals(txn.events, events)
self.assertEquals(txn.service, service)
@@ -313,7 +319,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
self.get_success(self._insert_txn(self.as_list[2]["id"], 10, events))
self.get_success(self._insert_txn(self.as_list[3]["id"], 9643, events))
- txn = self.get_success(self.store.create_appservice_txn(service, events, []))
+ txn = self.get_success(
+ self.store.create_appservice_txn(service, events, [], [])
+ )
self.assertEquals(txn.id, 9644)
self.assertEquals(txn.events, events)
self.assertEquals(txn.service, service)
@@ -481,10 +489,10 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase):
ValueError,
)
- def test_set_type_stream_id_for_appservice(self) -> None:
+ def test_set_appservice_stream_type_pos(self) -> None:
read_receipt_value = 1024
self.get_success(
- self.store.set_type_stream_id_for_appservice(
+ self.store.set_appservice_stream_type_pos(
self.service, "read_receipt", read_receipt_value
)
)
@@ -494,7 +502,7 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase):
self.assertEqual(result, read_receipt_value)
self.get_success(
- self.store.set_type_stream_id_for_appservice(
+ self.store.set_appservice_stream_type_pos(
self.service, "presence", read_receipt_value
)
)
@@ -503,9 +511,9 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase):
)
self.assertEqual(result, read_receipt_value)
- def test_set_type_stream_id_for_appservice_invalid_type(self) -> None:
+ def test_set_appservice_stream_type_pos_invalid_type(self) -> None:
self.get_failure(
- self.store.set_type_stream_id_for_appservice(self.service, "foobar", 1024),
+ self.store.set_appservice_stream_type_pos(self.service, "foobar", 1024),
ValueError,
)
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 2bc89512f8..667ca90a4d 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -260,16 +260,16 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
self.assertCountEqual(auth_chain_ids, ["h", "i", "j", "k"])
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["h"]))
- self.assertEqual(auth_chain_ids, ["k"])
+ self.assertEqual(auth_chain_ids, {"k"})
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["i"]))
- self.assertEqual(auth_chain_ids, ["j"])
+ self.assertEqual(auth_chain_ids, {"j"})
# j and k have no parents.
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["j"]))
- self.assertEqual(auth_chain_ids, [])
+ self.assertEqual(auth_chain_ids, set())
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["k"]))
- self.assertEqual(auth_chain_ids, [])
+ self.assertEqual(auth_chain_ids, set())
# More complex input sequences.
auth_chain_ids = self.get_success(
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 7f5b28aed8..48f1e9d841 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -341,7 +341,9 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
# Register an AS user.
user = self.register_user("user", "pass")
token = self.login(user, "pass")
- as_user = self.register_appservice_user("as_user_potato", self.appservice.token)
+ as_user, _ = self.register_appservice_user(
+ "as_user_potato", self.appservice.token
+ )
# Join the AS user to rooms owned by the normal user.
public, private = self._create_rooms_and_inject_memberships(
diff --git a/tests/unittest.py b/tests/unittest.py
index 1431848367..a71892cb9d 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -17,6 +17,7 @@ import gc
import hashlib
import hmac
import inspect
+import json
import logging
import secrets
import time
@@ -36,9 +37,11 @@ from typing import (
)
from unittest.mock import Mock, patch
-from canonicaljson import json
+import canonicaljson
+import signedjson.key
+import unpaddedbase64
-from twisted.internet.defer import Deferred, ensureDeferred, succeed
+from twisted.internet.defer import Deferred, ensureDeferred
from twisted.python.failure import Failure
from twisted.python.threadpool import ThreadPool
from twisted.test.proto_helpers import MemoryReactor
@@ -49,8 +52,7 @@ from twisted.web.server import Request
from synapse import events
from synapse.api.constants import EventTypes, Membership
from synapse.config.homeserver import HomeServerConfig
-from synapse.config.ratelimiting import FederationRateLimitConfig
-from synapse.federation.transport import server as federation_server
+from synapse.federation.transport.server import TransportLayerServer
from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest, SynapseSite
from synapse.logging.context import (
@@ -61,10 +63,10 @@ from synapse.logging.context import (
)
from synapse.rest import RegisterServletsFunc
from synapse.server import HomeServer
+from synapse.storage.keys import FetchKeyResult
from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.ratelimitutils import FederationRateLimiter
from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver
from tests.test_utils import event_injection, setup_awaitable_errors
@@ -620,18 +622,19 @@ class HomeserverTestCase(TestCase):
self,
username: str,
appservice_token: str,
- ) -> str:
+ ) -> Tuple[str, str]:
"""Register an appservice user as an application service.
Requires the client-facing registration API be registered.
Args:
username: the user to be registered by an application service.
- Should be a full username, i.e. ""@localpart:hostname" as opposed to just "localpart"
+ Should NOT be a full username, i.e. just "localpart" as opposed to "@localpart:hostname"
appservice_token: the acccess token for that application service.
Raises: if the request to '/register' does not return 200 OK.
- Returns: the MXID of the new user.
+ Returns:
+ The MXID of the new user, the device ID of the new user's first device.
"""
channel = self.make_request(
"POST",
@@ -643,7 +646,7 @@ class HomeserverTestCase(TestCase):
access_token=appservice_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
- return channel.json_body["user_id"]
+ return channel.json_body["user_id"], channel.json_body["device_id"]
def login(
self,
@@ -754,42 +757,116 @@ class HomeserverTestCase(TestCase):
class FederatingHomeserverTestCase(HomeserverTestCase):
"""
- A federating homeserver that authenticates incoming requests as `other.example.com`.
+ A federating homeserver, set up to validate incoming federation requests
"""
- def create_resource_dict(self) -> Dict[str, Resource]:
- d = super().create_resource_dict()
- d["/_matrix/federation"] = TestTransportLayerServer(self.hs)
- return d
+ OTHER_SERVER_NAME = "other.example.com"
+ OTHER_SERVER_SIGNATURE_KEY = signedjson.key.generate_signing_key("test")
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
+ super().prepare(reactor, clock, hs)
-class TestTransportLayerServer(JsonResource):
- """A test implementation of TransportLayerServer
+ # poke the other server's signing key into the key store, so that we don't
+ # make requests for it
+ verify_key = signedjson.key.get_verify_key(self.OTHER_SERVER_SIGNATURE_KEY)
+ verify_key_id = "%s:%s" % (verify_key.alg, verify_key.version)
- authenticates incoming requests as `other.example.com`.
- """
+ self.get_success(
+ hs.get_datastore().store_server_verify_keys(
+ from_server=self.OTHER_SERVER_NAME,
+ ts_added_ms=clock.time_msec(),
+ verify_keys=[
+ (
+ self.OTHER_SERVER_NAME,
+ verify_key_id,
+ FetchKeyResult(
+ verify_key=verify_key,
+ valid_until_ts=clock.time_msec() + 1000,
+ ),
+ )
+ ],
+ )
+ )
+
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ d = super().create_resource_dict()
+ d["/_matrix/federation"] = TransportLayerServer(self.hs)
+ return d
- def __init__(self, hs):
- super().__init__(hs)
+ def make_signed_federation_request(
+ self,
+ method: str,
+ path: str,
+ content: Optional[JsonDict] = None,
+ await_result: bool = True,
+ custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
+ client_ip: str = "127.0.0.1",
+ ) -> FakeChannel:
+ """Make an inbound signed federation request to this server
- class Authenticator:
- def authenticate_request(self, request, content):
- return succeed("other.example.com")
+ The request is signed as if it came from "other.example.com", which our HS
+ already has the keys for.
+ """
- authenticator = Authenticator()
+ if custom_headers is None:
+ custom_headers = []
+ else:
+ custom_headers = list(custom_headers)
+
+ custom_headers.append(
+ (
+ "Authorization",
+ _auth_header_for_request(
+ origin=self.OTHER_SERVER_NAME,
+ destination=self.hs.hostname,
+ signing_key=self.OTHER_SERVER_SIGNATURE_KEY,
+ method=method,
+ path=path,
+ content=content,
+ ),
+ )
+ )
- ratelimiter = FederationRateLimiter(
- hs.get_clock(),
- FederationRateLimitConfig(
- window_size=1,
- sleep_limit=1,
- sleep_delay=1,
- reject_limit=1000,
- concurrent=1000,
- ),
+ return make_request(
+ self.reactor,
+ self.site,
+ method=method,
+ path=path,
+ content=content,
+ shorthand=False,
+ await_result=await_result,
+ custom_headers=custom_headers,
+ client_ip=client_ip,
)
- federation_server.register_servlets(hs, self, authenticator, ratelimiter)
+
+def _auth_header_for_request(
+ origin: str,
+ destination: str,
+ signing_key: signedjson.key.SigningKey,
+ method: str,
+ path: str,
+ content: Optional[JsonDict],
+) -> str:
+ """Build a suitable Authorization header for an outgoing federation request"""
+ request_description: JsonDict = {
+ "method": method,
+ "uri": path,
+ "destination": destination,
+ "origin": origin,
+ }
+ if content is not None:
+ request_description["content"] = content
+ signature_base64 = unpaddedbase64.encode_base64(
+ signing_key.sign(
+ canonicaljson.encode_canonical_json(request_description)
+ ).signature
+ )
+ return (
+ f"X-Matrix origin={origin},"
+ f"key={signing_key.alg}:{signing_key.version},"
+ f"sig={signature_base64}"
+ )
def override_config(extra_config):
|